├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── robotic_transformer_pytorch
├── __init__.py
└── robotic_transformer_pytorch.py
├── rt1.png
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Robotic Transformer - Pytorch
4 |
5 | Implementation of RT1 (Robotic Transformer), from the Robotics at Google team, in Pytorch
6 |
7 | ## Install
8 |
9 | ```bash
10 | $ pip install robotic-transformer-pytorch
11 | ```
12 |
13 | ## Usage
14 |
15 | ```python
16 | import torch
17 | from robotic_transformer_pytorch import MaxViT, RT1
18 |
19 | vit = MaxViT(
20 | num_classes = 1000,
21 | dim_conv_stem = 64,
22 | dim = 96,
23 | dim_head = 32,
24 | depth = (2, 2, 5, 2),
25 | window_size = 7,
26 | mbconv_expansion_rate = 4,
27 | mbconv_shrinkage_rate = 0.25,
28 | dropout = 0.1
29 | )
30 |
31 | model = RT1(
32 | vit = vit,
33 | num_actions = 11,
34 | depth = 6,
35 | heads = 8,
36 | dim_head = 64,
37 | cond_drop_prob = 0.2
38 | )
39 |
40 | video = torch.randn(2, 3, 6, 224, 224)
41 |
42 | instructions = [
43 | 'bring me that apple sitting on the table',
44 | 'please pass the butter'
45 | ]
46 |
47 | train_logits = model(video, instructions) # (2, 6, 11, 256) # (batch, frames, actions, bins)
48 |
49 | # after much training
50 |
51 | model.eval()
52 | eval_logits = model(video, instructions, cond_scale = 3.) # classifier free guidance with conditional scale of 3
53 |
54 | ```
55 |
56 | ## Appreciation
57 |
58 | - Stability.ai for the generous sponsorship to work and open source cutting edge artificial intelligence research
59 |
60 |
61 | ## Todo
62 |
63 | - [x] add classifier free guidance option
64 | - [x] add cross attention based conditioning
65 |
66 | ## Citations
67 |
68 | ```bibtex
69 | @inproceedings{rt12022arxiv,
70 | title = {RT-1: Robotics Transformer for Real-World Control at Scale},
71 | author = {Anthony Brohan and Noah Brown and Justice Carbajal and Yevgen Chebotar and Joseph Dabis and Chelsea Finn and Keerthana Gopalakrishnan and Karol Hausman and Alex Herzog and Jasmine Hsu and Julian Ibarz and Brian Ichter and Alex Irpan and Tomas Jackson and Sally Jesmonth and Nikhil Joshi and Ryan Julian and Dmitry Kalashnikov and Yuheng Kuang and Isabel Leal and Kuang-Huei Lee and Sergey Levine and Yao Lu and Utsav Malla and Deeksha Manjunath and Igor Mordatch and Ofir Nachum and Carolina Parada and Jodilyn Peralta and Emily Perez and Karl Pertsch and Jornell Quiambao and Kanishka Rao and Michael Ryoo and Grecia Salazar and Pannag Sanketi and Kevin Sayed and Jaspiar Singh and Sumedh Sontakke and Austin Stone and Clayton Tan and Huong Tran and Vincent Vanhoucke and Steve Vega and Quan Vuong and Fei Xia and Ted Xiao and Peng Xu and Sichun Xu and Tianhe Yu and Brianna Zitkovich},
72 | booktitle = {arXiv preprint arXiv:2204.01691},
73 | year = {2022}
74 | }
75 | ```
76 |
77 | ```bibtex
78 | @inproceedings{Tu2022MaxViTMV,
79 | title = {MaxViT: Multi-Axis Vision Transformer},
80 | author = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
81 | year = {2022}
82 | }
83 | ```
84 |
85 | ```bibtex
86 | @misc{peebles2022scalable,
87 | title = {Scalable Diffusion Models with Transformers},
88 | author = {William Peebles and Saining Xie},
89 | year = {2022},
90 | eprint = {2212.09748},
91 | archivePrefix = {arXiv},
92 | primaryClass = {cs.CV}
93 | }
94 | ```
95 |
--------------------------------------------------------------------------------
/robotic_transformer_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from robotic_transformer_pytorch.robotic_transformer_pytorch import RT1, TokenLearner, MaxViT
2 |
--------------------------------------------------------------------------------
/robotic_transformer_pytorch/robotic_transformer_pytorch.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import torch
4 | from torch.nn import Module, ModuleList
5 | import torch.nn.functional as F
6 | from torch import nn, einsum, Tensor
7 |
8 | from typing import Callable
9 | from beartype import beartype
10 |
11 | from einops import pack, unpack, repeat, reduce, rearrange
12 | from einops.layers.torch import Rearrange, Reduce
13 |
14 | from functools import partial
15 |
16 | from classifier_free_guidance_pytorch import TextConditioner, AttentionTextConditioner, classifier_free_guidance
17 |
18 | # helpers
19 |
20 | def exists(val):
21 | return val is not None
22 |
23 | def default(val, d):
24 | return val if exists(val) else d
25 |
26 | def cast_tuple(val, length = 1):
27 | return val if isinstance(val, tuple) else ((val,) * length)
28 |
29 | def pack_one(x, pattern):
30 | return pack([x], pattern)
31 |
32 | def unpack_one(x, ps, pattern):
33 | return unpack(x, ps, pattern)[0]
34 |
35 | # sinusoidal positions
36 |
37 | def posemb_sincos_1d(seq, dim, temperature = 10000, device = None, dtype = torch.float32):
38 | n = torch.arange(seq, device = device)
39 | omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
40 | omega = 1. / (temperature ** omega)
41 |
42 | n = n[:, None] * omega[None, :]
43 | pos_emb = torch.cat((n.sin(), n.cos()), dim = 1)
44 | return pos_emb.type(dtype)
45 |
46 | # helper classes
47 |
48 | class Residual(Module):
49 | def __init__(self, fn):
50 | super().__init__()
51 | self.fn = fn
52 |
53 | def forward(self, x):
54 | return self.fn(x) + x
55 |
56 | class LayerNorm(Module):
57 | def __init__(self, dim):
58 | super().__init__()
59 | self.gamma = nn.Parameter(torch.ones(dim))
60 | self.register_buffer("beta", torch.zeros(dim))
61 |
62 | def forward(self, x):
63 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
64 |
65 | class FeedForward(Module):
66 | def __init__(self, dim, mult = 4, dropout = 0.):
67 | super().__init__()
68 | inner_dim = int(dim * mult)
69 | self.norm = LayerNorm(dim)
70 |
71 | self.net = nn.Sequential(
72 | nn.Linear(dim, inner_dim),
73 | nn.GELU(),
74 | nn.Dropout(dropout),
75 | nn.Linear(inner_dim, dim),
76 | nn.Dropout(dropout)
77 | )
78 | def forward(self, x, cond_fn = None):
79 | x = self.norm(x)
80 |
81 | if exists(cond_fn):
82 | # adaptive layernorm
83 | x = cond_fn(x)
84 |
85 | return self.net(x)
86 |
87 | # MBConv
88 |
89 | class SqueezeExcitation(Module):
90 | def __init__(self, dim, shrinkage_rate = 0.25):
91 | super().__init__()
92 | hidden_dim = int(dim * shrinkage_rate)
93 |
94 | self.gate = nn.Sequential(
95 | Reduce('b c h w -> b c', 'mean'),
96 | nn.Linear(dim, hidden_dim, bias = False),
97 | nn.SiLU(),
98 | nn.Linear(hidden_dim, dim, bias = False),
99 | nn.Sigmoid(),
100 | Rearrange('b c -> b c 1 1')
101 | )
102 |
103 | def forward(self, x):
104 | return x * self.gate(x)
105 |
106 |
107 | class MBConvResidual(Module):
108 | def __init__(self, fn, dropout = 0.):
109 | super().__init__()
110 | self.fn = fn
111 | self.dropsample = Dropsample(dropout)
112 |
113 | def forward(self, x):
114 | out = self.fn(x)
115 | out = self.dropsample(out)
116 | return out + x
117 |
118 | class Dropsample(Module):
119 | def __init__(self, prob = 0):
120 | super().__init__()
121 | self.prob = prob
122 |
123 | def forward(self, x):
124 | device = x.device
125 |
126 | if self.prob == 0. or (not self.training):
127 | return x
128 |
129 | keep_mask = torch.FloatTensor((x.shape[0], 1, 1, 1), device = device).uniform_() > self.prob
130 | return x * keep_mask / (1 - self.prob)
131 |
132 | def MBConv(
133 | dim_in,
134 | dim_out,
135 | *,
136 | downsample,
137 | expansion_rate = 4,
138 | shrinkage_rate = 0.25,
139 | dropout = 0.
140 | ):
141 | hidden_dim = int(expansion_rate * dim_out)
142 | stride = 2 if downsample else 1
143 |
144 | net = nn.Sequential(
145 | nn.Conv2d(dim_in, hidden_dim, 1),
146 | nn.BatchNorm2d(hidden_dim),
147 | nn.GELU(),
148 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = hidden_dim),
149 | nn.BatchNorm2d(hidden_dim),
150 | nn.GELU(),
151 | SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
152 | nn.Conv2d(hidden_dim, dim_out, 1),
153 | nn.BatchNorm2d(dim_out)
154 | )
155 |
156 | if dim_in == dim_out and not downsample:
157 | net = MBConvResidual(net, dropout = dropout)
158 |
159 | return net
160 |
161 | # attention related classes
162 |
163 | class Attention(Module):
164 | def __init__(
165 | self,
166 | dim,
167 | dim_head = 32,
168 | dropout = 0.,
169 | window_size = 7,
170 | num_mem_kv = 4
171 | ):
172 | super().__init__()
173 | assert (dim % dim_head) == 0, 'dimension should be divisible by dimension per head'
174 |
175 | self.norm = LayerNorm(dim)
176 |
177 | self.heads = dim // dim_head
178 | self.scale = dim_head ** -0.5
179 |
180 | self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
181 |
182 | self.mem_kv = nn.Parameter(torch.randn(2, self.heads, num_mem_kv, dim_head))
183 |
184 | self.attend = nn.Sequential(
185 | nn.Softmax(dim = -1),
186 | nn.Dropout(dropout)
187 | )
188 |
189 | self.to_out = nn.Sequential(
190 | nn.Linear(dim, dim, bias = False),
191 | nn.Dropout(dropout)
192 | )
193 |
194 | # relative positional bias
195 |
196 | self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
197 |
198 | pos = torch.arange(window_size)
199 | grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
200 | grid = rearrange(grid, 'c i j -> (i j) c')
201 | rel_pos = rearrange(grid, 'i ... -> i 1 ...') - rearrange(grid, 'j ... -> 1 j ...')
202 | rel_pos += window_size - 1
203 | rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(dim = -1)
204 |
205 | self.register_buffer('rel_pos_indices', rel_pos_indices, persistent = False)
206 |
207 | def forward(self, x):
208 | batch, height, width, window_height, window_width, _, device, h = *x.shape, x.device, self.heads
209 |
210 | x = self.norm(x)
211 |
212 | # flatten
213 |
214 | x = rearrange(x, 'b x y w1 w2 d -> (b x y) (w1 w2) d')
215 |
216 | # project for queries, keys, values
217 |
218 | q, k, v = self.to_qkv(x).chunk(3, dim = -1)
219 |
220 | # split heads
221 |
222 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
223 |
224 | # scale
225 |
226 | q = q * self.scale
227 |
228 | # null / memory / register kv
229 |
230 | mk, mv = map(lambda t: repeat(t, 'h n d -> b h n d', b = q.shape[0]), self.mem_kv)
231 | num_mem = mk.shape[-2]
232 |
233 | k = torch.cat((mk, k), dim = -2)
234 | v = torch.cat((mv, v), dim = -2)
235 |
236 | # sim
237 |
238 | sim = einsum('b h i d, b h j d -> b h i j', q, k)
239 |
240 | # add positional bias
241 |
242 | bias = self.rel_pos_bias(self.rel_pos_indices)
243 |
244 | bias = F.pad(bias, (0, 0, num_mem, 0), value = 0.)
245 |
246 | sim = sim + rearrange(bias, 'i j h -> h i j')
247 |
248 | # attention
249 |
250 | attn = self.attend(sim)
251 |
252 | # aggregate
253 |
254 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
255 |
256 | # merge heads
257 |
258 | out = rearrange(out, 'b h (w1 w2) d -> b w1 w2 (h d)', w1 = window_height, w2 = window_width)
259 |
260 | # combine heads out
261 |
262 | out = self.to_out(out)
263 | return rearrange(out, '(b x y) ... -> b x y ...', x = height, y = width)
264 |
265 | class MaxViT(Module):
266 | def __init__(
267 | self,
268 | *,
269 | num_classes,
270 | dim,
271 | depth,
272 | dim_head = 32,
273 | dim_conv_stem = None,
274 | window_size = 7,
275 | mbconv_expansion_rate = 4,
276 | mbconv_shrinkage_rate = 0.25,
277 | dropout = 0.1,
278 | channels = 3
279 | ):
280 | super().__init__()
281 | assert isinstance(depth, tuple), 'depth needs to be tuple if integers indicating number of transformer blocks at that stage'
282 |
283 | # convolutional stem
284 |
285 | dim_conv_stem = default(dim_conv_stem, dim)
286 |
287 | self.conv_stem = nn.Sequential(
288 | nn.Conv2d(channels, dim_conv_stem, 3, stride = 2, padding = 1),
289 | nn.Conv2d(dim_conv_stem, dim_conv_stem, 3, padding = 1)
290 | )
291 |
292 | # variables
293 |
294 | num_stages = len(depth)
295 |
296 | dims = tuple(map(lambda i: (2 ** i) * dim, range(num_stages)))
297 | dims = (dim_conv_stem, *dims)
298 | dim_pairs = tuple(zip(dims[:-1], dims[1:]))
299 |
300 | self.layers = ModuleList([])
301 |
302 | # shorthand for window size for efficient block - grid like attention
303 |
304 | w = window_size
305 |
306 | # iterate through stages
307 |
308 | cond_hidden_dims = []
309 |
310 | for ind, ((layer_dim_in, layer_dim), layer_depth) in enumerate(zip(dim_pairs, depth)):
311 | for stage_ind in range(layer_depth):
312 | is_first = stage_ind == 0
313 | stage_dim_in = layer_dim_in if is_first else layer_dim
314 |
315 | cond_hidden_dims.append(stage_dim_in)
316 |
317 | block = nn.Sequential(
318 | MBConv(
319 | stage_dim_in,
320 | layer_dim,
321 | downsample = is_first,
322 | expansion_rate = mbconv_expansion_rate,
323 | shrinkage_rate = mbconv_shrinkage_rate
324 | ),
325 | Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
326 | Residual(Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
327 | Residual(FeedForward(dim = layer_dim, dropout = dropout)),
328 | Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),
329 |
330 | Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
331 | Residual(Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
332 | Residual(FeedForward(dim = layer_dim, dropout = dropout)),
333 | Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
334 | )
335 |
336 | self.layers.append(block)
337 |
338 | embed_dim = dims[-1]
339 | self.embed_dim = dims[-1]
340 |
341 | self.cond_hidden_dims = cond_hidden_dims
342 |
343 | # mlp head out
344 |
345 | self.mlp_head = nn.Sequential(
346 | Reduce('b d h w -> b d', 'mean'),
347 | LayerNorm(embed_dim),
348 | nn.Linear(embed_dim, num_classes)
349 | )
350 |
351 | @beartype
352 | def forward(
353 | self,
354 | x,
355 | texts: list[str] | None = None,
356 | cond_fns: tuple[Callable, ...] | None = None,
357 | cond_drop_prob = 0.,
358 | return_embeddings = False
359 | ):
360 | x = self.conv_stem(x)
361 |
362 | cond_fns = iter(default(cond_fns, []))
363 |
364 | for stage in self.layers:
365 | cond_fn = next(cond_fns, None)
366 |
367 | if exists(cond_fn):
368 | x = cond_fn(x)
369 |
370 | x = stage(x)
371 |
372 | if return_embeddings:
373 | return x
374 |
375 | return self.mlp_head(x)
376 |
377 | # attention
378 |
379 | class TransformerAttention(Module):
380 | def __init__(
381 | self,
382 | dim,
383 | causal = False,
384 | dim_head = 64,
385 | dim_context = None,
386 | heads = 8,
387 | norm_context = False,
388 | dropout = 0.1
389 | ):
390 | super().__init__()
391 | self.heads = heads
392 | self.scale = dim_head ** -0.5
393 | self.causal = causal
394 | inner_dim = dim_head * heads
395 |
396 | dim_context = default(dim_context, dim)
397 |
398 | self.norm = LayerNorm(dim)
399 | self.context_norm = LayerNorm(dim_context) if norm_context else nn.Identity()
400 |
401 | self.attn_dropout = nn.Dropout(dropout)
402 |
403 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
404 | self.to_kv = nn.Linear(dim_context, dim_head * 2, bias = False)
405 | self.to_out = nn.Sequential(
406 | nn.Linear(inner_dim, dim, bias = False),
407 | nn.Dropout(dropout)
408 | )
409 |
410 | def forward(
411 | self,
412 | x,
413 | context = None,
414 | mask = None,
415 | attn_bias = None,
416 | attn_mask = None,
417 | cond_fn: Callable | None = None
418 | ):
419 | b = x.shape[0]
420 |
421 | if exists(context):
422 | context = self.context_norm(context)
423 |
424 | kv_input = default(context, x)
425 |
426 | x = self.norm(x)
427 |
428 | if exists(cond_fn):
429 | # adaptive layer-norm
430 | x = cond_fn(x)
431 |
432 | q, k, v = self.to_q(x), *self.to_kv(kv_input).chunk(2, dim = -1)
433 |
434 | q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
435 |
436 | q = q * self.scale
437 |
438 | sim = einsum('b h i d, b j d -> b h i j', q, k)
439 |
440 | if exists(attn_bias):
441 | sim = sim + attn_bias
442 |
443 | if exists(attn_mask):
444 | sim = sim.masked_fill(~attn_mask, -torch.finfo(sim.dtype).max)
445 |
446 | if exists(mask):
447 | mask = rearrange(mask, 'b j -> b 1 1 j')
448 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
449 |
450 | if self.causal:
451 | i, j = sim.shape[-2:]
452 | causal_mask = torch.ones((i, j), dtype = torch.bool, device = x.device).triu(j - i + 1)
453 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
454 |
455 | attn = sim.softmax(dim = -1)
456 | attn = self.attn_dropout(attn)
457 |
458 | out = einsum('b h i j, b j d -> b h i d', attn, v)
459 |
460 | out = rearrange(out, 'b h n d -> b n (h d)')
461 | return self.to_out(out)
462 |
463 | class Transformer(Module):
464 | @beartype
465 | def __init__(
466 | self,
467 | dim,
468 | dim_head = 64,
469 | heads = 8,
470 | depth = 6,
471 | attn_dropout = 0.,
472 | ff_dropout = 0.
473 | ):
474 | super().__init__()
475 | self.layers = ModuleList([])
476 | for _ in range(depth):
477 | self.layers.append(ModuleList([
478 | TransformerAttention(dim = dim, heads = heads, dropout = attn_dropout),
479 | FeedForward(dim = dim, dropout = ff_dropout)
480 | ]))
481 |
482 | @beartype
483 | def forward(
484 | self,
485 | x,
486 | cond_fns: tuple[Callable, ...] | None = None,
487 | attn_mask = None
488 | ):
489 | cond_fns = iter(default(cond_fns, []))
490 |
491 | for attn, ff in self.layers:
492 | x = attn(x, attn_mask = attn_mask, cond_fn = next(cond_fns, None)) + x
493 | x = ff(x, cond_fn = next(cond_fns, None)) + x
494 | return x
495 |
496 | # token learner module
497 |
498 | class TokenLearner(Module):
499 | """
500 | https://arxiv.org/abs/2106.11297
501 | using the 1.1 version with the MLP (2 dense layers with gelu) for generating attention map
502 | """
503 |
504 | def __init__(
505 | self,
506 | *,
507 | dim,
508 | ff_mult = 2,
509 | num_output_tokens = 8,
510 | num_layers = 2
511 | ):
512 | super().__init__()
513 | inner_dim = dim * ff_mult * num_output_tokens
514 |
515 | self.num_output_tokens = num_output_tokens
516 | self.net = nn.Sequential(
517 | nn.Conv2d(dim * num_output_tokens, inner_dim, 1, groups = num_output_tokens),
518 | nn.GELU(),
519 | nn.Conv2d(inner_dim, num_output_tokens, 1, groups = num_output_tokens),
520 | )
521 |
522 | def forward(self, x):
523 | x, ps = pack_one(x, '* c h w')
524 | x = repeat(x, 'b c h w -> b (g c) h w', g = self.num_output_tokens)
525 | attn = self.net(x)
526 |
527 | attn = rearrange(attn, 'b g h w -> b 1 g h w')
528 | x = rearrange(x, 'b (g c) h w -> b c g h w', g = self.num_output_tokens)
529 |
530 | x = reduce(x * attn, 'b c g h w -> b c g', 'mean')
531 | x = unpack_one(x, ps, '* c n')
532 | return x
533 |
534 | # Robotic Transformer
535 |
536 | class RT1(Module):
537 | @beartype
538 | def __init__(
539 | self,
540 | *,
541 | vit: MaxViT,
542 | num_actions = 11,
543 | action_bins = 256,
544 | depth = 6,
545 | heads = 8,
546 | dim_head = 64,
547 | token_learner_ff_mult = 2,
548 | token_learner_num_layers = 2,
549 | token_learner_num_output_tokens = 8,
550 | cond_drop_prob = 0.2,
551 | use_attn_conditioner = False,
552 | conditioner_kwargs: dict = dict()
553 | ):
554 | super().__init__()
555 | self.vit = vit
556 |
557 | self.num_vit_stages = len(vit.cond_hidden_dims)
558 |
559 | conditioner_klass = AttentionTextConditioner if use_attn_conditioner else TextConditioner
560 |
561 | self.conditioner = conditioner_klass(
562 | hidden_dims = (*tuple(vit.cond_hidden_dims), *((vit.embed_dim,) * depth * 2)),
563 | hiddens_channel_first = (*((True,) * self.num_vit_stages), *((False,) * depth * 2)),
564 | cond_drop_prob = cond_drop_prob,
565 | **conditioner_kwargs
566 | )
567 |
568 | self.token_learner = TokenLearner(
569 | dim = vit.embed_dim,
570 | ff_mult = token_learner_ff_mult,
571 | num_output_tokens = token_learner_num_output_tokens,
572 | num_layers = token_learner_num_layers
573 | )
574 |
575 | self.num_learned_tokens = token_learner_num_output_tokens
576 |
577 | self.transformer_depth = depth
578 |
579 | self.transformer = Transformer(
580 | dim = vit.embed_dim,
581 | dim_head = dim_head,
582 | heads = heads,
583 | depth = depth
584 | )
585 |
586 | self.cond_drop_prob = cond_drop_prob
587 |
588 | self.to_logits = nn.Sequential(
589 | LayerNorm(vit.embed_dim),
590 | nn.Linear(vit.embed_dim, num_actions * action_bins),
591 | Rearrange('... (a b) -> ... a b', b = action_bins)
592 | )
593 |
594 | @beartype
595 | def embed_texts(self, texts: list[str]):
596 | return self.conditioner.embed_texts(texts)
597 |
598 | @classifier_free_guidance
599 | @beartype
600 | def forward(
601 | self,
602 | video,
603 | texts: list[str] | None = None,
604 | text_embeds: Tensor | None = None,
605 | cond_drop_prob = 0.
606 | ):
607 | assert exists(texts) ^ exists(text_embeds)
608 |
609 | if exists(texts):
610 | num_texts = len(texts)
611 | elif exists(text_embeds):
612 | num_texts = text_embeds.shape[0]
613 |
614 | assert num_texts == video.shape[0], f'you only passed in {num_texts} strings for guiding the robot actions, but received batch size of {video.shape[0]} videos'
615 |
616 | cond_kwargs = dict(texts = texts, text_embeds = text_embeds)
617 |
618 | depth = self.transformer_depth
619 | cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)
620 |
621 | frames, device = video.shape[2], video.device
622 |
623 | cond_fns, _ = self.conditioner(
624 | **cond_kwargs,
625 | cond_drop_prob = cond_drop_prob,
626 | repeat_batch = (*((frames,) * self.num_vit_stages), *((1,) * self.transformer_depth * 2))
627 | )
628 |
629 | vit_cond_fns, transformer_cond_fns = cond_fns[:-(depth * 2)], cond_fns[-(depth * 2):]
630 |
631 | video = rearrange(video, 'b c f h w -> b f c h w')
632 | images, packed_shape = pack_one(video, '* c h w')
633 |
634 | tokens = self.vit(
635 | images,
636 | texts = texts,
637 | cond_fns = vit_cond_fns,
638 | cond_drop_prob = cond_drop_prob,
639 | return_embeddings = True
640 | )
641 |
642 | tokens = unpack_one(tokens, packed_shape, '* c h w')
643 | learned_tokens = self.token_learner(tokens)
644 |
645 | learned_tokens = rearrange(learned_tokens, 'b f c n -> b (f n) c')
646 |
647 | # causal attention mask
648 |
649 | attn_mask = torch.ones((frames, frames), dtype = torch.bool, device = device).triu(1)
650 | attn_mask = repeat(attn_mask, 'i j -> (i r1) (j r2)', r1 = self.num_learned_tokens, r2 = self.num_learned_tokens)
651 |
652 | # sinusoidal positional embedding
653 |
654 | pos_emb = posemb_sincos_1d(frames, learned_tokens.shape[-1], dtype = learned_tokens.dtype, device = learned_tokens.device)
655 |
656 | learned_tokens = learned_tokens + repeat(pos_emb, 'n d -> (n r) d', r = self.num_learned_tokens)
657 |
658 | # attention
659 |
660 | attended_tokens = self.transformer(learned_tokens, cond_fns = transformer_cond_fns, attn_mask = ~attn_mask)
661 |
662 | pooled = reduce(attended_tokens, 'b (f n) d -> b f d', 'mean', f = frames)
663 |
664 | logits = self.to_logits(pooled)
665 | return logits
666 |
--------------------------------------------------------------------------------
/rt1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/robotic-transformer-pytorch/1512c9b460944accb2d874e4f3354a8e196f50e8/rt1.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'robotic-transformer-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.2.3',
7 | license='MIT',
8 | description = 'Robotic Transformer - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/robotic-transformer-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'transformers',
17 | 'attention mechanism',
18 | 'robotics'
19 | ],
20 | install_requires=[
21 | 'classifier-free-guidance-pytorch>=0.7.1',
22 | 'einops>=0.8',
23 | 'torch>=2.0',
24 | ],
25 | classifiers=[
26 | 'Development Status :: 4 - Beta',
27 | 'Intended Audience :: Developers',
28 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
29 | 'License :: OSI Approved :: MIT License',
30 | 'Programming Language :: Python :: 3.6',
31 | ],
32 | )
33 |
--------------------------------------------------------------------------------