├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── robotic_transformer_pytorch ├── __init__.py └── robotic_transformer_pytorch.py ├── rt1.png └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Robotic Transformer - Pytorch 4 | 5 | Implementation of RT1 (Robotic Transformer), from the Robotics at Google team, in Pytorch 6 | 7 | ## Install 8 | 9 | ```bash 10 | $ pip install robotic-transformer-pytorch 11 | ``` 12 | 13 | ## Usage 14 | 15 | ```python 16 | import torch 17 | from robotic_transformer_pytorch import MaxViT, RT1 18 | 19 | vit = MaxViT( 20 | num_classes = 1000, 21 | dim_conv_stem = 64, 22 | dim = 96, 23 | dim_head = 32, 24 | depth = (2, 2, 5, 2), 25 | window_size = 7, 26 | mbconv_expansion_rate = 4, 27 | mbconv_shrinkage_rate = 0.25, 28 | dropout = 0.1 29 | ) 30 | 31 | model = RT1( 32 | vit = vit, 33 | num_actions = 11, 34 | depth = 6, 35 | heads = 8, 36 | dim_head = 64, 37 | cond_drop_prob = 0.2 38 | ) 39 | 40 | video = torch.randn(2, 3, 6, 224, 224) 41 | 42 | instructions = [ 43 | 'bring me that apple sitting on the table', 44 | 'please pass the butter' 45 | ] 46 | 47 | train_logits = model(video, instructions) # (2, 6, 11, 256) # (batch, frames, actions, bins) 48 | 49 | # after much training 50 | 51 | model.eval() 52 | eval_logits = model(video, instructions, cond_scale = 3.) # classifier free guidance with conditional scale of 3 53 | 54 | ``` 55 | 56 | ## Appreciation 57 | 58 | - Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research 59 | 60 | 61 | ## Todo 62 | 63 | - [x] add classifier free guidance option 64 | - [x] add cross attention based conditioning 65 | 66 | ## Citations 67 | 68 | ```bibtex 69 | @inproceedings{rt12022arxiv, 70 | title = {RT-1: Robotics Transformer for Real-World Control at Scale}, 71 | author = {Anthony Brohan and Noah Brown and Justice Carbajal and Yevgen Chebotar and Joseph Dabis and Chelsea Finn and Keerthana Gopalakrishnan and Karol Hausman and Alex Herzog and Jasmine Hsu and Julian Ibarz and Brian Ichter and Alex Irpan and Tomas Jackson and Sally Jesmonth and Nikhil Joshi and Ryan Julian and Dmitry Kalashnikov and Yuheng Kuang and Isabel Leal and Kuang-Huei Lee and Sergey Levine and Yao Lu and Utsav Malla and Deeksha Manjunath and Igor Mordatch and Ofir Nachum and Carolina Parada and Jodilyn Peralta and Emily Perez and Karl Pertsch and Jornell Quiambao and Kanishka Rao and Michael Ryoo and Grecia Salazar and Pannag Sanketi and Kevin Sayed and Jaspiar Singh and Sumedh Sontakke and Austin Stone and Clayton Tan and Huong Tran and Vincent Vanhoucke and Steve Vega and Quan Vuong and Fei Xia and Ted Xiao and Peng Xu and Sichun Xu and Tianhe Yu and Brianna Zitkovich}, 72 | booktitle = {arXiv preprint arXiv:2204.01691}, 73 | year = {2022} 74 | } 75 | ``` 76 | 77 | ```bibtex 78 | @inproceedings{Tu2022MaxViTMV, 79 | title = {MaxViT: Multi-Axis Vision Transformer}, 80 | author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li}, 81 | year = {2022} 82 | } 83 | ``` 84 | 85 | ```bibtex 86 | @misc{peebles2022scalable, 87 | title = {Scalable Diffusion Models with Transformers}, 88 | author = {William Peebles and Saining Xie}, 89 | year = {2022}, 90 | eprint = {2212.09748}, 91 | archivePrefix = {arXiv}, 92 | primaryClass = {cs.CV} 93 | } 94 | ``` 95 | -------------------------------------------------------------------------------- /robotic_transformer_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from robotic_transformer_pytorch.robotic_transformer_pytorch import RT1, TokenLearner, MaxViT 2 | -------------------------------------------------------------------------------- /robotic_transformer_pytorch/robotic_transformer_pytorch.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch 4 | from torch.nn import Module, ModuleList 5 | import torch.nn.functional as F 6 | from torch import nn, einsum, Tensor 7 | 8 | from typing import Callable 9 | from beartype import beartype 10 | 11 | from einops import pack, unpack, repeat, reduce, rearrange 12 | from einops.layers.torch import Rearrange, Reduce 13 | 14 | from functools import partial 15 | 16 | from classifier_free_guidance_pytorch import TextConditioner, AttentionTextConditioner, classifier_free_guidance 17 | 18 | # helpers 19 | 20 | def exists(val): 21 | return val is not None 22 | 23 | def default(val, d): 24 | return val if exists(val) else d 25 | 26 | def cast_tuple(val, length = 1): 27 | return val if isinstance(val, tuple) else ((val,) * length) 28 | 29 | def pack_one(x, pattern): 30 | return pack([x], pattern) 31 | 32 | def unpack_one(x, ps, pattern): 33 | return unpack(x, ps, pattern)[0] 34 | 35 | # sinusoidal positions 36 | 37 | def posemb_sincos_1d(seq, dim, temperature = 10000, device = None, dtype = torch.float32): 38 | n = torch.arange(seq, device = device) 39 | omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1) 40 | omega = 1. / (temperature ** omega) 41 | 42 | n = n[:, None] * omega[None, :] 43 | pos_emb = torch.cat((n.sin(), n.cos()), dim = 1) 44 | return pos_emb.type(dtype) 45 | 46 | # helper classes 47 | 48 | class Residual(Module): 49 | def __init__(self, fn): 50 | super().__init__() 51 | self.fn = fn 52 | 53 | def forward(self, x): 54 | return self.fn(x) + x 55 | 56 | class LayerNorm(Module): 57 | def __init__(self, dim): 58 | super().__init__() 59 | self.gamma = nn.Parameter(torch.ones(dim)) 60 | self.register_buffer("beta", torch.zeros(dim)) 61 | 62 | def forward(self, x): 63 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) 64 | 65 | class FeedForward(Module): 66 | def __init__(self, dim, mult = 4, dropout = 0.): 67 | super().__init__() 68 | inner_dim = int(dim * mult) 69 | self.norm = LayerNorm(dim) 70 | 71 | self.net = nn.Sequential( 72 | nn.Linear(dim, inner_dim), 73 | nn.GELU(), 74 | nn.Dropout(dropout), 75 | nn.Linear(inner_dim, dim), 76 | nn.Dropout(dropout) 77 | ) 78 | def forward(self, x, cond_fn = None): 79 | x = self.norm(x) 80 | 81 | if exists(cond_fn): 82 | # adaptive layernorm 83 | x = cond_fn(x) 84 | 85 | return self.net(x) 86 | 87 | # MBConv 88 | 89 | class SqueezeExcitation(Module): 90 | def __init__(self, dim, shrinkage_rate = 0.25): 91 | super().__init__() 92 | hidden_dim = int(dim * shrinkage_rate) 93 | 94 | self.gate = nn.Sequential( 95 | Reduce('b c h w -> b c', 'mean'), 96 | nn.Linear(dim, hidden_dim, bias = False), 97 | nn.SiLU(), 98 | nn.Linear(hidden_dim, dim, bias = False), 99 | nn.Sigmoid(), 100 | Rearrange('b c -> b c 1 1') 101 | ) 102 | 103 | def forward(self, x): 104 | return x * self.gate(x) 105 | 106 | 107 | class MBConvResidual(Module): 108 | def __init__(self, fn, dropout = 0.): 109 | super().__init__() 110 | self.fn = fn 111 | self.dropsample = Dropsample(dropout) 112 | 113 | def forward(self, x): 114 | out = self.fn(x) 115 | out = self.dropsample(out) 116 | return out + x 117 | 118 | class Dropsample(Module): 119 | def __init__(self, prob = 0): 120 | super().__init__() 121 | self.prob = prob 122 | 123 | def forward(self, x): 124 | device = x.device 125 | 126 | if self.prob == 0. or (not self.training): 127 | return x 128 | 129 | keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob 130 | return x * keep_mask / (1 - self.prob) 131 | 132 | def MBConv( 133 | dim_in, 134 | dim_out, 135 | *, 136 | downsample, 137 | expansion_rate = 4, 138 | shrinkage_rate = 0.25, 139 | dropout = 0. 140 | ): 141 | hidden_dim = int(expansion_rate * dim_out) 142 | stride = 2 if downsample else 1 143 | 144 | net = nn.Sequential( 145 | nn.Conv2d(dim_in, hidden_dim, 1), 146 | nn.BatchNorm2d(hidden_dim), 147 | nn.GELU(), 148 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim), 149 | nn.BatchNorm2d(hidden_dim), 150 | nn.GELU(), 151 | SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate), 152 | nn.Conv2d(hidden_dim, dim_out, 1), 153 | nn.BatchNorm2d(dim_out) 154 | ) 155 | 156 | if dim_in == dim_out and not downsample: 157 | net = MBConvResidual(net, dropout = dropout) 158 | 159 | return net 160 | 161 | # attention related classes 162 | 163 | class Attention(Module): 164 | def __init__( 165 | self, 166 | dim, 167 | dim_head = 32, 168 | dropout = 0., 169 | window_size = 7, 170 | num_mem_kv = 4 171 | ): 172 | super().__init__() 173 | assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head' 174 | 175 | self.norm = LayerNorm(dim) 176 | 177 | self.heads = dim // dim_head 178 | self.scale = dim_head ** -0.5 179 | 180 | self.to_qkv = nn.Linear(dim, dim * 3, bias = False) 181 | 182 | self.mem_kv = nn.Parameter(torch.randn(2, self.heads, num_mem_kv, dim_head)) 183 | 184 | self.attend = nn.Sequential( 185 | nn.Softmax(dim = -1), 186 | nn.Dropout(dropout) 187 | ) 188 | 189 | self.to_out = nn.Sequential( 190 | nn.Linear(dim, dim, bias = False), 191 | nn.Dropout(dropout) 192 | ) 193 | 194 | # relative positional bias 195 | 196 | self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads) 197 | 198 | pos = torch.arange(window_size) 199 | grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij')) 200 | grid = rearrange(grid, 'c i j -> (i j) c') 201 | rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...') 202 | rel_pos += window_size - 1 203 | rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1) 204 | 205 | self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False) 206 | 207 | def forward(self, x): 208 | batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads 209 | 210 | x = self.norm(x) 211 | 212 | # flatten 213 | 214 | x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d') 215 | 216 | # project for queries, keys, values 217 | 218 | q, k, v = self.to_qkv(x).chunk(3, dim = -1) 219 | 220 | # split heads 221 | 222 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) 223 | 224 | # scale 225 | 226 | q = q * self.scale 227 | 228 | # null / memory / register kv 229 | 230 | mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = q.shape[0]), self.mem_kv) 231 | num_mem = mk.shape[-2] 232 | 233 | k = torch.cat((mk, k), dim = -2) 234 | v = torch.cat((mv, v), dim = -2) 235 | 236 | # sim 237 | 238 | sim = einsum('b h i d, b h j d -> b h i j', q, k) 239 | 240 | # add positional bias 241 | 242 | bias = self.rel_pos_bias(self.rel_pos_indices) 243 | 244 | bias = F.pad(bias, (0, 0, num_mem, 0), value = 0.) 245 | 246 | sim = sim + rearrange(bias, 'i j h -> h i j') 247 | 248 | # attention 249 | 250 | attn = self.attend(sim) 251 | 252 | # aggregate 253 | 254 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 255 | 256 | # merge heads 257 | 258 | out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width) 259 | 260 | # combine heads out 261 | 262 | out = self.to_out(out) 263 | return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width) 264 | 265 | class MaxViT(Module): 266 | def __init__( 267 | self, 268 | *, 269 | num_classes, 270 | dim, 271 | depth, 272 | dim_head = 32, 273 | dim_conv_stem = None, 274 | window_size = 7, 275 | mbconv_expansion_rate = 4, 276 | mbconv_shrinkage_rate = 0.25, 277 | dropout = 0.1, 278 | channels = 3 279 | ): 280 | super().__init__() 281 | assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage' 282 | 283 | # convolutional stem 284 | 285 | dim_conv_stem = default(dim_conv_stem, dim) 286 | 287 | self.conv_stem = nn.Sequential( 288 | nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1), 289 | nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1) 290 | ) 291 | 292 | # variables 293 | 294 | num_stages = len(depth) 295 | 296 | dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages))) 297 | dims = (dim_conv_stem, *dims) 298 | dim_pairs = tuple(zip(dims[:-1], dims[1:])) 299 | 300 | self.layers = ModuleList([]) 301 | 302 | # shorthand for window size for efficient block - grid like attention 303 | 304 | w = window_size 305 | 306 | # iterate through stages 307 | 308 | cond_hidden_dims = [] 309 | 310 | for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)): 311 | for stage_ind in range(layer_depth): 312 | is_first = stage_ind == 0 313 | stage_dim_in = layer_dim_in if is_first else layer_dim 314 | 315 | cond_hidden_dims.append(stage_dim_in) 316 | 317 | block = nn.Sequential( 318 | MBConv( 319 | stage_dim_in, 320 | layer_dim, 321 | downsample = is_first, 322 | expansion_rate = mbconv_expansion_rate, 323 | shrinkage_rate = mbconv_shrinkage_rate 324 | ), 325 | Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention 326 | Residual(Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)), 327 | Residual(FeedForward(dim = layer_dim, dropout = dropout)), 328 | Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'), 329 | 330 | Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention 331 | Residual(Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)), 332 | Residual(FeedForward(dim = layer_dim, dropout = dropout)), 333 | Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'), 334 | ) 335 | 336 | self.layers.append(block) 337 | 338 | embed_dim = dims[-1] 339 | self.embed_dim = dims[-1] 340 | 341 | self.cond_hidden_dims = cond_hidden_dims 342 | 343 | # mlp head out 344 | 345 | self.mlp_head = nn.Sequential( 346 | Reduce('b d h w -> b d', 'mean'), 347 | LayerNorm(embed_dim), 348 | nn.Linear(embed_dim, num_classes) 349 | ) 350 | 351 | @beartype 352 | def forward( 353 | self, 354 | x, 355 | texts: list[str] | None = None, 356 | cond_fns: tuple[Callable, ...] | None = None, 357 | cond_drop_prob = 0., 358 | return_embeddings = False 359 | ): 360 | x = self.conv_stem(x) 361 | 362 | cond_fns = iter(default(cond_fns, [])) 363 | 364 | for stage in self.layers: 365 | cond_fn = next(cond_fns, None) 366 | 367 | if exists(cond_fn): 368 | x = cond_fn(x) 369 | 370 | x = stage(x) 371 | 372 | if return_embeddings: 373 | return x 374 | 375 | return self.mlp_head(x) 376 | 377 | # attention 378 | 379 | class TransformerAttention(Module): 380 | def __init__( 381 | self, 382 | dim, 383 | causal = False, 384 | dim_head = 64, 385 | dim_context = None, 386 | heads = 8, 387 | norm_context = False, 388 | dropout = 0.1 389 | ): 390 | super().__init__() 391 | self.heads = heads 392 | self.scale = dim_head ** -0.5 393 | self.causal = causal 394 | inner_dim = dim_head * heads 395 | 396 | dim_context = default(dim_context, dim) 397 | 398 | self.norm = LayerNorm(dim) 399 | self.context_norm = LayerNorm(dim_context) if norm_context else nn.Identity() 400 | 401 | self.attn_dropout = nn.Dropout(dropout) 402 | 403 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 404 | self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False) 405 | self.to_out = nn.Sequential( 406 | nn.Linear(inner_dim, dim, bias = False), 407 | nn.Dropout(dropout) 408 | ) 409 | 410 | def forward( 411 | self, 412 | x, 413 | context = None, 414 | mask = None, 415 | attn_bias = None, 416 | attn_mask = None, 417 | cond_fn: Callable | None = None 418 | ): 419 | b = x.shape[0] 420 | 421 | if exists(context): 422 | context = self.context_norm(context) 423 | 424 | kv_input = default(context, x) 425 | 426 | x = self.norm(x) 427 | 428 | if exists(cond_fn): 429 | # adaptive layer-norm 430 | x = cond_fn(x) 431 | 432 | q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1) 433 | 434 | q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) 435 | 436 | q = q * self.scale 437 | 438 | sim = einsum('b h i d, b j d -> b h i j', q, k) 439 | 440 | if exists(attn_bias): 441 | sim = sim + attn_bias 442 | 443 | if exists(attn_mask): 444 | sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max) 445 | 446 | if exists(mask): 447 | mask = rearrange(mask, 'b j -> b 1 1 j') 448 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) 449 | 450 | if self.causal: 451 | i, j = sim.shape[-2:] 452 | causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1) 453 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 454 | 455 | attn = sim.softmax(dim = -1) 456 | attn = self.attn_dropout(attn) 457 | 458 | out = einsum('b h i j, b j d -> b h i d', attn, v) 459 | 460 | out = rearrange(out, 'b h n d -> b n (h d)') 461 | return self.to_out(out) 462 | 463 | class Transformer(Module): 464 | @beartype 465 | def __init__( 466 | self, 467 | dim, 468 | dim_head = 64, 469 | heads = 8, 470 | depth = 6, 471 | attn_dropout = 0., 472 | ff_dropout = 0. 473 | ): 474 | super().__init__() 475 | self.layers = ModuleList([]) 476 | for _ in range(depth): 477 | self.layers.append(ModuleList([ 478 | TransformerAttention(dim = dim, heads = heads, dropout = attn_dropout), 479 | FeedForward(dim = dim, dropout = ff_dropout) 480 | ])) 481 | 482 | @beartype 483 | def forward( 484 | self, 485 | x, 486 | cond_fns: tuple[Callable, ...] | None = None, 487 | attn_mask = None 488 | ): 489 | cond_fns = iter(default(cond_fns, [])) 490 | 491 | for attn, ff in self.layers: 492 | x = attn(x, attn_mask = attn_mask, cond_fn = next(cond_fns, None)) + x 493 | x = ff(x, cond_fn = next(cond_fns, None)) + x 494 | return x 495 | 496 | # token learner module 497 | 498 | class TokenLearner(Module): 499 | """ 500 | https://arxiv.org/abs/2106.11297 501 | using the 1.1 version with the MLP (2 dense layers with gelu) for generating attention map 502 | """ 503 | 504 | def __init__( 505 | self, 506 | *, 507 | dim, 508 | ff_mult = 2, 509 | num_output_tokens = 8, 510 | num_layers = 2 511 | ): 512 | super().__init__() 513 | inner_dim = dim * ff_mult * num_output_tokens 514 | 515 | self.num_output_tokens = num_output_tokens 516 | self.net = nn.Sequential( 517 | nn.Conv2d(dim * num_output_tokens, inner_dim, 1, groups = num_output_tokens), 518 | nn.GELU(), 519 | nn.Conv2d(inner_dim, num_output_tokens, 1, groups = num_output_tokens), 520 | ) 521 | 522 | def forward(self, x): 523 | x, ps = pack_one(x, '* c h w') 524 | x = repeat(x, 'b c h w -> b (g c) h w', g = self.num_output_tokens) 525 | attn = self.net(x) 526 | 527 | attn = rearrange(attn, 'b g h w -> b 1 g h w') 528 | x = rearrange(x, 'b (g c) h w -> b c g h w', g = self.num_output_tokens) 529 | 530 | x = reduce(x * attn, 'b c g h w -> b c g', 'mean') 531 | x = unpack_one(x, ps, '* c n') 532 | return x 533 | 534 | # Robotic Transformer 535 | 536 | class RT1(Module): 537 | @beartype 538 | def __init__( 539 | self, 540 | *, 541 | vit: MaxViT, 542 | num_actions = 11, 543 | action_bins = 256, 544 | depth = 6, 545 | heads = 8, 546 | dim_head = 64, 547 | token_learner_ff_mult = 2, 548 | token_learner_num_layers = 2, 549 | token_learner_num_output_tokens = 8, 550 | cond_drop_prob = 0.2, 551 | use_attn_conditioner = False, 552 | conditioner_kwargs: dict = dict() 553 | ): 554 | super().__init__() 555 | self.vit = vit 556 | 557 | self.num_vit_stages = len(vit.cond_hidden_dims) 558 | 559 | conditioner_klass = AttentionTextConditioner if use_attn_conditioner else TextConditioner 560 | 561 | self.conditioner = conditioner_klass( 562 | hidden_dims = (*tuple(vit.cond_hidden_dims), *((vit.embed_dim,) * depth * 2)), 563 | hiddens_channel_first = (*((True,) * self.num_vit_stages), *((False,) * depth * 2)), 564 | cond_drop_prob = cond_drop_prob, 565 | **conditioner_kwargs 566 | ) 567 | 568 | self.token_learner = TokenLearner( 569 | dim = vit.embed_dim, 570 | ff_mult = token_learner_ff_mult, 571 | num_output_tokens = token_learner_num_output_tokens, 572 | num_layers = token_learner_num_layers 573 | ) 574 | 575 | self.num_learned_tokens = token_learner_num_output_tokens 576 | 577 | self.transformer_depth = depth 578 | 579 | self.transformer = Transformer( 580 | dim = vit.embed_dim, 581 | dim_head = dim_head, 582 | heads = heads, 583 | depth = depth 584 | ) 585 | 586 | self.cond_drop_prob = cond_drop_prob 587 | 588 | self.to_logits = nn.Sequential( 589 | LayerNorm(vit.embed_dim), 590 | nn.Linear(vit.embed_dim, num_actions * action_bins), 591 | Rearrange('... (a b) -> ... a b', b = action_bins) 592 | ) 593 | 594 | @beartype 595 | def embed_texts(self, texts: list[str]): 596 | return self.conditioner.embed_texts(texts) 597 | 598 | @classifier_free_guidance 599 | @beartype 600 | def forward( 601 | self, 602 | video, 603 | texts: list[str] | None = None, 604 | text_embeds: Tensor | None = None, 605 | cond_drop_prob = 0. 606 | ): 607 | assert exists(texts) ^ exists(text_embeds) 608 | 609 | if exists(texts): 610 | num_texts = len(texts) 611 | elif exists(text_embeds): 612 | num_texts = text_embeds.shape[0] 613 | 614 | assert num_texts == video.shape[0], f'you only passed in {num_texts} strings for guiding the robot actions, but received batch size of {video.shape[0]} videos' 615 | 616 | cond_kwargs = dict(texts = texts, text_embeds = text_embeds) 617 | 618 | depth = self.transformer_depth 619 | cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob) 620 | 621 | frames, device = video.shape[2], video.device 622 | 623 | cond_fns, _ = self.conditioner( 624 | **cond_kwargs, 625 | cond_drop_prob = cond_drop_prob, 626 | repeat_batch = (*((frames,) * self.num_vit_stages), *((1,) * self.transformer_depth * 2)) 627 | ) 628 | 629 | vit_cond_fns, transformer_cond_fns = cond_fns[:-(depth * 2)], cond_fns[-(depth * 2):] 630 | 631 | video = rearrange(video, 'b c f h w -> b f c h w') 632 | images, packed_shape = pack_one(video, '* c h w') 633 | 634 | tokens = self.vit( 635 | images, 636 | texts = texts, 637 | cond_fns = vit_cond_fns, 638 | cond_drop_prob = cond_drop_prob, 639 | return_embeddings = True 640 | ) 641 | 642 | tokens = unpack_one(tokens, packed_shape, '* c h w') 643 | learned_tokens = self.token_learner(tokens) 644 | 645 | learned_tokens = rearrange(learned_tokens, 'b f c n -> b (f n) c') 646 | 647 | # causal attention mask 648 | 649 | attn_mask = torch.ones((frames, frames), dtype = torch.bool, device = device).triu(1) 650 | attn_mask = repeat(attn_mask, 'i j -> (i r1) (j r2)', r1 = self.num_learned_tokens, r2 = self.num_learned_tokens) 651 | 652 | # sinusoidal positional embedding 653 | 654 | pos_emb = posemb_sincos_1d(frames, learned_tokens.shape[-1], dtype = learned_tokens.dtype, device = learned_tokens.device) 655 | 656 | learned_tokens = learned_tokens + repeat(pos_emb, 'n d -> (n r) d', r = self.num_learned_tokens) 657 | 658 | # attention 659 | 660 | attended_tokens = self.transformer(learned_tokens, cond_fns = transformer_cond_fns, attn_mask = ~attn_mask) 661 | 662 | pooled = reduce(attended_tokens, 'b (f n) d -> b f d', 'mean', f = frames) 663 | 664 | logits = self.to_logits(pooled) 665 | return logits 666 | -------------------------------------------------------------------------------- /rt1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/robotic-transformer-pytorch/1512c9b460944accb2d874e4f3354a8e196f50e8/rt1.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'robotic-transformer-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.2.3', 7 | license='MIT', 8 | description = 'Robotic Transformer - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/robotic-transformer-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'transformers', 17 | 'attention mechanism', 18 | 'robotics' 19 | ], 20 | install_requires=[ 21 | 'classifier-free-guidance-pytorch>=0.7.1', 22 | 'einops>=0.8', 23 | 'torch>=2.0', 24 | ], 25 | classifiers=[ 26 | 'Development Status :: 4 - Beta', 27 | 'Intended Audience :: Developers', 28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 29 | 'License :: OSI Approved :: MIT License', 30 | 'Programming Language :: Python :: 3.6', 31 | ], 32 | ) 33 | --------------------------------------------------------------------------------