├── .github └── workflows │ ├── lint.yaml │ ├── python-publish.yml │ └── test.yaml ├── .gitignore ├── LICENSE ├── README.md ├── fig3.png ├── pi_zero_pytorch ├── __init__.py ├── mock_env.py ├── pi_zero.py └── tensor_typing.py ├── plot_time_from_beta.py ├── pyproject.toml └── tests └── test_pi_zero.py /.github/workflows/lint.yaml: -------------------------------------------------------------------------------- 1 | name: Ruff 2 | on: [push, pull_request] 3 | 4 | jobs: 5 | build: 6 | 7 | runs-on: ubuntu-latest 8 | 9 | steps: 10 | - uses: actions/checkout@v4 11 | - name: Set up Python 3.10 12 | uses: actions/setup-python@v5 13 | with: 14 | python-version: "3.10" 15 | - name: Install dependencies 16 | run: | 17 | python -m pip install uv 18 | python -m uv pip install ruff 19 | - name: Lint with Ruff 20 | run: | 21 | ruff check pi_zero_pytorch/ 22 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: tests 2 | on: [push, pull_request] 3 | 4 | env: 5 | TYPECHECK: True 6 | 7 | jobs: 8 | test: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Install Python 13 | uses: actions/setup-python@v5 14 | with: 15 | python-version: "3.11" 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install uv 19 | python -m uv pip install --upgrade pip 20 | python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu 21 | python -m uv pip install -e .[test] 22 | - name: Test with pytest 23 | run: | 24 | python -m pytest tests/ 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## pi-zero-pytorch (wip) 4 | 5 | Implementation of π₀ the robotic foundation model architecture proposed by Physical Intelligence 6 | 7 | Summary of this work would be that it is a simplified Transfusion (Zhou et al.) with influence from Stable Diffusion 3 (Esser et al.), mainly the adoption of flow matching instead of diffusion for policy generation, as well as the separation of parameters (Joint Attention from mmDIT). They build on top of a pretrained vision language model, PaliGemma 2B. 8 | 9 | Update: The [official repository](https://github.com/Physical-Intelligence/openpi) has been open sourced! 10 | 11 | ### Appreciation 12 | 13 | - [Einops](https://github.com/arogozhnikov/einops) for the amazing [pack and unpack](https://einops.rocks/4-pack-and-unpack/), used extensively here for managing various token sets 14 | 15 | - [Flex Attention](https://pytorch.org/blog/flexattention/) for allowing for easy mixture of autoregressive and bidirectional attention 16 | 17 | - [@Wonder1905](https://github.com/Wonder1905) for the code review and identifying issues 18 | 19 | - You? maybe a phd student who wants to contribute to the latest SOTA architecture for behavioral cloning? 20 | 21 | ### Install 22 | 23 | ```bash 24 | $ pip install pi-zero-pytorch 25 | ``` 26 | 27 | ### Usage 28 | 29 | ```python 30 | import torch 31 | from pi_zero_pytorch import π0 32 | 33 | model = π0( 34 | dim = 512, 35 | dim_action_input = 6, 36 | dim_joint_state = 12, 37 | num_tokens = 20_000 38 | ) 39 | 40 | vision = torch.randn(1, 1024, 512) 41 | commands = torch.randint(0, 20_000, (1, 1024)) 42 | joint_state = torch.randn(1, 12) 43 | actions = torch.randn(1, 32, 6) 44 | 45 | loss, _ = model(vision, commands, joint_state, actions) 46 | loss.backward() 47 | 48 | # after much training 49 | 50 | sampled_actions = model(vision, commands, joint_state, trajectory_length = 32) # (1, 32, 6) 51 | ``` 52 | 53 | To do online learning, just wrap the model with the `Agent` class 54 | 55 | ```python 56 | from pi_zero_pytorch import π0, Agent, EPO 57 | 58 | # wrap the model with `Agent`, which will instantiate actor and critic for PPO 59 | 60 | agent = Agent(model) 61 | 62 | # you'll want to supply your own environment 63 | 64 | from pi_zero_pytorch.mock_env import Env 65 | mock_env = Env((256, 256), 2, 32, 1024, 12) 66 | 67 | # pass your agent and environment to EPO for learning to be orchestrated 68 | 69 | epo = EPO(agent, mock_env) 70 | 71 | # gather memories from environment 72 | 73 | memories = epo.gather_experience_from_env(steps = 10) 74 | 75 | # learn from memories 76 | 77 | epo.learn_agent(memories, batch_size = 2) 78 | ``` 79 | 80 | ### Contributing 81 | 82 | At the project root, run 83 | 84 | ```bash 85 | $ pip install '.[test]' # or `uv pip install '.[test]'` 86 | ``` 87 | 88 | Then add your tests to `tests/test_pi_zero.py` and run 89 | 90 | ```bash 91 | $ pytest tests/ 92 | ``` 93 | 94 | That's it 95 | 96 | ### Citation 97 | 98 | ```bibtex 99 | @misc{Black2024, 100 | author = {Kevin Black, Noah Brown, Danny Driess, Adnan Esmail, Michael Equi, Chelsea Finn, Niccolo Fusai, Lachy Groom, Karol Hausman, Brian Ichter, Szymon Jakubczak, Tim Jones, Liyiming Ke, Sergey Levine, Adrian Li-Bell, Mohith Mothukuri, Suraj Nair, Karl Pertsch, Lucy Xiaoyang Shi, James Tanner, Quan Vuong, Anna Walling, Haohuan Wang, Ury Zhilinsky}, 101 | url = {https://www.physicalintelligence.company/download/pi0.pdf} 102 | } 103 | ``` 104 | 105 | ```bibtex 106 | @inproceedings{Zhou2024ValueRL, 107 | title = {Value Residual Learning For Alleviating Attention Concentration In Transformers}, 108 | author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan}, 109 | year = {2024}, 110 | url = {https://api.semanticscholar.org/CorpusID:273532030} 111 | } 112 | ``` 113 | 114 | ```bibtex 115 | @inproceedings{Darcet2023VisionTN, 116 | title = {Vision Transformers Need Registers}, 117 | author = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski}, 118 | year = {2023}, 119 | url = {https://api.semanticscholar.org/CorpusID:263134283} 120 | } 121 | ``` 122 | 123 | ```bibtex 124 | @article{Li2024ImmiscibleDA, 125 | title = {Immiscible Diffusion: Accelerating Diffusion Training with Noise Assignment}, 126 | author = {Yiheng Li and Heyang Jiang and Akio Kodaira and Masayoshi Tomizuka and Kurt Keutzer and Chenfeng Xu}, 127 | journal = {ArXiv}, 128 | year = {2024}, 129 | volume = {abs/2406.12303}, 130 | url = {https://api.semanticscholar.org/CorpusID:270562607} 131 | } 132 | ``` 133 | 134 | ```bibtex 135 | @inproceedings{Sadat2024EliminatingOA, 136 | title = {Eliminating Oversaturation and Artifacts of High Guidance Scales in Diffusion Models}, 137 | author = {Seyedmorteza Sadat and Otmar Hilliges and Romann M. Weber}, 138 | year = {2024}, 139 | url = {https://api.semanticscholar.org/CorpusID:273098845} 140 | } 141 | ``` 142 | 143 | ```bibtex 144 | @article{Bulatov2022RecurrentMT, 145 | title = {Recurrent Memory Transformer}, 146 | author = {Aydar Bulatov and Yuri Kuratov and Mikhail S. Burtsev}, 147 | journal = {ArXiv}, 148 | year = {2022}, 149 | volume = {abs/2207.06881}, 150 | url = {https://api.semanticscholar.org/CorpusID:250526424} 151 | } 152 | ``` 153 | 154 | ```bibtex 155 | @inproceedings{Bessonov2023RecurrentAT, 156 | title = {Recurrent Action Transformer with Memory}, 157 | author = {A. B. Bessonov and Alexey Staroverov and Huzhenyu Zhang and Alexey K. Kovalev and D. Yudin and Aleksandr I. Panov}, 158 | year = {2023}, 159 | url = {https://api.semanticscholar.org/CorpusID:259188030} 160 | } 161 | ``` 162 | 163 | ```bibtex 164 | @article{Zhu2024HyperConnections, 165 | title = {Hyper-Connections}, 166 | author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou}, 167 | journal = {ArXiv}, 168 | year = {2024}, 169 | volume = {abs/2409.19606}, 170 | url = {https://api.semanticscholar.org/CorpusID:272987528} 171 | } 172 | ``` 173 | 174 | ```bibtex 175 | @inproceedings{Sun2025F5RTTSIF, 176 | title = {F5R-TTS: Improving Flow-Matching based Text-to-Speech with Group Relative Policy Optimization}, 177 | author = {Xiaohui Sun and Ruitong Xiao and Jianye Mo and Bowen Wu and Qun Yu and Baoxun Wang}, 178 | year = {2025}, 179 | url = {https://api.semanticscholar.org/CorpusID:277510064} 180 | } 181 | ``` 182 | 183 | ```bibtex 184 | @inproceedings{Wang2025EvolutionaryPO, 185 | title = {Evolutionary Policy Optimization}, 186 | author = {Jianren Wang and Yifan Su and Abhinav Gupta and Deepak Pathak}, 187 | year = {2025}, 188 | url = {https://api.semanticscholar.org/CorpusID:277313729} 189 | } 190 | ``` 191 | 192 | [*dear alice*](https://www.youtube.com/watch?v=z-Ng5ZvrDm4) 193 | -------------------------------------------------------------------------------- /fig3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/pi-zero-pytorch/a766537ea23fd91123cd054955127df6e1bd37c8/fig3.png -------------------------------------------------------------------------------- /pi_zero_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from pi_zero_pytorch.pi_zero import PiZero, π0 2 | -------------------------------------------------------------------------------- /pi_zero_pytorch/mock_env.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from random import choice 3 | 4 | import torch 5 | from torch import tensor, randn, randint 6 | from torch.nn import Module 7 | 8 | # functions 9 | 10 | def cast_tuple(v): 11 | return v if isinstance(v, tuple) else (v,) 12 | 13 | # mock env 14 | 15 | class Env(Module): 16 | def __init__( 17 | self, 18 | image_shape, 19 | num_images, 20 | num_text_tokens, 21 | max_text_len, 22 | joint_dim, 23 | can_terminate_after = 2 24 | ): 25 | super().__init__() 26 | self.image_shape = image_shape 27 | self.num_images = num_images 28 | self.num_text_tokens = num_text_tokens 29 | self.max_text_len = max_text_len 30 | self.joint_dim = joint_dim 31 | 32 | self.can_terminate_after = can_terminate_after 33 | self.register_buffer('_step', tensor(0)) 34 | 35 | def get_random_state(self): 36 | return ( 37 | randn(3, self.num_images, *self.image_shape), 38 | randint(0, self.num_text_tokens, (self.max_text_len,)), 39 | randn(self.joint_dim) 40 | ) 41 | 42 | def reset( 43 | self, 44 | seed = None 45 | ): 46 | self._step.zero_() 47 | return self.get_random_state() 48 | 49 | def step( 50 | self, 51 | actions, 52 | ): 53 | state = self.get_random_state() 54 | reward = randint(-100, 100, ()).float() 55 | 56 | if self._step > self.can_terminate_after: 57 | truncated = tensor(choice((True, False))) 58 | terminated = tensor(choice((True, False))) 59 | else: 60 | truncated = terminated = tensor(False) 61 | 62 | self._step.add_(1) 63 | 64 | return state, reward, truncated, terminated 65 | -------------------------------------------------------------------------------- /pi_zero_pytorch/pi_zero.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from random import random, randrange 4 | 5 | from beartype import beartype 6 | from beartype.typing import Callable, Literal 7 | 8 | from functools import partial, wraps 9 | from itertools import count 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import pi, nn, stack, tensor, is_tensor 14 | from torch.nn import Module, ModuleList 15 | from torch.distributions import Normal 16 | from torch.distributions.beta import Beta 17 | 18 | from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten 19 | 20 | from torch.utils.data import TensorDataset, DataLoader 21 | 22 | from torchdiffeq import odeint 23 | 24 | from scipy.optimize import linear_sum_assignment 25 | 26 | from ema_pytorch import EMA 27 | 28 | from adam_atan2_pytorch import AdoptAtan2 29 | 30 | from rotary_embedding_torch import ( 31 | RotaryEmbedding, 32 | apply_rotary_emb 33 | ) 34 | 35 | import einx 36 | from einops.layers.torch import Rearrange 37 | from einops import rearrange, repeat, reduce, einsum, pack, unpack 38 | 39 | from pi_zero_pytorch.tensor_typing import Float, Int, Bool 40 | 41 | from hyper_connections import HyperConnections 42 | 43 | from hl_gauss_pytorch import HLGaussLayer 44 | 45 | from assoc_scan import AssocScan 46 | 47 | from evolutionary_policy_optimization import LatentGenePool 48 | 49 | import tqdm 50 | 51 | from accelerate import Accelerator 52 | 53 | # ein notation 54 | 55 | # b - batch 56 | # n - sequence 57 | # na - seq of actions 58 | # nt - seq of text tokens 59 | # nv - seq of visual tokens 60 | # ns - seq of additional internal state tokens 61 | # nm - seq of memory tokens 62 | # d - dimension 63 | # da - action dimension 64 | # djs - joint state dimension 65 | # c - image channels 66 | # h - image height 67 | # w - image width 68 | # f - image frames 69 | # s - residual streams (hyper connections paper) 70 | 71 | # token layout for transformer 72 | # vision and language tokens are autoregressive causal mask, actions, interal states + joint bidirectional amongst own tokens, but still autoregressive with respect to other tokens 73 | 74 | # [state token groups] [action token groups] -> [autoregressive masking] [bidirectional] 75 | # [external state] [visual tokens] [language tokens] [maybe reward / condition token] [action registers] [joint state + internal state] [actions] 76 | 77 | # for an attempt to introduce recurrence, all tokens above can be flanked by read and write memory tokens 78 | # [read memory tokens] [...] [write memory tokens] 79 | 80 | # constants 81 | 82 | LinearNoBias = partial(nn.Linear, bias = False) 83 | 84 | # flex attention related 85 | # https://pytorch.org/blog/flexattention/ 86 | 87 | flex_attention = None 88 | 89 | if torch.cuda.is_available(): 90 | from torch.nn.attention.flex_attention import flex_attention, create_block_mask 91 | flex_attention = torch.compile(flex_attention) 92 | 93 | def create_pizero_attn_mask( 94 | prefix_causal_length, 95 | mask: Bool['b n'] 96 | ): 97 | # the pi-zero attention is a triangular causal mask, but bidirectional attention for the actions at the very right hand side 98 | 99 | def mask_fn(batch_index, head_index, query_index, key_index): 100 | key_mask = mask[batch_index, key_index] # variable length states 101 | causal_mask = query_index >= key_index # causal 102 | 103 | bidirectional_action_mask = ( # bidirectional action mask 104 | key_index >= prefix_causal_length and 105 | query_index >= prefix_causal_length 106 | ) 107 | 108 | return (key_mask and causal_mask) or bidirectional_action_mask 109 | 110 | return mask_fn 111 | 112 | def softclamp_score_mod(value): 113 | def identity(score, b, h, q, k): 114 | return score 115 | 116 | def softclamped(score, b, h, q, k): 117 | score = score / value 118 | score = torch.tanh(score) 119 | score = score * value 120 | return score 121 | 122 | return softclamped if value > 0. else identity 123 | 124 | # helper functions 125 | 126 | def exists(v): 127 | return v is not None 128 | 129 | def default(v, d): 130 | return v if exists(v) else d 131 | 132 | def identity(t): 133 | return t 134 | 135 | def xnor(x, y): 136 | return not (x ^ y) 137 | 138 | def maybe(fn): 139 | @wraps(fn) 140 | def inner(t, *args, **kwargs): 141 | if not exists(t): 142 | return None 143 | 144 | return fn(t, *args, **kwargs) 145 | 146 | return inner 147 | 148 | def save_args_kwargs(fn): 149 | @wraps(fn) 150 | def decorated(self, *args, **kwargs): 151 | self._init_args_kwargs = (args, kwargs) 152 | return fn(self, *args, **kwargs) 153 | 154 | return decorated 155 | 156 | def to_device(t, device): 157 | return tree_map(lambda el: el.to(device) if is_tensor(el) else el, t) 158 | 159 | def move_input_tensors_to_device(fn): 160 | 161 | @wraps(fn) 162 | def decorated_fn(self, *args, **kwargs): 163 | args, kwargs = to_device((args, kwargs), self.device) 164 | return fn(self, *args, **kwargs) 165 | 166 | return decorated_fn 167 | 168 | def temp_batch_dim(fn): 169 | 170 | @wraps(fn) 171 | def inner(*args, **kwargs): 172 | args, kwargs = tree_map(lambda t: rearrange(t, '... -> 1 ...') if is_tensor(t) else t, (args, kwargs)) 173 | 174 | out = fn(*args, **kwargs) 175 | 176 | out = tree_map(lambda t: rearrange(t, '1 ... -> ...') if is_tensor(t) else t, out) 177 | return out 178 | 179 | return inner 180 | 181 | # tensor helpers 182 | 183 | def log(t, eps = 1e-20): 184 | return t.clamp(min = eps).log() 185 | 186 | def l2norm(t, dim = -1): 187 | return F.normalize(t, dim = dim) 188 | 189 | def softclamp(t, value): 190 | if value <= 0.: 191 | return t 192 | 193 | return (t / value).tanh() * value 194 | 195 | def max_neg_value(t): 196 | return -torch.finfo(t.dtype).max 197 | 198 | def pack_with_inverse(t, pattern): 199 | packed, packed_shape = pack(t, pattern) 200 | 201 | def inverse(out, inv_pattern = None): 202 | inv_pattern = default(inv_pattern, pattern) 203 | out = unpack(out, packed_shape, inv_pattern) 204 | return out 205 | 206 | return packed, inverse 207 | 208 | def pack_one_with_inverse(t, pattern): 209 | packed, inverse = pack_with_inverse([t], pattern) 210 | 211 | def inverse_one(out, inv_pattern = None): 212 | out, = inverse(out, inv_pattern) 213 | return out 214 | 215 | return packed, inverse_one 216 | 217 | def tree_flatten_with_inverse(input): 218 | out, tree_spec = tree_flatten(input) 219 | 220 | def inverse(output): 221 | return tree_unflatten(output, tree_spec) 222 | 223 | return out, inverse 224 | 225 | def project(x, y): 226 | x, inverse = pack_one_with_inverse(x, 'b *') 227 | y, _ = pack_one_with_inverse(y, 'b *') 228 | 229 | dtype = x.dtype 230 | x, y = x.double(), y.double() 231 | unit = l2norm(y, dim = -1) 232 | 233 | parallel = (x * unit).sum(dim = -1, keepdim = True) * unit 234 | orthogonal = x - parallel 235 | 236 | return inverse(parallel).to(dtype), inverse(orthogonal).to(dtype) 237 | 238 | def pad_at_dim( 239 | t, 240 | pad: tuple[int, int], 241 | *, 242 | dim = -1, 243 | value = 0. 244 | ): 245 | dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) 246 | zeros = ((0, 0) * dims_from_right) 247 | return F.pad(t, (*zeros, *pad), value = value) 248 | 249 | # flow related 250 | 251 | def default_sample_times( 252 | shape, 253 | s = 0.999, 254 | alpha = 1.5, 255 | beta = 1, 256 | device = None 257 | ): 258 | """ they propose to sample times from Beta distribution - last part of appendix part B """ 259 | 260 | alpha = torch.full(shape, alpha, device = device) 261 | beta = torch.full(shape, beta, device = device) 262 | sampled = Beta(alpha, beta).sample() 263 | return (1. - sampled) * s 264 | 265 | def noise_assignment(data, noise): 266 | device = data.device 267 | data, noise = tuple(rearrange(t, 'b ... -> b (...)') for t in (data, noise)) 268 | dist = torch.cdist(data, noise) 269 | _, assign = linear_sum_assignment(dist.cpu()) 270 | return torch.from_numpy(assign).to(device) 271 | 272 | # policy optimization related 273 | 274 | class GaussianNLL(Module): 275 | def forward(self, mu_sigma, target): 276 | mean, variance = mu_sigma.unbind(dim = -1) 277 | return F.gaussian_nll_loss(mean, target, variance) 278 | 279 | class LinearToMeanStd(Module): 280 | def __init__( 281 | self, 282 | dim, 283 | dim_out, 284 | eps = 1e-5 285 | ): 286 | super().__init__() 287 | self.linear = LinearNoBias(dim, dim_out * 2) 288 | self.eps = eps 289 | 290 | def forward(self, embed): 291 | out = self.linear(embed) 292 | 293 | mean, log_variance = rearrange(out, '... (d mu_sigma) -> mu_sigma ... d', mu_sigma = 2) 294 | variance = log_variance.exp() 295 | std = variance.clamp(min = self.eps).sqrt() 296 | 297 | return stack((mean, std), dim = -1) 298 | 299 | # attention 300 | 301 | class Attention(Module): 302 | @beartype 303 | def __init__( 304 | self, 305 | dim, 306 | dim_head = 64, 307 | heads = 8, 308 | dropout = 0., 309 | softclamp_value = 50., 310 | accept_memories = False, 311 | learned_value_action_residual_mix = False, 312 | rotary_emb: RotaryEmbedding | None = None 313 | ): 314 | super().__init__() 315 | self.scale = dim_head ** -0.5 316 | dim_inner = dim_head * heads 317 | 318 | self.rotary_emb = rotary_emb 319 | 320 | self.split_heads = Rearrange('b n (h d) -> b h n d', h = heads) 321 | self.merge_heads = Rearrange('b h n d -> b n (h d)') 322 | 323 | self.rmsnorm = nn.RMSNorm(dim) 324 | 325 | # state parameters 326 | 327 | self.to_qkv = LinearNoBias(dim, 3 * dim_inner) 328 | self.to_out = LinearNoBias(dim_inner, dim) 329 | 330 | # maybe memory parameters 331 | 332 | self.accept_memories = accept_memories 333 | 334 | self.mem_rmsnorm = nn.RMSNorm(dim) if accept_memories else None 335 | self.to_mem_qkv = LinearNoBias(dim, 3 * dim_inner) if accept_memories else None 336 | self.to_mem_out = LinearNoBias(dim_inner, dim) if accept_memories else None 337 | 338 | # action parameters 339 | 340 | self.to_actions_qkvg = LinearNoBias(dim, 4 * dim_inner) 341 | 342 | self.to_action_value_residual_mix = nn.Sequential( 343 | LinearNoBias(dim, heads), 344 | nn.Sigmoid(), 345 | Rearrange('b n h -> b h n 1') 346 | ) if learned_value_action_residual_mix else (lambda _: 0.5) 347 | 348 | self.to_actions_out = LinearNoBias(dim_inner, dim) 349 | 350 | self.softclamp_value = softclamp_value 351 | 352 | def forward_actions_with_cached_state( 353 | self, 354 | actions, 355 | cached_state_keys_values: tuple[Tensor, Tensor], 356 | memories: tuple[Tensor, Tensor] | None = None, 357 | rotary_emb = None, 358 | mask: Bool['b n'] | None = None, 359 | actions_value_residual: Tensor | None = None, 360 | return_keys_values = False, 361 | flex_attn_fn: Callable | None = None 362 | ): 363 | aq, ak, av, ag = self.to_actions_qkvg(actions).chunk(4, dim = -1) 364 | 365 | aq, ak, av, ag = tuple(self.split_heads(t) for t in (aq, ak, av, ag)) 366 | 367 | if exists(actions_value_residual): 368 | mix = self.to_action_value_residual_mix(actions) 369 | av = av * mix + actions_value_residual * (1. - mix) 370 | 371 | q = aq 372 | mk, mv = cached_state_keys_values 373 | 374 | # concat cache key / values with action key / values 375 | 376 | k, v = tuple(torch.cat(tensors, dim = -2) for tensors in zip((mk, mv), (ak, av))) 377 | 378 | # handle read, write memories 379 | 380 | assert not (self.accept_memories ^ exists(memories)) 381 | 382 | if exists(memories): 383 | _, write_memories = memories 384 | write_memories = self.mem_rmsnorm(write_memories) 385 | # mqkv_write = self.to_mem_qkv(write_memories) 386 | 387 | if exists(rotary_emb): 388 | q = apply_rotary_emb(rotary_emb, q, freqs_seq_dim = -2) 389 | k = apply_rotary_emb(rotary_emb, k) 390 | 391 | elif exists(self.rotary_emb): 392 | q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k) 393 | 394 | # attention 395 | 396 | if exists(flex_attn_fn): 397 | out = flex_attn_fn(q, k, v) 398 | else: 399 | q = q * self.scale 400 | 401 | sim = einsum(q, k, 'b h i d, b h j d -> b h i j') 402 | 403 | sim = softclamp(sim, self.softclamp_value) 404 | 405 | if exists(mask): 406 | sim = einx.where('b j, b h i j, -> b h i j', mask, sim, max_neg_value(sim)) 407 | 408 | attn = sim.softmax(dim = -1) 409 | 410 | out = einsum(attn, v, 'b h i j, b h j d -> b h i d') 411 | 412 | # gate 413 | 414 | out = out * ag.sigmoid() 415 | 416 | # merge attention heads 417 | 418 | out = self.merge_heads(out) 419 | 420 | actions_out = self.to_actions_out(out) 421 | 422 | if not return_keys_values: 423 | return actions_out 424 | 425 | return actions_out, (mk, mv, ak, av) 426 | 427 | def forward_only_vision_language( 428 | self, 429 | state: Float['b n d'], 430 | rotary_emb = None 431 | ) -> Float['b n d']: 432 | 433 | device = state.device 434 | 435 | q, k, v = self.to_qkv(state).chunk(3, dim = -1) 436 | 437 | q, k, v = tuple(self.split_heads(t) for t in (q, k, v)) 438 | 439 | if exists(rotary_emb): 440 | q = apply_rotary_emb(rotary_emb, q) 441 | k = apply_rotary_emb(rotary_emb, k) 442 | 443 | elif exists(self.rotary_emb): 444 | q = self.rotary_emb.rotate_queries_or_keys(q) 445 | k = self.rotary_emb.rotate_queries_or_keys(k) 446 | 447 | # attention 448 | 449 | q = q * self.scale 450 | 451 | sim = einsum(q, k, 'b h i d, b h j d -> b h i j') 452 | 453 | sim = softclamp(sim, self.softclamp_value) 454 | 455 | causal_mask = torch.ones(sim.shape[-2:], dtype = torch.bool, device = device).triu(1) 456 | 457 | sim = sim.masked_fill(causal_mask, max_neg_value(sim)) 458 | 459 | attn = sim.softmax(dim = -1) 460 | 461 | out = einsum(attn, v, 'b h i j, b h j d -> b h i d') 462 | 463 | # merge attention heads 464 | 465 | out = self.merge_heads(out) 466 | 467 | return self.to_out(out) 468 | 469 | def forward( 470 | self, 471 | multimodal_seq, 472 | actions, 473 | rotary_emb = None, 474 | memories: tuple[Tensor, Tensor] | None = None, 475 | mask: Bool['b n'] | None = None, 476 | actions_value_residual: Tensor | None = None, 477 | return_keys_values = False, 478 | flex_attn_fn: Callable | None = None 479 | ): 480 | seq_len, device = multimodal_seq.shape[-2], multimodal_seq.device 481 | 482 | multimodal_seq = self.rmsnorm(multimodal_seq) 483 | 484 | # separate projections for multimodal seq vs actions 485 | 486 | mq, mk, mv = self.to_qkv(multimodal_seq).chunk(3, dim = -1) 487 | 488 | aq, ak, av, ag = self.to_actions_qkvg(actions).chunk(4, dim = -1) 489 | 490 | mq, mk, mv, aq, ak, av, ag = tuple(self.split_heads(t) for t in (mq, mk, mv, aq, ak, av, ag)) 491 | 492 | if exists(actions_value_residual): 493 | mix = self.to_action_value_residual_mix(actions) 494 | av = av * mix + actions_value_residual * (1. - mix) 495 | 496 | q, k, v = tuple(torch.cat(tensors, dim = -2) for tensors in zip((mq, mk, mv), (aq, ak, av))) 497 | 498 | # handle read, write memories 499 | 500 | has_memories = exists(memories) and any([m.numel() > 0 for m in memories]) 501 | 502 | assert not (self.accept_memories ^ has_memories) 503 | 504 | if has_memories: 505 | memories, unpack_memories = pack_with_inverse(memories, 'b * d') 506 | memories = self.mem_rmsnorm(memories) 507 | mqkv = self.to_mem_qkv(memories) 508 | mqkv_read, mqkv_write = unpack_memories(mqkv, 'b * d') 509 | 510 | mqr, mkr, mvr, mqw, mkw, mvw = tuple(self.split_heads(t) for t in (*mqkv_read.chunk(3, dim = -1), *mqkv_write.chunk(3, dim = -1))) 511 | 512 | k = torch.cat((mkr, k, mkw), dim = -2) 513 | v = torch.cat((mvr, v, mvw), dim = -2) 514 | q, attn_output_unpack_memories = pack_with_inverse((mqr, q, mqw), 'b h * d') 515 | 516 | # rotary embedding 517 | 518 | if exists(rotary_emb): 519 | q = apply_rotary_emb(rotary_emb, q) 520 | k = apply_rotary_emb(rotary_emb, k) 521 | elif exists(self.rotary_emb): 522 | q = self.rotary_emb.rotate_queries_or_keys(q) 523 | k = self.rotary_emb.rotate_queries_or_keys(k) 524 | 525 | if exists(flex_attn_fn): 526 | out = flex_attn_fn(q, k, v) 527 | 528 | else: 529 | # attention 530 | 531 | q = q * self.scale 532 | 533 | sim = einsum(q, k, 'b h i d, b h j d -> b h i j') 534 | 535 | sim = softclamp(sim, self.softclamp_value) 536 | 537 | causal_mask = torch.ones(sim.shape[-2:], dtype = torch.bool, device = device).triu(1) 538 | 539 | if exists(mask): 540 | causal_mask = einx.logical_or('b j, i j -> b 1 i j', ~mask, causal_mask) 541 | 542 | causal_mask[..., seq_len:, seq_len:] = False # actions have bidirectional attention, lining up with Transfusion paper 543 | 544 | sim = sim.masked_fill(causal_mask, max_neg_value(sim)) 545 | 546 | attn = sim.softmax(dim = -1) 547 | 548 | out = einsum(attn, v, 'b h i j, b h j d -> b h i d') 549 | 550 | # gating of values, used in alphafold line of work 551 | 552 | gates = pad_at_dim(ag.sigmoid(), (out.shape[-2] - ag.shape[-2], 0), value = 1., dim = -2) 553 | 554 | out = out * gates 555 | 556 | # split out memories 557 | 558 | if self.accept_memories: 559 | mem_read_out, out, mem_write_out = attn_output_unpack_memories(out) 560 | 561 | # merge attention heads 562 | 563 | out = self.merge_heads(out) 564 | 565 | # separate projections for multimodal seq vs actions 566 | 567 | mout, aout = out[:, :seq_len], out[:, seq_len:] 568 | 569 | output = self.to_out(mout), self.to_actions_out(aout) 570 | 571 | if self.accept_memories: 572 | mem_out, unpack_memories = pack_with_inverse((mem_read_out, mem_write_out), 'b h * d') 573 | mem_out = self.merge_heads(mem_out) 574 | mem_out = self.to_mem_out(mem_out) 575 | 576 | output = (*output, unpack_memories(mem_out, 'b * d')) 577 | 578 | if not return_keys_values: 579 | return output 580 | 581 | return output, (mk, mv, ak, av) 582 | 583 | # attention 584 | 585 | class SwiGLUFeedForward(Module): 586 | def __init__( 587 | self, 588 | dim, 589 | expand_factor = 4., 590 | dim_inner = None, 591 | rmsnorm = True 592 | ): 593 | super().__init__() 594 | dim_inner = default(dim_inner, int(dim * expand_factor * 2 / 3)) 595 | 596 | self.rmsnorm = nn.RMSNorm(dim) if rmsnorm else nn.Identity() 597 | self.proj_in = LinearNoBias(dim, dim_inner * 2) 598 | self.proj_out = LinearNoBias(dim_inner, dim) 599 | 600 | def forward( 601 | self, 602 | seq 603 | ): 604 | seq = self.rmsnorm(seq) 605 | seq, gates = self.proj_in(seq).chunk(2, dim = -1) 606 | seq = seq * F.gelu(gates) 607 | return self.proj_out(seq) 608 | 609 | # actions need time conditioning 610 | # ada-ln zero from DiT - here we will improvise with adaptive rmsnorm 611 | 612 | class RandomFourierEmbed(Module): 613 | def __init__(self, dim): 614 | super().__init__() 615 | self.proj = nn.Linear(1, dim) 616 | self.proj.requires_grad_(False) 617 | 618 | def forward( 619 | self, 620 | times, 621 | ): 622 | times = rearrange(times, '... -> ... 1') 623 | rand_proj = self.proj(times) 624 | return torch.cos(2 * pi * rand_proj) 625 | 626 | class AdaptiveRMSNorm(Module): 627 | def __init__( 628 | self, 629 | dim, 630 | dim_cond 631 | ): 632 | super().__init__() 633 | self.norm = nn.RMSNorm(dim, elementwise_affine = False) 634 | 635 | self.to_gamma = nn.Sequential( 636 | nn.Linear(dim_cond, dim), 637 | nn.Sigmoid() 638 | ) 639 | 640 | self.to_beta = LinearNoBias(dim_cond, dim) 641 | 642 | def forward(self, actions, cond): 643 | 644 | if cond.ndim == 2: 645 | cond = rearrange(cond, 'b d -> b 1 d') 646 | 647 | normed = self.norm(actions) 648 | gamma = self.to_gamma(cond) 649 | beta = self.to_beta(cond) 650 | return normed * gamma + beta 651 | 652 | class AdaptiveLayerscale(Module): 653 | def __init__( 654 | self, 655 | dim, 656 | dim_cond, 657 | adaln_zero_bias_init_value = -2. 658 | ): 659 | super().__init__() 660 | adaln_zero_gamma_linear = nn.Linear(dim_cond, dim) 661 | nn.init.zeros_(adaln_zero_gamma_linear.weight) 662 | nn.init.constant_(adaln_zero_gamma_linear.bias, adaln_zero_bias_init_value) 663 | 664 | self.to_adaln_zero_gamma = adaln_zero_gamma_linear 665 | 666 | def forward(self, actions, cond): 667 | 668 | if cond.ndim == 2: 669 | cond = rearrange(cond, 'b d -> b 1 d') 670 | 671 | gamma = self.to_adaln_zero_gamma(cond) 672 | return actions * gamma.sigmoid() 673 | 674 | # main class 675 | 676 | class PiZero(Module): 677 | @beartype 678 | @save_args_kwargs 679 | def __init__( 680 | self, 681 | dim, 682 | num_tokens, 683 | dim_action_input, 684 | dim_joint_state, 685 | dim_time_cond = None, 686 | depth = 12, 687 | dim_head = 64, 688 | heads = 8, 689 | use_flex_attn = False, 690 | ff_expand_factor = 4., 691 | attn_softclamp_value = 50., 692 | final_norm_softclamp_value = 30., 693 | vit: Module | None = None, 694 | vit_dim = None, 695 | external_state_encoders: Module | list[Module] | None = None, 696 | dim_internal_state: int | None = None, 697 | num_action_register_tokens = 4, 698 | attn_kwargs: dict = dict(), 699 | ff_kwargs: dict = dict(), 700 | lm_pad_id = -1, 701 | lm_loss_weight = 1., 702 | flow_loss_weight = 1., 703 | immiscible_flow = False, # https://arxiv.org/abs/2406.12303 704 | sample_times_fn = default_sample_times, 705 | reward_tokens_dropout_prob = 0., 706 | num_recurrent_memory_tokens = 0, 707 | num_residual_streams = 1, 708 | dim_latent = None, 709 | policy_optimizable = False, # if set to True, will use mean variance network for access to log prob 710 | is_critic = False, # whether this model is used as the critic, with the histogram classification loss from Imani et al. https://arxiv.org/html/2402.13425v1 711 | critic_value_kwargs: dict = dict( 712 | min_value = -10., 713 | max_value = 10., 714 | num_bins = 50 715 | ), 716 | odeint_kwargs: dict = dict( 717 | atol = 1e-5, 718 | rtol = 1e-5, 719 | method = 'midpoint' 720 | ), 721 | ): 722 | super().__init__() 723 | dim_time_cond = default(dim_time_cond, dim * 2) 724 | 725 | self.dim = dim 726 | 727 | # flex attention related 728 | 729 | assert not (use_flex_attn and not exists(flex_attention)), 'flex attention cannot be used' 730 | self.use_flex_attn = use_flex_attn 731 | self.attn_softclamp_value = attn_softclamp_value 732 | 733 | # vit 734 | 735 | self.vit = vit 736 | 737 | self.maybe_to_image_tokens = nn.Linear(vit_dim, dim) if exists(vit_dim) and vit_dim != dim else nn.Identity() 738 | 739 | # embedding 740 | 741 | self.token_emb = nn.Embedding(num_tokens, dim) 742 | 743 | # internal states 744 | 745 | self.to_joint_state_tokens = nn.Linear(dim_joint_state, dim) 746 | 747 | self.dim_internal_state = default(dim_internal_state, dim) 748 | self.to_internal_state_tokens = nn.Linear(dim_internal_state, dim) if exists(dim_internal_state) else nn.Identity() 749 | 750 | # additional external states 751 | 752 | external_state_encoders = default(external_state_encoders, []) 753 | self.external_state_encoders = ModuleList(external_state_encoders) 754 | 755 | # actions 756 | 757 | self.dim_action_input = dim_action_input 758 | 759 | self.action_register_tokens = nn.Parameter(torch.zeros(num_action_register_tokens, dim)) 760 | nn.init.normal_(self.action_register_tokens, std = 0.02) 761 | 762 | self.to_action_tokens = nn.Linear(dim_action_input, dim) 763 | 764 | # time conditioning 765 | 766 | self.to_time_cond = nn.Sequential( 767 | RandomFourierEmbed(dim), 768 | nn.Linear(dim, dim_time_cond), 769 | nn.SiLU(), 770 | ) 771 | 772 | # latent variable / gene conditioning 773 | 774 | can_accept_latent = exists(dim_latent) 775 | self.can_accept_latent = can_accept_latent 776 | 777 | if can_accept_latent: 778 | self.to_latent_cond = nn.Sequential( 779 | nn.Linear(dim_latent, dim_time_cond * 2), 780 | nn.SiLU(), 781 | nn.Linear(dim_time_cond * 2, dim_time_cond), 782 | ) 783 | 784 | nn.init.zeros_(self.to_latent_cond[-1].weight) 785 | nn.init.zeros_(self.to_latent_cond[-1].bias) 786 | 787 | # positional embedding 788 | 789 | self.rotary_emb = RotaryEmbedding(dim_head) 790 | 791 | # recurrent memory parameters and logic 792 | 793 | self.has_recurrent_memories = num_recurrent_memory_tokens > 0 794 | 795 | self.memory_tokens = nn.Parameter(torch.zeros(num_recurrent_memory_tokens, dim)) 796 | nn.init.normal_(self.memory_tokens, std = 0.02) 797 | 798 | self.final_norm_write_memories = nn.RMSNorm(dim) if self.has_recurrent_memories else None 799 | 800 | # residual functions, with maybe hyper connections 801 | 802 | assert num_residual_streams >= 1 803 | init_residual_fn, self.maybe_expand_residuals, self.maybe_reduce_residuals = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1) 804 | 805 | residual_fns = [] 806 | counter = count() 807 | 808 | # attention and feedforward 809 | 810 | layers = [] 811 | cond_layers = [] 812 | 813 | for i in range(depth): 814 | is_first_block = i == 0 815 | 816 | layers.append(ModuleList([ 817 | Attention(dim = dim, dim_head = dim_head, heads = heads, accept_memories = self.has_recurrent_memories, learned_value_action_residual_mix = not is_first_block, **attn_kwargs), 818 | SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs), 819 | SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, rmsnorm = False, **ff_kwargs), 820 | SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs) if self.has_recurrent_memories else None 821 | ])) 822 | 823 | residual_fns.append(ModuleList([ 824 | init_residual_fn(dim = dim, layer_index = next(counter)), 825 | init_residual_fn(dim = dim, layer_index = next(counter)), 826 | ])) 827 | 828 | cond_layers.append(ModuleList([ 829 | AdaptiveRMSNorm(dim, dim_time_cond), 830 | AdaptiveLayerscale(dim, dim_time_cond), 831 | AdaptiveRMSNorm(dim, dim_time_cond), 832 | AdaptiveLayerscale(dim, dim_time_cond) 833 | ])) 834 | 835 | self.layers = ModuleList(layers) 836 | self.cond_layers = ModuleList(cond_layers) 837 | 838 | self.residual_layers = ModuleList(residual_fns) 839 | 840 | self.final_norm_softclamp = partial(softclamp, value = final_norm_softclamp_value) 841 | 842 | self.final_norm = nn.RMSNorm(dim) 843 | self.final_actions_norm = nn.RMSNorm(dim) 844 | 845 | # unembedding 846 | 847 | self.state_to_logits = LinearNoBias(dim, num_tokens) 848 | 849 | # actor related 850 | 851 | self.actions_to_pred_flow = None 852 | self.loss_fn = None 853 | 854 | if not is_critic: 855 | if not policy_optimizable: 856 | self.actions_to_pred_flow = LinearNoBias(dim, dim_action_input) 857 | self.loss_fn = nn.MSELoss() 858 | else: 859 | self.actions_to_pred_flow = LinearToMeanStd(dim, dim_action_input) 860 | self.loss_fn = GaussianNLL() 861 | 862 | self.is_mean_std_output = policy_optimizable 863 | self.policy_optimizable = policy_optimizable 864 | 865 | # critic related 866 | 867 | self.is_critic = is_critic 868 | 869 | self.to_critic_value = HLGaussLayer( 870 | dim, 871 | hl_gauss_loss = critic_value_kwargs 872 | ) 873 | 874 | # the language token id padding id, for fine-tuning as well as taking care of the masking on top of causal mask 875 | 876 | self.lm_pad_id = lm_pad_id 877 | 878 | # flow related 879 | 880 | self.immiscible_flow = immiscible_flow 881 | 882 | # reward classifier free guidance 883 | 884 | self.reward_tokens_dropout_prob = reward_tokens_dropout_prob 885 | 886 | # time sampling related 887 | 888 | self.sample_times_fn = default(sample_times_fn, torch.rand) 889 | 890 | # loss related 891 | 892 | self.lm_loss_weight = lm_loss_weight 893 | self.flow_loss_weight = flow_loss_weight 894 | 895 | # sampling related 896 | 897 | self.odeint_fn = partial(odeint, **odeint_kwargs) 898 | 899 | self.register_buffer('zero', torch.tensor(0.), persistent = False) 900 | 901 | # tensor typing 902 | 903 | self._nm = num_recurrent_memory_tokens 904 | 905 | @property 906 | def can_cfg(self): 907 | return self.reward_tokens_dropout_prob > 0. 908 | 909 | @property 910 | def device(self): 911 | return next(self.parameters()).device 912 | 913 | @beartype 914 | def load_pretrained_vlm_weights_( 915 | self, 916 | weights: dict[str, Tensor] 917 | ): 918 | raise NotImplementedError 919 | 920 | def create_ema( 921 | self, 922 | beta = 0.99, 923 | **ema_kwargs 924 | ) -> EMA: 925 | 926 | ema_pi_zero = EMA( 927 | self, 928 | beta = beta, 929 | include_online_model = False, 930 | forward_method_names = ( 931 | 'sample_actions', 932 | ), 933 | **ema_kwargs 934 | ) 935 | 936 | return ema_pi_zero 937 | 938 | def create_actor(self, **kwargs) -> PiZero: 939 | assert not self.is_critic, 'base model must not be a critic' 940 | 941 | # make probabilistic flow if not already 942 | 943 | if not self.policy_optimizable: 944 | assert 'policy_optimizable' not in kwargs 945 | kwargs.update(policy_optimizable = True) 946 | 947 | orig_args, orig_kwargs = self._init_args_kwargs 948 | actor = PiZero(*orig_args, **orig_kwargs, **kwargs) 949 | 950 | # load all possible shared parameters except for output head to logits (for histogram loss) 951 | 952 | state_dict = self.state_dict() 953 | actor.load_state_dict(state_dict, strict = False) 954 | 955 | # now, initialize the actor with variance of 1. 956 | # https://arxiv.org/abs/2302.08875 957 | 958 | if not self.policy_optimizable: 959 | orig_mean_weight = self.actions_to_pred_flow.weight 960 | 961 | actor_mean_std_weight = actor.actions_to_pred_flow.linear.weight 962 | 963 | actor_mean_std_weight.data.copy_(rearrange([orig_mean_weight, torch.zeros_like(orig_mean_weight)], 'mu_sigma o i -> (o mu_sigma) i')) 964 | 965 | return actor.to(self.device) 966 | 967 | def create_critic(self, **kwargs) -> PiZero: 968 | assert self.policy_optimizable and not self.is_critic, 'base model must be policy optimizable as well as not a critic already' 969 | 970 | assert 'is_critic' not in kwargs 971 | kwargs.update(is_critic = True) 972 | 973 | orig_args, orig_kwargs = self._init_args_kwargs 974 | critic = PiZero(*orig_args, **orig_kwargs, **kwargs) 975 | 976 | # load all possible shared parameters except for output head to logits (for histogram loss) 977 | 978 | state_dict = self.state_dict() 979 | critic.load_state_dict(state_dict, strict = False) 980 | 981 | return critic.to(self.device) 982 | 983 | @torch.no_grad() 984 | def sample_actions( 985 | self, 986 | images, 987 | token_ids, 988 | joint_states, 989 | trajectory_length: int, 990 | latents: Float['d'] | Float['b d'] = None, 991 | reward_tokens: Float['b d'] | None = None, 992 | internal_state_tokens: Float['b ns d'] | None = None, 993 | steps = 18, 994 | show_pbar = True, 995 | cond_scale = 0., 996 | temperature = 1., 997 | remove_parallel_component = True, 998 | keep_parallel_frac = 0., 999 | cache_kv = True, 1000 | return_states_for_replay = False, 1001 | critic: Module | None = None, 1002 | ): 1003 | assert not self.is_critic 1004 | 1005 | batch_size = token_ids.shape[0] 1006 | 1007 | was_training = self.training 1008 | self.eval() 1009 | 1010 | pbar = tqdm.tqdm(desc = 'sampling action trajectory', disable = not show_pbar, total = steps) 1011 | 1012 | # accumulate log probs for ppo 1013 | 1014 | assert not (return_states_for_replay and not self.is_mean_std_output), 'only pi-zero with `policy_optimizable` turned on can return log probs' 1015 | 1016 | timesteps = [] 1017 | log_probs = [] 1018 | sampled_flows = [] 1019 | denoised_actions_across_time = [] 1020 | 1021 | critic_values = [] 1022 | 1023 | # ode step function 1024 | 1025 | cached_state_kv = None 1026 | null_cached_state_kv = None 1027 | 1028 | def ode_fn(timestep, denoised_actions): 1029 | nonlocal cached_state_kv 1030 | nonlocal null_cached_state_kv 1031 | 1032 | input_args = ( 1033 | images, 1034 | token_ids, 1035 | joint_states, 1036 | denoised_actions 1037 | ) 1038 | 1039 | input_kwargs = dict( 1040 | times = timestep, 1041 | latents = latents, 1042 | reward_tokens = reward_tokens, 1043 | internal_state_tokens = internal_state_tokens, 1044 | cached_state_keys_values = (cached_state_kv, null_cached_state_kv), 1045 | cond_scale = cond_scale, 1046 | remove_parallel_component = remove_parallel_component, 1047 | keep_parallel_frac = keep_parallel_frac 1048 | ) 1049 | 1050 | output, (new_cached_state_kv, new_null_cached_state_kv) = self.forward_with_reward_cfg(*input_args, **input_kwargs) 1051 | 1052 | if exists(critic): 1053 | critic_value, _ = critic.forward_with_reward_cfg(*input_args, **input_kwargs) 1054 | critic_values.append(critic_value) 1055 | 1056 | flow = output 1057 | 1058 | if self.is_mean_std_output: 1059 | mean, std = output.unbind(dim = -1) 1060 | 1061 | flow = torch.normal(mean, std * temperature) 1062 | 1063 | log_prob = Normal(mean, std).log_prob(flow) 1064 | 1065 | # save for replaying for optimizing actor 1066 | 1067 | denoised_actions_across_time.append(denoised_actions) 1068 | timesteps.append(repeat(timestep, ' -> b', b = batch_size)) 1069 | log_probs.append(log_prob) 1070 | sampled_flows.append(flow) 1071 | 1072 | if cache_kv: 1073 | cached_state_kv = new_cached_state_kv 1074 | null_cached_state_kv = new_null_cached_state_kv 1075 | 1076 | pbar.update(1) 1077 | 1078 | return flow 1079 | 1080 | # start with random gaussian noise - y0 1081 | 1082 | noise = torch.randn((batch_size, trajectory_length, self.dim_action_input), device = self.device) 1083 | 1084 | # time steps 1085 | 1086 | times = torch.linspace(0., 1., steps, device = self.device) 1087 | 1088 | # ode 1089 | 1090 | trajectory = self.odeint_fn(ode_fn, noise, times) 1091 | 1092 | sampled_actions = trajectory[-1] 1093 | 1094 | self.train(was_training) 1095 | 1096 | pbar.close() 1097 | 1098 | if not return_states_for_replay: 1099 | return sampled_actions 1100 | 1101 | # place the time step dimension after batch 1102 | 1103 | timesteps = stack(timesteps, dim = 1) 1104 | log_probs = stack(log_probs, dim = 1) 1105 | sampled_flow = stack(sampled_flows, dim = 1) 1106 | denoised_actions_across_time = stack(denoised_actions_across_time, dim = 1) 1107 | 1108 | policy_optimization_outputs = (denoised_actions_across_time, timesteps, log_probs, sampled_flow) 1109 | 1110 | # return critic value predictions if passed in - will just deepcopy pi-zero + a critic head 1111 | 1112 | if exists(critic): 1113 | critic_values = stack(critic_values, dim = 1) 1114 | policy_optimization_outputs = (*policy_optimization_outputs, critic_values) 1115 | 1116 | return sampled_actions, policy_optimization_outputs 1117 | 1118 | @torch.no_grad() 1119 | def forward_with_reward_cfg( 1120 | self, 1121 | *args, 1122 | reward_tokens: Float['b d'] | None = None, 1123 | cached_state_keys_values = (None, None), 1124 | cond_scale = 0., 1125 | remove_parallel_component = False, 1126 | keep_parallel_frac = 0., 1127 | **kwargs 1128 | ): 1129 | 1130 | with_reward_cache, without_reward_cache = cached_state_keys_values 1131 | 1132 | forward_kwargs = dict( 1133 | return_state_keys_values = True, 1134 | return_actions_flow = True, 1135 | ) 1136 | 1137 | action_flow_with_reward, with_reward_cache_kv = self.forward( 1138 | *args, 1139 | reward_tokens = reward_tokens, 1140 | cached_state_keys_values = with_reward_cache, 1141 | **forward_kwargs, 1142 | **kwargs 1143 | ) 1144 | 1145 | if not exists(reward_tokens) or cond_scale == 0.: 1146 | return action_flow_with_reward, (with_reward_cache_kv, None) 1147 | 1148 | assert self.can_cfg, 'you need to train with reward token dropout' 1149 | 1150 | action_flow_without_reward, without_reward_cache_kv = self.forward( 1151 | *args, 1152 | cached_state_keys_values = without_reward_cache, 1153 | **forward_kwargs, 1154 | **kwargs 1155 | ) 1156 | 1157 | update = action_flow_with_reward - action_flow_without_reward 1158 | 1159 | if remove_parallel_component: 1160 | # from https://arxiv.org/abs/2410.02416 1161 | 1162 | update_parallel, update_orthog = project(update, action_flow_with_reward) 1163 | update = update_orthog + update_parallel * keep_parallel_frac 1164 | 1165 | flow_with_reward_cfg = action_flow_with_reward + cond_scale * update 1166 | 1167 | return flow_with_reward_cfg, (with_reward_cache_kv, without_reward_cache_kv) 1168 | 1169 | @move_input_tensors_to_device 1170 | def forward_only_vision_language( 1171 | self, 1172 | images: Float['b nv d'] | Float['b c h w'] | Float['b c f h w'], # vision 1173 | token_ids: Int['b nt'], # language 1174 | ) -> Float['b n d']: 1175 | 1176 | device = token_ids.device 1177 | 1178 | language_tokens = self.token_emb(token_ids) 1179 | 1180 | # vision 1181 | 1182 | if exists(self.vit): 1183 | assert images.ndim in {4, 5} 1184 | is_multiple_images = images.ndim == 5 1185 | 1186 | if is_multiple_images: 1187 | images = rearrange(images, 'b c f h w -> b f c h w') 1188 | images, inverse_pack_image_frames = pack_with_inverse([images], '* c h w') 1189 | 1190 | with torch.no_grad(): 1191 | self.vit.eval() 1192 | visual_tokens = self.vit(images) 1193 | 1194 | if is_multiple_images: 1195 | visual_tokens, = inverse_pack_image_frames(visual_tokens, '* n d') 1196 | visual_tokens = rearrange(visual_tokens, 'b f n d -> b (f n) d') 1197 | 1198 | else: 1199 | assert images.ndim == 3, 'images must be already encoded as (batch, seq, feature dimension)' 1200 | visual_tokens = images 1201 | 1202 | visual_tokens = self.maybe_to_image_tokens(visual_tokens) 1203 | 1204 | # concat visual rep with language 1205 | 1206 | state_tokens, _ = pack_with_inverse([ 1207 | visual_tokens, 1208 | language_tokens, 1209 | ], 'b * d') 1210 | 1211 | # rotary embeddings 1212 | 1213 | seq_len = state_tokens.shape[-2] 1214 | 1215 | seq = torch.arange(seq_len, device = device) 1216 | 1217 | rotary_emb = self.rotary_emb(seq) 1218 | 1219 | # transformer 1220 | 1221 | for attn, ff, _, _ in self.layers: 1222 | 1223 | state_attn_out = attn.forward_only_vision_language(state_tokens, rotary_emb = rotary_emb) 1224 | 1225 | state_tokens = state_tokens + state_attn_out 1226 | 1227 | state_tokens = ff(state_tokens) + state_tokens 1228 | 1229 | embed = self.final_norm_softclamp(state_tokens) 1230 | 1231 | logits = self.state_to_logits(embed) 1232 | 1233 | return logits 1234 | 1235 | @move_input_tensors_to_device 1236 | def forward_for_policy_loss( 1237 | self, 1238 | images, 1239 | commands, 1240 | joint_state, 1241 | actions, 1242 | flow, 1243 | times, 1244 | old_log_probs: Float['b na'], 1245 | advantages: Float['b t'], 1246 | clip_eps = 0.2, 1247 | entropy_weight = 1e-2, 1248 | norm_eps = 1e-5, 1249 | **kwargs, 1250 | ): 1251 | assert not self.is_critic 1252 | assert self.policy_optimizable 1253 | assert 'return_actions_flow' not in kwargs 1254 | 1255 | # flatten the time into the batch for actions at timestep, sampled flow, and log prob 1256 | 1257 | if times.ndim == 2: 1258 | times = rearrange(times, 'b t -> (b t)') 1259 | 1260 | if flow.ndim == 4: 1261 | flow = rearrange(flow, 'b t ... -> (b t) ...') 1262 | 1263 | if old_log_probs.ndim == 4: 1264 | old_log_probs = rearrange(old_log_probs, 'b t ... -> (b t) ...') 1265 | 1266 | if actions.ndim == 4: 1267 | actions = rearrange(actions, 'b t ... -> (b t) ...') 1268 | 1269 | # expand inputs across timesteps if need be 1270 | 1271 | ( 1272 | images, 1273 | commands, 1274 | joint_state, 1275 | ) = tuple(repeat(inp, 'b ... -> (b t) ...', t = times.shape[0] // inp.shape[0]) for inp in ( 1276 | images, 1277 | commands, 1278 | joint_state 1279 | )) 1280 | 1281 | mean_std = self.forward( 1282 | images, 1283 | commands, 1284 | joint_state, 1285 | actions, 1286 | return_actions_flow = True, 1287 | **kwargs 1288 | ) 1289 | 1290 | normal_dist = Normal(*mean_std.unbind(dim = -1)) 1291 | 1292 | new_log_probs = normal_dist.log_prob(actions) 1293 | 1294 | # ppo surrogate loss 1295 | 1296 | ratio = (new_log_probs - old_log_probs).exp() 1297 | 1298 | advantages = F.layer_norm(advantages, advantages.shape, eps = norm_eps) 1299 | 1300 | advantages = rearrange(advantages, 'b t -> (b t) 1 1') 1301 | 1302 | surr1 = ratio * advantages 1303 | surr2 = ratio.clamp(1. - clip_eps, 1. + clip_eps) * advantages 1304 | 1305 | clipped_surr_loss = torch.min(surr1, surr2).sum(dim = -1) 1306 | 1307 | # entropy 1308 | 1309 | entropy = (normal_dist.entropy() * entropy_weight).sum(dim = -1) 1310 | 1311 | return -(clipped_surr_loss + entropy * entropy_weight).mean() 1312 | 1313 | @move_input_tensors_to_device 1314 | def forward_for_critic_loss( 1315 | self, 1316 | *args, 1317 | old_values: Float['b t'], 1318 | advantages: Float['b t'], 1319 | clip_eps = 0.4, 1320 | **kwargs 1321 | ): 1322 | assert self.is_critic 1323 | 1324 | eps = clip_eps 1325 | loss_fn = self.to_critic_value.loss_fn 1326 | 1327 | critic_value, critic_logits = self.forward(*args, **kwargs) 1328 | 1329 | # value clipping 1330 | 1331 | advantages = rearrange(advantages, 'b t -> (b t)') 1332 | old_values = rearrange(old_values, 'b t -> (b t)') 1333 | 1334 | returns = old_values + advantages 1335 | 1336 | clipped_value = old_values + (critic_value - old_values).clamp(-eps, eps) 1337 | 1338 | clipped_loss = loss_fn(clipped_value, returns, reduction = 'none') 1339 | loss = loss_fn(critic_logits, returns, reduction = 'none') 1340 | 1341 | return torch.max(clipped_loss, loss).mean() 1342 | 1343 | @move_input_tensors_to_device 1344 | def forward( 1345 | self, 1346 | images: Float['b nv d'] | Float['b c h w'] | Float['b c f h w'], # vision 1347 | token_ids: Int['b nt'], # language 1348 | joint_state: Float['b djs'], # joint state 1349 | actions: Float['b na da'] | None = None, # action 1350 | times: Float['b'] = None, 1351 | latents: Float['d'] | Float['b d'] = None, 1352 | reward_tokens: Float['b d'] | None = None, 1353 | internal_state_tokens: Float['b ns d'] | None = None, 1354 | external_states: tuple[Float['b ...']] | None = None, 1355 | record_and_return_memory_tokens = False, 1356 | past_recurrent_memory_tokens: Float['b {self._nm} d'] | None = None, 1357 | return_actions_flow = False, 1358 | return_state_keys_values = False, 1359 | cached_state_keys_values: list[tuple[Tensor, Tensor]] | None = None, 1360 | return_language_loss = True, 1361 | return_action_flow_loss = True, 1362 | **kwargs 1363 | ): 1364 | inferencing = exists(cached_state_keys_values) 1365 | assert not (inferencing and not return_actions_flow), 'must be generating action trajectory if receiving cached state key values' 1366 | 1367 | if not exists(actions) and not self.is_critic: 1368 | return self.sample_actions(images, token_ids, joint_state, **kwargs) 1369 | 1370 | batch, device = token_ids.shape[0], token_ids.device 1371 | 1372 | # noising the action for flow matching 1373 | 1374 | if not exists(times): 1375 | times = self.sample_times_fn((batch,), device = device) 1376 | 1377 | if times.ndim == 0: 1378 | times = repeat(times, '-> b', b = batch) 1379 | 1380 | # handle latent genes 1381 | 1382 | if exists(latents) and latents.ndim == 1: 1383 | latents = repeat(latents, 'd -> b d', b = batch) 1384 | 1385 | # if not returning the actions predicted flow, assume training and noise the actions for loss 1386 | 1387 | if not return_actions_flow and not self.is_critic: 1388 | noise = torch.randn_like(actions) 1389 | 1390 | if self.immiscible_flow: 1391 | assignment = noise_assignment(actions, noise) 1392 | noise = noise[assignment] 1393 | 1394 | flow = actions - noise 1395 | padded_times = rearrange(times, 'b -> b 1 1') 1396 | 1397 | actions = noise.lerp(actions, padded_times) 1398 | 1399 | # actions 1400 | 1401 | time_cond = self.to_time_cond(times) 1402 | action_tokens = self.to_action_tokens(actions) 1403 | 1404 | # handle maybe latents 1405 | 1406 | if exists(latents): 1407 | assert self.can_accept_latent 1408 | 1409 | latent_cond = self.to_latent_cond(latents) 1410 | 1411 | time_cond = time_cond * (latent_cond + 1.) 1412 | 1413 | # register tokens 1414 | 1415 | action_register_tokens = repeat(self.action_register_tokens, '... -> b ...', b = batch) 1416 | 1417 | # take care of maybe recurrent memory tokens 1418 | 1419 | assert self.has_recurrent_memories or not exists(past_recurrent_memory_tokens), 'you are asking for memories to be read, but `num_recurrent_memory_tokens` is 0' 1420 | assert self.has_recurrent_memories or not record_and_return_memory_tokens, 'you are asking for memories to be written, but `num_recurrent_memory_tokens` is 0' 1421 | 1422 | if not exists(past_recurrent_memory_tokens): 1423 | past_recurrent_memory_tokens = actions.new_empty((batch, 0, self.dim)) 1424 | 1425 | if self.has_recurrent_memories: 1426 | write_memory_tokens = repeat(self.memory_tokens, 'nm d -> b nm d', b = batch) 1427 | else: 1428 | write_memory_tokens = actions.new_empty((batch, 0, self.dim)) 1429 | 1430 | # joint state + additional internal states 1431 | 1432 | joint_state_tokens = self.to_joint_state_tokens(joint_state) 1433 | 1434 | # additional internal state tokens 1435 | 1436 | if not exists(internal_state_tokens): 1437 | internal_state_tokens = joint_state_tokens.new_empty((batch, 0, self.dim_internal_state)) 1438 | 1439 | internal_state_tokens = self.to_internal_state_tokens(internal_state_tokens) 1440 | 1441 | # handle memory tokens, both read and write as a tuple of two tensors 1442 | 1443 | memory_tokens = (past_recurrent_memory_tokens, write_memory_tokens) 1444 | 1445 | # mem_length = past_recurrent_memory_tokens.shape[-2] + write_memory_tokens.shape[-2] 1446 | 1447 | # pack into [action registers] [internal + joint states] [actions] 1448 | 1449 | action_tokens, inverse_pack_action_registers = pack_with_inverse([ 1450 | action_register_tokens, 1451 | joint_state_tokens, 1452 | internal_state_tokens, 1453 | action_tokens 1454 | ], 'b * d') 1455 | 1456 | action_with_registers_length = action_tokens.shape[-2] 1457 | 1458 | state_tokens = None 1459 | 1460 | if not inferencing: 1461 | # language 1462 | 1463 | labels = token_ids[:, 1:] 1464 | 1465 | language_tokens = self.token_emb(token_ids) 1466 | 1467 | # vision 1468 | 1469 | if exists(self.vit): 1470 | assert images.ndim in {4, 5} 1471 | is_multiple_images = images.ndim == 5 1472 | 1473 | if is_multiple_images: 1474 | images = rearrange(images, 'b c f h w -> b f c h w') 1475 | images, inverse_pack_image_frames = pack_with_inverse([images], '* c h w') 1476 | 1477 | with torch.no_grad(): 1478 | self.vit.eval() 1479 | visual_tokens = self.vit(images) 1480 | 1481 | if is_multiple_images: 1482 | visual_tokens, = inverse_pack_image_frames(visual_tokens, '* n d') 1483 | visual_tokens = rearrange(visual_tokens, 'b f n d -> b (f n) d') 1484 | 1485 | else: 1486 | assert images.ndim == 3, 'images must be already encoded as (batch, seq, feature dimension)' 1487 | visual_tokens = images 1488 | 1489 | visual_tokens = self.maybe_to_image_tokens(visual_tokens) 1490 | 1491 | # maybe reward tokens 1492 | 1493 | if not exists(reward_tokens): 1494 | reward_tokens = visual_tokens.new_empty((batch, 0, self.dim)) 1495 | 1496 | # maybe dropout reward tokens 1497 | 1498 | if self.training and random() < self.reward_tokens_dropout_prob: 1499 | reward_tokens = reward_tokens[:, 0:0] 1500 | 1501 | # additional external states 1502 | 1503 | if exists(external_states): 1504 | external_state_tokens = [encode(external_state) for encode, external_state in zip(self.external_state_encoders, external_states)] 1505 | external_state_tokens = pack(external_state_tokens, 'b * d') 1506 | 1507 | else: 1508 | external_state_tokens = visual_tokens.new_empty((batch, 0, self.dim)) 1509 | 1510 | # concat visual rep with language 1511 | 1512 | state_tokens, inverse_packed_states = pack_with_inverse([ 1513 | external_state_tokens, 1514 | visual_tokens, 1515 | language_tokens, 1516 | reward_tokens 1517 | ], 'b * d') 1518 | 1519 | # take care of masking for variable lengthed states, starting with the language tokens 1520 | 1521 | # which then leads to proper rotary embeddings 1522 | 1523 | command_length = token_ids.shape[-1] 1524 | 1525 | language_mask = token_ids != self.lm_pad_id 1526 | 1527 | if inferencing: 1528 | state_length = cached_state_keys_values[0][0].shape[-2] 1529 | else: 1530 | state_length = state_tokens.shape[-2] 1531 | 1532 | mask = F.pad(language_mask, (state_length - command_length, action_with_registers_length), value = True) # assume fixed number of images for now, but address variable length modality states later 1533 | 1534 | # memory 1535 | 1536 | mask = F.pad(mask, (past_recurrent_memory_tokens.shape[-2], write_memory_tokens.shape[-2]), value = True) 1537 | 1538 | # rotary embeddings 1539 | 1540 | seq = mask.float().cumsum(dim = -1) 1541 | rotary_emb = self.rotary_emb(seq) 1542 | 1543 | rotary_emb = rearrange(rotary_emb, 'b n d -> b 1 n d') 1544 | 1545 | # prepare maybe flex attention 1546 | 1547 | flex_attn_fn = None 1548 | 1549 | if not inferencing and self.use_flex_attn and state_tokens.is_cuda: 1550 | 1551 | prefix_length = state_tokens.shape[-2] 1552 | seq_len = prefix_length + action_tokens.shape[-2] 1553 | 1554 | block_mask = create_block_mask( 1555 | create_pizero_attn_mask( 1556 | prefix_length, 1557 | mask = mask, 1558 | ), 1559 | Q_LEN = seq_len, 1560 | KV_LEN = seq_len, 1561 | device = state_tokens.device, 1562 | _compile = True, 1563 | ) 1564 | 1565 | score_mod_fn = softclamp_score_mod(self.attn_softclamp_value) 1566 | 1567 | flex_attn_fn = partial( 1568 | flex_attention, 1569 | block_mask = block_mask, 1570 | score_mod = score_mod_fn 1571 | ) 1572 | 1573 | # state keys and values for caching during inference 1574 | 1575 | cached_state_key_values_iter = iter(default(cached_state_keys_values, [])) 1576 | 1577 | # value residual learning 1578 | 1579 | actions_value_residual = None 1580 | 1581 | # maybe expand residual streams 1582 | 1583 | action_tokens = self.maybe_expand_residuals(action_tokens) 1584 | 1585 | # transformer 1586 | 1587 | if not inferencing: 1588 | 1589 | next_state_cached_keys_values = [] 1590 | 1591 | for ( 1592 | (attn, state_ff, actions_ff, memories_ff), 1593 | (attn_ada_rmsnorm, attn_ada_layerscale, ff_ada_rmsnorm, ff_ada_layerscale), 1594 | (attn_residual, actions_ff_residual), 1595 | ) in zip(self.layers, self.cond_layers, self.residual_layers): 1596 | 1597 | # joint attention 1598 | 1599 | action_tokens, add_action_residual = attn_residual(action_tokens) 1600 | 1601 | action_tokens = attn_ada_rmsnorm(action_tokens, time_cond) 1602 | 1603 | (state_attn_out, actions_attn_out, *maybe_mem_out), (state_keys, state_values, action_keys, action_values) = attn( 1604 | state_tokens, 1605 | action_tokens, 1606 | rotary_emb = rotary_emb, 1607 | flex_attn_fn = flex_attn_fn, 1608 | actions_value_residual = actions_value_residual, 1609 | mask = mask, 1610 | return_keys_values = True, 1611 | memories = memory_tokens 1612 | ) 1613 | 1614 | next_state_cached_keys_values.append((state_keys, state_values)) 1615 | 1616 | actions_value_residual = default(actions_value_residual, action_values) 1617 | 1618 | action_attn_out = attn_ada_layerscale(actions_attn_out, time_cond) 1619 | 1620 | state_tokens = state_tokens + state_attn_out 1621 | action_tokens = add_action_residual(action_attn_out) 1622 | 1623 | if self.has_recurrent_memories: 1624 | (read_mem_attn_out, write_mem_attn_out), = maybe_mem_out 1625 | read_mem, write_mem = memory_tokens 1626 | 1627 | memory_tokens = (read_mem + read_mem_attn_out, write_mem + write_mem_attn_out) 1628 | 1629 | # state feedforward 1630 | 1631 | state_tokens_out = state_ff(state_tokens) 1632 | 1633 | state_tokens = state_tokens + state_tokens_out 1634 | 1635 | # action feedforward 1636 | 1637 | action_tokens, add_action_ff_residual = actions_ff_residual(action_tokens) 1638 | 1639 | action_tokens = ff_ada_rmsnorm(action_tokens, time_cond) 1640 | 1641 | action_tokens_out = actions_ff(action_tokens) 1642 | 1643 | action_tokens_out = ff_ada_layerscale(action_tokens_out, time_cond) 1644 | 1645 | action_tokens = add_action_ff_residual(action_tokens_out) 1646 | 1647 | # maybe memory feedforward 1648 | 1649 | if self.has_recurrent_memories: 1650 | memory_tokens, unpack_memory = pack_with_inverse(memory_tokens, 'b * d') 1651 | 1652 | memory_tokens = memories_ff(memory_tokens) + memory_tokens 1653 | 1654 | memory_tokens = unpack_memory(memory_tokens) 1655 | 1656 | else: 1657 | 1658 | assert exists(cached_state_keys_values) and len(cached_state_keys_values) > 0 1659 | 1660 | next_state_cached_keys_values = cached_state_keys_values 1661 | 1662 | for ( 1663 | (attn, state_ff, actions_ff, memories_ff), 1664 | (attn_ada_rmsnorm, attn_ada_layerscale, ff_ada_rmsnorm, ff_ada_layerscale), 1665 | (attn_residual, actions_ff_residual), 1666 | ) in zip(self.layers, self.cond_layers, self.residual_layers): 1667 | 1668 | # actions attention 1669 | 1670 | action_tokens, add_action_residual = attn_residual(action_tokens) 1671 | 1672 | action_tokens = attn_ada_rmsnorm(action_tokens, time_cond) 1673 | 1674 | actions_attn_out, (state_keys, state_values, action_keys, action_values) = attn.forward_actions_with_cached_state( 1675 | action_tokens, 1676 | cached_state_keys_values = next(cached_state_key_values_iter), 1677 | rotary_emb = rotary_emb, 1678 | mask = mask, 1679 | return_keys_values = True 1680 | ) 1681 | 1682 | actions_value_residual = default(actions_value_residual, action_values) 1683 | 1684 | actions_attn_out = attn_ada_layerscale(actions_attn_out, time_cond) 1685 | action_tokens = add_action_residual(actions_attn_out) 1686 | 1687 | # actions feed forward 1688 | 1689 | action_tokens, add_action_ff_residual = actions_ff_residual(action_tokens) 1690 | 1691 | action_tokens = ff_ada_rmsnorm(action_tokens, time_cond) 1692 | 1693 | action_out = actions_ff(action_tokens) 1694 | 1695 | action_out = ff_ada_layerscale(action_out, time_cond) 1696 | 1697 | action_tokens = add_action_residual(action_out) 1698 | 1699 | # maybe memory feed forward 1700 | 1701 | if self.has_recurrent_memories: 1702 | memory_tokens, unpack_memory = pack_with_inverse(memory_tokens, 'b * d') 1703 | 1704 | memory_tokens = memories_ff(memory_tokens) + memory_tokens 1705 | 1706 | memory_tokens = unpack_memory(memory_tokens) 1707 | 1708 | # maybe reduce residual streams 1709 | 1710 | action_tokens = self.maybe_reduce_residuals(action_tokens) 1711 | 1712 | if not inferencing: 1713 | # unpack and unembed to predictions 1714 | 1715 | _, visual_tokens, tokens, *_ = inverse_packed_states(state_tokens, 'b * d') 1716 | 1717 | # gemma uses a final softclamp before norm 1718 | 1719 | tokens = self.final_norm_softclamp(tokens) 1720 | 1721 | *_, action_tokens = inverse_pack_action_registers(action_tokens) 1722 | 1723 | action_tokens = self.final_norm_softclamp(action_tokens) 1724 | 1725 | # memories 1726 | 1727 | read_memories, written_memory_tokens = memory_tokens 1728 | 1729 | # writeable memories norm 1730 | 1731 | if self.has_recurrent_memories: 1732 | written_memory_tokens = self.final_norm_write_memories(written_memory_tokens) 1733 | 1734 | # final actions norm 1735 | 1736 | action_embeds = self.final_actions_norm(action_tokens) 1737 | 1738 | # pool the action embeds and project if critic loss 1739 | 1740 | if self.is_critic: 1741 | action_embeds = reduce(action_embeds, 'b n d -> b d', 'mean') 1742 | 1743 | return self.to_critic_value(action_embeds, return_value_and_logits = True) 1744 | 1745 | # validate loss being returned 1746 | 1747 | assert return_language_loss or return_action_flow_loss 1748 | 1749 | # flow loss for actions tokens 1750 | 1751 | pred_actions_flow = self.actions_to_pred_flow(action_embeds) 1752 | 1753 | if return_actions_flow: 1754 | 1755 | if not return_state_keys_values and not record_and_return_memory_tokens: 1756 | return pred_actions_flow 1757 | 1758 | if not return_state_keys_values: 1759 | return pred_actions_flow, written_memory_tokens 1760 | 1761 | return pred_actions_flow, next_state_cached_keys_values 1762 | 1763 | flow_loss = self.zero 1764 | 1765 | if return_action_flow_loss: 1766 | flow_loss = self.loss_fn(pred_actions_flow, flow) 1767 | 1768 | # language cross entropy loss 1769 | 1770 | language_loss = self.zero 1771 | 1772 | if return_language_loss: 1773 | tokens = self.final_norm(tokens) 1774 | 1775 | language_logits = self.state_to_logits(tokens) 1776 | 1777 | language_loss = F.cross_entropy( 1778 | rearrange(language_logits[:, :-1], 'b n l -> b l n'), 1779 | labels, 1780 | ignore_index = self.lm_pad_id 1781 | ) 1782 | 1783 | # loss breakdown 1784 | 1785 | loss_breakdown = (language_loss, flow_loss) 1786 | 1787 | # total loss and return breakdown 1788 | 1789 | total_loss = ( 1790 | language_loss * self.lm_loss_weight + 1791 | flow_loss * self.flow_loss_weight 1792 | ) 1793 | 1794 | if not record_and_return_memory_tokens: 1795 | return total_loss, loss_breakdown 1796 | 1797 | return total_loss, loss_breakdown, written_memory_tokens 1798 | 1799 | # generalized advantage estimate 1800 | 1801 | @torch.no_grad() 1802 | def calc_generalized_advantage_estimate( 1803 | rewards, 1804 | values, 1805 | masks, 1806 | gamma = 0.99, 1807 | lam = 0.95, 1808 | use_accelerated = None 1809 | ): 1810 | use_accelerated = default(use_accelerated, rewards.is_cuda) 1811 | 1812 | values = F.pad(values, (0, 1), value = 0.) 1813 | values, values_next = values[:, :-1], values[:, 1:] 1814 | 1815 | delta = rewards + gamma * values_next * masks - values 1816 | gates = gamma * lam * masks 1817 | 1818 | scan = AssocScan(reverse = True, use_accelerated = use_accelerated) 1819 | 1820 | return scan(gates, delta) 1821 | 1822 | # agent 1823 | 1824 | class Agent(Module): 1825 | def __init__( 1826 | self, 1827 | model: PiZero, 1828 | optim_klass = AdoptAtan2, 1829 | num_latent_genes = 1, 1830 | actor_lr = 3e-4, 1831 | critic_lr = 3e-4, 1832 | actor_weight_decay = 1e-3, 1833 | critic_weight_decay = 1e-3, 1834 | max_grad_norm = 0.5, 1835 | actor_optim_kwargs: dict = dict(), 1836 | critic_optim_kwargs: dict = dict(), 1837 | latent_gene_pool_kwargs: dict = dict( 1838 | frac_tournaments = 0.5 1839 | ) 1840 | ): 1841 | super().__init__() 1842 | 1843 | # evolutionary policy optimization related 1844 | # Wang et al. https://web3.arxiv.org/abs/2503.19037 1845 | 1846 | assert num_latent_genes >= 1 1847 | evolutionary_learning = num_latent_genes > 1 1848 | 1849 | dim_latent = model.dim if evolutionary_learning else None 1850 | 1851 | self.latent_gene_pool = LatentGenePool(dim_latent = dim_latent, num_latents = num_latent_genes, **latent_gene_pool_kwargs) if evolutionary_learning else None 1852 | self.has_gene_pool = evolutionary_learning 1853 | 1854 | # init actor critic, taking into account model may not have probabilistic flow to start off with, and determine whether it needs to be reinstantiated for latent conditioning 1855 | 1856 | actor = model 1857 | 1858 | if not model.policy_optimizable or evolutionary_learning: 1859 | actor = model.create_actor(dim_latent = dim_latent) 1860 | 1861 | self.actor = actor 1862 | self.critic = actor.create_critic() 1863 | 1864 | # gradient clipping 1865 | 1866 | self.max_grad_norm = max_grad_norm 1867 | 1868 | # optimizers 1869 | 1870 | self.actor_optim = optim_klass(self.actor.parameters(), lr = actor_lr, weight_decay = actor_weight_decay, **actor_optim_kwargs) 1871 | self.critic_optim = optim_klass(self.critic.parameters(), lr = critic_lr, weight_decay = critic_weight_decay, **critic_optim_kwargs) 1872 | 1873 | def take_genetic_algorithm_step_(self, fitnesses): 1874 | if not self.has_gene_pool: 1875 | return 1876 | 1877 | self.latent_gene_pool.genetic_algorithm_step(fitnesses) 1878 | 1879 | def forward( 1880 | self, 1881 | memories 1882 | ): 1883 | raise NotImplementedError 1884 | 1885 | class EPO(Module): 1886 | def __init__( 1887 | self, 1888 | agent: Agent, 1889 | env, 1890 | accelerate_kwargs: dict = dict() 1891 | ): 1892 | super().__init__() 1893 | self.accelerate = Accelerator(**accelerate_kwargs) 1894 | 1895 | self.agent = agent 1896 | self.env = env 1897 | 1898 | ( 1899 | agent.actor, 1900 | agent.critic, 1901 | agent.actor_optim, 1902 | agent.critic_optim 1903 | ) = self.accelerate.prepare( 1904 | agent.actor, 1905 | agent.critic, 1906 | agent.actor_optim, 1907 | agent.critic_optim 1908 | ) 1909 | 1910 | self.register_buffer('step', tensor(0)) 1911 | 1912 | @property 1913 | def unwrapped_actor(self): 1914 | return self.accelerate.unwrap_model(self.agent.actor) 1915 | 1916 | @property 1917 | def unwrapped_critic(self): 1918 | return self.accelerate.unwrap_model(self.agent.critic) 1919 | 1920 | def log(self, **data_kwargs): 1921 | return self.accelerate.log(data_kwargs, step = self.step.item()) 1922 | 1923 | @torch.no_grad() 1924 | def gather_experience_from_env( 1925 | self, 1926 | steps, 1927 | trajectory_length = 16, 1928 | flow_sampling_steps = 4, 1929 | temperature = 1., 1930 | **sampling_kwargs 1931 | ): 1932 | self.agent.eval() 1933 | 1934 | actor = self.unwrapped_actor 1935 | 1936 | states = self.env.reset() 1937 | 1938 | memories = [] 1939 | 1940 | for _ in range(steps): 1941 | 1942 | sampled_actions, replay_tensors = temp_batch_dim(actor)( 1943 | *states, 1944 | trajectory_length = trajectory_length, 1945 | steps = flow_sampling_steps, 1946 | return_states_for_replay = True, 1947 | temperature = temperature, 1948 | **sampling_kwargs 1949 | ) 1950 | 1951 | next_states, reward, truncated, terminated = self.env.step(sampled_actions) 1952 | 1953 | memories.append(to_device([*states, reward, terminated, *replay_tensors], torch.device('cpu'))) 1954 | 1955 | states = next_states 1956 | 1957 | self.accelerate.wait_for_everyone() 1958 | 1959 | return memories 1960 | 1961 | def learn_agent( 1962 | self, 1963 | memories, 1964 | fitnesses = None, 1965 | epochs = 2, 1966 | batch_size = 16 1967 | ): 1968 | self.agent.train() 1969 | 1970 | ( 1971 | images, 1972 | commands, 1973 | joint_state, 1974 | rewards, 1975 | terminated, 1976 | actions, 1977 | timesteps, 1978 | sampled_flows, 1979 | log_probs 1980 | ) = map(torch.stack, zip(*memories)) 1981 | 1982 | flow_timesteps = actions.shape[1] 1983 | 1984 | values, _ = self.agent.critic( 1985 | repeat(images, 't ... -> (t ft) ...', ft = flow_timesteps), 1986 | repeat(commands, 't ... -> (t ft) ...', ft = flow_timesteps), 1987 | repeat(joint_state, 't ... -> (t ft) ...', ft = flow_timesteps), 1988 | actions = rearrange(actions, 't ft ... -> (t ft) ...'), 1989 | times = rearrange(timesteps, 't ft ... -> (t ft) ...') 1990 | ) 1991 | 1992 | values = rearrange(values, '(t ft) -> ft t', ft = flow_timesteps) 1993 | values = values.detach().cpu() 1994 | 1995 | # actions go out into the environment, rewards are received, generalized advantage calculated with critic values 1996 | 1997 | boundaries = repeat(terminated, 't -> ft t', ft = flow_timesteps) 1998 | 1999 | advantages = calc_generalized_advantage_estimate(rewards, values, boundaries, use_accelerated = False).detach() 2000 | 2001 | # move time back to first dimension to be batched for learning 2002 | 2003 | advantages = rearrange(advantages, 'ft t -> t ft') 2004 | values = rearrange(values, 'ft t -> t ft') 2005 | 2006 | # dataset and dataloader 2007 | 2008 | dataset = TensorDataset( 2009 | images, 2010 | commands, 2011 | joint_state, 2012 | rewards, 2013 | terminated, 2014 | actions, 2015 | timesteps, 2016 | sampled_flows, 2017 | log_probs, 2018 | values, 2019 | advantages 2020 | ) 2021 | 2022 | dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True) 2023 | 2024 | # training loop 2025 | 2026 | for _ in range(epochs): 2027 | for ( 2028 | images, 2029 | commands, 2030 | joint_state, 2031 | rewards, 2032 | terminated, 2033 | actions, 2034 | timesteps, 2035 | sampled_flows, 2036 | log_probs, 2037 | values, 2038 | advantages 2039 | ) in dataloader: 2040 | 2041 | # optimize policy with replay tensors from above 2042 | 2043 | actor_loss = self.agent.actor.forward_for_policy_loss( 2044 | images, 2045 | commands, 2046 | joint_state, 2047 | actions, 2048 | times = timesteps, 2049 | flow = sampled_flows, 2050 | old_log_probs = log_probs, 2051 | advantages = advantages, 2052 | ) 2053 | 2054 | actor_loss.backward() 2055 | 2056 | self.log(actor_loss = actor_loss.item()) 2057 | 2058 | self.accelerate.clip_grad_norm_(self.agent.actor.parameters(), self.agent.max_grad_norm) 2059 | 2060 | self.agent.actor_optim.step() 2061 | self.agent.actor_optim.zero_grad() 2062 | 2063 | critic_loss = self.agent.critic.forward_for_critic_loss( 2064 | repeat(images, 't ... -> (ft t) ...', ft = flow_timesteps), 2065 | repeat(commands, 't ... -> (ft t) ...', ft = flow_timesteps), 2066 | repeat(joint_state, 't ... -> (ft t) ...', ft = flow_timesteps), 2067 | rearrange(actions, 't ft ... -> (ft t) ...'), 2068 | old_values = values, 2069 | advantages = advantages, 2070 | ) 2071 | 2072 | critic_loss.backward() 2073 | 2074 | self.log(critic_loss = critic_loss.item()) 2075 | 2076 | self.accelerate.clip_grad_norm_(self.agent.critic.parameters(), self.agent.max_grad_norm) 2077 | 2078 | self.agent.critic_optim.step() 2079 | self.agent.critic_optim.zero_grad() 2080 | 2081 | if exists(fitnesses): 2082 | self.log(fitnesses = fitnesses) 2083 | 2084 | self.agent.take_genetic_algorithm_step_(fitnesses) 2085 | 2086 | self.step.add_(1) 2087 | 2088 | # fun 2089 | 2090 | π0 = PiZero 2091 | -------------------------------------------------------------------------------- /pi_zero_pytorch/tensor_typing.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | from jaxtyping import ( 4 | Float, 5 | Int, 6 | Bool 7 | ) 8 | 9 | # jaxtyping is a misnomer, works for pytorch 10 | 11 | class TorchTyping: 12 | def __init__(self, abstract_dtype): 13 | self.abstract_dtype = abstract_dtype 14 | 15 | def __getitem__(self, shapes: str): 16 | return self.abstract_dtype[Tensor, shapes] 17 | 18 | Float = TorchTyping(Float) 19 | Int = TorchTyping(Int) 20 | Bool = TorchTyping(Bool) 21 | 22 | __all__ = [ 23 | Float, 24 | Int, 25 | Bool 26 | ] 27 | -------------------------------------------------------------------------------- /plot_time_from_beta.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from pi_zero_pytorch.pi_zero import default_sample_times 3 | 4 | times = default_sample_times((10000,), s = 0.9) 5 | 6 | plt.hist(times.numpy()) 7 | plt.show() 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "pi-zero-pytorch" 3 | version = "0.1.32" 4 | description = "π0 in Pytorch" 5 | authors = [ 6 | { name = "Phil Wang", email = "lucidrains@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">= 3.9" 10 | license = { file = "LICENSE" } 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'deep learning', 14 | 'transformers', 15 | 'flow policy', 16 | 'robotic foundation model', 17 | ] 18 | 19 | classifiers=[ 20 | 'Development Status :: 4 - Beta', 21 | 'Intended Audience :: Developers', 22 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 23 | 'License :: OSI Approved :: MIT License', 24 | 'Programming Language :: Python :: 3.9', 25 | ] 26 | 27 | dependencies = [ 28 | "accelerate>=1.6.0", 29 | "assoc-scan>=0.0.2", 30 | "beartype", 31 | "einx>=0.3.0", 32 | "einops>=0.8.0", 33 | "ema-pytorch>=0.7.3", 34 | "evolutionary-policy-optimization>=0.1.19", 35 | "jaxtyping", 36 | 'hyper-connections>=0.0.10', 37 | "hl-gauss-pytorch>=0.1.21", 38 | "rotary-embedding-torch>=0.8.5", 39 | 'scipy', 40 | "torch>=2.5", 41 | 'torchdiffeq', 42 | 'torchtyping>=0.1.5', 43 | "tqdm" 44 | ] 45 | 46 | [project.urls] 47 | Homepage = "https://pypi.org/project/pi-zero-pytorch/" 48 | Repository = "https://github.com/lucidrains/pi-zero-pytorch" 49 | 50 | [project.optional-dependencies] 51 | examples = [] 52 | test = [ 53 | "pytest", 54 | "ruff>=0.4.2", 55 | "vit-pytorch>=1.8.7" 56 | ] 57 | 58 | [tool.pytest.ini_options] 59 | pythonpath = [ 60 | "." 61 | ] 62 | 63 | [tool.ruff] 64 | line-length = 1000 65 | 66 | lint.ignore = [ 67 | "F722", # for jaxtyping shape annotation 68 | "F401", 69 | "F821" 70 | ] 71 | 72 | lint.extend-select = [ 73 | "W291" 74 | ] 75 | 76 | [build-system] 77 | requires = ["hatchling"] 78 | build-backend = "hatchling.build" 79 | 80 | [tool.rye] 81 | managed = true 82 | dev-dependencies = [] 83 | 84 | [tool.hatch.metadata] 85 | allow-direct-references = true 86 | 87 | [tool.hatch.build.targets.wheel] 88 | packages = ["pi_zero_pytorch"] 89 | -------------------------------------------------------------------------------- /tests/test_pi_zero.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import torch 4 | from pi_zero_pytorch import π0 5 | from einops import repeat, rearrange 6 | 7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 8 | 9 | @pytest.mark.parametrize('only_vlm', (True, False)) 10 | @pytest.mark.parametrize('num_residual_streams', (1, 4)) 11 | def test_pi_zero_with_vit( 12 | only_vlm: bool, 13 | num_residual_streams: int, 14 | ): 15 | from vit_pytorch import ViT 16 | from vit_pytorch.extractor import Extractor 17 | 18 | v = ViT( 19 | image_size = 256, 20 | patch_size = 32, 21 | num_classes = 1000, 22 | dim = 32, 23 | depth = 1, 24 | heads = 16, 25 | dim_head = 16, 26 | mlp_dim = 64, 27 | dropout = 0.1, 28 | emb_dropout = 0.1 29 | ).to(device) 30 | 31 | v = Extractor(v, return_embeddings_only = True) 32 | 33 | model = π0( 34 | dim = 32, 35 | vit = v, 36 | vit_dim = 32, 37 | depth = 1, 38 | dim_action_input = 6, 39 | dim_joint_state = 12, 40 | num_tokens = 32, 41 | num_residual_streams = num_residual_streams, 42 | ).to(device) 43 | 44 | images = torch.randn(2, 3, 2, 256, 256) 45 | commands = torch.randint(0, 32, (2, 1024)) 46 | 47 | if only_vlm: 48 | vlm_logits = model.forward_only_vision_language(images, commands) 49 | assert vlm_logits.ndim == 3 50 | return 51 | 52 | joint_state = torch.randn(2, 12) 53 | actions = torch.randn(2, 32, 6) 54 | 55 | loss, _ = model(images, commands, joint_state, actions) 56 | loss.backward() 57 | 58 | # after much training 59 | 60 | sampled_actions = model(images, commands, joint_state, trajectory_length = 32) # (1, 32, 6) 61 | 62 | assert sampled_actions.shape == (2, 32, 6) 63 | 64 | @pytest.mark.parametrize('num_latent_genes', (1, 16)) 65 | def test_policy_optimization( 66 | num_latent_genes 67 | ): 68 | 69 | from vit_pytorch import ViT 70 | from vit_pytorch.extractor import Extractor 71 | 72 | from pi_zero_pytorch.pi_zero import ( 73 | Agent, 74 | EPO, 75 | ) 76 | 77 | from pi_zero_pytorch.mock_env import Env 78 | 79 | v = ViT( 80 | image_size = 256, 81 | patch_size = 32, 82 | num_classes = 1000, 83 | dim = 32, 84 | depth = 1, 85 | heads = 2, 86 | dim_head = 8, 87 | mlp_dim = 16, 88 | dropout = 0.1, 89 | emb_dropout = 0.1 90 | ) 91 | 92 | v = Extractor(v, return_embeddings_only = True) 93 | 94 | model = π0( 95 | dim = 32, 96 | vit = v, 97 | vit_dim = 32, 98 | depth = 1, 99 | dim_action_input = 6, 100 | dim_joint_state = 12, 101 | num_tokens = 32, 102 | policy_optimizable = True 103 | ).to(device) 104 | 105 | images = torch.randn(2, 3, 2, 256, 256) 106 | commands = torch.randint(0, 32, (2, 1024)) 107 | 108 | joint_state = torch.randn(2, 12) 109 | actions = torch.randn(2, 32, 6) 110 | 111 | loss, _ = model(images, commands, joint_state, actions) 112 | loss.backward() 113 | 114 | # agent 115 | 116 | agent = Agent( 117 | model, 118 | num_latent_genes = num_latent_genes 119 | ) 120 | 121 | mock_env = Env((256, 256), 2, 32, 1024, 12) 122 | 123 | epo = EPO( 124 | agent, 125 | mock_env, 126 | accelerate_kwargs = dict( 127 | cpu = True 128 | ) 129 | ) 130 | 131 | memories = epo.gather_experience_from_env(steps = 10) 132 | 133 | epo.learn_agent(memories, batch_size = 2) 134 | --------------------------------------------------------------------------------