├── .gitignore ├── LICENSE ├── README.md ├── config └── tokenize.yaml ├── genie.py ├── genie ├── __init__.py ├── action.py ├── dataset.py ├── dynamics.py ├── genie.py ├── module │ ├── __init__.py │ ├── attention.py │ ├── data.py │ ├── discriminator.py │ ├── image.py │ ├── loss.py │ ├── misc.py │ ├── norm.py │ ├── quantization.py │ └── video.py ├── tokenizer.py └── utils.py ├── requirements.txt ├── res └── Genie.png ├── sample.py ├── setup.py ├── test ├── test_action.py ├── test_attention.py ├── test_dataset.py ├── test_discriminator.py ├── test_dynamics.py ├── test_image.py ├── test_loss.py ├── test_quantization.py ├── test_tokenizer.py └── test_video.py └── tokenizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.vs* 2 | *.pytest* 3 | *.pyc 4 | *.ipynb 5 | 6 | *DS_Store* 7 | 8 | *egg* 9 | *log* 10 | *.local* 11 | test/.* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Open-Genie: Generative Interactive Environments in PyTorch 2 | 3 | This repo contains the _unofficial_ implementation of _Genie: Generative Interactive Environments_ [Bruce et al. (2024)](https://arxiv.org/abs/2402.15391v1?curius=4125) as introduced by Google DeepMind. 4 | 5 | The goal of the model is to introduce "[...] The first generative interactive environment trained in an unsupervised manner from unlabelled Internet videos". 6 | 7 | ![Genie Banner](res/Genie.png) 8 | 9 | # Usage 10 | 11 | We provide a `LightningCLI` interface to easily train the several components of the `Genie` model. In particular, to train the `VideoTokenizer` one should run the following 12 | 13 | ```bash 14 | python tokenizer.py train -config 15 | ``` 16 | 17 | To train both the `LatentAction` and `Dynamics` model (use in turn would leverage a fully-trained `VideoTokenizer`), one can again simply run: 18 | 19 | ```bash 20 | python genie.py train -config 21 | ``` 22 | 23 | We provide example configuration files in the `📂 config` folder. 24 | 25 | In the following sections we provide example codes for the core building blocks that together form the overall Genie module. 26 | 27 | ## VideoTokenizer 28 | 29 | Genie relies on a `VideoTokenizer` which digests input videos and via its `encode`-`quantize` abilities converts them into discrete tokens. These tokens are what the `Dynamics` module uses to manipulate the _latent_ video space. The `VideoTokenizer` module accepts several parameters for extensive customization, here is an example code for a typical use: 30 | 31 | ```python 32 | from genie import VideoTokenizer 33 | 34 | # Pre-assembled description of MagViT2 35 | # encoder & decoder architecture 36 | from genie import MAGVIT2_ENC_DESC 37 | from genie import MAGVIT2_DEC_DESC 38 | 39 | tokenizer = VideoTokenizer( 40 | # We can pass an arbitrary description of the 41 | # encoder architecture, see genie.tokenizer.get_module 42 | # to see which module are supported 43 | enc_desc = ( 44 | 'causal', { # A CausalConv3d layer 45 | 'in_channels': 3, 46 | 'out_channels': 64, 47 | 'kernel_size': 3, 48 | }), 49 | ('residual', { # Residual Block 50 | 'in_channels': 64, 51 | 'kernel_size': 3, 52 | 'downsample': (1, 2), # Optional down-scaling (time, space) 53 | 'use_causal': True, # Using causal padding 54 | 'use_blur': True, # Using blur-pooling 55 | }), 56 | ('residual', { 57 | 'in_channels': 64, 58 | 'out_channels': 128, # Output channels can be different 59 | }), 60 | ('residual', { 61 | 'n_rep': 2, # We can repeat this block N-times 62 | 'in_channels': 128, 63 | }), 64 | ('residual', { 65 | 'in_channels': 128, 66 | 'out_channels': 256, # We can mix different output channels... 67 | 'kernel_size': 3, 68 | 'downsample': 2, # ...with down-sampling (here time=space=2) 69 | 'use_causal': True, 70 | }), 71 | ('proj_out', { # Output project to quantization module 72 | 'in_channels': 256, 73 | 'out_channels': 18, 74 | 'num_groups': 8, 75 | 'kernel_size': 3, 76 | }), 77 | # Save time, use a pre-made configuration! 78 | dec_desc = MAGVIT2_DEC_DESC, 79 | 80 | # Description of GAN discriminator 81 | disc_kwargs=dict( 82 | # Discriminator parameters 83 | inp_size = (64, 64), # Size of input frames 84 | model_dim = 64, 85 | dim_mults = (1, 2, 4), # Channel multipliers 86 | down_step = (None, 2, 2), # Down-sampling steps 87 | inp_channels = 3, 88 | kernel_size = 3, 89 | num_groups = 8, 90 | act_fn = 'leaky', # Use LeakyReLU as activation function 91 | use_blur = True, # Use BlurPooling for down-sampling 92 | use_attn = True, # Discriminator can have spatial attention 93 | num_heads = 4, # Number of (spatial) attention heads 94 | dim_head = 32, # Dimension of each spatial attention heads 95 | ), 96 | 97 | # Keyword for the LFQ module 98 | d_codebook = 18, # Codebook dimension, should match encoder output channels 99 | n_codebook = 1, # Support for multiple codebooks 100 | lfq_bias = True, 101 | lfq_frac_sample = 1., 102 | lfq_commit_weight = 0.25, 103 | lfq_entropy_weight = 0.1, 104 | lfq_diversity_weight = 1., 105 | # Keyword for the different loss 106 | perceptual_model = 'vgg16', # We pick VGG-16 for perceptual loss 107 | # Which layer should we record perceptual features from 108 | perc_feat_layers = ('features.6', 'features.13', 'features.18', 'features.25'), 109 | gan_discriminate='frames', # GAN discriminator looks at individual frames 110 | gan_frames_per_batch = 4, # How many frames to extract from each video to use for GAN 111 | gan_loss_weight = 1., 112 | perc_loss_weight = 1., 113 | quant_loss_weight = 1., 114 | ) 115 | 116 | batch_size = 4 117 | num_channels = 3 118 | num_frames = 16 119 | img_h, img_w = 64, 64 120 | 121 | # Example video tensor 122 | mock_video = torch.randn( 123 | batch_size, 124 | num_channels, 125 | num_frames, 126 | img_h, 127 | img_w 128 | ) 129 | 130 | # Tokenize input video 131 | tokens, idxs = tokenizer.tokenize(mock_video) 132 | 133 | # Tokenized video has shape: 134 | # (batch_size, d_codebook, num_frames // down_time, H // down_space, W // down_space) 135 | 136 | # To decode the video from tokens use: 137 | rec_video = tokenizer.decode(tokens) 138 | 139 | # To train the tokenizer (do many! times) 140 | loss, aux_losses = tokenizer(mock_video) 141 | loss.backward() 142 | ``` 143 | 144 | ## Latent Action Model 145 | 146 | Genie implements a `LatentAction` model whose sole task is to formalize a (discrete) codebook of latent actions. This codebook is small by design to encourage _interpretable_ actions (such as `MOVE_RIGHT`). In order to train such codebook the `LatentAction` model is build as a `VQ-VAE` model, where the encoder ingest the video (pixel) frames and produces (quantized) actions as latents. The decoder then ingest previous frame history and the current action to predict the next frame. Both the encoder and decoder are discarded at inference time as the action are provided by the user. 147 | 148 | The `LatentAction` model follows a similar design as the `VideoTokenizer`, where the encoder/decoder architectures can be specified via a `Blueprint`. Here is an example code to highlight the core components: 149 | 150 | ```python 151 | from genie import LatentAction 152 | from genie import LATENT_ACT_ENC 153 | 154 | model = LatentAction( 155 | # Use a pre-made configuration... 156 | enc_desc=LATENT_ACT_ENC, 157 | # ...Or specify a brand-new one 158 | dec_desc=( 159 | # Latent Action uses space-time transformer 160 | ('space-time_attn', { 161 | 'n_rep' : 2, 162 | 'n_embd' : 256, 163 | 'n_head' : 4, 164 | 'd_head' : 16, 165 | 'has_ext' : True, 166 | # Decoder uses latent action as external 167 | # conditioning for decoding! 168 | 'time_attn_kw' : {'key_dim' : 8}, 169 | }), 170 | # But we can also down/up-sample to manage resources 171 | # NOTE: Encoder & Decoder should work nicely together 172 | # so that down/up-samples cancel out 173 | ('spacetime_upsample', { 174 | 'in_channels' : 256, 175 | 'kernel_size' : 3, 176 | 'time_factor' : 1, 177 | 'space_factor' : 2, 178 | }), 179 | ('space-time_attn', { 180 | 'n_rep' : 2, 181 | 'n_embd' : 256, 182 | 'n_head' : 4, 183 | 'd_head' : 16, 184 | 'has_ext' : True, 185 | 'time_attn_kw' : {'key_dim' : 8}, 186 | }), 187 | ), 188 | d_codebook=8, # Small codebook to incentivize interpretability 189 | inp_channels=3, # Input video channel 190 | inp_shape=(64, 64), # Spatial frame dimensions 191 | n_embd=256, # Hidden model dimension 192 | # [...] Other kwargs for controlling LFQ module behavior 193 | ) 194 | 195 | # Create mock input video 196 | batch_size = 2 197 | video_len = 16 198 | frame_dim = 64, 64 199 | 200 | video = torch.randn(batch_size, 3, video_len, *frame_dim) 201 | 202 | # Encode the video to extract the latent actions 203 | (actions, encoded), quant_loss = model.encode(video) 204 | 205 | # Compute the reconstructed video and its loss 206 | recon, loss, aux_losses = model(video) 207 | 208 | # This should work! 209 | assert recon.shape == (batch_size, 3, video_len, *frame_dim) 210 | 211 | # Train the model 212 | loss.backward() 213 | ``` 214 | 215 | ## Dynamics Model 216 | 217 | The `DynamicsModel` is tasked to predict the next video token based on past video token and latent action histories. The architecture is based on the `MaskGIT` model from [Chang et al, (2022)](https://arxiv.org/abs/2202.04200). Here is an example code to highlight the core components: 218 | 219 | ```python 220 | from genie import DynamicsModel 221 | 222 | blueprint = ( 223 | # Describe a Space-Time Transformer 224 | ('space-time_attn', { 225 | 'n_rep' : 4, # Number of layers 226 | 'n_embd' : 256, # Hidden dimension 227 | 'n_head' : 4, # Number of attention heads 228 | 'd_head' : 16, # Dimension of each attention head 229 | 'transpose' : False, 230 | }), 231 | ) 232 | 233 | # Create the model 234 | tok_codebook = 16 # Dimension of video tokenizer codebook 235 | act_codebook = 4 # Dimension of latent action codebook 236 | dynamics = DynamicsModel( 237 | desc=blueprint, 238 | tok_vocab=tok_codebook, 239 | act_vocab=act_codebook, 240 | embed_dim=256, # Hidden dimension of the model 241 | ) 242 | 243 | batch_size = 2 244 | num_frames = 16 245 | img_size = 32 246 | 247 | # Create mock token and latent action inputs 248 | mock_tokens = torch.randint(0, tok_codebook, (batch_size, num_frames, img_size, img_size)) 249 | mock_act_id = torch.randint(0, act_codebook, (batch_size, num_frames)) 250 | 251 | # Compute the reconstruction loss based on Bernoulli 252 | # masking of input tokens 253 | loss = dynamics.compute_loss( 254 | mock_tokens, 255 | mock_act_id, 256 | ) 257 | 258 | # Generate the next video token 259 | new_tokens = dynamics.generate( 260 | mock_tokens, 261 | mock_act_id, 262 | steps=5, # Number of MaskGIT sampling steps 263 | ) 264 | 265 | assert new_tokes.shape == (batch_size, num_frame + 1, img_size, img_size) 266 | ``` 267 | 268 | # Roadmap 269 | 270 | - [x] Implement the video-tokenizer. Use the MagViT-2 tokenizer as described in [Yu et al., (2023)](https://magvit.cs.cmu.edu/v2/). 271 | - [x] Implement the Latent Action Model, a Vector-Quantized ST-Transformer. Predict game-action from past video frames. 272 | - [x] Implement the Dynamics Model, which takes past frames and actions and produces the new video frame. 273 | - [ ] Add functioning training script (Lightning). 274 | - [ ] Show some results. 275 | 276 | # Requirements 277 | 278 | Code was tested with Python 3.11+ and requires `torch 2.0+` (because of use of fast flash-attention). To install the required dependencies simply run `pip install -r requirements.txt` 279 | 280 | # Citations 281 | 282 | This repo builds upon the beautiful MagViT implementation by [lucidrains](https://github.com/lucidrains/magvit2-pytorch/tree/main) and the MaskGIT implementation from [valeoai](https://github.com/valeoai/Maskgit-pytorch/tree/main). 283 | 284 | ```bibtex 285 | @article{bruce2024genie, 286 | title={Genie: Generative Interactive Environments}, 287 | author={Bruce, Jake and Dennis, Michael and Edwards, Ashley and Parker-Holder, Jack and Shi, Yuge and Hughes, Edward and Lai, Matthew and Mavalankar, Aditi and Steigerwald, Richie and Apps, Chris and others}, 288 | journal={arXiv preprint arXiv:2402.15391}, 289 | year={2024} 290 | } 291 | ``` 292 | 293 | ```bibtex 294 | @article{yu2023language, 295 | title={Language Model Beats Diffusion--Tokenizer is Key to Visual Generation}, 296 | author={Yu, Lijun and Lezama, Jos{\'e} and Gundavarapu, Nitesh B and Versari, Luca and Sohn, Kihyuk and Minnen, David and Cheng, Yong and Gupta, Agrim and Gu, Xiuye and Hauptmann, Alexander G and others}, 297 | journal={arXiv preprint arXiv:2310.05737}, 298 | year={2023} 299 | } 300 | ``` 301 | 302 | ```bibtex 303 | @inproceedings{chang2022maskgit, 304 | title={Maskgit: Masked generative image transformer}, 305 | author={Chang, Huiwen and Zhang, Han and Jiang, Lu and Liu, Ce and Freeman, William T}, 306 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 307 | pages={11315--11325}, 308 | year={2022} 309 | } 310 | ``` 311 | -------------------------------------------------------------------------------- /config/tokenize.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 31415 2 | 3 | model: 4 | enc_desc: 5 | - - spacetime_downsample 6 | - in_channels : 3 7 | kernel_size : 3 8 | out_channels : 64 9 | time_factor : 1 10 | space_factor : 2 11 | - - space-time_attn 12 | - n_rep: 8 13 | n_head: 8 14 | d_head: 64 15 | dec_desc: 16 | - - space-time_attn 17 | - n_rep: 8 18 | n_head: 8 19 | d_head: 64 20 | - - depth2spacetime_upsample 21 | - in_channels : 64 22 | kernel_size : 3 23 | out_channels : 3 24 | time_factor : 1 25 | space_factor : 2 26 | disc_kwargs: 27 | inp_size: [64, 64] # Size of input frames 28 | model_dim: 64 # Dimension of the model 29 | dim_mults: [1, 2, 4] # Channel multipliers 30 | down_step: [null, 2, 2] # Down-sampling steps 31 | inp_channels: 3 32 | kernel_size: 3 33 | num_groups: 8 34 | act_fn: leaky # Use LeakyReLU as activation function 35 | use_blur: True # Use BlurPooling for down-sampling 36 | use_attn: True # Discriminator can have spatial attention 37 | num_heads: 4 # Number of (spatial) attention heads 38 | dim_head: 32 # Dimension of each spatial attention heads 39 | # 40 | d_codebook: 10 41 | n_codebook: 1 42 | # 43 | lfq_bias: True 44 | lfq_frac_sample: 1 45 | lfq_commit_weight: 0.25 46 | lfq_entropy_weight: 0.01 47 | lfq_diversity_weight: 1. 48 | # 49 | optimizer: 50 | class_path: torch.optim.AdamW 51 | init_args: 52 | lr: 1e-3 53 | weight_decay: 0.01 54 | # 55 | perceptual_model: vgg16 56 | perc_feat_layers: [features.6, features.13, features.18, features.25] 57 | gan_discriminate: frames 58 | gan_frames_per_batch: 4 59 | gan_loss_weight: 1. 60 | perc_loss_weight: 1. 61 | quant_loss_weight: 1. 62 | 63 | data: 64 | root: path/to/data 65 | env_name: Coinrun 66 | padding: none 67 | randomize: true 68 | transform: null 69 | num_frames: 64 70 | batch_size: 32 71 | output_format: c t h w 72 | 73 | trainer: 74 | max_epochs: 40 75 | accelerator: gpu 76 | devices: 1 77 | strategy: ddp_find_unused_parameters_false 78 | precision: 16-mixed 79 | log_every_n_steps: 16 80 | limit_val_batches: 32 81 | val_check_interval: 32 82 | callbacks: 83 | - class_path: lightning.pytorch.callbacks.ModelCheckpoint 84 | init_args: 85 | monitor: val_loss 86 | save_last: true 87 | logger: 88 | - class_path: lightning.pytorch.loggers.TensorBoardLogger 89 | init_args: 90 | save_dir: path/to/log 91 | name: genie-tokenizer 92 | version: null 93 | -------------------------------------------------------------------------------- /genie.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.cli import LightningCLI 2 | 3 | from genie import Genie 4 | from genie.dataset import LightningPlatformer2D 5 | 6 | def cli_main(): 7 | ''' 8 | Main function for the training script. 9 | ''' 10 | 11 | # That's all it takes for LightningCLI to work! 12 | # No need to call .fit() or .test() or anything like that. 13 | cli = LightningCLI( 14 | Genie, 15 | LightningPlatformer2D, 16 | ) 17 | 18 | if __name__ == '__main__': 19 | cli_main() -------------------------------------------------------------------------------- /genie/__init__.py: -------------------------------------------------------------------------------- 1 | from genie.tokenizer import VideoTokenizer 2 | from genie.tokenizer import MAGVIT2_ENC_DESC 3 | from genie.tokenizer import MAGVIT2_DEC_DESC 4 | 5 | from genie.action import LatentAction 6 | from genie.dynamics import DynamicsModel 7 | 8 | from genie.genie import Genie 9 | 10 | LATENT_ACT_ENC = ( 11 | ('space-time_attn', { 12 | 'n_rep' : 2, 13 | 'n_embd' : 256, 14 | 'n_head' : 4, 15 | 'd_head' : 16, 16 | }), 17 | ('spacetime_downsample', { 18 | 'in_channels' : 256, 19 | 'kernel_size' : 3, 20 | 'time_factor' : 1, 21 | 'space_factor' : 2, 22 | }), 23 | ('space-time_attn', { 24 | 'n_rep' : 2, 25 | 'n_embd' : 256, 26 | 'n_head' : 4, 27 | 'd_head' : 16, 28 | }), 29 | ) 30 | 31 | LATENT_ACT_DEC = ( 32 | ('space-time_attn', { 33 | 'n_rep' : 2, 34 | 'n_embd' : 256, 35 | 'n_head' : 4, 36 | 'd_head' : 16, 37 | 'has_ext' : True, 38 | 'time_attn_kw' : {'key_dim' : 8}, 39 | }), 40 | ('spacetime_upsample', { 41 | 'in_channels' : 256, 42 | 'kernel_size' : 3, 43 | 'time_factor' : 1, 44 | 'space_factor' : 2, 45 | }), 46 | ('space-time_attn', { 47 | 'n_rep' : 2, 48 | 'n_embd' : 256, 49 | 'n_head' : 4, 50 | 'd_head' : 16, 51 | 'has_ext' : True, 52 | 'time_attn_kw' : {'key_dim' : 8}, 53 | }), 54 | ) -------------------------------------------------------------------------------- /genie/action.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from torch import Tensor 3 | import torch.nn as nn 4 | 5 | from math import prod 6 | from torch.nn.functional import mse_loss 7 | 8 | from einops.layers.torch import Rearrange 9 | 10 | from genie.module import parse_blueprint 11 | from genie.module.quantization import LookupFreeQuantization 12 | from genie.module.video import CausalConv3d, Downsample, Upsample 13 | from genie.utils import Blueprint 14 | 15 | REPR_ACT_ENC = ( 16 | ('space-time_attn', { 17 | 'n_repr' : 8, 18 | 'n_heads': 8, 19 | 'd_head': 64, 20 | }), 21 | ) 22 | 23 | REPR_ACT_DEC = ( 24 | ('space-time_attn', { 25 | 'n_repr' : 8, 26 | 'n_heads': 8, 27 | 'd_head': 64, 28 | }), 29 | ) 30 | 31 | class LatentAction(nn.Module): 32 | '''Latent Action Model (LAM) used to distill latent actions 33 | from history of past video frames. The LAM model employs a 34 | VQ-VAE model to encode video frames into discrete latents. 35 | Both the encoder and decoder are based on spatial-temporal 36 | transformers. 37 | ''' 38 | 39 | def __init__( 40 | self, 41 | enc_desc: Blueprint, 42 | dec_desc: Blueprint, 43 | d_codebook: int, 44 | inp_channels: int = 3, 45 | inp_shape : int | Tuple[int, int] = (64, 64), 46 | ker_size : int | Tuple[int, int] = 3, 47 | n_embd: int = 256, 48 | n_codebook: int = 1, 49 | lfq_bias : bool = True, 50 | lfq_frac_sample : float = 1., 51 | lfq_commit_weight : float = 0.25, 52 | lfq_entropy_weight : float = 0.1, 53 | lfq_diversity_weight : float = 1., 54 | quant_loss_weight : float = 1., 55 | ) -> None: 56 | super().__init__() 57 | 58 | if isinstance(inp_shape, int): inp_shape = (inp_shape, inp_shape) 59 | 60 | self.proj_in = CausalConv3d( 61 | inp_channels, 62 | out_channels=n_embd, 63 | kernel_size=ker_size 64 | ) 65 | 66 | self.proj_out = CausalConv3d( 67 | n_embd, 68 | out_channels=inp_channels, 69 | kernel_size=ker_size 70 | ) 71 | 72 | # Build the encoder and decoder based on the blueprint 73 | self.enc_layers, self.enc_ext = parse_blueprint(enc_desc) 74 | self.dec_layers, self.dec_ext = parse_blueprint(dec_desc) 75 | 76 | # Keep track of space-time up/down factors 77 | enc_fact = prod(enc.factor for enc in self.enc_layers if isinstance(enc, (Downsample, Upsample))) 78 | dec_fact = prod(dec.factor for dec in self.dec_layers if isinstance(dec, (Downsample, Upsample))) 79 | 80 | assert enc_fact * dec_fact == 1, 'The product of the space-time up/down factors must be 1.' 81 | 82 | # Add the projections to the action space 83 | self.to_act = nn.Sequential( 84 | Rearrange('b c t ... -> b t (c ...)'), 85 | nn.Linear( 86 | int(n_embd * enc_fact * prod(inp_shape)), 87 | d_codebook, 88 | bias=False, 89 | ) 90 | ) 91 | 92 | # Build the quantization module 93 | self.quant = LookupFreeQuantization( 94 | codebook_dim = d_codebook, 95 | num_codebook = n_codebook, 96 | use_bias = lfq_bias, 97 | frac_sample = lfq_frac_sample, 98 | commit_weight = lfq_commit_weight, 99 | entropy_weight = lfq_entropy_weight, 100 | diversity_weight = lfq_diversity_weight, 101 | ) 102 | 103 | self.d_codebook = d_codebook 104 | self.n_codebook = n_codebook 105 | self.quant_loss_weight = quant_loss_weight 106 | 107 | def sample(self, idxs : Tensor) -> Tensor: 108 | '''Sample the action codebook values based on the indices.''' 109 | return self.quant.codebook[idxs] 110 | 111 | def encode( 112 | self, 113 | video: Tensor, 114 | mask : Tensor | None = None, 115 | transpose : bool = False, 116 | ) -> Tuple[Tuple[Tensor, Tensor], Tensor]: 117 | video = self.proj_in(video) 118 | 119 | # Encode the video frames into latent actions 120 | for enc in self.enc_layers: 121 | video = enc(video, mask=mask) 122 | 123 | # Project to latent action space 124 | act : Tensor = self.to_act(video) 125 | 126 | # Quantize the latent actions 127 | (act, idxs), q_loss = self.quant(act, transpose=transpose) 128 | 129 | return (act, idxs, video), q_loss 130 | 131 | def decode( 132 | self, 133 | video : Tensor, 134 | q_act : Tensor, 135 | ) -> Tensor: 136 | # Decode the video frames based on past history and 137 | # the quantized latent actions 138 | for dec, has_ext in zip(self.dec_layers, self.dec_ext): 139 | video = dec( 140 | video, 141 | cond=( 142 | None, # No space condition 143 | q_act if has_ext else None, 144 | ) 145 | ) 146 | 147 | recon = self.proj_out(video) 148 | 149 | return recon 150 | 151 | def forward( 152 | self, 153 | video: Tensor, 154 | mask : Tensor | None = None, 155 | ) -> Tuple[Tensor, Tensor]: 156 | 157 | # Encode the video frames into latent actions 158 | (act, idxs, enc_video), q_loss = self.encode(video, mask=mask) 159 | 160 | # Decode the last video frame based on all the previous 161 | # frames and the quantized latent actions 162 | recon = self.decode(enc_video, act) 163 | 164 | # Compute the reconstruction loss 165 | # Reconstruction loss 166 | rec_loss = mse_loss(recon, video) 167 | 168 | # Compute the total loss by combining the individual 169 | # losses, weighted by the corresponding loss weights 170 | loss = rec_loss\ 171 | + q_loss * self.quant_loss_weight 172 | 173 | return idxs, loss, ( 174 | rec_loss, 175 | q_loss, 176 | ) -------------------------------------------------------------------------------- /genie/dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from torchvision.datasets import Kinetics 4 | 5 | from typing import Callable, Tuple 6 | 7 | from genie.module.data import LightningDataset, Platformer2D 8 | 9 | class LightningKinetics(LightningDataset): 10 | '''Lightning Dataset class for the Kinetics dataset. 11 | ''' 12 | 13 | def __init__( 14 | self, 15 | root: str | Path, 16 | frames_per_clip: int, 17 | num_classes: str = '400', 18 | frame_rate: int | None = None, 19 | step_between_clips: int = 1, 20 | transform: Callable | None = None, 21 | extensions: Tuple[str, ...] = ('avi', 'mp4'), 22 | download: bool = False, 23 | num_download_workers: int = 1, 24 | num_workers: int = 1, 25 | output_format: str = 'CTHW', 26 | **kwargs, 27 | ) -> None: 28 | super().__init__(**kwargs) 29 | 30 | self.root = root 31 | 32 | self.download = download 33 | self.transform = transform 34 | self.extensions = extensions 35 | self.frame_rate = frame_rate 36 | self.num_classes = num_classes 37 | self.num_workers = num_workers 38 | self.output_format = output_format 39 | self.frames_per_clip = frames_per_clip 40 | self.step_between_clips = step_between_clips 41 | self.num_download_workers = num_download_workers 42 | 43 | self.save_hyperparameters() 44 | 45 | def setup(self, stage: str) -> None: 46 | 47 | match stage: 48 | case 'fit': 49 | self.train_dataset = Kinetics( 50 | root=self.root, 51 | split = 'train', 52 | download = self.download, 53 | transform = self.transform, 54 | extensions = self.extensions, 55 | frame_rate = self.frame_rate, 56 | num_classes = self.num_classes, 57 | num_workers = self.num_workers, 58 | output_format = self.output_format, 59 | frames_per_clip = self.frames_per_clip, 60 | step_between_clips = self.step_between_clips, 61 | num_download_workers = self.num_download_workers, 62 | ) 63 | self.valid_dataset = Kinetics( 64 | root=self.root, 65 | split = 'val', 66 | download = self.download, 67 | transform = self.transform, 68 | extensions = self.extensions, 69 | frame_rate = self.frame_rate, 70 | num_classes = self.num_classes, 71 | num_workers = self.num_workers, 72 | output_format = self.output_format, 73 | frames_per_clip = self.frames_per_clip, 74 | step_between_clips = self.step_between_clips, 75 | num_download_workers = self.num_download_workers, 76 | ) 77 | case 'test': 78 | self.test__dataset = Kinetics( 79 | root=self.root, 80 | split = 'test', 81 | download = self.download, 82 | transform = self.transform, 83 | extensions = self.extensions, 84 | frame_rate = self.frame_rate, 85 | num_classes = self.num_classes, 86 | num_workers = self.num_workers, 87 | output_format = self.output_format, 88 | frames_per_clip = self.frames_per_clip, 89 | step_between_clips = self.step_between_clips, 90 | num_download_workers = self.num_download_workers, 91 | ) 92 | case _: 93 | raise ValueError(f'Invalid stage: {stage}') 94 | 95 | class LightningPlatformer2D(LightningDataset): 96 | '''Lightning Dataset class for the Platformer2D Dataset. 97 | This dataset samples video recorded using a random agent 98 | playing the gym environments defined in the Procgen Benchmark, 99 | see Cobbe et al., ICML (2020). 100 | ''' 101 | 102 | def __init__( 103 | self, 104 | root: str | Path, 105 | env_name : str = 'Coinrun', 106 | padding : str = 'none', 107 | randomize : bool = False, 108 | transform : Callable | None = None, 109 | num_frames : int = 16, 110 | output_format: str = 't c h w', 111 | **kwargs, 112 | ) -> None: 113 | super().__init__(**kwargs) 114 | 115 | self.root = root 116 | 117 | self.padding = padding 118 | self.env_name = env_name 119 | self.transform = transform 120 | self.randomize = randomize 121 | self.num_frames = num_frames 122 | self.output_format = output_format 123 | 124 | self.save_hyperparameters() 125 | 126 | def setup(self, stage: str) -> None: 127 | 128 | match stage: 129 | case 'fit': 130 | self.train_dataset = Platformer2D( 131 | root=self.root, 132 | split = 'train', 133 | padding = self.padding, 134 | env_name = self.env_name, 135 | transform = self.transform, 136 | randomize = self.randomize, 137 | num_frames = self.num_frames, 138 | output_format= self.output_format, 139 | ) 140 | self.valid_dataset = Platformer2D( 141 | root=self.root, 142 | split = 'val', 143 | padding = self.padding, 144 | env_name = self.env_name, 145 | transform = self.transform, 146 | randomize = self.randomize, 147 | num_frames = self.num_frames, 148 | output_format= self.output_format, 149 | ) 150 | case 'test': 151 | self.test__dataset = Platformer2D( 152 | root=self.root, 153 | split = 'test', 154 | padding = self.padding, 155 | env_name = self.env_name, 156 | transform = self.transform, 157 | randomize = self.randomize, 158 | num_frames = self.num_frames, 159 | output_format= self.output_format, 160 | ) 161 | case _: 162 | raise ValueError(f'Invalid stage: {stage}') -------------------------------------------------------------------------------- /genie/dynamics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import Literal 4 | from math import inf, pi, prod 5 | from torch import Tensor, softmax 6 | from torch.nn.functional import cross_entropy 7 | 8 | from einops import pack, rearrange, unpack 9 | from einops.layers.torch import Rearrange 10 | 11 | from genie.utils import Blueprint, default 12 | from genie.module import parse_blueprint 13 | 14 | class DynamicsModel(nn.Module): 15 | '''Dynamics Model (DM) used to predict future video frames 16 | given the history of past video frames and the latent actions. 17 | The DM model employs the Mask-GIT architecture as introduced 18 | in Chang et al. (2022). 19 | ''' 20 | 21 | def __init__( 22 | self, 23 | desc: Blueprint, 24 | tok_vocab: int, 25 | act_vocab: int, 26 | embed_dim: int, 27 | ) -> None: 28 | super().__init__() 29 | 30 | self.dec_layers, self.ext_kw = parse_blueprint(desc) 31 | 32 | self.head = nn.Linear(embed_dim, tok_vocab) 33 | 34 | self.tok_emb = nn.Embedding(tok_vocab, embed_dim) 35 | self.act_emb = nn.Sequential( 36 | nn.Embedding(act_vocab, embed_dim), 37 | Rearrange('b t d -> b t 1 1 d'), 38 | ) 39 | 40 | self.tok_vocab = tok_vocab 41 | self.act_vocab = act_vocab 42 | self.embed_dim = embed_dim 43 | 44 | def forward( 45 | self, 46 | tokens : Tensor, 47 | act_id : Tensor, 48 | ) -> Tensor: 49 | ''' 50 | Predicts the next video token based on the previous tokens 51 | ''' 52 | 53 | # Actions are quantized, use them as additive embeddings to tokens 54 | # Token have shape (batch, seq_len, token_dim) 55 | tokens = self.tok_emb(tokens) + self.act_emb(act_id) 56 | 57 | # Predict the next video token based on previous tokens and actions 58 | for dec, has_ext in zip(self.dec_layers, self.ext_kw): 59 | tokens = dec(tokens) 60 | 61 | # Compute the next token probability 62 | logits = self.head(tokens) 63 | 64 | return logits, logits[:, -1] 65 | 66 | def compute_loss( 67 | self, 68 | tokens : Tensor, 69 | act_id : Tensor, 70 | mask : Tensor | None = None, 71 | fill : float = 0., 72 | ) -> Tensor: 73 | 74 | b, t, h, w = tokens.shape 75 | 76 | # Create Bernoulli mask if not provided 77 | mask = default(mask, torch.distributions.Bernoulli( 78 | torch.empty(1).uniform_(0.5, 1).item() # Random rate in [0.5, 1] 79 | ).sample((b, t, h, w)).bool() 80 | ) 81 | 82 | # Mask tokens based on external mask as training signal 83 | tokens = torch.masked_fill(tokens, mask, fill) 84 | 85 | # Compute the model prediction for the next token 86 | logits, _ = self(tokens, act_id.detach()) 87 | 88 | # Only compute loss on the tokens that were masked 89 | logits = logits[mask.squeeze()] 90 | tokens = tokens[mask.squeeze()] 91 | 92 | # Rearrange tokens to have shape (batch * seq_len, vocab_size) 93 | logits = rearrange(logits, '... d -> (...) d') 94 | target = rearrange(tokens, '... -> (...)') 95 | 96 | # Compute the cross-entropy loss between the predicted and actual tokens 97 | loss = cross_entropy(logits, target) 98 | 99 | return loss 100 | 101 | @torch.no_grad() 102 | def generate( 103 | self, 104 | tokens : Tensor, 105 | act_id : Tensor, 106 | steps : int = 10, 107 | which : Literal['linear', 'cosine', 'arccos'] = 'linear', 108 | temp : float = 1., 109 | topk : int = 50, 110 | masked_tok : int = 0, 111 | ) -> Tensor: 112 | ''' 113 | Given past token and action history, predicts the next token 114 | via the Mask-GIT sampling technique. 115 | ''' 116 | b, t, h, w = tokens.shape 117 | 118 | # Get the sampling schedule 119 | schedule = self.get_schedule(steps, shape=(h, w), which=which) 120 | 121 | # Initialize a fully active mask to signal that all the tokens 122 | # must receive a prediction. The mask will be updated at each 123 | # step based on the sampling schedule. 124 | mask = torch.ones(b, h, w, dtype=bool, device=tokens.device) 125 | code = torch.full((b, h, w), masked_tok, device=tokens.device) 126 | mock = torch.zeros(b, dtype=int, device=tokens.device) 127 | 128 | tok_id, ps = pack([tokens, code], 'b * h w') 129 | act_id, _ = pack([act_id, mock], 'b *') 130 | 131 | for num_tokens in schedule: 132 | # If no more tokens to predict, return 133 | if mask.sum() == 0: break 134 | 135 | # Get prediction for the next tokens 136 | _, logits = self(tok_id, act_id) 137 | 138 | # Refine the mask based on the sampling schedule 139 | prob = softmax(logits / temp, dim=-1) 140 | prob, ps = pack([prob], '* d') 141 | pred = torch.multinomial(prob, num_samples=1) 142 | conf = torch.gather(prob, -1, pred) 143 | conf = unpack(conf, ps, '* d')[0].squeeze() 144 | 145 | # We paint the k-tokens with highest confidence, excluding the 146 | # already predicted tokens from the mask 147 | conf[~mask.bool()] = -inf 148 | idxs = torch.topk(conf.view(b, -1), k=num_tokens, dim=-1).indices 149 | 150 | code, cps = pack([code], 'b *') 151 | mask, mps = pack([mask], 'b *') 152 | pred = pred.view(b, -1) 153 | 154 | # Fill the code with sampled tokens & update mask 155 | vals = torch.gather(pred, -1, idxs).to(code.dtype) 156 | code.scatter_(1, idxs, vals) 157 | mask.scatter_(1, idxs, False) 158 | 159 | code = unpack(code, cps, 'b *')[0] 160 | mask = unpack(mask, mps, 'b *')[0] 161 | 162 | pred_tok, ps = pack([tokens, code], 'b * h w') 163 | 164 | assert mask.sum() == 0, f'Not all tokens were predicted. {mask.sum()} tokens left.' 165 | return pred_tok 166 | 167 | def get_schedule( 168 | self, 169 | steps: int, 170 | shape: tuple[int, int], 171 | which: Literal['linear', 'cosine', 'arccos'] = 'linear', 172 | ) -> Tensor: 173 | n = prod(shape) 174 | t = torch.linspace(1, 0, steps) 175 | 176 | 177 | match which: 178 | case 'linear': 179 | s = 1 - t 180 | case 'cosine': 181 | s = torch.cos(t * pi * .5) 182 | case 'arccos': 183 | s = torch.acos(t) / (pi * .5) 184 | case _: 185 | raise ValueError(f'Unknown schedule type: {which}') 186 | 187 | # Fill the schedule with the ratio of tokens to predict 188 | schedule = (s / s.sum()) * n 189 | schedule = schedule.round().int().clamp(min=1) 190 | 191 | # Make sure that the total number of tokens to predict is 192 | # equal to the vocab size 193 | schedule[-1] += n - schedule.sum() 194 | 195 | return schedule -------------------------------------------------------------------------------- /genie/genie.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | from lightning import LightningModule 3 | from torch import Tensor 4 | import torch 5 | from torch.optim import AdamW 6 | from torch.optim import Optimizer 7 | 8 | from genie.action import LatentAction 9 | from genie.dynamics import DynamicsModel 10 | from genie.tokenizer import VideoTokenizer 11 | 12 | from typing import Callable, Iterable 13 | 14 | from genie.utils import default 15 | 16 | OptimizerCallable = Callable[[Iterable], Optimizer] 17 | 18 | class Genie(LightningModule): 19 | ''' 20 | Generative Interactive Environment model from Bruce et al. (2024). 21 | The model is composed of: 22 | - A (pre-trained) video tokenizer based on the MaskVit-2 architecture. 23 | - A Latent Action model that build a (quantized) dictionary of latent actions 24 | - A Dynamics Model that predicts the next frame given the current frame and the latent action. 25 | ''' 26 | def __init__( 27 | self, 28 | tokenizer : VideoTokenizer, 29 | optimizer : OptimizerCallable = AdamW, 30 | img_prompt : Tensor | None = None, 31 | ): 32 | super().__init__() 33 | 34 | # Pre-trained video tokenizer 35 | self.tokenizer = tokenizer 36 | 37 | self.latent_action = LatentAction( 38 | self.enc_desc, 39 | self.dec_desc, 40 | d_codebook=self.d_codebook, 41 | inp_channels=self.inp_channels, 42 | inp_shape=self.inp_shape, 43 | ker_size=self.ker_size, 44 | n_embd=self.n_embd, 45 | n_codebook=self.n_codebook, 46 | lfq_bias=self.lfq_bias, 47 | lfq_frac_sample=self.lfq_frac_sample, 48 | lfq_commit_weight=self.lfq_commit_weight, 49 | lfq_entropy_weight=self.lfq_entropy_weight, 50 | lfq_diversity_weight=self.lfq_diversity_weight, 51 | ) 52 | 53 | self.dynamics_model = DynamicsModel( 54 | desc=TEST_DESC, 55 | tok_vocab=self.tok_codebook, 56 | act_vocab=self.act_codebook, 57 | embed_dim=self.embed_dim, 58 | ) 59 | 60 | self.optimizer = optimizer 61 | self.img_prompt = img_prompt 62 | 63 | self.save_hyperparameters() 64 | 65 | @torch.no_grad() 66 | def forward( 67 | self, 68 | prompt : Tensor, 69 | actions : Tensor, 70 | num_frames : int | None = None, 71 | steps_per_frame : int = 25, 72 | ) -> Tensor: 73 | ''' 74 | Inference mode for the model. Generate videos from an initial 75 | image prompt and a sequence of latent actions. 76 | ''' 77 | num_frames = default(num_frames, actions.shape[1]) 78 | 79 | # Make sure prompt has correct shape for video 80 | match prompt.dim(): 81 | case 3: pattern = 'b h w -> b 1 1 h w' 82 | case 4: pattern = 'b c h w -> b c 1 h w' 83 | case 5: pattern = 'b c t h w -> b c t h w' 84 | case _: raise ValueError('Prompt must have 3, 4 or 5 dimensions') 85 | 86 | prompt = rearrange(prompt, pattern) 87 | 88 | # Tokenize the input prompt 89 | tokens = self.tokenizer.tokenize(prompt) 90 | 91 | for t in range(num_frames): 92 | # Predict the next frame based on the previous frame and the action 93 | new_tok = self.dynamics_model.generate( 94 | tokens, 95 | actions[:, :t], 96 | steps=steps_per_frame, 97 | ) 98 | 99 | # Add the new frame to the video 100 | tokens = torch.stack((tokens, new_tok), dim=2) 101 | 102 | # Return the generated video 103 | video = self.tokenizer.decode(tokens) 104 | 105 | return video 106 | 107 | def compute_loss(self, video : Tensor) -> Tensor: 108 | # Tokenize the input video 109 | tokens = self.tokenizer.tokenize(video) 110 | 111 | # Extract latent actions from the video 112 | act_id, act_loss, (act_rec_loss, act_q_loss) = self.latent_action(video) 113 | 114 | # Compute the next-frame prediction loss via the dynamics model 115 | dyn_loss = self.dynamics_model.compute_loss(tokens, act_id) 116 | 117 | # Combine both latent action and dynamics model losses 118 | loss = act_loss + dyn_loss 119 | 120 | return loss, ( 121 | ('act_loss', act_loss), 122 | ('dyn_loss', dyn_loss), 123 | ('act_rec_loss', act_rec_loss), 124 | ('act_q_loss', act_q_loss), 125 | ) 126 | 127 | def training_step(self, batch : Tensor, batch_idx : int) -> Tensor: 128 | # Compute the training loss 129 | loss, aux_losses = self.compute_loss(batch) 130 | 131 | # Log the training loss 132 | self.log_dict( 133 | {**{'train_loss' : loss}, **{f'train/{k}': v for k, v in aux_losses}}, 134 | logger=True, 135 | on_step=True, 136 | sync_dist=True, 137 | ) 138 | 139 | return loss 140 | 141 | def validation_step(self, batch : Tensor, batch_idx : int) -> Tensor: 142 | # Compute the validation loss 143 | loss, aux_losses = self.compute_loss(batch) 144 | 145 | # Log the training loss 146 | self.log_dict( 147 | {**{'val_loss' : loss}, **{f'val/{k}': v for k, v in aux_losses}}, 148 | logger=True, 149 | on_step=True, 150 | sync_dist=True, 151 | ) 152 | 153 | return loss 154 | 155 | def on_validation_end(self) -> None: 156 | '''Generate sample videos at the end of the validation loop''' 157 | 158 | # Generate a sample video from a given image prompt and random actions 159 | num_frames = 16 160 | prompt = default(self.img_prompt, torch.randn(1, 3, 64, 64)) 161 | actions = torch.randint(0, self.latent_action.d_codebook, size=(num_frames,)) 162 | 163 | video = self( 164 | prompt, 165 | actions, 166 | num_frames=num_frames, 167 | steps_per_frame=25 168 | ) 169 | 170 | self.logger.experiment.add_video( 171 | f'Generated Video #1', 172 | video, 173 | global_step=self.global_step, 174 | ) 175 | 176 | def configure_optimizers(self) -> Optimizer: 177 | optim = self.optimizer( 178 | self.parameters(), 179 | ) 180 | 181 | return optim -------------------------------------------------------------------------------- /genie/module/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | import torch.nn as nn 3 | 4 | from .attention import SpaceTimeAttention 5 | from .attention import SpatialAttention 6 | from .attention import TemporalAttention 7 | 8 | from genie.utils import Blueprint, default, exists 9 | from .image import BlurPooling2d 10 | from .image import SpaceDownsample 11 | from .image import ImageResidualBlock 12 | 13 | from .norm import AdaptiveGroupNorm 14 | 15 | from .video import CausalConv3d 16 | from .video import VideoResidualBlock 17 | from .video import CausalConvTranspose3d 18 | from .video import DepthToSpaceTimeUpsample 19 | from .video import DepthToSpaceUpsample 20 | from .video import DepthToTimeUpsample 21 | from .video import SpaceTimeDownsample 22 | 23 | def get_module(name : str) -> nn.Module: 24 | match name: 25 | # * Attention modules 26 | case 'space_attn': 27 | return SpatialAttention 28 | case 'time_attn': 29 | return TemporalAttention 30 | case 'space-time_attn': 31 | return SpaceTimeAttention 32 | # * Image modules 33 | case 'blur_pool': 34 | return BlurPooling2d 35 | case 'space_downsample': 36 | return SpaceDownsample 37 | case 'image-residual': 38 | return ImageResidualBlock 39 | # * Video modules 40 | case 'video-residual': 41 | return VideoResidualBlock 42 | case 'causal-conv3d': 43 | return CausalConv3d 44 | case 'causal-conv3d-transpose': 45 | return CausalConvTranspose3d 46 | case 'depth2space_upsample': 47 | return DepthToSpaceUpsample 48 | case 'depth2time_upsample': 49 | return DepthToTimeUpsample 50 | case 'depth2spacetime_upsample': 51 | return DepthToSpaceTimeUpsample 52 | case 'spacetime_downsample': 53 | return SpaceTimeDownsample 54 | # * Norm modules 55 | case 'group_norm': 56 | return nn.GroupNorm 57 | case 'adaptive_group_norm': 58 | return AdaptiveGroupNorm 59 | # * Activation modules 60 | case 'gelu': 61 | return nn.GELU 62 | case 'relu': 63 | return nn.ReLU 64 | case 'leaky_relu': 65 | return nn.LeakyReLU 66 | case 'silu': 67 | return nn.SiLU 68 | case _: 69 | raise ValueError(f'Unknown module name: {name}') 70 | 71 | def parse_blueprint( 72 | blueprint : Blueprint, 73 | ) -> Tuple[nn.ModuleList, List[bool]]: 74 | # Parse the blueprint 75 | layers = [] 76 | ext_kw = [] 77 | 78 | for desc in blueprint: 79 | if isinstance(desc, str): desc = (desc, {}) 80 | 81 | name, kwargs = default(desc, (None, {})) 82 | ext_kw.extend( 83 | [kwargs.pop('has_ext', False)] * kwargs.get('n_rep', 1) 84 | ) 85 | layers.extend( 86 | [ 87 | get_module(name)(**kwargs) 88 | for _ in range(kwargs.pop('n_rep', 1)) 89 | if exists(name) and exists(kwargs) 90 | ] 91 | ) 92 | 93 | return nn.ModuleList(layers), ext_kw -------------------------------------------------------------------------------- /genie/module/attention.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Tuple 2 | import torch 3 | import torch.nn as nn 4 | from math import pi 5 | from torch import Tensor 6 | from torch.nn.functional import scaled_dot_product_attention 7 | 8 | from einops import einsum, rearrange, repeat 9 | from einops import pack, unpack 10 | from einops.layers.torch import Rearrange 11 | 12 | from genie.module.misc import ForwardBlock 13 | from genie.utils import default, exists 14 | 15 | # Adapted from lucidrains/rotary-embedding-torch at: 16 | # https://github.com/lucidrains/rotary-embedding-torch/ 17 | class RotaryEmbedding(nn.Module): 18 | def __init__( 19 | self, 20 | dim : int, 21 | kind: Literal['1d', '2d', 'const'] = '1d', 22 | theta = 10000, 23 | max_freq = 10, 24 | num_freq = 1, 25 | learned_freq = False, 26 | interpolate_factor = 1., 27 | theta_rescale_factor = 1., 28 | ) -> None: 29 | super().__init__() 30 | 31 | theta *= theta_rescale_factor ** (dim / (dim - 2)) 32 | 33 | match kind: 34 | case '1d': 35 | freq = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 36 | case '2d': 37 | freq = torch.linspace(1., max_freq / 2, dim // 2) * pi 38 | case 'const': 39 | freq = torch.ones(num_freq).float() 40 | 41 | self.freq = nn.Parameter(freq, requires_grad=learned_freq) 42 | 43 | assert interpolate_factor >= 1. 44 | self.interpolate_factor = interpolate_factor 45 | 46 | self.default_seq_dim = -2 47 | 48 | def forward( 49 | self, 50 | seq : Tensor, 51 | seq_dim : int | None = None, 52 | offset = 0, 53 | ) -> Tensor: 54 | seq_dim = default(seq_dim, self.default_seq_dim) 55 | seq_len = seq.shape[seq_dim] 56 | 57 | freq = self.freq 58 | 59 | # Get sequence position 60 | pos = (torch.arange(seq_len, device=freq.device) + offset) / self.interpolate_factor 61 | 62 | freq = einsum(pos, freq, '..., f -> ... f') 63 | freq = repeat(freq, '... n -> ... (n r)', r = 2) 64 | 65 | if seq_dim == -3: freq = rearrange(freq, 'n d -> n 1 d') 66 | 67 | # Apply rotary embedding 68 | return self.apply(freq, seq, seq_dim = seq_dim) 69 | 70 | def apply( 71 | self, 72 | freq : Tensor, 73 | seq : Tensor, 74 | start_index : int = 0, 75 | scale : float = 1., 76 | seq_dim : int = -2 77 | ) -> Tensor: 78 | dtype = seq.dtype 79 | 80 | if seq.ndim == 3: 81 | seq_len = seq.shape[seq_dim] 82 | freq = freq[-seq_len:] 83 | 84 | rot_dim = freq.shape[-1] 85 | end_index = start_index + rot_dim 86 | 87 | assert rot_dim <= seq.shape[-1], f'feature dimension {seq.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' 88 | 89 | t_left, seq, t_right = seq[..., :start_index], seq[..., start_index:end_index], seq[..., end_index:] 90 | 91 | seq = (seq * freq.cos() * scale) + (self.rotate_half(seq) * freq.sin() * scale) 92 | out = torch.cat((t_left, seq, t_right), dim = -1) 93 | 94 | return out.type(dtype) 95 | 96 | def rotate_half(self, inp : Tensor) -> Tensor: 97 | inp = rearrange(inp, '... (d r) -> ... d r', r = 2) 98 | x1, x2 = inp.unbind(dim = -1) 99 | inp = torch.stack((-x2, x1), dim = -1) 100 | return rearrange(inp, '... d r -> ... (d r)') 101 | 102 | def get_seq_pos(self, seq_len, device, dtype, offset = 0): 103 | return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor 104 | 105 | class Adapter(nn.Module): 106 | def __init__( 107 | self, 108 | qry_dim : int, 109 | n_head : int, 110 | d_head : int, 111 | key_dim : int | None = None, 112 | val_dim : int | None = None, 113 | block : nn.Module | Tuple[nn.Module, ...] = nn.Linear, 114 | qry_kwargs : dict = {}, 115 | key_kwargs : dict = {}, 116 | val_kwargs : dict = {}, 117 | bias : bool = False, 118 | ) -> None: 119 | super().__init__() 120 | 121 | key_dim = default(key_dim, qry_dim) 122 | val_dim = default(val_dim, key_dim) 123 | 124 | if issubclass(block, nn.Module): 125 | block = (block, block, block) 126 | 127 | self.to_q = block[0](qry_dim, n_head * d_head, bias=bias, **qry_kwargs) if qry_dim != n_head * d_head else nn.Identity() 128 | self.to_k = block[1](key_dim, n_head * d_head, bias=bias, **key_kwargs) if key_dim != n_head * d_head else nn.Identity() 129 | self.to_v = block[2](val_dim, n_head * d_head, bias=bias, **val_kwargs) if val_dim != n_head * d_head else nn.Identity() 130 | 131 | self.n_head = n_head 132 | 133 | def forward( 134 | self, 135 | qry : Tensor, 136 | key : Tensor | None = None, 137 | val : Tensor | None = None, 138 | ) -> Tuple[Tensor, Tensor, Tensor]: 139 | key = default(key, qry) 140 | val = default(val, key) 141 | 142 | q = self.to_q(qry) 143 | k = self.to_k(key) 144 | v = self.to_v(val) 145 | 146 | qkv, ps = pack([q, k, v], '* n d') 147 | qkv = rearrange(qkv, 'qkv n (h d) -> qkv h n d', h=self.n_head) 148 | 149 | return unpack(qkv, ps, '* n h d') 150 | 151 | # Inspired by the two cool repos implementations at: 152 | # https://github.com/karpathy/nanoGPT/blob/master/model.py#L29 153 | # https://github.com/lucidrains/magvit2-pytorch/blob/main/magvit2_pytorch/magvit2_pytorch.py#L255 154 | class Attention(nn.Module): 155 | ''' 156 | Standard self-attention module as originally introduced 157 | in the paper "Attention is All You Need". This module 158 | uses the flash-attention implementation offered by 159 | PyTorch >= 2.0. 160 | ''' 161 | 162 | def __init__( 163 | self, 164 | n_head : int, 165 | d_head : int, 166 | d_inp : int | None = None, 167 | d_out : int | None = None, 168 | bias : bool = False, 169 | scale : float | None = None, 170 | causal : bool = False, 171 | dropout : float = 0.0, 172 | **kwargs, 173 | ) -> None: 174 | super().__init__() 175 | 176 | self.d_inp = default(d_inp, n_head * d_head) 177 | self.d_out = default(d_out, self.d_inp) 178 | 179 | self.norm = nn.LayerNorm(n_head * d_head) 180 | self.embed = nn.Identity() 181 | 182 | self.to_qkv = Adapter( 183 | qry_dim=self.d_inp, 184 | n_head=n_head, 185 | d_head=d_head, 186 | bias=bias, 187 | **kwargs, 188 | ) 189 | 190 | self.to_out = nn.Sequential( 191 | Rearrange('b h n d -> b n (h d)'), 192 | nn.Linear(n_head * d_head, self.d_out, bias=bias) if self.d_out != n_head * d_head else nn.Identity(), 193 | ) 194 | 195 | self.scale = default(scale, n_head * d_head ** -0.5) 196 | self.causal = causal 197 | self.dropout = dropout 198 | 199 | def forward( 200 | self, 201 | qry : Tensor, 202 | key : Tensor | None = None, 203 | val : Tensor | None = None, 204 | mask : Tensor | None = None, 205 | ) -> Tensor: 206 | ''' 207 | Apply self-attention mechanism to the input sequence. 208 | 209 | Args: 210 | qry (Tensor): Input sequence tensor of shape (batch_size, sequence_length, embedding_size). 211 | mask (Tensor, optional): Mask tensor of shape (batch_size, sequence_length) indicating which 212 | elements in the sequence should be masked. Defaults to None. 213 | 214 | Returns: 215 | Tensor: Output tensor after applying self-attention mechanism of shape 216 | (batch_size, sequence_length, embedding_size). 217 | ''' 218 | 219 | qry = self.embed(qry) 220 | qry = self.norm(qry) 221 | 222 | key = default(key, qry) 223 | val = default(val, key) 224 | 225 | # Project the input sequence into query, key, and value 226 | q, k, v = self.to_qkv(qry, key, val) 227 | 228 | # Compute the self-attention using fast flash-attention 229 | attn = scaled_dot_product_attention(q, k, v, 230 | attn_mask=mask, 231 | is_causal=self.causal, 232 | dropout_p=self.dropout, 233 | scale=self.scale, 234 | ) 235 | 236 | # Project the output back to the original embedding dimension 237 | out = self.to_out(attn) 238 | 239 | return out 240 | 241 | class SpatialAttention(Attention): 242 | ''' 243 | Attention module that applies self-attention across the 244 | spatial dimensions of the input tensor, expected to be 245 | either an image (4D tensor) or a video (5D tensor). 246 | ''' 247 | 248 | def __init__( 249 | self, 250 | n_head : int, 251 | d_head : int, 252 | d_inp : int | None = None, 253 | d_out : int | None = None, 254 | bias : bool = False, 255 | embed : bool = True, 256 | scale : float | None = None, 257 | causal : bool = False, 258 | dropout : float = 0.0, 259 | transpose : bool = False, 260 | **kwargs, 261 | ) -> None: 262 | super().__init__( 263 | n_head, 264 | d_head, 265 | d_inp, 266 | d_out, 267 | bias, 268 | scale, 269 | causal, 270 | dropout, 271 | **kwargs, 272 | ) 273 | 274 | # Use 2d-rotary embedding for spatial attention 275 | self.embed = RotaryEmbedding(self.d_inp, kind='2d') if embed else nn.Identity() 276 | 277 | self.transpose = transpose 278 | 279 | def forward( 280 | self, 281 | video : Tensor, 282 | cond : Tensor | None = None, 283 | mask: Tensor | None = None, 284 | transpose: bool | None = None, 285 | ) -> Tensor: 286 | transpose = default(transpose, self.transpose) 287 | 288 | pattern = 'b c ... h w' if transpose else 'b ... h w c' 289 | inp = rearrange(video, f'{pattern} -> b ... h w c') 290 | b, *t, h, w, c = video.shape 291 | 292 | inp, t_ps = pack([inp], '* h w c') 293 | inp, s_ps = pack([inp], 'b * c') 294 | 295 | # We expect the condition to be space-wise, i.e. of shape (batch, h * w, feat) 296 | cond = repeat(cond, 'b hw c -> (b t) hw c', t=t if exists(t) else 1) if exists(cond) else None 297 | 298 | out = super().forward( 299 | inp, 300 | key=cond, 301 | mask = mask, 302 | ) 303 | 304 | out = unpack(out, s_ps, 'b * c')[0] 305 | out = unpack(out, t_ps, '* h w c')[0] 306 | 307 | return rearrange(out, f'b ... h w c -> {pattern}') 308 | 309 | class TemporalAttention(Attention): 310 | ''' 311 | Attention module that applies self-attention across the 312 | temporal dimension of the input tensor, expected to be 313 | a 5D tensor of shape (batch, feat, time, height, width). 314 | ''' 315 | 316 | def __init__( 317 | self, 318 | n_head : int, 319 | d_head : int, 320 | d_inp : int | None = None, 321 | d_out : int | None = None, 322 | bias : bool = False, 323 | embed : bool = True, 324 | scale : float | None = None, 325 | causal : bool = False, 326 | dropout : float = 0.0, 327 | transpose : bool = False, 328 | **kwargs, 329 | ) -> None: 330 | super().__init__( 331 | n_head, 332 | d_head, 333 | d_inp, 334 | d_out, 335 | bias, 336 | scale, 337 | causal, 338 | dropout, 339 | **kwargs, 340 | ) 341 | 342 | # Use 1d-rotary embedding for temporal attention 343 | self.embed = RotaryEmbedding(self.d_inp, kind='1d') if embed else nn.Identity() 344 | 345 | self.transpose = transpose 346 | 347 | def forward( 348 | self, 349 | video : Tensor, 350 | cond : Tensor | None = None, 351 | mask : Tensor | None = None, 352 | transpose : bool | None = None, 353 | ) -> Tensor: 354 | transpose = default(transpose, self.transpose) 355 | 356 | pattern = 'b c t h w' if transpose else 'b t h w c' 357 | inp = rearrange(video, f'{pattern} -> b h w t c') 358 | b, h, w, *_ = inp.shape 359 | inp, ps = pack([inp], '* t c') 360 | 361 | # We expect the condition to be time-wise, i.e. of shape (batch, time, feat) 362 | cond = repeat(cond, 'b t c -> (b h w) t c', h=h, w=w) if exists(cond) else None 363 | 364 | out = super().forward( 365 | inp, 366 | key=cond, 367 | mask=mask, 368 | ) 369 | 370 | out = unpack(out, ps, '* t c')[0] 371 | return rearrange(out, f'b h w t c -> {pattern}') 372 | 373 | class SpaceTimeAttention(nn.Module): 374 | 375 | def __init__( 376 | self, 377 | n_head : int | Tuple[int, int], 378 | d_head : int | Tuple[int, int], 379 | d_inp : int | None = None, 380 | d_out : int | None = None, 381 | hid_dim : int | Tuple[int, int] | None = None, 382 | bias : bool = False, 383 | embed : bool | Tuple[bool, bool] = True, 384 | scale : float | None = None, 385 | dropout : float = 0.0, 386 | kernel_size : int = 3, 387 | transpose : bool = False, 388 | time_attn_kw : dict = {}, 389 | space_attn_kw : dict = {}, 390 | ) -> None: 391 | super().__init__() 392 | 393 | if isinstance(n_head, int): 394 | n_head = (n_head, n_head) 395 | if isinstance(d_head, int): 396 | d_head = (d_head, d_head) 397 | if isinstance(embed, bool): 398 | embed = (embed, embed) 399 | 400 | self.space_attn = SpatialAttention( 401 | n_head=n_head[0], 402 | d_head=d_head[0], 403 | d_inp=d_inp, 404 | d_out=None, 405 | bias=bias, 406 | scale=scale, 407 | embed=embed[0], 408 | causal=False, 409 | dropout=dropout, 410 | transpose=transpose, 411 | **space_attn_kw, 412 | ) 413 | 414 | self.temp_attn = TemporalAttention( 415 | n_head=n_head[1], 416 | d_head=d_head[1], 417 | d_inp=None, 418 | d_out=None, 419 | bias=bias, 420 | scale=scale, 421 | embed=embed[1], 422 | # * Causal attention for temporal attention 423 | causal=True, 424 | dropout=dropout, 425 | transpose=transpose, 426 | **time_attn_kw, 427 | ) 428 | 429 | self.ffn = ForwardBlock( 430 | n_head[1] * d_head[1], 431 | out_dim=d_out, 432 | hid_dim=hid_dim, 433 | num_groups=n_head[1], 434 | bias=bias, 435 | block=nn.Conv3d, 436 | kernel_size=kernel_size, 437 | padding=(kernel_size - 1) // 2, 438 | ) 439 | 440 | pattern = 'b c t h w' if transpose else 'b t h w c' 441 | self.ffn = nn.Sequential( 442 | Rearrange(f'{pattern} -> b c t h w'), 443 | self.ffn, 444 | Rearrange(f'b c t h w -> {pattern}'), 445 | ) 446 | 447 | self.in_channels = default(d_inp, n_head[0] * d_head[0]) 448 | self.out_channels = default(d_out, n_head[1] * d_head[1]) 449 | 450 | space_hid = d_head[0] * n_head[0] 451 | time_hid = d_head[1] * n_head[1] 452 | self.time_skip = nn.Identity() # Don't need this at the moment 453 | self.space_skip = nn.Conv3d(d_inp, space_hid, 1) if exists(d_inp) and d_inp != space_hid else nn.Identity() 454 | self.ffn_skip = nn.Conv3d(time_hid, d_out, 1) if exists(d_out) and time_hid != d_out else nn.Identity() 455 | 456 | def forward( 457 | self, 458 | video : Tensor, 459 | cond : Tuple[Tensor, Tensor] | Tensor | None = None, 460 | mask : Tensor | None = None, 461 | ) -> Tensor: 462 | if not isinstance(cond, tuple): 463 | cond = (cond, cond) 464 | 465 | space_cond, time_cond = cond 466 | 467 | # We feed the video first through the spatial attention 468 | # and then through the temporal attention mechanism. 469 | # NOTE: Positional embeddings are added within the attention 470 | video = self.space_attn(video, cond=space_cond, mask=mask) + self.space_skip(video) 471 | video = self.temp_attn (video, cond=time_cond , mask=mask) + self.time_skip (video) 472 | video = self.ffn(video) + self.ffn_skip(video) 473 | 474 | return video -------------------------------------------------------------------------------- /genie/module/data.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | import yaml 3 | import torch 4 | from cv2 import VideoCapture 5 | from cv2 import cvtColor 6 | from cv2 import COLOR_BGR2RGB 7 | from cv2 import CAP_PROP_POS_FRAMES 8 | from cv2 import CAP_PROP_FRAME_COUNT 9 | 10 | from os import listdir, path 11 | from abc import abstractmethod 12 | from random import randint 13 | 14 | from torch import Tensor 15 | from torch.utils.data import Dataset 16 | from torch.utils.data import DataLoader 17 | from torch.utils.data import IterableDataset 18 | 19 | from lightning import LightningDataModule 20 | 21 | from typing import Callable 22 | 23 | from genie.utils import default, exists 24 | from genie.utils import default_iterdata_worker_init 25 | 26 | class LightningDataset(LightningDataModule): 27 | ''' 28 | Abstract Lightning Data Module that represents a dataset we 29 | can train a Lightning module on. 30 | ''' 31 | 32 | @classmethod 33 | def from_config(cls, conf_path : str, *args, key : str = 'dataset') -> 'LightningDataset': 34 | ''' 35 | Construct a Lightning DataModule from a configuration file. 36 | ''' 37 | 38 | with open(conf_path, 'r') as f: 39 | conf = yaml.safe_load(f) 40 | 41 | data_conf = conf[key] 42 | 43 | return cls( 44 | *args, 45 | **data_conf, 46 | ) 47 | 48 | def __init__( 49 | self, 50 | *args, 51 | batch_size : int = 16, 52 | num_workers : int = 0, 53 | train_shuffle : bool | None = None, 54 | val_shuffle : bool | None = None, 55 | val_batch_size : None | int = None, 56 | worker_init_fn : None | Callable = None, 57 | collate_fn : None | Callable = None, 58 | train_sampler : None | Callable = None, 59 | val_sampler : None | Callable = None, 60 | test_sampler : None | Callable = None, 61 | ) -> None: 62 | super().__init__() 63 | 64 | self.train_dataset = None 65 | self.valid_dataset = None 66 | self.test__dataset = None 67 | 68 | val_batch_size = default(val_batch_size, batch_size) 69 | 70 | self.num_workers = num_workers 71 | self.batch_size = batch_size 72 | self.train_shuffle = train_shuffle 73 | self.val_shuffle = val_shuffle 74 | self.train_sampler = train_sampler 75 | self.valid_sampler = val_sampler 76 | self.test__sampler = test_sampler 77 | self.collate_fn = collate_fn 78 | self.worker_init_fn = worker_init_fn 79 | self.val_batch_size = val_batch_size 80 | 81 | @abstractmethod 82 | def setup(self, stage: str) -> None: 83 | msg = \ 84 | ''' 85 | This is an abstract datamodule class. You should use one of 86 | the concrete subclasses that represents an actual dataset. 87 | ''' 88 | 89 | raise NotImplementedError(msg) 90 | 91 | def train_dataloader(self) -> DataLoader: 92 | if isinstance(self.train_dataset, IterableDataset): 93 | worker_init_fn = default(self.worker_init_fn, default_iterdata_worker_init) 94 | else: 95 | worker_init_fn = self.worker_init_fn 96 | 97 | return DataLoader( 98 | self.train_dataset, # type: ignore 99 | sampler = self.train_sampler, # type: ignore 100 | batch_size = self.batch_size, 101 | shuffle = self.train_shuffle, 102 | collate_fn = self.collate_fn, 103 | num_workers = self.num_workers, 104 | worker_init_fn = worker_init_fn, 105 | ) 106 | 107 | def val_dataloader(self) -> DataLoader: 108 | if isinstance(self.train_dataset, IterableDataset): 109 | worker_init_fn = default(self.worker_init_fn, default_iterdata_worker_init) 110 | else: 111 | worker_init_fn = self.worker_init_fn 112 | 113 | return DataLoader( 114 | self.valid_dataset, # type: ignore 115 | sampler = self.valid_sampler, # type: ignore 116 | batch_size = self.val_batch_size, 117 | shuffle = self.val_shuffle, 118 | collate_fn = self.collate_fn, 119 | num_workers = self.num_workers, 120 | worker_init_fn = worker_init_fn, 121 | ) 122 | 123 | def test_dataloader(self) -> DataLoader: 124 | if isinstance(self.train_dataset, IterableDataset): 125 | worker_init_fn = default(self.worker_init_fn, default_iterdata_worker_init) 126 | else: 127 | worker_init_fn = self.worker_init_fn 128 | 129 | return DataLoader( 130 | self.test__dataset, # type: ignore 131 | sampler = self.test__sampler, # type: ignore 132 | batch_size = self.val_batch_size, 133 | shuffle = self.val_shuffle, 134 | collate_fn = self.collate_fn, 135 | num_workers = self.num_workers, 136 | worker_init_fn = worker_init_fn, 137 | ) 138 | 139 | class Platformer2D(Dataset): 140 | 141 | def __init__( 142 | self, 143 | root : str, 144 | split : str = 'train', 145 | env_name : str = 'Coinrun', 146 | padding : str = 'none', 147 | randomize : bool = False, 148 | transform : Callable | None = None, 149 | num_frames : int = 16, 150 | output_format: str = 't c h w', 151 | ) -> None: 152 | super().__init__() 153 | 154 | self.root = path.join(root, env_name, split) 155 | self.split = split 156 | self.padding = padding 157 | self.randomize = randomize 158 | self.num_frames = num_frames 159 | self.output_format = output_format 160 | self.transform = transform if exists(transform) else lambda x: x 161 | 162 | # Get all the file path based on the split 163 | self.file_names = [ 164 | path.join(self.root, f) 165 | for f in listdir(self.root) 166 | ] 167 | 168 | def __len__(self) -> int: 169 | return len(self.file_names) 170 | 171 | def __getitem__(self, idx: int) -> Tensor: 172 | video_path = self.file_names[idx] 173 | 174 | video = self.load_video_slice( 175 | video_path, 176 | self.num_frames, 177 | None if self.randomize else 0, 178 | ) 179 | 180 | return video 181 | 182 | def load_video_slice( 183 | self, 184 | video_path : str, 185 | num_frames : int, 186 | start_frame : int | None = None 187 | ) -> Tensor: 188 | cap = VideoCapture(video_path) 189 | total_frames = int(cap.get(CAP_PROP_FRAME_COUNT)) 190 | 191 | # If video is shorted than the requested number of frames 192 | # we just return the whole video 193 | num_frames = min(num_frames, total_frames) 194 | 195 | start_frame = start_frame if exists(start_frame) else randint(0, total_frames - num_frames) 196 | cap.set(CAP_PROP_POS_FRAMES, start_frame) 197 | 198 | frames = [] 199 | for _ in range(num_frames): 200 | ret, frame = cap.read() 201 | if ret: 202 | # *Frame was successfully read, parse it 203 | frame = cvtColor(frame, COLOR_BGR2RGB) 204 | frame = torch.from_numpy(frame) 205 | frames.append(frame) 206 | 207 | else: 208 | # * We reached the end of video 209 | # Deal with padding and return 210 | match self.padding: 211 | case 'none': pass 212 | case 'repeat': 213 | frames.extend([frames[-1]] * (num_frames - len(frames))) 214 | case 'zero': 215 | frames.extend([ 216 | torch.zeros_like(frames[-1]) 217 | ] * (num_frames - len(frames)) 218 | ) 219 | case 'random': 220 | frames.extend([ 221 | torch.rand_like(frames[-1]) 222 | ] * (num_frames - len(frames)) 223 | ) 224 | case _: 225 | raise ValueError(f'Invalid padding type: {self.padding}') 226 | break 227 | 228 | cap.release() 229 | video = torch.stack(frames) / 255. 230 | video = rearrange(video, f't h w c -> {self.output_format}') 231 | 232 | video = self.transform(video) 233 | 234 | return video -------------------------------------------------------------------------------- /genie/module/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch import Tensor 3 | from typing import Tuple 4 | from itertools import pairwise 5 | from einops.layers.torch import Rearrange 6 | 7 | from math import prod 8 | 9 | from genie.module.misc import ForwardBlock 10 | from genie.module.video import CausalConv3d 11 | from genie.module.image import ImageResidualBlock 12 | from genie.module.video import VideoResidualBlock 13 | 14 | from genie.module.attention import SpatialAttention 15 | from genie.utils import default 16 | 17 | class FrameDiscriminator(nn.Module): 18 | 19 | def __init__( 20 | self, 21 | inp_size : int | Tuple[int, int], 22 | model_dim : int = 64, 23 | dim_mults : Tuple[int, ...] = (1, 2, 4), 24 | down_step : Tuple[int | None, ...] = (None, 2, 2), 25 | inp_channels : int = 3, 26 | kernel_size : int | Tuple[int, int] = 3, 27 | num_groups : int = 1, 28 | num_heads : int = 4, 29 | dim_head : int = 32, 30 | use_attn : bool = False, 31 | use_blur : bool = True, 32 | act_fn : str = 'leaky', 33 | ) -> None: 34 | super().__init__() 35 | 36 | if isinstance(inp_size, int): 37 | inp_size = (inp_size, inp_size) 38 | 39 | # Assemble model core based on dimension schematics 40 | dims = [model_dim * mult for mult in dim_mults] 41 | 42 | assert len(dims) == len(down_step), "Dimension and downsample steps must match." 43 | 44 | self.proj_in = nn.Conv2d( 45 | inp_channels, 46 | model_dim, 47 | kernel_size=3, 48 | padding=1, 49 | ) 50 | 51 | self.core = nn.ModuleList([]) 52 | 53 | for (inp_dim, out_dim), down in zip(pairwise(dims), down_step): 54 | res_block = ImageResidualBlock( 55 | inp_dim, 56 | out_dim, 57 | downsample=down, 58 | num_groups=num_groups, 59 | kernel_size=kernel_size, 60 | ) 61 | 62 | attn_block = nn.ModuleList([ 63 | SpatialAttention( 64 | n_head=num_heads, 65 | d_head=dim_head, 66 | d_inp=out_dim, 67 | ), 68 | ForwardBlock( 69 | in_dim=out_dim, 70 | hid_dim=4 * out_dim, 71 | block=nn.Conv2d, 72 | kernel_size=1, 73 | ) 74 | ]) if use_attn else nn.ModuleList([ 75 | nn.Identity(), 76 | nn.Identity(), 77 | ]) 78 | 79 | self.core.append(nn.ModuleList( 80 | [ 81 | res_block, 82 | attn_block 83 | ] 84 | )) 85 | 86 | inp_size = tuple(map(lambda x: x // (down or 1), inp_size)) 87 | 88 | # Compute latent dimension as the product of the last dimension and the frame size 89 | latent_dim = out_dim * prod(inp_size) 90 | 91 | self.to_logits = nn.Sequential( 92 | nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1), 93 | nn.LeakyReLU(), 94 | Rearrange('b ... -> b (...)'), 95 | nn.Linear(latent_dim, 1), 96 | Rearrange('b 1 -> b') 97 | ) 98 | 99 | def forward( 100 | self, 101 | image : Tensor, 102 | ) -> Tensor: 103 | 104 | out = self.proj_in(image) 105 | 106 | for res, (attn, ff) in self.core: 107 | # Apply residual block 108 | out = res(out) 109 | 110 | # Apply attention block 111 | out = attn(out) + out 112 | out = ff(out) + out 113 | 114 | return self.to_logits(out) 115 | 116 | class VideoDiscriminator(nn.Module): 117 | 118 | def __init__( 119 | self, 120 | inp_size : Tuple[int, int] | Tuple[int, int, int], 121 | model_dim : int = 64, 122 | dim_mults : Tuple[int, ...] = (1, 2, 4), 123 | down_step : Tuple[int | Tuple[int, int] | None, ...] = (None, 2, 2), 124 | inp_channels : int = 3, 125 | kernel_size : int | Tuple[int, int] = 3, 126 | num_groups : int = 1, 127 | num_heads : int = 4, 128 | dim_head : int = 32, 129 | act_fn : str = 'leaky', 130 | use_attn : bool = False, 131 | use_blur : bool = True, 132 | use_causal : bool = False, 133 | ) -> None: 134 | super().__init__() 135 | 136 | if len(inp_size) == 2: 137 | inp_size = (inp_size[0], inp_size[1], inp_size[1]) 138 | 139 | Conv3d = CausalConv3d if use_causal else nn.Conv3d 140 | 141 | # Assemble model core based on dimension schematics 142 | dims = [model_dim * mult for mult in dim_mults] 143 | 144 | assert len(dims) == len(down_step), "Dimension and downsample steps must match." 145 | 146 | self.proj_in = Conv3d( 147 | inp_channels, 148 | model_dim, 149 | kernel_size=kernel_size, 150 | padding=1, 151 | ) 152 | 153 | self.core = nn.ModuleList([]) 154 | 155 | for (inp_dim, out_dim), down in zip(pairwise(dims), down_step): 156 | res_block = VideoResidualBlock( 157 | inp_dim, 158 | out_dim, 159 | downsample=down, 160 | num_groups=num_groups, 161 | kernel_size=kernel_size, 162 | act_fn=act_fn, 163 | use_blur=use_blur, 164 | use_causal=use_causal, 165 | ) 166 | 167 | attn_block = nn.ModuleList([ 168 | SpatialAttention( 169 | out_dim, 170 | n_head=num_heads, 171 | d_head=dim_head, 172 | ), 173 | ForwardBlock( 174 | in_dim=out_dim, 175 | hid_dim=4 * out_dim, 176 | block=Conv3d, 177 | kernel_size=1, 178 | ) 179 | ]) if use_attn else nn.ModuleList([ 180 | nn.Identity(), 181 | nn.Identity(), 182 | ]) 183 | 184 | self.core.append(nn.ModuleList( 185 | [ 186 | res_block, 187 | attn_block 188 | ] 189 | )) 190 | 191 | down = default(down, (1, 1, 1)) 192 | if isinstance(down, int): down = (down, down, down) 193 | if len(down) == 2: down = (down[0], down[1], down[1]) 194 | inp_size = tuple(map(lambda x, y: x // y, inp_size, down)) 195 | 196 | # Compute latent dimension as the product of the last dimension and the frame size 197 | latent_dim = out_dim * prod(inp_size) 198 | 199 | self.to_logits = nn.Sequential( 200 | nn.Conv3d(out_dim, out_dim, kernel_size=3, padding=1), 201 | nn.LeakyReLU(), 202 | Rearrange('b ... -> b (...)'), 203 | nn.Linear(latent_dim, 1), 204 | Rearrange('b 1 -> b') 205 | ) 206 | 207 | def forward( 208 | self, 209 | image : Tensor, 210 | ) -> Tensor: 211 | 212 | out = self.proj_in(image) 213 | 214 | for res, (attn, ff) in self.core: 215 | # Apply residual block 216 | out = res(out) 217 | 218 | # Apply attention block 219 | out = attn(out) + out 220 | out = ff(out) + out 221 | 222 | return self.to_logits(out) -------------------------------------------------------------------------------- /genie/module/image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | from typing import Tuple 5 | from math import comb 6 | 7 | from torch.types import Device 8 | 9 | from torch.nn.functional import conv2d 10 | 11 | from einops import repeat 12 | from einops.layers.torch import Rearrange 13 | 14 | from genie.utils import exists 15 | from genie.utils import default 16 | 17 | def get_blur_kernel( 18 | kernel_size : int | Tuple[int, int], 19 | device : Device = None, 20 | dtype : torch.dtype | None = None, 21 | norm : bool = True 22 | ) -> Tensor: 23 | if isinstance(kernel_size, int): 24 | kernel_size = (kernel_size, kernel_size) 25 | 26 | # Construct the 1d pascal blur kernel 27 | ker_a_1d = torch.tensor( 28 | [comb(kernel_size[0] - 1, i) for i in range(kernel_size[0])], 29 | device=device, 30 | dtype=dtype, 31 | ).unsqueeze(-1) 32 | ker_b_1d = torch.tensor( 33 | [comb(kernel_size[1] - 1, i) for i in range(kernel_size[0])], 34 | device=device, 35 | dtype=dtype, 36 | ).unsqueeze(0) 37 | 38 | 39 | ker_2d = ker_a_1d @ ker_b_1d 40 | 41 | return ker_2d / ker_2d.sum() if norm else ker_2d 42 | 43 | # Inspired by the (very cool) kornia library, see the original implementation here: 44 | # https://github.com/kornia/kornia/blob/e461f92ff9ee035d2de2513859bee4069356bc25/kornia/filters/blur_pool.py#L21 45 | class BlurPooling2d(nn.Module): 46 | def __init__( 47 | self, 48 | kernel_size : int | Tuple[int, int], 49 | # Expected kwargs are the same as the one accepted by Conv2d 50 | stride : int | Tuple[int, int] = 2, 51 | num_groups : int = 1, 52 | **kwargs, 53 | ) -> None: 54 | super().__init__() 55 | 56 | # Register the blurring kernel buffer 57 | self.register_buffer('blur', get_blur_kernel(kernel_size)) 58 | 59 | self.stride = stride 60 | self.kwargs = kwargs 61 | self.num_groups = num_groups 62 | 63 | str_h, str_w = stride if isinstance(stride, tuple) else (stride, stride) 64 | ker_h, ker_w = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size) 65 | self.padding = (ker_h - 1) // str_h, (ker_w - 1) // str_w 66 | 67 | def forward( 68 | self, 69 | inp : Tensor, 70 | ) -> Tensor: 71 | b, c, h, w = inp.shape 72 | 73 | # Repeat spatial kernel for each channel of input image 74 | ker = repeat(self.blur, 'i j -> c g i j', c=c, g=c // self.num_groups) 75 | 76 | # Compute the blur as 2d convolution 77 | return conv2d( 78 | inp, ker, 79 | stride=self.stride, 80 | padding=self.padding, 81 | groups=self.num_groups, 82 | **self.kwargs 83 | ) 84 | 85 | class SpaceDownsample(nn.Module): 86 | def __init__( 87 | self, 88 | in_dim : int, 89 | factor : int = 2, 90 | ) -> None: 91 | super().__init__() 92 | 93 | self.go_up = nn.Sequential( 94 | Rearrange('b c (h p) (w q) -> b (c p q) h w', p=factor, q=factor), 95 | nn.Conv2d(in_dim * factor ** 2, in_dim, kernel_size=1), 96 | ) 97 | 98 | def forward( 99 | self, 100 | inp : Tensor, 101 | ) -> Tensor: 102 | return self.go_up(inp) 103 | 104 | class ImageResidualBlock(nn.Module): 105 | 106 | def __init__( 107 | self, 108 | inp_channel : int, 109 | out_channel : int | None = None, 110 | kernel_size : int | Tuple[int, int] = 3, 111 | padding : int | Tuple[int, int] = 1, 112 | num_groups : int = 1, 113 | downsample : int | None = None, 114 | ) -> None: 115 | super().__init__() 116 | 117 | self.res = nn.Conv2d( 118 | inp_channel, 119 | out_channel, 120 | kernel_size=1, 121 | stride=default(downsample, 1), 122 | ) if exists(out_channel) else nn.Identity() 123 | 124 | out_channel = default(out_channel, inp_channel) 125 | 126 | self.main = nn.Sequential( 127 | nn.GroupNorm(num_groups, inp_channel), 128 | nn.LeakyReLU(), 129 | nn.Conv2d( 130 | inp_channel, 131 | out_channel, 132 | kernel_size=kernel_size, 133 | padding=padding, 134 | ), 135 | nn.GroupNorm(num_groups, out_channel), 136 | nn.LeakyReLU(), 137 | nn.Conv2d( 138 | out_channel, 139 | out_channel, 140 | kernel_size=kernel_size, 141 | padding=padding, 142 | ), 143 | *( 144 | [SpaceDownsample(out_channel, downsample)] 145 | if exists(downsample) and downsample 146 | else [] 147 | ) 148 | ) 149 | 150 | def forward( 151 | self, 152 | inp : Tensor, 153 | ) -> Tensor: 154 | """ 155 | Forward pass of the residual block. 156 | 157 | Args: 158 | inp (Tensor): The input tensor. 159 | 160 | Returns: 161 | Tensor: The output tensor after applying the residual block operations. 162 | """ 163 | return self.main(inp) + self.res(inp) -------------------------------------------------------------------------------- /genie/module/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | from torchvision.models import get_model 5 | 6 | from torch.nn.functional import relu 7 | from torch.nn.functional import mse_loss 8 | from torch.nn.modules.loss import _Loss 9 | from typing import Iterable, Tuple 10 | 11 | from genie.module.misc import NamingProbe 12 | from genie.module.misc import RecordingProbe 13 | from genie.module.discriminator import FrameDiscriminator, VideoDiscriminator 14 | from genie.utils import pick_frames 15 | 16 | VGG16_RELU_LAYERS = [ 17 | 'features.1', 18 | 'features.3', 19 | 'features.6', 20 | 'features.8', 21 | 'features.11', 22 | 'features.13', 23 | 'features.15', 24 | 'features.18', 25 | 'features.20', 26 | 'features.22', 27 | 'features.25', 28 | 'features.27', 29 | 'features.29', 30 | 'classifier.1', 31 | 'classifier.4', 32 | ] 33 | 34 | class PerceptualLoss(_Loss): 35 | 36 | def __init__( 37 | self, 38 | model_name : str = 'vgg16', 39 | model_weights : str | None = 'DEFAULT', 40 | num_frames : int = 4, 41 | feat_layers : str | Iterable[str] = ('features.6', 'features.13', 'features.18', 'features.25'), 42 | ) -> None: 43 | super().__init__() 44 | 45 | self.num_frames = num_frames 46 | self.percept_model = get_model(model_name, weights=model_weights) 47 | 48 | # Freeze the perceptual model 49 | self.percept_model.eval() 50 | for param in self.percept_model.parameters(): 51 | param.requires_grad = False 52 | 53 | # Attach the naming probe to make sure every layer 54 | # in the percept model has a unique identifier 55 | self.namer = NamingProbe() 56 | handles = [ 57 | module.register_forward_hook(self.namer) 58 | for name, module in self.percept_model.named_modules() 59 | ] 60 | 61 | # Fake forward pass to the model to trigger the probe 62 | with torch.no_grad(): 63 | _ = self.percept_model(torch.randn(1, 3, 224, 224)) 64 | for handle in handles: handle.remove() 65 | 66 | # Attach hooks to the model at desired locations 67 | self.probe = RecordingProbe() 68 | self.hook_handles = [ 69 | module.register_forward_hook(self.probe) 70 | for name, module in self.percept_model.named_modules() 71 | if name in feat_layers 72 | ] 73 | 74 | assert len(self.hook_handles) > 0, 'No valid layers found in the perceptual model.' 75 | 76 | def forward(self, rec_video : Tensor, inp_video : Tensor) -> Tensor: 77 | b, c, t, h, w = inp_video.shape 78 | 79 | # Extract a set of random frames from the input video 80 | 81 | frames_idxs = torch.cat([ 82 | torch.randperm(t, device=inp_video.device)[:self.num_frames] 83 | for _ in range(b)] 84 | ) 85 | 86 | fake_frames = pick_frames(rec_video, frames_idxs=frames_idxs) 87 | real_frames = pick_frames(inp_video, frames_idxs=frames_idxs) 88 | 89 | # Get the perceptual features for the input 90 | _ = self.percept_model(fake_frames) 91 | fake_feat = self.probe.features 92 | self.probe.clean() 93 | 94 | # Get the perceptual features for the target 95 | _ = self.percept_model(real_frames) 96 | real_feat = self.probe.features 97 | self.probe.clean() 98 | 99 | # Perceptual loss is the average MSE between the features 100 | return torch.stack([ 101 | mse_loss(fake_feat[k], real_feat[k]) 102 | for k in fake_feat.keys() 103 | ]).mean() 104 | 105 | def __del__(self) -> None: 106 | for handle in self.hook_handles: 107 | handle.remove() 108 | 109 | class GANLoss(_Loss): 110 | 111 | def __init__( 112 | self, 113 | discriminate : str = 'frames', 114 | num_frames : int = 4, 115 | **kwargs, 116 | ) -> None: 117 | super().__init__() 118 | 119 | assert discriminate in ('frames', 'video'), 'Invalid discriminator type. Must be either "frames" or "video".' 120 | 121 | self.disc = FrameDiscriminator(**kwargs) if discriminate == 'frames' else VideoDiscriminator(**kwargs) 122 | 123 | self.num_frames = num_frames 124 | self.discriminate = discriminate 125 | 126 | def get_examples( 127 | self, 128 | rec_video : Tensor, 129 | inp_video : Tensor, 130 | ) -> Tuple[Tensor, Tensor]: 131 | b, c, t, h, w = inp_video.shape 132 | 133 | if self.discriminate == 'video': 134 | return rec_video, inp_video 135 | 136 | # Extract a set of random frames from the input video 137 | frame_idxs = torch.cat([ 138 | torch.randperm(t, device=inp_video.device)[:self.num_frames] 139 | for _ in range(b)] 140 | ) 141 | fake = pick_frames(rec_video, frame_idxs) 142 | real = pick_frames(inp_video, frame_idxs) 143 | 144 | return fake, real 145 | 146 | def forward( 147 | self, 148 | rec_video : Tensor, 149 | inp_video : Tensor, 150 | train_gen : bool, 151 | ) -> Tensor: 152 | b, c, t, h, w = inp_video.shape 153 | 154 | # Extract a set of random frames from the input video 155 | fake, real = self.get_examples(rec_video, inp_video) 156 | 157 | # Compute discriminator opinions for real and fake frames 158 | fake_score : Tensor = self.disc(fake) if train_gen else self.disc(fake.detach()) 159 | real_score : Tensor = self.disc(real) if not train_gen else None 160 | 161 | # Compute hinge loss for the discriminator 162 | gan_loss = -fake_score.mean() if train_gen else (relu(1 + fake_score) + relu(1 - real_score)).mean() 163 | 164 | return gan_loss -------------------------------------------------------------------------------- /genie/module/misc.py: -------------------------------------------------------------------------------- 1 | from itertools import pairwise 2 | from uuid import uuid4 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from typing import Dict, List, Tuple 7 | 8 | from einops import rearrange 9 | from collections import defaultdict 10 | 11 | from genie.utils import default 12 | 13 | class NamingProbe: 14 | 15 | def __init__(self, name_attr : str = 'name') -> None: 16 | super().__init__() 17 | 18 | self.depth = -1 19 | self.name_attr = name_attr 20 | 21 | def __call__( 22 | self, 23 | module : nn.Module, 24 | inp : Tuple[Tensor, ...], 25 | out : Tensor 26 | ) -> None: 27 | '''Custom torch hook designed to record hidden activations. 28 | NOTE: This function should be called (implicitly) by the 29 | forward hook registered on the desired module. 30 | ''' 31 | 32 | self.depth += 1 33 | 34 | # Build unique module name as identifier 35 | name = f'{module._get_name().lower()}_{self.depth}_{uuid4().hex[:6]}' 36 | 37 | setattr(module, self.name_attr, name) 38 | 39 | class RecordingProbe: 40 | 41 | def __init__(self) -> None: 42 | self._data : Dict[str, List[Tensor]] = defaultdict(list) 43 | 44 | @property 45 | def features(self) -> Dict[str, Tensor]: 46 | return {k: torch.cat(v) for k, v in self._data.items()} 47 | 48 | def __call__( 49 | self, 50 | module : nn.Module, 51 | inp : Tuple[Tensor, ...], 52 | out : Tensor 53 | ) -> None: 54 | '''Custom torch hook designed to record hidden activations. 55 | NOTE: This function should be called (implicitly) by the 56 | forward hook registered on the desired module. 57 | ''' 58 | 59 | # Get the name of the module 60 | name = module.name if hasattr(module, 'name') else module._get_name().lower() 61 | 62 | feat = out.clone().detach() 63 | feat = rearrange(feat, 'b ... -> b (...)').contiguous() 64 | 65 | self._data[name].append(feat) 66 | 67 | def clean(self) -> None: 68 | '''Clear the recorded data.''' 69 | self._data.clear() 70 | 71 | class ForwardBlock(nn.Module): 72 | 73 | def __init__( 74 | self, 75 | in_dim : int, 76 | out_dim : int | None = None, 77 | hid_dim : int | Tuple[int, ...] | None = 256, 78 | block : nn.Module = nn.Linear, 79 | act_fn : nn.Module = nn.GELU, 80 | num_groups : int = 1, 81 | last_act : bool = False, 82 | **kwargs, 83 | ) -> None: 84 | super().__init__() 85 | 86 | out_dim = default(out_dim, in_dim) 87 | if isinstance(hid_dim, int): hid_dim = (hid_dim,) 88 | hid_dim = default(hid_dim, ()) 89 | 90 | dims = (in_dim,) + hid_dim + (out_dim,) 91 | 92 | self.net = nn.Sequential( 93 | nn.GroupNorm(num_groups, in_dim), 94 | *[nn.Sequential( 95 | block(inp_dim, out_dim, **kwargs), 96 | act_fn() if l < len(dims) - 2 or last_act else nn.Identity() 97 | ) for l, (inp_dim, out_dim) in enumerate(pairwise(dims))], 98 | ) 99 | 100 | def forward( 101 | self, 102 | inp : Tensor 103 | ) -> Tensor: 104 | return self.net(inp) -------------------------------------------------------------------------------- /genie/module/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | from torch.nn.functional import group_norm 5 | 6 | from einops import pack, rearrange, unpack 7 | 8 | class AdaptiveGroupNorm(nn.Module): 9 | def __init__( 10 | self, 11 | dim_cond : int, 12 | num_groups: int, 13 | num_channels: int, 14 | cond_bias : bool = True, 15 | affine : bool = True, 16 | eps : float = 1e-5, 17 | device : str | None = None, 18 | dtype : str | None = None, 19 | ) -> None: 20 | super().__init__() 21 | 22 | if num_channels % num_groups != 0: 23 | raise ValueError('num_channels must be divisible by num_groups') 24 | 25 | self.num_groups = num_groups 26 | self.num_channels = num_channels 27 | self.eps = eps 28 | self.affine = affine 29 | 30 | factory_kwargs = {'device': device, 'dtype': dtype} 31 | if self.affine: 32 | self.weight = nn.Parameter(torch.empty(num_channels, **factory_kwargs)) 33 | self.bias = nn.Parameter(torch.empty(num_channels, **factory_kwargs)) 34 | else: 35 | self.register_parameter('weight', None) 36 | self.register_parameter('bias', None) 37 | 38 | self.std = nn.Linear(dim_cond, self.num_channels) 39 | self.avg = nn.Linear(dim_cond, self.num_channels) if cond_bias else None 40 | 41 | self.reset_parameters() 42 | 43 | def reset_parameters(self) -> None: 44 | if self.affine: 45 | nn.init.ones_(self.weight) 46 | nn.init.zeros_(self.bias) 47 | 48 | nn.init.ones_ (self.std.bias) 49 | nn.init.zeros_(self.std.weight) 50 | 51 | if self.avg is not None: 52 | nn.init.zeros_(self.avg.bias) 53 | nn.init.zeros_(self.avg.weight) 54 | 55 | def forward(self, inp : Tensor, cond : Tensor) -> Tensor: 56 | # Apply the standard group norm to the input. 57 | # Expected shape: [B, G, *] 58 | norm = group_norm(inp, self.num_groups, self.weight, self.bias, self.eps) 59 | norm, ps = pack([norm], 'b g *') 60 | 61 | # Condition is expected to have shape b d ... 62 | cond = rearrange(cond, 'b d ... -> b d (...)').mean(-1) 63 | 64 | # Rescale the normalized input to match the conditional statistics 65 | std = self.std(cond).unsqueeze(-1) 66 | avg = self.avg(cond).unsqueeze(-1) if self.avg is not None else 0 67 | 68 | out = norm * std + avg 69 | return unpack(out, ps, 'b g *')[0] -------------------------------------------------------------------------------- /genie/module/quantization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import log 4 | 5 | from torch import Tensor 6 | from einops import reduce 7 | from einops import einsum 8 | from einops import rearrange 9 | from einops import pack, unpack 10 | 11 | from torch.nn.functional import mse_loss 12 | 13 | from typing import Tuple 14 | 15 | from genie.utils import default 16 | 17 | def entropy(p : Tensor, eps : float = 1e-6) -> Tensor: 18 | '''Calculates the entropy of a probability distribution. 19 | 20 | Args: 21 | p (Tensor): The probability distribution. 22 | eps (float, optional): A small value to avoid taking the logarithm of zero. 23 | Defaults to 1e-6. 24 | 25 | Returns: 26 | Tensor: The entropy of the probability distribution. 27 | ''' 28 | return - (p * log(p.clamp(min=eps))).sum(dim=-1) 29 | 30 | # Simplified version of the lucidrains implementation at: 31 | # https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/lookup_free_quantization.py#L49 32 | class LookupFreeQuantization(nn.Module): 33 | ''' 34 | Lookup-Free Quantization module as originally introduced 35 | in the paper "Language Model Beats Diffusion: Tokenizer 36 | is key to visual generation" Yu et al. (2024). 37 | ''' 38 | 39 | def __init__( 40 | self, 41 | codebook_dim : int, 42 | num_codebook : int = 1, 43 | input_dim : int | None = None, 44 | use_bias : bool = True, 45 | frac_sample : float = 1., 46 | commit_weight : float = 0.25, 47 | entropy_weight : float = 0.1, 48 | diversity_weight : float = 1., 49 | ) -> None: 50 | super().__init__() 51 | 52 | codebook_size = (2 ** codebook_dim) * num_codebook 53 | input_dim = default(input_dim, codebook_size) 54 | 55 | project = input_dim != codebook_dim * num_codebook 56 | 57 | self.proj_inp = nn.Linear(input_dim, codebook_dim * num_codebook, bias=use_bias) if project else nn.Identity() 58 | self.proj_out = nn.Linear(codebook_dim * num_codebook, input_dim, bias=use_bias) if project else nn.Identity() 59 | 60 | self.frac_sample = frac_sample 61 | self.codebook_dim = codebook_dim 62 | self.num_codebooks = num_codebook 63 | self.codebook_size = codebook_size 64 | self.commit_weight = commit_weight 65 | self.entropy_weight = entropy_weight 66 | self.diversity_weight = diversity_weight 67 | 68 | # * Initialize the codebook 69 | # Use the bit_mask to generate the bit-codes for all the codebook entries 70 | # and then convert them to the actual codebook values {-1, 1}. Resulting 71 | # codebook will have shape (codebook_size, d_codebook). 72 | self.register_buffer('bit_mask', 2 ** torch.arange(codebook_dim - 1, -1, -1)) 73 | 74 | codes = torch.arange(codebook_size, dtype=int)[:, None] & self.bit_mask 75 | self.register_buffer('codebook', 2 * (codes != 0).float() - 1, persistent=False) 76 | 77 | def forward( 78 | self, 79 | inp : Tensor, 80 | beta : float = 100., 81 | transpose : bool = False 82 | ) -> Tuple[Tuple[Tensor, Tensor], Tensor | None]: 83 | 84 | # Standardize the input tensor to have shape (batch_size, seq_len, inp_dim) 85 | inp = rearrange(inp, 'b d ... -> b ... d') if transpose else inp 86 | inp, ps = pack([inp], 'b * d') 87 | 88 | inp = self.proj_inp(inp) 89 | 90 | # Split into n_codebook parts 91 | inp = rearrange(inp, 'b n (c d) -> b n c d', c=self.num_codebooks) 92 | 93 | # Quantize by simply assigning {-1, 1} to the input tensor depending on the sign 94 | # of the input tensor values. This is the lookup-free quantization step. 95 | # See Eq. (3) in the original paper. To obtain the quantized-code indices 96 | # we simply sum the bit-codes representation of the quantized values. 97 | quant = inp.sign() 98 | idxs = reduce((inp > 0).int() * self.bit_mask.int(), 'b n c d -> b n c', 'sum') 99 | 100 | # Use straight-through estimator to back-propagate through the quantization step 101 | code = (inp + (quant - inp).detach()) if self.training else quant 102 | code = rearrange(code, 'b n c d -> b n (c d)') 103 | 104 | # Reconstruct the input tensor from the quantized values 105 | out = self.proj_out(code) 106 | out = unpack(out, ps, 'b * d')[0] 107 | out = rearrange(out, 'b ... d -> b d ...') if transpose else out 108 | 109 | # NOTE: Squeeze to remove the n_codebook dimension 110 | idxs = unpack(idxs, ps, 'b * d')[0].squeeze() 111 | 112 | # No need to compute the loss if we are not training 113 | if not self.training: return (out, idxs), None 114 | 115 | # Compute the entropy loss 116 | inp_prob = 2 * einsum(inp, self.codebook, '... i d, j d -> ... i j') 117 | inp_prob = (inp_prob * beta).softmax(dim=-1) 118 | inp_prob = rearrange(inp_prob, 'b n ... -> (b n) ...') 119 | 120 | avg_prob = reduce(inp_prob, '... c d -> c d', 'mean') 121 | 122 | inp_ent = entropy(inp_prob).mean() 123 | avg_ent = entropy(avg_prob).mean() 124 | 125 | entropy_loss = inp_ent + self.diversity_weight * avg_ent 126 | 127 | # Compute commitment loss 128 | commit_loss = mse_loss(inp, quant.detach(), reduction = 'mean') 129 | 130 | # Compute the complete final loss 131 | loss = entropy_loss * self.entropy_weight + commit_loss * self.commit_weight 132 | 133 | return (out, idxs), loss -------------------------------------------------------------------------------- /genie/module/video.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from abc import ABC 4 | from torch import Tensor 5 | from torch.nn.functional import pad 6 | from torch.nn.functional import conv3d 7 | from einops.layers.torch import Rearrange 8 | 9 | from math import comb 10 | from torch.types import Device 11 | 12 | from typing import Tuple 13 | from functools import partial 14 | from einops import pack 15 | from einops import unpack 16 | from einops import repeat 17 | from einops import einsum 18 | from einops import rearrange 19 | 20 | from genie.utils import default, exists 21 | 22 | def get_blur_kernel( 23 | kernel_size : int | Tuple[int, int], 24 | device : Device = None, 25 | dtype : torch.dtype | None = None, 26 | norm : bool = True 27 | ) -> Tensor: 28 | if isinstance(kernel_size, int): 29 | kernel_size = (kernel_size, kernel_size) 30 | 31 | # Construct the 1d pascal blur kernel 32 | ker_t_1d = torch.tensor( 33 | [comb(kernel_size[0] - 1, i) for i in range(kernel_size[0])], 34 | device=device, 35 | dtype=dtype, 36 | ) 37 | ker_h_1d = rearrange( 38 | torch.tensor( 39 | [comb(kernel_size[0] - 1, i) for i in range(kernel_size[0])], 40 | device=device, 41 | dtype=dtype, 42 | ), 43 | 'h -> h 1' 44 | ) 45 | ker_w_1d = rearrange( 46 | torch.tensor( 47 | [comb(kernel_size[1] - 1, i) for i in range(kernel_size[0])], 48 | device=device, 49 | dtype=dtype, 50 | ), 51 | 'w -> 1 w' 52 | ) 53 | 54 | ker_3d = einsum(ker_t_1d, ker_h_1d @ ker_w_1d, 't, h w -> t h w') 55 | 56 | return ker_3d / ker_3d.sum() if norm else ker_3d 57 | 58 | class Upsample(nn.Module, ABC): 59 | def __init__( 60 | self, 61 | time_factor : int = 1, 62 | space_factor : int = 1, 63 | ) -> None: 64 | super().__init__() 65 | 66 | self.time_factor = time_factor 67 | self.space_factor = space_factor 68 | 69 | self.go_up = None 70 | 71 | @property 72 | def factor(self) -> int: 73 | return self.time_factor * (self.space_factor ** 2) 74 | 75 | def forward( 76 | self, 77 | inp : Tensor, 78 | **kwargs, 79 | ) -> Tensor: 80 | return self.go_up(inp) 81 | 82 | class Downsample(nn.Module, ABC): 83 | def __init__( 84 | self, 85 | time_factor : int = 1, 86 | space_factor : int = 1, 87 | ) -> None: 88 | super().__init__() 89 | 90 | self.time_factor = time_factor 91 | self.space_factor = space_factor 92 | 93 | self.go_down = None 94 | 95 | @property 96 | def factor(self) -> int: 97 | return self.time_factor * (self.space_factor ** 2) 98 | 99 | def forward( 100 | self, 101 | inp : Tensor, 102 | **kwargs, 103 | ) -> Tensor: 104 | return self.go_down(inp) 105 | 106 | class CausalConv3d(nn.Module): 107 | """ 108 | 3D Causal Convolutional Layer. 109 | 110 | Args: 111 | in_channels (int): Number of input channels. 112 | out_channels (int): Number of output channels. 113 | kernel_size (int or Tuple[int, int, int]): Size of the convolutional kernel. 114 | stride (int or Tuple[int, int, int], optional): Stride of the convolution. Defaults to (1, 1, 1). 115 | dilation (int or Tuple[int, int, int], optional): Dilation rate of the convolution. Defaults to (1, 1, 1). 116 | pad_mode (str, optional): Padding mode. Defaults to 'constant'. 117 | **kwargs: Additional keyword arguments to be passed to the nn.Conv3d constructor. 118 | 119 | Attributes: 120 | causal_pad (partial): Partial function for applying causal padding. 121 | conv3d (nn.Conv3d): 3D convolutional layer. 122 | 123 | """ 124 | 125 | def __init__( 126 | self, 127 | in_channels: int, 128 | out_channels: int, 129 | kernel_size: int | Tuple[int, int, int], 130 | stride: int | Tuple[int, int, int] = (1, 1, 1), 131 | dilation: int | Tuple[int, int, int] = (1, 1, 1), 132 | padding : int | Tuple[int, int] | None = None, 133 | pad_mode: str = 'constant', 134 | **kwargs 135 | ): 136 | super().__init__() 137 | 138 | if isinstance(stride, int): 139 | stride = (stride, stride, stride) 140 | if isinstance(dilation, int): 141 | dilation = (dilation, dilation, dilation) 142 | if isinstance(kernel_size, int): 143 | kernel_size = (kernel_size, kernel_size, kernel_size) 144 | if isinstance(padding, int | None): 145 | padding = (padding, padding) 146 | 147 | t_stride, *s_stride = stride 148 | t_dilation, *s_dilation = dilation 149 | 150 | # Compute the appropriate causal padding 151 | if isinstance(padding, int | None): 152 | padding = (padding, padding) 153 | 154 | time_ker, height_ker, width_ker = kernel_size 155 | time_pad = (time_ker - 1) * t_dilation + (1 - t_stride) 156 | height_pad = default(padding[0], (height_ker - 1) // 2) 157 | width_pad = default(padding[1], (width_ker - 1) // 2) 158 | 159 | # Causal padding pads time only to the left to ensure causality 160 | self.causal_pad = partial( 161 | pad, 162 | pad=(width_pad, width_pad, height_pad, height_pad, time_pad, 0), 163 | mode=pad_mode 164 | ) 165 | 166 | self.conv3d = nn.Conv3d( 167 | in_channels, 168 | out_channels, 169 | kernel_size, 170 | stride=(t_stride, *s_stride), 171 | dilation=(t_dilation, *s_dilation), 172 | **kwargs 173 | ) 174 | 175 | self.in_channels = in_channels 176 | self.out_channels = out_channels 177 | 178 | def forward(self, inp: Tensor) -> Tensor: 179 | """ 180 | Forward pass of the CausalConv3d layer. 181 | 182 | Args: 183 | inp (Tensor): Input tensor. 184 | 185 | Returns: 186 | Tensor: Output tensor after applying the CausalConv3d layer. 187 | 188 | """ 189 | # Insert causal padding 190 | inp = self.causal_pad(inp) 191 | 192 | return self.conv3d(inp) 193 | 194 | @property 195 | def inp_dim(self) -> int: 196 | return self.in_channels 197 | 198 | @property 199 | def out_dim(self) -> int: 200 | return self.out_channels 201 | 202 | class CausalConvTranspose3d(nn.ConvTranspose3d): 203 | """ 204 | 3D Causal Convolutional Transpose layer. 205 | 206 | Args: 207 | in_channels (int): Number of input channels. 208 | out_channels (int): Number of output channels. 209 | kernel_size (int or Tuple[int, int, int]): Size of the convolutional kernel. 210 | stride (int or Tuple[int, int, int], optional): Stride of the convolution. Default is (1, 1, 1). 211 | dilation (int or Tuple[int, int, int], optional): Dilation rate of the convolution. Default is (1, 1, 1). 212 | **kwargs: Additional keyword arguments to be passed to the parent class. 213 | 214 | Attributes: 215 | Same as the parent class `nn.ConvTranspose3d`. 216 | 217 | """ 218 | 219 | def __init__( 220 | self, 221 | in_channels: int, 222 | out_channels: int, 223 | kernel_size: int | Tuple[int, int, int], 224 | stride : int | Tuple[int, int, int] = (1, 1, 1), 225 | dilation: int | Tuple[int, int, int] = (1, 1, 1), 226 | space_pad : int | Tuple[int, int] | None = None, 227 | **kwargs, 228 | ) -> None: 229 | if isinstance(stride, int): 230 | stride = (stride, stride, stride) 231 | if isinstance(dilation, int): 232 | dilation = (dilation, dilation, dilation) 233 | if isinstance(kernel_size, int): 234 | kernel_size = (kernel_size, kernel_size, kernel_size) 235 | if isinstance(space_pad, int | None): 236 | space_pad = (space_pad, space_pad) 237 | _, height_ker, width_ker = kernel_size 238 | 239 | height_pad = default(space_pad[0], height_ker // 2) 240 | width_pad = default(space_pad[1], width_ker // 2) 241 | 242 | super(CausalConvTranspose3d, self).__init__( 243 | in_channels, 244 | out_channels, 245 | kernel_size, 246 | stride=stride, 247 | dilation=dilation, 248 | padding=(0, height_pad, width_pad), 249 | **kwargs, 250 | ) 251 | 252 | self.in_channels = in_channels 253 | self.out_channels = out_channels 254 | 255 | def forward(self, inp: Tensor) -> Tensor: 256 | """ 257 | Forward pass of the CausalConvTranspose3d layer. 258 | 259 | Args: 260 | inp (Tensor): Input tensor of shape (batch_size, in_channels, t, h, w). 261 | 262 | Returns: 263 | Tensor: Output tensor of shape (batch_size, out_channels, t', h', w'). 264 | 265 | """ 266 | *_, t, h, w = inp.shape 267 | T, H, W = self.stride 268 | 269 | return super().forward(inp)[..., :t * T, :h * H, :w * W] 270 | 271 | @property 272 | def inp_dim(self) -> int: 273 | return self.in_channels 274 | 275 | @property 276 | def out_dim(self) -> int: 277 | return self.out_channels 278 | 279 | class DepthToSpaceUpsample(Upsample): 280 | '''Depth to Space Upsampling module. 281 | ''' 282 | 283 | def __init__( 284 | self, 285 | in_channels : int, 286 | out_channels : int | None = None, 287 | factor : int = 2, 288 | ) -> None: 289 | super().__init__( 290 | space_factor=factor, 291 | ) 292 | 293 | out_channels = default(out_channels, in_channels) 294 | 295 | self.go_up = nn.Sequential( 296 | nn.Conv2d(in_channels, out_channels * factor ** 2, kernel_size=1), 297 | Rearrange('b (c p q) h w -> b c (h p) (w q)', p=factor, q=factor), 298 | ) 299 | 300 | self.in_channels = in_channels 301 | self.out_channels = out_channels 302 | 303 | def forward( 304 | self, 305 | inp : Tensor, 306 | **kwargs, 307 | ) -> Tensor: 308 | # Input is expected to be a video, rearrange it to have 309 | # shape suitable for a Conv2d layer to operate on 310 | inp = rearrange(inp, 'b c t h w -> b t c h w') 311 | inp, ps = pack([inp], '* c h w') 312 | 313 | out = self.go_up(inp) 314 | 315 | # Restore video format 316 | out, *_ = unpack(out, ps, '* c h w') 317 | out = rearrange(out, 'b t c h w -> b c t h w') 318 | 319 | return out 320 | 321 | @property 322 | def inp_dim(self) -> int: 323 | return self.in_channels 324 | 325 | @property 326 | def out_dim(self) -> int: 327 | return self.out_channels 328 | 329 | class DepthToTimeUpsample(Upsample): 330 | '''Depth to Time Upsampling module. 331 | ''' 332 | 333 | def __init__( 334 | self, 335 | in_channels : int, 336 | out_channels : int | None = None, 337 | factor : int = 2, 338 | ) -> None: 339 | super().__init__( 340 | time_factor=factor, 341 | ) 342 | 343 | out_channels = default(out_channels, in_channels) 344 | 345 | self.go_up = nn.Sequential( 346 | nn.Conv1d(in_channels, out_channels * factor, kernel_size=1), 347 | Rearrange('b (c f) t -> b c (t f)', f=factor), 348 | ) 349 | 350 | self.in_channels = in_channels 351 | self.out_channels = out_channels 352 | 353 | def forward( 354 | self, 355 | inp : Tensor, 356 | **kwargs, 357 | ) -> Tensor: 358 | # Input is expected to be a video, rearrange it to have 359 | # shape suitable for a Conv2d layer to operate on 360 | inp = rearrange(inp, 'b c t h w -> b h w c t') 361 | inp, ps = pack([inp], '* c t') 362 | 363 | out = self.go_up(inp) 364 | 365 | # Restore video format 366 | out, *_ = unpack(out, ps, '* c t') 367 | out = rearrange(out, 'b h w c t -> b c t h w') 368 | 369 | return out 370 | 371 | @property 372 | def inp_dim(self) -> int: 373 | return self.in_channels 374 | 375 | @property 376 | def out_dim(self) -> int: 377 | return self.out_channels 378 | 379 | class DepthToSpaceTimeUpsample(Upsample): 380 | '''Depth to Space-Time Upsample 381 | ''' 382 | def __init__( 383 | self, 384 | in_channels : int, 385 | out_channels : int | None = None, 386 | time_factor : int = 2, 387 | space_factor : int = 2, 388 | kernel_size : int | Tuple[int, int, int] = 1, 389 | ) -> None: 390 | super().__init__( 391 | time_factor=time_factor, 392 | space_factor=space_factor, 393 | ) 394 | 395 | out_channels = default(out_channels, in_channels) 396 | 397 | self.go_up = nn.Sequential( 398 | CausalConv3d( 399 | in_channels, 400 | out_channels * time_factor * space_factor ** 2, 401 | kernel_size=kernel_size, 402 | ), 403 | Rearrange( 404 | 'b (c p q r) t h w -> b c (t p) (h q) (w r)', 405 | p=time_factor, 406 | q=space_factor, 407 | r=space_factor 408 | ), 409 | ) 410 | 411 | self.in_channels = in_channels 412 | self.out_channels = out_channels 413 | 414 | def forward( 415 | self, 416 | inp : Tensor, 417 | **kwargs, 418 | ) -> Tensor: 419 | # Input is expected to be a video 420 | out = self.go_up(inp) 421 | 422 | return out 423 | 424 | @property 425 | def inp_dim(self) -> int: 426 | return self.in_channels 427 | 428 | @property 429 | def out_dim(self) -> int: 430 | return self.out_channels 431 | 432 | class SpaceTimeUpsample(Upsample): 433 | '''Space-Time Upsample module. 434 | ''' 435 | 436 | def __init__( 437 | self, 438 | in_dim : int, 439 | out_dim : int, 440 | time_factor : int = 2, 441 | space_factor : int = 2, 442 | **kwargs 443 | ) -> None: 444 | super().__init__( 445 | time_factor=time_factor, 446 | space_factor=space_factor, 447 | ) 448 | 449 | self.go_up = nn.ConvTranspose3d( 450 | in_dim, 451 | out_dim, 452 | kernel_size=(time_factor, space_factor, space_factor), 453 | stride=(time_factor, space_factor, space_factor), 454 | **kwargs, 455 | ) 456 | 457 | class SpaceTimeDownsample(Downsample): 458 | '''Space-Time Downsample module. 459 | ''' 460 | 461 | def __init__( 462 | self, 463 | in_channels : int, 464 | kernel_size : int | Tuple[int, int, int], 465 | out_channels : int | None = None, 466 | time_factor : int = 2, 467 | space_factor : int = 2, 468 | **kwargs 469 | ) -> None: 470 | super().__init__( 471 | time_factor=1 / time_factor, 472 | space_factor=1 / space_factor, 473 | ) 474 | if isinstance(kernel_size, int): 475 | kernel_size = (kernel_size, kernel_size, kernel_size) 476 | 477 | self.go_down = CausalConv3d( 478 | in_channels, 479 | default(out_channels, in_channels), 480 | kernel_size = kernel_size, 481 | stride = (time_factor, space_factor, space_factor), 482 | **kwargs, 483 | ) 484 | 485 | # Inspired by the (very cool) kornia library, see the original implementation here: 486 | # https://github.com/kornia/kornia/blob/e461f92ff9ee035d2de2513859bee4069356bc25/kornia/filters/blur_pool.py#L21 487 | class BlurPooling3d(nn.Module): 488 | def __init__( 489 | self, 490 | in_channels : int, # Needed only for compatibility 491 | kernel_size : int | Tuple[int, int, int], 492 | out_channels : int | None = None, 493 | time_factor : int = 2, 494 | space_factor : int | Tuple[int, int] = 2, 495 | num_groups : int = 1, 496 | **kwargs, 497 | ) -> None: 498 | super().__init__() 499 | 500 | if isinstance(kernel_size, int): 501 | kernel_size = (kernel_size, kernel_size, kernel_size) 502 | if isinstance(space_factor, int): 503 | space_factor = (space_factor, space_factor) 504 | 505 | # Register the blurring kernel buffer 506 | self.register_buffer('blur', get_blur_kernel(kernel_size)) 507 | 508 | self.stride = (time_factor, *space_factor) 509 | self.kwargs = kwargs 510 | self.num_groups = num_groups 511 | self.out_channels = out_channels 512 | 513 | ker_t, ker_h, ker_w = kernel_size 514 | self.padding = (ker_t - 1) // 2, (ker_h - 1) // 2, (ker_w - 1) // 2 515 | 516 | def forward( 517 | self, 518 | inp : Tensor, 519 | ) -> Tensor: 520 | b, c, t, h, w = inp.shape 521 | 522 | o = default(self.out_channels, c) 523 | 524 | # Repeat spatial kernel for each channel of input image 525 | ker = repeat(self.blur, 'i j k -> o g i j k', o=o, g=c // self.num_groups) 526 | 527 | # Compute the blur as 2d convolution 528 | return conv3d( 529 | inp, ker, 530 | stride=self.stride, 531 | padding=self.padding, 532 | groups=self.num_groups, 533 | **self.kwargs 534 | ) 535 | 536 | def __repr__(self): 537 | return f'BlurPooling3d({self.out_channels}, kernel_size={tuple(self.blur.shape)}, stride={self.stride}, padding={self.padding})' 538 | 539 | class VideoResidualBlock(nn.Module): 540 | """ 541 | A residual block module that performs residual connections and applies 542 | convolutional operations, with flexible options for normalization and 543 | down-sampling of input. 544 | 545 | Args: 546 | inp_channel (int): The number of input channels. 547 | out_channel (int | None, optional): The number of output channels. If None, it defaults to inp_channel. 548 | kernel_size (int | Tuple[int, int, int], optional): The size of the convolutional kernel. Defaults to 3. 549 | num_groups (int, optional): The number of groups to separate the channels into for group normalization. Defaults to 32. 550 | pad_mode (str, optional): The padding mode for convolution. Defaults to 'constant'. 551 | downsample (int | Tuple[int, int] | None, optional): The factor by which to downsample the input. Defaults to None. 552 | causal (bool, optional): Whether to use a causal convolution. Defaults to False. 553 | use_norm (bool, optional): Whether to use normalization. Defaults to True. 554 | use_blur (bool, optional): Whether to use blur pooling. Defaults to True. 555 | act_fn (str, optional): The activation function to use. Defaults to 'swish'. 556 | """ 557 | 558 | def __init__( 559 | self, 560 | in_channels : int, 561 | out_channels : int | None = None, 562 | kernel_size : int | Tuple[int, int, int] = 3, 563 | num_groups : int = 1, 564 | pad_mode : str = 'constant', 565 | downsample : int | Tuple[int, int] | None = None, 566 | use_causal : bool = False, 567 | use_norm : bool = True, 568 | use_blur : bool = True, 569 | act_fn : str = 'swish', 570 | ) -> None: 571 | super().__init__() 572 | 573 | if isinstance(downsample, int): 574 | downsample = (downsample, downsample) 575 | if isinstance(kernel_size, int): 576 | kernel_size = (kernel_size, kernel_size, kernel_size) 577 | 578 | Norm = nn.GroupNorm if use_norm else nn.Identity 579 | Down = BlurPooling3d if use_blur else SpaceTimeDownsample 580 | Conv = partial(CausalConv3d, pad_mode=pad_mode) if use_causal else nn.Conv3d 581 | 582 | match act_fn: 583 | case 'relu': Act = nn.ReLU 584 | case 'gelu': Act = nn.GELU 585 | case 'leaky': Act = nn.LeakyReLU 586 | case 'swish' | 'silu': Act = nn.SiLU 587 | 588 | out_channels = default(out_channels, in_channels) 589 | time_factor, space_factor = downsample if exists(downsample) else (None, None) 590 | 591 | self.res = nn.Sequential( 592 | Down( 593 | in_channels, 594 | kernel_size, 595 | time_factor=time_factor, 596 | space_factor=space_factor, 597 | num_groups=num_groups, 598 | ) if exists(downsample) else nn.Identity(), 599 | Conv( 600 | in_channels, 601 | kernel_size=1, 602 | out_channels=out_channels, 603 | ) if exists(out_channels) else nn.Identity() 604 | ) 605 | 606 | self.main = nn.Sequential( 607 | Norm(num_groups, in_channels), 608 | Act(), 609 | Conv( 610 | in_channels, 611 | out_channels=out_channels, 612 | kernel_size=kernel_size, 613 | padding=tuple(map(lambda k : (k - 1) // 2, kernel_size)), 614 | ), 615 | Down( 616 | out_channels, 617 | kernel_size, 618 | time_factor=time_factor, 619 | space_factor=space_factor, 620 | num_groups=num_groups, 621 | ) if exists(downsample) else nn.Identity(), 622 | Norm(num_groups, out_channels), 623 | Act(), 624 | Conv( 625 | out_channels, 626 | out_channels, 627 | kernel_size=kernel_size, 628 | padding=tuple(map(lambda k : (k - 1) // 2, kernel_size)), 629 | ), 630 | ) 631 | 632 | self.inp_channels = in_channels 633 | self.out_channels = out_channels 634 | 635 | def forward( 636 | self, 637 | inp : Tensor 638 | ) -> Tensor: 639 | """ 640 | Forward pass of the residual block. 641 | 642 | Args: 643 | inp (Tensor): The input tensor. 644 | 645 | Returns: 646 | Tensor: The output tensor after applying the residual block operations. 647 | """ 648 | return self.main(inp) + self.res(inp) 649 | 650 | @property 651 | def inp_dim(self) -> int: 652 | return self.inp_channels 653 | 654 | @property 655 | def out_dim(self) -> int: 656 | return self.out_channels -------------------------------------------------------------------------------- /genie/tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | from torch.optim import AdamW 5 | from torch.optim import Optimizer 6 | from torchvision.models import get_model 7 | from torch.nn.functional import mse_loss 8 | 9 | from typing import Any, Tuple 10 | from typing import Dict, Callable, Iterable 11 | from itertools import zip_longest 12 | 13 | from lightning import LightningModule 14 | 15 | from genie.module.loss import GANLoss 16 | from genie.module.loss import PerceptualLoss 17 | from genie.module.quantization import LookupFreeQuantization 18 | from genie.utils import Blueprint, default, exists 19 | 20 | from genie.module import parse_blueprint 21 | 22 | OptimizerCallable = Callable[[Iterable], Optimizer] 23 | 24 | MAGVIT2_ENC_DESC = ( 25 | ('causal-conv3d', { 26 | 'in_channels': 3, 27 | 'out_channels': 128, 28 | 'kernel_size': 3, 29 | }), 30 | ('video-residual', { 31 | 'n_rep': 4, 32 | 'in_channels': 128, 33 | }), 34 | ('spacetime_downsample', { 35 | 'in_channels': 128, 36 | 'out_channels': 128, 37 | 'kernel_size': 3, 38 | 'time_factor': 1, 39 | 'space_factor': 2, 40 | }), 41 | ('video-residual', { 42 | 'in_channels': 128, 43 | 'out_channels': 256, 44 | }), 45 | ('video-residual', { 46 | 'n_rep': 3, 47 | 'in_channels': 256, 48 | }), 49 | ('spacetime_downsample', { 50 | 'in_channels': 256, 51 | 'out_channels': 256, 52 | 'kernel_size': 3, 53 | 'time_factor': 2, 54 | 'space_factor': 2, 55 | }), 56 | ('video-residual', { 57 | 'n_rep': 4, 58 | 'in_channels': 256, 59 | }), 60 | ('spacetime_downsample', { 61 | 'in_channels': 256, 62 | 'out_channels': 256, 63 | 'kernel_size': 3, 64 | 'time_factor': 2, 65 | 'space_factor': 2, 66 | }), 67 | ('video-residual', { 68 | 'in_channels': 256, 69 | 'out_channels': 512, 70 | }), 71 | ('video-residual', { 72 | 'n_rep': 7, 73 | 'in_channels': 512, 74 | }), 75 | ('group_norm', { 76 | 'num_groups': 8, 77 | 'num_channels': 512, 78 | }), 79 | ('silu', {}), 80 | ('causal-conv3d', { 81 | 'in_channels': 512, 82 | 'out_channels': 18, 83 | 'kernel_size': 1, 84 | }) 85 | ) 86 | 87 | MAGVIT2_DEC_DESC = ( 88 | ('causal-conv3d', { 89 | 'in_channels': 18, 90 | 'out_channels': 512, 91 | 'kernel_size': 3, 92 | }), 93 | ('video-residual', { 94 | 'n_rep': 4, 95 | 'in_channels': 512, 96 | }), 97 | ('adaptive_group_norm', { 98 | 'dim_cond' : 18, 99 | 'num_groups': 8, 100 | 'num_channels': 512, 101 | 'has_ext' : True, 102 | }), 103 | ('video-residual', { 104 | 'n_rep': 4, 105 | 'in_channels': 512, 106 | }), 107 | ('depth2spacetime_upsample', { 108 | 'in_channels': 512, 109 | 'kernel_size': 3, 110 | 'time_factor': 2, 111 | 'space_factor': 2, 112 | }), 113 | ('adaptive_group_norm', { 114 | 'dim_cond' : 18, 115 | 'num_groups': 8, 116 | 'num_channels': 512, 117 | 'has_ext' : True, 118 | }), 119 | ('video-residual', { 120 | 'in_channels': 512, 121 | 'out_channels': 256, 122 | }), 123 | ('video-residual', { 124 | 'n_rep': 3, 125 | 'in_channels': 256, 126 | }), 127 | ('depth2spacetime_upsample', { 128 | 'in_channels': 256, 129 | 'kernel_size': 3, 130 | 'time_factor': 2, 131 | 'space_factor': 2, 132 | }), 133 | ('adaptive_group_norm', { 134 | 'dim_cond' : 18, 135 | 'num_groups': 8, 136 | 'num_channels': 256, 137 | 'has_ext' : True, 138 | }), 139 | ('video-residual', { 140 | 'n_rep' : 4, 141 | 'in_channels': 256, 142 | }), 143 | ('depth2spacetime_upsample', { 144 | 'in_channels': 256, 145 | 'kernel_size': 3, 146 | 'time_factor': 1, 147 | 'space_factor': 2, 148 | }), 149 | ('adaptive_group_norm', { 150 | 'dim_cond' : 18, 151 | 'num_groups': 8, 152 | 'num_channels': 256, 153 | 'has_ext' : True, 154 | }), 155 | ('video-residual', { 156 | 'in_channels': 256, 157 | 'out_channels': 128, 158 | }), 159 | ('video-residual', { 160 | 'n_rep' : 3, 161 | 'in_channels': 128, 162 | }), 163 | ('group_norm', { 164 | 'num_groups': 8, 165 | 'num_channels': 128, 166 | }), 167 | ('silu', {}), 168 | ('causal-conv3d', { 169 | 'in_channels': 128, 170 | 'out_channels': 3, 171 | 'kernel_size': 3, 172 | }) 173 | ) 174 | 175 | REPR_TOK_ENC = ( 176 | ('spacetime_downsample', { 177 | 'in_channels' : 3, 178 | 'kernel_size' : 3, 179 | 'out_channels' : 512, 180 | 'time_factor' : 1, 181 | 'space_factor' : 4, 182 | }), 183 | ('space-time_attn', { 184 | 'n_rep' : 8, 185 | 'n_head': 8, 186 | 'd_head': 64, 187 | 'transpose' : True, 188 | }), 189 | ) 190 | 191 | REPR_TOK_DEC = ( 192 | ('space-time_attn', { 193 | 'n_rep' : 8, 194 | 'n_head': 8, 195 | 'd_head': 64, 196 | 'transpose' : True, 197 | }), 198 | ('depth2spacetime_upsample', { 199 | 'in_channels' : 512, 200 | 'kernel_size' : 3, 201 | 'out_channels' : 3, 202 | 'time_factor' : 1, 203 | 'space_factor' : 4, 204 | }) 205 | ) 206 | 207 | def get_enc(name : str) -> Blueprint: 208 | match name: 209 | case 'magvit2': 210 | return MAGVIT2_ENC_DESC 211 | case 'repr_tok': 212 | return REPR_TOK_ENC 213 | case _: 214 | raise ValueError(f'Unknown encoder: {name}') 215 | 216 | def get_dec(name : str) -> Blueprint: 217 | match name: 218 | case 'magvit2': 219 | return MAGVIT2_DEC_DESC 220 | case 'repr_tok': 221 | return REPR_TOK_DEC 222 | case _: 223 | raise ValueError(f'Unknown decoder: {name}') 224 | 225 | class VideoTokenizer(LightningModule): 226 | ''' 227 | Video Tokenizer based on the MagViT-2 paper: 228 | "Language Model Beats Diffusion: Tokenizer is 229 | key to visual generation", Yu et al. (2024). 230 | 231 | This tokenizer employs a stack of causal 232 | convolutions to process the input video sequence. 233 | ''' 234 | 235 | def __init__( 236 | self, 237 | enc_desc : Blueprint, 238 | dec_desc : Blueprint, 239 | disc_kwargs : Dict[str, Any] = {}, 240 | # Lookup-Free Quantization parameters 241 | d_codebook : int = 18, 242 | n_codebook : int = 1, 243 | # lfq_input_dim : int | None = None, 244 | lfq_bias : bool = True, 245 | lfq_frac_sample : float = 1., 246 | lfq_commit_weight : float = 0.25, 247 | lfq_entropy_weight : float = 0.1, 248 | lfq_diversity_weight : float = 1., 249 | # Misc parameters 250 | optimizer : OptimizerCallable = AdamW, 251 | perceptual_model : str = 'vgg16', 252 | perc_feat_layers : str | Iterable[str] = ('features.6', 'features.13', 'features.18', 'features.25'), 253 | gan_discriminate : str = 'frames', 254 | gan_frames_per_batch : int = 4, 255 | gan_loss_weight : float = 1., 256 | perc_loss_weight : float = 1., 257 | quant_loss_weight : float = 1., 258 | ) -> None: 259 | super().__init__() 260 | 261 | self.optimizer = optimizer 262 | 263 | # Scan the blueprint to build the tokenizer 264 | self.enc_layers, self.enc_ext = parse_blueprint(enc_desc) 265 | self.dec_layers, self.dec_ext = parse_blueprint(dec_desc) 266 | 267 | # Check consistency between last encoder dimension, first 268 | # decoder dimension and the codebook dimension 269 | # last_enc_dim = list(self.enc_layers.modules())[-1].out_channels 270 | last_enc_dim = [m.out_channels for m in self.enc_layers.modules() if hasattr(m, 'out_channels')][-1] 271 | first_dec_dim = self.dec_layers[0].in_channels 272 | assert last_enc_dim == first_dec_dim, 'Inconsistent encoder/decoder dimensions' 273 | # assert last_enc_dim == d_codebook , 'Codebook dimension mismatch with encoder/decoder' 274 | 275 | # Build the quantization module 276 | self.quant = LookupFreeQuantization( 277 | codebook_dim = d_codebook, 278 | num_codebook = n_codebook, 279 | input_dim = last_enc_dim, 280 | use_bias = lfq_bias, 281 | frac_sample = lfq_frac_sample, 282 | commit_weight = lfq_commit_weight, 283 | entropy_weight = lfq_entropy_weight, 284 | diversity_weight = lfq_diversity_weight, 285 | ) 286 | 287 | # If the perceptual loss is enabled, load the perceptual model 288 | self.perc_crit = PerceptualLoss( 289 | model_name=perceptual_model, 290 | feat_layers=perc_feat_layers, 291 | num_frames=gan_frames_per_batch, 292 | ) if perc_loss_weight > 0 else nn.Identity() 293 | 294 | # If the GAN loss is enabled, load the Discriminator model 295 | self.gan_crit = GANLoss( 296 | discriminate=gan_discriminate, 297 | num_frames=gan_frames_per_batch, 298 | **disc_kwargs, 299 | ) if gan_loss_weight > 0 else nn.Identity() 300 | 301 | 302 | self.gan_loss_weight = gan_loss_weight 303 | self.perc_loss_weight = perc_loss_weight 304 | self.quant_loss_weight = quant_loss_weight 305 | self.save_hyperparameters() 306 | 307 | def encode( 308 | self, 309 | video : Tensor, 310 | cond : Tensor | None = None, 311 | ) -> Tensor: 312 | enc_video = video 313 | 314 | for layer, has_ext in zip(self.enc_layers, self.enc_ext): 315 | enc_video = layer(enc_video, cond) if has_ext else layer(enc_video) 316 | 317 | return enc_video 318 | 319 | def decode( 320 | self, 321 | quant : Tensor, 322 | cond : Tensor | None = None, 323 | ) -> Tensor: 324 | cond = default(cond, quant) 325 | 326 | rec_video = quant 327 | for layer, has_ext in zip(self.dec_layers, self.dec_ext): 328 | rec_video = layer(rec_video, cond) if has_ext else layer(rec_video) 329 | 330 | return rec_video 331 | 332 | @torch.no_grad() 333 | def tokenize( 334 | self, 335 | video : Tensor, 336 | beta : float = 100., 337 | transpose : bool = True, 338 | ) -> Tuple[Tensor, Tensor]: 339 | self.eval() 340 | 341 | enc_video = self.encode(video) 342 | (quant_video, idxs), _ = self.quant( 343 | enc_video, 344 | beta=beta, 345 | transpose=transpose 346 | ) 347 | 348 | self.train() 349 | 350 | return quant_video, idxs 351 | 352 | def forward( 353 | self, 354 | video : Tensor, 355 | beta : float = 100., 356 | transpose : bool = True, 357 | ) -> Tuple[Tensor, Tuple[Tensor, ...]]: 358 | enc_video = self.encode(video) 359 | (quant_video, idxs), quant_loss = self.quant(enc_video, beta=beta, transpose=transpose) 360 | rec_video = self.decode(quant_video) 361 | 362 | # * Compute the tokenizer loss 363 | # Reconstruction loss 364 | rec_loss = mse_loss(rec_video, video) 365 | 366 | # GAN loss (if available) 367 | gen_loss = self.gan_crit(rec_video, video, train_gen=True) 368 | dis_loss = self.gan_crit(rec_video, video, train_gen=False) 369 | 370 | # Perceptual loss (if available) 371 | perc_loss = self.perc_crit(rec_video, video) 372 | 373 | # Compute the total loss by combining the individual 374 | # losses, weighted by the corresponding loss weights 375 | loss = rec_loss\ 376 | + gen_loss * self.gan_loss_weight\ 377 | + dis_loss * self.gan_loss_weight\ 378 | + perc_loss * self.perc_loss_weight\ 379 | + (quant_loss * self.quant_loss_weight) if exists(quant_loss) else 0\ 380 | 381 | return loss, ( 382 | rec_loss, 383 | gen_loss if self.gan_loss_weight > 0 else 0, 384 | dis_loss if self.gan_loss_weight > 0 else 0, 385 | perc_loss if self.perc_loss_weight > 0 else 0, 386 | quant_loss if exists(quant_loss) and self.quant_loss_weight > 0 else 0, 387 | ) 388 | 389 | # * Lightning core functions 390 | 391 | def training_step(self, batch : Tensor, batch_idx : int) -> Tensor: 392 | # Compute the training loss 393 | loss, aux_losses = self(batch) 394 | 395 | # Log the training loss 396 | self.log_dict( 397 | { 398 | 'train_loss': loss, 399 | 'train_rec_loss' : aux_losses[0], 400 | 'train_gen_loss' : aux_losses[1], 401 | 'train_dis_loss' : aux_losses[2], 402 | 'train_perc_loss' : aux_losses[3], 403 | 'train_quant_loss': aux_losses[4], 404 | }, 405 | logger=True, 406 | on_step=True, 407 | sync_dist=True 408 | ) 409 | 410 | return loss 411 | 412 | def validation_step(self, batch : Tensor, batch_idx : int) -> Tensor: 413 | # Compute the validation loss 414 | loss, aux_losses = self(batch) 415 | 416 | # Log the training loss 417 | self.log_dict( 418 | { 419 | 'val_loss': loss, 420 | 'val_rec_loss' : aux_losses[0], 421 | 'val_gen_loss' : aux_losses[1], 422 | 'val_dis_loss' : aux_losses[2], 423 | 'val_perc_loss' : aux_losses[3], 424 | 'val_quant_loss': aux_losses[4], 425 | }, 426 | on_step=True, 427 | logger=True, 428 | sync_dist=True 429 | ) 430 | 431 | return loss 432 | 433 | def on_validation_end(self) -> None: 434 | # Maybe put here example of video reconstructions? 435 | pass 436 | 437 | def configure_optimizers(self) -> Optimizer: 438 | optim = self.optimizer( 439 | self.parameters(), 440 | ) 441 | 442 | return optim -------------------------------------------------------------------------------- /genie/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from einops import rearrange 4 | 5 | from torch import Tensor 6 | from torch.utils.data import get_worker_info 7 | 8 | from typing import TypeVar, Tuple 9 | 10 | T = TypeVar('T') 11 | D = TypeVar('D') 12 | 13 | Blueprint = Tuple[str | Tuple[str, dict], ...] 14 | 15 | def exists(var : T | None) -> bool: 16 | return var is not None 17 | 18 | def default(var : T | None, val : D) -> T | D: 19 | return var if exists(var) else val 20 | 21 | def enlarge_as(src : Tensor, other : Tensor) -> Tensor: 22 | ''' 23 | Add sufficient number of singleton dimensions 24 | to tensor a **to the right** so to match the 25 | shape of tensor b. NOTE that simple broadcasting 26 | works in the opposite direction. 27 | ''' 28 | return rearrange(src, f'... -> ...{" 1" * (other.dim() - src.dim())}').contiguous() 29 | 30 | def pick_frames( 31 | video : Tensor, 32 | frames_idxs : Tensor | None = None, 33 | frames_per_batch : int | None = None, 34 | ) -> Tensor: 35 | ''' 36 | Randomly pick a subset of frames from the input video 37 | tensor. The number of frames to pick is determined by 38 | the `frames_per_batch` parameter. 39 | ''' 40 | assert exists(frames_idxs) ^ exists(frames_per_batch), 'Either `frames_idxs` or `frames_per_batch` must be provided.' 41 | 42 | b, c, t, h, w = video.shape 43 | 44 | # Randomly sample the indices of the frames to pick 45 | frame_idxs = default(frames_idxs, torch.cat([ 46 | torch.randperm(t, device=video.device)[:frames_per_batch] 47 | for _ in range(b)] 48 | ) 49 | ) 50 | 51 | batch_idxs = torch.repeat_interleave( 52 | torch.arange(b, device=video.device), 53 | default(frames_per_batch, frame_idxs.numel() // b) 54 | ) 55 | 56 | return video[batch_idxs, :, frame_idxs, ...] 57 | 58 | def enc2dec_name(name : str) -> str: 59 | return name.replace('downsample', 'upsample') 60 | 61 | def default_iterdata_worker_init(worker_id : int) -> None: 62 | torch.manual_seed(torch.initial_seed() + worker_id) 63 | worker_info = get_worker_info() 64 | 65 | if worker_info is None: return 66 | 67 | dataset = worker_info.dataset 68 | glob_start = dataset._start # type: ignore 69 | glob_end = dataset._end # type: ignore 70 | 71 | per_worker = int((glob_end - glob_start) / worker_info.num_workers) 72 | worker_id = worker_info.id 73 | 74 | dataset._start = glob_start + worker_id * per_worker # type: ignore 75 | dataset._end = min(dataset._start + per_worker, glob_end) # type: ignore -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.8.0 2 | lightning==2.2.4 3 | setuptools==69.5.1 4 | torch==2.3.0 5 | -------------------------------------------------------------------------------- /res/Genie.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myscience/open-genie/732b9f9b746f18fff1a0fb22f83638224f2f7cc6/res/Genie.png -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import gym 3 | from os import path 4 | from os import makedirs 5 | 6 | import argparse 7 | from tqdm.auto import trange 8 | 9 | ROOT='path/to/data' 10 | 11 | def save_frames_to_video(frames, output_file, fps=30): 12 | # Get the shape of the frame to set the video width and height 13 | height, width, layers = frames[0].shape 14 | size = (width, height) 15 | 16 | # Define the codec and create VideoWriter object 17 | # You can use different codecs, here 'mp4v' is used for .mp4 files 18 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 19 | out = cv2.VideoWriter(output_file, fourcc, fps, size) 20 | 21 | for frame in frames: 22 | out.write(frame) 23 | 24 | # Release the VideoWriter object 25 | out.release() 26 | 27 | def main(args): 28 | env_name = args.env_name 29 | num_envs = args.num_envs 30 | timeout = args.timeout 31 | 32 | for seed in trange(num_envs, desc=f'Generating {env_name} videos'): 33 | env = gym.make( 34 | f'procgen:procgen-{env_name.lower()}-v0', 35 | distribution_mode="hard", 36 | render_mode='rgb_array', 37 | start_level=seed, 38 | num_levels=1, 39 | use_sequential_levels=True, 40 | ) 41 | 42 | frames = [env.reset()] 43 | frames.extend([ 44 | env.step(env.action_space.sample())[0] 45 | for _ in range(timeout - 1) 46 | ]) 47 | 48 | env.close() 49 | 50 | savepath = path.join(args.root, env_name, f'{str(seed).zfill(4)}.mp4') 51 | makedirs(path.dirname(savepath), exist_ok=True) 52 | 53 | save_frames_to_video(frames, savepath) 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description='Generate videos of a gym environment') 57 | parser.add_argument('--env_name', type=str, default='Coinrun', help='Name of the environment') 58 | parser.add_argument('--num_envs', type=int, default=1, help='Number of samples to generate') 59 | parser.add_argument('--timeout', type=int, default=1000, help='Timeout for generating samples') 60 | parser.add_argument('--root', type=str, default=ROOT, help='Root folder where to save the videos') 61 | 62 | args = parser.parse_args() 63 | 64 | main(args) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='genie', 5 | version='0.1', 6 | packages=find_packages(), 7 | ) -------------------------------------------------------------------------------- /test/test_action.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from genie.action import LatentAction 4 | 5 | ENC_BLUEPRINT = ( 6 | ('space-time_attn', { 7 | 'n_rep' : 2, 8 | 'n_embd' : 256, 9 | 'n_head' : 4, 10 | 'd_head' : 16, 11 | }), 12 | ('spacetime_downsample', { 13 | 'in_channels' : 256, 14 | 'kernel_size' : 3, 15 | 'time_factor' : 1, 16 | 'space_factor' : 2, 17 | }), 18 | ('space-time_attn', { 19 | 'n_rep' : 2, 20 | 'n_embd' : 256, 21 | 'n_head' : 4, 22 | 'd_head' : 16, 23 | }), 24 | ) 25 | 26 | DEC_BLUEPRINT = ( 27 | ('space-time_attn', { 28 | 'n_rep' : 2, 29 | 'n_embd' : 256, 30 | 'n_head' : 4, 31 | 'd_head' : 16, 32 | 'has_ext' : True, 33 | 'time_attn_kw' : {'key_dim' : 8}, 34 | }), 35 | ('spacetime_upsample', { 36 | 'in_channels' : 256, 37 | 'kernel_size' : 3, 38 | 'time_factor' : 1, 39 | 'space_factor' : 2, 40 | }), 41 | ('space-time_attn', { 42 | 'n_rep' : 2, 43 | 'n_embd' : 256, 44 | 'n_head' : 4, 45 | 'd_head' : 16, 46 | 'has_ext' : True, 47 | 'time_attn_kw' : {'key_dim' : 8}, 48 | }), 49 | ) 50 | 51 | class TestLatentAction(unittest.TestCase): 52 | def setUp(self): 53 | self.enc_desc = ENC_BLUEPRINT 54 | self.dec_desc = DEC_BLUEPRINT 55 | self.d_codebook = 8 56 | self.inp_channels = 3 57 | self.ker_size = 3 58 | self.n_embd = 256 59 | self.n_codebook = 1 60 | self.lfq_bias = True 61 | self.lfq_frac_sample = 1.0 62 | self.lfq_commit_weight = 0.25 63 | self.lfq_entropy_weight = 0.1 64 | self.lfq_diversity_weight = 1.0 65 | 66 | self.inp_shape = 64, 64 67 | self.batch_size = 2 68 | 69 | def test_encode(self): 70 | model = LatentAction( 71 | self.enc_desc, 72 | self.dec_desc, 73 | d_codebook=self.d_codebook, 74 | inp_channels=self.inp_channels, 75 | inp_shape=self.inp_shape, 76 | ker_size=self.ker_size, 77 | n_embd=self.n_embd, 78 | n_codebook=self.n_codebook, 79 | lfq_bias=self.lfq_bias, 80 | lfq_frac_sample=self.lfq_frac_sample, 81 | lfq_commit_weight=self.lfq_commit_weight, 82 | lfq_entropy_weight=self.lfq_entropy_weight, 83 | lfq_diversity_weight=self.lfq_diversity_weight, 84 | ) 85 | 86 | video = torch.randn(self.batch_size, self.inp_channels, 16, *self.inp_shape) 87 | act, q_loss = model.encode(video) 88 | 89 | self.assertEqual(act.shape, (self.batch_size, 16, self.d_codebook)) 90 | self.assertEqual(q_loss.shape, ()) 91 | self.assertTrue(q_loss >= 0) 92 | 93 | def test_decode(self): 94 | model = LatentAction( 95 | self.enc_desc, 96 | self.dec_desc, 97 | d_codebook=self.d_codebook, 98 | inp_channels=self.inp_channels, 99 | ker_size=self.ker_size, 100 | n_embd=self.n_embd, 101 | n_codebook=self.n_codebook, 102 | lfq_bias=self.lfq_bias, 103 | lfq_frac_sample=self.lfq_frac_sample, 104 | lfq_commit_weight=self.lfq_commit_weight, 105 | lfq_entropy_weight=self.lfq_entropy_weight, 106 | lfq_diversity_weight=self.lfq_diversity_weight, 107 | ) 108 | 109 | h, w = self.inp_shape[0] // 2, self.inp_shape[1] // 2 110 | video = torch.randn(self.batch_size, self.n_embd, 16, h, w) 111 | q_act = torch.randint(0, self.d_codebook, (self.batch_size, 16, self.d_codebook), dtype=torch.float) 112 | recon = model.decode(video, q_act=q_act) 113 | 114 | self.assertEqual(recon.shape, (self.batch_size, self.inp_channels, 16, *self.inp_shape)) 115 | 116 | def test_forward(self): 117 | model = LatentAction( 118 | self.enc_desc, 119 | self.dec_desc, 120 | d_codebook=self.d_codebook, 121 | inp_channels=self.inp_channels, 122 | ker_size=self.ker_size, 123 | n_embd=self.n_embd, 124 | n_codebook=self.n_codebook, 125 | lfq_bias=self.lfq_bias, 126 | lfq_frac_sample=self.lfq_frac_sample, 127 | lfq_commit_weight=self.lfq_commit_weight, 128 | lfq_entropy_weight=self.lfq_entropy_weight, 129 | lfq_diversity_weight=self.lfq_diversity_weight, 130 | ) 131 | 132 | video = torch.randn(1, self.inp_channels, 16, *self.inp_shape) 133 | idxs, loss, aux_losses = model(video) 134 | 135 | self.assertEqual(loss.shape, ()) 136 | self.assertTrue(loss >= 0) 137 | 138 | if __name__ == '__main__': 139 | unittest.main() -------------------------------------------------------------------------------- /test/test_attention.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | 5 | from genie.module.attention import Attention 6 | from genie.module.attention import SpatialAttention 7 | from genie.module.attention import TemporalAttention 8 | from genie.module.attention import SpaceTimeAttention 9 | 10 | class TestAttentionModule(unittest.TestCase): 11 | def setUp(self) -> None: 12 | self.n_embd = 16 13 | self.n_head = 4 14 | self.d_head = 32 15 | self.bias = True 16 | self.cond_dim = 8 17 | 18 | def test_self_attention(self): 19 | # Test SelfAttention class 20 | attn = Attention( 21 | n_embd = self.n_embd, 22 | n_head = self.n_head, 23 | d_head = self.d_head, 24 | bias = self.bias, 25 | causal = True, 26 | ) 27 | 28 | mock_seq = torch.randn(2, 16, self.n_embd) 29 | 30 | attn_out = attn(mock_seq) 31 | 32 | self.assertEqual(attn_out.shape, mock_seq.shape) 33 | 34 | def test_cross_attention(self): 35 | # Test SelfAttention class 36 | attn = Attention( 37 | n_embd = self.n_embd, 38 | n_head = self.n_head, 39 | d_head = self.d_head, 40 | bias = self.bias, 41 | causal = True, 42 | key_dim = self.cond_dim, 43 | ) 44 | 45 | mock_seq = torch.randn(2, 16, self.n_embd) 46 | mock_cond = torch.randn(2, 16, self.cond_dim) 47 | 48 | attn_out = attn(mock_seq, mock_cond) 49 | 50 | self.assertEqual(attn_out.shape, mock_seq.shape) 51 | 52 | def test_spatial_attention_image(self): 53 | # Test SpatialAttention class 54 | attn = SpatialAttention( 55 | d_inp = self.n_embd, 56 | n_head = self.n_head, 57 | d_head = self.d_head, 58 | bias = self.bias, 59 | causal = True, 60 | ) 61 | 62 | mock_img = torch.randn(2, self.n_embd, 32, 32) 63 | 64 | attn_out = attn(mock_img) 65 | 66 | self.assertEqual(attn_out.shape, mock_img.shape) 67 | 68 | def test_spatial_attention_video(self): 69 | # Test SpatialAttention class 70 | attn = SpatialAttention( 71 | d_inp = self.n_embd, 72 | n_head = self.n_head, 73 | d_head = self.d_head, 74 | bias = self.bias, 75 | causal = True, 76 | ) 77 | 78 | mock_img = torch.randn(2, self.n_embd, 16, 32, 32) 79 | 80 | attn_out = attn(mock_img) 81 | 82 | self.assertEqual(attn_out.shape, mock_img.shape) 83 | 84 | def test_temporal_attention_video(self): 85 | # Test SpatialAttention class 86 | attn = TemporalAttention( 87 | n_embd = self.n_embd, 88 | n_head = self.n_head, 89 | d_head = self.d_head, 90 | bias = self.bias, 91 | causal = True, 92 | ) 93 | 94 | mock_img = torch.randn(2, self.n_embd, 16, 32, 32) 95 | 96 | attn_out = attn(mock_img) 97 | 98 | self.assertEqual(attn_out.shape, mock_img.shape) 99 | 100 | class TestSpaceTimeAttention(unittest.TestCase): 101 | def setUp(self): 102 | self.n_embed = 256 103 | 104 | self.action_block = SpaceTimeAttention( 105 | n_embd=self.n_embed, 106 | n_head=(4, 4), 107 | d_head=(32, 32), 108 | hid_dim=(512, 512), 109 | bias=True, 110 | embed=True, 111 | scale=0.5, 112 | dropout=0.1, 113 | ) 114 | 115 | def test_forward(self): 116 | inp_video = torch.randn(1, self.n_embed, 16, 32, 32) 117 | out_video = self.action_block(inp_video) 118 | 119 | self.assertEqual(out_video.shape, (1, self.n_embed, 16, 32, 32)) 120 | 121 | if __name__ == '__main__': 122 | unittest.main() -------------------------------------------------------------------------------- /test/test_dataset.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import yaml 4 | from os import path 5 | 6 | from genie.dataset import LightningKinetics 7 | 8 | # Loading `local_settings.json` for custom local settings 9 | test_folder = path.dirname(path.abspath(__file__)) 10 | local_settings = path.join(test_folder, '.local.yaml') 11 | 12 | with open(local_settings, 'r') as f: 13 | local_settings = yaml.safe_load(f) 14 | 15 | class TestKineticsDataset(unittest.TestCase): 16 | def setUp(self): 17 | self.batch_size = 16 18 | self.output_format = 'CTHW' 19 | 20 | self.dataset = LightningKinetics( 21 | root=local_settings['kinetics_remote_root'], 22 | frames_per_clip=16, 23 | num_classes=local_settings['num_classes'], 24 | frame_rate=None, 25 | step_between_clips=1, 26 | transform=None, 27 | extensions=('avi', 'mp4'), 28 | download=local_settings['download'], 29 | num_download_workers=4, 30 | num_workers=4, 31 | output_format=self.output_format, 32 | batch_size=self.batch_size, 33 | ) 34 | 35 | def test_setup_fit(self): 36 | self.dataset.setup('fit') 37 | self.assertIsNotNone(self.dataset.train_dataset) 38 | self.assertIsNotNone(self.dataset.valid_dataset) 39 | self.assertIsNone (self.dataset.test__dataset) 40 | 41 | def test_setup_test(self): 42 | self.dataset.setup('test') 43 | self.assertIsNone(self.dataset.train_dataset) 44 | self.assertIsNone(self.dataset.valid_dataset) 45 | self.assertIsNotNone(self.dataset.test__dataset) 46 | 47 | def test_setup_invalid_stage(self): 48 | with self.assertRaises(ValueError): 49 | self.dataset.setup('invalid_stage') 50 | 51 | def test_output_format(self): 52 | self.assertEqual(self.dataset.output_format, self.output_format) 53 | 54 | self.dataset.setup('fit') 55 | video, lbl = self.dataset.train_dataset[0] 56 | 57 | print(video.shape) 58 | print(lbl) 59 | 60 | if __name__ == '__main__': 61 | unittest.main() -------------------------------------------------------------------------------- /test/test_discriminator.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from genie.module.discriminator import FrameDiscriminator, VideoDiscriminator 4 | 5 | class TestDiscriminator(unittest.TestCase): 6 | def setUp(self): 7 | self.time_size = 16 8 | self.frame_size = (64, 64) 9 | self.model_dim = 64 10 | self.dim_mults = (1, 2, 4) 11 | self.down_step = (None, 2, 2) 12 | self.inp_channels = 3 13 | self.kernel_size = 3 14 | self.num_groups = 1 15 | self.num_heads = 4 16 | self.dim_head = 32 17 | 18 | self.batch_size = 2 19 | 20 | def test_frame_discriminator(self): 21 | discriminator = FrameDiscriminator( 22 | inp_size=self.frame_size, 23 | model_dim=self.model_dim, 24 | dim_mults=self.dim_mults, 25 | down_step=self.down_step, 26 | inp_channels=self.inp_channels, 27 | kernel_size=self.kernel_size, 28 | num_groups=self.num_groups, 29 | num_heads=self.num_heads, 30 | dim_head=self.dim_head 31 | ) 32 | input_tensor = torch.randn(self.batch_size, self.inp_channels, self.frame_size[0], self.frame_size[1]) 33 | output_tensor = discriminator(input_tensor) 34 | self.assertEqual(output_tensor.shape, (self.batch_size, )) 35 | 36 | def test_frame_discriminator_attn(self): 37 | discriminator = FrameDiscriminator( 38 | inp_size=self.frame_size, 39 | model_dim=self.model_dim, 40 | dim_mults=self.dim_mults, 41 | down_step=self.down_step, 42 | inp_channels=self.inp_channels, 43 | kernel_size=self.kernel_size, 44 | num_groups=self.num_groups, 45 | num_heads=self.num_heads, 46 | dim_head=self.dim_head, 47 | use_attn=True, 48 | ) 49 | frame_tensor = torch.randn(self.batch_size, self.inp_channels, self.frame_size[0], self.frame_size[1]) 50 | output_tensor = discriminator(frame_tensor) 51 | self.assertEqual(output_tensor.shape, (self.batch_size, )) 52 | 53 | def test_video_discriminator(self): 54 | discriminator = VideoDiscriminator( 55 | inp_size=(self.time_size, *self.frame_size), 56 | model_dim=self.model_dim, 57 | dim_mults=self.dim_mults, 58 | down_step=self.down_step, 59 | inp_channels=self.inp_channels, 60 | kernel_size=self.kernel_size, 61 | num_groups=self.num_groups, 62 | num_heads=self.num_heads, 63 | dim_head=self.dim_head 64 | ) 65 | 66 | video_tensor = torch.randn(self.batch_size, self.inp_channels, self.time_size, self.frame_size[0], self.frame_size[1]) 67 | output_tensor = discriminator(video_tensor) 68 | self.assertEqual(output_tensor.shape, (self.batch_size, )) 69 | 70 | def test_video_discriminator_attn(self): 71 | discriminator = VideoDiscriminator( 72 | inp_size=(self.time_size, *self.frame_size), 73 | model_dim=self.model_dim, 74 | dim_mults=self.dim_mults, 75 | down_step=self.down_step, 76 | inp_channels=self.inp_channels, 77 | kernel_size=self.kernel_size, 78 | num_groups=self.num_groups, 79 | num_heads=self.num_heads, 80 | dim_head=self.dim_head, 81 | use_attn=True, 82 | ) 83 | 84 | video_tensor = torch.randn(self.batch_size, self.inp_channels, self.time_size, self.frame_size[0], self.frame_size[1]) 85 | output_tensor = discriminator(video_tensor) 86 | self.assertEqual(output_tensor.shape, (self.batch_size, )) 87 | 88 | if __name__ == '__main__': 89 | unittest.main() -------------------------------------------------------------------------------- /test/test_dynamics.py: -------------------------------------------------------------------------------- 1 | from math import prod 2 | import unittest 3 | 4 | import torch 5 | from genie.dynamics import DynamicsModel 6 | 7 | class DynamicsModelTestCase(unittest.TestCase): 8 | def setUp(self): 9 | 10 | self.batch_size = 2 11 | self.video_len = 10 12 | self.tok_codebook = 16 13 | self.act_codebook = 4 14 | self.embed_dim = 64 15 | self.img_size = 16 16 | 17 | TEST_DESC = ( 18 | ('space-time_attn', { 19 | 'n_rep' : 4, 20 | 'n_embd' : self.embed_dim, 21 | 'n_head' : 4, 22 | 'd_head' : 16, 23 | 'transpose' : False, 24 | }), 25 | ) 26 | 27 | self.model = DynamicsModel( 28 | desc=TEST_DESC, 29 | tok_vocab=self.tok_codebook, 30 | act_vocab=self.act_codebook, 31 | embed_dim=self.embed_dim, 32 | ) 33 | 34 | self.mock_tokens = torch.randint(0, self.tok_codebook, (self.batch_size, self.video_len, self.img_size, self.img_size)) 35 | self.mock_act_id = torch.randint(0, self.act_codebook, (self.batch_size, self.video_len)) 36 | 37 | def test_forward(self): 38 | # Test the forward method of the DynamicsModel 39 | logits, last = self.model.forward(self.mock_tokens, self.mock_act_id) 40 | 41 | self.assertIsInstance(logits, torch.Tensor) 42 | self.assertEqual(logits.shape, ( 43 | self.batch_size, 44 | self.video_len, 45 | self.img_size, 46 | self.img_size, 47 | self.tok_codebook, 48 | )) 49 | self.assertEqual(last.shape, ( 50 | self.batch_size, 51 | self.img_size, 52 | self.img_size, 53 | self.tok_codebook, 54 | )) 55 | 56 | def test_compute_loss(self): 57 | # Test the compute_loss method of the DynamicsModel 58 | loss = self.model.compute_loss( 59 | self.mock_tokens, 60 | self.mock_act_id, 61 | ) 62 | 63 | self.assertIsInstance(loss, torch.Tensor) 64 | self.assertEqual(loss.shape, ()) 65 | self.assertTrue(loss >= 0) 66 | 67 | def test_generate(self): 68 | # Test the generate method of the DynamicsModel 69 | generated_frames = self.model.generate( 70 | self.mock_tokens, 71 | self.mock_act_id, 72 | steps=5, 73 | ) 74 | 75 | self.assertIsInstance(generated_frames, torch.Tensor) 76 | self.assertEqual(generated_frames.shape, ( 77 | self.batch_size, 78 | self.video_len + 1, 79 | self.img_size, 80 | self.img_size, 81 | )) 82 | 83 | def test_get_schedule(self): 84 | # Test the get_schedule method of the DynamicsModel 85 | steps = 10 86 | schedule = self.model.get_schedule(steps, (self.img_size, self.img_size)) 87 | 88 | print(schedule) 89 | 90 | self.assertIsInstance(schedule, torch.Tensor) 91 | self.assertEqual(schedule.shape, (steps,)) 92 | self.assertEqual(torch.sum(schedule), prod((self.img_size, self.img_size))) 93 | 94 | def tearDown(self): 95 | # Clean up any resources used for testing 96 | pass 97 | 98 | if __name__ == '__main__': 99 | unittest.main() -------------------------------------------------------------------------------- /test/test_image.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from genie.module.image import BlurPooling2d 4 | from genie.module.image import ImageResidualBlock 5 | from genie.module.image import SpaceDownsample 6 | 7 | class TestImageModule(unittest.TestCase): 8 | 9 | def test_space_downsample(self): 10 | in_dim = 64 11 | factor = 2 12 | batch_size = 4 13 | downsample = SpaceDownsample(in_dim, factor) 14 | 15 | inp = torch.randn(batch_size, in_dim, 32, 32) 16 | out = downsample(inp) 17 | 18 | self.assertEqual(out.shape, (batch_size, in_dim, 16, 16)) 19 | 20 | def test_residual_block_no_downsample(self): 21 | inp_channel = 64 22 | out_channel = 128 23 | kernel_size = 3 24 | num_groups = 1 25 | downsample = None 26 | batch_size = 4 27 | residual_block = ImageResidualBlock( 28 | inp_channel, 29 | out_channel=out_channel, 30 | kernel_size=kernel_size, 31 | num_groups=num_groups, 32 | downsample=downsample, 33 | ) 34 | 35 | inp = torch.randn(batch_size, inp_channel, 32, 32) 36 | out = residual_block(inp) 37 | 38 | self.assertEqual(out.shape, (batch_size, out_channel, 32, 32)) 39 | 40 | def test_residual_block_yes_downsample(self): 41 | inp_channel = 64 42 | out_channel = 128 43 | kernel_size = 3 44 | num_groups = 1 45 | downsample = 2 46 | batch_size = 4 47 | residual_block = ImageResidualBlock( 48 | inp_channel, 49 | out_channel=out_channel, 50 | kernel_size=kernel_size, 51 | num_groups=num_groups, 52 | downsample=downsample, 53 | ) 54 | 55 | img_h, img_w = 64, 64 56 | 57 | inp = torch.randn(batch_size, inp_channel, img_h, img_w) 58 | out = residual_block(inp) 59 | 60 | self.assertEqual(out.shape, (batch_size, out_channel, img_h // downsample, img_w // downsample)) 61 | 62 | def test_blur_pooling(self): 63 | kernel_size = 3 64 | batch_size = 4 65 | inp_channel = 64 66 | stride = 2 67 | img_h, img_w = 32, 32 68 | 69 | blur_pooling = BlurPooling2d( 70 | kernel_size, 71 | stride=stride, 72 | ) 73 | 74 | inp = torch.randn(batch_size, inp_channel, img_h, img_w) 75 | out = blur_pooling(inp) 76 | 77 | self.assertEqual(out.shape, (batch_size, inp_channel, img_h // stride, img_w // stride)) 78 | 79 | if __name__ == '__main__': 80 | unittest.main() -------------------------------------------------------------------------------- /test/test_loss.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from genie.module.loss import PerceptualLoss, GANLoss 4 | 5 | class TestLossModule(unittest.TestCase): 6 | def setUp(self) -> None: 7 | self.batch_size = 2 8 | self.num_channels = 3 9 | self.num_frames = 4 10 | self.img_h, self.img_w = 64, 64 11 | self.inp_video = torch.randn( 12 | self.batch_size, 13 | self.num_channels, 14 | self.num_frames, 15 | self.img_h, 16 | self.img_w 17 | ) # Mock real input video tensor 18 | 19 | self.rec_video = torch.randn( 20 | self.batch_size, 21 | self.num_channels, 22 | self.num_frames, 23 | self.img_h, 24 | self.img_w 25 | ) # Mock reconstructed video tensor 26 | 27 | def test_perceptual_loss(self): 28 | model = PerceptualLoss( 29 | model_name='vgg16', 30 | num_frames=2, 31 | feat_layers=('features.6', 'features.13', 'features.18', 'features.25'), 32 | ) 33 | 34 | loss = model(self.rec_video, self.inp_video) 35 | 36 | self.assertEqual(loss.shape, torch.Size([])) # Check the output shape 37 | self.assertTrue(loss >= 0) 38 | 39 | def test_gan_loss_frames(self): 40 | 41 | model = GANLoss( 42 | discriminate='frames', 43 | num_frames=2, 44 | 45 | # Discriminator parameters 46 | inp_size = (self.img_h, self.img_w), 47 | model_dim = 64, 48 | dim_mults = (1, 2, 4), 49 | down_step = (None, 2, 2), 50 | inp_channels = self.num_channels, 51 | kernel_size = 3, 52 | num_groups = 8, 53 | num_heads = 4, 54 | dim_head = 32, 55 | ) 56 | 57 | 58 | loss_gen = model(self.rec_video, self.inp_video, train_gen = True) 59 | loss_dis = model(self.rec_video, self.inp_video, train_gen = False) 60 | 61 | self.assertEqual(loss_gen.shape, torch.Size([])) # Check the output shape 62 | self.assertEqual(loss_dis.shape, torch.Size([])) # Check the output shape 63 | 64 | self.assertTrue(loss_dis >= 0) 65 | 66 | def test_gan_loss_video(self): 67 | 68 | model = GANLoss( 69 | discriminate='video', 70 | num_frames=2, 71 | 72 | # Discriminator parameters 73 | inp_size = (self.num_frames, self.img_h, self.img_w), 74 | model_dim = 64, 75 | dim_mults = (1, 2, 4), 76 | down_step = (None, 2, 2), 77 | inp_channels = self.num_channels, 78 | kernel_size = 3, 79 | num_groups = 8, 80 | num_heads = 4, 81 | dim_head = 32, 82 | ) 83 | 84 | loss_gen = model(self.rec_video, self.inp_video, train_gen = True) 85 | loss_dis = model(self.rec_video, self.inp_video, train_gen = False) 86 | 87 | self.assertEqual(loss_gen.shape, torch.Size([])) # Check the output shape 88 | self.assertEqual(loss_dis.shape, torch.Size([])) # Check the output shape 89 | 90 | self.assertTrue(loss_dis >= 0) 91 | 92 | if __name__ == '__main__': 93 | unittest.main() -------------------------------------------------------------------------------- /test/test_quantization.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import torch 4 | from genie.module.quantization import LookupFreeQuantization 5 | 6 | class TestLookupFreeQuantization(unittest.TestCase): 7 | def setUp(self) -> None: 8 | self.d_codebook = 8 # Codebook quantization, i.e. codebook size 2 ** d_codebook 9 | self.input_dim = 256 # Expected input dimension 10 | self.use_bias = True 11 | self.frac_sample = .8 12 | self.commit_weight = .25 13 | self.entropy_weight = .1 14 | self.diversity_weight = 1. 15 | self.batch_size = 4 16 | 17 | # Create mock input tensor 18 | self.seq_len = 16 19 | self.input_tensor = torch.rand((self.batch_size, self.seq_len, self.input_dim)) 20 | 21 | def test_eval_quantize_single_codebook(self): 22 | num_codebooks = 1 23 | lfq = LookupFreeQuantization( 24 | codebook_dim = self.d_codebook, 25 | num_codebook = num_codebooks, 26 | input_dim = self.input_dim, 27 | use_bias = self.use_bias, 28 | frac_sample = self.frac_sample, 29 | commit_weight = self.commit_weight, 30 | entropy_weight = self.entropy_weight, 31 | diversity_weight = self.diversity_weight 32 | ) 33 | 34 | self.assertEqual(lfq.num_codebooks, 1) 35 | self.assertEqual(lfq.frac_sample, self.frac_sample) 36 | self.assertEqual(lfq.commit_weight, self.commit_weight) 37 | self.assertEqual(lfq.entropy_weight, self.entropy_weight) 38 | self.assertEqual(lfq.diversity_weight, self.diversity_weight) 39 | 40 | # Quantize the input tensor 41 | lfq.eval() # Only the test quantization 42 | (quant, idxs), _ = lfq(self.input_tensor) 43 | 44 | # Check the shape of the quantized tensor 45 | self.assertEqual(quant.shape, (self.batch_size, self.seq_len, self.input_dim)) 46 | self.assertEqual( idxs.shape, (self.batch_size, self.seq_len)) # NOTE: No num_codebooks dimension 47 | 48 | if self.input_dim == self.d_codebook: 49 | # If not output projection, check that tokens have values in {-1, +1} 50 | self.assertTrue(torch.allclose(quant, torch.sign(quant))) 51 | 52 | def test_train_quantize_single_codebook(self): 53 | num_codebooks = 1 54 | lfq = LookupFreeQuantization( 55 | codebook_dim = self.d_codebook, 56 | num_codebook = num_codebooks, 57 | input_dim = self.input_dim, 58 | use_bias = self.use_bias, 59 | frac_sample = self.frac_sample, 60 | commit_weight = self.commit_weight, 61 | entropy_weight = self.entropy_weight, 62 | diversity_weight = self.diversity_weight 63 | ) 64 | 65 | self.assertEqual(lfq.num_codebooks, 1) 66 | self.assertEqual(lfq.frac_sample, self.frac_sample) 67 | self.assertEqual(lfq.commit_weight, self.commit_weight) 68 | self.assertEqual(lfq.entropy_weight, self.entropy_weight) 69 | self.assertEqual(lfq.diversity_weight, self.diversity_weight) 70 | 71 | # Quantize the input tensor 72 | lfq.train() # Only the test quantization 73 | (quant, idxs), loss = lfq(self.input_tensor) 74 | 75 | # Check the shape of the quantized tensor 76 | self.assertEqual(quant.shape, (self.batch_size, self.seq_len, self.input_dim)) 77 | self.assertEqual( idxs.shape, (self.batch_size, self.seq_len)) # NOTE: No num_codebooks dimension 78 | 79 | self.assertGreater(loss, 0.) 80 | 81 | def test_train_quantize_multi_codebook(self): 82 | num_codebooks = 3 83 | lfq = LookupFreeQuantization( 84 | codebook_dim = self.d_codebook, 85 | num_codebook = num_codebooks, 86 | input_dim = self.input_dim, 87 | use_bias = self.use_bias, 88 | frac_sample = self.frac_sample, 89 | commit_weight = self.commit_weight, 90 | entropy_weight = self.entropy_weight, 91 | diversity_weight = self.diversity_weight 92 | ) 93 | 94 | self.assertEqual(lfq.num_codebooks, num_codebooks) 95 | self.assertEqual(lfq.frac_sample, self.frac_sample) 96 | self.assertEqual(lfq.commit_weight, self.commit_weight) 97 | self.assertEqual(lfq.entropy_weight, self.entropy_weight) 98 | self.assertEqual(lfq.diversity_weight, self.diversity_weight) 99 | 100 | # Quantize the input tensor 101 | lfq.train() # Only the test quantization 102 | (quant, idxs), loss = lfq(self.input_tensor) 103 | 104 | # Check the shape of the quantized tensor 105 | self.assertEqual(quant.shape, (self.batch_size, self.seq_len, self.input_dim)) 106 | self.assertEqual( idxs.shape, (self.batch_size, self.seq_len, num_codebooks)) 107 | 108 | self.assertGreater(loss, 0.) 109 | 110 | if __name__ == '__main__': 111 | unittest.main() -------------------------------------------------------------------------------- /test/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | import unittest 3 | 4 | import torch 5 | import yaml 6 | 7 | from genie import VideoTokenizer 8 | from genie.dataset import LightningPlatformer2D 9 | from genie.tokenizer import REPR_TOK_ENC 10 | from genie.tokenizer import REPR_TOK_DEC 11 | 12 | from genie.tokenizer import MAGVIT2_ENC_DESC 13 | from genie.tokenizer import MAGVIT2_DEC_DESC 14 | 15 | TEST_ENC_DESC = ( 16 | ('causal-conv3d', { 17 | 'in_channels': 3, 18 | 'out_channels': 64, 19 | 'kernel_size': 3, 20 | }), 21 | ('video-residual', { 22 | 'in_channels': 64, 23 | 'kernel_size': 3, 24 | 'downsample': (1, 2), 25 | 'use_causal': True, 26 | 'use_blur': True, 27 | }), 28 | ('video-residual', { 29 | 'in_channels': 64, 30 | 'out_channels': 128, 31 | }), 32 | ('video-residual', { 33 | 'n_rep': 2, 34 | 'in_channels': 128, 35 | }), 36 | ('video-residual', { 37 | 'in_channels': 128, 38 | 'out_channels': 256, 39 | 'kernel_size': 3, 40 | 'downsample': 2, 41 | 'use_causal': True, 42 | }), 43 | ('causal-conv3d', { 44 | 'in_channels': 256, 45 | 'out_channels': 18, 46 | 'kernel_size': 3, 47 | }) 48 | ) 49 | 50 | TEST_DEC_DESC = ( 51 | ('causal-conv3d', { 52 | 'in_channels': 18, 53 | 'out_channels': 128, 54 | 'kernel_size': 3, 55 | }), 56 | ('video-residual', { 57 | 'n_rep': 2, 58 | 'in_channels': 128, 59 | }), 60 | ('adaptive_group_norm', { 61 | 'num_groups': 8, 62 | 'num_channels': 128, 63 | 'has_ext' : True, 64 | 'dim_cond' : 18, 65 | }), 66 | ('video-residual', { 67 | 'n_rep': 2, 68 | 'in_channels': 128, 69 | }), 70 | ('depth2spacetime_upsample', { 71 | 'in_channels': 128, 72 | 'kernel_size': 3, 73 | 'time_factor': 2, 74 | 'space_factor': 2, 75 | }), 76 | ('adaptive_group_norm', { 77 | 'num_groups': 8, 78 | 'num_channels': 128, 79 | 'has_ext' : True, 80 | 'dim_cond' : 18, 81 | }), 82 | ('video-residual', { 83 | 'in_channels': 128, 84 | 'out_channels': 64, 85 | }), 86 | ('video-residual', { 87 | 'n_rep': 2, 88 | 'in_channels': 64, 89 | }), 90 | ('depth2spacetime_upsample', { 91 | 'in_channels': 64, 92 | 'kernel_size': 3, 93 | 'time_factor': 1, 94 | 'space_factor': 2, 95 | }), 96 | ('adaptive_group_norm', { 97 | 'num_groups': 8, 98 | 'num_channels': 64, 99 | 'has_ext' : True, 100 | 'dim_cond' : 18, 101 | }), 102 | ('causal-conv3d', { 103 | 'in_channels': 64, 104 | 'out_channels': 3, 105 | 'kernel_size': 3, 106 | }) 107 | ) 108 | 109 | # Loading `local_settings.json` for custom local settings 110 | test_folder = path.dirname(path.abspath(__file__)) 111 | local_settings = path.join(test_folder, '.local.yaml') 112 | 113 | with open(local_settings, 'r') as f: 114 | local_settings = yaml.safe_load(f) 115 | 116 | class TestVideoTokenizer(unittest.TestCase): 117 | def setUp(self): 118 | 119 | self.d_codebook = 18 120 | self.n_codebook = 1 121 | 122 | self.batch_size = 2 123 | self.num_frames = 8 124 | self.num_channels = 3 125 | self.img_h, self.img_w = 64, 64 126 | 127 | # Number of channels after the encoding by the MAGVIT2 128 | self.hid_channels = 18 129 | 130 | self.time_down = 4 # This parameters are determined by MAGVIT2 131 | self.space_down = 8 # This parameters are determined by MAGVIT2 132 | 133 | self.tokenizer = VideoTokenizer( 134 | # enc_desc = REPR_TOK_ENC, 135 | # dec_desc = REPR_TOK_DEC, 136 | enc_desc = MAGVIT2_ENC_DESC, 137 | dec_desc = MAGVIT2_DEC_DESC, 138 | 139 | disc_kwargs=dict( 140 | # Discriminator parameters 141 | inp_size = (self.img_h, self.img_w), 142 | model_dim = 64, 143 | dim_mults = (1, 2, 4), 144 | down_step = (None, 2, 2), 145 | inp_channels = self.num_channels, 146 | kernel_size = 3, 147 | use_attn = False, 148 | use_blur = True, 149 | num_groups = 8, 150 | num_heads = 4, 151 | dim_head = 32, 152 | ), 153 | 154 | d_codebook = self.d_codebook, 155 | n_codebook = self.n_codebook, 156 | # 157 | lfq_bias = True, 158 | lfq_frac_sample = 1., 159 | lfq_commit_weight = 0.25, 160 | lfq_entropy_weight = 0.1, 161 | lfq_diversity_weight = 1., 162 | # 163 | perceptual_model = 'vgg16', 164 | perc_feat_layers = ('features.6', 'features.13', 'features.18', 'features.25'), 165 | gan_discriminate='frames', 166 | gan_frames_per_batch = 4, 167 | gan_loss_weight = 1., 168 | perc_loss_weight = 1., 169 | quant_loss_weight = 1., 170 | ) 171 | 172 | # Example video tensor 173 | self.video = torch.randn( 174 | self.batch_size, 175 | self.num_channels, 176 | self.num_frames, 177 | self.img_h, 178 | self.img_w 179 | ) 180 | 181 | def test_encode(self): 182 | encoded = self.tokenizer.encode(self.video) 183 | self.assertEqual(encoded.shape, ( 184 | self.batch_size, 185 | self.hid_channels, 186 | self.num_frames // self.time_down, 187 | self.img_h // self.space_down, 188 | self.img_w // self.space_down 189 | )) # Check output shape 190 | 191 | def test_decode(self): 192 | quantized = torch.randn( 193 | self.batch_size, 194 | self.hid_channels, 195 | self.num_frames // self.time_down, 196 | self.img_h // self.space_down, 197 | self.img_w // self.space_down, 198 | ) # Example quantized tensor 199 | decoded = self.tokenizer.decode(quantized) 200 | self.assertEqual(decoded.shape, ( 201 | self.batch_size, 202 | self.num_channels, 203 | self.num_frames, 204 | self.img_h, 205 | self.img_w 206 | )) # Check output shape 207 | 208 | def test_tokenize(self): 209 | tokens, idxs = self.tokenizer.tokenize(self.video) 210 | self.assertEqual(tokens.shape, ( 211 | self.batch_size, 212 | self.hid_channels, 213 | self.num_frames // self.time_down, 214 | self.img_h // self.space_down, 215 | self.img_w // self.space_down, 216 | )) # Check output shape 217 | 218 | if self.hid_channels == 2 ** self.d_codebook: 219 | # If not output projection, check that tokens have values in {-1, +1} 220 | self.assertTrue(torch.allclose(tokens, torch.sign(tokens))) 221 | 222 | self.assertEqual(idxs.shape, ( 223 | self.batch_size, 224 | self.num_frames // self.time_down, 225 | self.img_h // self.space_down, 226 | self.img_w // self.space_down, 227 | )) 228 | 229 | def test_forward(self): 230 | loss, aux_losses = self.tokenizer(self.video) 231 | 232 | self.assertTrue(loss >= 0) 233 | for loss in aux_losses: 234 | self.assertEqual(loss.shape, torch.Size([])) # Check the output shape 235 | 236 | print(aux_losses) 237 | self.assertTrue(aux_losses[0] >= 0) 238 | self.assertTrue(aux_losses[2] >= 0) 239 | self.assertTrue(aux_losses[3] >= 0) 240 | self.assertTrue(aux_losses[4] >= 0) 241 | 242 | def test_forward_platformer_2d(self): 243 | dataset = LightningPlatformer2D( 244 | root=local_settings['platformer_remote_root'], 245 | output_format='c t h w', 246 | transform=None, 247 | randomize=True, 248 | batch_size=self.batch_size, 249 | num_frames=self.num_frames, 250 | num_workers=4, 251 | ) 252 | 253 | dataset.setup('fit') 254 | loader = dataset.train_dataloader() 255 | 256 | video = next(iter(loader)) 257 | 258 | loss, aux_losses = self.tokenizer(video) 259 | 260 | self.assertTrue(loss >= 0) 261 | for loss in aux_losses: 262 | self.assertEqual(loss.shape, torch.Size([])) # Check the output shape 263 | 264 | print(aux_losses) 265 | self.assertTrue(aux_losses[0] >= 0) 266 | self.assertTrue(aux_losses[2] >= 0) 267 | self.assertTrue(aux_losses[3] >= 0) 268 | self.assertTrue(aux_losses[4] >= 0) 269 | 270 | if __name__ == '__main__': 271 | unittest.main() -------------------------------------------------------------------------------- /test/test_video.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import torch 3 | from genie.module.video import CausalConv3d 4 | from genie.module.video import CausalConvTranspose3d 5 | from genie.module.video import DepthToSpaceUpsample 6 | from genie.module.video import DepthToTimeUpsample 7 | from genie.module.video import SpaceTimeDownsample 8 | from genie.module.video import SpaceTimeUpsample 9 | from genie.module.video import VideoResidualBlock 10 | 11 | class TestVideoModule(unittest.TestCase): 12 | def test_causal_conv3d(self): 13 | # Create a CausalConv3d instance 14 | conv = CausalConv3d(3, 64, kernel_size=3) 15 | 16 | # Create a random input tensor 17 | inp = torch.randn(1, 3, 16, 16, 16) 18 | 19 | # Perform forward pass 20 | out = conv(inp) 21 | 22 | # Check output shape 23 | self.assertEqual(out.shape, (1, 64, 16, 16, 16)) 24 | 25 | def test_causal_conv_transpose3d(self): 26 | # Create a CausalConvTranspose3d instance 27 | conv_transpose = CausalConvTranspose3d(64, 3, kernel_size=3) 28 | 29 | # Create a random input tensor 30 | inp = torch.randn(1, 64, 16, 16, 16) 31 | 32 | # Perform forward pass 33 | out = conv_transpose(inp) 34 | 35 | # Check output shape 36 | self.assertEqual(out.shape, (1, 3, 16, 16, 16)) 37 | 38 | def test_space_upsample(self): 39 | # Create a SpaceUpsample instance 40 | upsample = DepthToSpaceUpsample(64, factor=2) 41 | 42 | # Create a random input tensor 43 | inp = torch.randn(1, 64, 8, 16, 16) 44 | 45 | # Perform forward pass 46 | out = upsample(inp) 47 | 48 | # Check output shape 49 | self.assertEqual(out.shape, (1, 64, 8, 32, 32)) 50 | 51 | def test_time_upsample(self): 52 | # Create a TimeUpsample instance 53 | upsample = DepthToTimeUpsample(64, factor=2) 54 | 55 | # Create a random input tensor 56 | inp = torch.randn(1, 64, 8, 16, 16) 57 | 58 | # Perform forward pass 59 | out = upsample(inp) 60 | 61 | # Check output shape 62 | self.assertEqual(out.shape, (1, 64, 16, 16, 16)) 63 | 64 | def test_space_time_downsample(self): 65 | # Create a SpaceTimeDownsample instance 66 | downsample = SpaceTimeDownsample( 67 | in_channels=64, 68 | kernel_size=3, 69 | out_channels=128, 70 | time_factor=2, 71 | space_factor=2 72 | ) 73 | 74 | # Create a random input tensor 75 | inp = torch.randn(1, 64, 16, 28, 28) 76 | 77 | # Perform forward pass 78 | out = downsample(inp) 79 | 80 | # Check output shape 81 | self.assertEqual(out.shape, (1, 128, 8, 14, 14)) 82 | 83 | def test_space_time_upsample(self): 84 | # Create a SpaceTimeUpsample instance 85 | upsample = SpaceTimeUpsample(128, 64, time_factor=2, space_factor=2) 86 | 87 | # Create a random input tensor 88 | inp = torch.randn(1, 128, 8, 7, 7) 89 | 90 | # Perform forward pass 91 | out = upsample(inp) 92 | 93 | # Check output shape 94 | self.assertEqual(out.shape, (1, 64, 16, 14, 14)) 95 | 96 | def test_residual_block(self): 97 | # Create a ResidualBlock instance 98 | block = VideoResidualBlock( 99 | in_channels=64, 100 | out_channels=128, 101 | ) 102 | 103 | # Create a random input tensor 104 | inp = torch.randn(1, 64, 8, 16, 16) 105 | 106 | # Perform forward pass 107 | out = block(inp) 108 | 109 | # Check output shape 110 | self.assertEqual(out.shape, (1, 128, 8, 16, 16)) 111 | 112 | def test_residual_block_causal(self): 113 | # Create a ResidualBlock instance 114 | block = VideoResidualBlock( 115 | in_channels=64, 116 | out_channels=128, 117 | num_groups=2, 118 | use_causal=True, 119 | ) 120 | 121 | # Create a random input tensor 122 | inp = torch.randn(1, 64, 8, 16, 16) 123 | 124 | # Perform forward pass 125 | out = block(inp) 126 | 127 | # Check output shape 128 | self.assertEqual(out.shape, (1, 128, 8, 16, 16)) 129 | 130 | def test_residual_block_downsample(self): 131 | # Create a ResidualBlock instance 132 | block = VideoResidualBlock( 133 | in_channels=64, 134 | out_channels=128, 135 | downsample=(2, 4), 136 | act_fn='leaky', 137 | ) 138 | 139 | # Create a random input tensor 140 | inp = torch.randn(1, 64, 8, 16, 16) 141 | 142 | # Perform forward pass 143 | out = block(inp) 144 | 145 | # Check output shape 146 | self.assertEqual(out.shape, (1, 128, 4, 4, 4)) 147 | 148 | def test_residual_block_causal_downsample(self): 149 | # Create a ResidualBlock instance 150 | block = VideoResidualBlock( 151 | in_channels=64, 152 | out_channels=128, 153 | num_groups=2, 154 | use_causal=True, 155 | act_fn='leaky', 156 | downsample=(2, 4), 157 | ) 158 | 159 | # Create a random input tensor 160 | inp = torch.randn(1, 64, 8, 16, 16) 161 | 162 | # Perform forward pass 163 | out = block(inp) 164 | 165 | # Check output shape 166 | self.assertEqual(out.shape, (1, 128, 4, 4, 4)) 167 | 168 | if __name__ == '__main__': 169 | unittest.main() -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | from lightning.pytorch.cli import LightningCLI 2 | 3 | from genie import VideoTokenizer 4 | from genie.dataset import LightningPlatformer2D 5 | 6 | def cli_main(): 7 | ''' 8 | Main function for the training script. 9 | ''' 10 | 11 | # That's all it takes for LightningCLI to work! 12 | # No need to call .fit() or .test() or anything like that. 13 | cli = LightningCLI( 14 | VideoTokenizer, 15 | LightningPlatformer2D, 16 | ) 17 | 18 | if __name__ == '__main__': 19 | cli_main() --------------------------------------------------------------------------------