├── tiny_recursive_model ├── __init__.py ├── mlp_mixer_1d.py ├── trainer.py └── trm.py ├── .github └── workflows │ ├── test.yml │ └── python-publish.yml ├── LICENSE ├── pyproject.toml ├── README.md ├── tests └── test_trm.py └── .gitignore /tiny_recursive_model/__init__.py: -------------------------------------------------------------------------------- 1 | from tiny_recursive_model.trm import ( 2 | TinyRecursiveModel, 3 | ) 4 | 5 | from tiny_recursive_model.trainer import ( 6 | Trainer 7 | ) 8 | 9 | from tiny_recursive_model.mlp_mixer_1d import ( 10 | MLPMixer1D 11 | ) 12 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Pytest 2 | on: [push, pull_request] 3 | 4 | jobs: 5 | build: 6 | 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - uses: actions/checkout@v4 11 | - name: Set up Python 3.10 12 | uses: actions/setup-python@v5 13 | with: 14 | python-version: "3.10" 15 | - name: Install dependencies 16 | run: | 17 | python -m pip install --upgrade pip 18 | python -m pip install -e .[test] 19 | - name: Test with pytest 20 | run: | 21 | python -m pytest tests/ 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "tiny-recursive-model" 3 | version = "0.0.12" 4 | description = "Tiny Recursive Model" 5 | authors = [ 6 | { name = "Phil Wang", email = "lucidrains@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">= 3.9" 10 | license = { file = "LICENSE" } 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'deep learning', 14 | 'reasoning', 15 | ] 16 | 17 | classifiers=[ 18 | 'Development Status :: 4 - Beta', 19 | 'Intended Audience :: Developers', 20 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 21 | 'License :: OSI Approved :: MIT License', 22 | 'Programming Language :: Python :: 3.9', 23 | ] 24 | 25 | dependencies = [ 26 | "accelerate", 27 | "adam-atan2-pytorch>=0.2.2", 28 | "einops>=0.8.1", 29 | "ema-pytorch", 30 | "torch>=2.4", 31 | "x-transformers>=2.9.0", 32 | ] 33 | 34 | [project.urls] 35 | Homepage = "https://pypi.org/project/tiny-recursive-model/" 36 | Repository = "https://github.com/lucidrains/tiny-recursive-model" 37 | 38 | [project.optional-dependencies] 39 | examples = [] 40 | test = [ 41 | "pytest" 42 | ] 43 | 44 | [tool.pytest.ini_options] 45 | pythonpath = [ 46 | "." 47 | ] 48 | 49 | [build-system] 50 | requires = ["hatchling"] 51 | build-backend = "hatchling.build" 52 | 53 | [tool.rye] 54 | managed = true 55 | dev-dependencies = [] 56 | 57 | [tool.hatch.metadata] 58 | allow-direct-references = true 59 | 60 | [tool.hatch.build.targets.wheel] 61 | packages = ["tiny_recursive_model"] 62 | -------------------------------------------------------------------------------- /tiny_recursive_model/mlp_mixer_1d.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from torch import nn 4 | from torch.nn import Module, LayerNorm 5 | from einops.layers.torch import Rearrange, Reduce 6 | 7 | pair = lambda x: x if isinstance(x, tuple) else (x, x) 8 | 9 | class PreNormResidual(Module): 10 | def __init__(self, dim, fn): 11 | super().__init__() 12 | self.fn = fn 13 | self.norm = LayerNorm(dim, bias = False) 14 | 15 | def forward(self, x): 16 | return self.fn(self.norm(x)) + x 17 | 18 | def FeedForward(dim, dim_hidden, dropout = 0., dense = nn.Linear): 19 | return nn.Sequential( 20 | dense(dim, dim_hidden), 21 | nn.GELU(), 22 | nn.Dropout(dropout), 23 | dense(dim_hidden, dim), 24 | nn.Dropout(dropout) 25 | ) 26 | 27 | def MLPMixer1D(*, dim, depth, seq_len, expansion_factor = 4, expansion_factor_token = 0.5, dropout = 0.): 28 | chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear 29 | 30 | return nn.Sequential( 31 | *[nn.Sequential( 32 | PreNormResidual(dim, FeedForward(seq_len, int(expansion_factor * dim), dropout, chan_first)), 33 | PreNormResidual(dim, FeedForward(dim, int(expansion_factor_token * dim), dropout, chan_last)) 34 | ) for _ in range(depth)], 35 | LayerNorm(dim, bias = False) 36 | ) 37 | 38 | # quick test 39 | 40 | if __name__ == '__main__': 41 | 42 | import torch 43 | tokens = torch.randn(1, 1024, 512) 44 | mixer = MLPMixer1D(dim = 512, depth = 4, seq_len = 1024) 45 | 46 | assert mixer(tokens).shape == tokens.shape 47 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | trm-fig1 3 | 4 | ## Tiny Recursive Model (TRM) 5 | 6 | Implementation of [Tiny Recursive Model](https://arxiv.org/abs/2510.04871) (TRM), improvement to [HRM](https://github.com/lucidrains/hrm) from Sapient AI, by [Alexia Jolicoeur-Martineau](https://ajolicoeur.wordpress.com/about/) 7 | 8 | Official repository is [here](https://github.com/SamsungSAILMontreal/TinyRecursiveModels) 9 | 10 | [Interview with Alexia!](https://www.youtube.com/watch?v=P9zzUM0PrBM) 11 | 12 | [Paper review by bycloud](https://www.youtube.com/watch?v=ZgwHaI2C-9s) 13 | 14 | trm-fig3 15 | 16 | ## Install 17 | 18 | ```bash 19 | $ pip install tiny-recursive-model 20 | ``` 21 | 22 | ## Usage 23 | 24 | ```python 25 | import torch 26 | from tiny_recursive_model import TinyRecursiveModel, MLPMixer1D, Trainer 27 | 28 | trm = TinyRecursiveModel( 29 | dim = 16, 30 | num_tokens = 256, 31 | network = MLPMixer1D( 32 | dim = 16, 33 | depth = 2, 34 | seq_len = 256 35 | ), 36 | ) 37 | 38 | # mock dataset 39 | 40 | from torch.utils.data import Dataset 41 | class MockDataset(Dataset): 42 | def __len__(self): 43 | return 16 44 | 45 | def __getitem__(self, idx): 46 | inp = torch.randint(0, 256, (256,)) 47 | out = torch.randint(0, 256, (256,)) 48 | return inp, out 49 | 50 | mock_dataset = MockDataset() 51 | 52 | # trainer 53 | 54 | trainer = Trainer( 55 | trm, 56 | mock_dataset, 57 | epochs = 1, 58 | batch_size = 16, 59 | cpu = True 60 | ) 61 | 62 | trainer() 63 | 64 | # inference 65 | 66 | pred_answer, exit_indices = trm.predict( 67 | torch.randint(0, 256, (1, 256)), 68 | max_deep_refinement_steps = 12, 69 | halt_prob_thres = 0.1 70 | ) 71 | 72 | # save to collection of specialized networks for tool call 73 | 74 | torch.save(trm.state_dict(), 'saved-trm.pt') 75 | 76 | ``` 77 | 78 | ## Citations 79 | 80 | ```bibtex 81 | @misc{jolicoeurmartineau2025morerecursivereasoningtiny, 82 | title = {Less is More: Recursive Reasoning with Tiny Networks}, 83 | author = {Alexia Jolicoeur-Martineau}, 84 | year = {2025}, 85 | eprint = {2510.04871}, 86 | archivePrefix = {arXiv}, 87 | primaryClass = {cs.LG}, 88 | url = {https://arxiv.org/abs/2510.04871}, 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /tests/test_trm.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | param = pytest.mark.parametrize 3 | 4 | import torch 5 | 6 | from tiny_recursive_model.trm import TinyRecursiveModel 7 | from tiny_recursive_model.trainer import Trainer 8 | 9 | @param('use_self_attn', (False, True)) 10 | @param('registers', (0, 4)) 11 | def test_trm( 12 | use_self_attn, 13 | registers 14 | ): 15 | from torch.optim import AdamW 16 | 17 | if use_self_attn: 18 | from x_transformers import Encoder 19 | network = Encoder(dim = 512, depth = 2) 20 | else: 21 | from tiny_recursive_model.mlp_mixer_1d import MLPMixer1D 22 | network = MLPMixer1D(dim = 512, depth = 2, seq_len = 1024 + registers) 23 | 24 | trm = TinyRecursiveModel( 25 | dim = 512, 26 | num_tokens = 256, 27 | num_register_tokens = registers, 28 | network = network 29 | ) 30 | 31 | optim = AdamW(trm.parameters(), lr = 1e-4) 32 | 33 | seq = torch.randint(0, 256, (2, 1024)) 34 | answer = torch.randint(0, 256, (2, 1024)) 35 | 36 | outputs, latents = trm.get_initial() 37 | 38 | for _ in range(3): 39 | loss, losses, outputs, latents, pred, halt = trm(seq, outputs, latents, labels = answer) 40 | 41 | loss.backward() 42 | optim.step() 43 | optim.zero_grad() 44 | 45 | pred_answer, exit_indices = trm.predict(seq) 46 | 47 | def test_trainer(): 48 | from torch.utils.data import Dataset 49 | from tiny_recursive_model.mlp_mixer_1d import MLPMixer1D 50 | 51 | trm = TinyRecursiveModel( 52 | dim = 16, 53 | num_tokens = 256, 54 | network = MLPMixer1D( 55 | dim = 16, 56 | depth = 2, 57 | seq_len = 256 58 | ), 59 | ) 60 | 61 | class MockDataset(Dataset): 62 | def __len__(self): 63 | return 16 64 | 65 | def __getitem__(self, idx): 66 | inp = torch.randint(0, 256, (256,)) 67 | out = torch.randint(0, 256, (256,)) 68 | return inp, out 69 | 70 | trainer = Trainer( 71 | trm, 72 | MockDataset(), 73 | epochs = 1, 74 | batch_size = 16, 75 | cpu = True 76 | ) 77 | 78 | trainer() 79 | 80 | pred_answer, exit_indices = trm.predict(torch.randint(0, 256, (1, 256))) 81 | 82 | def test_gpt(): 83 | from torch.utils.data import Dataset 84 | from x_transformers import Decoder 85 | 86 | trm = TinyRecursiveModel( 87 | dim = 16, 88 | num_tokens = 256, 89 | network = Decoder( 90 | dim = 16, 91 | depth = 2 92 | ), 93 | ) 94 | 95 | class MockDataset(Dataset): 96 | def __len__(self): 97 | return 16 98 | 99 | def __getitem__(self, idx): 100 | seq = torch.randint(0, 256, (257,)) 101 | return seq[:-1], seq[1:] 102 | 103 | trainer = Trainer( 104 | trm, 105 | MockDataset(), 106 | epochs = 1, 107 | batch_size = 16, 108 | cpu = True 109 | ) 110 | 111 | trainer() 112 | 113 | pred_answer, exit_indices = trm.predict(torch.randint(0, 256, (1, 256))) 114 | -------------------------------------------------------------------------------- /tiny_recursive_model/trainer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | from torch.nn import Module 5 | from torch.optim import AdamW, Optimizer 6 | from torch.optim.lr_scheduler import LambdaLR 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | from einops import pack, unpack 10 | 11 | from accelerate import Accelerator 12 | 13 | # ema - apparently greatly helped with results 14 | 15 | from ema_pytorch import EMA 16 | 17 | from tiny_recursive_model.trm import TinyRecursiveModel 18 | 19 | from adam_atan2_pytorch import MuonAdamAtan2 20 | 21 | from x_transformers import Encoder, Decoder 22 | 23 | # helpers 24 | 25 | def exists(v): 26 | return v is not None 27 | 28 | def range_from_one(n): 29 | return range(1, n + 1) 30 | 31 | def is_empty(t): 32 | return t.numel() == 0 33 | 34 | # trainer 35 | 36 | class Trainer(Module): 37 | def __init__( 38 | self, 39 | model: TinyRecursiveModel | Module, 40 | dataset: Dataset, 41 | optim_klass = AdamW, 42 | optim: Optimizer | None = None, 43 | learning_rate = 1e-4, 44 | muon_learning_rate = 1e-3, 45 | weight_decay = 1., 46 | batch_size = 16, 47 | epochs = 2, 48 | halt_prob_thres = 0.5, 49 | max_recurrent_steps = 12, 50 | warmup_steps = 2000, 51 | ema_decay_rate = 0.999, 52 | switch_ema_every = 10000, # switch ema https://arxiv.org/abs/2402.09240 53 | accelerate_kwargs: dict = dict(), 54 | cpu = False 55 | ): 56 | super().__init__() 57 | 58 | self.accelerator = Accelerator(**accelerate_kwargs, cpu = cpu) 59 | 60 | self.batch_size = batch_size 61 | self.epochs = epochs 62 | 63 | # data 64 | 65 | self.dataset = dataset 66 | self.dataloader = dataloader = DataLoader(self.dataset, batch_size = self.batch_size, shuffle = True) 67 | 68 | # optim 69 | 70 | if not exists(optim): 71 | 72 | if isinstance(model.network, (Encoder, Decoder)): 73 | optim = MuonAdamAtan2( 74 | model.network.muon_parameters(), 75 | model.parameters(), 76 | lr = learning_rate / (batch_size * max_recurrent_steps), 77 | muon_lr = muon_learning_rate / (batch_size * max_recurrent_steps), 78 | weight_decay = weight_decay 79 | ) 80 | else: 81 | optim = optim_klass( 82 | model.parameters(), 83 | lr = learning_rate / (batch_size * max_recurrent_steps), 84 | weight_decay = weight_decay 85 | ) 86 | 87 | self.optim = optim 88 | 89 | # scheduler 90 | 91 | self.scheduler = LambdaLR(self.optim, lambda step: min((step + 1) / warmup_steps, 1.0)) 92 | 93 | # model 94 | 95 | self.model = model 96 | 97 | # ema model 98 | 99 | self.ema_model = None 100 | 101 | if self.accelerator.is_main_process: 102 | self.ema_model = EMA( 103 | model, 104 | beta = ema_decay_rate, 105 | update_model_with_ema_every = switch_ema_every, 106 | forward_method_names = ('predict',) 107 | ) 108 | 109 | # recurrent and act related variables 110 | 111 | self.halt_prob_thres = halt_prob_thres 112 | 113 | self.max_recurrent_steps = max_recurrent_steps 114 | 115 | # prepare maybe distributed 116 | 117 | self.model, self.optim, self.dataloader, self.scheduler = self.accelerator.prepare(self.model, self.optim, self.dataloader, self.scheduler) 118 | 119 | def forward(self): 120 | 121 | for epoch in range_from_one(self.epochs): 122 | 123 | for dataset_input, dataset_output in self.dataloader: 124 | 125 | outputs, latents = self.model.get_initial() 126 | 127 | for recurrent_step in range_from_one(self.max_recurrent_steps): 128 | 129 | loss, (main_loss, halt_loss), outputs, latents, pred, halt = self.model(dataset_input, outputs, latents, labels = dataset_output) 130 | 131 | self.accelerator.print(f'[{epoch} ({recurrent_step} / {self.max_recurrent_steps})] loss: {main_loss.mean().item():.3f} | halt loss: {halt_loss.mean().item():.3f}') 132 | 133 | self.accelerator.backward(loss) 134 | 135 | self.optim.step() 136 | self.optim.zero_grad() 137 | 138 | self.scheduler.step() 139 | 140 | if self.accelerator.is_main_process: 141 | self.ema_model.update() 142 | 143 | # handle halting 144 | 145 | halt_mask = halt >= self.halt_prob_thres 146 | 147 | if not halt_mask.any(): 148 | continue 149 | 150 | outputs = outputs[~halt_mask] 151 | latents = latents[~halt_mask] 152 | dataset_input = dataset_input[~halt_mask] 153 | dataset_output = dataset_output[~halt_mask] 154 | 155 | if is_empty(outputs): 156 | break 157 | 158 | self.accelerator.print('complete') 159 | 160 | if self.accelerator.is_main_process: 161 | self.ema_model.copy_params_from_ema_to_model() 162 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[codz] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # UV 100 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | #uv.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | #poetry.toml 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python. 116 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control 117 | #pdm.lock 118 | #pdm.toml 119 | .pdm-python 120 | .pdm-build/ 121 | 122 | # pixi 123 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control. 124 | #pixi.lock 125 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one 126 | # in the .venv directory. It is recommended not to include this directory in version control. 127 | .pixi 128 | 129 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 130 | __pypackages__/ 131 | 132 | # Celery stuff 133 | celerybeat-schedule 134 | celerybeat.pid 135 | 136 | # SageMath parsed files 137 | *.sage.py 138 | 139 | # Environments 140 | .env 141 | .envrc 142 | .venv 143 | env/ 144 | venv/ 145 | ENV/ 146 | env.bak/ 147 | venv.bak/ 148 | 149 | # Spyder project settings 150 | .spyderproject 151 | .spyproject 152 | 153 | # Rope project settings 154 | .ropeproject 155 | 156 | # mkdocs documentation 157 | /site 158 | 159 | # mypy 160 | .mypy_cache/ 161 | .dmypy.json 162 | dmypy.json 163 | 164 | # Pyre type checker 165 | .pyre/ 166 | 167 | # pytype static type analyzer 168 | .pytype/ 169 | 170 | # Cython debug symbols 171 | cython_debug/ 172 | 173 | # PyCharm 174 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 175 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 176 | # and can be added to the global gitignore or merged into this file. For a more nuclear 177 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 178 | #.idea/ 179 | 180 | # Abstra 181 | # Abstra is an AI-powered process automation framework. 182 | # Ignore directories containing user credentials, local state, and settings. 183 | # Learn more at https://abstra.io/docs 184 | .abstra/ 185 | 186 | # Visual Studio Code 187 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 188 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 189 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 190 | # you could uncomment the following to ignore the entire vscode folder 191 | # .vscode/ 192 | 193 | # Ruff stuff: 194 | .ruff_cache/ 195 | 196 | # PyPI configuration file 197 | .pypirc 198 | 199 | # Cursor 200 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 201 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 202 | # refer to https://docs.cursor.com/context/ignore-files 203 | .cursorignore 204 | .cursorindexingignore 205 | 206 | # Marimo 207 | marimo/_static/ 208 | marimo/_lsp/ 209 | __marimo__/ 210 | -------------------------------------------------------------------------------- /tiny_recursive_model/trm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from contextlib import nullcontext 3 | 4 | import torch 5 | from torch import nn, cat, arange, tensor 6 | import torch.nn.functional as F 7 | from torch.nn import Module, ModuleList 8 | 9 | from einops import rearrange, repeat, reduce, pack, unpack 10 | from einops.layers.torch import Reduce, Rearrange 11 | 12 | # network related 13 | 14 | from tiny_recursive_model.mlp_mixer_1d import MLPMixer1D 15 | 16 | # helpers 17 | 18 | def exists(v): 19 | return v is not None 20 | 21 | def default(v, d): 22 | return v if exists(v) else d 23 | 24 | def is_empty(t): 25 | return t.numel() == 0 26 | 27 | def range_from_one(n): 28 | return range(1, n + 1) 29 | 30 | # classes 31 | 32 | class TinyRecursiveModel(Module): 33 | def __init__( 34 | self, 35 | *, 36 | dim, 37 | num_tokens, 38 | network: Module, 39 | num_refinement_blocks = 3, # T in paper 40 | num_latent_refinements = 6, # n in paper - 1 output refinement per N latent refinements 41 | halt_loss_weight = 1., 42 | num_register_tokens = 0 43 | ): 44 | super().__init__() 45 | assert num_refinement_blocks > 1 46 | 47 | self.input_embed = nn.Embedding(num_tokens, dim) 48 | self.output_init_embed = nn.Parameter(torch.randn(dim) * 1e-2) 49 | self.latent_init_embed = nn.Parameter(torch.randn(dim) * 1e-2) 50 | 51 | self.network = network 52 | 53 | self.num_latent_refinements = num_latent_refinements 54 | self.num_refinement_blocks = num_refinement_blocks 55 | 56 | # register tokens for the self attend version 57 | 58 | self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2) 59 | 60 | # prediction heads 61 | 62 | self.to_pred = nn.Linear(dim, num_tokens, bias = False) 63 | 64 | self.to_halt_pred = nn.Sequential( 65 | Reduce('b n d -> b d', 'mean'), 66 | nn.Linear(dim, 1, bias = False), 67 | nn.Sigmoid(), 68 | Rearrange('... 1 -> ...') 69 | ) 70 | 71 | self.halt_loss_weight = halt_loss_weight 72 | 73 | @property 74 | def device(self): 75 | return next(self.parameters()).device 76 | 77 | def get_initial(self): 78 | outputs = self.output_init_embed 79 | latents = self.latent_init_embed 80 | 81 | return outputs, latents 82 | 83 | def embed_inputs_with_registers( 84 | self, 85 | seq 86 | ): 87 | batch = seq.shape[0] 88 | 89 | inputs = self.input_embed(seq) 90 | 91 | # maybe registers 92 | 93 | registers = repeat(self.register_tokens, 'n d -> b n d', b = batch) 94 | 95 | inputs, packed_shape = pack([registers, inputs], 'b * d') 96 | 97 | return inputs, packed_shape 98 | 99 | def refine_latent_then_output_once( 100 | self, 101 | inputs, # (b n d) 102 | outputs, # (b n d) 103 | latents, # (b n d) 104 | ): 105 | 106 | # so it seems for this work, they use only one network 107 | # the network learns to refine the latents if input is passed in, otherwise it refines the output 108 | 109 | for _ in range(self.num_latent_refinements): 110 | 111 | latents = self.network(outputs + latents + inputs) 112 | 113 | outputs = self.network(outputs + latents) 114 | 115 | return outputs, latents 116 | 117 | def deep_refinement( 118 | self, 119 | inputs, # (b n d) 120 | outputs, # (b n d) 121 | latents, # (b n d) 122 | ): 123 | 124 | for step in range_from_one(self.num_refinement_blocks): 125 | 126 | # only last round of refinement receives gradients 127 | 128 | is_last = step == self.num_refinement_blocks 129 | context = torch.no_grad if not is_last else nullcontext 130 | 131 | with context(): 132 | outputs, latents = self.refine_latent_then_output_once(inputs, outputs, latents) 133 | 134 | return outputs, latents 135 | 136 | @torch.no_grad() 137 | def predict( 138 | self, 139 | seq, 140 | halt_prob_thres = 0.5, 141 | max_deep_refinement_steps = 12 142 | ): 143 | batch = seq.shape[0] 144 | 145 | inputs, packed_shape = self.embed_inputs_with_registers(seq) 146 | 147 | # initial outputs and latents 148 | 149 | outputs, latents = self.get_initial() 150 | 151 | # active batch indices, the step it exited at, and the final output predictions 152 | 153 | active_batch_indices = arange(batch, device = self.device, dtype = torch.float32) 154 | 155 | preds = [] 156 | exited_step_indices = [] 157 | exited_batch_indices = [] 158 | 159 | for step in range_from_one(max_deep_refinement_steps): 160 | is_last = step == max_deep_refinement_steps 161 | 162 | outputs, latents = self.deep_refinement(inputs, outputs, latents) 163 | 164 | halt_prob = self.to_halt_pred(outputs) 165 | 166 | should_halt = (halt_prob >= halt_prob_thres) | is_last 167 | 168 | if not should_halt.any(): 169 | continue 170 | 171 | # maybe remove registers 172 | 173 | registers, outputs_for_pred = unpack(outputs, packed_shape, 'b * d') 174 | 175 | # append to exited predictions 176 | 177 | pred = self.to_pred(outputs_for_pred[should_halt]) 178 | preds.append(pred) 179 | 180 | # append the step at which early halted 181 | 182 | exited_step_indices.extend([step] * should_halt.sum().item()) 183 | 184 | # append indices for sorting back 185 | 186 | exited_batch_indices.append(active_batch_indices[should_halt]) 187 | 188 | if is_last: 189 | continue 190 | 191 | # ready for next round 192 | 193 | inputs = inputs[~should_halt] 194 | outputs = outputs[~should_halt] 195 | latents = latents[~should_halt] 196 | active_batch_indices = active_batch_indices[~should_halt] 197 | 198 | if is_empty(outputs): 199 | break 200 | 201 | preds = cat(preds).argmax(dim = -1) 202 | exited_step_indices = tensor(exited_step_indices) 203 | 204 | exited_batch_indices = cat(exited_batch_indices) 205 | sort_indices = exited_batch_indices.argsort(dim = -1) 206 | 207 | return preds[sort_indices], exited_step_indices[sort_indices] 208 | 209 | def forward( 210 | self, 211 | seq, 212 | outputs, 213 | latents, 214 | labels = None 215 | ): 216 | 217 | inputs, packed_shape = self.embed_inputs_with_registers(seq) 218 | 219 | outputs, latents = self.deep_refinement(inputs, outputs, latents) 220 | 221 | registers, outputs_for_pred = unpack(outputs, packed_shape, 'b * d') 222 | 223 | pred = self.to_pred(outputs_for_pred) 224 | 225 | halt_prob = self.to_halt_pred(outputs) 226 | 227 | outputs, latents = outputs.detach(), latents.detach() 228 | 229 | return_package = (outputs, latents, pred, halt_prob) 230 | 231 | if not exists(labels): 232 | return return_package 233 | 234 | # calculate loss if labels passed in 235 | 236 | loss = F.cross_entropy(rearrange(pred, 'b n l -> b l n'), labels, reduction = 'none') 237 | loss = reduce(loss, 'b ... -> b', 'mean') 238 | 239 | is_all_correct = (pred.argmax(dim = -1) == labels).all(dim = -1) 240 | 241 | halt_loss = F.binary_cross_entropy(halt_prob, is_all_correct.float(), reduction = 'none') 242 | 243 | # total loss and loss breakdown 244 | 245 | total_loss = ( 246 | loss + 247 | halt_loss * self.halt_loss_weight 248 | ) 249 | 250 | losses = (loss, halt_loss) 251 | 252 | return (total_loss.sum(), losses, *return_package) 253 | --------------------------------------------------------------------------------