├── .github
└── workflows
│ ├── lint.yaml
│ ├── python-publish.yml
│ └── test.yaml
├── .gitignore
├── LICENSE
├── README.md
├── fig3.png
├── pi_zero_pytorch
├── __init__.py
├── mock_env.py
├── pi_zero.py
└── tensor_typing.py
├── plot_time_from_beta.py
├── pyproject.toml
└── tests
└── test_pi_zero.py
/.github/workflows/lint.yaml:
--------------------------------------------------------------------------------
1 | name: Ruff
2 | on: [push, pull_request]
3 |
4 | jobs:
5 | build:
6 |
7 | runs-on: ubuntu-latest
8 |
9 | steps:
10 | - uses: actions/checkout@v4
11 | - name: Set up Python 3.10
12 | uses: actions/setup-python@v5
13 | with:
14 | python-version: "3.10"
15 | - name: Install dependencies
16 | run: |
17 | python -m pip install uv
18 | python -m uv pip install ruff
19 | - name: Lint with Ruff
20 | run: |
21 | ruff check pi_zero_pytorch/
22 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.github/workflows/test.yaml:
--------------------------------------------------------------------------------
1 | name: tests
2 | on: [push, pull_request]
3 |
4 | env:
5 | TYPECHECK: True
6 |
7 | jobs:
8 | test:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - uses: actions/checkout@v4
12 | - name: Install Python
13 | uses: actions/setup-python@v5
14 | with:
15 | python-version: "3.11"
16 | - name: Install dependencies
17 | run: |
18 | python -m pip install uv
19 | python -m uv pip install --upgrade pip
20 | python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
21 | python -m uv pip install -e .[test]
22 | - name: Test with pytest
23 | run: |
24 | python -m pytest tests/
25 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## pi-zero-pytorch (wip)
4 |
5 | Implementation of π₀ the robotic foundation model architecture proposed by Physical Intelligence
6 |
7 | Summary of this work would be that it is a simplified Transfusion (Zhou et al.) with influence from Stable Diffusion 3 (Esser et al.), mainly the adoption of flow matching instead of diffusion for policy generation, as well as the separation of parameters (Joint Attention from mmDIT). They build on top of a pretrained vision language model, PaliGemma 2B.
8 |
9 | Update: The [official repository](https://github.com/Physical-Intelligence/openpi) has been open sourced!
10 |
11 | ### Appreciation
12 |
13 | - [Einops](https://github.com/arogozhnikov/einops) for the amazing [pack and unpack](https://einops.rocks/4-pack-and-unpack/), used extensively here for managing various token sets
14 |
15 | - [Flex Attention](https://pytorch.org/blog/flexattention/) for allowing for easy mixture of autoregressive and bidirectional attention
16 |
17 | - [@Wonder1905](https://github.com/Wonder1905) for the code review and identifying issues
18 |
19 | - You? maybe a phd student who wants to contribute to the latest SOTA architecture for behavioral cloning?
20 |
21 | ### Install
22 |
23 | ```bash
24 | $ pip install pi-zero-pytorch
25 | ```
26 |
27 | ### Usage
28 |
29 | ```python
30 | import torch
31 | from pi_zero_pytorch import π0
32 |
33 | model = π0(
34 | dim = 512,
35 | dim_action_input = 6,
36 | dim_joint_state = 12,
37 | num_tokens = 20_000
38 | )
39 |
40 | vision = torch.randn(1, 1024, 512)
41 | commands = torch.randint(0, 20_000, (1, 1024))
42 | joint_state = torch.randn(1, 12)
43 | actions = torch.randn(1, 32, 6)
44 |
45 | loss, _ = model(vision, commands, joint_state, actions)
46 | loss.backward()
47 |
48 | # after much training
49 |
50 | sampled_actions = model(vision, commands, joint_state, trajectory_length = 32) # (1, 32, 6)
51 | ```
52 |
53 | To do online learning, just wrap the model with the `Agent` class
54 |
55 | ```python
56 | from pi_zero_pytorch import π0, Agent, EPO
57 |
58 | # wrap the model with `Agent`, which will instantiate actor and critic for PPO
59 |
60 | agent = Agent(model)
61 |
62 | # you'll want to supply your own environment
63 |
64 | from pi_zero_pytorch.mock_env import Env
65 | mock_env = Env((256, 256), 2, 32, 1024, 12)
66 |
67 | # pass your agent and environment to EPO for learning to be orchestrated
68 |
69 | epo = EPO(agent, mock_env)
70 |
71 | # gather memories from environment
72 |
73 | memories = epo.gather_experience_from_env(steps = 10)
74 |
75 | # learn from memories
76 |
77 | epo.learn_agent(memories, batch_size = 2)
78 | ```
79 |
80 | ### Contributing
81 |
82 | At the project root, run
83 |
84 | ```bash
85 | $ pip install '.[test]' # or `uv pip install '.[test]'`
86 | ```
87 |
88 | Then add your tests to `tests/test_pi_zero.py` and run
89 |
90 | ```bash
91 | $ pytest tests/
92 | ```
93 |
94 | That's it
95 |
96 | ### Citation
97 |
98 | ```bibtex
99 | @misc{Black2024,
100 | author = {Kevin Black, Noah Brown, Danny Driess, Adnan Esmail, Michael Equi, Chelsea Finn, Niccolo Fusai, Lachy Groom, Karol Hausman, Brian Ichter, Szymon Jakubczak, Tim Jones, Liyiming Ke, Sergey Levine, Adrian Li-Bell, Mohith Mothukuri, Suraj Nair, Karl Pertsch, Lucy Xiaoyang Shi, James Tanner, Quan Vuong, Anna Walling, Haohuan Wang, Ury Zhilinsky},
101 | url = {https://www.physicalintelligence.company/download/pi0.pdf}
102 | }
103 | ```
104 |
105 | ```bibtex
106 | @inproceedings{Zhou2024ValueRL,
107 | title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
108 | author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
109 | year = {2024},
110 | url = {https://api.semanticscholar.org/CorpusID:273532030}
111 | }
112 | ```
113 |
114 | ```bibtex
115 | @inproceedings{Darcet2023VisionTN,
116 | title = {Vision Transformers Need Registers},
117 | author = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
118 | year = {2023},
119 | url = {https://api.semanticscholar.org/CorpusID:263134283}
120 | }
121 | ```
122 |
123 | ```bibtex
124 | @article{Li2024ImmiscibleDA,
125 | title = {Immiscible Diffusion: Accelerating Diffusion Training with Noise Assignment},
126 | author = {Yiheng Li and Heyang Jiang and Akio Kodaira and Masayoshi Tomizuka and Kurt Keutzer and Chenfeng Xu},
127 | journal = {ArXiv},
128 | year = {2024},
129 | volume = {abs/2406.12303},
130 | url = {https://api.semanticscholar.org/CorpusID:270562607}
131 | }
132 | ```
133 |
134 | ```bibtex
135 | @inproceedings{Sadat2024EliminatingOA,
136 | title = {Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion Models},
137 | author = {Seyedmorteza Sadat and Otmar Hilliges and Romann M. Weber},
138 | year = {2024},
139 | url = {https://api.semanticscholar.org/CorpusID:273098845}
140 | }
141 | ```
142 |
143 | ```bibtex
144 | @article{Bulatov2022RecurrentMT,
145 | title = {Recurrent Memory Transformer},
146 | author = {Aydar Bulatov and Yuri Kuratov and Mikhail S. Burtsev},
147 | journal = {ArXiv},
148 | year = {2022},
149 | volume = {abs/2207.06881},
150 | url = {https://api.semanticscholar.org/CorpusID:250526424}
151 | }
152 | ```
153 |
154 | ```bibtex
155 | @inproceedings{Bessonov2023RecurrentAT,
156 | title = {Recurrent Action Transformer with Memory},
157 | author = {A. B. Bessonov and Alexey Staroverov and Huzhenyu Zhang and Alexey K. Kovalev and D. Yudin and Aleksandr I. Panov},
158 | year = {2023},
159 | url = {https://api.semanticscholar.org/CorpusID:259188030}
160 | }
161 | ```
162 |
163 | ```bibtex
164 | @article{Zhu2024HyperConnections,
165 | title = {Hyper-Connections},
166 | author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
167 | journal = {ArXiv},
168 | year = {2024},
169 | volume = {abs/2409.19606},
170 | url = {https://api.semanticscholar.org/CorpusID:272987528}
171 | }
172 | ```
173 |
174 | ```bibtex
175 | @inproceedings{Sun2025F5RTTSIF,
176 | title = {F5R-TTS: Improving Flow-Matching based Text-to-Speech with Group Relative Policy Optimization},
177 | author = {Xiaohui Sun and Ruitong Xiao and Jianye Mo and Bowen Wu and Qun Yu and Baoxun Wang},
178 | year = {2025},
179 | url = {https://api.semanticscholar.org/CorpusID:277510064}
180 | }
181 | ```
182 |
183 | ```bibtex
184 | @inproceedings{Wang2025EvolutionaryPO,
185 | title = {Evolutionary Policy Optimization},
186 | author = {Jianren Wang and Yifan Su and Abhinav Gupta and Deepak Pathak},
187 | year = {2025},
188 | url = {https://api.semanticscholar.org/CorpusID:277313729}
189 | }
190 | ```
191 |
192 | [*dear alice*](https://www.youtube.com/watch?v=z-Ng5ZvrDm4)
193 |
--------------------------------------------------------------------------------
/fig3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/pi-zero-pytorch/a766537ea23fd91123cd054955127df6e1bd37c8/fig3.png
--------------------------------------------------------------------------------
/pi_zero_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from pi_zero_pytorch.pi_zero import PiZero, π0
2 |
--------------------------------------------------------------------------------
/pi_zero_pytorch/mock_env.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from random import choice
3 |
4 | import torch
5 | from torch import tensor, randn, randint
6 | from torch.nn import Module
7 |
8 | # functions
9 |
10 | def cast_tuple(v):
11 | return v if isinstance(v, tuple) else (v,)
12 |
13 | # mock env
14 |
15 | class Env(Module):
16 | def __init__(
17 | self,
18 | image_shape,
19 | num_images,
20 | num_text_tokens,
21 | max_text_len,
22 | joint_dim,
23 | can_terminate_after = 2
24 | ):
25 | super().__init__()
26 | self.image_shape = image_shape
27 | self.num_images = num_images
28 | self.num_text_tokens = num_text_tokens
29 | self.max_text_len = max_text_len
30 | self.joint_dim = joint_dim
31 |
32 | self.can_terminate_after = can_terminate_after
33 | self.register_buffer('_step', tensor(0))
34 |
35 | def get_random_state(self):
36 | return (
37 | randn(3, self.num_images, *self.image_shape),
38 | randint(0, self.num_text_tokens, (self.max_text_len,)),
39 | randn(self.joint_dim)
40 | )
41 |
42 | def reset(
43 | self,
44 | seed = None
45 | ):
46 | self._step.zero_()
47 | return self.get_random_state()
48 |
49 | def step(
50 | self,
51 | actions,
52 | ):
53 | state = self.get_random_state()
54 | reward = randint(-100, 100, ()).float()
55 |
56 | if self._step > self.can_terminate_after:
57 | truncated = tensor(choice((True, False)))
58 | terminated = tensor(choice((True, False)))
59 | else:
60 | truncated = terminated = tensor(False)
61 |
62 | self._step.add_(1)
63 |
64 | return state, reward, truncated, terminated
65 |
--------------------------------------------------------------------------------
/pi_zero_pytorch/pi_zero.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from random import random, randrange
4 |
5 | from beartype import beartype
6 | from beartype.typing import Callable, Literal
7 |
8 | from functools import partial, wraps
9 | from itertools import count
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | from torch import pi, nn, stack, tensor, is_tensor
14 | from torch.nn import Module, ModuleList
15 | from torch.distributions import Normal
16 | from torch.distributions.beta import Beta
17 |
18 | from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
19 |
20 | from torch.utils.data import TensorDataset, DataLoader
21 |
22 | from torchdiffeq import odeint
23 |
24 | from scipy.optimize import linear_sum_assignment
25 |
26 | from ema_pytorch import EMA
27 |
28 | from adam_atan2_pytorch import AdoptAtan2
29 |
30 | from rotary_embedding_torch import (
31 | RotaryEmbedding,
32 | apply_rotary_emb
33 | )
34 |
35 | import einx
36 | from einops.layers.torch import Rearrange
37 | from einops import rearrange, repeat, reduce, einsum, pack, unpack
38 |
39 | from pi_zero_pytorch.tensor_typing import Float, Int, Bool
40 |
41 | from hyper_connections import HyperConnections
42 |
43 | from hl_gauss_pytorch import HLGaussLayer
44 |
45 | from assoc_scan import AssocScan
46 |
47 | from evolutionary_policy_optimization import LatentGenePool
48 |
49 | import tqdm
50 |
51 | from accelerate import Accelerator
52 |
53 | # ein notation
54 |
55 | # b - batch
56 | # n - sequence
57 | # na - seq of actions
58 | # nt - seq of text tokens
59 | # nv - seq of visual tokens
60 | # ns - seq of additional internal state tokens
61 | # nm - seq of memory tokens
62 | # d - dimension
63 | # da - action dimension
64 | # djs - joint state dimension
65 | # c - image channels
66 | # h - image height
67 | # w - image width
68 | # f - image frames
69 | # s - residual streams (hyper connections paper)
70 |
71 | # token layout for transformer
72 | # vision and language tokens are autoregressive causal mask, actions, interal states + joint bidirectional amongst own tokens, but still autoregressive with respect to other tokens
73 |
74 | # [state token groups] [action token groups] -> [autoregressive masking] [bidirectional]
75 | # [external state] [visual tokens] [language tokens] [maybe reward / condition token] [action registers] [joint state + internal state] [actions]
76 |
77 | # for an attempt to introduce recurrence, all tokens above can be flanked by read and write memory tokens
78 | # [read memory tokens] [...] [write memory tokens]
79 |
80 | # constants
81 |
82 | LinearNoBias = partial(nn.Linear, bias = False)
83 |
84 | # flex attention related
85 | # https://pytorch.org/blog/flexattention/
86 |
87 | flex_attention = None
88 |
89 | if torch.cuda.is_available():
90 | from torch.nn.attention.flex_attention import flex_attention, create_block_mask
91 | flex_attention = torch.compile(flex_attention)
92 |
93 | def create_pizero_attn_mask(
94 | prefix_causal_length,
95 | mask: Bool['b n']
96 | ):
97 | # the pi-zero attention is a triangular causal mask, but bidirectional attention for the actions at the very right hand side
98 |
99 | def mask_fn(batch_index, head_index, query_index, key_index):
100 | key_mask = mask[batch_index, key_index] # variable length states
101 | causal_mask = query_index >= key_index # causal
102 |
103 | bidirectional_action_mask = ( # bidirectional action mask
104 | key_index >= prefix_causal_length and
105 | query_index >= prefix_causal_length
106 | )
107 |
108 | return (key_mask and causal_mask) or bidirectional_action_mask
109 |
110 | return mask_fn
111 |
112 | def softclamp_score_mod(value):
113 | def identity(score, b, h, q, k):
114 | return score
115 |
116 | def softclamped(score, b, h, q, k):
117 | score = score / value
118 | score = torch.tanh(score)
119 | score = score * value
120 | return score
121 |
122 | return softclamped if value > 0. else identity
123 |
124 | # helper functions
125 |
126 | def exists(v):
127 | return v is not None
128 |
129 | def default(v, d):
130 | return v if exists(v) else d
131 |
132 | def identity(t):
133 | return t
134 |
135 | def xnor(x, y):
136 | return not (x ^ y)
137 |
138 | def maybe(fn):
139 | @wraps(fn)
140 | def inner(t, *args, **kwargs):
141 | if not exists(t):
142 | return None
143 |
144 | return fn(t, *args, **kwargs)
145 |
146 | return inner
147 |
148 | def save_args_kwargs(fn):
149 | @wraps(fn)
150 | def decorated(self, *args, **kwargs):
151 | self._init_args_kwargs = (args, kwargs)
152 | return fn(self, *args, **kwargs)
153 |
154 | return decorated
155 |
156 | def to_device(t, device):
157 | return tree_map(lambda el: el.to(device) if is_tensor(el) else el, t)
158 |
159 | def move_input_tensors_to_device(fn):
160 |
161 | @wraps(fn)
162 | def decorated_fn(self, *args, **kwargs):
163 | args, kwargs = to_device((args, kwargs), self.device)
164 | return fn(self, *args, **kwargs)
165 |
166 | return decorated_fn
167 |
168 | def temp_batch_dim(fn):
169 |
170 | @wraps(fn)
171 | def inner(*args, **kwargs):
172 | args, kwargs = tree_map(lambda t: rearrange(t, '... -> 1 ...') if is_tensor(t) else t, (args, kwargs))
173 |
174 | out = fn(*args, **kwargs)
175 |
176 | out = tree_map(lambda t: rearrange(t, '1 ... -> ...') if is_tensor(t) else t, out)
177 | return out
178 |
179 | return inner
180 |
181 | # tensor helpers
182 |
183 | def log(t, eps = 1e-20):
184 | return t.clamp(min = eps).log()
185 |
186 | def l2norm(t, dim = -1):
187 | return F.normalize(t, dim = dim)
188 |
189 | def softclamp(t, value):
190 | if value <= 0.:
191 | return t
192 |
193 | return (t / value).tanh() * value
194 |
195 | def max_neg_value(t):
196 | return -torch.finfo(t.dtype).max
197 |
198 | def pack_with_inverse(t, pattern):
199 | packed, packed_shape = pack(t, pattern)
200 |
201 | def inverse(out, inv_pattern = None):
202 | inv_pattern = default(inv_pattern, pattern)
203 | out = unpack(out, packed_shape, inv_pattern)
204 | return out
205 |
206 | return packed, inverse
207 |
208 | def pack_one_with_inverse(t, pattern):
209 | packed, inverse = pack_with_inverse([t], pattern)
210 |
211 | def inverse_one(out, inv_pattern = None):
212 | out, = inverse(out, inv_pattern)
213 | return out
214 |
215 | return packed, inverse_one
216 |
217 | def tree_flatten_with_inverse(input):
218 | out, tree_spec = tree_flatten(input)
219 |
220 | def inverse(output):
221 | return tree_unflatten(output, tree_spec)
222 |
223 | return out, inverse
224 |
225 | def project(x, y):
226 | x, inverse = pack_one_with_inverse(x, 'b *')
227 | y, _ = pack_one_with_inverse(y, 'b *')
228 |
229 | dtype = x.dtype
230 | x, y = x.double(), y.double()
231 | unit = l2norm(y, dim = -1)
232 |
233 | parallel = (x * unit).sum(dim = -1, keepdim = True) * unit
234 | orthogonal = x - parallel
235 |
236 | return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype)
237 |
238 | def pad_at_dim(
239 | t,
240 | pad: tuple[int, int],
241 | *,
242 | dim = -1,
243 | value = 0.
244 | ):
245 | dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)
246 | zeros = ((0, 0) * dims_from_right)
247 | return F.pad(t, (*zeros, *pad), value = value)
248 |
249 | # flow related
250 |
251 | def default_sample_times(
252 | shape,
253 | s = 0.999,
254 | alpha = 1.5,
255 | beta = 1,
256 | device = None
257 | ):
258 | """ they propose to sample times from Beta distribution - last part of appendix part B """
259 |
260 | alpha = torch.full(shape, alpha, device = device)
261 | beta = torch.full(shape, beta, device = device)
262 | sampled = Beta(alpha, beta).sample()
263 | return (1. - sampled) * s
264 |
265 | def noise_assignment(data, noise):
266 | device = data.device
267 | data, noise = tuple(rearrange(t, 'b ... -> b (...)') for t in (data, noise))
268 | dist = torch.cdist(data, noise)
269 | _, assign = linear_sum_assignment(dist.cpu())
270 | return torch.from_numpy(assign).to(device)
271 |
272 | # policy optimization related
273 |
274 | class GaussianNLL(Module):
275 | def forward(self, mu_sigma, target):
276 | mean, variance = mu_sigma.unbind(dim = -1)
277 | return F.gaussian_nll_loss(mean, target, variance)
278 |
279 | class LinearToMeanStd(Module):
280 | def __init__(
281 | self,
282 | dim,
283 | dim_out,
284 | eps = 1e-5
285 | ):
286 | super().__init__()
287 | self.linear = LinearNoBias(dim, dim_out * 2)
288 | self.eps = eps
289 |
290 | def forward(self, embed):
291 | out = self.linear(embed)
292 |
293 | mean, log_variance = rearrange(out, '... (d mu_sigma) -> mu_sigma ... d', mu_sigma = 2)
294 | variance = log_variance.exp()
295 | std = variance.clamp(min = self.eps).sqrt()
296 |
297 | return stack((mean, std), dim = -1)
298 |
299 | # attention
300 |
301 | class Attention(Module):
302 | @beartype
303 | def __init__(
304 | self,
305 | dim,
306 | dim_head = 64,
307 | heads = 8,
308 | dropout = 0.,
309 | softclamp_value = 50.,
310 | accept_memories = False,
311 | learned_value_action_residual_mix = False,
312 | rotary_emb: RotaryEmbedding | None = None
313 | ):
314 | super().__init__()
315 | self.scale = dim_head ** -0.5
316 | dim_inner = dim_head * heads
317 |
318 | self.rotary_emb = rotary_emb
319 |
320 | self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads)
321 | self.merge_heads = Rearrange('b h n d -> b n (h d)')
322 |
323 | self.rmsnorm = nn.RMSNorm(dim)
324 |
325 | # state parameters
326 |
327 | self.to_qkv = LinearNoBias(dim, 3 * dim_inner)
328 | self.to_out = LinearNoBias(dim_inner, dim)
329 |
330 | # maybe memory parameters
331 |
332 | self.accept_memories = accept_memories
333 |
334 | self.mem_rmsnorm = nn.RMSNorm(dim) if accept_memories else None
335 | self.to_mem_qkv = LinearNoBias(dim, 3 * dim_inner) if accept_memories else None
336 | self.to_mem_out = LinearNoBias(dim_inner, dim) if accept_memories else None
337 |
338 | # action parameters
339 |
340 | self.to_actions_qkvg = LinearNoBias(dim, 4 * dim_inner)
341 |
342 | self.to_action_value_residual_mix = nn.Sequential(
343 | LinearNoBias(dim, heads),
344 | nn.Sigmoid(),
345 | Rearrange('b n h -> b h n 1')
346 | ) if learned_value_action_residual_mix else (lambda _: 0.5)
347 |
348 | self.to_actions_out = LinearNoBias(dim_inner, dim)
349 |
350 | self.softclamp_value = softclamp_value
351 |
352 | def forward_actions_with_cached_state(
353 | self,
354 | actions,
355 | cached_state_keys_values: tuple[Tensor, Tensor],
356 | memories: tuple[Tensor, Tensor] | None = None,
357 | rotary_emb = None,
358 | mask: Bool['b n'] | None = None,
359 | actions_value_residual: Tensor | None = None,
360 | return_keys_values = False,
361 | flex_attn_fn: Callable | None = None
362 | ):
363 | aq, ak, av, ag = self.to_actions_qkvg(actions).chunk(4, dim = -1)
364 |
365 | aq, ak, av, ag = tuple(self.split_heads(t) for t in (aq, ak, av, ag))
366 |
367 | if exists(actions_value_residual):
368 | mix = self.to_action_value_residual_mix(actions)
369 | av = av * mix + actions_value_residual * (1. - mix)
370 |
371 | q = aq
372 | mk, mv = cached_state_keys_values
373 |
374 | # concat cache key / values with action key / values
375 |
376 | k, v = tuple(torch.cat(tensors, dim = -2) for tensors in zip((mk, mv), (ak, av)))
377 |
378 | # handle read, write memories
379 |
380 | assert not (self.accept_memories ^ exists(memories))
381 |
382 | if exists(memories):
383 | _, write_memories = memories
384 | write_memories = self.mem_rmsnorm(write_memories)
385 | # mqkv_write = self.to_mem_qkv(write_memories)
386 |
387 | if exists(rotary_emb):
388 | q = apply_rotary_emb(rotary_emb, q, freqs_seq_dim = -2)
389 | k = apply_rotary_emb(rotary_emb, k)
390 |
391 | elif exists(self.rotary_emb):
392 | q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
393 |
394 | # attention
395 |
396 | if exists(flex_attn_fn):
397 | out = flex_attn_fn(q, k, v)
398 | else:
399 | q = q * self.scale
400 |
401 | sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
402 |
403 | sim = softclamp(sim, self.softclamp_value)
404 |
405 | if exists(mask):
406 | sim = einx.where('b j, b h i j, -> b h i j', mask, sim, max_neg_value(sim))
407 |
408 | attn = sim.softmax(dim = -1)
409 |
410 | out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
411 |
412 | # gate
413 |
414 | out = out * ag.sigmoid()
415 |
416 | # merge attention heads
417 |
418 | out = self.merge_heads(out)
419 |
420 | actions_out = self.to_actions_out(out)
421 |
422 | if not return_keys_values:
423 | return actions_out
424 |
425 | return actions_out, (mk, mv, ak, av)
426 |
427 | def forward_only_vision_language(
428 | self,
429 | state: Float['b n d'],
430 | rotary_emb = None
431 | ) -> Float['b n d']:
432 |
433 | device = state.device
434 |
435 | q, k, v = self.to_qkv(state).chunk(3, dim = -1)
436 |
437 | q, k, v = tuple(self.split_heads(t) for t in (q, k, v))
438 |
439 | if exists(rotary_emb):
440 | q = apply_rotary_emb(rotary_emb, q)
441 | k = apply_rotary_emb(rotary_emb, k)
442 |
443 | elif exists(self.rotary_emb):
444 | q = self.rotary_emb.rotate_queries_or_keys(q)
445 | k = self.rotary_emb.rotate_queries_or_keys(k)
446 |
447 | # attention
448 |
449 | q = q * self.scale
450 |
451 | sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
452 |
453 | sim = softclamp(sim, self.softclamp_value)
454 |
455 | causal_mask = torch.ones(sim.shape[-2:], dtype = torch.bool, device = device).triu(1)
456 |
457 | sim = sim.masked_fill(causal_mask, max_neg_value(sim))
458 |
459 | attn = sim.softmax(dim = -1)
460 |
461 | out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
462 |
463 | # merge attention heads
464 |
465 | out = self.merge_heads(out)
466 |
467 | return self.to_out(out)
468 |
469 | def forward(
470 | self,
471 | multimodal_seq,
472 | actions,
473 | rotary_emb = None,
474 | memories: tuple[Tensor, Tensor] | None = None,
475 | mask: Bool['b n'] | None = None,
476 | actions_value_residual: Tensor | None = None,
477 | return_keys_values = False,
478 | flex_attn_fn: Callable | None = None
479 | ):
480 | seq_len, device = multimodal_seq.shape[-2], multimodal_seq.device
481 |
482 | multimodal_seq = self.rmsnorm(multimodal_seq)
483 |
484 | # separate projections for multimodal seq vs actions
485 |
486 | mq, mk, mv = self.to_qkv(multimodal_seq).chunk(3, dim = -1)
487 |
488 | aq, ak, av, ag = self.to_actions_qkvg(actions).chunk(4, dim = -1)
489 |
490 | mq, mk, mv, aq, ak, av, ag = tuple(self.split_heads(t) for t in (mq, mk, mv, aq, ak, av, ag))
491 |
492 | if exists(actions_value_residual):
493 | mix = self.to_action_value_residual_mix(actions)
494 | av = av * mix + actions_value_residual * (1. - mix)
495 |
496 | q, k, v = tuple(torch.cat(tensors, dim = -2) for tensors in zip((mq, mk, mv), (aq, ak, av)))
497 |
498 | # handle read, write memories
499 |
500 | has_memories = exists(memories) and any([m.numel() > 0 for m in memories])
501 |
502 | assert not (self.accept_memories ^ has_memories)
503 |
504 | if has_memories:
505 | memories, unpack_memories = pack_with_inverse(memories, 'b * d')
506 | memories = self.mem_rmsnorm(memories)
507 | mqkv = self.to_mem_qkv(memories)
508 | mqkv_read, mqkv_write = unpack_memories(mqkv, 'b * d')
509 |
510 | mqr, mkr, mvr, mqw, mkw, mvw = tuple(self.split_heads(t) for t in (*mqkv_read.chunk(3, dim = -1), *mqkv_write.chunk(3, dim = -1)))
511 |
512 | k = torch.cat((mkr, k, mkw), dim = -2)
513 | v = torch.cat((mvr, v, mvw), dim = -2)
514 | q, attn_output_unpack_memories = pack_with_inverse((mqr, q, mqw), 'b h * d')
515 |
516 | # rotary embedding
517 |
518 | if exists(rotary_emb):
519 | q = apply_rotary_emb(rotary_emb, q)
520 | k = apply_rotary_emb(rotary_emb, k)
521 | elif exists(self.rotary_emb):
522 | q = self.rotary_emb.rotate_queries_or_keys(q)
523 | k = self.rotary_emb.rotate_queries_or_keys(k)
524 |
525 | if exists(flex_attn_fn):
526 | out = flex_attn_fn(q, k, v)
527 |
528 | else:
529 | # attention
530 |
531 | q = q * self.scale
532 |
533 | sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
534 |
535 | sim = softclamp(sim, self.softclamp_value)
536 |
537 | causal_mask = torch.ones(sim.shape[-2:], dtype = torch.bool, device = device).triu(1)
538 |
539 | if exists(mask):
540 | causal_mask = einx.logical_or('b j, i j -> b 1 i j', ~mask, causal_mask)
541 |
542 | causal_mask[..., seq_len:, seq_len:] = False # actions have bidirectional attention, lining up with Transfusion paper
543 |
544 | sim = sim.masked_fill(causal_mask, max_neg_value(sim))
545 |
546 | attn = sim.softmax(dim = -1)
547 |
548 | out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
549 |
550 | # gating of values, used in alphafold line of work
551 |
552 | gates = pad_at_dim(ag.sigmoid(), (out.shape[-2] - ag.shape[-2], 0), value = 1., dim = -2)
553 |
554 | out = out * gates
555 |
556 | # split out memories
557 |
558 | if self.accept_memories:
559 | mem_read_out, out, mem_write_out = attn_output_unpack_memories(out)
560 |
561 | # merge attention heads
562 |
563 | out = self.merge_heads(out)
564 |
565 | # separate projections for multimodal seq vs actions
566 |
567 | mout, aout = out[:, :seq_len], out[:, seq_len:]
568 |
569 | output = self.to_out(mout), self.to_actions_out(aout)
570 |
571 | if self.accept_memories:
572 | mem_out, unpack_memories = pack_with_inverse((mem_read_out, mem_write_out), 'b h * d')
573 | mem_out = self.merge_heads(mem_out)
574 | mem_out = self.to_mem_out(mem_out)
575 |
576 | output = (*output, unpack_memories(mem_out, 'b * d'))
577 |
578 | if not return_keys_values:
579 | return output
580 |
581 | return output, (mk, mv, ak, av)
582 |
583 | # attention
584 |
585 | class SwiGLUFeedForward(Module):
586 | def __init__(
587 | self,
588 | dim,
589 | expand_factor = 4.,
590 | dim_inner = None,
591 | rmsnorm = True
592 | ):
593 | super().__init__()
594 | dim_inner = default(dim_inner, int(dim * expand_factor * 2 / 3))
595 |
596 | self.rmsnorm = nn.RMSNorm(dim) if rmsnorm else nn.Identity()
597 | self.proj_in = LinearNoBias(dim, dim_inner * 2)
598 | self.proj_out = LinearNoBias(dim_inner, dim)
599 |
600 | def forward(
601 | self,
602 | seq
603 | ):
604 | seq = self.rmsnorm(seq)
605 | seq, gates = self.proj_in(seq).chunk(2, dim = -1)
606 | seq = seq * F.gelu(gates)
607 | return self.proj_out(seq)
608 |
609 | # actions need time conditioning
610 | # ada-ln zero from DiT - here we will improvise with adaptive rmsnorm
611 |
612 | class RandomFourierEmbed(Module):
613 | def __init__(self, dim):
614 | super().__init__()
615 | self.proj = nn.Linear(1, dim)
616 | self.proj.requires_grad_(False)
617 |
618 | def forward(
619 | self,
620 | times,
621 | ):
622 | times = rearrange(times, '... -> ... 1')
623 | rand_proj = self.proj(times)
624 | return torch.cos(2 * pi * rand_proj)
625 |
626 | class AdaptiveRMSNorm(Module):
627 | def __init__(
628 | self,
629 | dim,
630 | dim_cond
631 | ):
632 | super().__init__()
633 | self.norm = nn.RMSNorm(dim, elementwise_affine = False)
634 |
635 | self.to_gamma = nn.Sequential(
636 | nn.Linear(dim_cond, dim),
637 | nn.Sigmoid()
638 | )
639 |
640 | self.to_beta = LinearNoBias(dim_cond, dim)
641 |
642 | def forward(self, actions, cond):
643 |
644 | if cond.ndim == 2:
645 | cond = rearrange(cond, 'b d -> b 1 d')
646 |
647 | normed = self.norm(actions)
648 | gamma = self.to_gamma(cond)
649 | beta = self.to_beta(cond)
650 | return normed * gamma + beta
651 |
652 | class AdaptiveLayerscale(Module):
653 | def __init__(
654 | self,
655 | dim,
656 | dim_cond,
657 | adaln_zero_bias_init_value = -2.
658 | ):
659 | super().__init__()
660 | adaln_zero_gamma_linear = nn.Linear(dim_cond, dim)
661 | nn.init.zeros_(adaln_zero_gamma_linear.weight)
662 | nn.init.constant_(adaln_zero_gamma_linear.bias, adaln_zero_bias_init_value)
663 |
664 | self.to_adaln_zero_gamma = adaln_zero_gamma_linear
665 |
666 | def forward(self, actions, cond):
667 |
668 | if cond.ndim == 2:
669 | cond = rearrange(cond, 'b d -> b 1 d')
670 |
671 | gamma = self.to_adaln_zero_gamma(cond)
672 | return actions * gamma.sigmoid()
673 |
674 | # main class
675 |
676 | class PiZero(Module):
677 | @beartype
678 | @save_args_kwargs
679 | def __init__(
680 | self,
681 | dim,
682 | num_tokens,
683 | dim_action_input,
684 | dim_joint_state,
685 | dim_time_cond = None,
686 | depth = 12,
687 | dim_head = 64,
688 | heads = 8,
689 | use_flex_attn = False,
690 | ff_expand_factor = 4.,
691 | attn_softclamp_value = 50.,
692 | final_norm_softclamp_value = 30.,
693 | vit: Module | None = None,
694 | vit_dim = None,
695 | external_state_encoders: Module | list[Module] | None = None,
696 | dim_internal_state: int | None = None,
697 | num_action_register_tokens = 4,
698 | attn_kwargs: dict = dict(),
699 | ff_kwargs: dict = dict(),
700 | lm_pad_id = -1,
701 | lm_loss_weight = 1.,
702 | flow_loss_weight = 1.,
703 | immiscible_flow = False, # https://arxiv.org/abs/2406.12303
704 | sample_times_fn = default_sample_times,
705 | reward_tokens_dropout_prob = 0.,
706 | num_recurrent_memory_tokens = 0,
707 | num_residual_streams = 1,
708 | dim_latent = None,
709 | policy_optimizable = False, # if set to True, will use mean variance network for access to log prob
710 | is_critic = False, # whether this model is used as the critic, with the histogram classification loss from Imani et al. https://arxiv.org/html/2402.13425v1
711 | critic_value_kwargs: dict = dict(
712 | min_value = -10.,
713 | max_value = 10.,
714 | num_bins = 50
715 | ),
716 | odeint_kwargs: dict = dict(
717 | atol = 1e-5,
718 | rtol = 1e-5,
719 | method = 'midpoint'
720 | ),
721 | ):
722 | super().__init__()
723 | dim_time_cond = default(dim_time_cond, dim * 2)
724 |
725 | self.dim = dim
726 |
727 | # flex attention related
728 |
729 | assert not (use_flex_attn and not exists(flex_attention)), 'flex attention cannot be used'
730 | self.use_flex_attn = use_flex_attn
731 | self.attn_softclamp_value = attn_softclamp_value
732 |
733 | # vit
734 |
735 | self.vit = vit
736 |
737 | self.maybe_to_image_tokens = nn.Linear(vit_dim, dim) if exists(vit_dim) and vit_dim != dim else nn.Identity()
738 |
739 | # embedding
740 |
741 | self.token_emb = nn.Embedding(num_tokens, dim)
742 |
743 | # internal states
744 |
745 | self.to_joint_state_tokens = nn.Linear(dim_joint_state, dim)
746 |
747 | self.dim_internal_state = default(dim_internal_state, dim)
748 | self.to_internal_state_tokens = nn.Linear(dim_internal_state, dim) if exists(dim_internal_state) else nn.Identity()
749 |
750 | # additional external states
751 |
752 | external_state_encoders = default(external_state_encoders, [])
753 | self.external_state_encoders = ModuleList(external_state_encoders)
754 |
755 | # actions
756 |
757 | self.dim_action_input = dim_action_input
758 |
759 | self.action_register_tokens = nn.Parameter(torch.zeros(num_action_register_tokens, dim))
760 | nn.init.normal_(self.action_register_tokens, std = 0.02)
761 |
762 | self.to_action_tokens = nn.Linear(dim_action_input, dim)
763 |
764 | # time conditioning
765 |
766 | self.to_time_cond = nn.Sequential(
767 | RandomFourierEmbed(dim),
768 | nn.Linear(dim, dim_time_cond),
769 | nn.SiLU(),
770 | )
771 |
772 | # latent variable / gene conditioning
773 |
774 | can_accept_latent = exists(dim_latent)
775 | self.can_accept_latent = can_accept_latent
776 |
777 | if can_accept_latent:
778 | self.to_latent_cond = nn.Sequential(
779 | nn.Linear(dim_latent, dim_time_cond * 2),
780 | nn.SiLU(),
781 | nn.Linear(dim_time_cond * 2, dim_time_cond),
782 | )
783 |
784 | nn.init.zeros_(self.to_latent_cond[-1].weight)
785 | nn.init.zeros_(self.to_latent_cond[-1].bias)
786 |
787 | # positional embedding
788 |
789 | self.rotary_emb = RotaryEmbedding(dim_head)
790 |
791 | # recurrent memory parameters and logic
792 |
793 | self.has_recurrent_memories = num_recurrent_memory_tokens > 0
794 |
795 | self.memory_tokens = nn.Parameter(torch.zeros(num_recurrent_memory_tokens, dim))
796 | nn.init.normal_(self.memory_tokens, std = 0.02)
797 |
798 | self.final_norm_write_memories = nn.RMSNorm(dim) if self.has_recurrent_memories else None
799 |
800 | # residual functions, with maybe hyper connections
801 |
802 | assert num_residual_streams >= 1
803 | init_residual_fn, self.maybe_expand_residuals, self.maybe_reduce_residuals = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
804 |
805 | residual_fns = []
806 | counter = count()
807 |
808 | # attention and feedforward
809 |
810 | layers = []
811 | cond_layers = []
812 |
813 | for i in range(depth):
814 | is_first_block = i == 0
815 |
816 | layers.append(ModuleList([
817 | Attention(dim = dim, dim_head = dim_head, heads = heads, accept_memories = self.has_recurrent_memories, learned_value_action_residual_mix = not is_first_block, **attn_kwargs),
818 | SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs),
819 | SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, rmsnorm = False, **ff_kwargs),
820 | SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs) if self.has_recurrent_memories else None
821 | ]))
822 |
823 | residual_fns.append(ModuleList([
824 | init_residual_fn(dim = dim, layer_index = next(counter)),
825 | init_residual_fn(dim = dim, layer_index = next(counter)),
826 | ]))
827 |
828 | cond_layers.append(ModuleList([
829 | AdaptiveRMSNorm(dim, dim_time_cond),
830 | AdaptiveLayerscale(dim, dim_time_cond),
831 | AdaptiveRMSNorm(dim, dim_time_cond),
832 | AdaptiveLayerscale(dim, dim_time_cond)
833 | ]))
834 |
835 | self.layers = ModuleList(layers)
836 | self.cond_layers = ModuleList(cond_layers)
837 |
838 | self.residual_layers = ModuleList(residual_fns)
839 |
840 | self.final_norm_softclamp = partial(softclamp, value = final_norm_softclamp_value)
841 |
842 | self.final_norm = nn.RMSNorm(dim)
843 | self.final_actions_norm = nn.RMSNorm(dim)
844 |
845 | # unembedding
846 |
847 | self.state_to_logits = LinearNoBias(dim, num_tokens)
848 |
849 | # actor related
850 |
851 | self.actions_to_pred_flow = None
852 | self.loss_fn = None
853 |
854 | if not is_critic:
855 | if not policy_optimizable:
856 | self.actions_to_pred_flow = LinearNoBias(dim, dim_action_input)
857 | self.loss_fn = nn.MSELoss()
858 | else:
859 | self.actions_to_pred_flow = LinearToMeanStd(dim, dim_action_input)
860 | self.loss_fn = GaussianNLL()
861 |
862 | self.is_mean_std_output = policy_optimizable
863 | self.policy_optimizable = policy_optimizable
864 |
865 | # critic related
866 |
867 | self.is_critic = is_critic
868 |
869 | self.to_critic_value = HLGaussLayer(
870 | dim,
871 | hl_gauss_loss = critic_value_kwargs
872 | )
873 |
874 | # the language token id padding id, for fine-tuning as well as taking care of the masking on top of causal mask
875 |
876 | self.lm_pad_id = lm_pad_id
877 |
878 | # flow related
879 |
880 | self.immiscible_flow = immiscible_flow
881 |
882 | # reward classifier free guidance
883 |
884 | self.reward_tokens_dropout_prob = reward_tokens_dropout_prob
885 |
886 | # time sampling related
887 |
888 | self.sample_times_fn = default(sample_times_fn, torch.rand)
889 |
890 | # loss related
891 |
892 | self.lm_loss_weight = lm_loss_weight
893 | self.flow_loss_weight = flow_loss_weight
894 |
895 | # sampling related
896 |
897 | self.odeint_fn = partial(odeint, **odeint_kwargs)
898 |
899 | self.register_buffer('zero', torch.tensor(0.), persistent = False)
900 |
901 | # tensor typing
902 |
903 | self._nm = num_recurrent_memory_tokens
904 |
905 | @property
906 | def can_cfg(self):
907 | return self.reward_tokens_dropout_prob > 0.
908 |
909 | @property
910 | def device(self):
911 | return next(self.parameters()).device
912 |
913 | @beartype
914 | def load_pretrained_vlm_weights_(
915 | self,
916 | weights: dict[str, Tensor]
917 | ):
918 | raise NotImplementedError
919 |
920 | def create_ema(
921 | self,
922 | beta = 0.99,
923 | **ema_kwargs
924 | ) -> EMA:
925 |
926 | ema_pi_zero = EMA(
927 | self,
928 | beta = beta,
929 | include_online_model = False,
930 | forward_method_names = (
931 | 'sample_actions',
932 | ),
933 | **ema_kwargs
934 | )
935 |
936 | return ema_pi_zero
937 |
938 | def create_actor(self, **kwargs) -> PiZero:
939 | assert not self.is_critic, 'base model must not be a critic'
940 |
941 | # make probabilistic flow if not already
942 |
943 | if not self.policy_optimizable:
944 | assert 'policy_optimizable' not in kwargs
945 | kwargs.update(policy_optimizable = True)
946 |
947 | orig_args, orig_kwargs = self._init_args_kwargs
948 | actor = PiZero(*orig_args, **orig_kwargs, **kwargs)
949 |
950 | # load all possible shared parameters except for output head to logits (for histogram loss)
951 |
952 | state_dict = self.state_dict()
953 | actor.load_state_dict(state_dict, strict = False)
954 |
955 | # now, initialize the actor with variance of 1.
956 | # https://arxiv.org/abs/2302.08875
957 |
958 | if not self.policy_optimizable:
959 | orig_mean_weight = self.actions_to_pred_flow.weight
960 |
961 | actor_mean_std_weight = actor.actions_to_pred_flow.linear.weight
962 |
963 | actor_mean_std_weight.data.copy_(rearrange([orig_mean_weight, torch.zeros_like(orig_mean_weight)], 'mu_sigma o i -> (o mu_sigma) i'))
964 |
965 | return actor.to(self.device)
966 |
967 | def create_critic(self, **kwargs) -> PiZero:
968 | assert self.policy_optimizable and not self.is_critic, 'base model must be policy optimizable as well as not a critic already'
969 |
970 | assert 'is_critic' not in kwargs
971 | kwargs.update(is_critic = True)
972 |
973 | orig_args, orig_kwargs = self._init_args_kwargs
974 | critic = PiZero(*orig_args, **orig_kwargs, **kwargs)
975 |
976 | # load all possible shared parameters except for output head to logits (for histogram loss)
977 |
978 | state_dict = self.state_dict()
979 | critic.load_state_dict(state_dict, strict = False)
980 |
981 | return critic.to(self.device)
982 |
983 | @torch.no_grad()
984 | def sample_actions(
985 | self,
986 | images,
987 | token_ids,
988 | joint_states,
989 | trajectory_length: int,
990 | latents: Float['d'] | Float['b d'] = None,
991 | reward_tokens: Float['b d'] | None = None,
992 | internal_state_tokens: Float['b ns d'] | None = None,
993 | steps = 18,
994 | show_pbar = True,
995 | cond_scale = 0.,
996 | temperature = 1.,
997 | remove_parallel_component = True,
998 | keep_parallel_frac = 0.,
999 | cache_kv = True,
1000 | return_states_for_replay = False,
1001 | critic: Module | None = None,
1002 | ):
1003 | assert not self.is_critic
1004 |
1005 | batch_size = token_ids.shape[0]
1006 |
1007 | was_training = self.training
1008 | self.eval()
1009 |
1010 | pbar = tqdm.tqdm(desc = 'sampling action trajectory', disable = not show_pbar, total = steps)
1011 |
1012 | # accumulate log probs for ppo
1013 |
1014 | assert not (return_states_for_replay and not self.is_mean_std_output), 'only pi-zero with `policy_optimizable` turned on can return log probs'
1015 |
1016 | timesteps = []
1017 | log_probs = []
1018 | sampled_flows = []
1019 | denoised_actions_across_time = []
1020 |
1021 | critic_values = []
1022 |
1023 | # ode step function
1024 |
1025 | cached_state_kv = None
1026 | null_cached_state_kv = None
1027 |
1028 | def ode_fn(timestep, denoised_actions):
1029 | nonlocal cached_state_kv
1030 | nonlocal null_cached_state_kv
1031 |
1032 | input_args = (
1033 | images,
1034 | token_ids,
1035 | joint_states,
1036 | denoised_actions
1037 | )
1038 |
1039 | input_kwargs = dict(
1040 | times = timestep,
1041 | latents = latents,
1042 | reward_tokens = reward_tokens,
1043 | internal_state_tokens = internal_state_tokens,
1044 | cached_state_keys_values = (cached_state_kv, null_cached_state_kv),
1045 | cond_scale = cond_scale,
1046 | remove_parallel_component = remove_parallel_component,
1047 | keep_parallel_frac = keep_parallel_frac
1048 | )
1049 |
1050 | output, (new_cached_state_kv, new_null_cached_state_kv) = self.forward_with_reward_cfg(*input_args, **input_kwargs)
1051 |
1052 | if exists(critic):
1053 | critic_value, _ = critic.forward_with_reward_cfg(*input_args, **input_kwargs)
1054 | critic_values.append(critic_value)
1055 |
1056 | flow = output
1057 |
1058 | if self.is_mean_std_output:
1059 | mean, std = output.unbind(dim = -1)
1060 |
1061 | flow = torch.normal(mean, std * temperature)
1062 |
1063 | log_prob = Normal(mean, std).log_prob(flow)
1064 |
1065 | # save for replaying for optimizing actor
1066 |
1067 | denoised_actions_across_time.append(denoised_actions)
1068 | timesteps.append(repeat(timestep, ' -> b', b = batch_size))
1069 | log_probs.append(log_prob)
1070 | sampled_flows.append(flow)
1071 |
1072 | if cache_kv:
1073 | cached_state_kv = new_cached_state_kv
1074 | null_cached_state_kv = new_null_cached_state_kv
1075 |
1076 | pbar.update(1)
1077 |
1078 | return flow
1079 |
1080 | # start with random gaussian noise - y0
1081 |
1082 | noise = torch.randn((batch_size, trajectory_length, self.dim_action_input), device = self.device)
1083 |
1084 | # time steps
1085 |
1086 | times = torch.linspace(0., 1., steps, device = self.device)
1087 |
1088 | # ode
1089 |
1090 | trajectory = self.odeint_fn(ode_fn, noise, times)
1091 |
1092 | sampled_actions = trajectory[-1]
1093 |
1094 | self.train(was_training)
1095 |
1096 | pbar.close()
1097 |
1098 | if not return_states_for_replay:
1099 | return sampled_actions
1100 |
1101 | # place the time step dimension after batch
1102 |
1103 | timesteps = stack(timesteps, dim = 1)
1104 | log_probs = stack(log_probs, dim = 1)
1105 | sampled_flow = stack(sampled_flows, dim = 1)
1106 | denoised_actions_across_time = stack(denoised_actions_across_time, dim = 1)
1107 |
1108 | policy_optimization_outputs = (denoised_actions_across_time, timesteps, log_probs, sampled_flow)
1109 |
1110 | # return critic value predictions if passed in - will just deepcopy pi-zero + a critic head
1111 |
1112 | if exists(critic):
1113 | critic_values = stack(critic_values, dim = 1)
1114 | policy_optimization_outputs = (*policy_optimization_outputs, critic_values)
1115 |
1116 | return sampled_actions, policy_optimization_outputs
1117 |
1118 | @torch.no_grad()
1119 | def forward_with_reward_cfg(
1120 | self,
1121 | *args,
1122 | reward_tokens: Float['b d'] | None = None,
1123 | cached_state_keys_values = (None, None),
1124 | cond_scale = 0.,
1125 | remove_parallel_component = False,
1126 | keep_parallel_frac = 0.,
1127 | **kwargs
1128 | ):
1129 |
1130 | with_reward_cache, without_reward_cache = cached_state_keys_values
1131 |
1132 | forward_kwargs = dict(
1133 | return_state_keys_values = True,
1134 | return_actions_flow = True,
1135 | )
1136 |
1137 | action_flow_with_reward, with_reward_cache_kv = self.forward(
1138 | *args,
1139 | reward_tokens = reward_tokens,
1140 | cached_state_keys_values = with_reward_cache,
1141 | **forward_kwargs,
1142 | **kwargs
1143 | )
1144 |
1145 | if not exists(reward_tokens) or cond_scale == 0.:
1146 | return action_flow_with_reward, (with_reward_cache_kv, None)
1147 |
1148 | assert self.can_cfg, 'you need to train with reward token dropout'
1149 |
1150 | action_flow_without_reward, without_reward_cache_kv = self.forward(
1151 | *args,
1152 | cached_state_keys_values = without_reward_cache,
1153 | **forward_kwargs,
1154 | **kwargs
1155 | )
1156 |
1157 | update = action_flow_with_reward - action_flow_without_reward
1158 |
1159 | if remove_parallel_component:
1160 | # from https://arxiv.org/abs/2410.02416
1161 |
1162 | update_parallel, update_orthog = project(update, action_flow_with_reward)
1163 | update = update_orthog + update_parallel * keep_parallel_frac
1164 |
1165 | flow_with_reward_cfg = action_flow_with_reward + cond_scale * update
1166 |
1167 | return flow_with_reward_cfg, (with_reward_cache_kv, without_reward_cache_kv)
1168 |
1169 | @move_input_tensors_to_device
1170 | def forward_only_vision_language(
1171 | self,
1172 | images: Float['b nv d'] | Float['b c h w'] | Float['b c f h w'], # vision
1173 | token_ids: Int['b nt'], # language
1174 | ) -> Float['b n d']:
1175 |
1176 | device = token_ids.device
1177 |
1178 | language_tokens = self.token_emb(token_ids)
1179 |
1180 | # vision
1181 |
1182 | if exists(self.vit):
1183 | assert images.ndim in {4, 5}
1184 | is_multiple_images = images.ndim == 5
1185 |
1186 | if is_multiple_images:
1187 | images = rearrange(images, 'b c f h w -> b f c h w')
1188 | images, inverse_pack_image_frames = pack_with_inverse([images], '* c h w')
1189 |
1190 | with torch.no_grad():
1191 | self.vit.eval()
1192 | visual_tokens = self.vit(images)
1193 |
1194 | if is_multiple_images:
1195 | visual_tokens, = inverse_pack_image_frames(visual_tokens, '* n d')
1196 | visual_tokens = rearrange(visual_tokens, 'b f n d -> b (f n) d')
1197 |
1198 | else:
1199 | assert images.ndim == 3, 'images must be already encoded as (batch, seq, feature dimension)'
1200 | visual_tokens = images
1201 |
1202 | visual_tokens = self.maybe_to_image_tokens(visual_tokens)
1203 |
1204 | # concat visual rep with language
1205 |
1206 | state_tokens, _ = pack_with_inverse([
1207 | visual_tokens,
1208 | language_tokens,
1209 | ], 'b * d')
1210 |
1211 | # rotary embeddings
1212 |
1213 | seq_len = state_tokens.shape[-2]
1214 |
1215 | seq = torch.arange(seq_len, device = device)
1216 |
1217 | rotary_emb = self.rotary_emb(seq)
1218 |
1219 | # transformer
1220 |
1221 | for attn, ff, _, _ in self.layers:
1222 |
1223 | state_attn_out = attn.forward_only_vision_language(state_tokens, rotary_emb = rotary_emb)
1224 |
1225 | state_tokens = state_tokens + state_attn_out
1226 |
1227 | state_tokens = ff(state_tokens) + state_tokens
1228 |
1229 | embed = self.final_norm_softclamp(state_tokens)
1230 |
1231 | logits = self.state_to_logits(embed)
1232 |
1233 | return logits
1234 |
1235 | @move_input_tensors_to_device
1236 | def forward_for_policy_loss(
1237 | self,
1238 | images,
1239 | commands,
1240 | joint_state,
1241 | actions,
1242 | flow,
1243 | times,
1244 | old_log_probs: Float['b na'],
1245 | advantages: Float['b t'],
1246 | clip_eps = 0.2,
1247 | entropy_weight = 1e-2,
1248 | norm_eps = 1e-5,
1249 | **kwargs,
1250 | ):
1251 | assert not self.is_critic
1252 | assert self.policy_optimizable
1253 | assert 'return_actions_flow' not in kwargs
1254 |
1255 | # flatten the time into the batch for actions at timestep, sampled flow, and log prob
1256 |
1257 | if times.ndim == 2:
1258 | times = rearrange(times, 'b t -> (b t)')
1259 |
1260 | if flow.ndim == 4:
1261 | flow = rearrange(flow, 'b t ... -> (b t) ...')
1262 |
1263 | if old_log_probs.ndim == 4:
1264 | old_log_probs = rearrange(old_log_probs, 'b t ... -> (b t) ...')
1265 |
1266 | if actions.ndim == 4:
1267 | actions = rearrange(actions, 'b t ... -> (b t) ...')
1268 |
1269 | # expand inputs across timesteps if need be
1270 |
1271 | (
1272 | images,
1273 | commands,
1274 | joint_state,
1275 | ) = tuple(repeat(inp, 'b ... -> (b t) ...', t = times.shape[0] // inp.shape[0]) for inp in (
1276 | images,
1277 | commands,
1278 | joint_state
1279 | ))
1280 |
1281 | mean_std = self.forward(
1282 | images,
1283 | commands,
1284 | joint_state,
1285 | actions,
1286 | return_actions_flow = True,
1287 | **kwargs
1288 | )
1289 |
1290 | normal_dist = Normal(*mean_std.unbind(dim = -1))
1291 |
1292 | new_log_probs = normal_dist.log_prob(actions)
1293 |
1294 | # ppo surrogate loss
1295 |
1296 | ratio = (new_log_probs - old_log_probs).exp()
1297 |
1298 | advantages = F.layer_norm(advantages, advantages.shape, eps = norm_eps)
1299 |
1300 | advantages = rearrange(advantages, 'b t -> (b t) 1 1')
1301 |
1302 | surr1 = ratio * advantages
1303 | surr2 = ratio.clamp(1. - clip_eps, 1. + clip_eps) * advantages
1304 |
1305 | clipped_surr_loss = torch.min(surr1, surr2).sum(dim = -1)
1306 |
1307 | # entropy
1308 |
1309 | entropy = (normal_dist.entropy() * entropy_weight).sum(dim = -1)
1310 |
1311 | return -(clipped_surr_loss + entropy * entropy_weight).mean()
1312 |
1313 | @move_input_tensors_to_device
1314 | def forward_for_critic_loss(
1315 | self,
1316 | *args,
1317 | old_values: Float['b t'],
1318 | advantages: Float['b t'],
1319 | clip_eps = 0.4,
1320 | **kwargs
1321 | ):
1322 | assert self.is_critic
1323 |
1324 | eps = clip_eps
1325 | loss_fn = self.to_critic_value.loss_fn
1326 |
1327 | critic_value, critic_logits = self.forward(*args, **kwargs)
1328 |
1329 | # value clipping
1330 |
1331 | advantages = rearrange(advantages, 'b t -> (b t)')
1332 | old_values = rearrange(old_values, 'b t -> (b t)')
1333 |
1334 | returns = old_values + advantages
1335 |
1336 | clipped_value = old_values + (critic_value - old_values).clamp(-eps, eps)
1337 |
1338 | clipped_loss = loss_fn(clipped_value, returns, reduction = 'none')
1339 | loss = loss_fn(critic_logits, returns, reduction = 'none')
1340 |
1341 | return torch.max(clipped_loss, loss).mean()
1342 |
1343 | @move_input_tensors_to_device
1344 | def forward(
1345 | self,
1346 | images: Float['b nv d'] | Float['b c h w'] | Float['b c f h w'], # vision
1347 | token_ids: Int['b nt'], # language
1348 | joint_state: Float['b djs'], # joint state
1349 | actions: Float['b na da'] | None = None, # action
1350 | times: Float['b'] = None,
1351 | latents: Float['d'] | Float['b d'] = None,
1352 | reward_tokens: Float['b d'] | None = None,
1353 | internal_state_tokens: Float['b ns d'] | None = None,
1354 | external_states: tuple[Float['b ...']] | None = None,
1355 | record_and_return_memory_tokens = False,
1356 | past_recurrent_memory_tokens: Float['b {self._nm} d'] | None = None,
1357 | return_actions_flow = False,
1358 | return_state_keys_values = False,
1359 | cached_state_keys_values: list[tuple[Tensor, Tensor]] | None = None,
1360 | return_language_loss = True,
1361 | return_action_flow_loss = True,
1362 | **kwargs
1363 | ):
1364 | inferencing = exists(cached_state_keys_values)
1365 | assert not (inferencing and not return_actions_flow), 'must be generating action trajectory if receiving cached state key values'
1366 |
1367 | if not exists(actions) and not self.is_critic:
1368 | return self.sample_actions(images, token_ids, joint_state, **kwargs)
1369 |
1370 | batch, device = token_ids.shape[0], token_ids.device
1371 |
1372 | # noising the action for flow matching
1373 |
1374 | if not exists(times):
1375 | times = self.sample_times_fn((batch,), device = device)
1376 |
1377 | if times.ndim == 0:
1378 | times = repeat(times, '-> b', b = batch)
1379 |
1380 | # handle latent genes
1381 |
1382 | if exists(latents) and latents.ndim == 1:
1383 | latents = repeat(latents, 'd -> b d', b = batch)
1384 |
1385 | # if not returning the actions predicted flow, assume training and noise the actions for loss
1386 |
1387 | if not return_actions_flow and not self.is_critic:
1388 | noise = torch.randn_like(actions)
1389 |
1390 | if self.immiscible_flow:
1391 | assignment = noise_assignment(actions, noise)
1392 | noise = noise[assignment]
1393 |
1394 | flow = actions - noise
1395 | padded_times = rearrange(times, 'b -> b 1 1')
1396 |
1397 | actions = noise.lerp(actions, padded_times)
1398 |
1399 | # actions
1400 |
1401 | time_cond = self.to_time_cond(times)
1402 | action_tokens = self.to_action_tokens(actions)
1403 |
1404 | # handle maybe latents
1405 |
1406 | if exists(latents):
1407 | assert self.can_accept_latent
1408 |
1409 | latent_cond = self.to_latent_cond(latents)
1410 |
1411 | time_cond = time_cond * (latent_cond + 1.)
1412 |
1413 | # register tokens
1414 |
1415 | action_register_tokens = repeat(self.action_register_tokens, '... -> b ...', b = batch)
1416 |
1417 | # take care of maybe recurrent memory tokens
1418 |
1419 | assert self.has_recurrent_memories or not exists(past_recurrent_memory_tokens), 'you are asking for memories to be read, but `num_recurrent_memory_tokens` is 0'
1420 | assert self.has_recurrent_memories or not record_and_return_memory_tokens, 'you are asking for memories to be written, but `num_recurrent_memory_tokens` is 0'
1421 |
1422 | if not exists(past_recurrent_memory_tokens):
1423 | past_recurrent_memory_tokens = actions.new_empty((batch, 0, self.dim))
1424 |
1425 | if self.has_recurrent_memories:
1426 | write_memory_tokens = repeat(self.memory_tokens, 'nm d -> b nm d', b = batch)
1427 | else:
1428 | write_memory_tokens = actions.new_empty((batch, 0, self.dim))
1429 |
1430 | # joint state + additional internal states
1431 |
1432 | joint_state_tokens = self.to_joint_state_tokens(joint_state)
1433 |
1434 | # additional internal state tokens
1435 |
1436 | if not exists(internal_state_tokens):
1437 | internal_state_tokens = joint_state_tokens.new_empty((batch, 0, self.dim_internal_state))
1438 |
1439 | internal_state_tokens = self.to_internal_state_tokens(internal_state_tokens)
1440 |
1441 | # handle memory tokens, both read and write as a tuple of two tensors
1442 |
1443 | memory_tokens = (past_recurrent_memory_tokens, write_memory_tokens)
1444 |
1445 | # mem_length = past_recurrent_memory_tokens.shape[-2] + write_memory_tokens.shape[-2]
1446 |
1447 | # pack into [action registers] [internal + joint states] [actions]
1448 |
1449 | action_tokens, inverse_pack_action_registers = pack_with_inverse([
1450 | action_register_tokens,
1451 | joint_state_tokens,
1452 | internal_state_tokens,
1453 | action_tokens
1454 | ], 'b * d')
1455 |
1456 | action_with_registers_length = action_tokens.shape[-2]
1457 |
1458 | state_tokens = None
1459 |
1460 | if not inferencing:
1461 | # language
1462 |
1463 | labels = token_ids[:, 1:]
1464 |
1465 | language_tokens = self.token_emb(token_ids)
1466 |
1467 | # vision
1468 |
1469 | if exists(self.vit):
1470 | assert images.ndim in {4, 5}
1471 | is_multiple_images = images.ndim == 5
1472 |
1473 | if is_multiple_images:
1474 | images = rearrange(images, 'b c f h w -> b f c h w')
1475 | images, inverse_pack_image_frames = pack_with_inverse([images], '* c h w')
1476 |
1477 | with torch.no_grad():
1478 | self.vit.eval()
1479 | visual_tokens = self.vit(images)
1480 |
1481 | if is_multiple_images:
1482 | visual_tokens, = inverse_pack_image_frames(visual_tokens, '* n d')
1483 | visual_tokens = rearrange(visual_tokens, 'b f n d -> b (f n) d')
1484 |
1485 | else:
1486 | assert images.ndim == 3, 'images must be already encoded as (batch, seq, feature dimension)'
1487 | visual_tokens = images
1488 |
1489 | visual_tokens = self.maybe_to_image_tokens(visual_tokens)
1490 |
1491 | # maybe reward tokens
1492 |
1493 | if not exists(reward_tokens):
1494 | reward_tokens = visual_tokens.new_empty((batch, 0, self.dim))
1495 |
1496 | # maybe dropout reward tokens
1497 |
1498 | if self.training and random() < self.reward_tokens_dropout_prob:
1499 | reward_tokens = reward_tokens[:, 0:0]
1500 |
1501 | # additional external states
1502 |
1503 | if exists(external_states):
1504 | external_state_tokens = [encode(external_state) for encode, external_state in zip(self.external_state_encoders, external_states)]
1505 | external_state_tokens = pack(external_state_tokens, 'b * d')
1506 |
1507 | else:
1508 | external_state_tokens = visual_tokens.new_empty((batch, 0, self.dim))
1509 |
1510 | # concat visual rep with language
1511 |
1512 | state_tokens, inverse_packed_states = pack_with_inverse([
1513 | external_state_tokens,
1514 | visual_tokens,
1515 | language_tokens,
1516 | reward_tokens
1517 | ], 'b * d')
1518 |
1519 | # take care of masking for variable lengthed states, starting with the language tokens
1520 |
1521 | # which then leads to proper rotary embeddings
1522 |
1523 | command_length = token_ids.shape[-1]
1524 |
1525 | language_mask = token_ids != self.lm_pad_id
1526 |
1527 | if inferencing:
1528 | state_length = cached_state_keys_values[0][0].shape[-2]
1529 | else:
1530 | state_length = state_tokens.shape[-2]
1531 |
1532 | mask = F.pad(language_mask, (state_length - command_length, action_with_registers_length), value = True) # assume fixed number of images for now, but address variable length modality states later
1533 |
1534 | # memory
1535 |
1536 | mask = F.pad(mask, (past_recurrent_memory_tokens.shape[-2], write_memory_tokens.shape[-2]), value = True)
1537 |
1538 | # rotary embeddings
1539 |
1540 | seq = mask.float().cumsum(dim = -1)
1541 | rotary_emb = self.rotary_emb(seq)
1542 |
1543 | rotary_emb = rearrange(rotary_emb, 'b n d -> b 1 n d')
1544 |
1545 | # prepare maybe flex attention
1546 |
1547 | flex_attn_fn = None
1548 |
1549 | if not inferencing and self.use_flex_attn and state_tokens.is_cuda:
1550 |
1551 | prefix_length = state_tokens.shape[-2]
1552 | seq_len = prefix_length + action_tokens.shape[-2]
1553 |
1554 | block_mask = create_block_mask(
1555 | create_pizero_attn_mask(
1556 | prefix_length,
1557 | mask = mask,
1558 | ),
1559 | Q_LEN = seq_len,
1560 | KV_LEN = seq_len,
1561 | device = state_tokens.device,
1562 | _compile = True,
1563 | )
1564 |
1565 | score_mod_fn = softclamp_score_mod(self.attn_softclamp_value)
1566 |
1567 | flex_attn_fn = partial(
1568 | flex_attention,
1569 | block_mask = block_mask,
1570 | score_mod = score_mod_fn
1571 | )
1572 |
1573 | # state keys and values for caching during inference
1574 |
1575 | cached_state_key_values_iter = iter(default(cached_state_keys_values, []))
1576 |
1577 | # value residual learning
1578 |
1579 | actions_value_residual = None
1580 |
1581 | # maybe expand residual streams
1582 |
1583 | action_tokens = self.maybe_expand_residuals(action_tokens)
1584 |
1585 | # transformer
1586 |
1587 | if not inferencing:
1588 |
1589 | next_state_cached_keys_values = []
1590 |
1591 | for (
1592 | (attn, state_ff, actions_ff, memories_ff),
1593 | (attn_ada_rmsnorm, attn_ada_layerscale, ff_ada_rmsnorm, ff_ada_layerscale),
1594 | (attn_residual, actions_ff_residual),
1595 | ) in zip(self.layers, self.cond_layers, self.residual_layers):
1596 |
1597 | # joint attention
1598 |
1599 | action_tokens, add_action_residual = attn_residual(action_tokens)
1600 |
1601 | action_tokens = attn_ada_rmsnorm(action_tokens, time_cond)
1602 |
1603 | (state_attn_out, actions_attn_out, *maybe_mem_out), (state_keys, state_values, action_keys, action_values) = attn(
1604 | state_tokens,
1605 | action_tokens,
1606 | rotary_emb = rotary_emb,
1607 | flex_attn_fn = flex_attn_fn,
1608 | actions_value_residual = actions_value_residual,
1609 | mask = mask,
1610 | return_keys_values = True,
1611 | memories = memory_tokens
1612 | )
1613 |
1614 | next_state_cached_keys_values.append((state_keys, state_values))
1615 |
1616 | actions_value_residual = default(actions_value_residual, action_values)
1617 |
1618 | action_attn_out = attn_ada_layerscale(actions_attn_out, time_cond)
1619 |
1620 | state_tokens = state_tokens + state_attn_out
1621 | action_tokens = add_action_residual(action_attn_out)
1622 |
1623 | if self.has_recurrent_memories:
1624 | (read_mem_attn_out, write_mem_attn_out), = maybe_mem_out
1625 | read_mem, write_mem = memory_tokens
1626 |
1627 | memory_tokens = (read_mem + read_mem_attn_out, write_mem + write_mem_attn_out)
1628 |
1629 | # state feedforward
1630 |
1631 | state_tokens_out = state_ff(state_tokens)
1632 |
1633 | state_tokens = state_tokens + state_tokens_out
1634 |
1635 | # action feedforward
1636 |
1637 | action_tokens, add_action_ff_residual = actions_ff_residual(action_tokens)
1638 |
1639 | action_tokens = ff_ada_rmsnorm(action_tokens, time_cond)
1640 |
1641 | action_tokens_out = actions_ff(action_tokens)
1642 |
1643 | action_tokens_out = ff_ada_layerscale(action_tokens_out, time_cond)
1644 |
1645 | action_tokens = add_action_ff_residual(action_tokens_out)
1646 |
1647 | # maybe memory feedforward
1648 |
1649 | if self.has_recurrent_memories:
1650 | memory_tokens, unpack_memory = pack_with_inverse(memory_tokens, 'b * d')
1651 |
1652 | memory_tokens = memories_ff(memory_tokens) + memory_tokens
1653 |
1654 | memory_tokens = unpack_memory(memory_tokens)
1655 |
1656 | else:
1657 |
1658 | assert exists(cached_state_keys_values) and len(cached_state_keys_values) > 0
1659 |
1660 | next_state_cached_keys_values = cached_state_keys_values
1661 |
1662 | for (
1663 | (attn, state_ff, actions_ff, memories_ff),
1664 | (attn_ada_rmsnorm, attn_ada_layerscale, ff_ada_rmsnorm, ff_ada_layerscale),
1665 | (attn_residual, actions_ff_residual),
1666 | ) in zip(self.layers, self.cond_layers, self.residual_layers):
1667 |
1668 | # actions attention
1669 |
1670 | action_tokens, add_action_residual = attn_residual(action_tokens)
1671 |
1672 | action_tokens = attn_ada_rmsnorm(action_tokens, time_cond)
1673 |
1674 | actions_attn_out, (state_keys, state_values, action_keys, action_values) = attn.forward_actions_with_cached_state(
1675 | action_tokens,
1676 | cached_state_keys_values = next(cached_state_key_values_iter),
1677 | rotary_emb = rotary_emb,
1678 | mask = mask,
1679 | return_keys_values = True
1680 | )
1681 |
1682 | actions_value_residual = default(actions_value_residual, action_values)
1683 |
1684 | actions_attn_out = attn_ada_layerscale(actions_attn_out, time_cond)
1685 | action_tokens = add_action_residual(actions_attn_out)
1686 |
1687 | # actions feed forward
1688 |
1689 | action_tokens, add_action_ff_residual = actions_ff_residual(action_tokens)
1690 |
1691 | action_tokens = ff_ada_rmsnorm(action_tokens, time_cond)
1692 |
1693 | action_out = actions_ff(action_tokens)
1694 |
1695 | action_out = ff_ada_layerscale(action_out, time_cond)
1696 |
1697 | action_tokens = add_action_residual(action_out)
1698 |
1699 | # maybe memory feed forward
1700 |
1701 | if self.has_recurrent_memories:
1702 | memory_tokens, unpack_memory = pack_with_inverse(memory_tokens, 'b * d')
1703 |
1704 | memory_tokens = memories_ff(memory_tokens) + memory_tokens
1705 |
1706 | memory_tokens = unpack_memory(memory_tokens)
1707 |
1708 | # maybe reduce residual streams
1709 |
1710 | action_tokens = self.maybe_reduce_residuals(action_tokens)
1711 |
1712 | if not inferencing:
1713 | # unpack and unembed to predictions
1714 |
1715 | _, visual_tokens, tokens, *_ = inverse_packed_states(state_tokens, 'b * d')
1716 |
1717 | # gemma uses a final softclamp before norm
1718 |
1719 | tokens = self.final_norm_softclamp(tokens)
1720 |
1721 | *_, action_tokens = inverse_pack_action_registers(action_tokens)
1722 |
1723 | action_tokens = self.final_norm_softclamp(action_tokens)
1724 |
1725 | # memories
1726 |
1727 | read_memories, written_memory_tokens = memory_tokens
1728 |
1729 | # writeable memories norm
1730 |
1731 | if self.has_recurrent_memories:
1732 | written_memory_tokens = self.final_norm_write_memories(written_memory_tokens)
1733 |
1734 | # final actions norm
1735 |
1736 | action_embeds = self.final_actions_norm(action_tokens)
1737 |
1738 | # pool the action embeds and project if critic loss
1739 |
1740 | if self.is_critic:
1741 | action_embeds = reduce(action_embeds, 'b n d -> b d', 'mean')
1742 |
1743 | return self.to_critic_value(action_embeds, return_value_and_logits = True)
1744 |
1745 | # validate loss being returned
1746 |
1747 | assert return_language_loss or return_action_flow_loss
1748 |
1749 | # flow loss for actions tokens
1750 |
1751 | pred_actions_flow = self.actions_to_pred_flow(action_embeds)
1752 |
1753 | if return_actions_flow:
1754 |
1755 | if not return_state_keys_values and not record_and_return_memory_tokens:
1756 | return pred_actions_flow
1757 |
1758 | if not return_state_keys_values:
1759 | return pred_actions_flow, written_memory_tokens
1760 |
1761 | return pred_actions_flow, next_state_cached_keys_values
1762 |
1763 | flow_loss = self.zero
1764 |
1765 | if return_action_flow_loss:
1766 | flow_loss = self.loss_fn(pred_actions_flow, flow)
1767 |
1768 | # language cross entropy loss
1769 |
1770 | language_loss = self.zero
1771 |
1772 | if return_language_loss:
1773 | tokens = self.final_norm(tokens)
1774 |
1775 | language_logits = self.state_to_logits(tokens)
1776 |
1777 | language_loss = F.cross_entropy(
1778 | rearrange(language_logits[:, :-1], 'b n l -> b l n'),
1779 | labels,
1780 | ignore_index = self.lm_pad_id
1781 | )
1782 |
1783 | # loss breakdown
1784 |
1785 | loss_breakdown = (language_loss, flow_loss)
1786 |
1787 | # total loss and return breakdown
1788 |
1789 | total_loss = (
1790 | language_loss * self.lm_loss_weight +
1791 | flow_loss * self.flow_loss_weight
1792 | )
1793 |
1794 | if not record_and_return_memory_tokens:
1795 | return total_loss, loss_breakdown
1796 |
1797 | return total_loss, loss_breakdown, written_memory_tokens
1798 |
1799 | # generalized advantage estimate
1800 |
1801 | @torch.no_grad()
1802 | def calc_generalized_advantage_estimate(
1803 | rewards,
1804 | values,
1805 | masks,
1806 | gamma = 0.99,
1807 | lam = 0.95,
1808 | use_accelerated = None
1809 | ):
1810 | use_accelerated = default(use_accelerated, rewards.is_cuda)
1811 |
1812 | values = F.pad(values, (0, 1), value = 0.)
1813 | values, values_next = values[:, :-1], values[:, 1:]
1814 |
1815 | delta = rewards + gamma * values_next * masks - values
1816 | gates = gamma * lam * masks
1817 |
1818 | scan = AssocScan(reverse = True, use_accelerated = use_accelerated)
1819 |
1820 | return scan(gates, delta)
1821 |
1822 | # agent
1823 |
1824 | class Agent(Module):
1825 | def __init__(
1826 | self,
1827 | model: PiZero,
1828 | optim_klass = AdoptAtan2,
1829 | num_latent_genes = 1,
1830 | actor_lr = 3e-4,
1831 | critic_lr = 3e-4,
1832 | actor_weight_decay = 1e-3,
1833 | critic_weight_decay = 1e-3,
1834 | max_grad_norm = 0.5,
1835 | actor_optim_kwargs: dict = dict(),
1836 | critic_optim_kwargs: dict = dict(),
1837 | latent_gene_pool_kwargs: dict = dict(
1838 | frac_tournaments = 0.5
1839 | )
1840 | ):
1841 | super().__init__()
1842 |
1843 | # evolutionary policy optimization related
1844 | # Wang et al. https://web3.arxiv.org/abs/2503.19037
1845 |
1846 | assert num_latent_genes >= 1
1847 | evolutionary_learning = num_latent_genes > 1
1848 |
1849 | dim_latent = model.dim if evolutionary_learning else None
1850 |
1851 | self.latent_gene_pool = LatentGenePool(dim_latent = dim_latent, num_latents = num_latent_genes, **latent_gene_pool_kwargs) if evolutionary_learning else None
1852 | self.has_gene_pool = evolutionary_learning
1853 |
1854 | # init actor critic, taking into account model may not have probabilistic flow to start off with, and determine whether it needs to be reinstantiated for latent conditioning
1855 |
1856 | actor = model
1857 |
1858 | if not model.policy_optimizable or evolutionary_learning:
1859 | actor = model.create_actor(dim_latent = dim_latent)
1860 |
1861 | self.actor = actor
1862 | self.critic = actor.create_critic()
1863 |
1864 | # gradient clipping
1865 |
1866 | self.max_grad_norm = max_grad_norm
1867 |
1868 | # optimizers
1869 |
1870 | self.actor_optim = optim_klass(self.actor.parameters(), lr = actor_lr, weight_decay = actor_weight_decay, **actor_optim_kwargs)
1871 | self.critic_optim = optim_klass(self.critic.parameters(), lr = critic_lr, weight_decay = critic_weight_decay, **critic_optim_kwargs)
1872 |
1873 | def take_genetic_algorithm_step_(self, fitnesses):
1874 | if not self.has_gene_pool:
1875 | return
1876 |
1877 | self.latent_gene_pool.genetic_algorithm_step(fitnesses)
1878 |
1879 | def forward(
1880 | self,
1881 | memories
1882 | ):
1883 | raise NotImplementedError
1884 |
1885 | class EPO(Module):
1886 | def __init__(
1887 | self,
1888 | agent: Agent,
1889 | env,
1890 | accelerate_kwargs: dict = dict()
1891 | ):
1892 | super().__init__()
1893 | self.accelerate = Accelerator(**accelerate_kwargs)
1894 |
1895 | self.agent = agent
1896 | self.env = env
1897 |
1898 | (
1899 | agent.actor,
1900 | agent.critic,
1901 | agent.actor_optim,
1902 | agent.critic_optim
1903 | ) = self.accelerate.prepare(
1904 | agent.actor,
1905 | agent.critic,
1906 | agent.actor_optim,
1907 | agent.critic_optim
1908 | )
1909 |
1910 | self.register_buffer('step', tensor(0))
1911 |
1912 | @property
1913 | def unwrapped_actor(self):
1914 | return self.accelerate.unwrap_model(self.agent.actor)
1915 |
1916 | @property
1917 | def unwrapped_critic(self):
1918 | return self.accelerate.unwrap_model(self.agent.critic)
1919 |
1920 | def log(self, **data_kwargs):
1921 | return self.accelerate.log(data_kwargs, step = self.step.item())
1922 |
1923 | @torch.no_grad()
1924 | def gather_experience_from_env(
1925 | self,
1926 | steps,
1927 | trajectory_length = 16,
1928 | flow_sampling_steps = 4,
1929 | temperature = 1.,
1930 | **sampling_kwargs
1931 | ):
1932 | self.agent.eval()
1933 |
1934 | actor = self.unwrapped_actor
1935 |
1936 | states = self.env.reset()
1937 |
1938 | memories = []
1939 |
1940 | for _ in range(steps):
1941 |
1942 | sampled_actions, replay_tensors = temp_batch_dim(actor)(
1943 | *states,
1944 | trajectory_length = trajectory_length,
1945 | steps = flow_sampling_steps,
1946 | return_states_for_replay = True,
1947 | temperature = temperature,
1948 | **sampling_kwargs
1949 | )
1950 |
1951 | next_states, reward, truncated, terminated = self.env.step(sampled_actions)
1952 |
1953 | memories.append(to_device([*states, reward, terminated, *replay_tensors], torch.device('cpu')))
1954 |
1955 | states = next_states
1956 |
1957 | self.accelerate.wait_for_everyone()
1958 |
1959 | return memories
1960 |
1961 | def learn_agent(
1962 | self,
1963 | memories,
1964 | fitnesses = None,
1965 | epochs = 2,
1966 | batch_size = 16
1967 | ):
1968 | self.agent.train()
1969 |
1970 | (
1971 | images,
1972 | commands,
1973 | joint_state,
1974 | rewards,
1975 | terminated,
1976 | actions,
1977 | timesteps,
1978 | sampled_flows,
1979 | log_probs
1980 | ) = map(torch.stack, zip(*memories))
1981 |
1982 | flow_timesteps = actions.shape[1]
1983 |
1984 | values, _ = self.agent.critic(
1985 | repeat(images, 't ... -> (t ft) ...', ft = flow_timesteps),
1986 | repeat(commands, 't ... -> (t ft) ...', ft = flow_timesteps),
1987 | repeat(joint_state, 't ... -> (t ft) ...', ft = flow_timesteps),
1988 | actions = rearrange(actions, 't ft ... -> (t ft) ...'),
1989 | times = rearrange(timesteps, 't ft ... -> (t ft) ...')
1990 | )
1991 |
1992 | values = rearrange(values, '(t ft) -> ft t', ft = flow_timesteps)
1993 | values = values.detach().cpu()
1994 |
1995 | # actions go out into the environment, rewards are received, generalized advantage calculated with critic values
1996 |
1997 | boundaries = repeat(terminated, 't -> ft t', ft = flow_timesteps)
1998 |
1999 | advantages = calc_generalized_advantage_estimate(rewards, values, boundaries, use_accelerated = False).detach()
2000 |
2001 | # move time back to first dimension to be batched for learning
2002 |
2003 | advantages = rearrange(advantages, 'ft t -> t ft')
2004 | values = rearrange(values, 'ft t -> t ft')
2005 |
2006 | # dataset and dataloader
2007 |
2008 | dataset = TensorDataset(
2009 | images,
2010 | commands,
2011 | joint_state,
2012 | rewards,
2013 | terminated,
2014 | actions,
2015 | timesteps,
2016 | sampled_flows,
2017 | log_probs,
2018 | values,
2019 | advantages
2020 | )
2021 |
2022 | dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True)
2023 |
2024 | # training loop
2025 |
2026 | for _ in range(epochs):
2027 | for (
2028 | images,
2029 | commands,
2030 | joint_state,
2031 | rewards,
2032 | terminated,
2033 | actions,
2034 | timesteps,
2035 | sampled_flows,
2036 | log_probs,
2037 | values,
2038 | advantages
2039 | ) in dataloader:
2040 |
2041 | # optimize policy with replay tensors from above
2042 |
2043 | actor_loss = self.agent.actor.forward_for_policy_loss(
2044 | images,
2045 | commands,
2046 | joint_state,
2047 | actions,
2048 | times = timesteps,
2049 | flow = sampled_flows,
2050 | old_log_probs = log_probs,
2051 | advantages = advantages,
2052 | )
2053 |
2054 | actor_loss.backward()
2055 |
2056 | self.log(actor_loss = actor_loss.item())
2057 |
2058 | self.accelerate.clip_grad_norm_(self.agent.actor.parameters(), self.agent.max_grad_norm)
2059 |
2060 | self.agent.actor_optim.step()
2061 | self.agent.actor_optim.zero_grad()
2062 |
2063 | critic_loss = self.agent.critic.forward_for_critic_loss(
2064 | repeat(images, 't ... -> (ft t) ...', ft = flow_timesteps),
2065 | repeat(commands, 't ... -> (ft t) ...', ft = flow_timesteps),
2066 | repeat(joint_state, 't ... -> (ft t) ...', ft = flow_timesteps),
2067 | rearrange(actions, 't ft ... -> (ft t) ...'),
2068 | old_values = values,
2069 | advantages = advantages,
2070 | )
2071 |
2072 | critic_loss.backward()
2073 |
2074 | self.log(critic_loss = critic_loss.item())
2075 |
2076 | self.accelerate.clip_grad_norm_(self.agent.critic.parameters(), self.agent.max_grad_norm)
2077 |
2078 | self.agent.critic_optim.step()
2079 | self.agent.critic_optim.zero_grad()
2080 |
2081 | if exists(fitnesses):
2082 | self.log(fitnesses = fitnesses)
2083 |
2084 | self.agent.take_genetic_algorithm_step_(fitnesses)
2085 |
2086 | self.step.add_(1)
2087 |
2088 | # fun
2089 |
2090 | π0 = PiZero
2091 |
--------------------------------------------------------------------------------
/pi_zero_pytorch/tensor_typing.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 |
3 | from jaxtyping import (
4 | Float,
5 | Int,
6 | Bool
7 | )
8 |
9 | # jaxtyping is a misnomer, works for pytorch
10 |
11 | class TorchTyping:
12 | def __init__(self, abstract_dtype):
13 | self.abstract_dtype = abstract_dtype
14 |
15 | def __getitem__(self, shapes: str):
16 | return self.abstract_dtype[Tensor, shapes]
17 |
18 | Float = TorchTyping(Float)
19 | Int = TorchTyping(Int)
20 | Bool = TorchTyping(Bool)
21 |
22 | __all__ = [
23 | Float,
24 | Int,
25 | Bool
26 | ]
27 |
--------------------------------------------------------------------------------
/plot_time_from_beta.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from pi_zero_pytorch.pi_zero import default_sample_times
3 |
4 | times = default_sample_times((10000,), s = 0.9)
5 |
6 | plt.hist(times.numpy())
7 | plt.show()
8 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "pi-zero-pytorch"
3 | version = "0.1.32"
4 | description = "π0 in Pytorch"
5 | authors = [
6 | { name = "Phil Wang", email = "lucidrains@gmail.com" }
7 | ]
8 | readme = "README.md"
9 | requires-python = ">= 3.9"
10 | license = { file = "LICENSE" }
11 | keywords = [
12 | 'artificial intelligence',
13 | 'deep learning',
14 | 'transformers',
15 | 'flow policy',
16 | 'robotic foundation model',
17 | ]
18 |
19 | classifiers=[
20 | 'Development Status :: 4 - Beta',
21 | 'Intended Audience :: Developers',
22 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
23 | 'License :: OSI Approved :: MIT License',
24 | 'Programming Language :: Python :: 3.9',
25 | ]
26 |
27 | dependencies = [
28 | "accelerate>=1.6.0",
29 | "assoc-scan>=0.0.2",
30 | "beartype",
31 | "einx>=0.3.0",
32 | "einops>=0.8.0",
33 | "ema-pytorch>=0.7.3",
34 | "evolutionary-policy-optimization>=0.1.19",
35 | "jaxtyping",
36 | 'hyper-connections>=0.0.10',
37 | "hl-gauss-pytorch>=0.1.21",
38 | "rotary-embedding-torch>=0.8.5",
39 | 'scipy',
40 | "torch>=2.5",
41 | 'torchdiffeq',
42 | 'torchtyping>=0.1.5',
43 | "tqdm"
44 | ]
45 |
46 | [project.urls]
47 | Homepage = "https://pypi.org/project/pi-zero-pytorch/"
48 | Repository = "https://github.com/lucidrains/pi-zero-pytorch"
49 |
50 | [project.optional-dependencies]
51 | examples = []
52 | test = [
53 | "pytest",
54 | "ruff>=0.4.2",
55 | "vit-pytorch>=1.8.7"
56 | ]
57 |
58 | [tool.pytest.ini_options]
59 | pythonpath = [
60 | "."
61 | ]
62 |
63 | [tool.ruff]
64 | line-length = 1000
65 |
66 | lint.ignore = [
67 | "F722", # for jaxtyping shape annotation
68 | "F401",
69 | "F821"
70 | ]
71 |
72 | lint.extend-select = [
73 | "W291"
74 | ]
75 |
76 | [build-system]
77 | requires = ["hatchling"]
78 | build-backend = "hatchling.build"
79 |
80 | [tool.rye]
81 | managed = true
82 | dev-dependencies = []
83 |
84 | [tool.hatch.metadata]
85 | allow-direct-references = true
86 |
87 | [tool.hatch.build.targets.wheel]
88 | packages = ["pi_zero_pytorch"]
89 |
--------------------------------------------------------------------------------
/tests/test_pi_zero.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import torch
4 | from pi_zero_pytorch import π0
5 | from einops import repeat, rearrange
6 |
7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8 |
9 | @pytest.mark.parametrize('only_vlm', (True, False))
10 | @pytest.mark.parametrize('num_residual_streams', (1, 4))
11 | def test_pi_zero_with_vit(
12 | only_vlm: bool,
13 | num_residual_streams: int,
14 | ):
15 | from vit_pytorch import ViT
16 | from vit_pytorch.extractor import Extractor
17 |
18 | v = ViT(
19 | image_size = 256,
20 | patch_size = 32,
21 | num_classes = 1000,
22 | dim = 32,
23 | depth = 1,
24 | heads = 16,
25 | dim_head = 16,
26 | mlp_dim = 64,
27 | dropout = 0.1,
28 | emb_dropout = 0.1
29 | ).to(device)
30 |
31 | v = Extractor(v, return_embeddings_only = True)
32 |
33 | model = π0(
34 | dim = 32,
35 | vit = v,
36 | vit_dim = 32,
37 | depth = 1,
38 | dim_action_input = 6,
39 | dim_joint_state = 12,
40 | num_tokens = 32,
41 | num_residual_streams = num_residual_streams,
42 | ).to(device)
43 |
44 | images = torch.randn(2, 3, 2, 256, 256)
45 | commands = torch.randint(0, 32, (2, 1024))
46 |
47 | if only_vlm:
48 | vlm_logits = model.forward_only_vision_language(images, commands)
49 | assert vlm_logits.ndim == 3
50 | return
51 |
52 | joint_state = torch.randn(2, 12)
53 | actions = torch.randn(2, 32, 6)
54 |
55 | loss, _ = model(images, commands, joint_state, actions)
56 | loss.backward()
57 |
58 | # after much training
59 |
60 | sampled_actions = model(images, commands, joint_state, trajectory_length = 32) # (1, 32, 6)
61 |
62 | assert sampled_actions.shape == (2, 32, 6)
63 |
64 | @pytest.mark.parametrize('num_latent_genes', (1, 16))
65 | def test_policy_optimization(
66 | num_latent_genes
67 | ):
68 |
69 | from vit_pytorch import ViT
70 | from vit_pytorch.extractor import Extractor
71 |
72 | from pi_zero_pytorch.pi_zero import (
73 | Agent,
74 | EPO,
75 | )
76 |
77 | from pi_zero_pytorch.mock_env import Env
78 |
79 | v = ViT(
80 | image_size = 256,
81 | patch_size = 32,
82 | num_classes = 1000,
83 | dim = 32,
84 | depth = 1,
85 | heads = 2,
86 | dim_head = 8,
87 | mlp_dim = 16,
88 | dropout = 0.1,
89 | emb_dropout = 0.1
90 | )
91 |
92 | v = Extractor(v, return_embeddings_only = True)
93 |
94 | model = π0(
95 | dim = 32,
96 | vit = v,
97 | vit_dim = 32,
98 | depth = 1,
99 | dim_action_input = 6,
100 | dim_joint_state = 12,
101 | num_tokens = 32,
102 | policy_optimizable = True
103 | ).to(device)
104 |
105 | images = torch.randn(2, 3, 2, 256, 256)
106 | commands = torch.randint(0, 32, (2, 1024))
107 |
108 | joint_state = torch.randn(2, 12)
109 | actions = torch.randn(2, 32, 6)
110 |
111 | loss, _ = model(images, commands, joint_state, actions)
112 | loss.backward()
113 |
114 | # agent
115 |
116 | agent = Agent(
117 | model,
118 | num_latent_genes = num_latent_genes
119 | )
120 |
121 | mock_env = Env((256, 256), 2, 32, 1024, 12)
122 |
123 | epo = EPO(
124 | agent,
125 | mock_env,
126 | accelerate_kwargs = dict(
127 | cpu = True
128 | )
129 | )
130 |
131 | memories = epo.gather_experience_from_env(steps = 10)
132 |
133 | epo.learn_agent(memories, batch_size = 2)
134 |
--------------------------------------------------------------------------------