├── tiny_recursive_model
├── __init__.py
├── mlp_mixer_1d.py
├── trainer.py
└── trm.py
├── .github
└── workflows
│ ├── test.yml
│ └── python-publish.yml
├── LICENSE
├── pyproject.toml
├── README.md
├── tests
└── test_trm.py
└── .gitignore
/tiny_recursive_model/__init__.py:
--------------------------------------------------------------------------------
1 | from tiny_recursive_model.trm import (
2 | TinyRecursiveModel,
3 | )
4 |
5 | from tiny_recursive_model.trainer import (
6 | Trainer
7 | )
8 |
9 | from tiny_recursive_model.mlp_mixer_1d import (
10 | MLPMixer1D
11 | )
12 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Pytest
2 | on: [push, pull_request]
3 |
4 | jobs:
5 | build:
6 |
7 | runs-on: ubuntu-latest
8 |
9 | steps:
10 | - uses: actions/checkout@v4
11 | - name: Set up Python 3.10
12 | uses: actions/setup-python@v5
13 | with:
14 | python-version: "3.10"
15 | - name: Install dependencies
16 | run: |
17 | python -m pip install --upgrade pip
18 | python -m pip install -e .[test]
19 | - name: Test with pytest
20 | run: |
21 | python -m pytest tests/
22 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "tiny-recursive-model"
3 | version = "0.0.12"
4 | description = "Tiny Recursive Model"
5 | authors = [
6 | { name = "Phil Wang", email = "lucidrains@gmail.com" }
7 | ]
8 | readme = "README.md"
9 | requires-python = ">= 3.9"
10 | license = { file = "LICENSE" }
11 | keywords = [
12 | 'artificial intelligence',
13 | 'deep learning',
14 | 'reasoning',
15 | ]
16 |
17 | classifiers=[
18 | 'Development Status :: 4 - Beta',
19 | 'Intended Audience :: Developers',
20 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
21 | 'License :: OSI Approved :: MIT License',
22 | 'Programming Language :: Python :: 3.9',
23 | ]
24 |
25 | dependencies = [
26 | "accelerate",
27 | "adam-atan2-pytorch>=0.2.2",
28 | "einops>=0.8.1",
29 | "ema-pytorch",
30 | "torch>=2.4",
31 | "x-transformers>=2.9.0",
32 | ]
33 |
34 | [project.urls]
35 | Homepage = "https://pypi.org/project/tiny-recursive-model/"
36 | Repository = "https://github.com/lucidrains/tiny-recursive-model"
37 |
38 | [project.optional-dependencies]
39 | examples = []
40 | test = [
41 | "pytest"
42 | ]
43 |
44 | [tool.pytest.ini_options]
45 | pythonpath = [
46 | "."
47 | ]
48 |
49 | [build-system]
50 | requires = ["hatchling"]
51 | build-backend = "hatchling.build"
52 |
53 | [tool.rye]
54 | managed = true
55 | dev-dependencies = []
56 |
57 | [tool.hatch.metadata]
58 | allow-direct-references = true
59 |
60 | [tool.hatch.build.targets.wheel]
61 | packages = ["tiny_recursive_model"]
62 |
--------------------------------------------------------------------------------
/tiny_recursive_model/mlp_mixer_1d.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | from torch import nn
4 | from torch.nn import Module, LayerNorm
5 | from einops.layers.torch import Rearrange, Reduce
6 |
7 | pair = lambda x: x if isinstance(x, tuple) else (x, x)
8 |
9 | class PreNormResidual(Module):
10 | def __init__(self, dim, fn):
11 | super().__init__()
12 | self.fn = fn
13 | self.norm = LayerNorm(dim, bias = False)
14 |
15 | def forward(self, x):
16 | return self.fn(self.norm(x)) + x
17 |
18 | def FeedForward(dim, dim_hidden, dropout = 0., dense = nn.Linear):
19 | return nn.Sequential(
20 | dense(dim, dim_hidden),
21 | nn.GELU(),
22 | nn.Dropout(dropout),
23 | dense(dim_hidden, dim),
24 | nn.Dropout(dropout)
25 | )
26 |
27 | def MLPMixer1D(*, dim, depth, seq_len, expansion_factor = 4, expansion_factor_token = 0.5, dropout = 0.):
28 | chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear
29 |
30 | return nn.Sequential(
31 | *[nn.Sequential(
32 | PreNormResidual(dim, FeedForward(seq_len, int(expansion_factor * dim), dropout, chan_first)),
33 | PreNormResidual(dim, FeedForward(dim, int(expansion_factor_token * dim), dropout, chan_last))
34 | ) for _ in range(depth)],
35 | LayerNorm(dim, bias = False)
36 | )
37 |
38 | # quick test
39 |
40 | if __name__ == '__main__':
41 |
42 | import torch
43 | tokens = torch.randn(1, 1024, 512)
44 | mixer = MLPMixer1D(dim = 512, depth = 4, seq_len = 1024)
45 |
46 | assert mixer(tokens).shape == tokens.shape
47 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | ## Tiny Recursive Model (TRM)
5 |
6 | Implementation of [Tiny Recursive Model](https://arxiv.org/abs/2510.04871) (TRM), improvement to [HRM](https://github.com/lucidrains/hrm) from Sapient AI, by [Alexia Jolicoeur-Martineau](https://ajolicoeur.wordpress.com/about/)
7 |
8 | Official repository is [here](https://github.com/SamsungSAILMontreal/TinyRecursiveModels)
9 |
10 | [Interview with Alexia!](https://www.youtube.com/watch?v=P9zzUM0PrBM)
11 |
12 | [Paper review by bycloud](https://www.youtube.com/watch?v=ZgwHaI2C-9s)
13 |
14 |
15 |
16 | ## Install
17 |
18 | ```bash
19 | $ pip install tiny-recursive-model
20 | ```
21 |
22 | ## Usage
23 |
24 | ```python
25 | import torch
26 | from tiny_recursive_model import TinyRecursiveModel, MLPMixer1D, Trainer
27 |
28 | trm = TinyRecursiveModel(
29 | dim = 16,
30 | num_tokens = 256,
31 | network = MLPMixer1D(
32 | dim = 16,
33 | depth = 2,
34 | seq_len = 256
35 | ),
36 | )
37 |
38 | # mock dataset
39 |
40 | from torch.utils.data import Dataset
41 | class MockDataset(Dataset):
42 | def __len__(self):
43 | return 16
44 |
45 | def __getitem__(self, idx):
46 | inp = torch.randint(0, 256, (256,))
47 | out = torch.randint(0, 256, (256,))
48 | return inp, out
49 |
50 | mock_dataset = MockDataset()
51 |
52 | # trainer
53 |
54 | trainer = Trainer(
55 | trm,
56 | mock_dataset,
57 | epochs = 1,
58 | batch_size = 16,
59 | cpu = True
60 | )
61 |
62 | trainer()
63 |
64 | # inference
65 |
66 | pred_answer, exit_indices = trm.predict(
67 | torch.randint(0, 256, (1, 256)),
68 | max_deep_refinement_steps = 12,
69 | halt_prob_thres = 0.1
70 | )
71 |
72 | # save to collection of specialized networks for tool call
73 |
74 | torch.save(trm.state_dict(), 'saved-trm.pt')
75 |
76 | ```
77 |
78 | ## Citations
79 |
80 | ```bibtex
81 | @misc{jolicoeurmartineau2025morerecursivereasoningtiny,
82 | title = {Less is More: Recursive Reasoning with Tiny Networks},
83 | author = {Alexia Jolicoeur-Martineau},
84 | year = {2025},
85 | eprint = {2510.04871},
86 | archivePrefix = {arXiv},
87 | primaryClass = {cs.LG},
88 | url = {https://arxiv.org/abs/2510.04871},
89 | }
90 | ```
91 |
--------------------------------------------------------------------------------
/tests/test_trm.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | param = pytest.mark.parametrize
3 |
4 | import torch
5 |
6 | from tiny_recursive_model.trm import TinyRecursiveModel
7 | from tiny_recursive_model.trainer import Trainer
8 |
9 | @param('use_self_attn', (False, True))
10 | @param('registers', (0, 4))
11 | def test_trm(
12 | use_self_attn,
13 | registers
14 | ):
15 | from torch.optim import AdamW
16 |
17 | if use_self_attn:
18 | from x_transformers import Encoder
19 | network = Encoder(dim = 512, depth = 2)
20 | else:
21 | from tiny_recursive_model.mlp_mixer_1d import MLPMixer1D
22 | network = MLPMixer1D(dim = 512, depth = 2, seq_len = 1024 + registers)
23 |
24 | trm = TinyRecursiveModel(
25 | dim = 512,
26 | num_tokens = 256,
27 | num_register_tokens = registers,
28 | network = network
29 | )
30 |
31 | optim = AdamW(trm.parameters(), lr = 1e-4)
32 |
33 | seq = torch.randint(0, 256, (2, 1024))
34 | answer = torch.randint(0, 256, (2, 1024))
35 |
36 | outputs, latents = trm.get_initial()
37 |
38 | for _ in range(3):
39 | loss, losses, outputs, latents, pred, halt = trm(seq, outputs, latents, labels = answer)
40 |
41 | loss.backward()
42 | optim.step()
43 | optim.zero_grad()
44 |
45 | pred_answer, exit_indices = trm.predict(seq)
46 |
47 | def test_trainer():
48 | from torch.utils.data import Dataset
49 | from tiny_recursive_model.mlp_mixer_1d import MLPMixer1D
50 |
51 | trm = TinyRecursiveModel(
52 | dim = 16,
53 | num_tokens = 256,
54 | network = MLPMixer1D(
55 | dim = 16,
56 | depth = 2,
57 | seq_len = 256
58 | ),
59 | )
60 |
61 | class MockDataset(Dataset):
62 | def __len__(self):
63 | return 16
64 |
65 | def __getitem__(self, idx):
66 | inp = torch.randint(0, 256, (256,))
67 | out = torch.randint(0, 256, (256,))
68 | return inp, out
69 |
70 | trainer = Trainer(
71 | trm,
72 | MockDataset(),
73 | epochs = 1,
74 | batch_size = 16,
75 | cpu = True
76 | )
77 |
78 | trainer()
79 |
80 | pred_answer, exit_indices = trm.predict(torch.randint(0, 256, (1, 256)))
81 |
82 | def test_gpt():
83 | from torch.utils.data import Dataset
84 | from x_transformers import Decoder
85 |
86 | trm = TinyRecursiveModel(
87 | dim = 16,
88 | num_tokens = 256,
89 | network = Decoder(
90 | dim = 16,
91 | depth = 2
92 | ),
93 | )
94 |
95 | class MockDataset(Dataset):
96 | def __len__(self):
97 | return 16
98 |
99 | def __getitem__(self, idx):
100 | seq = torch.randint(0, 256, (257,))
101 | return seq[:-1], seq[1:]
102 |
103 | trainer = Trainer(
104 | trm,
105 | MockDataset(),
106 | epochs = 1,
107 | batch_size = 16,
108 | cpu = True
109 | )
110 |
111 | trainer()
112 |
113 | pred_answer, exit_indices = trm.predict(torch.randint(0, 256, (1, 256)))
114 |
--------------------------------------------------------------------------------
/tiny_recursive_model/trainer.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 | from torch.nn import Module
5 | from torch.optim import AdamW, Optimizer
6 | from torch.optim.lr_scheduler import LambdaLR
7 | from torch.utils.data import Dataset, DataLoader
8 |
9 | from einops import pack, unpack
10 |
11 | from accelerate import Accelerator
12 |
13 | # ema - apparently greatly helped with results
14 |
15 | from ema_pytorch import EMA
16 |
17 | from tiny_recursive_model.trm import TinyRecursiveModel
18 |
19 | from adam_atan2_pytorch import MuonAdamAtan2
20 |
21 | from x_transformers import Encoder, Decoder
22 |
23 | # helpers
24 |
25 | def exists(v):
26 | return v is not None
27 |
28 | def range_from_one(n):
29 | return range(1, n + 1)
30 |
31 | def is_empty(t):
32 | return t.numel() == 0
33 |
34 | # trainer
35 |
36 | class Trainer(Module):
37 | def __init__(
38 | self,
39 | model: TinyRecursiveModel | Module,
40 | dataset: Dataset,
41 | optim_klass = AdamW,
42 | optim: Optimizer | None = None,
43 | learning_rate = 1e-4,
44 | muon_learning_rate = 1e-3,
45 | weight_decay = 1.,
46 | batch_size = 16,
47 | epochs = 2,
48 | halt_prob_thres = 0.5,
49 | max_recurrent_steps = 12,
50 | warmup_steps = 2000,
51 | ema_decay_rate = 0.999,
52 | switch_ema_every = 10000, # switch ema https://arxiv.org/abs/2402.09240
53 | accelerate_kwargs: dict = dict(),
54 | cpu = False
55 | ):
56 | super().__init__()
57 |
58 | self.accelerator = Accelerator(**accelerate_kwargs, cpu = cpu)
59 |
60 | self.batch_size = batch_size
61 | self.epochs = epochs
62 |
63 | # data
64 |
65 | self.dataset = dataset
66 | self.dataloader = dataloader = DataLoader(self.dataset, batch_size = self.batch_size, shuffle = True)
67 |
68 | # optim
69 |
70 | if not exists(optim):
71 |
72 | if isinstance(model.network, (Encoder, Decoder)):
73 | optim = MuonAdamAtan2(
74 | model.network.muon_parameters(),
75 | model.parameters(),
76 | lr = learning_rate / (batch_size * max_recurrent_steps),
77 | muon_lr = muon_learning_rate / (batch_size * max_recurrent_steps),
78 | weight_decay = weight_decay
79 | )
80 | else:
81 | optim = optim_klass(
82 | model.parameters(),
83 | lr = learning_rate / (batch_size * max_recurrent_steps),
84 | weight_decay = weight_decay
85 | )
86 |
87 | self.optim = optim
88 |
89 | # scheduler
90 |
91 | self.scheduler = LambdaLR(self.optim, lambda step: min((step + 1) / warmup_steps, 1.0))
92 |
93 | # model
94 |
95 | self.model = model
96 |
97 | # ema model
98 |
99 | self.ema_model = None
100 |
101 | if self.accelerator.is_main_process:
102 | self.ema_model = EMA(
103 | model,
104 | beta = ema_decay_rate,
105 | update_model_with_ema_every = switch_ema_every,
106 | forward_method_names = ('predict',)
107 | )
108 |
109 | # recurrent and act related variables
110 |
111 | self.halt_prob_thres = halt_prob_thres
112 |
113 | self.max_recurrent_steps = max_recurrent_steps
114 |
115 | # prepare maybe distributed
116 |
117 | self.model, self.optim, self.dataloader, self.scheduler = self.accelerator.prepare(self.model, self.optim, self.dataloader, self.scheduler)
118 |
119 | def forward(self):
120 |
121 | for epoch in range_from_one(self.epochs):
122 |
123 | for dataset_input, dataset_output in self.dataloader:
124 |
125 | outputs, latents = self.model.get_initial()
126 |
127 | for recurrent_step in range_from_one(self.max_recurrent_steps):
128 |
129 | loss, (main_loss, halt_loss), outputs, latents, pred, halt = self.model(dataset_input, outputs, latents, labels = dataset_output)
130 |
131 | self.accelerator.print(f'[{epoch} ({recurrent_step} / {self.max_recurrent_steps})] loss: {main_loss.mean().item():.3f} | halt loss: {halt_loss.mean().item():.3f}')
132 |
133 | self.accelerator.backward(loss)
134 |
135 | self.optim.step()
136 | self.optim.zero_grad()
137 |
138 | self.scheduler.step()
139 |
140 | if self.accelerator.is_main_process:
141 | self.ema_model.update()
142 |
143 | # handle halting
144 |
145 | halt_mask = halt >= self.halt_prob_thres
146 |
147 | if not halt_mask.any():
148 | continue
149 |
150 | outputs = outputs[~halt_mask]
151 | latents = latents[~halt_mask]
152 | dataset_input = dataset_input[~halt_mask]
153 | dataset_output = dataset_output[~halt_mask]
154 |
155 | if is_empty(outputs):
156 | break
157 |
158 | self.accelerator.print('complete')
159 |
160 | if self.accelerator.is_main_process:
161 | self.ema_model.copy_params_from_ema_to_model()
162 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[codz]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py.cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # UV
100 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
101 | # This is especially recommended for binary packages to ensure reproducibility, and is more
102 | # commonly ignored for libraries.
103 | #uv.lock
104 |
105 | # poetry
106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
107 | # This is especially recommended for binary packages to ensure reproducibility, and is more
108 | # commonly ignored for libraries.
109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
110 | #poetry.lock
111 | #poetry.toml
112 |
113 | # pdm
114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
115 | # pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
116 | # https://pdm-project.org/en/latest/usage/project/#working-with-version-control
117 | #pdm.lock
118 | #pdm.toml
119 | .pdm-python
120 | .pdm-build/
121 |
122 | # pixi
123 | # Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
124 | #pixi.lock
125 | # Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
126 | # in the .venv directory. It is recommended not to include this directory in version control.
127 | .pixi
128 |
129 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
130 | __pypackages__/
131 |
132 | # Celery stuff
133 | celerybeat-schedule
134 | celerybeat.pid
135 |
136 | # SageMath parsed files
137 | *.sage.py
138 |
139 | # Environments
140 | .env
141 | .envrc
142 | .venv
143 | env/
144 | venv/
145 | ENV/
146 | env.bak/
147 | venv.bak/
148 |
149 | # Spyder project settings
150 | .spyderproject
151 | .spyproject
152 |
153 | # Rope project settings
154 | .ropeproject
155 |
156 | # mkdocs documentation
157 | /site
158 |
159 | # mypy
160 | .mypy_cache/
161 | .dmypy.json
162 | dmypy.json
163 |
164 | # Pyre type checker
165 | .pyre/
166 |
167 | # pytype static type analyzer
168 | .pytype/
169 |
170 | # Cython debug symbols
171 | cython_debug/
172 |
173 | # PyCharm
174 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
175 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
176 | # and can be added to the global gitignore or merged into this file. For a more nuclear
177 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
178 | #.idea/
179 |
180 | # Abstra
181 | # Abstra is an AI-powered process automation framework.
182 | # Ignore directories containing user credentials, local state, and settings.
183 | # Learn more at https://abstra.io/docs
184 | .abstra/
185 |
186 | # Visual Studio Code
187 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
188 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
189 | # and can be added to the global gitignore or merged into this file. However, if you prefer,
190 | # you could uncomment the following to ignore the entire vscode folder
191 | # .vscode/
192 |
193 | # Ruff stuff:
194 | .ruff_cache/
195 |
196 | # PyPI configuration file
197 | .pypirc
198 |
199 | # Cursor
200 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
201 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
202 | # refer to https://docs.cursor.com/context/ignore-files
203 | .cursorignore
204 | .cursorindexingignore
205 |
206 | # Marimo
207 | marimo/_static/
208 | marimo/_lsp/
209 | __marimo__/
210 |
--------------------------------------------------------------------------------
/tiny_recursive_model/trm.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from contextlib import nullcontext
3 |
4 | import torch
5 | from torch import nn, cat, arange, tensor
6 | import torch.nn.functional as F
7 | from torch.nn import Module, ModuleList
8 |
9 | from einops import rearrange, repeat, reduce, pack, unpack
10 | from einops.layers.torch import Reduce, Rearrange
11 |
12 | # network related
13 |
14 | from tiny_recursive_model.mlp_mixer_1d import MLPMixer1D
15 |
16 | # helpers
17 |
18 | def exists(v):
19 | return v is not None
20 |
21 | def default(v, d):
22 | return v if exists(v) else d
23 |
24 | def is_empty(t):
25 | return t.numel() == 0
26 |
27 | def range_from_one(n):
28 | return range(1, n + 1)
29 |
30 | # classes
31 |
32 | class TinyRecursiveModel(Module):
33 | def __init__(
34 | self,
35 | *,
36 | dim,
37 | num_tokens,
38 | network: Module,
39 | num_refinement_blocks = 3, # T in paper
40 | num_latent_refinements = 6, # n in paper - 1 output refinement per N latent refinements
41 | halt_loss_weight = 1.,
42 | num_register_tokens = 0
43 | ):
44 | super().__init__()
45 | assert num_refinement_blocks > 1
46 |
47 | self.input_embed = nn.Embedding(num_tokens, dim)
48 | self.output_init_embed = nn.Parameter(torch.randn(dim) * 1e-2)
49 | self.latent_init_embed = nn.Parameter(torch.randn(dim) * 1e-2)
50 |
51 | self.network = network
52 |
53 | self.num_latent_refinements = num_latent_refinements
54 | self.num_refinement_blocks = num_refinement_blocks
55 |
56 | # register tokens for the self attend version
57 |
58 | self.register_tokens = nn.Parameter(torch.randn(num_register_tokens, dim) * 1e-2)
59 |
60 | # prediction heads
61 |
62 | self.to_pred = nn.Linear(dim, num_tokens, bias = False)
63 |
64 | self.to_halt_pred = nn.Sequential(
65 | Reduce('b n d -> b d', 'mean'),
66 | nn.Linear(dim, 1, bias = False),
67 | nn.Sigmoid(),
68 | Rearrange('... 1 -> ...')
69 | )
70 |
71 | self.halt_loss_weight = halt_loss_weight
72 |
73 | @property
74 | def device(self):
75 | return next(self.parameters()).device
76 |
77 | def get_initial(self):
78 | outputs = self.output_init_embed
79 | latents = self.latent_init_embed
80 |
81 | return outputs, latents
82 |
83 | def embed_inputs_with_registers(
84 | self,
85 | seq
86 | ):
87 | batch = seq.shape[0]
88 |
89 | inputs = self.input_embed(seq)
90 |
91 | # maybe registers
92 |
93 | registers = repeat(self.register_tokens, 'n d -> b n d', b = batch)
94 |
95 | inputs, packed_shape = pack([registers, inputs], 'b * d')
96 |
97 | return inputs, packed_shape
98 |
99 | def refine_latent_then_output_once(
100 | self,
101 | inputs, # (b n d)
102 | outputs, # (b n d)
103 | latents, # (b n d)
104 | ):
105 |
106 | # so it seems for this work, they use only one network
107 | # the network learns to refine the latents if input is passed in, otherwise it refines the output
108 |
109 | for _ in range(self.num_latent_refinements):
110 |
111 | latents = self.network(outputs + latents + inputs)
112 |
113 | outputs = self.network(outputs + latents)
114 |
115 | return outputs, latents
116 |
117 | def deep_refinement(
118 | self,
119 | inputs, # (b n d)
120 | outputs, # (b n d)
121 | latents, # (b n d)
122 | ):
123 |
124 | for step in range_from_one(self.num_refinement_blocks):
125 |
126 | # only last round of refinement receives gradients
127 |
128 | is_last = step == self.num_refinement_blocks
129 | context = torch.no_grad if not is_last else nullcontext
130 |
131 | with context():
132 | outputs, latents = self.refine_latent_then_output_once(inputs, outputs, latents)
133 |
134 | return outputs, latents
135 |
136 | @torch.no_grad()
137 | def predict(
138 | self,
139 | seq,
140 | halt_prob_thres = 0.5,
141 | max_deep_refinement_steps = 12
142 | ):
143 | batch = seq.shape[0]
144 |
145 | inputs, packed_shape = self.embed_inputs_with_registers(seq)
146 |
147 | # initial outputs and latents
148 |
149 | outputs, latents = self.get_initial()
150 |
151 | # active batch indices, the step it exited at, and the final output predictions
152 |
153 | active_batch_indices = arange(batch, device = self.device, dtype = torch.float32)
154 |
155 | preds = []
156 | exited_step_indices = []
157 | exited_batch_indices = []
158 |
159 | for step in range_from_one(max_deep_refinement_steps):
160 | is_last = step == max_deep_refinement_steps
161 |
162 | outputs, latents = self.deep_refinement(inputs, outputs, latents)
163 |
164 | halt_prob = self.to_halt_pred(outputs)
165 |
166 | should_halt = (halt_prob >= halt_prob_thres) | is_last
167 |
168 | if not should_halt.any():
169 | continue
170 |
171 | # maybe remove registers
172 |
173 | registers, outputs_for_pred = unpack(outputs, packed_shape, 'b * d')
174 |
175 | # append to exited predictions
176 |
177 | pred = self.to_pred(outputs_for_pred[should_halt])
178 | preds.append(pred)
179 |
180 | # append the step at which early halted
181 |
182 | exited_step_indices.extend([step] * should_halt.sum().item())
183 |
184 | # append indices for sorting back
185 |
186 | exited_batch_indices.append(active_batch_indices[should_halt])
187 |
188 | if is_last:
189 | continue
190 |
191 | # ready for next round
192 |
193 | inputs = inputs[~should_halt]
194 | outputs = outputs[~should_halt]
195 | latents = latents[~should_halt]
196 | active_batch_indices = active_batch_indices[~should_halt]
197 |
198 | if is_empty(outputs):
199 | break
200 |
201 | preds = cat(preds).argmax(dim = -1)
202 | exited_step_indices = tensor(exited_step_indices)
203 |
204 | exited_batch_indices = cat(exited_batch_indices)
205 | sort_indices = exited_batch_indices.argsort(dim = -1)
206 |
207 | return preds[sort_indices], exited_step_indices[sort_indices]
208 |
209 | def forward(
210 | self,
211 | seq,
212 | outputs,
213 | latents,
214 | labels = None
215 | ):
216 |
217 | inputs, packed_shape = self.embed_inputs_with_registers(seq)
218 |
219 | outputs, latents = self.deep_refinement(inputs, outputs, latents)
220 |
221 | registers, outputs_for_pred = unpack(outputs, packed_shape, 'b * d')
222 |
223 | pred = self.to_pred(outputs_for_pred)
224 |
225 | halt_prob = self.to_halt_pred(outputs)
226 |
227 | outputs, latents = outputs.detach(), latents.detach()
228 |
229 | return_package = (outputs, latents, pred, halt_prob)
230 |
231 | if not exists(labels):
232 | return return_package
233 |
234 | # calculate loss if labels passed in
235 |
236 | loss = F.cross_entropy(rearrange(pred, 'b n l -> b l n'), labels, reduction = 'none')
237 | loss = reduce(loss, 'b ... -> b', 'mean')
238 |
239 | is_all_correct = (pred.argmax(dim = -1) == labels).all(dim = -1)
240 |
241 | halt_loss = F.binary_cross_entropy(halt_prob, is_all_correct.float(), reduction = 'none')
242 |
243 | # total loss and loss breakdown
244 |
245 | total_loss = (
246 | loss +
247 | halt_loss * self.halt_loss_weight
248 | )
249 |
250 | losses = (loss, halt_loss)
251 |
252 | return (total_loss.sum(), losses, *return_package)
253 |
--------------------------------------------------------------------------------