├── .gitignore ├── LICENSE ├── README.md ├── di ├── frame.py ├── quantizer.py └── transformer.py ├── dreamer_model.py ├── iris.py ├── model.py ├── ruff.toml ├── test_compare_delta_iris.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.ckpt 2 | *_jaxpr* 3 | *.pt 4 | *.pyc 5 | delta-iris 6 | outputs 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024, George Hotz 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DreamerV3 2 | 3 | https://arxiv.org/abs/2301.04104 4 | 5 | Want to do well on Atari100k (`pip install gym[atari] autorom[accept-rom-license]`), though BSuite (`pip install bsuite`) looks interesting too. 6 | 7 | This is designed to run on a tinybox, either red or green, with just `./train.py` 8 | 9 | ## Process 10 | 11 | 1. Run https://github.com/danijar/dreamerv3 to train a model that plays Pong 12 | 2. Get that model loaded into tinygrad and running, both the policy model and decoder 13 | 3. Get fine tuning working 14 | 4. Get full training working 15 | 16 | ## delta-iris 17 | 18 | Might be a better choice, the repo is a lot easier to read. https://github.com/vmicheli/delta-iris 19 | 20 | Three models: 21 | * actor_critic (two copies, model and target_model) 22 | * world_model 23 | * transformer takes in (frames_emb x1, act_tokens_emb x1, latents_emb x4) x many 24 | * frame_cnn (FrameEncoder), output 4 channels 25 | * tokenizer 26 | * frame_cnn (FrameEncoder), output 16 channels 27 | * encoder is 7 channels, 3 for prev_frame, 1 for action, and 3 for frame (FrameEncoder), output 64 channels for quantizer 28 | * decoder is 84 channels, 16 for prev_frame, 4 for action, and 64 for latents. it outputs an image (FrameDecoder) 29 | * quantizer 30 | 31 | Training: 32 | * Happens in three distinct phases 33 | * First, tokenizer is trained. It outputs 4 (from a vocab of 1024, codebook dim of 64) tokens per delta image 34 | * q = encoder(img_0, encoder_act_emb(a), img_1) 35 | * decoder(frame_cnn(img_0), decoder_act_emb(a), q) 36 | * Then, world model is trained 37 | * transformer([frame_cnn(img_0), act_emb(a), latents_emb(tokens_from_encoder), ...]) 38 | * Last, actor critic is trained (in world model) 39 | 40 | Our training strategy is to reproduce each one in reverse. 41 | -------------------------------------------------------------------------------- /di/frame.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from tinygrad import Tensor, nn 3 | 4 | class Downsample: 5 | def __init__(self, num_channels: int) -> None: 6 | self.conv = nn.Conv2d(num_channels, num_channels, kernel_size=2, stride=2, padding=0) 7 | def __call__(self, x: Tensor) -> Tensor: return self.conv(x) 8 | 9 | class Upsample: 10 | def __init__(self, num_channels: int) -> None: 11 | self.conv = nn.Conv2d(num_channels, num_channels, kernel_size=3, stride=1, padding=1) 12 | def __call__(self, x: Tensor) -> Tensor: 13 | # TODO: is this fast? AssertionError: only supports linear interpolate 14 | #x = x.interpolate([s*2 for s in x.size()], mode="nearest") 15 | # TODO: repeat_interleave should support a tuple for dim 16 | x = x.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3) 17 | return self.conv(x) 18 | 19 | class ResidualBlock: 20 | def __init__(self, in_channels: int, out_channels: int, num_groups_norm: int = 32) -> None: 21 | self.f = [ 22 | nn.GroupNorm(num_groups_norm, in_channels), 23 | Tensor.silu, 24 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), 25 | nn.GroupNorm(num_groups_norm, out_channels), 26 | Tensor.silu, 27 | nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), 28 | ] 29 | self.skip_projection = (lambda x: x) if in_channels == out_channels else nn.Conv2d(in_channels, out_channels, kernel_size=1) 30 | def __call__(self, x: Tensor) -> Tensor: return self.skip_projection(x) + x.sequential(self.f) 31 | 32 | class FrameDecoder: 33 | def __init__(self): 34 | self.decoder = [ 35 | nn.Conv2d(84, 256, kernel_size=3, stride=1, padding=1), 36 | ResidualBlock(256, 128), Upsample(128), 37 | ResidualBlock(128, 128), Upsample(128), 38 | ResidualBlock(128, 64), 39 | ResidualBlock(64, 64), Upsample(64), 40 | ResidualBlock(64, 64), 41 | nn.GroupNorm(num_groups=32, num_channels=64), 42 | Tensor.silu, 43 | nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1), 44 | ] 45 | 46 | def __call__(self, x:Tensor) -> Tensor: 47 | b, t, _, _, _ = x.size() 48 | x = x.rearrange('b t c h w -> (b t) c h w') 49 | x = x.sequential(self.decoder) 50 | x = x.rearrange('(b t) c h w -> b t c h w', b=b, t=t) 51 | return x 52 | 53 | class FrameEncoder: 54 | def __init__(self, channels: List[int]): 55 | self.encoder = [ 56 | nn.Conv2d(channels[0], channels[1], kernel_size=3, stride=1, padding=1), 57 | ResidualBlock(channels[1], channels[1]), Downsample(channels[1]), 58 | ResidualBlock(channels[1], channels[1]), 59 | ResidualBlock(channels[1], channels[2]), Downsample(channels[2]), 60 | ResidualBlock(channels[2], channels[2]), Downsample(channels[2]), 61 | ResidualBlock(channels[2], channels[3]), 62 | nn.GroupNorm(num_groups=32, num_channels=channels[3]), 63 | Tensor.silu, 64 | nn.Conv2d(channels[3], channels[4], kernel_size=3, stride=1, padding=1), 65 | ] 66 | 67 | def __call__(self, x: Tensor) -> Tensor: 68 | b, t, _, _, _ = x.size() 69 | x = x.rearrange('b t c h w -> (b t) c h w') 70 | x = x.sequential(self.encoder) 71 | x = x.rearrange('(b t) c h w -> b t c h w', b=b, t=t) 72 | return x -------------------------------------------------------------------------------- /di/quantizer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Optional, Tuple 2 | import math 3 | from tinygrad import Tensor, nn 4 | from dataclasses import dataclass 5 | 6 | @dataclass 7 | class QuantizerOutput: 8 | q: Tensor 9 | tokens: Tensor 10 | loss: Dict[str, Tensor] 11 | metrics: Dict[str, float] 12 | 13 | class Quantizer: 14 | def __init__(self, codebook_size: int, codebook_dim: int, input_dim: int): 15 | assert math.log2(codebook_size).is_integer() 16 | self.pre_quant_proj = nn.Linear(input_dim, codebook_dim) 17 | self.post_quant_proj = nn.Linear(codebook_dim, input_dim) 18 | self.codebook = Tensor.uniform(codebook_size, codebook_dim, low=-1.0 / codebook_size, high=1.0 / codebook_size) 19 | 20 | def __call__(self, z:Tensor) -> QuantizerOutput: 21 | z = self.pre_quant_proj(z) 22 | b, k = z.size(0), z.size(2) 23 | z = z.rearrange('b t k e -> (b t k) e') 24 | 25 | cosine_similarity = Tensor.einsum('n e, c e -> n c', z, self.codebook) 26 | tokens = cosine_similarity.argmax(axis=-1) # TODO: support both axis and dim 27 | q = self.codebook[tokens] 28 | 29 | losses = {'commitment_loss': 0.02 * (z - q.detach()).pow(2).mean()} 30 | metrics = {} 31 | 32 | q = z + (q - z).detach() 33 | q = self.post_quant_proj(q) 34 | 35 | q = q.rearrange('(b t k) e -> b t k e', b=b, k=k) 36 | tokens = tokens.rearrange('(b t k) -> b t k', b=b, k=k) 37 | return QuantizerOutput(q, tokens, losses, metrics) 38 | 39 | def embed_tokens(self, tokens: Tensor) -> Tensor: 40 | return self.post_quant_proj(self.codebook[tokens]) 41 | -------------------------------------------------------------------------------- /di/transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from tinygrad import Tensor, nn, Variable 3 | 4 | MAX_CONTEXT = 156 5 | EMBED_DIM = 256 6 | 7 | class MLPLayer: 8 | def __init__(self): 9 | self.ln = nn.LayerNorm(EMBED_DIM) 10 | self.mlp = [nn.Linear(EMBED_DIM, 4*EMBED_DIM), Tensor.gelu, nn.Linear(4*EMBED_DIM, EMBED_DIM)] 11 | def __call__(self, x:Tensor): return x + self.ln(x).sequential(self.mlp) 12 | 13 | class Attention: 14 | def __init__(self): 15 | self.proj = nn.Linear(EMBED_DIM, EMBED_DIM) 16 | self.num_heads = 4 17 | 18 | def __call__(self, q:Tensor, k:Tensor, v:Tensor, start_pos:int) -> Tensor: 19 | # update the cache 20 | bsz, seqlen, _ = q.shape 21 | if not hasattr(self, 'cache_kv'): self.cache_kv = Tensor.zeros(2, bsz, MAX_CONTEXT, EMBED_DIM, dtype=k.dtype).contiguous().realize() 22 | assert k.dtype == v.dtype == self.cache_kv.dtype, f"{k.dtype=}, {v.dtype=}, {self.cache_kv.dtype=}" 23 | # TODO: figure out how to remove this realize 24 | self.cache_kv.shrink((None, None, (start_pos, start_pos+seqlen), None)).assign(Tensor.stack(k, v)).realize() 25 | k = self.cache_kv[0].shrink((None, (0, start_pos+seqlen), None)) 26 | v = self.cache_kv[1].shrink((None, (0, start_pos+seqlen), None)) 27 | 28 | # TODO: this should be smartly folded in 29 | mask = Tensor.full((1, 1, seqlen, start_pos+seqlen), float("-inf"), dtype=q.dtype, device=q.device).triu(start_pos+1) 30 | q: Tensor = q.rearrange('b q (h e) -> b h q e', h=self.num_heads) 31 | k = k.rearrange('b k (h e) -> b h k e', h=self.num_heads) 32 | v = v.rearrange('b k (h d) -> b h k d', h=self.num_heads) 33 | y = q.scaled_dot_product_attention(k, v, mask).rearrange('b h q d -> b q (h d)') 34 | return self.proj(y) 35 | 36 | class SelfAttentionLayer: 37 | def __init__(self): 38 | self.ln = nn.LayerNorm(EMBED_DIM) 39 | self.query = nn.Linear(EMBED_DIM, EMBED_DIM) 40 | self.key = nn.Linear(EMBED_DIM, EMBED_DIM) 41 | self.value = nn.Linear(EMBED_DIM, EMBED_DIM) 42 | self.attention = Attention() 43 | def __call__(self, inputs:Tensor, start_pos:int) -> Tensor: 44 | x = self.ln(inputs) 45 | q = self.query(x) 46 | k = self.key(x) 47 | v = self.value(x) 48 | return inputs + self.attention(q, k, v, start_pos) 49 | 50 | class EncoderLayer: 51 | def __init__(self): 52 | self.sa = SelfAttentionLayer() 53 | self.mlp = MLPLayer() 54 | def __call__(self, x:Tensor, start_pos:int) -> Tensor: return self.mlp(self.sa(x, start_pos)) 55 | 56 | class TransformerEncoder: 57 | def __init__(self): 58 | self.pos_emb = nn.Embedding(MAX_CONTEXT, EMBED_DIM) 59 | self.ln = nn.LayerNorm(EMBED_DIM) 60 | self.blocks = [EncoderLayer() for _ in range(3)] 61 | self.start_pos = 0 62 | 63 | def __call__(self, x:Tensor) -> Tensor: 64 | assert x.ndim == 3 and x.size(2) == EMBED_DIM # (B, TK, E) 65 | y = x + self.pos_emb(Tensor.arange(self.start_pos, self.start_pos+x.size(1))) 66 | for b in self.blocks: y = b(y, self.start_pos) 67 | self.start_pos += x.size(1) 68 | return self.ln(y) 69 | 70 | class Head: 71 | def __init__(self, output_dim): 72 | self.head_module = [ 73 | nn.Linear(EMBED_DIM, EMBED_DIM), Tensor.relu, 74 | nn.Linear(EMBED_DIM, output_dim)] 75 | def __call__(self, outputs:Tensor): 76 | return outputs.sequential(self.head_module) 77 | -------------------------------------------------------------------------------- /dreamer_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import gymnasium as gym 3 | import pickle, math 4 | import numpy as np 5 | from tinygrad.helpers import prod 6 | from tinygrad import Tensor, nn 7 | from PIL import Image 8 | 9 | class NormConv2d: 10 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, transp=False): 11 | self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) 12 | self.stride, self.padding, self.transp = stride, padding, transp 13 | self.scale = Tensor.ones(out_channels) 14 | self.eps = 1e-6 15 | scale = 1 / math.sqrt(in_channels * prod(self.kernel_size)) 16 | if transp: self.weight = Tensor.uniform(in_channels, out_channels, *self.kernel_size, low=-scale, high=scale) 17 | else: self.weight = Tensor.uniform(out_channels, in_channels, *self.kernel_size, low=-scale, high=scale) 18 | self.bias = Tensor.uniform(out_channels, low=-scale, high=scale) 19 | 20 | def __call__(self, x:Tensor) -> Tensor: 21 | if self.transp: x = x.conv_transpose2d(self.weight, self.bias, padding=self.padding, stride=self.stride, output_padding=1) 22 | else: x = x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride) 23 | # TODO: RMSNorm should work on given channel, not just -1 24 | x = x * (x.square().mean(1, keepdim=True) + self.eps).rsqrt() 25 | return x * self.scale.reshape(1, -1, 1, 1) 26 | 27 | class RSSM: 28 | def __init__(self): 29 | pass 30 | 31 | class Actor: 32 | def __init__(self): 33 | self.mlp = [ 34 | nn.Linear(10240, 1024), nn.RMSNorm(1024), Tensor.silu, 35 | nn.Linear(1024, 1024), nn.RMSNorm(1024), Tensor.silu, 36 | nn.Linear(1024, 1024), nn.RMSNorm(1024), Tensor.silu, 37 | nn.Linear(1024, 18)] 38 | def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.mlp) 39 | 40 | class Encoder: 41 | def __init__(self): 42 | # TODO: i want padding to support "same" 43 | self.conv = [ 44 | NormConv2d(1, 64, 5, padding=(2,2)), 45 | NormConv2d(64, 128, 5, 2, padding=(2,2)), 46 | NormConv2d(128, 192, 5, 2, padding=(2,2)), 47 | NormConv2d(192, 256, 5, 2, padding=(2,2)), 48 | NormConv2d(256, 256, 5, 2, padding=(2,2))] 49 | 50 | def __call__(self, x:Tensor) -> Tensor: 51 | for c in self.conv: x = c(x).silu() 52 | return x 53 | 54 | class Decoder: 55 | def __init__(self): 56 | self.conv = [ 57 | NormConv2d(128, 64, 5, 2, padding=(2,2), transp=True), 58 | NormConv2d(192, 128, 5, 2, padding=(2,2), transp=True), 59 | NormConv2d(256, 192, 5, 2, padding=(2,2), transp=True), 60 | NormConv2d(256, 256, 5, 2, padding=(2,2), transp=True)] 61 | self.imgout = nn.ConvTranspose2d(64, 1, 5, padding=(2,2)) 62 | 63 | def __call__(self, x:Tensor) -> Tensor: 64 | for c in self.conv[::-1]: x = c(x).silu() 65 | return self.imgout(x) 66 | 67 | def preprocess(obs, size=(64, 64)): 68 | image = Image.fromarray(obs).resize(size, Image.NEAREST) 69 | weights = [0.299, 0.587, 1 - (0.299 + 0.587)] 70 | image = np.tensordot(image, weights, (-1, 0)) 71 | return Tensor(image, dtype='float32').unsqueeze(0) 72 | 73 | if __name__ == "__main__": 74 | env = gym.make('ALE/Pong-v5') 75 | obs, info = env.reset() 76 | 77 | actor = Actor() 78 | encoder = Encoder() 79 | decoder = Decoder() 80 | dyn = RSSM() 81 | 82 | # TODO: confirm that assigning to transpose like this is correct 83 | assigns = { 84 | "agent/actor/h0/kernel": actor.mlp[0].weight.T, 85 | "agent/actor/h0/bias": actor.mlp[0].bias, 86 | "agent/actor/h0/norm/scale": actor.mlp[1].weight, 87 | "agent/actor/h1/kernel": actor.mlp[3].weight.T, 88 | "agent/actor/h1/bias": actor.mlp[3].bias, 89 | "agent/actor/h1/norm/scale": actor.mlp[4].weight, 90 | "agent/actor/h2/kernel": actor.mlp[6].weight.T, 91 | "agent/actor/h2/bias": actor.mlp[6].bias, 92 | "agent/actor/h2/norm/scale": actor.mlp[7].weight, 93 | "agent/actor/action/out/kernel": actor.mlp[9].weight.T, 94 | "agent/actor/action/out/bias": actor.mlp[9].bias, 95 | } 96 | 97 | dat = pickle.load(open("checkpoint.ckpt", "rb")) 98 | for k, v in dat['agent'].items(): 99 | print(k, v.shape if hasattr(v, 'shape') else None) 100 | for s,m,e in [('agent/enc/conv', encoder, True), ('agent/dec/conv', decoder, False)]: 101 | if k.startswith(s): 102 | if k.endswith('kernel'): m.conv[int(k.split(s)[1].split("/")[0])].weight.assign(v.transpose(3,2,0,1) if e else v.transpose(2,3,0,1)) 103 | if k.endswith('bias'): m.conv[int(k.split(s)[1].split("/")[0])].bias.assign(v) 104 | if k.endswith('scale'): m.conv[int(k.split(s)[1].split("/")[0])].scale.assign(v) 105 | if k == 'agent/dec/imgout/kernel': decoder.imgout.weight.assign(v.transpose(2,3,0,1)) 106 | if k == 'agent/dec/imgout/bias': decoder.imgout.bias.assign(v) 107 | if k in assigns: assigns[k].assign(v) 108 | 109 | out = encoder(preprocess(obs, size=(96,96)).unsqueeze(0)) 110 | print(out.shape) 111 | ret = decoder(out) 112 | print(ret.shape) 113 | exit(0) 114 | 115 | import matplotlib.pyplot as plt 116 | plt.imshow(ret[0, 0].numpy()) 117 | plt.show() 118 | -------------------------------------------------------------------------------- /iris.py: -------------------------------------------------------------------------------- 1 | # scp t18:/home/tiny/build/iris/outputs/2024-08-14/01-36-07/checkpoints/last.pt last_iris.pt 2 | from tinygrad import Tensor, nn 3 | 4 | if __name__ == "__main__": 5 | dat = nn.state.torch_load("last_iris.pt") 6 | for k,v in dat.items(): print(k, v.shape) 7 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from typing import List, Dict, Optional, Tuple 3 | import argparse 4 | import math 5 | import pygame 6 | from dataclasses import dataclass 7 | from tinygrad import Tensor, nn, TinyJit, dtypes, GlobalCounters 8 | import gymnasium as gym 9 | from PIL import Image 10 | import numpy as np 11 | 12 | # copied from delta-iris 13 | class MaxAndSkipEnv(gym.Wrapper): 14 | def __init__(self, env, skip=4): 15 | """Return only every `skip`-th frame""" 16 | gym.Wrapper.__init__(self, env) 17 | assert skip > 0 18 | # most recent raw observations (for max pooling across time steps) 19 | self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8) 20 | self._skip = skip 21 | self.max_frame = np.zeros(env.observation_space.shape, dtype=np.uint8) 22 | 23 | def step(self, action): 24 | """Repeat action, sum reward, and max over last observations.""" 25 | total_reward = 0.0 26 | for i in range(self._skip): 27 | obs, reward, terminated, truncated, info = self.env.step(action) 28 | if i == self._skip - 2: 29 | self._obs_buffer[0] = obs 30 | if i == self._skip - 1: 31 | self._obs_buffer[1] = obs 32 | total_reward += reward 33 | if terminated or truncated: break 34 | # Note that the observation on the done=True frame 35 | # doesn't matter 36 | self.max_frame = self._obs_buffer.max(axis=0) 37 | 38 | return self.max_frame, total_reward, terminated, truncated, info 39 | 40 | def reset(self, **kwargs): return self.env.reset(**kwargs) 41 | 42 | 43 | from di.frame import FrameDecoder, FrameEncoder 44 | from di.transformer import EMBED_DIM, TransformerEncoder, Head 45 | from di.quantizer import Quantizer, QuantizerOutput 46 | 47 | # TODO: i like torches tensors that include dtype in the type 48 | 49 | class Tokenizer: 50 | def __init__(self): 51 | self.encoder_act_emb = nn.Embedding(6, 4096) 52 | self.decoder_act_emb = nn.Embedding(6, 256) 53 | self.frame_cnn = FrameEncoder([3,32,64,128,16]) 54 | self.encoder = FrameEncoder([7,64,128,256,64]) 55 | self.decoder = FrameDecoder() 56 | self.quantizer = Quantizer(codebook_size=1024, codebook_dim=64, input_dim=1024) 57 | 58 | # guessed to make dims match 59 | self.token_res = 4 60 | self.tokens_grid_res = 2 61 | 62 | def __call__(self, x1: Tensor, a: Tensor, x2: Tensor) -> QuantizerOutput: 63 | z = self.encode(x1, a, x2) 64 | z = z.rearrange('b t c (h k) (w l) -> b t (h w) (k l c)', h=self.tokens_grid_res, w=self.tokens_grid_res) 65 | return self.quantizer(z) 66 | 67 | # need typing 68 | def encode(self, x1: Tensor, a: Tensor, x2: Tensor) -> Tensor: 69 | a_emb = self.encoder_act_emb(a).rearrange('b t (h w) -> b t 1 h w', h=x1.size(3)) 70 | encoder_input = Tensor.cat(x1, a_emb, x2, dim=2) 71 | return self.encoder(encoder_input) 72 | 73 | def decode(self, x1: Tensor, a: Tensor, q2: Tensor, should_clamp: bool = False) -> Tensor: 74 | x1_emb = self.frame_cnn(x1) 75 | a_emb = self.decoder_act_emb(a).rearrange('b t (c h w) -> b t c h w', c=4, h=x1_emb.size(3)) 76 | decoder_input = Tensor.cat(x1_emb, a_emb, q2, dim=2) 77 | r = self.decoder(decoder_input) 78 | r = r.clamp(0, 1).mul(255).round().div(255) if should_clamp else r 79 | return r 80 | 81 | @TinyJit 82 | def encode_decode(self, x1: Tensor, a: Tensor, x2: Tensor) -> Tensor: 83 | z = self.encode(x1, a, x2) 84 | z = z.rearrange('b t c (h k) (w l) -> b t (h w) (k l c)', k=self.token_res, l=self.token_res) 85 | q = self.quantizer(z).q.rearrange('b t (h w) (k l e) -> b t e (h k) (w l)', h=self.tokens_grid_res, k=self.token_res, l=self.token_res) 86 | r = self.decode(x1, a, q, should_clamp=True) 87 | return r 88 | 89 | #def __call__(self, x: Tensor) -> Tensor: 90 | # return self.frame_cnn(x) 91 | 92 | @dataclass 93 | class WorldModelOutput: 94 | output_sequence: Tensor #[f"b t {EMBED_DIM}", dtypes.float] 95 | logits_latents: Tensor 96 | logits_rewards: Tensor #["b t 3", "float"] 97 | logits_ends: Tensor #["b t 2", "float"] 98 | 99 | class WorldModel: 100 | def __init__(self): 101 | self.frame_cnn = [FrameEncoder([3,32,64,128,4]), lambda x: x.rearrange('b t c h w -> b t 1 (h w c)'), nn.LayerNorm(EMBED_DIM)] 102 | self.act_emb = nn.Embedding(6, EMBED_DIM) # embed the action 103 | self.latents_emb = nn.Embedding(1024, EMBED_DIM) 104 | 105 | self.transformer = TransformerEncoder() 106 | self.head_latents = Head(1024) 107 | self.head_rewards = Head(3) 108 | self.head_ends = Head(2) 109 | 110 | def __call__(self, sequence:Tensor) -> WorldModelOutput: 111 | outputs = self.transformer(sequence) 112 | 113 | # TODO: this should probably be sliced 114 | logits_latents = self.head_latents(outputs) 115 | logits_rewards = self.head_rewards(outputs) 116 | logits_ends = self.head_ends(outputs) 117 | 118 | return WorldModelOutput(outputs, logits_latents, logits_rewards, logits_ends) 119 | 120 | @dataclass 121 | class ActorCriticOutput: 122 | logits_actions: Tensor 123 | logits_values: Tensor 124 | 125 | class CnnLstmActorCritic: 126 | def __init__(self, num_actions): 127 | self.lstm_dim = 512 128 | self.cnn = [FrameEncoder([3,32,64,128,16]), lambda x: x.rearrange('b t c h w -> (b t) (h w c)')] 129 | self.actor_linear = nn.Linear(self.lstm_dim, num_actions) 130 | self.critic_linear = nn.Linear(self.lstm_dim, 1) 131 | self.lstm = nn.LSTMCell(1024, self.lstm_dim) 132 | self.reset() 133 | 134 | def reset(self): self.hx, self.cx = None, None 135 | 136 | def __call__(self, x:Tensor) -> ActorCriticOutput: 137 | x = x.sequential(self.cnn) 138 | hx, cx = self.lstm(x, (self.hx, self.cx) if self.hx is not None else None) 139 | if self.hx is None: 140 | self.hx, self.cx = hx, cx 141 | else: 142 | # are these assigns needed? 143 | self.hx.assign(hx) 144 | self.cx.assign(cx) 145 | logits_actions = self.actor_linear(self.hx).rearrange('b a -> b 1 a') 146 | logits_values = self.critic_linear(self.hx).rearrange('b c -> b 1 c') 147 | return ActorCriticOutput(logits_actions, logits_values) 148 | 149 | class Model: 150 | def __init__(self): 151 | self.world_model = WorldModel() 152 | self.tokenizer = Tokenizer() 153 | self.actor_critic = {"model": CnnLstmActorCritic(6), "target_model": CnnLstmActorCritic(6)} 154 | 155 | def preprocess(obs: Tensor) -> Tensor: 156 | # TODO: tinygrad has a known bug in uint8 interpolate linear 157 | #return Tensor(obs).permute(2,0,1).interpolate((64, 64), mode='linear').float().reshape(1,1,3,64,64) / 255.0 158 | image = Image.fromarray(obs).resize((64, 64), Image.BILINEAR) 159 | return Tensor(np.array(image)).permute(2,0,1).float().reshape(1,1,3,64,64) / 255.0 160 | 161 | if __name__ == "__main__": 162 | parser = argparse.ArgumentParser() 163 | parser.add_argument('--action', choices=['model', 'user', 'random'], default='model', 164 | help='Choose the action to perform (default: model)') 165 | parser.add_argument('--render', choices=['none', 'worldmodel', 'tokenizer', 'tokenizer_free'], default='tokenizer', 166 | help='Choose the rendering option (default: tokenizer)') 167 | args = parser.parse_args() 168 | 169 | env = MaxAndSkipEnv(gym.make('PongNoFrameskip-v4')) #, render_mode="human")) 170 | obs, info = env.reset() 171 | 172 | model = Model() 173 | 174 | # scp t18:~/build/delta-iris/outputs/2024-08-13/20-34-53/checkpoints/last.pt . 175 | dat = nn.state.torch_load("last.pt") 176 | nn.state.load_state_dict(model, dat) 177 | 178 | model_state = nn.state.get_state_dict(model) 179 | for k,v in dat.items(): 180 | if k not in model_state: print("DIDN'T LOAD", k, v.shape) 181 | 182 | pygame.init() 183 | screen = pygame.display.set_mode((64*8*2, 64*8)) 184 | 185 | def draw(x:Tensor): 186 | img = x[0, 0].permute(2,1,0) 187 | surf = pygame.surfarray.make_surface((img*256).cast('uint8').repeat_interleave(8, 0).repeat_interleave(8, 1).numpy()) 188 | screen.blit(surf, (0, 0)) 189 | pygame.display.flip() 190 | 191 | def getkey(): 192 | pygame.event.clear() 193 | while True: 194 | event = pygame.event.wait() 195 | if event.type == pygame.QUIT: 196 | pygame.quit() 197 | elif event.type == pygame.KEYDOWN: 198 | if event.key == pygame.K_q: return 0 199 | if event.key == pygame.K_w: return 2 200 | if event.key == pygame.K_s: return 5 201 | 202 | # TODO: is this correct with the LSTM and TinyJIT 203 | @TinyJit 204 | def get_action(img_0:Tensor) -> Tensor: 205 | x = model.actor_critic['model'](img_0) 206 | return x.logits_actions.exp().softmax().flatten().multinomial() 207 | 208 | # roll out down 209 | transformer_tokens = None 210 | img_0 = None 211 | while 1: 212 | GlobalCounters.reset() 213 | cur_img = preprocess(obs) 214 | if img_0 is None: img_0 = cur_img 215 | if args.action == "model": 216 | act = get_action(cur_img).item() 217 | elif args.action == "user": 218 | act = getkey() 219 | elif args.action == "random": 220 | act = env.action_space.sample() 221 | obs, reward, terminated, truncated, info = env.step(act) 222 | if terminated or truncated: break 223 | 224 | draw(Tensor.cat(img_0, cur_img, dim=4)) 225 | 226 | if args.render == "worldmodel": 227 | # resync every 20 frames 228 | if transformer_tokens is None or transformer_tokens.shape[1] > 25*6: 229 | img_0 = cur_img 230 | transformer_tokens = Tensor.zeros(1, 0, EMBED_DIM) 231 | 232 | frames_emb = img_0.sequential(model.world_model.frame_cnn)[:, 0] 233 | act_tokens_emb = model.world_model.act_emb(Tensor([[act]])) 234 | print(transformer_tokens.shape, frames_emb.shape, act_tokens_emb.shape) 235 | transformer_tokens = transformer_tokens.cat(frames_emb, act_tokens_emb, dim=1).contiguous() 236 | latents = [] 237 | for i in range(4): 238 | out = model.world_model.transformer(transformer_tokens) 239 | logits_latents = model.world_model.head_latents(out[:, -1:])[0] 240 | latent = logits_latents.exp().softmax().multinomial().flatten() 241 | latents.append(latent) 242 | transformer_tokens = transformer_tokens.cat(model.world_model.latents_emb(latent), dim=1) 243 | latents = model.tokenizer.quantizer.embed_tokens(Tensor.cat(*latents)).reshape((1, 1, 4, 1024)) 244 | qq = latents.rearrange('b t (h w) (k l e) -> b t e (h k) (w l)', 245 | h=model.tokenizer.tokens_grid_res, k=model.tokenizer.token_res, l=model.tokenizer.token_res) 246 | img_0 = model.tokenizer.decode(img_0, Tensor([[act]]), qq, should_clamp=True) 247 | elif args.render == "tokenizer": 248 | img_0 = model.tokenizer.encode_decode(cur_img, Tensor([[act]]), preprocess(obs)) 249 | elif args.render == "tokenizer_free": 250 | img_0 = model.tokenizer.encode_decode(img_0, Tensor([[act]]), preprocess(obs)) 251 | elif args.render == "none": 252 | img_0 = preprocess(obs) 253 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | indent-width = 2 2 | preview = true 3 | target-version = "py38" 4 | 5 | lint.select = [ 6 | "E702", # multiple-statements-on-one-line-semicolon 7 | ] 8 | 9 | -------------------------------------------------------------------------------- /test_compare_delta_iris.py: -------------------------------------------------------------------------------- 1 | from model import Model, preprocess, MaxAndSkipEnv 2 | from tinygrad import Tensor, nn 3 | import gymnasium as gym 4 | import numpy as np 5 | 6 | import sys 7 | sys.path.append("delta-iris/src") 8 | 9 | import hydra 10 | from pathlib import Path 11 | from hydra.utils import instantiate 12 | from omegaconf import DictConfig, OmegaConf 13 | from agent import Agent 14 | 15 | import torch 16 | torch.set_grad_enabled(False) 17 | 18 | from models.tokenizer import Tokenizer 19 | from models.world_model import WorldModel 20 | from models.actor_critic import ActorCritic 21 | 22 | OmegaConf.register_new_resolver("eval", eval) 23 | 24 | @hydra.main(config_path="delta-iris/config", config_name="params/atari") 25 | def main(cfg: DictConfig) -> None: 26 | cfg.params.tokenizer.num_actions = cfg.params.world_model.num_actions = cfg.params.actor_critic.model.num_actions = 6 27 | agent = Agent(Tokenizer(instantiate(cfg.params.tokenizer)), WorldModel(instantiate(cfg.params.world_model)), ActorCritic(instantiate(cfg.params.actor_critic))) 28 | agent.load(model_path := Path(__file__).parent / 'last.pt', "cpu", strict=False) 29 | 30 | model = Model() 31 | dat = nn.state.torch_load(model_path) 32 | nn.state.load_state_dict(model, dat) 33 | 34 | env = MaxAndSkipEnv(gym.make('PongNoFrameskip-v4')) #, render_mode="human")) 35 | act = Tensor([[0]]) 36 | 37 | obs, info = env.reset() 38 | img_0 = preprocess(obs) 39 | obs, reward, terminated, truncated, info = env.step(act.item()) 40 | img_1 = preprocess(obs) 41 | 42 | print("testing worldmodel frame_cnn") 43 | test_x_emb = img_0.sequential(model.world_model.frame_cnn) 44 | real_x_emb = agent.world_model.frame_cnn(torch.Tensor(img_0.numpy())) 45 | assert test_x_emb.shape == real_x_emb.shape 46 | np.testing.assert_allclose(test_x_emb.numpy(), real_x_emb.numpy(), atol=1e-4) 47 | print("PASS") 48 | 49 | print("testing worldmodel act_emb") 50 | test_a_emb = model.world_model.act_emb(act) 51 | real_a_emb = agent.world_model.act_emb(torch.Tensor(act.numpy()).long()) 52 | np.testing.assert_allclose(test_a_emb.numpy(), real_a_emb.numpy(), atol=1e-6) 53 | print("PASS") 54 | 55 | print("testing transformer") 56 | transformer_in = test_x_emb[0].cat(test_a_emb, dim=1) 57 | test_tout = model.world_model.transformer(transformer_in) 58 | real_tout = agent.world_model.transformer(torch.Tensor(transformer_in.numpy())) 59 | np.testing.assert_allclose(test_tout.numpy(), real_tout.numpy(), atol=1e-2) # this atol might not be okay 60 | print("PASS") 61 | 62 | print("testing tokenizer") 63 | test_token = model.tokenizer(img_0, act, img_1) 64 | real_token = agent.tokenizer(torch.Tensor(img_0.numpy()), torch.Tensor(act.numpy()).long(), torch.Tensor(img_1.numpy())) 65 | np.testing.assert_allclose(test_token.q.numpy(), real_token.q.numpy(), atol=1e-6) 66 | np.testing.assert_allclose(test_token.tokens.numpy(), real_token.tokens.numpy(), atol=1e-6) 67 | print("PASS") 68 | 69 | print("testing tokenizer encode/decode") 70 | test_image = model.tokenizer.encode_decode(img_0, act, img_1) 71 | real_image = agent.tokenizer.encode_decode(torch.Tensor(img_0.numpy()), torch.Tensor(act.numpy()).long(), torch.Tensor(img_1.numpy())) 72 | np.testing.assert_allclose(test_image.numpy(), real_image.numpy(), atol=1e-6) # one is a tiny bit off 73 | print("PASS") 74 | 75 | print("testing actor critic") 76 | model.actor_critic['model'](img_0) 77 | test_hx, test_cx = model.actor_critic['model'].hx, model.actor_critic['model'].cx 78 | model.actor_critic['model'](img_1) 79 | test_hx_2, test_cx_2 = model.actor_critic['model'].hx, model.actor_critic['model'].cx 80 | x = agent.actor_critic.model.cnn(torch.Tensor(img_0.numpy())) 81 | real_hx, real_cx = agent.actor_critic.model.lstm(x) 82 | x = agent.actor_critic.model.cnn(torch.Tensor(img_1.numpy())) 83 | real_hx_2, real_cx_2 = agent.actor_critic.model.lstm(x, (real_hx, real_cx)) 84 | # NOTE: this is wrong in new torch? 85 | np.testing.assert_allclose(test_hx.numpy(), real_hx.numpy(), atol=1e-5) 86 | np.testing.assert_allclose(test_cx.numpy(), real_cx.numpy(), atol=1e-5) 87 | np.testing.assert_allclose(test_hx_2.numpy(), real_hx_2.numpy(), atol=1e-5) 88 | np.testing.assert_allclose(test_cx_2.numpy(), real_cx_2.numpy(), atol=1e-5) 89 | print("PASS") 90 | 91 | if __name__ == "__main__": 92 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import gymnasium as gym 3 | from model import Model, CnnLstmActorCritic, MaxAndSkipEnv, preprocess, EMBED_DIM 4 | from tinygrad import nn, Tensor, GlobalCounters 5 | 6 | BS = 16 7 | 8 | import pygame 9 | 10 | screen = None 11 | SCALE = 4 12 | def draw(x:Tensor): 13 | global screen 14 | if screen is None: 15 | pygame.init() 16 | screen = pygame.display.set_mode((x.shape[1]*SCALE, x.shape[2]*SCALE)) 17 | img = x.permute(2,1,0) 18 | surf = pygame.surfarray.make_surface((img*256).cast('uint8').repeat_interleave(SCALE, 0).repeat_interleave(SCALE, 1).numpy()) 19 | screen.blit(surf, (0, 0)) 20 | pygame.display.flip() 21 | 22 | if __name__ == "__main__": 23 | model = Model() 24 | nn.state.load_state_dict(model, nn.state.torch_load("last.pt")) 25 | ac = CnnLstmActorCritic(6) 26 | 27 | env = MaxAndSkipEnv(gym.make('PongNoFrameskip-v4')) 28 | obs, info = env.reset() 29 | img_0 = preprocess(obs).expand(BS, -1, -1, -1, -1) 30 | 31 | model.world_model.transformer.start_pos = 0 32 | ac.reset() 33 | for j in range(25): 34 | GlobalCounters.reset() 35 | print(img_0.shape) 36 | draw(img_0.rearrange("(bw bh) 1 c w h -> c (bw w) (bh h)", bw=4)) 37 | 38 | ac_out = ac(img_0) 39 | sampled_actions = ac_out.logits_actions.exp().softmax().squeeze(1).multinomial() 40 | 41 | frame_emb = img_0.sequential(model.world_model.frame_cnn)[:, 0] 42 | act_emb = model.world_model.act_emb(sampled_actions) 43 | out = model.world_model.transformer(frame_emb.cat(act_emb, dim=1)) 44 | latents = [] 45 | for i in range(4): 46 | logits_latents = model.world_model.head_latents(out[:, -1:]) 47 | latent = logits_latents.exp().softmax().squeeze(1).multinomial() 48 | latents.append(latent) 49 | latent_emb = model.world_model.latents_emb(latent) 50 | out = model.world_model.transformer(latent_emb) 51 | 52 | latents = model.tokenizer.quantizer.embed_tokens(Tensor.cat(*latents, dim=1)).unsqueeze(1) 53 | qq = latents.rearrange('b t (h w) (k l e) -> b t e (h k) (w l)', 54 | h=model.tokenizer.tokens_grid_res, k=model.tokenizer.token_res, l=model.tokenizer.token_res) 55 | img_0 = model.tokenizer.decode(img_0, sampled_actions, qq, should_clamp=True) 56 | --------------------------------------------------------------------------------