├── .gitignore ├── LICENSE ├── README.md ├── assets └── activations.gif ├── data └── custom_datasets.py ├── examples ├── 01_mnist.ipynb └── 02_inference.ipynb ├── models ├── README.md ├── constants.py ├── ctm.py ├── ctm_qamnist.py ├── ctm_rl.py ├── ctm_sort.py ├── ff.py ├── lstm.py ├── lstm_qamnist.py ├── lstm_rl.py ├── modules.py ├── resnet.py └── utils.py ├── requirements.txt ├── tasks ├── image_classification │ ├── README.md │ ├── analysis │ │ ├── README.md │ │ └── run_imagenet_analysis.py │ ├── imagenet_classes.py │ ├── plotting.py │ ├── scripts │ │ ├── train_cifar10.sh │ │ └── train_imagenet.sh │ ├── train.py │ └── train_distributed.py ├── mazes │ ├── README.md │ ├── analysis │ │ ├── README.md │ │ └── run.py │ ├── plotting.py │ ├── scripts │ │ └── train_ctm.sh │ ├── train.py │ └── train_distributed.py ├── parity │ ├── README.md │ ├── analysis │ │ ├── make_blog_gifs.py │ │ └── run.py │ ├── plotting.py │ ├── scripts │ │ ├── train_ctm_100_50.sh │ │ ├── train_ctm_10_5.sh │ │ ├── train_ctm_1_1.sh │ │ ├── train_ctm_25_10.sh │ │ ├── train_ctm_50_25.sh │ │ ├── train_ctm_75_25.sh │ │ ├── train_lstm_1.sh │ │ ├── train_lstm_10.sh │ │ ├── train_lstm_100.sh │ │ ├── train_lstm_10_certain.sh │ │ ├── train_lstm_25.sh │ │ ├── train_lstm_25_certain.sh │ │ ├── train_lstm_50.sh │ │ └── train_lstm_75.sh │ ├── train.py │ └── utils.py ├── qamnist │ ├── README.md │ ├── analysis │ │ ├── make_blog_gifs.py │ │ ├── make_blog_gifs_equation_animation.py │ │ └── run.py │ ├── plotting.py │ ├── scripts │ │ ├── train_ctm_1.sh │ │ ├── train_ctm_10.sh │ │ ├── train_lstm_1.sh │ │ └── train_lstm_10.sh │ ├── train.py │ └── utils.py ├── rl │ ├── README.md │ ├── analysis │ │ ├── make_blog_gifs.py │ │ └── run.py │ ├── envs.py │ ├── plotting.py │ ├── scripts │ │ ├── 4rooms │ │ │ ├── train_ctm_1.sh │ │ │ ├── train_ctm_2.sh │ │ │ ├── train_lstm_1.sh │ │ │ └── train_lstm_2.sh │ │ ├── acrobot │ │ │ ├── train_ctm_1.sh │ │ │ ├── train_ctm_2.sh │ │ │ ├── train_ctm_5.sh │ │ │ ├── train_lstm_1.sh │ │ │ ├── train_lstm_2.sh │ │ │ └── train_lstm_5.sh │ │ └── cartpole │ │ │ ├── train_ctm_1.sh │ │ │ ├── train_ctm_2.sh │ │ │ ├── train_ctm_5.sh │ │ │ ├── train_lstm_1.sh │ │ │ ├── train_lstm_2.sh │ │ │ └── train_lstm_5.sh │ ├── train.py │ └── utils.py └── sort │ ├── train.py │ └── utils.py ├── tests ├── README.md ├── __init__.py ├── conftest.py ├── test_data.py └── tests.py └── utils ├── __init__.py ├── housekeeping.py ├── losses.py ├── samplers.py └── schedulers.py /.gitignore: -------------------------------------------------------------------------------- 1 | */__pycache__ 2 | logs 3 | .DS_Store 4 | *.png 5 | *.pdf 6 | *.gif 7 | *.out 8 | *.pyc 9 | *.env 10 | *.pt 11 | *.mp4 12 | .vscode* 13 | *outputs* 14 | data/* 15 | !assets/*.gif 16 | !data/custom_datasets.py 17 | examples/* 18 | !examples/01_mnist.ipynb 19 | !examples/02_inference.ipynb 20 | checkpoints 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 Rémi Louf 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🕰️ The Continuous Thought Machine 2 | 3 | 📚 [PAPER: Technical Report](https://arxiv.org/abs/2505.05522) | 📝 [Blog](https://sakana.ai/ctm/) | 🕹️ [Interactive Website](https://pub.sakana.ai/ctm) | ✏️ [Tutorial](examples/01_mnist.ipynb) 4 | 5 | ![Activations](assets/activations.gif) 6 | 7 | We present the Continuous Thought Machine (CTM), a model designed to unfold and then leverage neural activity as the underlying mechanism for observation and action. Our contributions are: 8 | 9 | 1. An internal temporal axis, decoupled from any input data, that enables neuron activity to unfold. 10 | 11 | 2. Neuron-level temporal processing, where each neuron uses unique weight parameters to process a history of incoming signals, enabling fine-grained temporal dynamics. 12 | 13 | 3. Neural synchronisation, employed as a direct latent representation for modulating data and producing outputs, thus directly encoding information in the timing of neural activity. 14 | 15 | We demonstrate the CTM's strong performance and versatility across a range of challenging tasks, including ImageNet classification, solving 2D mazes, sorting, parity computation, question-answering, and RL tasks. 16 | 17 | We provide all necessary code to reproduce our results and invite others to build upon and use CTMs in their own work. 18 | 19 | ## [Interactive Website](https://pub.sakana.ai/ctm) 20 | Please see our [Interactive Website](https://pub.sakana.ai/ctm) for a maze-solving demo, many demonstrative videos of the method, results, and other findings. 21 | 22 | 23 | ## Repo structure 24 | ``` 25 | ├── tasks 26 | │   ├── image_classification 27 | │   │   ├── train.py # Training code for image classification (cifar, imagenet) 28 | │   │   ├── imagenet_classes.py # Helper for imagenet class names 29 | │   │   ├── plotting.py # Plotting utils specific to this task 30 | │   │   └── analysis 31 | │   │   ├──run_imagenet_analysis.py # ImageNet eval and visualisation code 32 | │   │      └──outputs/ # Folder for outputs of analysis 33 | │   ├── mazes 34 | │   │   ├── train.py # Training code for solving 2D mazes (by way of a route; see paper) 35 | │   │   └── plotting.py # Plotting utils specific to this task 36 | │   │   └── analysis 37 | │   │   ├──run.py # Maze analysis code 38 | │   │      └──outputs/ # Folder for outputs of analysis 39 | │   ├── sort 40 | │   │   ├── train.py # Training code for sorting 41 | │   │   └── utils.py # Sort specific utils (e.g., CTC decode) 42 | │   ├── parity 43 | │   │   ├── train.py # Training code for parity task 44 | │   │   ├── utils.py # Parity-specific helper functions 45 | │   │   ├── plotting.py # Plotting utils specific to this task 46 | │   │   ├── scripts/ 47 | │   │   │   └── *.sh # Training scripts for different experimental setups 48 | │   │   └── analysis/ 49 | │   │   └── run.py # Entry point for parity analysis 50 | │   ├── qamnist 51 | │   │   ├── train.py # Training code for QAMNIST task (quantized MNIST) 52 | │   │   ├── utils.py # QAMNIST-specific helper functions 53 | │   │   ├── plotting.py # Plotting utils specific to this task 54 | │   │   ├── scripts/ 55 | │   │   │   └── *.sh # Training scripts for different experimental setups 56 | │   │   └── analysis/ 57 | │   │   └── run.py # Entry point for QAMNIST analysis 58 | │   └── rl 59 | │      ├── train.py # Training code for RL environments 60 | │      ├── utils.py # RL-specific helper functions 61 | │      ├── plotting.py # Plotting utils specific to this task 62 | │      ├── envs.py # Custom RL environment wrappers 63 | │      ├── scripts/ 64 | │      │   ├── 4rooms/ 65 | │      │   │   └── *.sh # Training scripts for MiniGrid-FourRooms-v0 environment 66 | │      │   ├── acrobot/ 67 | │      │   │   └── *.sh # Training scripts for Acrobot-v1 environment 68 | │      │   └── cartpole/ 69 | │      │   └── *.sh # Training scripts for CartPole-v1 environment 70 | │      └── analysis/ 71 | │      └── run.py # Entry point for RL analysis 72 | ├── data # This is where data will be saved and downloaded to 73 | │   └── custom_datasets.py # Custom datasets (e.g., Mazes), sort 74 | ├── models 75 | │   ├── ctm.py # Main model code, used for: image classification, solving mazes, sort 76 | │   ├── ctm_*.py # Other model code, standalone adjustments for other tasks 77 | │   ├── ff.py # feed-forward (simple) baseline code (e.g., for image classification) 78 | │   ├── lstm.py # LSTM baseline code (e.g., for image classification) 79 | │   ├── lstm_*.py # Other baseline code, standalone adjustments for other tasks 80 | │   ├── modules.py # Helper modules, including Neuron-level models and the Synapse UNET 81 | │   ├── utils.py # Helper functions (e.g., synch decay) 82 | │   └── resnet.py # Wrapper for ResNet featuriser 83 | ├── utils 84 | │   ├── housekeeping.py # Helper functions for keeping things neat 85 | │   ├── losses.py # Loss functions for various tasks (mostly with reshaping stuff) 86 | │   └── schedulers.py # Helper wrappers for learning rate schedulers 87 | └── checkpoints 88 |    └── imagenet, mazes, ... # Checkpoint directories (see google drive link for files) 89 | 90 | ``` 91 | 92 | ## Setup 93 | To set up the environment using conda: 94 | 95 | ``` 96 | conda create --name=ctm python=3.12 97 | conda activate ctm 98 | pip install -r requirements.txt 99 | ``` 100 | 101 | If there are issues with PyTorch versions, the following can be ran: 102 | ``` 103 | pip uninstall torch 104 | pip install torch --index-url https://download.pytorch.org/whl/cu121 105 | ``` 106 | 107 | ## Model training 108 | Each task has its own (set of) training code. See for instance [tasks/image_classification/train.py](tasks/image_classification/train.py). We have set it up like this to ensure ease-of-use as opposed to clinical efficiency. This code is for researchers and we hope to have it shared in a way that fosters collaboration and learning. 109 | 110 | While we have provided reasonable defaults in the argparsers of each training setup, scripts to replicate the setups in the paper will typically be found in the accompanying script folders. If you simply want to dive in, run the following as a module (setup like this to make it easy to run many high-level training scripts from the top directory): 111 | 112 | ``` 113 | python -m tasks.image_classification.train 114 | ``` 115 | For debugging in VSCode, this configuration example might be helpful to you: 116 | ``` 117 | { 118 | "name": "Debug: train image classifier", 119 | "type": "debugpy", 120 | "request": "launch", 121 | "module": "tasks.image_classification.train", 122 | "console": "integratedTerminal", 123 | "justMyCode": false 124 | } 125 | ``` 126 | 127 | 128 | ## Running analyses 129 | 130 | We also provide analysis and plotting code to replicate many of the plots in our paper. See `tasks/.../analysis/*` for more details on that. We also provide some data (e.g., the mazes we generated for training) and checkpoints (see [here](#checkpoints-and-data)). Note that ffmpeg is required for generating mp4 files from the analysis scripts. It can be installed with: 131 | ``` 132 | conda install -c conda-forge ffmpeg 133 | ``` 134 | 135 | 136 | ## Checkpoints and data 137 | You can download the data and checkpoints from here: 138 | - checkpoints: https://drive.google.com/drive/folders/1vSg8T7FqP-guMDk1LU7_jZaQtXFP9sZg 139 | - maze data: https://drive.google.com/file/d/1cBgqhaUUtsrll8-o2VY42hPpyBcfFv86/view?usp=drivesdk 140 | 141 | Checkpoints go in the `checkpoints` folder. For instance, when properly populated, the checkpoints folder will have the maze checkpoint in `checkpoints/mazes/...` 142 | -------------------------------------------------------------------------------- /assets/activations.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SakanaAI/continuous-thought-machines/aecfb63ac42db7a20903ee27489ff71671669474/assets/activations.gif -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | # Continuous Thought Machines 2 | ## Models 3 | 4 | This folder contains all model-related code. 5 | 6 | Some notes for clarity: 7 | 1. The resnet structure we used (see resnet.py) has a few minor changes that enable constraining the receptive field of the features yielded. We do this because we want the CTM (or baseline methods) to learn a process whereby they gather information. Neural networks that use SGD will find the [path of least resistence](https://era.ed.ac.uk/handle/1842/39606), even if that path doesn't result in actually intelligent behaviour. Constraining the receptive field helps to prevent this, a bit. -------------------------------------------------------------------------------- /models/constants.py: -------------------------------------------------------------------------------- 1 | VALID_NEURON_SELECT_TYPES = ['first-last', 'random', 'random-pairing'] 2 | 3 | VALID_BACKBONE_TYPES = [ 4 | f'resnet{depth}-{i}' for depth in [18, 34, 50, 101, 152] for i in range(1, 5) 5 | ] + ['shallow-wide', 'parity_backbone'] 6 | 7 | VALID_POSITIONAL_EMBEDDING_TYPES = [ 8 | 'learnable-fourier', 'multi-learnable-fourier', 9 | 'custom-rotational', 'custom-rotational-1d' 10 | ] 11 | -------------------------------------------------------------------------------- /models/ctm_qamnist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from models.ctm import ContinuousThoughtMachine 4 | from models.modules import MNISTBackbone, QAMNISTIndexEmbeddings, QAMNISTOperatorEmbeddings 5 | 6 | class ContinuousThoughtMachineQAMNIST(ContinuousThoughtMachine): 7 | def __init__(self, 8 | iterations, 9 | d_model, 10 | d_input, 11 | heads, 12 | n_synch_out, 13 | n_synch_action, 14 | synapse_depth, 15 | memory_length, 16 | deep_nlms, 17 | memory_hidden_dims, 18 | do_layernorm_nlm, 19 | out_dims, 20 | iterations_per_digit, 21 | iterations_per_question_part, 22 | iterations_for_answering, 23 | prediction_reshaper=[-1], 24 | dropout=0, 25 | neuron_select_type='first-last', 26 | n_random_pairing_self=256 27 | ): 28 | super().__init__( 29 | iterations=iterations, 30 | d_model=d_model, 31 | d_input=d_input, 32 | heads=heads, 33 | n_synch_out=n_synch_out, 34 | n_synch_action=n_synch_action, 35 | synapse_depth=synapse_depth, 36 | memory_length=memory_length, 37 | deep_nlms=deep_nlms, 38 | memory_hidden_dims=memory_hidden_dims, 39 | do_layernorm_nlm=do_layernorm_nlm, 40 | out_dims=out_dims, 41 | prediction_reshaper=prediction_reshaper, 42 | dropout=dropout, 43 | neuron_select_type=neuron_select_type, 44 | n_random_pairing_self=n_random_pairing_self, 45 | backbone_type='none', 46 | positional_embedding_type='none', 47 | ) 48 | 49 | # --- Core Parameters --- 50 | self.iterations_per_digit = iterations_per_digit 51 | self.iterations_per_question_part = iterations_per_question_part 52 | self.iterations_for_answering = iterations_for_answering 53 | 54 | # --- Setup Methods --- 55 | 56 | def set_initial_rgb(self): 57 | """Set the initial RGB values for the backbone.""" 58 | return None 59 | 60 | def get_d_backbone(self): 61 | """Get the dimensionality of the backbone output.""" 62 | return self.d_input 63 | 64 | def set_backbone(self): 65 | """Set the backbone module based on the specified type.""" 66 | self.backbone_digit = MNISTBackbone(self.d_input) 67 | self.index_backbone = QAMNISTIndexEmbeddings(50, self.d_input) 68 | self.operator_backbone = QAMNISTOperatorEmbeddings(2, self.d_input) 69 | pass 70 | 71 | # --- Utilty Methods --- 72 | 73 | def determine_step_type(self, total_iterations_for_digits, total_iterations_for_question, stepi: int): 74 | """Determine whether the current step is for digits, questions, or answers.""" 75 | is_digit_step = stepi < total_iterations_for_digits 76 | is_question_step = total_iterations_for_digits <= stepi < total_iterations_for_digits + total_iterations_for_question 77 | is_answer_step = stepi >= total_iterations_for_digits + total_iterations_for_question 78 | return is_digit_step, is_question_step, is_answer_step 79 | 80 | def determine_index_operator_step_type(self, total_iterations_for_digits, stepi: int): 81 | """Determine whether the current step is for index or operator.""" 82 | step_within_questions = stepi - total_iterations_for_digits 83 | if step_within_questions % (2 * self.iterations_per_question_part) < self.iterations_per_question_part: 84 | is_index_step = True 85 | is_operator_step = False 86 | else: 87 | is_index_step = False 88 | is_operator_step = True 89 | return is_index_step, is_operator_step 90 | 91 | def get_kv_for_step(self, total_iterations_for_digits, total_iterations_for_question, stepi, x, z, prev_input=None, prev_kv=None): 92 | """Get the key-value for the current step.""" 93 | is_digit_step, is_question_step, is_answer_step = self.determine_step_type(total_iterations_for_digits, total_iterations_for_question, stepi) 94 | 95 | if is_digit_step: 96 | current_input = x[:, stepi] 97 | if prev_input is not None and torch.equal(current_input, prev_input): 98 | return prev_kv, prev_input 99 | kv = self.kv_proj(self.backbone_digit(current_input).flatten(2).permute(0, 2, 1)) 100 | 101 | elif is_question_step: 102 | offset = stepi - total_iterations_for_digits 103 | current_input = z[:, offset] 104 | if prev_input is not None and torch.equal(current_input, prev_input): 105 | return prev_kv, prev_input 106 | is_index_step, is_operator_step = self.determine_index_operator_step_type(total_iterations_for_digits, stepi) 107 | if is_index_step: 108 | kv = self.index_backbone(current_input) 109 | elif is_operator_step: 110 | kv = self.operator_backbone(current_input) 111 | else: 112 | raise ValueError("Invalid step type for question processing.") 113 | 114 | elif is_answer_step: 115 | current_input = None 116 | kv = torch.zeros((x.size(0), self.d_input), device=x.device) 117 | 118 | else: 119 | raise ValueError("Invalid step type.") 120 | 121 | return kv, current_input 122 | 123 | 124 | 125 | 126 | def forward(self, x, z, track=False): 127 | B = x.size(0) 128 | device = x.device 129 | 130 | # --- Tracking Initialization --- 131 | pre_activations_tracking = [] 132 | post_activations_tracking = [] 133 | attention_tracking = [] 134 | embedding_tracking = [] 135 | 136 | total_iterations_for_digits = x.size(1) 137 | total_iterations_for_question = z.size(1) 138 | total_iterations = total_iterations_for_digits + total_iterations_for_question + self.iterations_for_answering 139 | 140 | # --- Initialise Recurrent State --- 141 | state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T) 142 | activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H) 143 | 144 | # --- Storage for outputs per iteration --- 145 | predictions = torch.empty(B, self.out_dims, total_iterations, device=device, dtype=x.dtype) 146 | certainties = torch.empty(B, 2, total_iterations, device=device, dtype=x.dtype) 147 | 148 | # --- Initialise Recurrent Synch Values --- 149 | decay_alpha_action, decay_beta_action = None, None 150 | self.decay_params_action.data = torch.clamp(self.decay_params_action, 0, 15) # Fix from github user: kuviki 151 | self.decay_params_out.data = torch.clamp(self.decay_params_out, 0, 15) 152 | r_action, r_out = torch.exp(-self.decay_params_action).unsqueeze(0).repeat(B, 1), torch.exp(-self.decay_params_out).unsqueeze(0).repeat(B, 1) 153 | 154 | _, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out') 155 | 156 | prev_input = None 157 | prev_kv = None 158 | 159 | # --- Recurrent Loop --- 160 | for stepi in range(total_iterations): 161 | is_digit_step, is_question_step, is_answer_step = self.determine_step_type(total_iterations_for_digits, total_iterations_for_question, stepi) 162 | 163 | kv, prev_input = self.get_kv_for_step(total_iterations_for_digits, total_iterations_for_question, stepi, x, z, prev_input, prev_kv) 164 | prev_kv = kv 165 | 166 | synchronization_action, decay_alpha_action, decay_beta_action = self.compute_synchronisation(activated_state, decay_alpha_action, decay_beta_action, r_action, synch_type='action') 167 | 168 | # --- Interact with Data via Attention --- 169 | attn_weights = None 170 | if is_digit_step: 171 | q = self.q_proj(synchronization_action).unsqueeze(1) 172 | attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True) 173 | attn_out = attn_out.squeeze(1) 174 | pre_synapse_input = torch.concatenate((attn_out, activated_state), dim=-1) 175 | else: 176 | kv = kv.squeeze(1) 177 | pre_synapse_input = torch.concatenate((kv, activated_state), dim=-1) 178 | 179 | # --- Apply Synapses --- 180 | state = self.synapses(pre_synapse_input) 181 | state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1) 182 | 183 | # --- Apply NLMs --- 184 | activated_state = self.trace_processor(state_trace) 185 | 186 | # --- Calculate Synchronisation for Output Predictions --- 187 | synchronization_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out') 188 | 189 | # --- Get Predictions and Certainties --- 190 | current_prediction = self.output_projector(synchronization_out) 191 | current_certainty = self.compute_certainty(current_prediction) 192 | 193 | predictions[..., stepi] = current_prediction 194 | certainties[..., stepi] = current_certainty 195 | 196 | # --- Tracking --- 197 | if track: 198 | pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy()) 199 | post_activations_tracking.append(activated_state.detach().cpu().numpy()) 200 | if attn_weights is not None: 201 | attention_tracking.append(attn_weights.detach().cpu().numpy()) 202 | if is_question_step: 203 | embedding_tracking.append(kv.detach().cpu().numpy()) 204 | 205 | # --- Return Values --- 206 | if track: 207 | return predictions, certainties, synchronization_out, np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking), np.array(embedding_tracking) 208 | return predictions, certainties, synchronization_out -------------------------------------------------------------------------------- /models/ctm_rl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | from models.ctm import ContinuousThoughtMachine 6 | from models.modules import MiniGridBackbone, ClassicControlBackbone, SynapseUNET 7 | from models.utils import compute_decay 8 | from models.constants import VALID_NEURON_SELECT_TYPES 9 | 10 | class ContinuousThoughtMachineRL(ContinuousThoughtMachine): 11 | def __init__(self, 12 | iterations, 13 | d_model, 14 | d_input, 15 | n_synch_out, 16 | synapse_depth, 17 | memory_length, 18 | deep_nlms, 19 | memory_hidden_dims, 20 | do_layernorm_nlm, 21 | backbone_type, 22 | prediction_reshaper=[-1], 23 | dropout=0, 24 | neuron_select_type='first-last', 25 | ): 26 | super().__init__( 27 | iterations=iterations, 28 | d_model=d_model, 29 | d_input=d_input, 30 | heads=0, # Set heads to 0 will return None 31 | n_synch_out=n_synch_out, 32 | n_synch_action=0, 33 | synapse_depth=synapse_depth, 34 | memory_length=memory_length, 35 | deep_nlms=deep_nlms, 36 | memory_hidden_dims=memory_hidden_dims, 37 | do_layernorm_nlm=do_layernorm_nlm, 38 | out_dims=0, 39 | prediction_reshaper=prediction_reshaper, 40 | dropout=dropout, 41 | neuron_select_type=neuron_select_type, 42 | backbone_type=backbone_type, 43 | n_random_pairing_self=0, 44 | positional_embedding_type='none', 45 | ) 46 | 47 | # --- Use a minimal CTM w/out input (action) synch --- 48 | self.neuron_select_type_action = None 49 | self.synch_representation_size_action = None 50 | 51 | # --- Start dynamics with a learned activated state trace --- 52 | self.register_parameter('start_activated_trace', nn.Parameter(torch.zeros((d_model, memory_length)).uniform_(-math.sqrt(1/(d_model+memory_length)), math.sqrt(1/(d_model+memory_length))), requires_grad=True)) 53 | self.start_activated_state = None 54 | 55 | self.register_buffer('diagonal_mask_out', torch.triu(torch.ones(self.n_synch_out, self.n_synch_out, dtype=torch.bool))) 56 | 57 | self.attention = None # Should already be None because super(... heads=0... ) 58 | self.q_proj = None # Should already be None because super(... heads=0... ) 59 | self.kv_proj = None # Should already be None because super(... heads=0... ) 60 | self.output_projector = None 61 | 62 | # --- Core CTM Methods --- 63 | 64 | def compute_synchronisation(self, activated_state_trace): 65 | """Compute the synchronisation between neurons.""" 66 | assert self.neuron_select_type == "first-last", "only fisrst-last neuron selection is supported here" 67 | # For RL tasks we track a sliding window of activations from which we compute synchronisation 68 | S = activated_state_trace.permute(0, 2, 1) 69 | diagonal_mask = self.diagonal_mask_out.to(S.device) 70 | decay = compute_decay(S.size(1), self.decay_params_out, clamp_lims=(0, 4)) 71 | synchronisation = ((decay.unsqueeze(0) *(S[:,:,-self.n_synch_out:].unsqueeze(-1) * S[:,:,-self.n_synch_out:].unsqueeze(-2))[:,:,diagonal_mask]).sum(1))/torch.sqrt(decay.unsqueeze(0).sum(1,)) 72 | return synchronisation 73 | 74 | # --- Setup Methods --- 75 | 76 | def set_initial_rgb(self): 77 | """Set the initial RGB values for the backbone.""" 78 | return None 79 | 80 | def get_d_backbone(self): 81 | """Get the dimensionality of the backbone output.""" 82 | return self.d_input 83 | 84 | def set_backbone(self): 85 | """Set the backbone module based on the specified type.""" 86 | if self.backbone_type == 'navigation-backbone': 87 | self.backbone = MiniGridBackbone(self.d_input) 88 | elif self.backbone_type == 'classic-control-backbone': 89 | self.backbone = ClassicControlBackbone(self.d_input) 90 | else: 91 | raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).') 92 | pass 93 | 94 | def get_positional_embedding(self, d_backbone): 95 | """Get the positional embedding module.""" 96 | return None 97 | 98 | 99 | def get_synapses(self, synapse_depth, d_model, dropout): 100 | """ 101 | Get the synapse module. 102 | 103 | We found in our early experimentation that a single Linear, GLU and LayerNorm block performed worse than two blocks. 104 | For that reason we set the default synapse depth to two blocks. 105 | 106 | TODO: This is legacy and needs further experimentation to iron out. 107 | """ 108 | if synapse_depth == 1: 109 | return nn.Sequential( 110 | nn.Dropout(dropout), 111 | nn.LazyLinear(d_model*2), 112 | nn.GLU(), 113 | nn.LayerNorm(d_model), 114 | nn.LazyLinear(d_model*2), 115 | nn.GLU(), 116 | nn.LayerNorm(d_model) 117 | ) 118 | else: 119 | return SynapseUNET(d_model, synapse_depth, 16, dropout) 120 | 121 | def set_synchronisation_parameters(self, synch_type: str, n_synch: int, n_random_pairing_self: int = 0): 122 | """Set the parameters for the synchronisation of neurons.""" 123 | if synch_type == 'action': 124 | pass 125 | elif synch_type == 'out': 126 | left, right = self.initialize_left_right_neurons("out", self.d_model, n_synch, n_random_pairing_self) 127 | self.register_buffer(f'out_neuron_indices_left', left) 128 | self.register_buffer(f'out_neuron_indices_right', right) 129 | self.register_parameter(f'decay_params_out', nn.Parameter(torch.zeros(self.synch_representation_size_out), requires_grad=True)) 130 | pass 131 | else: 132 | raise ValueError(f"Invalid synch_type: {synch_type}") 133 | 134 | # --- Utilty Methods --- 135 | 136 | def verify_args(self): 137 | """Verify the validity of the input arguments.""" 138 | assert self.neuron_select_type in VALID_NEURON_SELECT_TYPES, \ 139 | f"Invalid neuron selection type: {self.neuron_select_type}" 140 | assert self.neuron_select_type != 'random-pairing', \ 141 | f"Random pairing is not supported for RL." 142 | assert self.backbone_type in ('navigation-backbone', 'classic-control-backbone'), \ 143 | f"Invalid backbone_type: {self.backbone_type}" 144 | assert self.d_model >= (self.n_synch_out), \ 145 | "d_model must be >= n_synch_out for neuron subsets" 146 | pass 147 | 148 | 149 | 150 | 151 | def forward(self, x, hidden_states, track=False): 152 | 153 | # --- Tracking Initialization --- 154 | pre_activations_tracking = [] 155 | post_activations_tracking = [] 156 | 157 | # --- Featurise Input Data --- 158 | features = self.backbone(x) 159 | 160 | # --- Get Recurrent State --- 161 | state_trace, activated_state_trace = hidden_states 162 | 163 | # --- Recurrent Loop --- 164 | for stepi in range(self.iterations): 165 | 166 | pre_synapse_input = torch.concatenate((features.reshape(x.size(0), -1), activated_state_trace[:,:,-1]), -1) 167 | 168 | # --- Apply Synapses --- 169 | state = self.synapses(pre_synapse_input) 170 | state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1) 171 | 172 | # --- Apply NLMs --- 173 | activated_state = self.trace_processor(state_trace) 174 | activated_state_trace = torch.concatenate((activated_state_trace[:,:,1:], activated_state.unsqueeze(-1)), -1) 175 | 176 | # --- Tracking --- 177 | if track: 178 | pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy()) 179 | post_activations_tracking.append(activated_state.detach().cpu().numpy()) 180 | 181 | hidden_states = ( 182 | state_trace, 183 | activated_state_trace, 184 | ) 185 | 186 | # --- Calculate Output Synchronisation --- 187 | synchronisation_out = self.compute_synchronisation(activated_state_trace) 188 | 189 | # --- Return Values --- 190 | if track: 191 | return synchronisation_out, hidden_states, np.array(pre_activations_tracking), np.array(post_activations_tracking) 192 | return synchronisation_out, hidden_states -------------------------------------------------------------------------------- /models/ctm_sort.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from models.ctm import ContinuousThoughtMachine 4 | 5 | class ContinuousThoughtMachineSORT(ContinuousThoughtMachine): 6 | """ 7 | Slight adaption of the CTM to work with the sort task. 8 | """ 9 | 10 | def __init__(self, 11 | iterations, 12 | d_model, 13 | d_input, 14 | heads, 15 | n_synch_out, 16 | n_synch_action, 17 | synapse_depth, 18 | memory_length, 19 | deep_nlms, 20 | memory_hidden_dims, 21 | do_layernorm_nlm, 22 | backbone_type, 23 | positional_embedding_type, 24 | out_dims, 25 | prediction_reshaper=[-1], 26 | dropout=0, 27 | dropout_nlm=None, 28 | neuron_select_type='random-pairing', 29 | n_random_pairing_self=0, 30 | ): 31 | super().__init__( 32 | iterations=iterations, 33 | d_model=d_model, 34 | d_input=d_input, 35 | heads=0, 36 | n_synch_out=n_synch_out, 37 | n_synch_action=0, 38 | synapse_depth=synapse_depth, 39 | memory_length=memory_length, 40 | deep_nlms=deep_nlms, 41 | memory_hidden_dims=memory_hidden_dims, 42 | do_layernorm_nlm=do_layernorm_nlm, 43 | backbone_type='none', 44 | positional_embedding_type='none', 45 | out_dims=out_dims, 46 | prediction_reshaper=prediction_reshaper, 47 | dropout=dropout, 48 | dropout_nlm=dropout_nlm, 49 | neuron_select_type=neuron_select_type, 50 | n_random_pairing_self=n_random_pairing_self, 51 | ) 52 | 53 | # --- Use a minimal CTM w/out input (action) synch --- 54 | self.neuron_select_type_action = None 55 | self.synch_representation_size_action = None 56 | 57 | self.attention = None # Should already be None because super(... heads=0... ) 58 | self.q_proj = None # Should already be None because super(... heads=0... ) 59 | self.kv_proj = None # Should already be None because super(... heads=0... ) 60 | 61 | 62 | 63 | 64 | def forward(self, x, track=False): 65 | B = x.size(0) 66 | device = x.device 67 | 68 | # --- Tracking Initialization --- 69 | pre_activations_tracking = [] 70 | post_activations_tracking = [] 71 | synch_out_tracking = [] 72 | attention_tracking = [] 73 | 74 | # --- For SORT: no need to featurise data --- 75 | 76 | 77 | # --- Initialise Recurrent State --- 78 | state_trace = self.start_trace.unsqueeze(0).expand(B, -1, -1) # Shape: (B, H, T) 79 | activated_state = self.start_activated_state.unsqueeze(0).expand(B, -1) # Shape: (B, H) 80 | 81 | # --- Prepare Storage for Outputs per Iteration --- 82 | predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=x.dtype) 83 | certainties = torch.empty(B, 2, self.iterations, device=device, dtype=x.dtype) 84 | 85 | # --- Initialise Recurrent Synch Values --- 86 | r_out = torch.exp(-torch.clamp(self.decay_params_out, 0, 15)).unsqueeze(0).repeat(B, 1) 87 | _, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, None, None, r_out, synch_type='out') 88 | # Compute learned weighting for synchronisation 89 | 90 | 91 | # --- Recurrent Loop --- 92 | for stepi in range(self.iterations): 93 | 94 | pre_synapse_input = torch.concatenate((x, activated_state), dim=-1) 95 | 96 | # --- Apply Synapses --- 97 | state = self.synapses(pre_synapse_input) 98 | # The 'state_trace' is the history of incoming pre-activations 99 | state_trace = torch.cat((state_trace[:, :, 1:], state.unsqueeze(-1)), dim=-1) 100 | 101 | # --- Apply Neuron-Level Models --- 102 | activated_state = self.trace_processor(state_trace) 103 | # One would also keep an 'activated_state_trace' as the history of outgoing post-activations 104 | # BUT, this is unnecessary because the synchronisation calculation is fully linear and can be 105 | # done using only the currect activated state (see compute_synchronisation method for explanation) 106 | 107 | # --- Calculate Synchronisation for Output Predictions --- 108 | synchronisation_out, decay_alpha_out, decay_beta_out = self.compute_synchronisation(activated_state, decay_alpha_out, decay_beta_out, r_out, synch_type='out') 109 | 110 | # --- Get Predictions and Certainties --- 111 | current_prediction = self.output_projector(synchronisation_out) 112 | current_certainty = self.compute_certainty(current_prediction) 113 | 114 | predictions[..., stepi] = current_prediction 115 | certainties[..., stepi] = current_certainty 116 | 117 | # --- Tracking --- 118 | if track: 119 | pre_activations_tracking.append(state_trace[:,:,-1].detach().cpu().numpy()) 120 | post_activations_tracking.append(activated_state.detach().cpu().numpy()) 121 | synch_out_tracking.append(synchronisation_out.detach().cpu().numpy()) 122 | 123 | # --- Return Values --- 124 | if track: 125 | return predictions, certainties, np.array(synch_out_tracking), np.array(pre_activations_tracking), np.array(post_activations_tracking), np.array(attention_tracking) 126 | return predictions, certainties, synchronisation_out 127 | -------------------------------------------------------------------------------- /models/ff.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | # Local imports (Assuming these contain necessary custom modules) 4 | from models.modules import * 5 | from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152 6 | 7 | 8 | class FFBaseline(nn.Module): 9 | """ 10 | LSTM Baseline. 11 | 12 | Wrapper that lets us use the same backbone as the CTM and LSTM baselines, with a 13 | 14 | 15 | Args: 16 | d_model (int): workaround that projects final layer to this space so that parameter-matching is plausible. 17 | backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none'). 18 | out_dims (int): Dimensionality of the final output projection. 19 | dropout (float): dropout in last layer 20 | """ 21 | 22 | def __init__(self, 23 | d_model, 24 | backbone_type, 25 | out_dims, 26 | dropout=0, 27 | ): 28 | super(FFBaseline, self).__init__() 29 | 30 | # --- Core Parameters --- 31 | self.d_model = d_model 32 | self.backbone_type = backbone_type 33 | self.out_dims = out_dims 34 | 35 | # --- Input Assertions --- 36 | assert backbone_type in ['resnet18-1', 'resnet18-2', 'resnet18-3', 'resnet18-4', 37 | 'resnet34-1', 'resnet34-2', 'resnet34-3', 'resnet34-4', 38 | 'resnet50-1', 'resnet50-2', 'resnet50-3', 'resnet50-4', 39 | 'resnet101-1', 'resnet101-2', 'resnet101-3', 'resnet101-4', 40 | 'resnet152-1', 'resnet152-2', 'resnet152-3', 'resnet152-4', 41 | 'none', 'shallow-wide', 'parity_backbone'], f"Invalid backbone_type: {backbone_type}" 42 | 43 | # --- Backbone / Feature Extraction --- 44 | self.initial_rgb = Identity() # Placeholder, potentially replaced if using ResNet 45 | 46 | 47 | self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily 48 | resnet_family = resnet18 # Default 49 | if '34' in self.backbone_type: resnet_family = resnet34 50 | if '50' in self.backbone_type: resnet_family = resnet50 51 | if '101' in self.backbone_type: resnet_family = resnet101 52 | if '152' in self.backbone_type: resnet_family = resnet152 53 | 54 | # Determine which ResNet blocks to keep 55 | block_num_str = self.backbone_type.split('-')[-1] 56 | hyper_blocks_to_keep = list(range(1, int(block_num_str) + 1)) if block_num_str.isdigit() else [1, 2, 3, 4] 57 | 58 | self.backbone = resnet_family( 59 | 3, # initial_rgb handles input channels now 60 | hyper_blocks_to_keep, 61 | stride=2, 62 | pretrained=False, 63 | progress=True, 64 | device="cpu", # Initialise on CPU, move later via .to(device) 65 | do_initial_max_pool=True, 66 | ) 67 | 68 | 69 | # At this point we will have a 4D tensor of features: [B, C, H, W] 70 | # The following lets us scale up the resnet with d_model until it matches the CTM 71 | self.output_projector = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), Squeeze(-1), Squeeze(-1), nn.LazyLinear(d_model), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_model, out_dims)) 72 | 73 | 74 | def forward(self, x): 75 | return self.output_projector((self.backbone(self.initial_rgb(x)))) 76 | -------------------------------------------------------------------------------- /models/lstm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import math 5 | 6 | from models.modules import ParityBackbone, LearnableFourierPositionalEncoding, MultiLearnableFourierPositionalEncoding, CustomRotationalEmbedding, CustomRotationalEmbedding1D, ShallowWide 7 | from models.resnet import prepare_resnet_backbone 8 | from models.utils import compute_normalized_entropy 9 | 10 | from models.constants import ( 11 | VALID_BACKBONE_TYPES, 12 | VALID_POSITIONAL_EMBEDDING_TYPES 13 | ) 14 | 15 | class LSTMBaseline(nn.Module): 16 | """ 17 | LSTM Baseline 18 | 19 | Args: 20 | iterations (int): Number of internal 'thought' steps (T, in paper). 21 | d_model (int): Core dimensionality of the latent space. 22 | d_input (int): Dimensionality of projected attention outputs or direct input features. 23 | heads (int): Number of attention heads. 24 | backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none'). 25 | positional_embedding_type (str): Type of positional embedding for backbone features. 26 | out_dims (int): Dimensionality of the final output projection. 27 | prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific). 28 | dropout (float): Dropout rate. 29 | """ 30 | 31 | def __init__(self, 32 | iterations, 33 | d_model, 34 | d_input, 35 | heads, 36 | backbone_type, 37 | num_layers, 38 | positional_embedding_type, 39 | out_dims, 40 | prediction_reshaper=[-1], 41 | dropout=0, 42 | ): 43 | super(LSTMBaseline, self).__init__() 44 | 45 | # --- Core Parameters --- 46 | self.iterations = iterations 47 | self.d_model = d_model 48 | self.d_input = d_input 49 | self.prediction_reshaper = prediction_reshaper 50 | self.backbone_type = backbone_type 51 | self.positional_embedding_type = positional_embedding_type 52 | self.out_dims = out_dims 53 | 54 | # --- Assertions --- 55 | self.verify_args() 56 | 57 | # --- Input Processing --- 58 | d_backbone = self.get_d_backbone() 59 | 60 | self.set_initial_rgb() 61 | self.set_backbone() 62 | self.positional_embedding = self.get_positional_embedding(d_backbone) 63 | self.kv_proj = self.get_kv_proj() 64 | self.lstm = nn.LSTM(d_input, d_model, num_layers, batch_first=True, dropout=dropout) 65 | self.q_proj = self.get_q_proj() 66 | self.attention = self.get_attention(heads, dropout) 67 | self.output_projector = nn.Sequential(nn.LazyLinear(out_dims)) 68 | 69 | # --- Start States --- 70 | self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((num_layers, d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True)) 71 | self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((num_layers, d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True)) 72 | 73 | 74 | 75 | # --- Core LSTM Methods --- 76 | 77 | def compute_features(self, x): 78 | """Applies backbone and positional embedding to input.""" 79 | x = self.initial_rgb(x) 80 | self.kv_features = self.backbone(x) 81 | pos_emb = self.positional_embedding(self.kv_features) 82 | combined_features = (self.kv_features + pos_emb).flatten(2).transpose(1, 2) 83 | kv = self.kv_proj(combined_features) 84 | return kv 85 | 86 | def compute_certainty(self, current_prediction): 87 | """Compute the certainty of the current prediction.""" 88 | B = current_prediction.size(0) 89 | reshaped_pred = current_prediction.reshape([B] +self.prediction_reshaper) 90 | ne = compute_normalized_entropy(reshaped_pred) 91 | current_certainty = torch.stack((ne, 1-ne), -1) 92 | return current_certainty 93 | 94 | # --- Setup Methods --- 95 | 96 | def set_initial_rgb(self): 97 | """Set the initial RGB processing module based on the backbone type.""" 98 | if 'resnet' in self.backbone_type: 99 | self.initial_rgb = nn.LazyConv2d(3, 1, 1) # Adapts input channels lazily 100 | else: 101 | self.initial_rgb = nn.Identity() 102 | 103 | def get_d_backbone(self): 104 | """ 105 | Get the dimensionality of the backbone output, to be used for positional embedding setup. 106 | 107 | This is a little bit complicated for resnets, but the logic should be easy enough to read below. 108 | """ 109 | if self.backbone_type == 'shallow-wide': 110 | return 2048 111 | elif self.backbone_type == 'parity_backbone': 112 | return self.d_input 113 | elif 'resnet' in self.backbone_type: 114 | if '18' in self.backbone_type or '34' in self.backbone_type: 115 | if self.backbone_type.split('-')[1]=='1': return 64 116 | elif self.backbone_type.split('-')[1]=='2': return 128 117 | elif self.backbone_type.split('-')[1]=='3': return 256 118 | elif self.backbone_type.split('-')[1]=='4': return 512 119 | else: 120 | raise NotImplementedError 121 | else: 122 | if self.backbone_type.split('-')[1]=='1': return 256 123 | elif self.backbone_type.split('-')[1]=='2': return 512 124 | elif self.backbone_type.split('-')[1]=='3': return 1024 125 | elif self.backbone_type.split('-')[1]=='4': return 2048 126 | else: 127 | raise NotImplementedError 128 | elif self.backbone_type == 'none': 129 | return None 130 | else: 131 | raise ValueError(f"Invalid backbone_type: {self.backbone_type}") 132 | 133 | def set_backbone(self): 134 | """Set the backbone module based on the specified type.""" 135 | if self.backbone_type == 'shallow-wide': 136 | self.backbone = ShallowWide() 137 | elif self.backbone_type == 'parity_backbone': 138 | d_backbone = self.get_d_backbone() 139 | self.backbone = ParityBackbone(n_embeddings=2, d_embedding=d_backbone) 140 | elif 'resnet' in self.backbone_type: 141 | self.backbone = prepare_resnet_backbone(self.backbone_type) 142 | elif self.backbone_type == 'none': 143 | self.backbone = nn.Identity() 144 | else: 145 | raise ValueError(f"Invalid backbone_type: {self.backbone_type}") 146 | 147 | def get_positional_embedding(self, d_backbone): 148 | """Get the positional embedding module.""" 149 | if self.positional_embedding_type == 'learnable-fourier': 150 | return LearnableFourierPositionalEncoding(d_backbone, gamma=1 / 2.5) 151 | elif self.positional_embedding_type == 'multi-learnable-fourier': 152 | return MultiLearnableFourierPositionalEncoding(d_backbone) 153 | elif self.positional_embedding_type == 'custom-rotational': 154 | return CustomRotationalEmbedding(d_backbone) 155 | elif self.positional_embedding_type == 'custom-rotational-1d': 156 | return CustomRotationalEmbedding1D(d_backbone) 157 | elif self.positional_embedding_type == 'none': 158 | return lambda x: 0 # Default no-op 159 | else: 160 | raise ValueError(f"Invalid positional_embedding_type: {self.positional_embedding_type}") 161 | 162 | def get_attention(self, heads, dropout): 163 | """Get the attention module.""" 164 | return nn.MultiheadAttention(self.d_input, heads, dropout, batch_first=True) 165 | 166 | def get_kv_proj(self): 167 | """Get the key-value projection module.""" 168 | return nn.Sequential(nn.LazyLinear(self.d_input), nn.LayerNorm(self.d_input)) 169 | 170 | def get_q_proj(self): 171 | """Get the query projection module.""" 172 | return nn.LazyLinear(self.d_input) 173 | 174 | 175 | def verify_args(self): 176 | """Verify the validity of the input arguments.""" 177 | 178 | assert self.backbone_type in VALID_BACKBONE_TYPES + ['none'], \ 179 | f"Invalid backbone_type: {self.backbone_type}" 180 | 181 | assert self.positional_embedding_type in VALID_POSITIONAL_EMBEDDING_TYPES + ['none'], \ 182 | f"Invalid positional_embedding_type: {self.positional_embedding_type}" 183 | 184 | if self.backbone_type=='none' and self.positional_embedding_type!='none': 185 | raise AssertionError("There should be no positional embedding if there is no backbone.") 186 | 187 | pass 188 | 189 | 190 | 191 | 192 | def forward(self, x, track=False): 193 | """ 194 | Forward pass - Reverted to structure closer to user's working version. 195 | Executes T=iterations steps. 196 | """ 197 | B = x.size(0) 198 | device = x.device 199 | 200 | # --- Tracking Initialization --- 201 | activations_tracking = [] 202 | attention_tracking = [] 203 | 204 | # --- Featurise Input Data --- 205 | kv = self.compute_features(x) 206 | 207 | # --- Initialise Recurrent State --- 208 | hn = torch.repeat_interleave(self.start_hidden_state.unsqueeze(1), x.size(0), 1) 209 | cn = torch.repeat_interleave(self.start_cell_state.unsqueeze(1), x.size(0), 1) 210 | state_trace = [hn[-1]] 211 | 212 | # --- Prepare Storage for Outputs per Iteration --- 213 | predictions = torch.empty(B, self.out_dims, self.iterations, device=device, dtype=x.dtype) 214 | certainties = torch.empty(B, 2, self.iterations, device=device, dtype=x.dtype) 215 | 216 | # --- Recurrent Loop --- 217 | for stepi in range(self.iterations): 218 | 219 | # --- Interact with Data via Attention --- 220 | q = self.q_proj(hn[-1].unsqueeze(1)) 221 | attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True) 222 | lstm_input = attn_out 223 | 224 | # --- Apply LSTM --- 225 | hidden_state, (hn,cn) = self.lstm(lstm_input, (hn, cn)) 226 | hidden_state = hidden_state.squeeze(1) 227 | state_trace.append(hidden_state) 228 | 229 | # --- Get Predictions and Certainties --- 230 | current_prediction = self.output_projector(hidden_state) 231 | current_certainty = self.compute_certainty(current_prediction) 232 | 233 | predictions[..., stepi] = current_prediction 234 | certainties[..., stepi] = current_certainty 235 | 236 | # --- Tracking --- 237 | if track: 238 | activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy()) 239 | attention_tracking.append(attn_weights.detach().cpu().numpy()) 240 | 241 | # --- Return Values --- 242 | if track: 243 | return predictions, certainties, None, np.zeros_like(activations_tracking), np.array(activations_tracking), np.array(attention_tracking) 244 | return predictions, certainties, None -------------------------------------------------------------------------------- /models/lstm_qamnist.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F # Used for GLU if not in modules 4 | import numpy as np 5 | import math 6 | 7 | # Local imports (Assuming these contain necessary custom modules) 8 | from models.modules import * 9 | from models.utils import * # Assuming compute_decay, compute_normalized_entropy are here 10 | 11 | class LSTMBaseline(nn.Module): 12 | """ 13 | LSTM Baseline 14 | 15 | Args: 16 | iterations (int): Number of internal 'thought' steps (T, in paper). 17 | d_model (int): Core dimensionality of the CTM's latent space (D, in paper). 18 | d_input (int): Dimensionality of projected attention outputs or direct input features. 19 | heads (int): Number of attention heads. 20 | n_synch_out (int): Number of neurons used for output synchronisation (No, in paper). 21 | n_synch_action (int): Number of neurons used for action/attention synchronisation (Ni, in paper). 22 | synapse_depth (int): Depth of the synapse model (U-Net if > 1, else MLP). 23 | memory_length (int): History length for Neuron-Level Models (M, in paper). 24 | deep_nlms (bool): Use deeper (2-layer) NLMs if True, else linear. 25 | memory_hidden_dims (int): Hidden dimension size for deep NLMs. 26 | do_layernorm_nlm (bool): Apply LayerNorm within NLMs. 27 | backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none'). 28 | positional_embedding_type (str): Type of positional embedding for backbone features. 29 | out_dims (int): Dimensionality of the final output projection. 30 | prediction_reshaper (list): Shape for reshaping predictions before certainty calculation (task-specific). 31 | dropout (float): Dropout rate. 32 | """ 33 | 34 | def __init__(self, 35 | iterations, 36 | d_model, 37 | d_input, 38 | heads, 39 | out_dims, 40 | iterations_per_digit, 41 | iterations_per_question_part, 42 | iterations_for_answering, 43 | prediction_reshaper=[-1], 44 | dropout=0, 45 | ): 46 | super(LSTMBaseline, self).__init__() 47 | 48 | # --- Core Parameters --- 49 | self.iterations = iterations 50 | self.d_model = d_model 51 | self.prediction_reshaper = prediction_reshaper 52 | self.out_dims = out_dims 53 | self.d_input = d_input 54 | self.backbone_type = 'qamnist_backbone' 55 | self.iterations_per_digit = iterations_per_digit 56 | self.iterations_per_question_part = iterations_per_question_part 57 | self.total_iterations_for_answering = iterations_for_answering 58 | 59 | # --- Backbone / Feature Extraction --- 60 | self.backbone_digit = MNISTBackbone(d_input) 61 | self.index_backbone = QAMNISTIndexEmbeddings(50, d_input) 62 | self.operator_backbone = QAMNISTOperatorEmbeddings(2, d_input) 63 | 64 | # --- Core CTM Modules --- 65 | self.lstm_cell = nn.LSTMCell(d_input, d_model) 66 | self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True)) 67 | self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True)) 68 | 69 | # Attention 70 | self.q_proj = nn.LazyLinear(d_input) 71 | self.kv_proj = nn.Sequential(nn.LazyLinear(d_input), nn.LayerNorm(d_input)) 72 | self.attention = nn.MultiheadAttention(d_input, heads, dropout, batch_first=True) 73 | 74 | # Output Projection 75 | self.output_projector = nn.Sequential(nn.LazyLinear(out_dims)) 76 | 77 | def compute_certainty(self, current_prediction): 78 | """Compute the certainty of the current prediction.""" 79 | B = current_prediction.size(0) 80 | reshaped_pred = current_prediction.reshape([B] +self.prediction_reshaper) 81 | ne = compute_normalized_entropy(reshaped_pred) 82 | current_certainty = torch.stack((ne, 1-ne), -1) 83 | return current_certainty 84 | 85 | def get_kv_for_step(self, stepi, x, z, thought_steps, prev_input=None, prev_kv=None): 86 | is_digit_step, is_question_step, is_answer_step = thought_steps.determine_step_type(stepi) 87 | 88 | if is_digit_step: 89 | current_input = x[:, stepi] 90 | if prev_input is not None and torch.equal(current_input, prev_input): 91 | return prev_kv, prev_input 92 | kv = self.kv_proj(self.backbone_digit(current_input).flatten(2).permute(0, 2, 1)) 93 | 94 | elif is_question_step: 95 | offset = stepi - thought_steps.total_iterations_for_digits 96 | current_input = z[:, offset].squeeze(0) 97 | if prev_input is not None and torch.equal(current_input, prev_input): 98 | return prev_kv, prev_input 99 | is_index_step, is_operator_step = thought_steps.determine_answer_step_type(stepi) 100 | if is_index_step: 101 | kv = self.kv_proj(self.index_backbone(current_input)) 102 | elif is_operator_step: 103 | kv = self.kv_proj(self.operator_backbone(current_input)) 104 | else: 105 | raise ValueError("Invalid step type for question processing.") 106 | 107 | elif is_answer_step: 108 | current_input = None 109 | kv = torch.zeros((x.size(0), self.d_input), device=x.device) 110 | 111 | else: 112 | raise ValueError("Invalid step type.") 113 | 114 | return kv, current_input 115 | 116 | def forward(self, x, z, track=False): 117 | """ 118 | Forward pass - Reverted to structure closer to user's working version. 119 | Executes T=iterations steps. 120 | """ 121 | B = x.size(0) # Batch size 122 | 123 | # --- Tracking Initialization --- 124 | activations_tracking = [] 125 | attention_tracking = [] # Note: reshaping this correctly requires knowing num_heads 126 | embedding_tracking = [] 127 | 128 | thought_steps = ThoughtSteps(self.iterations_per_digit, self.iterations_per_question_part, self.total_iterations_for_answering, x.size(1), z.size(1)) 129 | 130 | # --- Step 2: Initialise Recurrent State --- 131 | hidden_state = torch.repeat_interleave(self.start_hidden_state.unsqueeze(0), x.size(0), 0) 132 | cell_state = torch.repeat_interleave(self.start_cell_state.unsqueeze(0), x.size(0), 0) 133 | 134 | state_trace = [hidden_state] 135 | 136 | device = hidden_state.device 137 | 138 | # Storage for outputs per iteration 139 | predictions = torch.empty(B, self.out_dims, thought_steps.total_iterations, device=device, dtype=x.dtype) # Adjust dtype if needed 140 | certainties = torch.empty(B, 2, thought_steps.total_iterations, device=device, dtype=x.dtype) # Adjust dtype if needed 141 | 142 | prev_input = None 143 | prev_kv = None 144 | 145 | # --- Recurrent Loop (T=iterations steps) --- 146 | for stepi in range(thought_steps.total_iterations): 147 | 148 | is_digit_step, is_question_step, is_answer_step = thought_steps.determine_step_type(stepi) 149 | kv, prev_input = self.get_kv_for_step(stepi, x, z, thought_steps, prev_input, prev_kv) 150 | prev_kv = kv 151 | 152 | # --- Interact with Data via Attention --- 153 | attn_weights = None 154 | if is_digit_step: 155 | q = self.q_proj(hidden_state).unsqueeze(1) 156 | attn_out, attn_weights = self.attention(q, kv, kv, average_attn_weights=False, need_weights=True) 157 | lstm_input = attn_out.squeeze(1) 158 | else: 159 | lstm_input = kv 160 | 161 | 162 | 163 | hidden_state, cell_state = self.lstm_cell(lstm_input.squeeze(1), (hidden_state, cell_state)) 164 | state_trace.append(hidden_state) 165 | 166 | # --- Get Predictions and Certainties --- 167 | current_prediction = self.output_projector(hidden_state) 168 | current_certainty = self.compute_certainty(current_prediction) 169 | 170 | predictions[..., stepi] = current_prediction 171 | certainties[..., stepi] = current_certainty 172 | 173 | # --- Tracking --- 174 | if track: 175 | activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy()) 176 | if attn_weights is not None: 177 | attention_tracking.append(attn_weights.detach().cpu().numpy()) 178 | if is_question_step: 179 | embedding_tracking.append(kv.detach().cpu().numpy()) 180 | 181 | # --- Return Values --- 182 | if track: 183 | return predictions, certainties, None, np.array(activations_tracking), np.array(activations_tracking), np.array(attention_tracking), np.array(embedding_tracking) 184 | return predictions, certainties, None -------------------------------------------------------------------------------- /models/lstm_rl.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F # Used for GLU if not in modules 4 | import numpy as np 5 | import math 6 | 7 | # Local imports (Assuming these contain necessary custom modules) 8 | from models.modules import * 9 | from models.utils import * # Assuming compute_decay, compute_normalized_entropy are here 10 | 11 | 12 | class LSTMBaseline(nn.Module): 13 | """ 14 | 15 | LSTM Baseline 16 | 17 | Args: 18 | iterations (int): Number of internal 'thought' steps (T, in paper). 19 | d_model (int): Core dimensionality of the CTM's latent space (D, in paper). 20 | d_input (int): Dimensionality of projected attention outputs or direct input features. 21 | backbone_type (str): Type of feature extraction backbone (e.g., 'resnet18-2', 'none'). 22 | """ 23 | 24 | def __init__(self, 25 | iterations, 26 | d_model, 27 | d_input, 28 | backbone_type, 29 | ): 30 | super(LSTMBaseline, self).__init__() 31 | 32 | # --- Core Parameters --- 33 | self.iterations = iterations 34 | self.d_model = d_model 35 | self.backbone_type = backbone_type 36 | 37 | # --- Input Assertions --- 38 | assert backbone_type in ('navigation-backbone', 'classic-control-backbone'), f"Invalid backbone_type: {backbone_type}" 39 | 40 | # --- Backbone / Feature Extraction --- 41 | if self.backbone_type == 'navigation-backbone': 42 | grid_size = 7 43 | self.backbone = MiniGridBackbone(d_input=d_input, grid_size=grid_size) 44 | lstm_cell_input_dim = grid_size * grid_size * d_input 45 | 46 | elif self.backbone_type == 'classic-control-backbone': 47 | self.backbone = ClassicControlBackbone(d_input=d_input) 48 | lstm_cell_input_dim = d_input 49 | 50 | else: 51 | raise NotImplemented('The only backbone supported for RL are for navigation (symbolic C x H x W inputs) and classic control (vectors of length D).') 52 | 53 | # --- Core LSTM Modules --- 54 | self.lstm_cell = nn.LSTMCell(lstm_cell_input_dim, d_model) 55 | self.register_parameter('start_hidden_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True)) 56 | self.register_parameter('start_cell_state', nn.Parameter(torch.zeros((d_model)).uniform_(-math.sqrt(1/(d_model)), math.sqrt(1/(d_model))), requires_grad=True)) 57 | 58 | def compute_features(self, x): 59 | """Applies backbone and positional embedding to input.""" 60 | return self.backbone(x) 61 | 62 | 63 | def forward(self, x, hidden_states, track=False): 64 | """ 65 | Forward pass - Reverted to structure closer to user's working version. 66 | Executes T=iterations steps. 67 | """ 68 | 69 | # --- Tracking Initialization --- 70 | activations_tracking = [] 71 | 72 | # --- Featurise Input Data --- 73 | features = self.compute_features(x) 74 | 75 | hidden_state = hidden_states[0] 76 | cell_state = hidden_states[1] 77 | 78 | # --- Recurrent Loop --- 79 | for stepi in range(self.iterations): 80 | 81 | lstm_input = features.reshape(x.size(0), -1) 82 | hidden_state, cell_state = self.lstm_cell(lstm_input.squeeze(1), (hidden_state, cell_state)) 83 | 84 | # --- Tracking --- 85 | if track: 86 | activations_tracking.append(hidden_state.squeeze(1).detach().cpu().numpy()) 87 | 88 | hidden_states = ( 89 | hidden_state, 90 | cell_state 91 | ) 92 | 93 | # --- Return Values --- 94 | if track: 95 | return hidden_state, hidden_states, np.array(activations_tracking), np.array(activations_tracking) 96 | return hidden_state, hidden_states -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import re 4 | import os 5 | 6 | def compute_decay(T, params, clamp_lims=(0, 15)): 7 | """ 8 | This function computes exponential decays for learnable synchronisation 9 | interactions between pairs of neurons. 10 | """ 11 | assert len(clamp_lims), 'Clamp lims should be length 2' 12 | assert type(clamp_lims) == tuple, 'Clamp lims should be tuple' 13 | 14 | indices = torch.arange(T-1, -1, -1, device=params.device).reshape(T, 1).expand(T, params.shape[0]) 15 | out = torch.exp(-indices * torch.clamp(params, clamp_lims[0], clamp_lims[1]).unsqueeze(0)) 16 | return out 17 | 18 | def add_coord_dim(x, scaled=True): 19 | """ 20 | Adds a final dimension to the tensor representing 2D coordinates. 21 | 22 | Args: 23 | tensor: A PyTorch tensor of shape (B, D, H, W). 24 | 25 | Returns: 26 | A PyTorch tensor of shape (B, D, H, W, 2) with the last dimension 27 | representing the 2D coordinates within the HW dimensions. 28 | """ 29 | B, H, W = x.shape 30 | # Create coordinate grids 31 | x_coords = torch.arange(W, device=x.device, dtype=x.dtype).repeat(H, 1) # Shape (H, W) 32 | y_coords = torch.arange(H, device=x.device, dtype=x.dtype).unsqueeze(-1).repeat(1, W) # Shape (H, W) 33 | if scaled: 34 | x_coords /= (W-1) 35 | y_coords /= (H-1) 36 | # Stack coordinates and expand dimensions 37 | coords = torch.stack((x_coords, y_coords), dim=-1) # Shape (H, W, 2) 38 | coords = coords.unsqueeze(0) # Shape (1, 1, H, W, 2) 39 | coords = coords.repeat(B, 1, 1, 1) # Shape (B, D, H, W, 2) 40 | return coords 41 | 42 | def compute_normalized_entropy(logits, reduction='mean'): 43 | """ 44 | Calculates the normalized entropy of a PyTorch tensor of logits along the 45 | final dimension. 46 | 47 | Args: 48 | logits: A PyTorch tensor of logits. 49 | 50 | Returns: 51 | A PyTorch tensor containing the normalized entropy values. 52 | """ 53 | 54 | # Apply softmax to get probabilities 55 | preds = F.softmax(logits, dim=-1) 56 | 57 | # Calculate the log probabilities 58 | log_preds = torch.log_softmax(logits, dim=-1) 59 | 60 | # Calculate the entropy 61 | entropy = -torch.sum(preds * log_preds, dim=-1) 62 | 63 | # Calculate the maximum possible entropy 64 | num_classes = preds.shape[-1] 65 | max_entropy = torch.log(torch.tensor(num_classes, dtype=torch.float32)) 66 | 67 | # Normalize the entropy 68 | normalized_entropy = entropy / max_entropy 69 | if len(logits.shape)>2 and reduction == 'mean': 70 | normalized_entropy = normalized_entropy.flatten(1).mean(-1) 71 | 72 | return normalized_entropy 73 | 74 | def reshape_predictions(predictions, prediction_reshaper): 75 | B, T = predictions.size(0), predictions.size(-1) 76 | new_shape = [B] + prediction_reshaper + [T] 77 | rehaped_predictions = predictions.reshape(new_shape) 78 | return rehaped_predictions 79 | 80 | def get_all_log_dirs(root_dir): 81 | folders = [] 82 | for dirpath, dirnames, filenames in os.walk(root_dir): 83 | if any(f.endswith(".pt") for f in filenames): 84 | folders.append(dirpath) 85 | return folders 86 | 87 | def get_latest_checkpoint(log_dir): 88 | files = [f for f in os.listdir(log_dir) if re.match(r'checkpoint_\d+\.pt', f)] 89 | return os.path.join(log_dir, max(files, key=lambda f: int(re.search(r'\d+', f).group()))) if files else None 90 | 91 | def get_latest_checkpoint_file(filepath, limit=300000): 92 | checkpoint_files = get_checkpoint_files(filepath) 93 | checkpoint_files = [ 94 | f for f in checkpoint_files if int(re.search(r'checkpoint_(\d+)\.pt', f).group(1)) <= limit 95 | ] 96 | if not checkpoint_files: 97 | return None 98 | return checkpoint_files[-1] 99 | 100 | def get_checkpoint_files(filepath): 101 | regex = r'checkpoint_(\d+)\.pt' 102 | files = [f for f in os.listdir(filepath) if re.match(regex, f)] 103 | files = sorted(files, key=lambda f: int(re.search(regex, f).group(1))) 104 | return [os.path.join(filepath, f) for f in files] 105 | 106 | def load_checkpoint(checkpoint_path, device): 107 | checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) 108 | return checkpoint 109 | 110 | def get_model_args_from_checkpoint(checkpoint): 111 | if "args" in checkpoint: 112 | return(checkpoint["args"]) 113 | else: 114 | raise ValueError("Checkpoint does not contain saved args.") 115 | 116 | def get_accuracy_and_loss_from_checkpoint(checkpoint, device="cpu"): 117 | training_iteration = checkpoint.get('training_iteration', 0) 118 | train_losses = checkpoint.get('train_losses', []) 119 | test_losses = checkpoint.get('test_losses', []) 120 | train_accuracies = checkpoint.get('train_accuracies_most_certain', []) 121 | test_accuracies = checkpoint.get('test_accuracies_most_certain', []) 122 | return training_iteration, train_losses, test_losses, train_accuracies, test_accuracies 123 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | torchvision 4 | matplotlib 5 | seaborn 6 | tqdm 7 | opencv-python 8 | imageio 9 | scikit-learn 10 | umap-learn 11 | python-dotenv 12 | gymnasium 13 | minigrid 14 | datasets 15 | autoclip -------------------------------------------------------------------------------- /tasks/image_classification/README.md: -------------------------------------------------------------------------------- 1 | # Image classification 2 | 3 | This folder contains code for training and analysing imagenet and cifar related experiments. 4 | 5 | ## Accessing and loading imagenet 6 | 7 | We use the [ILSRC/imagenet-1k](https://huggingface.co/datasets/ILSVRC/imagenet-1k) dataset in our paper. 8 | 9 | To get this to work for you, you will need to do the following: 10 | 1. Login to huggingface (make an account) to agree to TCs of this dataset, 11 | 2. Make a new access token. 12 | 3. Install huggingface_hub on the target machine with ```pip install huggingface_hub``` 13 | 4. Run ```huggingface-cli login``` and use your token. This will authenticate you on the backend and allow the code to run. 14 | 5. Simply run an imagenet experiment. It will auto download and do all that magic. 15 | 16 | 17 | ## Training 18 | There are two training files: `train.py` and `train_distributed.py`. The training code uses mixed precision. For the settings in the paper, the following command was used for distributed training: 19 | 20 | ``` 21 | torchrun --standalone --nnodes=1 --nproc_per_node=8 -m tasks.image_classification.train_distributed --d_model 4096 --d_input 1024 --synapse_depth 12 --heads 16 --n_synch_out 150 --n_synch_action 150 --neuron_select_type random --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 64 --dropout 0.05 --no-do_normalisation --positional_embedding_type none --backbone_type resnet152-4 --batch_size 60 --batch_size_test 64 --lr 5e-4 --training_iterations 500001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs-lambda/imagenet-distributed-4april/d=4096--i=1024--h=16--ns=150-random--iters=75x25--h=64--drop=0.05--pos=none--back=152x4--seed=42 --dataset imagenet --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50 --use_amp 22 | ``` 23 | 24 | You can run the same setup on a single GPU with: 25 | ``` 26 | python -m tasks.image_classification.train --d_model 4096 --d_input 1024 --synapse_depth 12 --heads 16 --n_synch_out 150 --n_synch_action 150 --neuron_select_type random --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 64 --dropout 0.05 --no-do_normalisation --positional_embedding_type none --backbone_type resnet152-4 --batch_size 60 --batch_size_test 64 --lr 5e-4 --training_iterations 500001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs-lambda/imagenet-distributed-4april/d=4096--i=1024--h=16--ns=150-random--iters=75x25--h=64--drop=0.05--pos=none--back=152x4--seed=42 --dataset imagenet --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50 --use_amp --device 0 27 | ``` 28 | 29 | ## Checkpoint 30 | 31 | The checkpoint for the model used in the paper can be found [here](https://drive.google.com/file/d/1Lr_3RZU9X9SS8lBhAhECBiSZDKfKhDkJ/view?usp=drive_link). 32 | -------------------------------------------------------------------------------- /tasks/image_classification/analysis/README.md: -------------------------------------------------------------------------------- 1 | # Analysis 2 | 3 | This folder contains the analysis code for the image classifcation experiments. Running the following from the base directory will generate figures, gifs and mp4 files: 4 | 5 | ``` 6 | python -m tasks.image_classification.analysis.run_imagenet_analysis 7 | ``` -------------------------------------------------------------------------------- /tasks/image_classification/scripts/train_cifar10.sh: -------------------------------------------------------------------------------- 1 | python -m tasks.image_classification.train \ 2 | --log_dir logs/cifar10-versus-humans/ctm/d=256--i=64--heads=16--sd=5--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=1 \ 3 | --model ctm 4 | --dataset cifar10 \ 5 | --d_model 256 \ 6 | --d_input 64 \ 7 | --synapse_depth 5 \ 8 | --heads 16 \ 9 | --n_synch_out 256 \ 10 | --n_synch_action 512 \ 11 | --n_random_pairing_self 0 \ 12 | --neuron_select_type random-pairing \ 13 | --iterations 50 \ 14 | --memory_length 15 \ 15 | --deep_memory \ 16 | --memory_hidden_dims 64 \ 17 | --dropout 0.0 \ 18 | --dropout_nlm 0 \ 19 | --no-do_normalisation \ 20 | --positional_embedding_type none \ 21 | --backbone_type resnet18-1 \ 22 | --training_iterations 600001 \ 23 | --warmup_steps 1000 \ 24 | --use_scheduler \ 25 | --scheduler_type cosine \ 26 | --weight_decay 0.0001 \ 27 | --save_every 1000 \ 28 | --track_every 2000 \ 29 | --n_test_batches 50 \ 30 | --num_workers_train 8 \ 31 | --batch_size 512 \ 32 | --batch_size_test 512 \ 33 | --lr 1e-4 \ 34 | --device 0 \ 35 | --seed 1 36 | 37 | 38 | python -m tasks.image_classification.train \ 39 | --log_dir logs/cifar10-versus-humans/ctm/d=256--i=64--heads=16--sd=5--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=2 \ 40 | --model ctm 41 | --dataset cifar10 \ 42 | --d_model 256 \ 43 | --d_input 64 \ 44 | --synapse_depth 5 \ 45 | --heads 16 \ 46 | --n_synch_out 256 \ 47 | --n_synch_action 512 \ 48 | --n_random_pairing_self 0 \ 49 | --neuron_select_type random-pairing \ 50 | --iterations 50 \ 51 | --memory_length 15 \ 52 | --deep_memory \ 53 | --memory_hidden_dims 64 \ 54 | --dropout 0.0 \ 55 | --dropout_nlm 0 \ 56 | --no-do_normalisation \ 57 | --positional_embedding_type none \ 58 | --backbone_type resnet18-1 \ 59 | --training_iterations 600001 \ 60 | --warmup_steps 1000 \ 61 | --use_scheduler \ 62 | --scheduler_type cosine \ 63 | --weight_decay 0.0001 \ 64 | --save_every 1000 \ 65 | --track_every 2000 \ 66 | --n_test_batches 50 \ 67 | --num_workers_train 8 \ 68 | --batch_size 512 \ 69 | --batch_size_test 512 \ 70 | --lr 1e-4 \ 71 | --device 0 \ 72 | --seed 2 73 | 74 | python -m tasks.image_classification.train \ 75 | --log_dir logs/cifar10-versus-humans/ctm/d=256--i=64--heads=16--sd=5--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=42 \ 76 | --model ctm 77 | --dataset cifar10 \ 78 | --d_model 256 \ 79 | --d_input 64 \ 80 | --synapse_depth 5 \ 81 | --heads 16 \ 82 | --n_synch_out 256 \ 83 | --n_synch_action 512 \ 84 | --n_random_pairing_self 0 \ 85 | --neuron_select_type random-pairing \ 86 | --iterations 50 \ 87 | --memory_length 15 \ 88 | --deep_memory \ 89 | --memory_hidden_dims 64 \ 90 | --dropout 0.0 \ 91 | --dropout_nlm 0 \ 92 | --no-do_normalisation \ 93 | --positional_embedding_type none \ 94 | --backbone_type resnet18-1 \ 95 | --training_iterations 600001 \ 96 | --warmup_steps 1000 \ 97 | --use_scheduler \ 98 | --scheduler_type cosine \ 99 | --weight_decay 0.0001 \ 100 | --save_every 1000 \ 101 | --track_every 2000 \ 102 | --n_test_batches 50 \ 103 | --num_workers_train 8 \ 104 | --batch_size 512 \ 105 | --batch_size_test 512 \ 106 | --lr 1e-4 \ 107 | --device 0 \ 108 | --seed 42 109 | 110 | 111 | 112 | 113 | 114 | 115 | python -m tasks.image_classification.train \ 116 | --log_dir logs/cifar10-versus-humans/lstm/nlayers=2--d=256--i=64--heads=16--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=1 \ 117 | --dataset cifar10 \ 118 | --model lstm \ 119 | --num_layers 2 \ 120 | --d_model 256 \ 121 | --d_input 64 \ 122 | --heads 16 \ 123 | --iterations 50 \ 124 | --dropout 0.0 \ 125 | --positional_embedding_type none \ 126 | --backbone_type resnet18-1 \ 127 | --training_iterations 600001 \ 128 | --warmup_steps 2000 \ 129 | --use_scheduler \ 130 | --scheduler_type cosine \ 131 | --weight_decay 0.0001 \ 132 | --save_every 1000 \ 133 | --track_every 2000 \ 134 | --n_test_batches 50 \ 135 | --reload \ 136 | --num_workers_train 8 \ 137 | --batch_size 512 \ 138 | --batch_size_test 512 \ 139 | --lr 1e-4 \ 140 | --device 0 \ 141 | --seed 1 \ 142 | --no-reload 143 | 144 | 145 | python -m tasks.image_classification.train \ 146 | --log_dir logs/cifar10-versus-humans/lstm/nlayers=2--d=256--i=64--heads=16--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=2 \ 147 | --dataset cifar10 \ 148 | --model lstm \ 149 | --num_layers 2 \ 150 | --d_model 256 \ 151 | --d_input 64 \ 152 | --heads 16 \ 153 | --iterations 50 \ 154 | --dropout 0.0 \ 155 | --positional_embedding_type none \ 156 | --backbone_type resnet18-1 \ 157 | --training_iterations 600001 \ 158 | --warmup_steps 2000 \ 159 | --use_scheduler \ 160 | --scheduler_type cosine \ 161 | --weight_decay 0.0001 \ 162 | --save_every 1000 \ 163 | --track_every 2000 \ 164 | --n_test_batches 50 \ 165 | --reload \ 166 | --num_workers_train 8 \ 167 | --batch_size 512 \ 168 | --batch_size_test 512 \ 169 | --lr 1e-4 \ 170 | --device 0 \ 171 | --seed 2 \ 172 | --no-reload 173 | 174 | 175 | python -m tasks.image_classification.train \ 176 | --log_dir logs/cifar10-versus-humans/lstm/nlayers=2--d=256--i=64--heads=16--synch=256-512-0-h=64-random-pairing--iters=50x15--backbone=18-1--seed=42 \ 177 | --dataset cifar10 \ 178 | --model lstm \ 179 | --num_layers 2 \ 180 | --d_model 256 \ 181 | --d_input 64 \ 182 | --heads 16 \ 183 | --iterations 50 \ 184 | --dropout 0.0 \ 185 | --positional_embedding_type none \ 186 | --backbone_type resnet18-1 \ 187 | --training_iterations 600001 \ 188 | --warmup_steps 2000 \ 189 | --use_scheduler \ 190 | --scheduler_type cosine \ 191 | --weight_decay 0.0001 \ 192 | --save_every 1000 \ 193 | --track_every 2000 \ 194 | --n_test_batches 50 \ 195 | --reload \ 196 | --num_workers_train 8 \ 197 | --batch_size 512 \ 198 | --batch_size_test 512 \ 199 | --lr 1e-4 \ 200 | --device 0 \ 201 | --seed 42 \ 202 | --no-reload 203 | 204 | 205 | 206 | 207 | 208 | python -m tasks.image_classification.train \ 209 | --log_dir logs/cifar10-versus-humans/ff/d=256--backbone=18-1--seed=1 \ 210 | --dataset cifar10 \ 211 | --model ff \ 212 | --d_model 256 \ 213 | --memory_hidden_dims 64 \ 214 | --dropout 0.0 \ 215 | --dropout_nlm 0 \ 216 | --backbone_type resnet18-1 \ 217 | --training_iterations 600001 \ 218 | --warmup_steps 1000 \ 219 | --use_scheduler \ 220 | --scheduler_type cosine \ 221 | --weight_decay 0.0001 \ 222 | --save_every 1000 \ 223 | --track_every 2000 \ 224 | --n_test_batches 50 \ 225 | --num_workers_train 8 \ 226 | --batch_size 512 \ 227 | --batch_size_test 512 \ 228 | --lr 1e-4 \ 229 | --device 0 \ 230 | --seed 1 231 | 232 | 233 | python -m tasks.image_classification.train \ 234 | --log_dir logs/cifar10-versus-humans/ff/d=256--backbone=18-1--seed=2 \ 235 | --dataset cifar10 \ 236 | --model ff \ 237 | --d_model 256 \ 238 | --memory_hidden_dims 64 \ 239 | --dropout 0.0 \ 240 | --dropout_nlm 0 \ 241 | --backbone_type resnet18-1 \ 242 | --training_iterations 600001 \ 243 | --warmup_steps 1000 \ 244 | --use_scheduler \ 245 | --scheduler_type cosine \ 246 | --weight_decay 0.0001 \ 247 | --save_every 1000 \ 248 | --track_every 2000 \ 249 | --n_test_batches 50 \ 250 | --num_workers_train 8 \ 251 | --batch_size 512 \ 252 | --batch_size_test 512 \ 253 | --lr 1e-4 \ 254 | --device 0 \ 255 | --seed 2 256 | 257 | python -m tasks.image_classification.train \ 258 | --log_dir logs/cifar10-versus-humans/ff/d=256--backbone=18-1--seed=42 \ 259 | --dataset cifar10 \ 260 | --model ff \ 261 | --d_model 256 \ 262 | --memory_hidden_dims 64 \ 263 | --dropout 0.0 \ 264 | --dropout_nlm 0 \ 265 | --backbone_type resnet18-1 \ 266 | --training_iterations 600001 \ 267 | --warmup_steps 1000 \ 268 | --use_scheduler \ 269 | --scheduler_type cosine \ 270 | --weight_decay 0.0001 \ 271 | --save_every 1000 \ 272 | --track_every 2000 \ 273 | --n_test_batches 50 \ 274 | --num_workers_train 8 \ 275 | --batch_size 512 \ 276 | --batch_size_test 512 \ 277 | --lr 1e-4 \ 278 | --device 0 \ 279 | --seed 42 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | -------------------------------------------------------------------------------- /tasks/image_classification/scripts/train_imagenet.sh: -------------------------------------------------------------------------------- 1 | torchrun --standalone --nnodes=1 --nproc_per_node=8 -m tasks.image_classification.train_distributed \ 2 | --log_dir logs/imagenet/d=4096--i=1024--heads=16--sd=8--nlm=64--synch=8192-2048-32-h=64-random-pairing--iters=50x25--backbone=152x4 \ 3 | --model ctm \ 4 | --dataset imagenet \ 5 | --d_model 4096 \ 6 | --d_input 1024 \ 7 | --synapse_depth 8 \ 8 | --heads 16 \ 9 | --n_synch_out 8196 \ 10 | --n_synch_action 2048 \ 11 | --n_random_pairing_self 32 \ 12 | --neuron_select_type random-pairing \ 13 | --iterations 50 \ 14 | --memory_length 25 \ 15 | --deep_memory \ 16 | --memory_hidden_dims 64 \ 17 | --dropout 0.2 \ 18 | --dropout_nlm 0 \ 19 | --no-do_normalisation \ 20 | --positional_embedding_type none \ 21 | --backbone_type resnet152-4 \ 22 | --batch_size 64 \ 23 | --batch_size_test 64 \ 24 | --n_test_batches 200 \ 25 | --lr 5e-4 \ 26 | --gradient_clipping 20 \ 27 | --training_iterations 500001 \ 28 | --save_every 1000 \ 29 | --track_every 5000 \ 30 | --warmup_steps 10000 \ 31 | --use_scheduler \ 32 | --scheduler_type cosine \ 33 | --weight_decay 0.0 \ 34 | --seed 1 \ 35 | --use_amp \ 36 | --reload \ 37 | --num_workers_train 8 \ 38 | --use_custom_sampler 39 | -------------------------------------------------------------------------------- /tasks/mazes/README.md: -------------------------------------------------------------------------------- 1 | # Mazes 2 | 3 | This folder contains code for training and analysing 2D maze solving experiments 4 | 5 | 6 | ## Training 7 | To run the maze training that we used for the paper, run the following command from the parent directory: 8 | ``` 9 | python -m tasks.mazes.train --d_model 2048 --d_input 512 --synapse_depth 4 --heads 8 --n_synch_out 64 --n_synch_action 32 --neuron_select_type first-last --iterations 75 --memory_length 25 --deep_memory --memory_hidden_dims 32 --dropout 0.1 --no-do_normalisation --positional_embedding_type none --backbone_type resnet34-2 --batch_size 64 --batch_size_test 64 --lr 1e-4 --training_iterations 1000001 --warmup_steps 10000 --use_scheduler --scheduler_type cosine --weight_decay 0.0 --log_dir logs/mazes/d=2048--i=512--h=8--ns=64-32--iters=75x25--h=32--drop=0.1--pos=none--back=34-2--seed=42 --dataset mazes-medium --save_every 2000 --track_every 5000 --seed 42 --n_test_batches 50 10 | ``` 11 | 12 | ## Small training run 13 | We also provide a 'mazes-small' dataset (see [here](https://drive.google.com/file/d/1cBgqhaUUtsrll8-o2VY42hPpyBcfFv86/view?usp=drivesdk)) for fast iteration and testing ideas. The following command can train a CTM locally without a GPU in 12-24 hours: 14 | ``` 15 | python -m tasks.mazes.train --dataset mazes-small --maze_route_length 50 --cirriculum_lookahead 5 --model ctm --d_model 1024 --d_input 256 --backbone_type resnet18-1 --synapse_depth 8 --heads 4 --n_synch_out 128 --n_synch_action 128 --neuron_select_type random-pairing --memory_length 25 --iterations 50 --training_iterations 100001 --lr 1e-4 --batch_size 64 --batch_size_test 32 --n_test_batches 50 --log_dir logs/mazes-small-tester --track_every 2000 16 | ``` -------------------------------------------------------------------------------- /tasks/mazes/analysis/README.md: -------------------------------------------------------------------------------- 1 | # Analysis 2 | 3 | This folder contains analysis code for 2D maze experiments. To build GIFs for imagenet run (from the base directory): 4 | 5 | To run maze analysis run the following command from the parent directory: 6 | ``` 7 | python -m tasks.mazes.analysis.run --actions viz viz --checkpoint checkpoints/mazes/ctm_mazeslarge_D=2048_T=75_M=25.pt 8 | ``` 9 | 10 | You will need to download the checkpoint from here: https://drive.google.com/file/d/1vGiMaQCxzKVT68SipxDCW0W5n5jjEQnC/view?usp=drive_link . Extract this to the appropriate directory: `checkpoints/mazes/...` . Otherwise, use your own after training. -------------------------------------------------------------------------------- /tasks/mazes/plotting.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import cv2 4 | import torch 5 | import os 6 | import matplotlib.pyplot as plt 7 | import imageio 8 | 9 | from tqdm.auto import tqdm 10 | 11 | def find_center_of_mass(array_2d): 12 | """ 13 | Alternative implementation using np.average and meshgrid. 14 | This version is generally faster and more concise. 15 | 16 | Args: 17 | array_2d: A 2D numpy array of values between 0 and 1. 18 | 19 | Returns: 20 | A tuple (x, y) representing the coordinates of the center of mass. 21 | """ 22 | total_mass = np.sum(array_2d) 23 | if total_mass == 0: 24 | return (np.nan, np.nan) 25 | 26 | y_coords, x_coords = np.mgrid[:array_2d.shape[0], :array_2d.shape[1]] 27 | x_center = np.average(x_coords, weights=array_2d) 28 | y_center = np.average(y_coords, weights=array_2d) 29 | return (round(y_center, 4), round(x_center, 4)) 30 | 31 | def draw_path(x, route, valid_only=False, gt=False, cmap=None): 32 | """ 33 | Draws a path on a maze image based on a given route. 34 | 35 | Args: 36 | maze: A numpy array representing the maze image. 37 | route: A list of integers representing the route, where 0 is up, 1 is down, 2 is left, and 3 is right. 38 | valid_only: A boolean indicating whether to only draw valid steps (i.e., steps that don't go into walls). 39 | 40 | Returns: 41 | A numpy array representing the maze image with the path drawn in blue. 42 | """ 43 | x = np.copy(x) 44 | start = np.argwhere((x == [1, 0, 0]).all(axis=2)) 45 | end = np.argwhere((x == [0, 1, 0]).all(axis=2)) 46 | if cmap is None: 47 | cmap = plt.get_cmap('winter') if not valid_only else plt.get_cmap('summer') 48 | 49 | # Initialize the current position 50 | current_pos = start[0] 51 | 52 | # Draw the path 53 | colors = cmap(np.linspace(0, 1, len(route))) 54 | si = 0 55 | for step in route: 56 | new_pos = current_pos 57 | if step == 0: # Up 58 | new_pos = (current_pos[0] - 1, current_pos[1]) 59 | elif step == 1: # Down 60 | new_pos = (current_pos[0] + 1, current_pos[1]) 61 | elif step == 2: # Left 62 | new_pos = (current_pos[0], current_pos[1] - 1) 63 | elif step == 3: # Right 64 | new_pos = (current_pos[0], current_pos[1] + 1) 65 | elif step == 4: # Do nothing 66 | pass 67 | else: 68 | raise ValueError("Invalid step: {}".format(step)) 69 | 70 | # Check if the new position is valid 71 | if valid_only: 72 | try: 73 | if np.all(x[new_pos] == [0,0,0]): # Check if it's a wall 74 | continue # Skip this step if it's invalid 75 | except IndexError: 76 | continue # Skip this step if it's out of bounds 77 | 78 | # Draw the step 79 | if new_pos[0] >= 0 and new_pos[0] < x.shape[0] and new_pos[1] >= 0 and new_pos[1] < x.shape[1]: 80 | if not ((x[new_pos] == [1,0,0]).all() or (x[new_pos] == [0,1,0]).all()): 81 | colour = colors[si][:3] 82 | si += 1 83 | x[new_pos] = x[new_pos]*0.5 + colour*0.5 84 | 85 | # Update the current position 86 | current_pos = new_pos 87 | # cv2.imwrite('maze2.png', x[:,:,::-1]*255) 88 | 89 | return x 90 | 91 | def make_maze_gif(inputs, predictions, targets, attention_tracking, save_location): 92 | """ 93 | Expect inputs, predictions, targets as numpy arrays 94 | """ 95 | route_steps = [] 96 | route_colours = [] 97 | solution_maze = draw_path(np.moveaxis(inputs, 0, -1), targets) 98 | 99 | n_heads = attention_tracking.shape[1] 100 | mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'], 101 | ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'], 102 | ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'], 103 | ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'], 104 | ['head_0', 'head_1', 'head_2', 'head_3', 'head_4', 'head_5', 'head_6', 'head_7'], 105 | ['head_8', 'head_9', 'head_10', 'head_11', 'head_12', 'head_13', 'head_14', 'head_15'], 106 | ] 107 | if n_heads == 8: 108 | mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'], 109 | ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'], 110 | ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'], 111 | ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'], 112 | ['head_0', 'head_1', 'head_2', 'head_3', 'head_4', 'head_5', 'head_6', 'head_7'], 113 | ] 114 | elif n_heads == 4: 115 | mosaic = [['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'], 116 | ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'], 117 | ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'], 118 | ['overlay', 'overlay', 'overlay', 'overlay', 'route', 'route', 'route', 'route'], 119 | ['head_0', 'head_0', 'head_1', 'head_1', 'head_2', 'head_2', 'head_3', 'head_3'], 120 | ['head_0', 'head_0', 'head_1', 'head_1', 'head_2', 'head_2', 'head_3', 'head_3'], 121 | ] 122 | 123 | img_aspect = 1 124 | figscale = 1 125 | aspect_ratio = (len(mosaic[0]) * figscale, len(mosaic) * figscale * img_aspect) # W, H 126 | 127 | route_steps = [np.unravel_index(np.argmax((inputs == np.reshape(np.array([1, 0, 0]), (3, 1, 1))).all(0)), inputs.shape[1:])] # Starting point 128 | frames = [] 129 | cmap = plt.get_cmap('gist_rainbow') 130 | cmap_viridis = plt.get_cmap('viridis') 131 | step_linspace = np.linspace(0, 1, predictions.shape[-1]) # For sampling colours 132 | with tqdm(total=predictions.shape[-1], initial=0, leave=True, position=1, dynamic_ncols=True) as pbar: 133 | pbar.set_description('Processing frames for maze plotting') 134 | for stepi in np.arange(0, predictions.shape[-1], 1): 135 | fig, axes = plt.subplot_mosaic(mosaic, figsize=aspect_ratio) 136 | for ax in axes.values(): 137 | ax.axis('off') 138 | guess_maze = draw_path(np.moveaxis(inputs, 0, -1), predictions.argmax(1)[:,stepi], cmap=cmap) 139 | attention_now = attention_tracking[stepi] 140 | for hi in range(min((attention_tracking.shape[1], 16))): 141 | ax = axes[f'head_{hi}'] 142 | attn = attention_tracking[stepi, hi] 143 | attn = (attn - attn.min())/(np.ptp(attn)) 144 | ax.imshow(attn, cmap=cmap_viridis) 145 | # Upsample attention just for visualisation 146 | aggregated_attention = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), inputs.shape[-1], mode='bilinear')[0].mean(0).numpy() 147 | 148 | # Get approximate center of mass 149 | com_attn = np.copy(aggregated_attention) 150 | com_attn[com_attn < np.percentile(com_attn, 96)] = 0.0 151 | aggregated_attention[aggregated_attention < np.percentile(aggregated_attention, 80)] = 0.0 152 | route_steps.append(find_center_of_mass(com_attn)) 153 | 154 | 155 | colour = list(cmap(step_linspace[stepi])) 156 | route_colours.append(colour) 157 | 158 | mapped_attention = torch.nn.functional.interpolate(torch.from_numpy(attention_now).unsqueeze(0), inputs.shape[-1], mode='bilinear')[0].mean(0).numpy() 159 | mapped_attention = (mapped_attention - mapped_attention.min())/np.ptp(mapped_attention) 160 | # np.clip(guess_maze * (1-mapped_attention[...,np.newaxis]*0.5) + (cmap_viridis(mapped_attention)[:,:,:3] * mapped_attention[...,np.newaxis])*1.3, 0, 1) 161 | overlay_img = np.clip(guess_maze * (1-mapped_attention[...,np.newaxis]*0.6) + (cmap_viridis(mapped_attention)[:,:,:3] * mapped_attention[...,np.newaxis])*1.1, 0, 1)#np.clip((np.copy(guess_maze)*(1-aggregated_attention[:,:,np.newaxis])*0.7 + (aggregated_attention[:,:,np.newaxis]*3 * np.reshape(np.array(colour)[:3], (1, 1, 3)))), 0, 1) 162 | axes['overlay'].imshow(overlay_img) 163 | 164 | y_coords, x_coords = zip(*route_steps) 165 | y_coords = inputs.shape[-1] - np.array(list(y_coords))-1 166 | 167 | 168 | axes['route'].imshow(np.flip(np.moveaxis(inputs, 0, -1), axis=0), origin='lower') 169 | # ax.imshow(np.flip(solution_maze, axis=0), origin='lower') 170 | arrow_scale = 2 171 | for i in range(len(route_steps)-1): 172 | dx = x_coords[i+1] - x_coords[i] 173 | dy = y_coords[i+1] - y_coords[i] 174 | axes['route'].arrow(x_coords[i], y_coords[i], dx, dy, linewidth=2*arrow_scale, head_width=0.2*arrow_scale, head_length=0.3*arrow_scale, fc=route_colours[i], ec=route_colours[i], length_includes_head = True) 175 | 176 | fig.tight_layout(pad=0.1) # Adjust spacing 177 | 178 | # Render the plot to a numpy array 179 | canvas = fig.canvas 180 | canvas.draw() 181 | image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8') 182 | image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] # Get RGB 183 | 184 | frames.append(image_numpy) # Add to list for GIF 185 | 186 | # fig.savefig(f'{save_location}/frame.png', dpi=200) 187 | 188 | plt.close(fig) 189 | 190 | # # frame = np.clip((np.copy(guess_maze)*0.5 + (aggregated_attention[:,:,np.newaxis] * np.reshape(np.array(colour)[:3], (1, 1, 3)))), 0, 1) 191 | # frame = torch.nn.functional.interpolate(torch.from_numpy(frame).permute(2,0,1).unsqueeze(0), 256)[0].permute(1,2,0).detach().cpu().numpy() 192 | # frames.append((frame*255).astype(np.uint8)) 193 | pbar.update(1) 194 | 195 | 196 | y_coords, x_coords = zip(*route_steps) 197 | y_coords = inputs.shape[-1] - np.array(list(y_coords))-1 198 | 199 | fig = plt.figure(figsize=(5,5)) 200 | ax = fig.add_subplot(111) 201 | 202 | ax.imshow(np.flip(np.moveaxis(inputs, 0, -1), axis=0), origin='lower') 203 | # ax.imshow(np.flip(solution_maze, axis=0), origin='lower') 204 | arrow_scale = 2 205 | for i in range(len(route_steps)-1): 206 | dx = x_coords[i+1] - x_coords[i] 207 | dy = y_coords[i+1] - y_coords[i] 208 | plt.arrow(x_coords[i], y_coords[i], dx, dy, linewidth=2*arrow_scale, head_width=0.2*arrow_scale, head_length=0.3*arrow_scale, fc=route_colours[i], ec=route_colours[i], length_includes_head = True) 209 | 210 | ax.axis('off') 211 | fig.tight_layout(pad=0) 212 | fig.savefig(f'{save_location}/route_approximation.png', dpi=200) 213 | imageio.mimsave(f'{save_location}/prediction.gif', frames, fps=15, loop=100) 214 | plt.close(fig) 215 | -------------------------------------------------------------------------------- /tasks/mazes/scripts/train_ctm.sh: -------------------------------------------------------------------------------- 1 | python -m tasks.mazes.train \ 2 | --model ctm \ 3 | --log_dir logs/mazes/ctm/d=2048--i=512--heads=16--sd=8--nlm=32--synch=64-32-h=32-first-last--iters=75x25--backbone=34-2 \ 4 | --neuron_select_type first-last \ 5 | --dataset mazes-large \ 6 | --synapse_depth 8 \ 7 | --heads 16 \ 8 | --iterations 75 \ 9 | --memory_length 25 \ 10 | --d_model 2048 \ 11 | --d_input 512 \ 12 | --backbone_type resnet34-2 \ 13 | --n_synch_out 64 \ 14 | --n_synch_action 32 \ 15 | --memory_hidden_dims 32 \ 16 | --deep_memory \ 17 | --weight_decay 0.000 \ 18 | --batch_size 64 \ 19 | --batch_size_test 128 \ 20 | --n_test_batches 20 \ 21 | --gradient_clipping -1 \ 22 | --use_scheduler \ 23 | --scheduler_type cosine \ 24 | --warmup_steps 10000 \ 25 | --training_iterations 1000001 \ 26 | --no-do_normalisation \ 27 | --track_every 1000 \ 28 | --lr 1e-4 \ 29 | --no-reload \ 30 | --dropout 0.1 \ 31 | --positional_embedding_type none \ 32 | --maze_route_length 100 \ 33 | --cirriculum_lookahead 5 \ 34 | --device 0 \ 35 | --no-expand_range -------------------------------------------------------------------------------- /tasks/parity/README.md: -------------------------------------------------------------------------------- 1 | # Parity 2 | 3 | ## Training 4 | To run the parity training that we used for the paper, run bash scripts from the root level of the repository. For example, to train the 75-iteration, 25-memory-length CTM, run: 5 | 6 | ``` 7 | bash tasks/parity/scripts/train_ctm_75_25.sh 8 | ``` 9 | 10 | 11 | ## Analysis 12 | To run the analysis, first make sure the checkpoints are saved in the log directory (specified by the `log_dir` argument). The checkpoints can be obtained by either running the training code, or downloading them from [this link](https://drive.google.com/file/d/1itUS5_i9AyUo_7awllTx8X0PXYw9fnaG/view?usp=drive_link). 13 | 14 | ``` 15 | python -m tasks.parity.analysis.run --log_dir 16 | ``` 17 | -------------------------------------------------------------------------------- /tasks/parity/analysis/make_blog_gifs.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import os 4 | import math 5 | import imageio 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | from matplotlib.patches import FancyArrowPatch 9 | from scipy.special import softmax 10 | import matplotlib.cm as cm 11 | from data.custom_datasets import ParityDataset 12 | import umap 13 | from tqdm import tqdm 14 | 15 | 16 | from models.utils import reshape_predictions 17 | from tasks.parity.utils import reshape_inputs 18 | from tasks.parity.analysis.run import build_model_from_checkpoint_path 19 | 20 | from tasks.image_classification.plotting import save_frames_to_mp4 21 | 22 | 23 | def make_parity_gif( 24 | predictions, 25 | targets, 26 | post_activations, 27 | attention_weights, 28 | inputs_to_model, 29 | save_path, 30 | umap_positions, 31 | umap_point_scaler=1.0, 32 | ): 33 | batch_index = 0 34 | figscale = 0.32 35 | n_steps, n_heads, seqLen = attention_weights.shape[:3] 36 | grid_side = int(np.sqrt(seqLen)) 37 | frames = [] 38 | 39 | inputs_this_batch = inputs_to_model[:, batch_index] 40 | preds_this_batch = predictions[batch_index] 41 | targets_this_batch = targets[batch_index] 42 | post_act_this_batch = post_activations[:, batch_index] 43 | 44 | # build a flexible mosaic 45 | mosaic = [ 46 | [f"att_0", f"in_0", "probs", "probs", "target", "target"], 47 | [f"att_1", f"in_1", "probs", "probs", "target", "target"], 48 | ] 49 | for h in range(2, n_heads): 50 | mosaic.append( 51 | [f"att_{h}", f"in_{h}", "umap", "umap", 52 | "umap", "umap"] 53 | ) 54 | 55 | for t in range(n_steps): 56 | rows = len(mosaic) 57 | cell_size = figscale * 4 58 | fig_h = rows * cell_size 59 | 60 | fig, ax = plt.subplot_mosaic( 61 | mosaic, 62 | figsize=(6 * cell_size, fig_h), 63 | constrained_layout=False, 64 | gridspec_kw={'wspace': 0.05, 'hspace': 0.05}, # small gaps 65 | ) 66 | # restore a little margin 67 | fig.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02) 68 | 69 | # probabilities heatmap 70 | logits_t = preds_this_batch[:, :, t] 71 | probs_t = softmax(logits_t, axis=1)[:, 0].reshape(grid_side, grid_side) 72 | ax["probs"].imshow(probs_t, cmap="gray", vmin=0, vmax=1) 73 | ax["probs"].axis("off") 74 | 75 | # target overlay 76 | ax["target"].imshow( 77 | targets_this_batch.reshape(grid_side, grid_side), 78 | cmap="gray_r", vmin=0, vmax=1 79 | ) 80 | ax["target"].axis("off") 81 | ax["target"].grid(which="minor", color="black", linestyle="-", linewidth=0.5) 82 | 83 | z = post_act_this_batch[t] 84 | low, high = np.percentile(z, 5), np.percentile(z, 95) 85 | z_norm = np.clip((z - low) / (high - low), 0, 1) 86 | point_sizes = (np.abs(z_norm - 0.5) * 100 + 5) * umap_point_scaler 87 | cmap = plt.get_cmap("Spectral") 88 | ax["umap"].scatter( 89 | umap_positions[:, 0], 90 | umap_positions[:, 1], 91 | s=point_sizes, 92 | c=cmap(z_norm), 93 | alpha=0.8 94 | ) 95 | ax["umap"].axis("off") 96 | 97 | 98 | # normalize attention 99 | att_t = attention_weights[t, :, :] 100 | a_min, a_max = att_t.min(), att_t.max() 101 | if not np.isclose(a_min, a_max): 102 | att_t = (att_t - a_min) / (a_max - a_min + 1e-8) 103 | else: 104 | att_t = np.zeros_like(att_t) 105 | 106 | # input image for arrows 107 | img_t = inputs_this_batch[t].transpose(1, 2, 0) 108 | 109 | if t == 0: 110 | route_history = [[] for _ in range(n_heads)] 111 | 112 | img_h, img_w = img_t.shape[:2] 113 | cell_h = img_h // grid_side 114 | cell_w = img_w // grid_side 115 | 116 | for h in range(n_heads): 117 | head_map = att_t[h].reshape(grid_side, grid_side) 118 | ax[f"att_{h}"].imshow(head_map, cmap="viridis", vmin=0, vmax=1) 119 | ax[f"att_{h}"].axis("off") 120 | ax[f"in_{h}"].imshow(img_t, cmap="gray", vmin=0, vmax=1) 121 | ax[f"in_{h}"].axis("off") 122 | 123 | # track argmax center 124 | flat_idx = np.argmax(head_map) 125 | gy, gx = divmod(flat_idx, grid_side) 126 | cx = int((gx + 0.5) * cell_w) 127 | cy = int((gy + 0.5) * cell_h) 128 | route_history[h].append((cx, cy)) 129 | 130 | cmap_steps = plt.colormaps.get_cmap("Spectral") 131 | colors = [cmap_steps(i / (n_steps - 1)) for i in range(n_steps)] 132 | for i in range(len(route_history[h]) - 1): 133 | x0, y0 = route_history[h][i] 134 | x1, y1 = route_history[h][i + 1] 135 | color = colors[i] 136 | is_last = (i == len(route_history[h]) - 2) 137 | style = '->' if is_last else '-' 138 | lw = 2.0 if is_last else 1.6 139 | alpha = 1.0 if is_last else 0.9 140 | scale = 10 if is_last else 1 141 | 142 | # draw arrow 143 | arr = FancyArrowPatch( 144 | (x0, y0), (x1, y1), 145 | arrowstyle=style, 146 | linewidth=lw, 147 | mutation_scale=scale, 148 | alpha=alpha, 149 | facecolor=color, 150 | edgecolor=color, 151 | shrinkA=0, shrinkB=0, 152 | capstyle='round', joinstyle='round', 153 | zorder=3 if is_last else 2, 154 | clip_on=False, 155 | ) 156 | ax[f"in_{h}"].add_patch(arr) 157 | 158 | ax[f"in_{h}"].scatter( 159 | x1, y1, 160 | marker='x', 161 | s=40, 162 | color=color, 163 | linewidths=lw, 164 | zorder=4 165 | ) 166 | 167 | canvas = fig.canvas 168 | canvas.draw() 169 | frame = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8) 170 | w, h = canvas.get_width_height() 171 | frames.append(frame.reshape(h, w, 4)[..., :3]) 172 | plt.close(fig) 173 | 174 | # save gif 175 | imageio.mimsave(f"{save_path}/activation.gif", frames, fps=15, loop=0) 176 | 177 | # save mp4 178 | save_frames_to_mp4( 179 | [fm[:, :, ::-1] for fm in frames], # RGB→BGR 180 | f"{save_path}/activation.mp4", 181 | fps=15, 182 | gop_size=1, 183 | preset="slow" 184 | ) 185 | 186 | def run_umap(model, testloader): 187 | all_post_activations = [] 188 | point_counts = 150 189 | sampled = 0 190 | with tqdm(total=point_counts, desc="Collecting UMAP data") as pbar: 191 | for inputs, _ in testloader: 192 | for i in range(inputs.size(0)): 193 | if sampled >= point_counts: 194 | break 195 | input_i = inputs[i].unsqueeze(0).to(device) 196 | _, _, _, _, post_activations, _ = model(input_i, track=True) 197 | all_post_activations.append(post_activations) 198 | sampled += 1 199 | pbar.update(1) 200 | if sampled >= point_counts: 201 | break 202 | 203 | stacked = np.stack(all_post_activations, 1) 204 | umap_features = stacked.reshape(-1, stacked.shape[-1]) 205 | reducer = umap.UMAP( 206 | n_components=2, 207 | n_neighbors=20, 208 | min_dist=1, 209 | spread=1, 210 | metric='cosine', 211 | local_connectivity=1 212 | ) 213 | positions = reducer.fit_transform(umap_features.T) 214 | return positions 215 | 216 | 217 | def run_model_and_make_gif(checkpoint_path, save_path, device): 218 | 219 | parity_sequence_length = 64 220 | iterations = 75 221 | 222 | test_data = ParityDataset(sequence_length=parity_sequence_length, length=10000) 223 | testloader = torch.utils.data.DataLoader(test_data, batch_size=256, shuffle=True, num_workers=0, drop_last=False) 224 | 225 | 226 | model, _ = build_model_from_checkpoint_path(checkpoint_path, "ctm", device=device) 227 | 228 | input = torch.randint(0, 2, (64,), dtype=torch.float32, device=device) * 2 - 1 229 | input = input.unsqueeze(0) 230 | 231 | target = torch.cumsum((input == -1).to(torch.long), dim=1) % 2 232 | target = target.unsqueeze(0) 233 | 234 | positions = run_umap(model, testloader) 235 | 236 | model.eval() 237 | with torch.inference_mode(): 238 | predictions, _, _, _, post_activations, attention = model(input, track=True) 239 | predictons = reshape_predictions(predictions, prediction_reshaper=[parity_sequence_length, 2]) 240 | input_images = reshape_inputs(input, iterations, grid_size=int(math.sqrt(parity_sequence_length))) 241 | 242 | make_parity_gif( 243 | predictions=predictons.detach().cpu().numpy(), 244 | targets=target.detach().cpu().numpy(), 245 | post_activations=post_activations, 246 | attention_weights=attention.squeeze(1).squeeze(2), 247 | inputs_to_model=input_images, 248 | save_path=save_path, 249 | umap_positions=positions, 250 | umap_point_scaler=1.0, 251 | ) 252 | 253 | 254 | 255 | if __name__ == "__main__": 256 | 257 | CHECKPOINT_PATH = "checkpoints/parity/run1/ctm_75_25/checkpoint_200000.pt" 258 | SAVE_PATH = f"tasks/parity/analysis/outputs/blog_gifs/" 259 | os.makedirs(SAVE_PATH, exist_ok=True) 260 | 261 | device = "cuda" if torch.cuda.is_available() else "cpu" 262 | 263 | run_model_and_make_gif(CHECKPOINT_PATH, SAVE_PATH, device) 264 | -------------------------------------------------------------------------------- /tasks/parity/scripts/train_ctm_100_50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | ITERATIONS=100 4 | MEMORY_LENGTH=50 5 | LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}" 6 | SEED=$((RUN - 1)) 7 | 8 | python -m tasks.parity.train \ 9 | --log_dir $LOG_DIR \ 10 | --seed $SEED \ 11 | --iterations $ITERATIONS \ 12 | --memory_length $MEMORY_LENGTH \ 13 | --parity_sequence_length 64 \ 14 | --n_test_batches 20 \ 15 | --d_model 1024 \ 16 | --d_input 512 \ 17 | --n_synch_out 32 \ 18 | --n_synch_action 32 \ 19 | --synapse_depth 1 \ 20 | --heads 8 \ 21 | --memory_hidden_dims 16 \ 22 | --dropout 0.0 \ 23 | --deep_memory \ 24 | --no-do_normalisation \ 25 | --positional_embedding_type="custom-rotational-1d" \ 26 | --backbone_type="parity_backbone" \ 27 | --no-full_eval \ 28 | --weight_decay 0.0 \ 29 | --gradient_clipping 0.9 \ 30 | --use_scheduler \ 31 | --scheduler_type "cosine" \ 32 | --milestones 0 0 0 \ 33 | --gamma 0 \ 34 | --dataset "parity" \ 35 | --batch_size 64 \ 36 | --batch_size_test 256 \ 37 | --lr=0.0001 \ 38 | --training_iterations 200001 \ 39 | --warmup_steps 500 \ 40 | --track_every 1000 \ 41 | --save_every 10000 \ 42 | --no-reload \ 43 | --no-reload_model_only \ 44 | --device 0 \ 45 | --no-use_amp \ 46 | --neuron_select_type "random" -------------------------------------------------------------------------------- /tasks/parity/scripts/train_ctm_10_5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | ITERATIONS=10 4 | MEMORY_LENGTH=5 5 | LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}" 6 | SEED=$((RUN - 1)) 7 | 8 | python -m tasks.parity.train \ 9 | --log_dir $LOG_DIR \ 10 | --seed $SEED \ 11 | --iterations $ITERATIONS \ 12 | --memory_length $MEMORY_LENGTH \ 13 | --parity_sequence_length 64 \ 14 | --n_test_batches 20 \ 15 | --d_model 1024 \ 16 | --d_input 512 \ 17 | --n_synch_out 32 \ 18 | --n_synch_action 32 \ 19 | --synapse_depth 1 \ 20 | --heads 8 \ 21 | --memory_hidden_dims 16 \ 22 | --dropout 0.0 \ 23 | --deep_memory \ 24 | --no-do_normalisation \ 25 | --positional_embedding_type="custom-rotational-1d" \ 26 | --backbone_type="parity_backbone" \ 27 | --no-full_eval \ 28 | --weight_decay 0.0 \ 29 | --gradient_clipping 0.9 \ 30 | --use_scheduler \ 31 | --scheduler_type "cosine" \ 32 | --milestones 0 0 0 \ 33 | --gamma 0 \ 34 | --dataset "parity" \ 35 | --batch_size 64 \ 36 | --batch_size_test 256 \ 37 | --lr=0.0001 \ 38 | --training_iterations 200001 \ 39 | --warmup_steps 500 \ 40 | --track_every 1000 \ 41 | --save_every 10000 \ 42 | --no-reload \ 43 | --no-reload_model_only \ 44 | --device 0 \ 45 | --no-use_amp \ 46 | --neuron_select_type "random" -------------------------------------------------------------------------------- /tasks/parity/scripts/train_ctm_1_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | ITERATIONS=1 4 | MEMORY_LENGTH=1 5 | LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}" 6 | SEED=$((RUN - 1)) 7 | 8 | python -m tasks.parity.train \ 9 | --log_dir $LOG_DIR \ 10 | --seed $SEED \ 11 | --iterations $ITERATIONS \ 12 | --memory_length $MEMORY_LENGTH \ 13 | --parity_sequence_length 64 \ 14 | --n_test_batches 20 \ 15 | --d_model 1024 \ 16 | --d_input 512 \ 17 | --n_synch_out 32 \ 18 | --n_synch_action 32 \ 19 | --synapse_depth 1 \ 20 | --heads 8 \ 21 | --memory_hidden_dims 16 \ 22 | --dropout 0.0 \ 23 | --deep_memory \ 24 | --no-do_normalisation \ 25 | --positional_embedding_type="custom-rotational-1d" \ 26 | --backbone_type="parity_backbone" \ 27 | --no-full_eval \ 28 | --weight_decay 0.0 \ 29 | --gradient_clipping 0.9 \ 30 | --use_scheduler \ 31 | --scheduler_type "cosine" \ 32 | --milestones 0 0 0 \ 33 | --gamma 0 \ 34 | --dataset "parity" \ 35 | --batch_size 64 \ 36 | --batch_size_test 256 \ 37 | --lr=0.0001 \ 38 | --training_iterations 200001 \ 39 | --warmup_steps 500 \ 40 | --track_every 1000 \ 41 | --save_every 10000 \ 42 | --no-reload \ 43 | --no-reload_model_only \ 44 | --device 0 \ 45 | --no-use_amp \ 46 | --neuron_select_type "random" -------------------------------------------------------------------------------- /tasks/parity/scripts/train_ctm_25_10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | ITERATIONS=25 4 | MEMORY_LENGTH=10 5 | LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}" 6 | SEED=$((RUN - 1)) 7 | 8 | python -m tasks.parity.train \ 9 | --log_dir $LOG_DIR \ 10 | --seed $SEED \ 11 | --iterations $ITERATIONS \ 12 | --memory_length $MEMORY_LENGTH \ 13 | --parity_sequence_length 64 \ 14 | --n_test_batches 20 \ 15 | --d_model 1024 \ 16 | --d_input 512 \ 17 | --n_synch_out 32 \ 18 | --n_synch_action 32 \ 19 | --synapse_depth 1 \ 20 | --heads 8 \ 21 | --memory_hidden_dims 16 \ 22 | --dropout 0.0 \ 23 | --deep_memory \ 24 | --no-do_normalisation \ 25 | --positional_embedding_type="custom-rotational-1d" \ 26 | --backbone_type="parity_backbone" \ 27 | --no-full_eval \ 28 | --weight_decay 0.0 \ 29 | --gradient_clipping 0.9 \ 30 | --use_scheduler \ 31 | --scheduler_type "cosine" \ 32 | --milestones 0 0 0 \ 33 | --gamma 0 \ 34 | --dataset "parity" \ 35 | --batch_size 64 \ 36 | --batch_size_test 256 \ 37 | --lr=0.0001 \ 38 | --training_iterations 200001 \ 39 | --warmup_steps 500 \ 40 | --track_every 1000 \ 41 | --save_every 10000 \ 42 | --no-reload \ 43 | --no-reload_model_only \ 44 | --device 0 \ 45 | --no-use_amp \ 46 | --neuron_select_type "random" -------------------------------------------------------------------------------- /tasks/parity/scripts/train_ctm_50_25.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | ITERATIONS=50 4 | MEMORY_LENGTH=25 5 | LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}" 6 | SEED=$((RUN - 1)) 7 | 8 | python -m tasks.parity.train \ 9 | --log_dir $LOG_DIR \ 10 | --seed $SEED \ 11 | --iterations $ITERATIONS \ 12 | --memory_length $MEMORY_LENGTH \ 13 | --parity_sequence_length 64 \ 14 | --n_test_batches 20 \ 15 | --d_model 1024 \ 16 | --d_input 512 \ 17 | --n_synch_out 32 \ 18 | --n_synch_action 32 \ 19 | --synapse_depth 1 \ 20 | --heads 8 \ 21 | --memory_hidden_dims 16 \ 22 | --dropout 0.0 \ 23 | --deep_memory \ 24 | --no-do_normalisation \ 25 | --positional_embedding_type="custom-rotational-1d" \ 26 | --backbone_type="parity_backbone" \ 27 | --no-full_eval \ 28 | --weight_decay 0.0 \ 29 | --gradient_clipping 0.9 \ 30 | --use_scheduler \ 31 | --scheduler_type "cosine" \ 32 | --milestones 0 0 0 \ 33 | --gamma 0 \ 34 | --dataset "parity" \ 35 | --batch_size 64 \ 36 | --batch_size_test 256 \ 37 | --lr=0.0001 \ 38 | --training_iterations 200001 \ 39 | --warmup_steps 500 \ 40 | --track_every 1000 \ 41 | --save_every 10000 \ 42 | --no-reload \ 43 | --no-reload_model_only \ 44 | --device 0 \ 45 | --no-use_amp \ 46 | --neuron_select_type "random" -------------------------------------------------------------------------------- /tasks/parity/scripts/train_ctm_75_25.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | ITERATIONS=75 4 | MEMORY_LENGTH=25 5 | LOG_DIR="logs/parity/run${RUN}/ctm_${ITERATIONS}_${MEMORY_LENGTH}" 6 | SEED=$((RUN - 1)) 7 | 8 | python -m tasks.parity.train \ 9 | --log_dir $LOG_DIR \ 10 | --seed $SEED \ 11 | --iterations $ITERATIONS \ 12 | --memory_length $MEMORY_LENGTH \ 13 | --parity_sequence_length 64 \ 14 | --n_test_batches 20 \ 15 | --d_model 1024 \ 16 | --d_input 512 \ 17 | --n_synch_out 32 \ 18 | --n_synch_action 32 \ 19 | --synapse_depth 1 \ 20 | --heads 8 \ 21 | --memory_hidden_dims 16 \ 22 | --dropout 0.0 \ 23 | --deep_memory \ 24 | --no-do_normalisation \ 25 | --positional_embedding_type="custom-rotational-1d" \ 26 | --backbone_type="parity_backbone" \ 27 | --no-full_eval \ 28 | --weight_decay 0.0 \ 29 | --gradient_clipping 0.9 \ 30 | --use_scheduler \ 31 | --scheduler_type "cosine" \ 32 | --milestones 0 0 0 \ 33 | --gamma 0 \ 34 | --dataset "parity" \ 35 | --batch_size 64 \ 36 | --batch_size_test 256 \ 37 | --lr=0.0001 \ 38 | --training_iterations 200001 \ 39 | --warmup_steps 500 \ 40 | --track_every 1000 \ 41 | --save_every 10000 \ 42 | --no-reload \ 43 | --no-reload_model_only \ 44 | --device 0 \ 45 | --no-use_amp \ 46 | --neuron_select_type "random" -------------------------------------------------------------------------------- /tasks/parity/scripts/train_lstm_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | ITERATIONS=1 4 | MODEL_TYPE="lstm" 5 | LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 6 | SEED=$((RUN - 1)) 7 | 8 | python -m tasks.parity.train \ 9 | --log_dir $LOG_DIR \ 10 | --seed $SEED \ 11 | --iterations $ITERATIONS \ 12 | --model_type $MODEL_TYPE \ 13 | --parity_sequence_length 64 \ 14 | --n_test_batches 20 \ 15 | --d_model 669 \ 16 | --d_input 512 \ 17 | --heads 8 \ 18 | --dropout 0.0 \ 19 | --positional_embedding_type="custom-rotational-1d" \ 20 | --backbone_type="parity_backbone" \ 21 | --no-full_eval \ 22 | --weight_decay 0.0 \ 23 | --gradient_clipping -1 \ 24 | --use_scheduler \ 25 | --scheduler_type "cosine" \ 26 | --milestones 0 0 0 \ 27 | --gamma 0 \ 28 | --dataset "parity" \ 29 | --batch_size 64 \ 30 | --batch_size_test 256 \ 31 | --lr=0.0001 \ 32 | --training_iterations 200001 \ 33 | --warmup_steps 500 \ 34 | --track_every 1000 \ 35 | --save_every 10000 \ 36 | --no-reload \ 37 | --no-reload_model_only \ 38 | --device 0 \ 39 | --no-use_amp \ 40 | -------------------------------------------------------------------------------- /tasks/parity/scripts/train_lstm_10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | ITERATIONS=10 4 | MODEL_TYPE="lstm" 5 | LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 6 | SEED=$((RUN - 1)) 7 | 8 | python -m tasks.parity.train \ 9 | --log_dir $LOG_DIR \ 10 | --seed $SEED \ 11 | --iterations $ITERATIONS \ 12 | --model_type $MODEL_TYPE \ 13 | --parity_sequence_length 64 \ 14 | --n_test_batches 20 \ 15 | --d_model 686 \ 16 | --d_input 512 \ 17 | --heads 8 \ 18 | --dropout 0.0 \ 19 | --positional_embedding_type="custom-rotational-1d" \ 20 | --backbone_type="parity_backbone" \ 21 | --no-full_eval \ 22 | --weight_decay 0.0 \ 23 | --gradient_clipping -1 \ 24 | --use_scheduler \ 25 | --scheduler_type "cosine" \ 26 | --milestones 0 0 0 \ 27 | --gamma 0 \ 28 | --dataset "parity" \ 29 | --batch_size 64 \ 30 | --batch_size_test 256 \ 31 | --lr=0.0001 \ 32 | --training_iterations 200001 \ 33 | --warmup_steps 500 \ 34 | --track_every 1000 \ 35 | --save_every 10000 \ 36 | --no-reload \ 37 | --no-reload_model_only \ 38 | --device 0 \ 39 | --no-use_amp \ 40 | -------------------------------------------------------------------------------- /tasks/parity/scripts/train_lstm_100.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | ITERATIONS=100 4 | MODEL_TYPE="lstm" 5 | LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 6 | SEED=$((RUN - 1)) 7 | 8 | python -m tasks.parity.train \ 9 | --log_dir $LOG_DIR \ 10 | --seed $SEED \ 11 | --iterations $ITERATIONS \ 12 | --model_type $MODEL_TYPE \ 13 | --parity_sequence_length 64 \ 14 | --n_test_batches 20 \ 15 | --d_model 857 \ 16 | --d_input 512 \ 17 | --heads 8 \ 18 | --dropout 0.0 \ 19 | --positional_embedding_type="custom-rotational-1d" \ 20 | --backbone_type="parity_backbone" \ 21 | --no-full_eval \ 22 | --weight_decay 0.0 \ 23 | --gradient_clipping -1 \ 24 | --use_scheduler \ 25 | --scheduler_type "cosine" \ 26 | --milestones 0 0 0 \ 27 | --gamma 0 \ 28 | --dataset "parity" \ 29 | --batch_size 64 \ 30 | --batch_size_test 256 \ 31 | --lr=0.0001 \ 32 | --training_iterations 200001 \ 33 | --warmup_steps 500 \ 34 | --track_every 1000 \ 35 | --save_every 10000 \ 36 | --no-reload \ 37 | --no-reload_model_only \ 38 | --device 0 \ 39 | --no-use_amp \ 40 | -------------------------------------------------------------------------------- /tasks/parity/scripts/train_lstm_10_certain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=3 3 | ITERATIONS=10 4 | MODEL_TYPE="lstm" 5 | LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}_certain" 6 | SEED=$((RUN - 1)) 7 | 8 | python -m tasks.parity.train \ 9 | --log_dir $LOG_DIR \ 10 | --seed $SEED \ 11 | --iterations $ITERATIONS \ 12 | --model_type $MODEL_TYPE \ 13 | --parity_sequence_length 64 \ 14 | --n_test_batches 20 \ 15 | --d_model 686 \ 16 | --d_input 512 \ 17 | --heads 8 \ 18 | --dropout 0.0 \ 19 | --positional_embedding_type="custom-rotational-1d" \ 20 | --backbone_type="parity_backbone" \ 21 | --no-full_eval \ 22 | --weight_decay 0.0 \ 23 | --gradient_clipping -1 \ 24 | --use_scheduler \ 25 | --scheduler_type "cosine" \ 26 | --milestones 0 0 0 \ 27 | --gamma 0 \ 28 | --dataset "parity" \ 29 | --batch_size 64 \ 30 | --batch_size_test 256 \ 31 | --lr=0.0001 \ 32 | --training_iterations 200001 \ 33 | --warmup_steps 500 \ 34 | --track_every 1000 \ 35 | --save_every 10000 \ 36 | --no-reload \ 37 | --no-reload_model_only \ 38 | --device 0 \ 39 | --no-use_amp \ 40 | --use_most_certain_with_lstm \ 41 | -------------------------------------------------------------------------------- /tasks/parity/scripts/train_lstm_25.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | ITERATIONS=25 4 | MODEL_TYPE="lstm" 5 | LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 6 | SEED=$((RUN - 1)) 7 | 8 | python -m tasks.parity.train \ 9 | --log_dir $LOG_DIR \ 10 | --seed $SEED \ 11 | --iterations $ITERATIONS \ 12 | --model_type $MODEL_TYPE \ 13 | --parity_sequence_length 64 \ 14 | --n_test_batches 20 \ 15 | --d_model 706 \ 16 | --d_input 512 \ 17 | --heads 8 \ 18 | --dropout 0.0 \ 19 | --positional_embedding_type="custom-rotational-1d" \ 20 | --backbone_type="parity_backbone" \ 21 | --no-full_eval \ 22 | --weight_decay 0.0 \ 23 | --gradient_clipping -1 \ 24 | --use_scheduler \ 25 | --scheduler_type "cosine" \ 26 | --milestones 0 0 0 \ 27 | --gamma 0 \ 28 | --dataset "parity" \ 29 | --batch_size 64 \ 30 | --batch_size_test 256 \ 31 | --lr=0.0001 \ 32 | --training_iterations 200001 \ 33 | --warmup_steps 500 \ 34 | --track_every 1000 \ 35 | --save_every 10000 \ 36 | --no-reload \ 37 | --no-reload_model_only \ 38 | --device 0 \ 39 | --no-use_amp \ 40 | -------------------------------------------------------------------------------- /tasks/parity/scripts/train_lstm_25_certain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=3 3 | ITERATIONS=25 4 | MODEL_TYPE="lstm" 5 | LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}_certain" 6 | SEED=$((RUN - 1)) 7 | 8 | python -m tasks.parity.train \ 9 | --log_dir $LOG_DIR \ 10 | --seed $SEED \ 11 | --iterations $ITERATIONS \ 12 | --model_type $MODEL_TYPE \ 13 | --parity_sequence_length 64 \ 14 | --n_test_batches 20 \ 15 | --d_model 706 \ 16 | --d_input 512 \ 17 | --heads 8 \ 18 | --dropout 0.0 \ 19 | --positional_embedding_type="custom-rotational-1d" \ 20 | --backbone_type="parity_backbone" \ 21 | --no-full_eval \ 22 | --weight_decay 0.0 \ 23 | --gradient_clipping -1 \ 24 | --use_scheduler \ 25 | --scheduler_type "cosine" \ 26 | --milestones 0 0 0 \ 27 | --gamma 0 \ 28 | --dataset "parity" \ 29 | --batch_size 64 \ 30 | --batch_size_test 256 \ 31 | --lr=0.0001 \ 32 | --training_iterations 200001 \ 33 | --warmup_steps 500 \ 34 | --track_every 1000 \ 35 | --save_every 10000 \ 36 | --no-reload \ 37 | --no-reload_model_only \ 38 | --device 0 \ 39 | --no-use_amp \ 40 | --use_most_certain_with_lstm \ -------------------------------------------------------------------------------- /tasks/parity/scripts/train_lstm_50.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | ITERATIONS=50 4 | MODEL_TYPE="lstm" 5 | LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 6 | SEED=$((RUN - 1)) 7 | 8 | python -m tasks.parity.train \ 9 | --log_dir $LOG_DIR \ 10 | --seed $SEED \ 11 | --iterations $ITERATIONS \ 12 | --model_type $MODEL_TYPE \ 13 | --parity_sequence_length 64 \ 14 | --n_test_batches 20 \ 15 | --d_model 765 \ 16 | --d_input 512 \ 17 | --heads 8 \ 18 | --dropout 0.0 \ 19 | --positional_embedding_type="custom-rotational-1d" \ 20 | --backbone_type="parity_backbone" \ 21 | --no-full_eval \ 22 | --weight_decay 0.0 \ 23 | --gradient_clipping -1 \ 24 | --use_scheduler \ 25 | --scheduler_type "cosine" \ 26 | --milestones 0 0 0 \ 27 | --gamma 0 \ 28 | --dataset "parity" \ 29 | --batch_size 64 \ 30 | --batch_size_test 256 \ 31 | --lr=0.0001 \ 32 | --training_iterations 200001 \ 33 | --warmup_steps 500 \ 34 | --track_every 1000 \ 35 | --save_every 10000 \ 36 | --no-reload \ 37 | --no-reload_model_only \ 38 | --device 0 \ 39 | --no-use_amp \ 40 | -------------------------------------------------------------------------------- /tasks/parity/scripts/train_lstm_75.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | ITERATIONS=75 4 | MODEL_TYPE="lstm" 5 | LOG_DIR="logs/parity/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 6 | SEED=$((RUN - 1)) 7 | 8 | python -m tasks.parity.train \ 9 | --log_dir $LOG_DIR \ 10 | --seed $SEED \ 11 | --iterations $ITERATIONS \ 12 | --model_type $MODEL_TYPE \ 13 | --parity_sequence_length 64 \ 14 | --n_test_batches 20 \ 15 | --d_model 765 \ 16 | --d_input 512 \ 17 | --heads 8 \ 18 | --dropout 0.0 \ 19 | --positional_embedding_type="custom-rotational-1d" \ 20 | --backbone_type="parity_backbone" \ 21 | --no-full_eval \ 22 | --weight_decay 0.0 \ 23 | --gradient_clipping -1 \ 24 | --use_scheduler \ 25 | --scheduler_type "cosine" \ 26 | --milestones 0 0 0 \ 27 | --gamma 0 \ 28 | --dataset "parity" \ 29 | --batch_size 64 \ 30 | --batch_size_test 256 \ 31 | --lr=0.0001 \ 32 | --training_iterations 200001 \ 33 | --warmup_steps 500 \ 34 | --track_every 1000 \ 35 | --save_every 10000 \ 36 | --no-reload \ 37 | --no-reload_model_only \ 38 | --device 0 \ 39 | --no-use_amp \ 40 | -------------------------------------------------------------------------------- /tasks/parity/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import math 4 | from models.ctm import ContinuousThoughtMachine 5 | from models.lstm import LSTMBaseline 6 | 7 | def prepare_model(prediction_reshaper, args, device): 8 | if args.model_type == 'ctm': 9 | model = ContinuousThoughtMachine( 10 | iterations=args.iterations, 11 | d_model=args.d_model, 12 | d_input=args.d_input, 13 | heads=args.heads, 14 | n_synch_out=args.n_synch_out, 15 | n_synch_action=args.n_synch_action, 16 | synapse_depth=args.synapse_depth, 17 | memory_length=args.memory_length, 18 | deep_nlms=args.deep_memory, 19 | memory_hidden_dims=args.memory_hidden_dims, 20 | do_layernorm_nlm=args.do_normalisation, 21 | backbone_type=args.backbone_type, 22 | positional_embedding_type=args.positional_embedding_type, 23 | out_dims=args.out_dims, 24 | prediction_reshaper=prediction_reshaper, 25 | dropout=args.dropout, 26 | neuron_select_type=args.neuron_select_type, 27 | n_random_pairing_self=args.n_random_pairing_self, 28 | ).to(device) 29 | elif args.model_type == 'lstm': 30 | model = LSTMBaseline( 31 | iterations=args.iterations, 32 | d_model=args.d_model, 33 | num_layers=1, 34 | d_input=args.d_input, 35 | heads=args.heads, 36 | backbone_type=args.backbone_type, 37 | positional_embedding_type=args.positional_embedding_type, 38 | out_dims=args.out_dims, 39 | prediction_reshaper=prediction_reshaper, 40 | dropout=args.dropout, 41 | ).to(device) 42 | else: 43 | raise ValueError(f"Model must be either ctm or lstm, not {args.model_type}") 44 | 45 | return model 46 | 47 | def reshape_attention_weights(attention_weights): 48 | T, B = attention_weights.shape[0], attention_weights.shape[1] 49 | grid_size = math.sqrt(attention_weights.shape[-1]) 50 | assert grid_size.is_integer(), f'Grid size should be a perfect square, but got {attention_weights.shape[-1]}' 51 | H_ATTENTION = W_ATTENTION = int(grid_size) 52 | attn_weights_reshaped = attention_weights.reshape(T, B, -1, H_ATTENTION, W_ATTENTION) 53 | return attn_weights_reshaped.mean(2) 54 | 55 | def reshape_inputs(inputs, iterations, grid_size): 56 | reshaped_inputs = inputs.reshape(-1, grid_size, grid_size).unsqueeze(0).repeat(iterations, 1, 1, 1).unsqueeze(2).detach().cpu().numpy() 57 | return reshaped_inputs 58 | 59 | def get_where_most_certain(certainties): 60 | return certainties[:,1].argmax(-1) 61 | 62 | def parse_folder_name(folder_path): 63 | folder = os.path.basename(folder_path) 64 | 65 | lstm_match = re.match(r"lstm_(\d+)", folder) 66 | if lstm_match: 67 | model_type = "LSTM" 68 | iters = int(lstm_match.group(1)) 69 | return f"{model_type}, {iters} Iters.", model_type, iters 70 | 71 | ctm_full_match = re.match(r"ctm(\d+)_(\d+)", folder) 72 | if ctm_full_match: 73 | model_type = "CTM" 74 | iters = int(ctm_full_match.group(1)) 75 | mem_len = int(ctm_full_match.group(2)) 76 | return f"{model_type}, {iters} Iters., {mem_len} Mem. Len.", model_type, iters 77 | 78 | ctm_partial_match = re.match(r"ctm_(\d+)", folder) 79 | if ctm_partial_match: 80 | model_type = "CTM" 81 | iters = int(ctm_partial_match.group(1)) 82 | return f"{model_type}, {iters} Iters.", model_type, iters 83 | 84 | return "Unknown", None, None 85 | -------------------------------------------------------------------------------- /tasks/qamnist/README.md: -------------------------------------------------------------------------------- 1 | # Q&A MNIST 2 | 3 | ## Training 4 | To run the Q&A MNIST training that we used for the paper, run bash scripts from the root level of the repository. For example, to train the 10-iteration CTM, run: 5 | 6 | ``` 7 | bash tasks/qamnist/scripts/train_ctm_10.sh 8 | ``` 9 | 10 | ## Analysis 11 | To run the analysis, first make sure the checkpoints are saved in the log directory (specified by the `log_dir` argument). The checkpoints can be obtained by either running the training code, or downloading them from [this link](https://drive.google.com/file/d/1-ycgRYxOlZ9-TJ_n3xvUonRvvf5Lh0r3/view?usp=drive_link). 12 | 13 | ``` 14 | python -m tasks.qamnist.analysis.run --log_dir 15 | ``` 16 | -------------------------------------------------------------------------------- /tasks/qamnist/analysis/make_blog_gifs_equation_animation.py: -------------------------------------------------------------------------------- 1 | from manim import config 2 | config.background_color = "WHITE" 3 | 4 | from manim import Scene, MathTex, Indicate, RED, BLACK 5 | 6 | class ModularEquationHighlight(Scene): 7 | def construct(self): 8 | tokens = [ 9 | "(", "(", "(", "(", "(", "(", "(", # 0-6 10 | "5", "-", "6", # 7-9 11 | ")", ")", # 10-11 12 | "\mod", "10", # 12-13 13 | ")", "+", "5", # 14-16 14 | ")", ")", "\mod", "10", # 17-20 15 | ")", "+", "5", # 21-23 16 | ")", "\mod", "10", "=", "9" # 24-28 17 | ] 18 | 19 | eq = MathTex(*tokens) 20 | eq.set_color(BLACK) 21 | eq.scale(1.1) 22 | 23 | self.add(eq) 24 | self.wait(40 / 15) 25 | 26 | highlight_sequence = [7, 8, 9, 15, 16, 22, 23] # 9 + 9, +1, +3 27 | 28 | for idx in highlight_sequence: 29 | eq[idx].set_stroke(color=RED, width=6) 30 | self.play(Indicate(eq[idx], color=RED, scale_factor=1.3), run_time=10 / 15) 31 | eq[idx].set_stroke(color=BLACK, width=0) 32 | 33 | self.wait(10 / 15) 34 | -------------------------------------------------------------------------------- /tasks/qamnist/plotting.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.pyplot as plt 3 | import matplotlib as mpl 4 | import seaborn as sns 5 | import numpy as np 6 | import imageio 7 | from scipy.special import softmax 8 | sns.set_style('darkgrid') 9 | mpl.use('Agg') 10 | 11 | 12 | 13 | 14 | def make_qamnist_gif(predictions, certainties, targets, pre_activations, post_activations, input_gates, inputs_to_model, filename, question_readable=None): 15 | 16 | # Config 17 | batch_index = 0 18 | n_neurons_to_visualise = 16 19 | figscale = 0.28 20 | n_steps = len(pre_activations) 21 | heatmap_cmap = sns.color_palette("viridis", as_cmap=True) 22 | frames = [] 23 | 24 | these_pre_acts = pre_activations[:, batch_index, :] # Shape: (T, H) 25 | these_post_acts = post_activations[:, batch_index, :] # Shape: (T, H) 26 | these_inputs = inputs_to_model[:, batch_index, :, :, :] # Shape: (T, C, H, W) 27 | these_input_gates = input_gates[:, batch_index, :, :] # Shape: (T, H, W) 28 | these_predictions = predictions[batch_index, :, :] # Shape: (C, T) 29 | these_certainties = certainties[batch_index, :, :] # Shape: (C, T) 30 | this_target = targets[batch_index] # Shape: (C) 31 | 32 | logits_min, logits_max = np.min(these_predictions), np.max(these_predictions) 33 | probs_min, probs_max = 0, 1 34 | 35 | class_labels = [str(i) for i in range(10)] 36 | pad = 0.1 37 | if question_readable: 38 | this_question = question_readable[batch_index] 39 | pad = 1.6 40 | class_labels = ["" for i in range(len(these_predictions))] 41 | 42 | # Create mosaic layout 43 | mosaic = [['img_data', 'img_data', 'attention', 'attention', 'logits', 'logits', 'probs', 'probs'] for _ in range(2)] + \ 44 | [['img_data', 'img_data', 'attention', 'attention', 'logits', 'logits', 'probs', 'probs'] for _ in range(2)] + \ 45 | [['certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty', 'certainty']] + \ 46 | [[f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}'] for ti in range(n_neurons_to_visualise)] 47 | 48 | for stepi in range(n_steps): 49 | fig_gif, axes_gif = plt.subplot_mosaic(mosaic=mosaic, figsize=(31*figscale*8/4, 76*figscale)) 50 | 51 | if question_readable: 52 | if this_question: 53 | fig_gif.suptitle(this_question, fontsize=24) 54 | 55 | # Plot action log probs 56 | colors = [('g' if i == this_target else ('b' if e >= 0 else 'r')) for i, e in enumerate(these_predictions[:, stepi])] 57 | sort_idxs = np.arange(len(these_predictions[:, stepi])) 58 | bars = axes_gif['logits'].bar(np.arange(len(these_predictions[:,stepi])), these_predictions[:,stepi][sort_idxs], color=np.array(colors)[sort_idxs], width=0.9, alpha=0.5) 59 | axes_gif['logits'].set_title('Logits') 60 | axes_gif['logits'].axis('off') 61 | for bar, label in zip(bars, class_labels): 62 | x = bar.get_x() + bar.get_width() / 2 63 | axes_gif['logits'].annotate(label, xy=(x, 0), xytext=(1, 0), 64 | textcoords="offset points", 65 | ha='center', va='bottom', rotation=90) 66 | axes_gif['logits'].set_ylim([logits_min - 0.1 * abs(logits_min), logits_max + 0.1 * abs(logits_max)]) 67 | 68 | # Add probability plot 69 | probs = softmax(these_predictions[:, stepi]) 70 | bars_prob = axes_gif['probs'].bar(np.arange(len(probs)), probs[sort_idxs], 71 | color=np.array(colors)[sort_idxs], width=0.9, alpha=0.5) 72 | axes_gif['probs'].set_title('Probabilities') 73 | axes_gif['probs'].axis('off') 74 | axes_gif['probs'].set_ylim([0, 1]) 75 | for bar, label in zip(bars_prob, class_labels): 76 | x = bar.get_x() + bar.get_width() / 2 77 | axes_gif['probs'].annotate(label, xy=(x, 0), xytext=(1, 0), textcoords="offset points", ha='center', va='bottom', rotation=90) 78 | 79 | axes_gif['probs'].set_ylim([probs_min, probs_max]) 80 | 81 | # Add certainty plot 82 | axes_gif['certainty'].plot(np.arange(n_steps), these_certainties[1], 'k-', linewidth=2) 83 | axes_gif['certainty'].set_xlim([0, n_steps-1]) 84 | axes_gif['certainty'].axvline(x=stepi, color='black', linewidth=1, alpha=0.5) 85 | axes_gif['certainty'].set_xticklabels([]) 86 | axes_gif['certainty'].set_yticklabels([]) 87 | axes_gif['certainty'].grid(False) 88 | 89 | # Plot neuron traces 90 | for neuroni in range(n_neurons_to_visualise): 91 | ax = axes_gif[f'trace_{neuroni}'] 92 | 93 | pre_activation = these_pre_acts[:, neuroni] 94 | post_activation = these_post_acts[:, neuroni] 95 | 96 | ax_pre = ax.twinx() 97 | 98 | pre_min, pre_max = np.min(pre_activation), np.max(pre_activation) 99 | post_min, post_max = np.min(post_activation), np.max(post_activation) 100 | 101 | ax_pre.plot(np.arange(n_steps), pre_activation, 102 | color='grey', 103 | linestyle='--', 104 | linewidth=1, 105 | alpha=0.4, 106 | label='Pre-activation') 107 | 108 | color = 'blue' if neuroni % 2 else 'red' 109 | ax.plot(np.arange(n_steps), post_activation, 110 | color=color, 111 | linestyle='-', 112 | linewidth=2, 113 | alpha=1.0, 114 | label='Post-activation') 115 | 116 | ax.set_xlim([0, n_steps-1]) 117 | ax_pre.set_xlim([0, n_steps-1]) 118 | ax.set_ylim([post_min, post_max]) 119 | ax_pre.set_ylim([pre_min, pre_max]) 120 | 121 | ax.axvline(x=stepi, color='black', linewidth=1, alpha=0.5) 122 | 123 | ax.set_xticklabels([]) 124 | ax.set_yticklabels([]) 125 | ax.grid(False) 126 | 127 | ax_pre.set_xticklabels([]) 128 | ax_pre.set_yticklabels([]) 129 | ax_pre.grid(False) 130 | 131 | # Show input image 132 | this_image = these_inputs[stepi].transpose(1, 2, 0) 133 | # this_image = (this_image - this_image.min()) / (this_image.max() - this_image.min() + 1e-8) # Normalize to [0,1] 134 | axes_gif['img_data'].imshow(this_image, cmap='binary', vmin=0, vmax=1) 135 | axes_gif['img_data'].grid(False) 136 | axes_gif['img_data'].set_xticks([]) 137 | axes_gif['img_data'].set_yticks([]) 138 | 139 | # Create and show attention heatmap 140 | try: 141 | this_input_gate = these_input_gates[stepi] 142 | except (IndexError, TypeError): 143 | this_input_gate = np.zeros_like(these_input_gates[0]) 144 | gate_min, gate_max = np.nanmin(this_input_gate), np.nanmax(this_input_gate) 145 | if not np.isclose(gate_min, gate_max): 146 | normalized_gate = (this_input_gate - gate_min) / (gate_max - gate_min + 1e-8) 147 | else: 148 | normalized_gate = np.zeros_like(this_input_gate) 149 | input_heatmap = heatmap_cmap(normalized_gate)[:,:,:3] 150 | # Show heatmaps 151 | axes_gif['attention'].imshow(input_heatmap, vmin=0, vmax=1) 152 | axes_gif['attention'].axis('off') 153 | axes_gif['attention'].set_title('Attention') 154 | 155 | # Save frames 156 | fig_gif.tight_layout(pad=pad) 157 | if stepi == 0: 158 | fig_gif.savefig(filename.split('.gif')[0]+'_frame0.png', dpi=100) 159 | if stepi == 1: 160 | fig_gif.savefig(filename.split('.gif')[0]+'_frame1.png', dpi=100) 161 | if stepi == n_steps-1: 162 | fig_gif.savefig(filename.split('.gif')[0]+'_frame-1.png', dpi=100) 163 | 164 | # Convert to frame 165 | canvas = fig_gif.canvas 166 | canvas.draw() 167 | image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8') 168 | image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] 169 | frames.append(image_numpy) 170 | plt.close(fig_gif) 171 | 172 | imageio.mimsave(filename, frames, fps=15, loop=100) 173 | 174 | pass -------------------------------------------------------------------------------- /tasks/qamnist/scripts/train_ctm_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | MEMORY_LENGTH=3 4 | MODEL_TYPE="ctm" 5 | Q_NUM_REPEATS_PER_INPUT=1 6 | LOG_DIR="logs/qamnist/run${RUN}/${MODEL_TYPE}_${Q_NUM_REPEATS_PER_INPUT}" 7 | SEED=$((RUN - 1)) 8 | 9 | python -m tasks.qamnist.train \ 10 | --log_dir $LOG_DIR \ 11 | --seed $SEED \ 12 | --memory_length $MEMORY_LENGTH \ 13 | --model_type $MODEL_TYPE \ 14 | --q_num_images 3 \ 15 | --q_num_images_delta 2 \ 16 | --q_num_repeats_per_input $Q_NUM_REPEATS_PER_INPUT \ 17 | --q_num_operations 3 \ 18 | --q_num_operations_delta 2 \ 19 | --q_num_answer_steps 1 \ 20 | --n_test_batches 20 \ 21 | --d_model 1024 \ 22 | --d_input 64 \ 23 | --n_synch_out 32 \ 24 | --n_synch_action 32 \ 25 | --synapse_depth 1 \ 26 | --heads 4 \ 27 | --memory_hidden_dims 16 \ 28 | --dropout 0.0 \ 29 | --deep_memory \ 30 | --no-do_normalisation \ 31 | --weight_decay 0.0 \ 32 | --use_scheduler \ 33 | --scheduler_type "cosine" \ 34 | --milestones 0 0 0 \ 35 | --gamma 0 \ 36 | --batch_size 64 \ 37 | --batch_size_test 256 \ 38 | --lr=0.0001 \ 39 | --training_iterations 300001 \ 40 | --warmup_steps 500 \ 41 | --track_every 1000 \ 42 | --save_every 10000 \ 43 | --no-reload \ 44 | --no-reload_model_only \ 45 | --device 0 \ 46 | --no-use_amp \ 47 | --neuron_select_type "random" -------------------------------------------------------------------------------- /tasks/qamnist/scripts/train_ctm_10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=3 3 | MEMORY_LENGTH=30 4 | MODEL_TYPE="ctm" 5 | Q_NUM_REPEATS_PER_INPUT=10 6 | LOG_DIR="logs/qamnist/run${RUN}/${MODEL_TYPE}_${Q_NUM_REPEATS_PER_INPUT}" 7 | SEED=$((RUN - 1)) 8 | 9 | python -m tasks.qamnist.train \ 10 | --log_dir $LOG_DIR \ 11 | --seed $SEED \ 12 | --memory_length $MEMORY_LENGTH \ 13 | --model_type $MODEL_TYPE \ 14 | --q_num_images 3 \ 15 | --q_num_images_delta 2 \ 16 | --q_num_repeats_per_input $Q_NUM_REPEATS_PER_INPUT \ 17 | --q_num_operations 3 \ 18 | --q_num_operations_delta 2 \ 19 | --q_num_answer_steps 10 \ 20 | --n_test_batches 20 \ 21 | --d_model 1024 \ 22 | --d_input 64 \ 23 | --n_synch_out 32 \ 24 | --n_synch_action 32 \ 25 | --synapse_depth 1 \ 26 | --heads 4 \ 27 | --memory_hidden_dims 16 \ 28 | --dropout 0.0 \ 29 | --deep_memory \ 30 | --no-do_normalisation \ 31 | --weight_decay 0.0 \ 32 | --use_scheduler \ 33 | --scheduler_type "cosine" \ 34 | --milestones 0 0 0 \ 35 | --gamma 0 \ 36 | --batch_size 64 \ 37 | --batch_size_test 256 \ 38 | --lr=0.0001 \ 39 | --training_iterations 300001 \ 40 | --warmup_steps 500 \ 41 | --track_every 1000 \ 42 | --save_every 10000 \ 43 | --no-reload \ 44 | --no-reload_model_only \ 45 | --device 0 \ 46 | --no-use_amp \ 47 | --neuron_select_type "random" -------------------------------------------------------------------------------- /tasks/qamnist/scripts/train_lstm_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | MEMORY_LENGTH=3 4 | MODEL_TYPE="lstm" 5 | Q_NUM_REPEATS_PER_INPUT=1 6 | LOG_DIR="logs/qamnist/run${RUN}/${MODEL_TYPE}_${Q_NUM_REPEATS_PER_INPUT}" 7 | SEED=$((RUN - 1)) 8 | 9 | python -m tasks.qamnist.train \ 10 | --log_dir $LOG_DIR \ 11 | --seed $SEED \ 12 | --memory_length $MEMORY_LENGTH \ 13 | --model_type $MODEL_TYPE \ 14 | --q_num_images 3 \ 15 | --q_num_images_delta 2 \ 16 | --q_num_repeats_per_input $Q_NUM_REPEATS_PER_INPUT \ 17 | --q_num_operations 3 \ 18 | --q_num_operations_delta 2 \ 19 | --q_num_answer_steps 1 \ 20 | --n_test_batches 20 \ 21 | --d_model 741 \ 22 | --d_input 64 \ 23 | --n_synch_out 32 \ 24 | --n_synch_action 32 \ 25 | --synapse_depth 1 \ 26 | --heads 4 \ 27 | --memory_hidden_dims 16 \ 28 | --dropout 0.0 \ 29 | --deep_memory \ 30 | --no-do_normalisation \ 31 | --weight_decay 0.0 \ 32 | --use_scheduler \ 33 | --scheduler_type "cosine" \ 34 | --milestones 0 0 0 \ 35 | --gamma 0 \ 36 | --batch_size 64 \ 37 | --batch_size_test 256 \ 38 | --lr=0.0001 \ 39 | --training_iterations 300001 \ 40 | --warmup_steps 500 \ 41 | --track_every 1000 \ 42 | --save_every 10000 \ 43 | --no-reload \ 44 | --no-reload_model_only \ 45 | --device 0 \ 46 | --no-use_amp \ 47 | --neuron_select_type "random" -------------------------------------------------------------------------------- /tasks/qamnist/scripts/train_lstm_10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=3 3 | MEMORY_LENGTH=30 4 | MODEL_TYPE="lstm" 5 | Q_NUM_REPEATS_PER_INPUT=10 6 | LOG_DIR="logs/qamnist/run${RUN}/${MODEL_TYPE}_${Q_NUM_REPEATS_PER_INPUT}" 7 | SEED=$((RUN - 1)) 8 | 9 | python -m tasks.qamnist.train \ 10 | --log_dir $LOG_DIR \ 11 | --seed $SEED \ 12 | --memory_length $MEMORY_LENGTH \ 13 | --model_type $MODEL_TYPE \ 14 | --q_num_images 3 \ 15 | --q_num_images_delta 2 \ 16 | --q_num_repeats_per_input $Q_NUM_REPEATS_PER_INPUT \ 17 | --q_num_operations 3 \ 18 | --q_num_operations_delta 2 \ 19 | --q_num_answer_steps 10 \ 20 | --n_test_batches 20 \ 21 | --d_model 875 \ 22 | --d_input 64 \ 23 | --n_synch_out 32 \ 24 | --n_synch_action 32 \ 25 | --synapse_depth 1 \ 26 | --heads 4 \ 27 | --memory_hidden_dims 16 \ 28 | --dropout 0.0 \ 29 | --deep_memory \ 30 | --no-do_normalisation \ 31 | --weight_decay 0.0 \ 32 | --use_scheduler \ 33 | --scheduler_type "cosine" \ 34 | --milestones 0 0 0 \ 35 | --gamma 0 \ 36 | --batch_size 64 \ 37 | --batch_size_test 256 \ 38 | --lr=0.0001 \ 39 | --training_iterations 300001 \ 40 | --warmup_steps 500 \ 41 | --track_every 1000 \ 42 | --save_every 10000 \ 43 | --no-reload \ 44 | --no-reload_model_only \ 45 | --device 0 \ 46 | --no-use_amp \ 47 | --neuron_select_type "random" -------------------------------------------------------------------------------- /tasks/qamnist/utils.py: -------------------------------------------------------------------------------- 1 | from models.ctm_qamnist import ContinuousThoughtMachineQAMNIST 2 | from models.lstm_qamnist import LSTMBaseline 3 | from data.custom_datasets import QAMNISTDataset 4 | from torchvision import datasets 5 | from torchvision import transforms 6 | import numpy as np 7 | 8 | def get_dataset(q_num_images, q_num_images_delta, q_num_repeats_per_input, q_num_operations, q_num_operations_delta): 9 | dataset_mean = 0.1307 10 | dataset_std = 0.3081 11 | transform = transforms.Compose( 12 | [transforms.Resize(32), 13 | transforms.ToTensor(), 14 | transforms.Normalize(dataset_mean, dataset_std) 15 | ]) 16 | train_data = QAMNISTDataset(datasets.MNIST("data/", train=True, transform=transform, download=True), num_images=q_num_images, num_images_delta=q_num_images_delta, num_repeats_per_input=q_num_repeats_per_input, num_operations=q_num_operations, num_operations_delta=q_num_operations_delta) 17 | test_data = QAMNISTDataset(datasets.MNIST("data/", train=False, transform=transform, download=True), num_images=q_num_images, num_images_delta=q_num_images_delta, num_repeats_per_input=q_num_repeats_per_input, num_operations=q_num_operations, num_operations_delta=q_num_operations_delta) 18 | class_labels = [str(i) for i in np.arange(train_data.output_range[0], train_data.output_range[1]+1)] 19 | return train_data, test_data, class_labels, dataset_mean, dataset_std 20 | 21 | def prepare_model(args, device): 22 | if args.model_type == 'ctm': 23 | model = ContinuousThoughtMachineQAMNIST( 24 | iterations=args.iterations, 25 | d_model=args.d_model, 26 | d_input=args.d_input, 27 | heads=args.heads, 28 | n_synch_out=args.n_synch_out, 29 | n_synch_action=args.n_synch_action, 30 | synapse_depth=args.synapse_depth, 31 | memory_length=args.memory_length, 32 | deep_nlms=args.deep_memory, 33 | memory_hidden_dims=args.memory_hidden_dims, 34 | do_layernorm_nlm=args.do_normalisation, 35 | out_dims=args.out_dims, 36 | prediction_reshaper=[-1], 37 | dropout=args.dropout, 38 | neuron_select_type=args.neuron_select_type, 39 | n_random_pairing_self=args.n_random_pairing_self, 40 | iterations_per_digit=args.q_num_repeats_per_input, 41 | iterations_per_question_part=args.q_num_repeats_per_input, 42 | iterations_for_answering=args.q_num_answer_steps, 43 | ).to(device) 44 | elif args.model_type == 'lstm': 45 | model = LSTMBaseline( 46 | iterations=args.iterations, 47 | d_model=args.d_model, 48 | d_input=args.d_input, 49 | heads=args.heads, 50 | out_dims=args.out_dims, 51 | prediction_reshaper=[-1], 52 | iterations_per_digit=args.q_num_repeats_per_input, 53 | iterations_per_question_part=args.q_num_repeats_per_input, 54 | iterations_for_answering=args.q_num_answer_steps, 55 | ).to(device) 56 | else: 57 | raise ValueError(f"Model must be either ctm or lstm, not {args.model_type}") 58 | 59 | return model 60 | -------------------------------------------------------------------------------- /tasks/rl/README.md: -------------------------------------------------------------------------------- 1 | # RL 2 | 3 | ## Training 4 | To run the RL training that we used for the paper, run bash scripts from the root level of the repository. For example, to train the 2-iteration CTM on the Acrobot task, run: 5 | 6 | ``` 7 | bash tasks/rl/scripts/acrobot/train_ctm_2.sh 8 | ``` 9 | 10 | Note that tensorboard is used for monitoring training. It should be installed with: 11 | ``` 12 | pip install tensorboard 13 | ``` 14 | 15 | 16 | ## Analysis 17 | To run the analysis, first make sure the checkpoints are saved in the log directory (specified by the `log_dir` argument). The checkpoints can be obtained by either running the training code, or downloading them from [this link](https://drive.google.com/file/d/1VRl6qA5lX690A1X0emNg0nRH758XJEXJ/view?usp=drive_link). 18 | 19 | ``` 20 | python -m tasks.rl.analysis.run --log_dir 21 | ``` 22 | -------------------------------------------------------------------------------- /tasks/rl/analysis/make_blog_gifs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tqdm import tqdm 4 | import numpy as np 5 | import umap 6 | from matplotlib import pyplot as plt 7 | import imageio 8 | from scipy.special import softmax 9 | 10 | 11 | from tasks.rl.train import Agent 12 | from tasks.rl.analysis.run import get_training_data_from_checkpoint_path, get_size_action_space, prepare_env 13 | from tasks.rl.utils import combine_tracking_data 14 | from tasks.image_classification.plotting import save_frames_to_mp4 15 | 16 | 17 | def load_model(agent, checkpoint_path, device): 18 | checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) 19 | agent.load_state_dict(checkpoint['model_state_dict']) 20 | pass 21 | 22 | def interpolate_post_activations(arrays, target_length): 23 | interpolated = [] 24 | for arr in arrays: 25 | arr = arr.squeeze(1) 26 | T, D = arr.shape 27 | if T == target_length: 28 | interpolated.append(arr) 29 | continue 30 | x_old = np.linspace(0, 1, T) 31 | x_new = np.linspace(0, 1, target_length) 32 | arr_interp = np.array([ 33 | np.interp(x_new, x_old, arr[:, d]) for d in range(D) 34 | ]).T 35 | interpolated.append(arr_interp) 36 | return interpolated 37 | 38 | def make_rl_gif(post_activations, inputs_to_model, action_probs, actions, save_path, umap_positions, umap_point_scaler=1.0): 39 | 40 | batch_index = 0 41 | figscale = 0.32 42 | n_steps = action_probs.shape[0] 43 | frames = [] 44 | 45 | inputs_this_batch = inputs_to_model # Already shape (T, H, W, C) 46 | 47 | class_labels = ["Left", "Right", "Forward", "Pickup", "Drop", "Toggle", "Done"] 48 | 49 | post_act_this_batch = post_activations[:, batch_index] 50 | 51 | mosaic = [ 52 | [f"obs", f"obs", f"obs", f"obs", "probs", "probs", "probs", "probs"], 53 | [f"obs", f"obs", f"obs", f"obs", "probs", "probs","probs", "probs"], 54 | ] 55 | for _ in range(2, 8): 56 | mosaic.append( 57 | ["umap", "umap", "umap", "umap", "umap", "umap", "umap", "umap"] 58 | ) 59 | 60 | for t in range(n_steps): 61 | rows = len(mosaic) 62 | cell_size = figscale * 4 63 | fig_h = rows * cell_size 64 | 65 | probs_t = action_probs[t] 66 | 67 | fig, ax = plt.subplot_mosaic( 68 | mosaic, 69 | figsize=(6 * cell_size, fig_h), 70 | constrained_layout=False, 71 | gridspec_kw={'wspace': 0.05, 'hspace': 0.05}, # small gaps 72 | ) 73 | # restore a little margin 74 | fig.subplots_adjust(left=0.02, right=0.98, top=0.98, bottom=0.02) 75 | 76 | this_image = inputs_this_batch[t] 77 | if this_image.dtype != np.uint8: 78 | this_image = (np.clip(this_image, 0, 1) * 255).astype(np.uint8) 79 | if this_image.shape[-1] == 1: 80 | this_image = np.repeat(this_image, 3, axis=-1) 81 | 82 | 83 | ax["obs"].imshow(this_image) 84 | ax["obs"].axis("off") 85 | 86 | probs_t = action_probs[t] 87 | colors = ['black' if i == actions[t] else 'gray' for i in range(len(probs_t))] 88 | bars = ax["probs"].bar(np.arange(len(probs_t)), probs_t, color=colors, width=0.9, alpha=0.8) 89 | ax["probs"].axis('off') 90 | for bar, label in zip(bars, class_labels): 91 | x = bar.get_x() + bar.get_width() / 2 92 | ax["probs"].annotate(label, xy=(x, 0), xytext=(1, 0), 93 | textcoords="offset points", 94 | ha='center', va='bottom', rotation=90) 95 | ax["probs"].set_ylim([0, 1]) 96 | 97 | z = post_act_this_batch[t] 98 | low, high = np.percentile(z, 5), np.percentile(z, 95) 99 | z_norm = np.clip((z - low) / (high - low), 0, 1) 100 | point_sizes = (np.abs(z_norm - 0.5) * 100 + 5) * umap_point_scaler 101 | cmap = plt.get_cmap("Spectral") 102 | ax["umap"].scatter( 103 | umap_positions[:, 0], 104 | umap_positions[:, 1], 105 | s=point_sizes, 106 | c=cmap(z_norm), 107 | alpha=0.8 108 | ) 109 | ax["umap"].axis("off") 110 | 111 | canvas = fig.canvas 112 | canvas.draw() 113 | frame = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8) 114 | w, h = canvas.get_width_height() 115 | frames.append(frame.reshape(h, w, 4)[..., :3]) 116 | plt.close(fig) 117 | 118 | # save gif 119 | imageio.mimsave(f"{save_path}/activation.gif", frames, fps=15, loop=0) 120 | 121 | # save mp4 122 | save_frames_to_mp4( 123 | [fm[:, :, ::-1] for fm in frames], # RGB→BGR 124 | f"{save_path}/activation.mp4", 125 | fps=15, 126 | gop_size=1, 127 | preset="slow" 128 | ) 129 | 130 | 131 | def run_umap(agent, model_args): 132 | 133 | all_post_activations = [] 134 | point_counts = 150 135 | 136 | eval_env = prepare_env(model_args.env_id, model_args.max_environment_steps, mask_velocity=model_args.mask_velocity, render_mode="rgb_array") 137 | with tqdm(total=point_counts, desc="Collecting UMAP data") as pbar: 138 | for idx in range(point_counts): 139 | eval_next_obs, _ = eval_env.reset(seed=idx) 140 | eval_next_done = False 141 | eval_state = agent.get_initial_state(1) 142 | tracking_data_by_world_step = [] 143 | for environment_step_i in range(model_args.max_environment_steps): 144 | with torch.no_grad(): 145 | action, _, _, value, eval_state, tracking_data, action_logits, action_probs = agent.get_action_and_value( 146 | torch.Tensor(eval_next_obs).to(device).unsqueeze(0), 147 | eval_state, 148 | torch.Tensor([eval_next_done]).to(device), 149 | track=True 150 | ) 151 | eval_next_obs, reward, termination, truncation, _ = eval_env.step(action.cpu().numpy()[0]) 152 | eval_next_done = termination or truncation 153 | 154 | tracking_data['actions'] = np.tile(action.detach().cpu().numpy(), (model_args.iterations)) # Shape T 155 | tracking_data['values'] = np.tile(value.squeeze(-1).detach().cpu().numpy(), (model_args.iterations)) # Shape T 156 | tracking_data['action_logits'] = np.tile(action_logits.detach().cpu().numpy(), (model_args.iterations, 1)) # Shape T, A 157 | tracking_data['action_probs'] = np.tile(action_probs.detach().cpu().numpy(), (model_args.iterations, 1))# Shape T, A 158 | tracking_data['rewards'] = np.tile(np.array(reward), (model_args.iterations)) # Shape T 159 | tracking_data['inputs'] = np.tile(np.array(eval_env.render()), (model_args.iterations, 1, 1, 1)) # Shape T, H, W, C 160 | 161 | tracking_data_by_world_step.append(tracking_data) 162 | 163 | if eval_next_done: 164 | break 165 | 166 | eval_env.close() 167 | 168 | combined_tracking_data = combine_tracking_data(tracking_data_by_world_step) 169 | all_post_activations.append(combined_tracking_data['post_activations']) 170 | pbar.update(1) 171 | 172 | all_post_activations = interpolate_post_activations(all_post_activations, all_post_activations[-1].shape[0]) 173 | stacked = np.stack(all_post_activations, 1) 174 | umap_features = stacked.reshape(-1, stacked.shape[-1]) 175 | reducer = umap.UMAP( 176 | n_components=2, 177 | n_neighbors=40, 178 | min_dist=1, 179 | spread=1, 180 | metric='cosine', 181 | local_connectivity=1 182 | ) 183 | positions = reducer.fit_transform(umap_features.T) 184 | return combined_tracking_data, positions 185 | 186 | def run_model_and_make_gif(checkpoint_path, save_path, env_id, device): 187 | 188 | # Load the model 189 | _, _, _, _, _, model_args = get_training_data_from_checkpoint_path(checkpoint_path, device) 190 | agent = Agent(size_action_space=get_size_action_space(env_id), args=model_args, device=device).to(device) 191 | load_model(agent, checkpoint_path, device) 192 | 193 | # Run the umapping 194 | tracking_data, positions = run_umap(agent, model_args) 195 | 196 | make_rl_gif( 197 | post_activations=tracking_data['post_activations'], 198 | inputs_to_model=tracking_data['inputs'], 199 | action_probs=tracking_data['action_probs'], 200 | actions=tracking_data['actions'], 201 | save_path=save_path, 202 | umap_positions=positions, 203 | umap_point_scaler=1.0, 204 | ) 205 | 206 | 207 | pass 208 | 209 | 210 | if __name__ == "__main__": 211 | 212 | env_id = "MiniGrid-FourRooms-v0" 213 | 214 | CHECKPOINT_PATH = f"logs/rl/{env_id}/run1/ctm_2/checkpoint.pt" 215 | SAVE_PATH = f"tasks/rl/analysis/outputs/blog_gifs/{env_id}" 216 | os.makedirs(SAVE_PATH, exist_ok=True) 217 | 218 | device = "cuda" if torch.cuda.is_available() else "cpu" 219 | 220 | run_model_and_make_gif(CHECKPOINT_PATH, SAVE_PATH, env_id, device) 221 | -------------------------------------------------------------------------------- /tasks/rl/envs.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | import numpy as np 3 | 4 | class MaskVelocityWrapper(gym.Wrapper): 5 | """ 6 | Simple wrapper that automatically resets the environment on done. 7 | Modeled after EpisodicLifeEnv but simplified since we don't need 8 | to handle lives or partial resets. 9 | """ 10 | def __init__(self, env: gym.Env) -> None: 11 | super().__init__(env) 12 | 13 | def reset(self, **kwargs): 14 | obs, info = self.env.reset(**kwargs) 15 | return self._apply_velocity_mask(obs), info 16 | 17 | def step(self, action): 18 | obs, reward, terminated, truncated, info = self.env.step(action) 19 | return self._apply_velocity_mask(obs), reward, terminated, truncated, info 20 | 21 | def _apply_velocity_mask(self, observation): 22 | gym_id = self.env.spec.id 23 | if gym_id == "CartPole-v1": 24 | return self._apply_velocity_mask_cartpole(observation) 25 | elif gym_id == "Acrobot-v1": 26 | return self._apply_velocity_mask_acrobot(observation) 27 | else: 28 | raise NotImplementedError 29 | 30 | def _apply_velocity_mask_cartpole(self, observation): 31 | return observation * np.array([1, 0, 1, 0], dtype="float32") 32 | 33 | def _apply_velocity_mask_acrobot(self, observation): 34 | return observation * np.array([1, 1, 1, 1, 0, 0], dtype="float32") 35 | -------------------------------------------------------------------------------- /tasks/rl/plotting.py: -------------------------------------------------------------------------------- 1 | 2 | import matplotlib.pyplot as plt 3 | import matplotlib as mpl 4 | mpl.use('Agg') 5 | import seaborn as sns 6 | import numpy as np 7 | sns.set_style('darkgrid') 8 | import imageio 9 | 10 | 11 | def make_rl_gif(action_logits, action_probs, actions, values, rewards, pre_activations, post_activations, inputs, filename): 12 | 13 | n_steps = len(pre_activations) 14 | pre_activations = pre_activations[:,0,:] 15 | post_activations = post_activations[:,0,:] 16 | 17 | if action_logits.shape[1] == 5: 18 | class_labels = ['W', 'U', 'D', 'L', 'R'] 19 | elif action_logits.shape[1] == 2: 20 | class_labels = ['L', 'R'] 21 | else: 22 | class_labels = [str(i) for i in range(action_logits.shape[1])] 23 | 24 | max_target = len(class_labels) 25 | 26 | figscale = 0.28 27 | frames = [] 28 | n_neurons_to_visualise = 15 29 | 30 | # Create mosaic layout 31 | mosaic = [['img_data', 'img_data', 'img_data', 'img_data', 'action_logits', 'action_logits', 'action_log_probs', 'action_log_probs'] for _ in range(2)] + \ 32 | [['img_data', 'img_data', 'img_data', 'img_data', 'action_logits', 'action_logits', 'action_log_probs', 'action_log_probs'] for _ in range(2)] + \ 33 | [['value', 'value', 'value', 'value', 'value', 'value', 'value', 'value']] + \ 34 | [['reward', 'reward', 'reward', 'reward', 'reward', 'reward', 'reward', 'reward']] + \ 35 | [[f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}', f'trace_{ti}'] for ti in range(n_neurons_to_visualise)] 36 | 37 | 38 | # Main plotting loop 39 | for stepi in range(n_steps): 40 | fig_gif, axes_gif = plt.subplot_mosaic(mosaic=mosaic, figsize=(31*figscale*8/4, 76*figscale)) 41 | 42 | # Plot action logits 43 | these_action_logits = np.array(action_logits)[:, :max_target] 44 | colors = ['black' if i == actions[stepi] else ('b' if e >= 0 else 'r') 45 | for i, e in enumerate(these_action_logits[stepi])] 46 | sort_idxs = np.arange(len(these_action_logits[stepi])) 47 | bars = axes_gif['action_logits'].bar(np.arange(len(these_action_logits[stepi][sort_idxs])), these_action_logits[stepi][sort_idxs], color=np.array(colors)[sort_idxs],width=0.9, alpha=0.5) 48 | axes_gif['action_logits'].axis('off') 49 | for bar, label in zip(bars, class_labels): 50 | x = bar.get_x() + bar.get_width() / 2 51 | axes_gif['action_logits'].annotate(label, xy=(x, 0), xytext=(1, 0), 52 | textcoords="offset points", 53 | ha='center', va='bottom', rotation=90) 54 | axes_gif['action_logits'].set_ylim([np.min(these_action_logits), np.max(these_action_logits)]) 55 | 56 | 57 | # Plot action probs 58 | these_action_log_probs = np.array(action_probs)[:, :max_target] 59 | colors = ['black' if i == actions[stepi] else ('b' if e >= 0 else 'r') 60 | for i, e in enumerate(these_action_log_probs[stepi])] 61 | sort_idxs = np.arange(len(these_action_log_probs[stepi])) 62 | bars = axes_gif['action_log_probs'].bar(np.arange(len(these_action_log_probs[stepi][sort_idxs])), these_action_log_probs[stepi][sort_idxs], color=np.array(colors)[sort_idxs],width=0.9, alpha=0.5) 63 | axes_gif['action_log_probs'].axis('off') 64 | for bar, label in zip(bars, class_labels): 65 | x = bar.get_x() + bar.get_width() / 2 66 | axes_gif['action_log_probs'].annotate(label, xy=(x, 0), xytext=(1, 0), 67 | textcoords="offset points", 68 | ha='center', va='bottom', rotation=90) 69 | axes_gif['action_log_probs'].set_ylim([0,1]) 70 | 71 | # Plot value trace 72 | ax_value = axes_gif['value'] 73 | ax_value.plot(np.arange(n_steps), values, 'b-', linewidth=2) 74 | ax_value.axvline(x=stepi, color='k', linewidth=2, alpha=0.3) 75 | ax_value.set_xticklabels([]) 76 | ax_value.set_yticklabels([]) 77 | ax_value.grid(False) 78 | ax_value.set_xlim([0, n_steps-1]) 79 | 80 | # Plot reward trace 81 | ax_reward = axes_gif['reward'] 82 | ax_reward.plot(np.arange(n_steps), rewards, 'g-', linewidth=2) 83 | ax_reward.axvline(x=stepi, color='k', linewidth=2, alpha=0.3) 84 | ax_reward.set_xticklabels([]) 85 | ax_reward.set_yticklabels([]) 86 | ax_reward.grid(False) 87 | ax_reward.set_xlim([0, n_steps-1]) 88 | 89 | # Plot neuron traces 90 | for neuroni in range(n_neurons_to_visualise): 91 | ax = axes_gif[f'trace_{neuroni}'] 92 | 93 | pre_activation = pre_activations[:, neuroni] 94 | post_activation = post_activations[:, neuroni] 95 | 96 | ax_pre = ax.twinx() 97 | 98 | pre_min, pre_max = np.min(pre_activation), np.max(pre_activation) 99 | post_min, post_max = np.min(post_activation), np.max(post_activation) 100 | 101 | ax_pre.plot(np.arange(n_steps), pre_activation, 102 | color='grey', 103 | linestyle='--', 104 | linewidth=1, 105 | alpha=0.4, 106 | label='Pre-activation') 107 | 108 | color = 'blue' if neuroni % 2 else 'red' 109 | ax.plot(np.arange(n_steps), post_activation, 110 | color=color, 111 | linestyle='-', 112 | linewidth=2, 113 | alpha=1.0, 114 | label='Post-activation') 115 | 116 | ax.set_xlim([0, n_steps-1]) 117 | ax_pre.set_xlim([0, n_steps-1]) 118 | ax.set_ylim([post_min, post_max]) 119 | ax_pre.set_ylim([pre_min, pre_max]) 120 | 121 | ax.axvline(x=stepi, color='black', linewidth=1, alpha=0.5) 122 | 123 | ax.set_xticklabels([]) 124 | ax.set_yticklabels([]) 125 | ax.grid(False) 126 | 127 | ax_pre.set_xticklabels([]) 128 | ax_pre.set_yticklabels([]) 129 | ax_pre.grid(False) 130 | 131 | 132 | ax.set_xlim([0, n_steps-1]) 133 | ax.set_xticklabels([]) 134 | ax.grid(False) 135 | 136 | # Show input image 137 | this_image = inputs[stepi] 138 | axes_gif['img_data'].imshow(this_image, cmap='binary', vmin=0, vmax=1) 139 | axes_gif['img_data'].grid(False) 140 | axes_gif['img_data'].set_xticks([]) 141 | axes_gif['img_data'].set_yticks([]) 142 | 143 | # Save frames 144 | fig_gif.tight_layout(pad=0.1) 145 | if stepi == 0: 146 | fig_gif.savefig(filename.split('.gif')[0]+'_frame0.png', dpi=100) 147 | if stepi == 1: 148 | fig_gif.savefig(filename.split('.gif')[0]+'_frame1.png', dpi=100) 149 | if stepi == n_steps-1: 150 | fig_gif.savefig(filename.split('.gif')[0]+'_frame-1.png', dpi=100) 151 | 152 | # Convert to frame 153 | canvas = fig_gif.canvas 154 | canvas.draw() 155 | image_numpy = np.frombuffer(canvas.buffer_rgba(), dtype='uint8') 156 | image_numpy = image_numpy.reshape(*reversed(canvas.get_width_height()), 4)[:,:,:3] 157 | frames.append(image_numpy) 158 | plt.close(fig_gif) 159 | 160 | imageio.mimsave(filename, frames, fps=15, loop=100) 161 | 162 | -------------------------------------------------------------------------------- /tasks/rl/scripts/4rooms/train_ctm_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | ITERATIONS=1 4 | MODEL_TYPE="ctm" 5 | ENV_ID="MiniGrid-FourRooms-v0" 6 | LOG_DIR="logs/rl/${ENV_ID}/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 7 | RUN_NAME="run${RUN}_${ENV_ID}_${MODEL_TYPE}_${ITERATIONS}" 8 | TB_LOG_DIR="logs/runs/" 9 | SEED=$RUN 10 | 11 | python -m tasks.rl.train \ 12 | --model_type $MODEL_TYPE \ 13 | --env_id $ENV_ID \ 14 | --log_dir $LOG_DIR \ 15 | --tb_log_dir $TB_LOG_DIR \ 16 | --seed $SEED \ 17 | --iterations $ITERATIONS \ 18 | --run_name $RUN_NAME \ 19 | --d_model 512 \ 20 | --d_input 128 \ 21 | --memory_hidden_dims 16 \ 22 | --n_synch_out 32 \ 23 | --discount_gamma 0.99 \ 24 | --gae_lambda 0.95 \ 25 | --ent_coef 0.1 \ 26 | --vf_coef 0.25 \ 27 | --memory_length 10 \ 28 | --max_environment_steps 300 \ 29 | --total_timesteps 300000000 \ 30 | --num_steps 50 \ 31 | --anneal_lr \ 32 | --num_envs 256 \ 33 | --update_epochs 1 \ 34 | --mask_velocity \ 35 | --continuous_state_trace \ 36 | --dropout 0.0 \ 37 | --lr=0.0001 \ 38 | --track_every 100 \ 39 | --save_every 100 \ 40 | --no-reload \ 41 | --device 0 \ 42 | --neuron_select_type "first-last" -------------------------------------------------------------------------------- /tasks/rl/scripts/4rooms/train_ctm_2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | ITERATIONS=2 4 | MODEL_TYPE="ctm" 5 | ENV_ID="MiniGrid-FourRooms-v0" 6 | LOG_DIR="logs/rl/${ENV_ID}/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 7 | RUN_NAME="run${RUN}_${ENV_ID}_${MODEL_TYPE}_${ITERATIONS}" 8 | TB_LOG_DIR="logs/runs/" 9 | SEED=$RUN 10 | 11 | python -m tasks.rl.train \ 12 | --model_type $MODEL_TYPE \ 13 | --env_id $ENV_ID \ 14 | --log_dir $LOG_DIR \ 15 | --tb_log_dir $TB_LOG_DIR \ 16 | --seed $SEED \ 17 | --iterations $ITERATIONS \ 18 | --run_name $RUN_NAME \ 19 | --d_model 512 \ 20 | --d_input 128 \ 21 | --memory_hidden_dims 16 \ 22 | --n_synch_out 32 \ 23 | --discount_gamma 0.99 \ 24 | --gae_lambda 0.95 \ 25 | --ent_coef 0.1 \ 26 | --vf_coef 0.25 \ 27 | --memory_length 20 \ 28 | --max_environment_steps 300 \ 29 | --total_timesteps 300000000 \ 30 | --num_steps 50 \ 31 | --anneal_lr \ 32 | --num_envs 256 \ 33 | --update_epochs 1 \ 34 | --mask_velocity \ 35 | --continuous_state_trace \ 36 | --dropout 0.0 \ 37 | --lr=0.0001 \ 38 | --track_every 100 \ 39 | --save_every 100 \ 40 | --no-reload \ 41 | --device 0 \ 42 | --neuron_select_type "first-last" -------------------------------------------------------------------------------- /tasks/rl/scripts/4rooms/train_lstm_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | ITERATIONS=1 4 | MODEL_TYPE="lstm" 5 | ENV_ID="MiniGrid-FourRooms-v0" 6 | LOG_DIR="logs/rl/${ENV_ID}/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 7 | RUN_NAME="run${RUN}_${ENV_ID}_${MODEL_TYPE}_${ITERATIONS}" 8 | TB_LOG_DIR="logs/runs/" 9 | SEED=$RUN 10 | 11 | python -m tasks.rl.train \ 12 | --model_type $MODEL_TYPE \ 13 | --env_id $ENV_ID \ 14 | --log_dir $LOG_DIR \ 15 | --tb_log_dir $TB_LOG_DIR \ 16 | --seed $SEED \ 17 | --iterations $ITERATIONS \ 18 | --run_name $RUN_NAME \ 19 | --d_model 294 \ 20 | --d_input 128 \ 21 | --memory_hidden_dims 16 \ 22 | --n_synch_out 32 \ 23 | --discount_gamma 0.99 \ 24 | --gae_lambda 0.95 \ 25 | --ent_coef 0.1 \ 26 | --vf_coef 0.25 \ 27 | --memory_length 10 \ 28 | --max_environment_steps 300 \ 29 | --total_timesteps 300000000 \ 30 | --num_steps 50 \ 31 | --anneal_lr \ 32 | --num_envs 256 \ 33 | --update_epochs 1 \ 34 | --mask_velocity \ 35 | --continuous_state_trace \ 36 | --dropout 0.0 \ 37 | --lr=0.0001 \ 38 | --track_every 100 \ 39 | --save_every 100 \ 40 | --no-reload \ 41 | --device 0 -------------------------------------------------------------------------------- /tasks/rl/scripts/4rooms/train_lstm_2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | RUN=1 3 | ITERATIONS=2 4 | MODEL_TYPE="lstm" 5 | ENV_ID="MiniGrid-FourRooms-v0" 6 | LOG_DIR="logs/rl/${ENV_ID}/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 7 | RUN_NAME="run${RUN}_${ENV_ID}_${MODEL_TYPE}_${ITERATIONS}" 8 | TB_LOG_DIR="logs/runs/" 9 | SEED=$RUN 10 | 11 | python -m tasks.rl.train \ 12 | --model_type $MODEL_TYPE \ 13 | --env_id $ENV_ID \ 14 | --log_dir $LOG_DIR \ 15 | --tb_log_dir $TB_LOG_DIR \ 16 | --seed $SEED \ 17 | --iterations $ITERATIONS \ 18 | --run_name $RUN_NAME \ 19 | --d_model 300 \ 20 | --d_input 128 \ 21 | --memory_hidden_dims 16 \ 22 | --n_synch_out 32 \ 23 | --discount_gamma 0.99 \ 24 | --gae_lambda 0.95 \ 25 | --ent_coef 0.1 \ 26 | --vf_coef 0.25 \ 27 | --memory_length 20 \ 28 | --max_environment_steps 300 \ 29 | --total_timesteps 300000000 \ 30 | --num_steps 50 \ 31 | --anneal_lr \ 32 | --num_envs 256 \ 33 | --update_epochs 1 \ 34 | --mask_velocity \ 35 | --continuous_state_trace \ 36 | --dropout 0.0 \ 37 | --lr=0.0001 \ 38 | --track_every 100 \ 39 | --save_every 100 \ 40 | --no-reload \ 41 | --device 0 -------------------------------------------------------------------------------- /tasks/rl/scripts/acrobot/train_ctm_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for RUN in 1 2 3; do 3 | ITERATIONS=1 4 | MODEL_TYPE="ctm" 5 | ENV_ID="Acrobot-v1" 6 | LOG_DIR="logs/rl/${ENV_ID}/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 7 | RUN_NAME="run${RUN}_${ENV_ID}_${MODEL_TYPE}_${ITERATIONS}" 8 | TB_LOG_DIR="logs/runs/" 9 | SEED=$RUN 10 | 11 | python -m tasks.rl.train \ 12 | --model_type $MODEL_TYPE \ 13 | --env_id $ENV_ID \ 14 | --log_dir $LOG_DIR \ 15 | --tb_log_dir $TB_LOG_DIR \ 16 | --seed $SEED \ 17 | --iterations $ITERATIONS \ 18 | --run_name $RUN_NAME \ 19 | --d_model 256 \ 20 | --d_input 64 \ 21 | --memory_hidden_dims 4 \ 22 | --n_synch_out 16 \ 23 | --discount_gamma 0.99 \ 24 | --gae_lambda 0.95 \ 25 | --ent_coef 0.1 \ 26 | --vf_coef 0.25 \ 27 | --memory_length 5 \ 28 | --max_environment_steps 500 \ 29 | --total_timesteps 2000000 \ 30 | --num_steps 100 \ 31 | --anneal_lr \ 32 | --num_envs 12 \ 33 | --update_epochs 1 \ 34 | --mask_velocity \ 35 | --continuous_state_trace \ 36 | --dropout 0.0 \ 37 | --lr="5e-4" \ 38 | --track_every 1000 \ 39 | --save_every 100 \ 40 | --no-reload \ 41 | --device 0 \ 42 | --neuron_select_type "first-last" 43 | done -------------------------------------------------------------------------------- /tasks/rl/scripts/acrobot/train_ctm_2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for RUN in 1 2 3; do 3 | ITERATIONS=2 4 | MODEL_TYPE="ctm" 5 | ENV_ID="Acrobot-v1" 6 | LOG_DIR="logs/rl/${ENV_ID}/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 7 | RUN_NAME="run${RUN}_${ENV_ID}_${MODEL_TYPE}_${ITERATIONS}" 8 | TB_LOG_DIR="logs/runs/" 9 | SEED=$RUN 10 | 11 | python -m tasks.rl.train \ 12 | --model_type $MODEL_TYPE \ 13 | --env_id $ENV_ID \ 14 | --log_dir $LOG_DIR \ 15 | --tb_log_dir $TB_LOG_DIR \ 16 | --seed $SEED \ 17 | --iterations $ITERATIONS \ 18 | --run_name $RUN_NAME \ 19 | --d_model 256 \ 20 | --d_input 64 \ 21 | --memory_hidden_dims 4 \ 22 | --n_synch_out 16 \ 23 | --discount_gamma 0.99 \ 24 | --gae_lambda 0.95 \ 25 | --ent_coef 0.1 \ 26 | --vf_coef 0.25 \ 27 | --memory_length 10 \ 28 | --max_environment_steps 500 \ 29 | --total_timesteps 2000000 \ 30 | --num_steps 100 \ 31 | --anneal_lr \ 32 | --num_envs 12 \ 33 | --update_epochs 1 \ 34 | --mask_velocity \ 35 | --continuous_state_trace \ 36 | --dropout 0.0 \ 37 | --lr="5e-4" \ 38 | --track_every 1000 \ 39 | --save_every 100 \ 40 | --no-reload \ 41 | --device 0 \ 42 | --neuron_select_type "first-last" 43 | done -------------------------------------------------------------------------------- /tasks/rl/scripts/acrobot/train_ctm_5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for RUN in 1 2 3; do 3 | ITERATIONS=5 4 | MODEL_TYPE="ctm" 5 | ENV_ID="Acrobot-v1" 6 | LOG_DIR="logs/rl/${ENV_ID}/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 7 | RUN_NAME="run${RUN}_${ENV_ID}_${MODEL_TYPE}_${ITERATIONS}" 8 | TB_LOG_DIR="logs/runs/" 9 | SEED=$RUN 10 | 11 | python -m tasks.rl.train \ 12 | --model_type $MODEL_TYPE \ 13 | --env_id $ENV_ID \ 14 | --log_dir $LOG_DIR \ 15 | --tb_log_dir $TB_LOG_DIR \ 16 | --seed $SEED \ 17 | --iterations $ITERATIONS \ 18 | --run_name $RUN_NAME \ 19 | --d_model 256 \ 20 | --d_input 64 \ 21 | --memory_hidden_dims 4 \ 22 | --n_synch_out 16 \ 23 | --discount_gamma 0.99 \ 24 | --gae_lambda 0.95 \ 25 | --ent_coef 0.1 \ 26 | --vf_coef 0.25 \ 27 | --memory_length 25 \ 28 | --max_environment_steps 500 \ 29 | --total_timesteps 2000000 \ 30 | --num_steps 100 \ 31 | --anneal_lr \ 32 | --num_envs 12 \ 33 | --update_epochs 1 \ 34 | --mask_velocity \ 35 | --continuous_state_trace \ 36 | --dropout 0.0 \ 37 | --lr="5e-4" \ 38 | --track_every 1000 \ 39 | --save_every 100 \ 40 | --no-reload \ 41 | --device 0 \ 42 | --neuron_select_type "first-last" 43 | done -------------------------------------------------------------------------------- /tasks/rl/scripts/acrobot/train_lstm_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for RUN in 1 2 3; do 3 | ITERATIONS=1 4 | MODEL_TYPE="lstm" 5 | ENV_ID="Acrobot-v1" 6 | LOG_DIR="logs/rl/${ENV_ID}/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 7 | RUN_NAME="run${RUN}_${ENV_ID}_${MODEL_TYPE}_${ITERATIONS}" 8 | TB_LOG_DIR="logs/runs/" 9 | SEED=$RUN 10 | 11 | python -m tasks.rl.train \ 12 | --model_type $MODEL_TYPE \ 13 | --env_id $ENV_ID \ 14 | --log_dir $LOG_DIR \ 15 | --tb_log_dir $TB_LOG_DIR \ 16 | --seed $SEED \ 17 | --iterations $ITERATIONS \ 18 | --run_name $RUN_NAME \ 19 | --d_model 243 \ 20 | --d_input 64 \ 21 | --memory_hidden_dims 4 \ 22 | --n_synch_out 16 \ 23 | --discount_gamma 0.99 \ 24 | --gae_lambda 0.95 \ 25 | --ent_coef 0.1 \ 26 | --vf_coef 0.25 \ 27 | --memory_length 5 \ 28 | --max_environment_steps 500 \ 29 | --total_timesteps 2000000 \ 30 | --num_steps 100 \ 31 | --anneal_lr \ 32 | --num_envs 12 \ 33 | --update_epochs 1 \ 34 | --mask_velocity \ 35 | --continuous_state_trace \ 36 | --dropout 0.0 \ 37 | --lr="5e-4" \ 38 | --track_every 1000 \ 39 | --save_every 100 \ 40 | --no-reload \ 41 | --device 0 42 | done -------------------------------------------------------------------------------- /tasks/rl/scripts/acrobot/train_lstm_2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for RUN in 1 2 3; do 3 | ITERATIONS=2 4 | MODEL_TYPE="lstm" 5 | ENV_ID="Acrobot-v1" 6 | LOG_DIR="logs/rl/${ENV_ID}/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 7 | RUN_NAME="run${RUN}_${ENV_ID}_${MODEL_TYPE}_${ITERATIONS}" 8 | TB_LOG_DIR="logs/runs/" 9 | SEED=$RUN 10 | 11 | python -m tasks.rl.train \ 12 | --model_type $MODEL_TYPE \ 13 | --env_id $ENV_ID \ 14 | --log_dir $LOG_DIR \ 15 | --tb_log_dir $TB_LOG_DIR \ 16 | --seed $SEED \ 17 | --iterations $ITERATIONS \ 18 | --run_name $RUN_NAME \ 19 | --d_model 249 \ 20 | --d_input 64 \ 21 | --memory_hidden_dims 4 \ 22 | --n_synch_out 16 \ 23 | --discount_gamma 0.99 \ 24 | --gae_lambda 0.95 \ 25 | --ent_coef 0.1 \ 26 | --vf_coef 0.25 \ 27 | --memory_length 5 \ 28 | --max_environment_steps 500 \ 29 | --total_timesteps 2000000 \ 30 | --num_steps 100 \ 31 | --anneal_lr \ 32 | --num_envs 12 \ 33 | --update_epochs 1 \ 34 | --mask_velocity \ 35 | --continuous_state_trace \ 36 | --dropout 0.0 \ 37 | --lr="5e-4" \ 38 | --track_every 1000 \ 39 | --save_every 100 \ 40 | --no-reload \ 41 | --device 0 42 | done -------------------------------------------------------------------------------- /tasks/rl/scripts/acrobot/train_lstm_5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for RUN in 1 2 3; do 3 | ITERATIONS=5 4 | MODEL_TYPE="lstm" 5 | ENV_ID="Acrobot-v1" 6 | LOG_DIR="logs/rl/${ENV_ID}/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 7 | RUN_NAME="run${RUN}_${ENV_ID}_${MODEL_TYPE}_${ITERATIONS}" 8 | TB_LOG_DIR="logs/runs/" 9 | SEED=$RUN 10 | 11 | python -m tasks.rl.train \ 12 | --model_type $MODEL_TYPE \ 13 | --env_id $ENV_ID \ 14 | --log_dir $LOG_DIR \ 15 | --tb_log_dir $TB_LOG_DIR \ 16 | --seed $SEED \ 17 | --iterations $ITERATIONS \ 18 | --run_name $RUN_NAME \ 19 | --d_model 265 \ 20 | --d_input 64 \ 21 | --memory_hidden_dims 4 \ 22 | --n_synch_out 16 \ 23 | --discount_gamma 0.99 \ 24 | --gae_lambda 0.95 \ 25 | --ent_coef 0.1 \ 26 | --vf_coef 0.25 \ 27 | --memory_length 5 \ 28 | --max_environment_steps 500 \ 29 | --total_timesteps 2000000 \ 30 | --num_steps 100 \ 31 | --anneal_lr \ 32 | --num_envs 12 \ 33 | --update_epochs 1 \ 34 | --mask_velocity \ 35 | --continuous_state_trace \ 36 | --dropout 0.0 \ 37 | --lr="5e-4" \ 38 | --track_every 1000 \ 39 | --save_every 100 \ 40 | --no-reload \ 41 | --device 0 42 | done -------------------------------------------------------------------------------- /tasks/rl/scripts/cartpole/train_ctm_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for RUN in 1 2 3; do 3 | ITERATIONS=1 4 | MODEL_TYPE="ctm" 5 | ENV_ID="CartPole-v1" 6 | LOG_DIR="logs/rl/${ENV_ID}/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 7 | RUN_NAME="run${RUN}_${ENV_ID}_${MODEL_TYPE}_${ITERATIONS}" 8 | TB_LOG_DIR="logs/runs/" 9 | SEED=$RUN 10 | 11 | python -m tasks.rl.train \ 12 | --model_type $MODEL_TYPE \ 13 | --env_id $ENV_ID \ 14 | --log_dir $LOG_DIR \ 15 | --tb_log_dir $TB_LOG_DIR \ 16 | --seed $SEED \ 17 | --iterations $ITERATIONS \ 18 | --run_name $RUN_NAME \ 19 | --d_model 128 \ 20 | --d_input 128 \ 21 | --memory_hidden_dims 4 \ 22 | --n_synch_out 16 \ 23 | --discount_gamma 0.99 \ 24 | --gae_lambda 0.95 \ 25 | --ent_coef 0.1 \ 26 | --vf_coef 0.25 \ 27 | --memory_length 10 \ 28 | --max_environment_steps 200 \ 29 | --total_timesteps 10000000 \ 30 | --num_steps 100 \ 31 | --anneal_lr \ 32 | --num_envs 256 \ 33 | --update_epochs 4 \ 34 | --mask_velocity \ 35 | --continuous_state_trace \ 36 | --dropout 0.0 \ 37 | --lr=0.001 \ 38 | --track_every 1000 \ 39 | --save_every 100 \ 40 | --no-reload \ 41 | --device 0 \ 42 | --neuron_select_type "first-last" 43 | done -------------------------------------------------------------------------------- /tasks/rl/scripts/cartpole/train_ctm_2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for RUN in 1 2 3; do 3 | ITERATIONS=2 4 | MODEL_TYPE="ctm" 5 | ENV_ID="CartPole-v1" 6 | LOG_DIR="logs/rl/${ENV_ID}/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 7 | RUN_NAME="run${RUN}_${ENV_ID}_${MODEL_TYPE}_${ITERATIONS}" 8 | TB_LOG_DIR="logs/runs/" 9 | SEED=$RUN 10 | 11 | python -m tasks.rl.train \ 12 | --model_type $MODEL_TYPE \ 13 | --env_id $ENV_ID \ 14 | --log_dir $LOG_DIR \ 15 | --tb_log_dir $TB_LOG_DIR \ 16 | --seed $SEED \ 17 | --iterations $ITERATIONS \ 18 | --run_name $RUN_NAME \ 19 | --d_model 128 \ 20 | --d_input 128 \ 21 | --memory_hidden_dims 4 \ 22 | --n_synch_out 16 \ 23 | --discount_gamma 0.99 \ 24 | --gae_lambda 0.95 \ 25 | --ent_coef 0.1 \ 26 | --vf_coef 0.25 \ 27 | --memory_length 20 \ 28 | --max_environment_steps 200 \ 29 | --total_timesteps 10000000 \ 30 | --num_steps 100 \ 31 | --anneal_lr \ 32 | --num_envs 256 \ 33 | --update_epochs 4 \ 34 | --mask_velocity \ 35 | --continuous_state_trace \ 36 | --dropout 0.0 \ 37 | --lr=0.001 \ 38 | --track_every 1000 \ 39 | --save_every 100 \ 40 | --no-reload \ 41 | --device 0 \ 42 | --neuron_select_type "first-last" 43 | done -------------------------------------------------------------------------------- /tasks/rl/scripts/cartpole/train_ctm_5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for RUN in 1 2 3; do 3 | ITERATIONS=5 4 | MODEL_TYPE="ctm" 5 | ENV_ID="CartPole-v1" 6 | LOG_DIR="logs/rl/${ENV_ID}/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 7 | RUN_NAME="run${RUN}_${ENV_ID}_${MODEL_TYPE}_${ITERATIONS}" 8 | TB_LOG_DIR="logs/runs/" 9 | SEED=$RUN 10 | 11 | python -m tasks.rl.train \ 12 | --model_type $MODEL_TYPE \ 13 | --env_id $ENV_ID \ 14 | --log_dir $LOG_DIR \ 15 | --tb_log_dir $TB_LOG_DIR \ 16 | --seed $SEED \ 17 | --iterations $ITERATIONS \ 18 | --run_name $RUN_NAME \ 19 | --d_model 128 \ 20 | --d_input 128 \ 21 | --memory_hidden_dims 4 \ 22 | --n_synch_out 16 \ 23 | --discount_gamma 0.99 \ 24 | --gae_lambda 0.95 \ 25 | --ent_coef 0.1 \ 26 | --vf_coef 0.25 \ 27 | --memory_length 50 \ 28 | --max_environment_steps 200 \ 29 | --total_timesteps 10000000 \ 30 | --num_steps 100 \ 31 | --anneal_lr \ 32 | --num_envs 256 \ 33 | --update_epochs 4 \ 34 | --mask_velocity \ 35 | --continuous_state_trace \ 36 | --dropout 0.0 \ 37 | --lr=0.001 \ 38 | --track_every 1000 \ 39 | --save_every 100 \ 40 | --no-reload \ 41 | --device 0 \ 42 | --neuron_select_type "first-last" 43 | done -------------------------------------------------------------------------------- /tasks/rl/scripts/cartpole/train_lstm_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for RUN in 1 2 3; do 3 | ITERATIONS=1 4 | MODEL_TYPE="lstm" 5 | ENV_ID="CartPole-v1" 6 | LOG_DIR="logs/rl/${ENV_ID}/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 7 | RUN_NAME="run${RUN}_${ENV_ID}_${MODEL_TYPE}_${ITERATIONS}" 8 | TB_LOG_DIR="logs/runs/" 9 | SEED=$RUN 10 | 11 | python -m tasks.rl.train \ 12 | --model_type $MODEL_TYPE \ 13 | --env_id $ENV_ID \ 14 | --log_dir $LOG_DIR \ 15 | --tb_log_dir $TB_LOG_DIR \ 16 | --seed $SEED \ 17 | --iterations $ITERATIONS \ 18 | --run_name $RUN_NAME \ 19 | --d_model 118 \ 20 | --d_input 128 \ 21 | --memory_hidden_dims 4 \ 22 | --n_synch_out 16 \ 23 | --discount_gamma 0.99 \ 24 | --gae_lambda 0.95 \ 25 | --ent_coef 0.1 \ 26 | --vf_coef 0.25 \ 27 | --memory_length 10 \ 28 | --max_environment_steps 200 \ 29 | --total_timesteps 10000000 \ 30 | --num_steps 100 \ 31 | --anneal_lr \ 32 | --num_envs 256 \ 33 | --update_epochs 4 \ 34 | --mask_velocity \ 35 | --continuous_state_trace \ 36 | --dropout 0.0 \ 37 | --lr=0.001 \ 38 | --track_every 1000 \ 39 | --save_every 100 \ 40 | --no-reload \ 41 | --device 0 42 | done -------------------------------------------------------------------------------- /tasks/rl/scripts/cartpole/train_lstm_2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for RUN in 1 2 3; do 3 | ITERATIONS=2 4 | MODEL_TYPE="lstm" 5 | ENV_ID="CartPole-v1" 6 | LOG_DIR="logs/rl/${ENV_ID}/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 7 | RUN_NAME="run${RUN}_${ENV_ID}_${MODEL_TYPE}_${ITERATIONS}" 8 | TB_LOG_DIR="logs/runs/" 9 | SEED=$RUN 10 | 11 | python -m tasks.rl.train \ 12 | --model_type $MODEL_TYPE \ 13 | --env_id $ENV_ID \ 14 | --log_dir $LOG_DIR \ 15 | --tb_log_dir $TB_LOG_DIR \ 16 | --seed $SEED \ 17 | --iterations $ITERATIONS \ 18 | --run_name $RUN_NAME \ 19 | --d_model 126 \ 20 | --d_input 128 \ 21 | --memory_hidden_dims 4 \ 22 | --n_synch_out 16 \ 23 | --discount_gamma 0.99 \ 24 | --gae_lambda 0.95 \ 25 | --ent_coef 0.1 \ 26 | --vf_coef 0.25 \ 27 | --memory_length 20 \ 28 | --max_environment_steps 200 \ 29 | --total_timesteps 10000000 \ 30 | --num_steps 100 \ 31 | --anneal_lr \ 32 | --num_envs 256 \ 33 | --update_epochs 4 \ 34 | --mask_velocity \ 35 | --continuous_state_trace \ 36 | --dropout 0.0 \ 37 | --lr=0.001 \ 38 | --track_every 1000 \ 39 | --save_every 100 \ 40 | --no-reload \ 41 | --device 0 42 | done -------------------------------------------------------------------------------- /tasks/rl/scripts/cartpole/train_lstm_5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | for RUN in 1 2 3; do 3 | ITERATIONS=5 4 | MODEL_TYPE="lstm" 5 | ENV_ID="CartPole-v1" 6 | LOG_DIR="logs/rl/${ENV_ID}/run${RUN}/${MODEL_TYPE}_${ITERATIONS}" 7 | RUN_NAME="run${RUN}_${ENV_ID}_${MODEL_TYPE}_${ITERATIONS}" 8 | TB_LOG_DIR="logs/runs/" 9 | SEED=$RUN 10 | 11 | python -m tasks.rl.train \ 12 | --model_type $MODEL_TYPE \ 13 | --env_id $ENV_ID \ 14 | --log_dir $LOG_DIR \ 15 | --tb_log_dir $TB_LOG_DIR \ 16 | --seed $SEED \ 17 | --iterations $ITERATIONS \ 18 | --run_name $RUN_NAME \ 19 | --d_model 148 \ 20 | --d_input 128 \ 21 | --memory_hidden_dims 4 \ 22 | --n_synch_out 16 \ 23 | --discount_gamma 0.99 \ 24 | --gae_lambda 0.95 \ 25 | --ent_coef 0.1 \ 26 | --vf_coef 0.25 \ 27 | --memory_length 50 \ 28 | --max_environment_steps 200 \ 29 | --total_timesteps 10000000 \ 30 | --num_steps 100 \ 31 | --anneal_lr \ 32 | --num_envs 256 \ 33 | --update_epochs 1 \ 34 | --mask_velocity \ 35 | --continuous_state_trace \ 36 | --dropout 0.0 \ 37 | --lr=0.001 \ 38 | --track_every 1000 \ 39 | --save_every 100 \ 40 | --no-reload \ 41 | --device 0 42 | done -------------------------------------------------------------------------------- /tasks/rl/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def combine_tracking_data(tracking_history): 4 | combined_data = {} 5 | keys = tracking_history[0].keys() 6 | for key in keys: 7 | arrays_to_concat = [data[key] for data in tracking_history] 8 | combined_data[key] = np.concatenate(arrays_to_concat, axis=0) 9 | 10 | return combined_data 11 | -------------------------------------------------------------------------------- /tasks/sort/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def decode_predictions(predictions, blank_label=0, return_wait_times=False): 5 | """ 6 | Decodes the predictions using greedy decoding (best path), correctly handling duplicates. 7 | 8 | Args: 9 | predictions: A tensor of shape [B, C, L] representing the logits. 10 | blank_label: The index of the blank label. 11 | 12 | Returns: 13 | A list of tensors, where each tensor is the decoded sequence. 14 | """ 15 | 16 | batch_size, num_classes, prediction_length = predictions.shape 17 | decoded_sequences = [] 18 | wait_times_all = [] 19 | probs = F.softmax(predictions, dim=1) # Probabilities 20 | for b in range(batch_size): 21 | best_path = torch.argmax(probs[b], dim=0) # Best path indices 22 | decoded = [] 23 | wait_times = [] 24 | 25 | prev_char = -1 # Keep track of the previous character 26 | wait_time_now = 0 27 | for t in range(prediction_length): 28 | char_idx = best_path[t].item() # Get index as integer 29 | if char_idx != blank_label and char_idx != prev_char: # Skip blanks and duplicates 30 | decoded.append(char_idx) 31 | prev_char = char_idx # Update previous character 32 | wait_times.append(wait_time_now) 33 | wait_time_now = 0 34 | else: 35 | wait_time_now += 1 36 | decoded_sequences.append(torch.tensor(decoded, device=predictions.device)) 37 | if return_wait_times: wait_times_all.append(torch.tensor(wait_times, device=predictions.device)) 38 | 39 | if return_wait_times: return decoded_sequences, wait_times_all 40 | 41 | return decoded_sequences 42 | 43 | def compute_ctc_accuracy(predictions, targets, blank_label=0): 44 | """ 45 | Computes the accuracy of the predictions given the targets, considering CTC decoding. 46 | 47 | Args: 48 | predictions: A tensor of shape [B, C, L] representing the logits. 49 | targets: A list of tensors, each of shape [T_i], representing a target sequence. 50 | blank_label: The index of the blank label. 51 | 52 | Returns: 53 | The accuracy (a float). 54 | """ 55 | 56 | batch_size, num_classes, prediction_length = predictions.shape 57 | total_correct = 0 58 | 59 | # 1. Get predicted sequences (decoded from logits): 60 | predicted_sequences = decode_predictions(predictions, blank_label) 61 | 62 | # 2. Compare predicted sequences to targets: 63 | for i in range(batch_size): 64 | target = targets[i] 65 | predicted = predicted_sequences[i] 66 | 67 | if torch.equal(predicted, target): # Direct comparison of tensors 68 | total_correct += 1 69 | 70 | accuracy = total_correct / batch_size if batch_size > 0 else 0.0 71 | return accuracy -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | # Tests 2 | 3 | To execute all tests, run the following: 4 | 5 | ```bash 6 | PYTHONPATH=. pytest tests/tests.py 7 | ``` 8 | 9 | ## Golden Test 10 | The golden test is designed to check if changes in the code have resulted in a difference in model behaviour. It takes a pre-defined input, runs the model, and compares the output to what is expected. If something changed in the model and the outputs are different, the test will fail. 11 | 12 | The golden test can be ran with: 13 | ```bash 14 | PYTHONPATH=. pytest tests/tests.py::test_golden_ 15 | ``` 16 | 17 | where `` is one of `parity`, `qamnist`, `rl`. -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SakanaAI/continuous-thought-machines/aecfb63ac42db7a20903ee27489ff71671669474/tests/__init__.py -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SakanaAI/continuous-thought-machines/aecfb63ac42db7a20903ee27489ff71671669474/utils/__init__.py -------------------------------------------------------------------------------- /utils/housekeeping.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | 5 | 6 | import os 7 | import zipfile 8 | import glob 9 | 10 | def zip_python_code(output_filename): 11 | """ 12 | Zips all .py files in the current repository and saves it to the 13 | specified output filename. 14 | 15 | Args: 16 | output_filename: The name of the output zip file. 17 | Defaults to "python_code_backup.zip". 18 | """ 19 | 20 | with zipfile.ZipFile(output_filename, 'w') as zipf: 21 | files = glob.glob('models/**/*.py', recursive=True) + glob.glob('utils/**/*.py', recursive=True) + glob.glob('tasks/**/*.py', recursive=True) + glob.glob('*.py', recursive=True) 22 | for file in files: 23 | root = '/'.join(file.split('/')[:-1]) 24 | nm = file.split('/')[-1] 25 | zipf.write(os.path.join(root, nm)) 26 | 27 | def set_seed(seed=42, deterministic=True): 28 | """ 29 | ... and the answer is ... 30 | """ 31 | random.seed(seed) 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) 35 | torch.backends.cudnn.deterministic = deterministic 36 | torch.backends.cudnn.benchmark = False 37 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def compute_ctc_loss(predictions, targets, blank_label=0): 7 | """ 8 | Computes the Connectionist Temporal Classification (CTC) loss. 9 | 10 | Args: 11 | predictions: A tensor of shape [B, C, L] representing the logits of the 12 | predicted sequences. B is the batch size, C is the number 13 | of classes (including the blank label), and L is the sequence 14 | length of the predictions. 15 | targets: A tensor of shape [B, T] representing the target sequences. 16 | B is the batch size and T is the target sequence length. 17 | Note that T can vary within the batch. 18 | blank_label: The index of the blank label. Defaults to 0. 19 | 20 | Returns: 21 | The CTC loss (a scalar tensor). 22 | """ 23 | 24 | batch_size, num_classes, prediction_length = predictions.shape 25 | _, target_length = targets.shape 26 | 27 | # 1. Log softmax on predictions: Crucially, CTC loss requires log probabilities. 28 | log_probs = F.log_softmax(predictions, dim=1) # Shape: [B, C, L] 29 | 30 | # 2. Prepare inputs for torch.nn.CTCLoss: 31 | # a. Convert log_probs to shape (L, B, C): CTCLoss expects time first. 32 | log_probs = log_probs.permute(2, 0, 1) # Shape: [L, B, C] 33 | 34 | # b. Get lengths of the predicted sequences (all L in this case). 35 | input_lengths = torch.full(size=(batch_size,), fill_value=prediction_length, dtype=torch.long) 36 | 37 | # c. Get lengths of the target sequences. 38 | target_lengths = torch.tensor([t.shape[0] for t in targets], dtype=torch.long) # Handle variable target lengths 39 | 40 | # 3. Create the CTCLoss criterion. `blank=blank_label` is essential! 41 | ctc_loss = torch.nn.CTCLoss(blank=blank_label, reduction='mean') # 'mean' for averaging over the batch 42 | 43 | # 4. Calculate the loss. `targets` needs to be a concatenated tensor. 44 | # We handle padding by only passing the valid lengths to CTCLoss. 45 | concatenated_targets = torch.cat(list(targets)) # Concatenate targets 46 | 47 | loss = ctc_loss(log_probs, concatenated_targets, input_lengths, target_lengths) 48 | 49 | return loss 50 | 51 | def sort_loss(predictions, targets): 52 | """ 53 | The sort task was used partly to show that ctc loss can work. 54 | """ 55 | loss = compute_ctc_loss(predictions, targets, blank_label=predictions.shape[1]-1) 56 | return loss 57 | 58 | def image_classification_loss(predictions, certainties, targets, use_most_certain=True): 59 | """ 60 | Computes the maze loss with auto-extending cirriculum. 61 | 62 | Predictions are of shape: (B, class, internal_ticks), 63 | Certainties are of shape: (B, 2, internal_ticks), 64 | where the inside dimension (2) is [normalised_entropy, 1-normalised_entropy] 65 | Targets are of shape: [B] 66 | 67 | use_most_certain will select either the most certain point or the final point. 68 | """ 69 | targets_expanded = torch.repeat_interleave(targets.unsqueeze(-1), predictions.size(-1), -1) 70 | # Losses are of shape [B, internal_ticks] 71 | losses = nn.CrossEntropyLoss(reduction='none')(predictions, targets_expanded) 72 | 73 | loss_index_1 = losses.argmin(dim=1) 74 | loss_index_2 = certainties[:,1].argmax(-1) 75 | if not use_most_certain: # Revert to final loss if set 76 | loss_index_2[:] = -1 77 | 78 | batch_indexer = torch.arange(predictions.size(0), device=predictions.device) 79 | loss_minimum_ce = losses[batch_indexer, loss_index_1].mean() 80 | loss_selected = losses[batch_indexer, loss_index_2].mean() 81 | 82 | loss = (loss_minimum_ce + loss_selected)/2 83 | return loss, loss_index_2 84 | 85 | def maze_loss(predictions, certainties, targets, cirriculum_lookahead=5, use_most_certain=True): 86 | """ 87 | Computes the maze loss with auto-extending cirriculum. 88 | 89 | Predictions are of shape: (B, route_length, class, internal_ticks), 90 | where classes are in [0,1,2,3,4] for [Up, Down, Left, Right, Wait] 91 | Certainties are of shape: (B, 2, internal_ticks), 92 | where the inside dimension (2) is [normalised_entropy, 1-normalised_entropy] 93 | Targets are of shape: [B, route_length] 94 | 95 | cirriculum_lookahead: how far to look ahead in the auto-cirriculum 96 | 97 | use_most_certain will select either the most certain point or the final point. For baselines, 98 | the final point proved the only usable option. 99 | 100 | """ 101 | # Predictions reshaped to: [B*route_length, 5, internal_ticks] 102 | predictions_reshaped = predictions.flatten(0,1) 103 | # Targets reshaped to: [B*route_length, internal_ticks] 104 | targets_reshaped = torch.repeat_interleave(targets.unsqueeze(-1), 105 | predictions.size(-1), -1).flatten(0,1).long() 106 | 107 | # Losses are of shape [B, route_length, internal_ticks] 108 | losses = nn.CrossEntropyLoss(reduction='none')(predictions_reshaped, targets_reshaped) 109 | losses = losses.reshape(predictions[:,:,0].shape) 110 | 111 | # Below is the code for auto-cirriculum 112 | # Find where correct, and make sure to always push +5 beyond that 113 | iscorrects = (predictions.argmax(2) == targets.unsqueeze(-1)).cumsum(1) 114 | correct_mask = (iscorrects == torch.arange(1, iscorrects.size(1)+1, device=iscorrects.device).reshape(1, -1, 1)) 115 | correct_mask[:,0,:] = 1 116 | upto_where = correct_mask.cumsum(1).argmax(1).max(-1)[0]+cirriculum_lookahead 117 | loss_mask = torch.zeros_like(losses) 118 | for bi in range(predictions.size(0)): 119 | loss_mask[bi, :upto_where[bi]] = 1 120 | 121 | # Reduce losses along route dimension 122 | # Will now be of shape [B, internal_ticks] 123 | losses = (losses * loss_mask).sum(1)/(loss_mask.sum(1)) 124 | 125 | loss_index_1 = losses.argmin(dim=1) 126 | loss_index_2 = certainties[:,1].argmax(-1) 127 | if not use_most_certain: 128 | loss_index_2[:] = -1 129 | 130 | batch_indexer = torch.arange(predictions.size(0), device=predictions.device) 131 | loss_minimum_ce = losses[batch_indexer, loss_index_1] 132 | loss_selected = losses[batch_indexer, loss_index_2] 133 | 134 | loss = ((loss_minimum_ce + loss_selected)/2).mean() 135 | return loss, loss_index_2, upto_where.detach().cpu().numpy() 136 | 137 | def parity_loss(predictions, certainties, targets, use_most_certain=True): 138 | """ 139 | Computes the parity loss. 140 | 141 | Predictions are of shape: (B, parity_sequence_length, class, internal_ticks), 142 | where classes are in [0,1,2,3,4] for [Up, Down, Left, Right, Wait] 143 | Certainties are of shape: (B, 2, internal_ticks), 144 | where the inside dimension (2) is [normalised_entropy, 1-normalised_entropy] 145 | Targets are of shape: [B, parity_sequence_length] 146 | 147 | use_most_certain will select either the most certain point or the final point. For baselines, 148 | the final point proved the only usable option. 149 | """ 150 | 151 | # Losses are of shape [B, parity_sequence_length, internal_ticks] 152 | losses = nn.CrossEntropyLoss(reduction='none')(predictions.flatten(0,1), 153 | torch.repeat_interleave(targets.unsqueeze(-1), 154 | predictions.size(-1), -1).flatten(0,1).long()).reshape(predictions[:,:,0].shape) 155 | 156 | # Average the loss over the parity sequenece dimension 157 | losses = losses.mean(1) 158 | 159 | loss_index_1 = losses.argmin(dim=1) 160 | loss_index_2 = certainties[:,1].argmax(-1) 161 | if not use_most_certain: 162 | loss_index_2[:] = -1 163 | 164 | batch_indexer = torch.arange(predictions.size(0), device=predictions.device) 165 | loss_minimum_ce = losses[batch_indexer, loss_index_1].mean() 166 | loss_selected = losses[batch_indexer, loss_index_2].mean() 167 | 168 | loss = (loss_minimum_ce + loss_selected)/2 169 | return loss, loss_index_2 170 | 171 | 172 | def qamnist_loss(predictions, certainties, targets, use_most_certain=True): 173 | """ 174 | Computes the qamnist loss over the last num_answer_steps steps. 175 | 176 | Predictions are of shape: (B, class, internal_ticks), 177 | Certainties are of shape: (B, 2, internal_ticks), 178 | where the inside dimension (2) is [normalised_entropy, 1-normalised_entropy] 179 | Targets are of shape: [B] 180 | num_answer_steps: number of steps to consider for the loss 181 | 182 | use_most_certain will select either the most certain point or the final point. 183 | """ 184 | 185 | losses = nn.CrossEntropyLoss(reduction='none')(predictions, 186 | torch.repeat_interleave(targets.unsqueeze(-1), predictions.size(-1), -1)) 187 | 188 | loss_index_1 = losses.argmin(dim=1) 189 | loss_index_2 = certainties[:,1].argmax(-1) 190 | if not use_most_certain: 191 | loss_index_2[:] = -1 192 | 193 | batch_indexer = torch.arange(predictions.size(0), device=predictions.device) 194 | loss_minimum_ce = losses[batch_indexer, loss_index_1].mean() 195 | loss_selected = losses[batch_indexer, loss_index_2].mean() 196 | 197 | loss = (loss_minimum_ce + loss_selected)/2 198 | return loss, loss_index_2 -------------------------------------------------------------------------------- /utils/samplers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from torch.utils.data import Sampler 4 | import math 5 | import itertools 6 | import numpy as np 7 | 8 | class FastRandomDistributedSampler(Sampler[int]): 9 | r""" 10 | A distributed sampler that continuously yields random indices with replacement, 11 | avoiding frequent iterator recreation overhead for DataLoader. 12 | 13 | Instead of stopping after one pass through the dataset, this sampler's 14 | iterator yields a specified number of indices (`epoch_steps`) before 15 | stopping. This significantly reduces the frequency of DataLoader worker 16 | restarts when the underlying dataset is small. 17 | 18 | Args: 19 | dataset: Dataset used for sampling. 20 | num_replicas (int, optional): Number of processes participating in 21 | distributed training. Defaults to current world size. 22 | rank (int, optional): Rank of the current process. Defaults to current rank. 23 | seed (int): Base seed for the random number generator. Each epoch/rank 24 | gets a different derived seed. Defaults to 0. 25 | epoch_steps (int): The number of indices this sampler should yield per 26 | __iter__ call (per replica). Set this to a large number 27 | to reduce iterator recreation frequency. If None, it defaults 28 | to ceil(len(dataset) / num_replicas). 29 | """ 30 | def __init__(self, dataset, num_replicas=None, rank=None, seed=0, epoch_steps=None): 31 | if num_replicas is None: 32 | if not dist.is_available() or not dist.is_initialized(): 33 | raise RuntimeError("Requires distributed package to be available and initialized") 34 | num_replicas = dist.get_world_size() 35 | if rank is None: 36 | if not dist.is_available() or not dist.is_initialized(): 37 | raise RuntimeError("Requires distributed package to be available and initialized") 38 | rank = dist.get_rank() 39 | if rank >= num_replicas or rank < 0: 40 | raise ValueError( 41 | f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") 42 | 43 | self.dataset = dataset 44 | self.num_replicas = num_replicas 45 | self.rank = rank 46 | self.seed = seed 47 | self.epoch = 0 48 | self.dataset_len = len(self.dataset) 49 | 50 | # Determine the number of steps/indices per iterator cycle for this rank 51 | if epoch_steps is None: 52 | # Default behavior: roughly one pass over the data 53 | self.num_samples_per_epoch = math.ceil(self.dataset_len / self.num_replicas) 54 | else: 55 | # User-defined length for the iterator cycle 56 | self.num_samples_per_epoch = epoch_steps 57 | 58 | if not isinstance(self.num_samples_per_epoch, int) or self.num_samples_per_epoch <= 0: 59 | raise ValueError("epoch_steps must be a positive integer") 60 | 61 | def _infinite_indices(self): 62 | """A generator that yields random indices indefinitely.""" 63 | g = torch.Generator() 64 | # Ensure distinct seeds based on rank, epoch, and base seed 65 | current_seed = self.seed + self.epoch * self.num_replicas + self.rank 66 | g.manual_seed(current_seed) 67 | while True: 68 | yield torch.randint(low=0, high=self.dataset_len, size=(1,), generator=g).item() 69 | 70 | def __iter__(self): 71 | """ 72 | Returns an iterator that yields 'num_samples_per_epoch' indices. 73 | It uses itertools.islice to take a finite slice from the 74 | infinite generator, avoiding expensive list creation. 75 | """ 76 | # Create the infinite generator and slice it 77 | # The generator state is preserved across calls to next() by the DataLoader 78 | # The expensive DataLoader setup only happens when this sliced iterator is exhausted 79 | return itertools.islice(self._infinite_indices(), self.num_samples_per_epoch) 80 | 81 | def __len__(self): 82 | """The number of samples produced by the iterator per __iter__ call.""" 83 | return self.num_samples_per_epoch 84 | 85 | def set_epoch(self, epoch: int) -> None: 86 | """ 87 | Sets the epoch for this sampler. This is used to vary the random seed sequence 88 | each time __iter__ is called. 89 | """ 90 | self.epoch = epoch 91 | 92 | class QAMNISTSampler(Sampler): 93 | def __init__(self, dataset, batch_size): 94 | self.dataset = dataset 95 | self.batch_size = batch_size 96 | self.num_samples = len(dataset) 97 | 98 | def __iter__(self): 99 | indices = torch.randperm(self.num_samples).tolist() 100 | for i in range(0, self.num_samples, self.batch_size): 101 | batch_indices = indices[i:i + self.batch_size] 102 | 103 | if self.dataset.num_images_range[0] == self.dataset.num_images_range[1]: 104 | batch_num_digits = self.dataset.num_images_range[0] 105 | else: 106 | batch_num_digits = np.random.randint(self.dataset.num_images_range[0], self.dataset.num_images_range[1]) 107 | 108 | if self.dataset.num_operations_range[0] == self.dataset.num_operations_range[1]: 109 | batch_num_operations = self.dataset.num_operations_range[0] 110 | else: 111 | batch_num_operations = np.random.randint(self.dataset.num_operations_range[0], self.dataset.num_operations_range[1]) 112 | 113 | self.dataset.set_num_digits(batch_num_digits) 114 | self.dataset.set_num_operations(batch_num_operations) 115 | 116 | yield batch_indices 117 | 118 | def __len__(self): 119 | return self.num_samples // self.batch_size -------------------------------------------------------------------------------- /utils/schedulers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | from torch.optim.lr_scheduler import LambdaLR, SequentialLR, MultiStepLR 5 | 6 | class warmup(): 7 | def __init__(self, warmup_steps): 8 | self.warmup_steps = warmup_steps 9 | 10 | def step(self, current_step): 11 | if current_step < self.warmup_steps: # current_step / warmup_steps * base_lr 12 | return float(current_step / self.warmup_steps) 13 | else: # (num_training_steps - current_step) / (num_training_steps - warmup_steps) * base_lr 14 | return 1.0 15 | 16 | class WarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler): 17 | def __init__( 18 | self, 19 | optimizer: torch.optim.Optimizer, 20 | warmup_epochs: int, 21 | max_epochs: int, 22 | warmup_start_lr: float = 0.00001, 23 | eta_min: float = 0.00001, 24 | last_epoch: int = -1, 25 | ): 26 | """ 27 | Args: 28 | optimizer (torch.optim.Optimizer): 29 | 最適化手法インスタンス 30 | warmup_epochs (int): 31 | linear warmupを行うepoch数 32 | max_epochs (int): 33 | cosine曲線の終了に用いる 学習のepoch数 34 | warmup_start_lr (float): 35 | linear warmup 0 epoch目の学習率 36 | eta_min (float): 37 | cosine曲線の下限 38 | last_epoch (int): 39 | cosine曲線の位相オフセット 40 | 学習率をmax_epochsに至るまでコサイン曲線に沿ってスケジュールする 41 | epoch 0からwarmup_epochsまでの学習曲線は線形warmupがかかる 42 | https://pytorch-lightning-bolts.readthedocs.io/en/stable/schedulers/warmup_cosine_annealing.html 43 | """ 44 | self.warmup_epochs = warmup_epochs 45 | self.max_epochs = max_epochs 46 | self.warmup_start_lr = warmup_start_lr 47 | self.eta_min = eta_min 48 | super().__init__(optimizer, last_epoch) 49 | return None 50 | 51 | def get_lr(self): 52 | if self.last_epoch == 0: 53 | return [self.warmup_start_lr] * len(self.base_lrs) 54 | if self.last_epoch < self.warmup_epochs: 55 | return [ 56 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 57 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 58 | ] 59 | if self.last_epoch == self.warmup_epochs: 60 | return self.base_lrs 61 | if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: 62 | return [ 63 | group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 64 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 65 | ] 66 | 67 | return [ 68 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 69 | / (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs))) 70 | * (group["lr"] - self.eta_min) 71 | + self.eta_min 72 | for group in self.optimizer.param_groups 73 | ] 74 | 75 | class WarmupMultiStepLR(object): 76 | def __init__(self, optimizer, warmup_steps, milestones, gamma=0.1, last_epoch=-1, verbose=False): 77 | self.warmup_steps = warmup_steps 78 | self.milestones = milestones 79 | self.gamma = gamma 80 | 81 | # Define the warmup scheduler 82 | lambda_func = lambda step: step / warmup_steps if step < warmup_steps else 1.0 83 | warmup_scheduler = LambdaLR(optimizer, lr_lambda=lambda_func, last_epoch=last_epoch) 84 | 85 | # Define the multi-step scheduler 86 | multistep_scheduler = MultiStepLR(optimizer, milestones=[m - warmup_steps for m in milestones], gamma=gamma, last_epoch=last_epoch) 87 | 88 | # Chain the schedulers 89 | self.scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, multistep_scheduler], milestones=[warmup_steps]) 90 | 91 | def step(self, epoch=None): 92 | self.scheduler.step() 93 | 94 | def state_dict(self): 95 | return self.scheduler.state_dict() 96 | 97 | def load_state_dict(self, state_dict): 98 | self.scheduler.load_state_dict(state_dict) --------------------------------------------------------------------------------