├── LICENSE ├── README.md ├── __init__.py ├── configs ├── 3D_ENV_BENCHMARK │ ├── dmlab │ │ ├── dmlab_S5.yaml │ │ ├── dmlab_S5_eval.yaml │ │ ├── dmlab_convs5.yaml │ │ ├── dmlab_convs5_eval.yaml │ │ ├── dmlab_teco_S5.yaml │ │ ├── dmlab_teco_S5_eval.yaml │ │ ├── dmlab_teco_convS5.yaml │ │ ├── dmlab_teco_convS5_eval.yaml │ │ ├── dmlab_teco_transformer.yaml │ │ ├── dmlab_teco_transformer_eval.yaml │ │ ├── dmlab_transformer.yaml │ │ └── dmlab_transformer_eval.yaml │ ├── habitat │ │ ├── habitat_convS5.yaml │ │ └── habitat_teco_convS5.yaml │ └── minecraft │ │ ├── minecraft_convS5.yaml │ │ └── minecraft_teco_convS5.yaml └── Moving-MNIST │ ├── 300_train_len │ ├── mnist_convLSTM_novq.yaml │ ├── mnist_convLSTM_novq_eval.yaml │ ├── mnist_convS5_novq.yaml │ ├── mnist_convS5_novq_eval.yaml │ ├── mnist_noVQ_transformer.yaml │ └── mnist_noVQ_transformer_eval.yaml │ └── 600_train_len │ ├── mnist_convLSTM_novq.yaml │ ├── mnist_convLSTM_novq_eval.yaml │ ├── mnist_convS5_novq.yaml │ ├── mnist_convS5_novq_eval.yaml │ ├── mnist_noVQ_transformer.yaml │ └── mnist_noVQ_transformer_eval.yaml ├── data └── moving-mnist-pytorch │ └── moving_mnist.py ├── figs └── convssm_50.png ├── requirements.txt ├── scripts ├── __init__.py ├── compute_fvd.py ├── compute_fvd_mnist.py ├── compute_metrics.py ├── download │ ├── dmlab.sh │ ├── dmlab_encoded.sh │ ├── habitat.sh │ ├── habitat_encoded.sh │ ├── kinetics600_encoded.sh │ ├── minecraft.sh │ └── minecraft_encoded.sh ├── eval.py └── train.py ├── setup.py └── src ├── __init__.py ├── data.py ├── fvd.py ├── fvd_mnist.py ├── metrics.py ├── models ├── S5 │ ├── __init__.py │ ├── diagonal_scans.py │ ├── diagonal_ssm.py │ └── layers.py ├── __init__.py ├── base.py ├── convLSTM │ ├── __init__.py │ ├── conv_ops.py │ ├── layers.py │ ├── scans.py │ └── ssm.py ├── convS5 │ ├── __init__.py │ ├── conv_ops.py │ ├── diagonal_scans.py │ ├── diagonal_ssm.py │ └── layers.py ├── sampling │ ├── __init__.py │ ├── sample_convSSM.py │ ├── sample_convSSM_noVQ.py │ ├── sample_transformer.py │ └── sample_transformer_noVQ.py ├── sequence_models │ ├── VQ │ │ ├── S5.py │ │ ├── __init__.py │ │ ├── convS5.py │ │ ├── teco_S5.py │ │ ├── teco_convS5.py │ │ ├── teco_transformer.py │ │ └── transformer.py │ ├── __init__.py │ └── noVQ │ │ ├── .ipynb_checkpoints │ │ └── transformer-checkpoint.py │ │ ├── __init__.py │ │ ├── convLSTM.py │ │ ├── convS5.py │ │ └── transformer.py ├── transformer │ ├── __init__.py │ ├── favor.py │ ├── maskgit.py │ └── transformer.py ├── vae.py └── vqgan.py ├── runtime_metrics.py ├── train_utils.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | 3 | NVIDIA Source Code License for Convolutional State Space Models for Long-Range Spatiotemporal Modeling 4 | 5 | ======================================================================= 6 | 7 | 1. Definitions 8 | 9 | “Licensor” means any person or entity that distributes its Work. 10 | 11 | “Work” means (a) the original work of authorship made available under 12 | this license, which may include software, documentation, or other files, 13 | and (b) any additions to or derivative works thereof that are made 14 | available under this license. 15 | 16 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” 17 | have the meaning as provided under U.S. copyright law; provided, however, 18 | that for the purposes of this license, derivative works shall not include works 19 | that remain separable from, or merely link (or bind by name) to the 20 | interfaces of, the Work. 21 | 22 | Works are “made available” under this license by including in or with the Work 23 | either (a) a copyright notice referencing the applicability of 24 | this license to the Work, or (b) a copy of this license. 25 | 26 | 2. License Grant 27 | 28 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each 29 | Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, 30 | copyright license to use, reproduce, prepare derivative works of, publicly display, 31 | publicly perform, sublicense and distribute its Work and any resulting derivative 32 | works in any form. 33 | 34 | 3. Limitations 35 | 36 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under 37 | this license, (b) you include a complete copy of this license with your distribution, 38 | and (c) you retain without modification any copyright, patent, trademark, or 39 | attribution notices that are present in the Work. 40 | 41 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, 42 | reproduction, and distribution of your derivative works of the Work (“Your Terms”) only 43 | if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative 44 | works, and (b) you identify the specific derivative works that are subject to Your Terms. 45 | Notwithstanding Your Terms, this license (including the redistribution requirements in 46 | Section 3.1) will continue to apply to the Work itself. 47 | 48 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or 49 | intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation 50 | and its affiliates may use the Work and any derivative works commercially. 51 | As used herein, “non-commercially” means for research or evaluation purposes only. 52 | 53 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor 54 | (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that 55 | you allege are infringed by any Work, then your rights under this license from 56 | such Licensor (including the grant in Section 2.1) will terminate immediately. 57 | 58 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its 59 | affiliates’ names, logos, or trademarks, except as necessary to reproduce 60 | the notices described in this license. 61 | 62 | 3.6 Termination. If you violate any term of this license, then your rights under 63 | this license (including the grant in Section 2.1) will terminate immediately. 64 | 65 | 4. Disclaimer of Warranty. 66 | 67 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, 68 | EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 69 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. 70 | YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 71 | 72 | 5. Limitation of Liability. 73 | 74 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, 75 | WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR 76 | BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, 77 | OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR 78 | INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS 79 | INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY 80 | OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE 81 | POSSIBILITY OF SUCH DAMAGES. 82 | 83 | ======================================================================= 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convolutional State Space Models for Long-Range Spatiotemporal Modeling 2 | 3 | This repository provides the official JAX implementation for the 4 | paper: 5 | 6 | **Convolutional State Space Models for Long-Range Spatiotemporal Modeling** [[arXiv]](https://arxiv.org/abs/2310.19694) 7 | 8 | [Jimmy T.H. Smith](https://jimmysmith1919.github.io/), 9 | [Shalini De Mello](https://research.nvidia.com/person/shalini-de-mello), 10 | [Jan Kautz](https://jankautz.com), 11 | [Scott Linderman](https://web.stanford.edu/~swl1/), 12 | [Wonmin Byeon](https://wonmin-byeon.github.io/), 13 | NeurIPS 2023. 14 | 15 | 16 | For business inquiries, please visit the NVIDIA website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/). 17 | 18 | --- 19 | 20 | We introduce an efficient long-range spatiotemporal sequence modeling method, **ConvSSM**. It is parallelizable and overcomes major limitations of the traditional ConvRNN (e.g., vanishing/exploding gradient problems) while providing an unbounded context and fast autoregressive generation compared to Transformers. It performs similarly or better than Transformers/ConvLSTM on long-horizon video prediction tasks, trains up to 3× faster than ConvLSTM, and generates samples up to 400× faster than Transformers. We provide the results for the long horizon Moving-MNIST generation task and long-range 3D environment benchmarks (DMLab, Minecraft, and Habitat). 21 | 22 | ![teaser](figs/convssm_50.png) 23 | 24 | The repository builds on the training pipeline from [TECO](https://github.com/wilson1yan/teco). 25 | 26 | --- 27 | 28 | ### Installation 29 | You will need to install JAX following the instructions [here](https://jax.readthedocs.io/en/latest/installation.html). 30 | We used JAX version 0.3.21. 31 | ```commandline 32 | pip install --upgrade jax[cuda]==0.3.21 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 33 | ``` 34 | 35 | Then install the rest of the dependencies with: 36 | ```commandline 37 | sudo apt-get update && sudo apt-get install -y ffmpeg 38 | pip install -r requirements.txt 39 | pip install -e . 40 | ``` 41 | 42 | 43 | --- 44 | 45 | ### Datasets 46 | For `Moving-Mnist`: 47 | 48 | 1) Download the MNIST binary file. 49 | ```commandline 50 | wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz -O data/moving-mnist-pytorch/train-images-idx3-ubyte.gz 51 | ``` 52 | 2) Use the script in `data/moving-mnist-pytorch` to generate the Moving MNIST data. 53 | 54 | For 3D Environment tasks: 55 | 56 | We used the scripts from the [TECO](https://github.com/wilson1yan/teco) repository to download the datasets; [`DMLab`](https://github.com/wilson1yan/teco/blob/master/scripts/download/dmlab.sh) and 57 | [`Habitat`](https://github.com/wilson1yan/teco/blob/master/scripts/download/habitat.sh). Check the TECO repository for the details of the datasets. 58 | 59 | 60 | The data should be split into 'train' and 'test' folders. 61 | 62 | --- 63 | 64 | ### Pretrained VQ-GANs: 65 | Pretrained VQ-GAN checkpoints for each dataset can be found [here](https://drive.google.com/drive/folders/10hAqVjoxte9OxYc7WIih_5OtwbdOxKoi). Note these are also from [TECO](https://github.com/wilson1yan/teco). 66 | 67 | --- 68 | 69 | ### Pretrained ConvS5 checkpoints: 70 | Pretrained ConvS5 checkpoints for each dataset can be found [here](https://www.dropbox.com/scl/fo/h3omm0bc3dau9uh9cgrq0/AICA1umpuN1LRG_MRwUyPWU?rlkey=s9w4d3ncsfz39n2r390dpbsk2&st=v722uk6x&dl=0). Download the checkpoints to the checkpoint_directories. 71 | Default checkpoint_directory: `logs//checkpoints/` 72 | 73 | | dataset | checkpoint | config | 74 | |:---:|:---:|:---:| 75 | | Moving-Mnist 300 | [link](https://www.dropbox.com/scl/fo/wg6f4cazhlw5cs3fjfapf/AERdhoK8HARlwwlGvu8ZpLY?rlkey=spq6umv7m2scywntwgqswxvys&st=hpbemnp2&dl=0) | `Moving-MNIST/300_train_len/mnist_convS5_novq.yaml` | 76 | | Moving-Mnist 600 | [link](https://www.dropbox.com/scl/fo/1vog37ntlr67084o6qbpm/AD8B3ZZIek9pxDvb80rhd4k?rlkey=krp36u6zbu8nac4foml9f3bq1&st=hmuymguf&dl=0) | `Moving-MNIST/600_train_len/mnist_convS5_novq.yaml` | 77 | | DMLab | [link](https://www.dropbox.com/scl/fo/dcy9nhw0umbowang36po1/AKPtYSxP2ynJnUDvoZsdxqc?rlkey=cw7bas02w2mw7ldyephu3w9r7&st=20wwyamh&dl=0) | `3D_ENV_BENCHMARK/dmlab/dmlab_convs5.yaml` | 78 | | Habitat | [link](https://www.dropbox.com/scl/fo/6k6tchauqaguilkr8rb7c/ACyEPN_X00f1xWM_RQFyDF8?rlkey=gx5o11o9n5npfj09gxac8hq2p&st=9k23lyfk&dl=0) | `3D_ENV_BENCHMARK/habitat/habitat_teco_convS5.yaml` | 79 | | Minecraft | [link](https://www.dropbox.com/scl/fo/c4g2ol85hbt58kveoklek/AJvYilfaNdaap1V89u5Z5Oo?rlkey=1pv457c6bal7t2pisqx20s51i&st=8ft9r7d0&dl=0) | `3D_ENV_BENCHMARK/minecraft/minecraft_teco_convS5.yaml` | 80 | 81 | --- 82 | 83 | ### Training 84 | Before training, you will need to update the paths to the corresponding configs files to point to your dataset and VQ-GAN directories. 85 | 86 | To train, run: 87 | `python scripts/train.py -d -o -c ` 88 | 89 | Example for training ConvS5 on DMLAB: 90 | ```commandline 91 | python scripts/train.py -d datasets/dmlab -o dmlab_convs5 -c configs/3D_ENV_BENCHMARK/dmlab/dmlab_convs5.yaml 92 | ``` 93 | 94 | Note: we only used data parallel training for our experiments. Model parallel training will require implementing JAX [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html) or [pjit/jit](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html). See [this](https://github.com/wilson1yan/teco/tree/master/teco/models/xmap) folder in the TECO repo for an example using xmap. 95 | 96 | Our runs were performed in a multinode NVIDIA V100 32GB GPU environment. 97 | 98 | --- 99 | 100 | ### Evaluation 101 | To evaluate run: 102 | `python scripts/eval.py -d -o -c ` 103 | 104 | Example for evaluating ConvS5 on DMLAB: 105 | ```commandline 106 | python scripts/eval.py -d datasets/dmlab -o dmlab_convs5 -c configs/3D_ENV_BENCHMARK/dmlab/dmlab_convs5_eval.yaml 107 | ``` 108 | 109 | This will perform the sampling required for computing the different evaluation metrics. The videos will be saved into `npz` files. 110 | 111 | For FVD evaluations run: `python scripts/compute_fvd.py ` 112 | 113 | Example for ConvS5 on DMLAB: 114 | ```commandline 115 | python scripts/compute_fvd.py logs/dmlab_convs5/samples_36 116 | ``` 117 | 118 | For PSNR, SSIM, and LPIPS run: `python scripts/compute_metrics.py ` 119 | 120 | Example for ConvS5 on DMLAB: 121 | ```commandline 122 | python scripts/compute_metrics.py logs/dmlab_convs5/samples_action_144 123 | ``` 124 | 125 | --- 126 | 127 | ### Citation 128 | Please use the following when citing our work: 129 | 130 | ```BiBTeX 131 | @inproceedings{ 132 | smith2023convolutional, 133 | title={Convolutional State Space Models for Long-Range Spatiotemporal Modeling}, 134 | author={Jimmy T.H. Smith and Shalini De Mello and Jan Kautz and Scott Linderman and Wonmin Byeon}, 135 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 136 | year={2023}, 137 | url={https://openreview.net/forum?id=1ZvEtnrHS1} 138 | } 139 | ``` 140 | 141 | --- 142 | 143 | ### License 144 | Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 145 | See LICENSE file for details. 146 | 147 | 148 | Please reach out if you have any questions. 149 | 150 | -- The ConvS5 authors. 151 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/ConvSSM/cb08bb3d5eef4f7438a4d76f4c24007beecd66db/__init__.py -------------------------------------------------------------------------------- /configs/3D_ENV_BENCHMARK/dmlab/dmlab_S5.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: True 6 | batch_size: 16 7 | eval_size: 512 8 | num_workers: 4 9 | lr: 0.001 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 250000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/dmlab" 20 | eval_seq_len: 300 21 | seq_len: 300 22 | image_size: 64 23 | channels: 3 24 | 25 | num_shards: null 26 | rng_keys: ["dropout", "sample"] 27 | batch_keys: ["video", "actions"] 28 | 29 | # Model 30 | model: "S5" 31 | vqvae_ckpt: "./dmlab_vqgan" 32 | 33 | 34 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 35 | depths: [256] # 16x16 36 | blocks: 1 #blocks doesn't do anything here 37 | 38 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 39 | depths: [512] # 16x16 40 | blocks: 4 41 | 42 | z_ds: 16 # 16x16 -> 1x1 43 | 44 | d_model: 1024 45 | 46 | # Sequence Model 47 | seq_model: 48 | n_layers: 8 49 | dropout: 0.0 50 | use_norm: True 51 | prenorm: False 52 | per_layer_skip: True 53 | skip_connections: False 54 | 55 | # SSM 56 | ssm: 57 | ssm_size: 1024 58 | blocks: 1 59 | clip_eigs: False 60 | dt_min: 0.001 61 | dt_max: 0.1 62 | 63 | embedding_dim: 512 64 | n_cond: 1 65 | 66 | # Causal Masking 67 | causal_masking: True 68 | frame_mask_id: -2 69 | 70 | # Actions 71 | use_actions: true 72 | action_dim: 6 73 | action_embed_dim: 16 74 | dropout_actions: true 75 | action_dropout_rate: 0.5 76 | action_mask_id: -1 77 | 78 | # Sampling 79 | open_loop_ctx: 36 80 | 81 | open_loop_ctx_1: 144 82 | action_conditioned_1: True 83 | open_loop_ctx_2: 36 84 | action_conditioned_2: False -------------------------------------------------------------------------------- /configs/3D_ENV_BENCHMARK/dmlab/dmlab_S5_eval.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 512 8 | num_workers: 4 9 | lr: 0.001 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 250000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/dmlab" 20 | eval_seq_len: 300 21 | seq_len: 300 22 | image_size: 64 23 | channels: 3 24 | 25 | num_shards: null 26 | rng_keys: ["dropout", "sample"] 27 | batch_keys: ["video", "actions"] 28 | 29 | # Model 30 | model: "S5" 31 | vqvae_ckpt: "./dmlab_vqgan" 32 | 33 | 34 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 35 | depths: [256] # 16x16 36 | blocks: 1 #blocks doesn't do anything here 37 | 38 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 39 | depths: [512] # 16x16 40 | blocks: 4 41 | 42 | z_ds: 16 # 16x16 -> 1x1 43 | 44 | d_model: 1024 45 | 46 | # Sequence Model 47 | seq_model: 48 | n_layers: 8 49 | dropout: 0.0 50 | use_norm: True 51 | prenorm: False 52 | per_layer_skip: True 53 | skip_connections: False 54 | 55 | # SSM 56 | ssm: 57 | ssm_size: 1024 58 | blocks: 1 59 | clip_eigs: False 60 | dt_min: 0.001 61 | dt_max: 0.1 62 | 63 | embedding_dim: 512 64 | n_cond: 1 65 | 66 | # Causal Masking 67 | causal_masking: True 68 | frame_mask_id: -2 69 | 70 | # Actions 71 | use_actions: true 72 | action_dim: 6 73 | action_embed_dim: 16 74 | dropout_actions: true 75 | action_dropout_rate: 0.5 76 | action_mask_id: -1 77 | 78 | # Sampling 79 | open_loop_ctx: 36 80 | 81 | open_loop_ctx_1: 144 82 | action_conditioned_1: True 83 | open_loop_ctx_2: 36 84 | action_conditioned_2: False -------------------------------------------------------------------------------- /configs/3D_ENV_BENCHMARK/dmlab/dmlab_convs5.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: True 6 | batch_size: 16 7 | eval_size: 512 8 | num_workers: 4 9 | lr: 0.0005 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 250000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/dmlab" 20 | eval_seq_len: 300 21 | seq_len: 300 22 | image_size: 64 23 | channels: 3 24 | 25 | num_shards: null 26 | rng_keys: ["dropout", "sample"] 27 | batch_keys: ["video", "actions"] 28 | 29 | # Model 30 | model: "convS5" 31 | vqvae_ckpt: "./dmlab_vqgan" 32 | 33 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 34 | depths: [256] # 16x16 35 | blocks: 1 #blocks doesn't do anything here 36 | 37 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 38 | depths: [512] # 16x16 39 | blocks: 4 40 | 41 | d_model: 512 42 | 43 | # Sequence Model 44 | seq_model: 45 | n_layers: 8 46 | layer_activation: "gelu" 47 | dropout: 0.0 48 | use_norm: True 49 | prenorm: False 50 | per_layer_skip: True 51 | num_groups: 32 52 | squeeze_excite: False 53 | skip_connections: False 54 | 55 | # SSM 56 | ssm: 57 | ssm_size: 512 58 | blocks: 1 59 | clip_eigs: False 60 | B_kernel_size: 3 61 | C_kernel_size: 3 62 | D_kernel_size: 3 63 | dt_min: 0.001 64 | dt_max: 0.1 65 | C_D_config: "resnet" 66 | 67 | 68 | latent_height: 16 69 | latent_width: 16 70 | 71 | n_cond: 1 72 | drop_loss_rate: 0.0 73 | 74 | # Causal Masking 75 | causal_masking: True 76 | frame_mask_id: -2 77 | 78 | # Actions 79 | use_actions: true 80 | action_dim: 6 81 | action_embed_dim: 16 82 | dropout_actions: true 83 | action_dropout_rate: 0.5 84 | action_mask_id: -1 85 | 86 | # Sampling 87 | open_loop_ctx: 36 88 | 89 | open_loop_ctx_1: 144 90 | action_conditioned_1: True 91 | open_loop_ctx_2: 36 92 | action_conditioned_2: False -------------------------------------------------------------------------------- /configs/3D_ENV_BENCHMARK/dmlab/dmlab_convs5_eval.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 512 8 | num_workers: 4 9 | lr: 0.0005 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 250000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/dmlab" 20 | eval_seq_len: 300 21 | seq_len: 300 22 | image_size: 64 23 | channels: 3 24 | 25 | num_shards: null 26 | rng_keys: ["dropout", "sample"] 27 | batch_keys: ["video", "actions"] 28 | 29 | # Model 30 | model: "convS5" 31 | vqvae_ckpt: "./dmlab_vqgan" 32 | 33 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 34 | depths: [256] # 16x16 35 | blocks: 1 #blocks doesn't do anything here 36 | 37 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 38 | depths: [512] # 16x16 39 | blocks: 4 40 | 41 | d_model: 512 42 | 43 | # Sequence Model 44 | seq_model: 45 | n_layers: 8 46 | layer_activation: "gelu" 47 | dropout: 0.0 48 | use_norm: True 49 | prenorm: False 50 | per_layer_skip: True 51 | num_groups: 32 52 | squeeze_excite: False 53 | skip_connections: False 54 | 55 | # SSM 56 | ssm: 57 | ssm_size: 512 58 | blocks: 1 59 | clip_eigs: False 60 | B_kernel_size: 3 61 | C_kernel_size: 3 62 | D_kernel_size: 3 63 | dt_min: 0.001 64 | dt_max: 0.1 65 | C_D_config: "resnet" 66 | 67 | 68 | latent_height: 16 69 | latent_width: 16 70 | 71 | n_cond: 1 72 | drop_loss_rate: 0.0 73 | 74 | # Causal Masking 75 | causal_masking: True 76 | frame_mask_id: -2 77 | 78 | # Actions 79 | use_actions: true 80 | action_dim: 6 81 | action_embed_dim: 16 82 | dropout_actions: true 83 | action_dropout_rate: 0.5 84 | action_mask_id: -1 85 | 86 | # Sampling 87 | open_loop_ctx: 36 88 | 89 | open_loop_ctx_1: 144 90 | action_conditioned_1: True 91 | open_loop_ctx_2: 36 92 | action_conditioned_2: False -------------------------------------------------------------------------------- /configs/3D_ENV_BENCHMARK/dmlab/dmlab_teco_S5.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: True 6 | batch_size: 16 7 | eval_size: 512 8 | num_workers: 4 9 | lr: 0.001 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 500000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/dmlab" 20 | eval_seq_len: 300 21 | seq_len: 300 22 | image_size: 64 23 | channels: 3 24 | 25 | num_shards: null 26 | rng_keys: ["dropout", "sample"] 27 | batch_keys: ["video", "actions"] 28 | 29 | # Model 30 | model: "teco_S5" 31 | vqvae_ckpt: "./dmlab_vqgan" 32 | 33 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 34 | depths: [256, 512] # 16x16 -> 8x8 35 | blocks: 2 36 | 37 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 38 | depths: [256, 512] # 16x16 -> 8x8 39 | blocks: 4 40 | 41 | z_ds: 8 # 8x8 -> 1x1 42 | 43 | d_model: 2048 44 | # Sequence Model 45 | seq_model: 46 | n_layers: 8 47 | dropout: 0.0 48 | use_norm: True 49 | prenorm: False 50 | per_layer_skip: True 51 | skip_connections: False 52 | 53 | # SSM 54 | ssm: 55 | ssm_size: 2048 56 | blocks: 1 57 | clip_eigs: False 58 | dt_min: 0.001 59 | dt_max: 0.1 60 | 61 | z_git: 62 | vocab_dim: 256 63 | mask_schedule: "cosine" 64 | tfm_kwargs: 65 | embed_dim: 512 66 | mlp_dim: 2048 67 | num_heads: 8 68 | num_layers: 8 69 | dropout: 0. 70 | attention_dropout: 0. 71 | 72 | embedding_dim: 64 73 | codebook: 74 | n_codes: 1024 75 | proj_dim: 32 76 | 77 | n_cond: 1 78 | drop_loss_rate: 0.9 79 | 80 | # Causal Masking 81 | causal_masking: True 82 | frame_mask_id: -2 83 | 84 | # Actions 85 | use_actions: true 86 | action_dim: 6 87 | action_embed_dim: 16 88 | dropout_actions: true 89 | action_dropout_rate: 0.5 90 | action_mask_id: -1 91 | 92 | # Sampling 93 | T_draft: 8 94 | T_revise: 8 95 | M: 2 96 | open_loop_ctx: 36 97 | 98 | open_loop_ctx_1: 144 99 | action_conditioned_1: True 100 | open_loop_ctx_2: 36 101 | action_conditioned_2: False -------------------------------------------------------------------------------- /configs/3D_ENV_BENCHMARK/dmlab/dmlab_teco_S5_eval.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 512 8 | num_workers: 4 9 | lr: 0.001 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 500000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/dmlab" 20 | eval_seq_len: 300 21 | seq_len: 300 22 | image_size: 64 23 | channels: 3 24 | 25 | num_shards: null 26 | rng_keys: ["dropout", "sample"] 27 | batch_keys: ["video", "actions"] 28 | 29 | # Model 30 | model: "teco_S5" 31 | vqvae_ckpt: "./dmlab_vqgan" 32 | 33 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 34 | depths: [256, 512] # 16x16 -> 8x8 35 | blocks: 2 36 | 37 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 38 | depths: [256, 512] # 16x16 -> 8x8 39 | blocks: 4 40 | 41 | z_ds: 8 # 8x8 -> 1x1 42 | 43 | d_model: 2048 44 | # Sequence Model 45 | seq_model: 46 | n_layers: 8 47 | dropout: 0.0 48 | use_norm: True 49 | prenorm: False 50 | per_layer_skip: True 51 | skip_connections: False 52 | 53 | # SSM 54 | ssm: 55 | ssm_size: 2048 56 | blocks: 1 57 | clip_eigs: False 58 | dt_min: 0.001 59 | dt_max: 0.1 60 | 61 | z_git: 62 | vocab_dim: 256 63 | mask_schedule: "cosine" 64 | tfm_kwargs: 65 | embed_dim: 512 66 | mlp_dim: 2048 67 | num_heads: 8 68 | num_layers: 8 69 | dropout: 0. 70 | attention_dropout: 0. 71 | 72 | embedding_dim: 64 73 | codebook: 74 | n_codes: 1024 75 | proj_dim: 32 76 | 77 | n_cond: 1 78 | drop_loss_rate: 0.9 79 | 80 | # Causal Masking 81 | causal_masking: True 82 | frame_mask_id: -2 83 | 84 | # Actions 85 | use_actions: true 86 | action_dim: 6 87 | action_embed_dim: 16 88 | dropout_actions: true 89 | action_dropout_rate: 0.5 90 | action_mask_id: -1 91 | 92 | # Sampling 93 | T_draft: 8 94 | T_revise: 8 95 | M: 2 96 | open_loop_ctx: 36 97 | 98 | open_loop_ctx_1: 144 99 | action_conditioned_1: True 100 | open_loop_ctx_2: 36 101 | action_conditioned_2: False -------------------------------------------------------------------------------- /configs/3D_ENV_BENCHMARK/dmlab/dmlab_teco_convS5.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: True 6 | batch_size: 16 7 | eval_size: 512 8 | num_workers: 4 9 | lr: 0.0005 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 500000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/dmlab" 20 | eval_seq_len: 300 21 | seq_len: 300 22 | image_size: 64 23 | channels: 3 24 | 25 | num_shards: null 26 | rng_keys: ["dropout", "sample"] 27 | batch_keys: ["video", "actions"] 28 | 29 | # Model 30 | model: "teco_convS5" 31 | vqvae_ckpt: "./dmlab_vqgan" 32 | 33 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 34 | depths: [256, 512] # 16x16 -> 8x8 35 | blocks: 2 36 | 37 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 38 | depths: [256, 512] # 16x16 -> 8x8 39 | blocks: 4 40 | 41 | d_model: 512 42 | 43 | # Sequence Model 44 | seq_model: 45 | n_layers: 8 46 | layer_activation: "gelu" 47 | dropout: 0.0 48 | use_norm: True 49 | prenorm: False 50 | per_layer_skip: True 51 | num_groups: 32 52 | squeeze_excite: False 53 | skip_connections: False 54 | 55 | # SSM 56 | ssm: 57 | ssm_size: 1024 58 | blocks: 32 59 | clip_eigs: False 60 | B_kernel_size: 3 61 | C_kernel_size: 3 62 | D_kernel_size: 3 63 | dt_min: 0.001 64 | dt_max: 0.1 65 | C_D_config: "resnet" 66 | 67 | #prior 68 | z_git: 69 | vocab_dim: 256 70 | mask_schedule: "cosine" 71 | tfm_kwargs: 72 | embed_dim: 512 73 | mlp_dim: 2048 74 | num_heads: 8 75 | num_layers: 8 76 | dropout: 0. 77 | attention_dropout: 0. 78 | 79 | embedding_dim: 64 80 | codebook: 81 | n_codes: 1024 82 | proj_dim: 32 83 | 84 | latent_height: 8 85 | latent_width: 8 86 | 87 | n_cond: 1 88 | drop_loss_rate: 0.9 89 | 90 | # Causal Masking 91 | causal_masking: True 92 | frame_mask_id: -2 93 | 94 | # Actions 95 | use_actions: true 96 | action_dim: 6 97 | action_embed_dim: 16 98 | dropout_actions: true 99 | action_dropout_rate: 0.5 100 | action_mask_id: -1 101 | 102 | # Sampling 103 | T_draft: 8 104 | T_revise: 8 105 | M: 2 106 | 107 | open_loop_ctx: 36 108 | 109 | open_loop_ctx_1: 144 110 | action_conditioned_1: True 111 | open_loop_ctx_2: 36 112 | action_conditioned_2: False -------------------------------------------------------------------------------- /configs/3D_ENV_BENCHMARK/dmlab/dmlab_teco_convS5_eval.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 512 8 | num_workers: 4 9 | lr: 0.0005 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 500000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/dmlab" 20 | eval_seq_len: 300 21 | seq_len: 300 22 | image_size: 64 23 | channels: 3 24 | 25 | num_shards: null 26 | rng_keys: ["dropout", "sample"] 27 | batch_keys: ["video", "actions"] 28 | 29 | # Model 30 | model: "teco_convS5" 31 | vqvae_ckpt: "./dmlab_vqgan" 32 | 33 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 34 | depths: [256, 512] # 16x16 -> 8x8 35 | blocks: 2 36 | 37 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 38 | depths: [256, 512] # 16x16 -> 8x8 39 | blocks: 4 40 | 41 | d_model: 512 42 | 43 | # Sequence Model 44 | seq_model: 45 | n_layers: 8 46 | layer_activation: "gelu" 47 | dropout: 0.0 48 | use_norm: True 49 | prenorm: False 50 | per_layer_skip: True 51 | num_groups: 32 52 | squeeze_excite: False 53 | skip_connections: False 54 | 55 | # SSM 56 | ssm: 57 | ssm_size: 1024 58 | blocks: 32 59 | clip_eigs: False 60 | B_kernel_size: 3 61 | C_kernel_size: 3 62 | D_kernel_size: 3 63 | dt_min: 0.001 64 | dt_max: 0.1 65 | C_D_config: "resnet" 66 | 67 | #prior 68 | z_git: 69 | vocab_dim: 256 70 | mask_schedule: "cosine" 71 | tfm_kwargs: 72 | embed_dim: 512 73 | mlp_dim: 2048 74 | num_heads: 8 75 | num_layers: 8 76 | dropout: 0. 77 | attention_dropout: 0. 78 | 79 | embedding_dim: 64 80 | codebook: 81 | n_codes: 1024 82 | proj_dim: 32 83 | 84 | latent_height: 8 85 | latent_width: 8 86 | 87 | n_cond: 1 88 | drop_loss_rate: 0.9 89 | 90 | # Causal Masking 91 | causal_masking: True 92 | frame_mask_id: -2 93 | 94 | # Actions 95 | use_actions: true 96 | action_dim: 6 97 | action_embed_dim: 16 98 | dropout_actions: true 99 | action_dropout_rate: 0.5 100 | action_mask_id: -1 101 | 102 | # Sampling 103 | T_draft: 8 104 | T_revise: 8 105 | M: 2 106 | 107 | open_loop_ctx: 36 108 | 109 | open_loop_ctx_1: 144 110 | action_conditioned_1: True 111 | open_loop_ctx_2: 36 112 | action_conditioned_2: False -------------------------------------------------------------------------------- /configs/3D_ENV_BENCHMARK/dmlab/dmlab_teco_transformer.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: True 6 | batch_size: 16 7 | eval_size: 512 8 | num_workers: 4 9 | lr: 0.0001 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 500000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/dmlab" 20 | eval_seq_len: 300 21 | seq_len: 300 22 | image_size: 64 23 | channels: 3 24 | 25 | num_shards: null 26 | rng_keys: ["dropout", "sample"] 27 | batch_keys: ["video", "actions"] 28 | 29 | # Model 30 | model: "teco_transformer" 31 | vqvae_ckpt: "./dmlab_vqgan" 32 | 33 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 34 | depths: [256, 512] # 16x16 -> 8x8 35 | blocks: 2 36 | 37 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 38 | depths: [256, 512] # 16x16 -> 8x8 39 | blocks: 4 40 | 41 | z_ds: 8 # 8x8 -> 1x1 42 | z_tfm_kwargs: 43 | embed_dim: 1024 44 | mlp_dim: 4096 45 | num_heads: 16 46 | num_layers: 8 47 | dropout: 0. 48 | attention_dropout: 0. 49 | 50 | z_git: 51 | vocab_dim: 256 52 | mask_schedule: "cosine" 53 | tfm_kwargs: 54 | embed_dim: 512 55 | mlp_dim: 2048 56 | num_heads: 8 57 | num_layers: 8 58 | dropout: 0. 59 | attention_dropout: 0. 60 | 61 | embedding_dim: 64 62 | codebook: 63 | n_codes: 1024 64 | proj_dim: 32 65 | 66 | n_cond: 1 67 | drop_loss_rate: 0.9 68 | 69 | # Causal Masking 70 | causal_masking: True 71 | frame_mask_id: -2 72 | 73 | # Actions 74 | use_actions: true 75 | action_dim: 6 76 | action_embed_dim: 16 77 | dropout_actions: true 78 | action_dropout_rate: 0.5 79 | action_mask_id: -1 80 | 81 | # Sampling 82 | T_draft: 8 83 | T_revise: 8 84 | M: 2 85 | open_loop_ctx: 36 86 | 87 | open_loop_ctx_1: 144 88 | action_conditioned_1: True 89 | open_loop_ctx_2: 36 90 | action_conditioned_2: False 91 | -------------------------------------------------------------------------------- /configs/3D_ENV_BENCHMARK/dmlab/dmlab_teco_transformer_eval.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 512 8 | num_workers: 4 9 | lr: 0.0001 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 500000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/dmlab" 20 | eval_seq_len: 300 21 | seq_len: 300 22 | image_size: 64 23 | channels: 3 24 | 25 | num_shards: null 26 | rng_keys: ["dropout", "sample"] 27 | batch_keys: ["video", "actions"] 28 | 29 | # Model 30 | model: "teco_transformer" 31 | vqvae_ckpt: "./dmlab_vqgan" 32 | 33 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 34 | depths: [256, 512] # 16x16 -> 8x8 35 | blocks: 2 36 | 37 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 38 | depths: [256, 512] # 16x16 -> 8x8 39 | blocks: 4 40 | 41 | z_ds: 8 # 8x8 -> 1x1 42 | z_tfm_kwargs: 43 | embed_dim: 1024 44 | mlp_dim: 4096 45 | num_heads: 16 46 | num_layers: 8 47 | dropout: 0. 48 | attention_dropout: 0. 49 | 50 | z_git: 51 | vocab_dim: 256 52 | mask_schedule: "cosine" 53 | tfm_kwargs: 54 | embed_dim: 512 55 | mlp_dim: 2048 56 | num_heads: 8 57 | num_layers: 8 58 | dropout: 0. 59 | attention_dropout: 0. 60 | 61 | embedding_dim: 64 62 | codebook: 63 | n_codes: 1024 64 | proj_dim: 32 65 | 66 | n_cond: 1 67 | drop_loss_rate: 0.9 68 | 69 | # Causal Masking 70 | causal_masking: True 71 | frame_mask_id: -2 72 | 73 | # Actions 74 | use_actions: true 75 | action_dim: 6 76 | action_embed_dim: 16 77 | dropout_actions: true 78 | action_dropout_rate: 0.5 79 | action_mask_id: -1 80 | 81 | # Sampling 82 | T_draft: 8 83 | T_revise: 8 84 | M: 2 85 | open_loop_ctx: 36 86 | 87 | open_loop_ctx_1: 144 88 | action_conditioned_1: True 89 | open_loop_ctx_2: 36 90 | action_conditioned_2: False 91 | -------------------------------------------------------------------------------- /configs/3D_ENV_BENCHMARK/dmlab/dmlab_transformer.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: True 6 | batch_size: 16 7 | eval_size: 512 8 | num_workers: 4 9 | lr: 0.0005 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 250000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/dmlab" 20 | eval_seq_len: 300 21 | seq_len: 300 22 | image_size: 64 23 | channels: 3 24 | 25 | num_shards: null 26 | rng_keys: ["dropout", "sample"] 27 | batch_keys: ["video", "actions"] 28 | 29 | # Model 30 | model: "transformer" 31 | vqvae_ckpt: "./dmlab_vqgan" 32 | 33 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 34 | depths: [256] # 16x16 -> 8x8 35 | blocks: 1 36 | 37 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 38 | depths: [512] # 16x16 -> 8x8 39 | blocks: 4 40 | 41 | z_ds: 16 # 16x16 -> 1x1 42 | z_tfm_kwargs: 43 | embed_dim: 512 44 | mlp_dim: 2048 45 | num_heads: 16 46 | num_layers: 8 47 | dropout: 0. 48 | attention_dropout: 0. 49 | 50 | 51 | embedding_dim: 512 52 | n_cond: 1 53 | 54 | # Causal Masking 55 | causal_masking: True 56 | frame_mask_id: -2 57 | 58 | # Actions 59 | use_actions: true 60 | action_dim: 6 61 | action_embed_dim: 16 62 | dropout_actions: true 63 | action_dropout_rate: 0.5 64 | action_mask_id: -1 65 | 66 | # Sampling 67 | open_loop_ctx: 36 68 | 69 | open_loop_ctx_1: 144 70 | action_conditioned_1: True 71 | open_loop_ctx_2: 36 72 | action_conditioned_2: False 73 | -------------------------------------------------------------------------------- /configs/3D_ENV_BENCHMARK/dmlab/dmlab_transformer_eval.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 512 8 | num_workers: 4 9 | lr: 0.0005 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 250000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/dmlab" 20 | eval_seq_len: 300 21 | seq_len: 300 22 | image_size: 64 23 | channels: 3 24 | 25 | num_shards: null 26 | rng_keys: ["dropout", "sample"] 27 | batch_keys: ["video", "actions"] 28 | 29 | # Model 30 | model: "transformer" 31 | vqvae_ckpt: "./dmlab_vqgan" 32 | 33 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 34 | depths: [256] # 16x16 -> 8x8 35 | blocks: 1 36 | 37 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 38 | depths: [512] # 16x16 -> 8x8 39 | blocks: 4 40 | 41 | z_ds: 16 # 16x16 -> 1x1 42 | z_tfm_kwargs: 43 | embed_dim: 512 44 | mlp_dim: 2048 45 | num_heads: 16 46 | num_layers: 8 47 | dropout: 0. 48 | attention_dropout: 0. 49 | 50 | 51 | embedding_dim: 512 52 | n_cond: 1 53 | 54 | # Causal Masking 55 | causal_masking: True 56 | frame_mask_id: -2 57 | 58 | # Actions 59 | use_actions: true 60 | action_dim: 6 61 | action_embed_dim: 16 62 | dropout_actions: true 63 | action_dropout_rate: 0.5 64 | action_mask_id: -1 65 | 66 | # Sampling 67 | open_loop_ctx: 36 68 | 69 | open_loop_ctx_1: 144 70 | action_conditioned_1: True 71 | open_loop_ctx_2: 36 72 | action_conditioned_2: False 73 | -------------------------------------------------------------------------------- /configs/3D_ENV_BENCHMARK/habitat/habitat_convS5.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: True 6 | batch_size: 16 7 | eval_size: 1024 8 | num_workers: 4 9 | lr: 0.0001 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 500000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/habitat" 20 | eval_seq_len: 300 21 | seq_len: 300 22 | image_size: 128 23 | channels: 3 24 | 25 | num_shards: null 26 | rng_keys: ["dropout", "sample"] 27 | batch_keys: ["video", "actions"] 28 | 29 | # Model 30 | model: "convS5" 31 | vqvae_ckpt: "./habitat_vqgan" 32 | 33 | use_encoder: True 34 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 35 | depths: [256] # 16x16 36 | blocks: 1 #blocks doesn't do anything here 37 | 38 | use_decoder: True 39 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 40 | depths: [512] # 16x16 41 | blocks: 4 42 | 43 | d_model: 512 44 | 45 | # Sequence Model 46 | seq_model: 47 | n_layers: 8 48 | layer_activation: "gelu" 49 | dropout: 0.0 50 | use_norm: True 51 | prenorm: False 52 | per_layer_skip: True 53 | num_groups: 32 54 | squeeze_excite: False 55 | skip_connections: False 56 | 57 | # SSM 58 | ssm: 59 | ssm_size: 512 60 | blocks: 1 61 | clip_eigs: False 62 | B_kernel_size: 3 63 | C_kernel_size: 3 64 | D_kernel_size: 3 65 | dt_min: 0.001 66 | dt_max: 0.1 67 | C_D_config: "resnet" 68 | 69 | 70 | latent_height: 16 71 | latent_width: 16 72 | 73 | n_cond: 1 74 | drop_loss_rate: 0.9 75 | 76 | # Causal Masking 77 | causal_masking: False 78 | frame_mask_id: None 79 | 80 | 81 | # Actions 82 | use_actions: true 83 | action_dim: 6 84 | action_embed_dim: 16 85 | dropout_actions: true 86 | action_dropout_rate: 0.5 87 | action_mask_id: -1 88 | 89 | # Sampling 90 | open_loop_ctx: 36 91 | 92 | open_loop_ctx_1: 144 93 | action_conditioned_1: True 94 | open_loop_ctx_2: 36 95 | action_conditioned_2: False -------------------------------------------------------------------------------- /configs/3D_ENV_BENCHMARK/habitat/habitat_teco_convS5.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: True 6 | batch_size: 16 7 | eval_size: 1024 8 | num_workers: 4 9 | lr: 0.0001 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 1000000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/habitat" 20 | eval_seq_len: 300 21 | seq_len: 300 22 | image_size: 128 23 | channels: 3 24 | 25 | num_shards: null 26 | rng_keys: ["dropout", "sample"] 27 | batch_keys: ["video", "actions"] 28 | 29 | # Model 30 | model: "teco_convS5" 31 | vqvae_ckpt: "./habitat_vqgan" 32 | 33 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 34 | depths: [256, 512] # 16x16 -> 8x8 35 | blocks: 4 36 | 37 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 38 | depths: [256, 512] # 16x16 -> 8x8 39 | blocks: 8 40 | 41 | d_model: 512 42 | 43 | # Sequence Model 44 | seq_model: 45 | n_layers: 8 46 | layer_activation: "gelu" 47 | dropout: 0.0 48 | use_norm: True 49 | prenorm: False 50 | per_layer_skip: True 51 | num_groups: 32 52 | squeeze_excite: False 53 | skip_connections: False 54 | 55 | # SSM 56 | ssm: 57 | ssm_size: 512 58 | blocks: 1 59 | clip_eigs: False 60 | B_kernel_size: 3 61 | C_kernel_size: 3 62 | D_kernel_size: 3 63 | dt_min: 0.001 64 | dt_max: 0.1 65 | C_D_config: "resnet" 66 | 67 | #prior 68 | z_git: 69 | vocab_dim: 256 70 | mask_schedule: "cosine" 71 | tfm_kwargs: 72 | embed_dim: 1024 73 | mlp_dim: 4096 74 | num_heads: 16 75 | num_layers: 16 76 | dropout: 0. 77 | attention_dropout: 0. 78 | 79 | embedding_dim: 256 80 | codebook: 81 | n_codes: 2048 82 | proj_dim: 32 83 | 84 | latent_height: 8 85 | latent_width: 8 86 | 87 | n_cond: 1 88 | drop_loss_rate: 0.9 89 | 90 | # Causal Masking 91 | causal_masking: False 92 | frame_mask_id: None 93 | 94 | # Actions 95 | use_actions: true 96 | action_dim: 6 97 | action_embed_dim: 16 98 | dropout_actions: true 99 | action_dropout_rate: 0.5 100 | action_mask_id: -1 101 | 102 | # Sampling 103 | T_draft: 8 104 | T_revise: 8 105 | M: 2 106 | 107 | open_loop_ctx: 36 108 | 109 | open_loop_ctx_1: 144 110 | action_conditioned_1: True 111 | open_loop_ctx_2: 36 112 | action_conditioned_2: False -------------------------------------------------------------------------------- /configs/3D_ENV_BENCHMARK/minecraft/minecraft_convS5.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: True 6 | batch_size: 16 7 | eval_size: 1024 8 | num_workers: 4 9 | lr: 0.0005 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 500000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/minecraft" 20 | eval_seq_len: 300 21 | seq_len: 300 22 | image_size: 128 23 | channels: 3 24 | 25 | num_shards: null 26 | rng_keys: ["dropout", "sample"] 27 | batch_keys: ["video", "actions"] 28 | 29 | # Model 30 | model: "convS5" 31 | vqvae_ckpt: "./minecraft_vqgan" 32 | 33 | use_encoder: True 34 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 35 | depths: [256] # 16x16 36 | blocks: 1 #blocks doesn't do anything here 37 | 38 | use_decoder: True 39 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 40 | depths: [512] # 16x16 41 | blocks: 4 42 | 43 | d_model: 512 44 | 45 | # Sequence Model 46 | seq_model: 47 | n_layers: 12 48 | layer_activation: "gelu" 49 | dropout: 0.0 50 | use_norm: True 51 | prenorm: False 52 | per_layer_skip: True 53 | num_groups: 32 54 | squeeze_excite: False 55 | skip_connections: False 56 | 57 | # SSM 58 | ssm: 59 | ssm_size: 512 60 | blocks: 1 61 | clip_eigs: False 62 | B_kernel_size: 3 63 | C_kernel_size: 3 64 | D_kernel_size: 3 65 | dt_min: 0.001 66 | dt_max: 0.1 67 | C_D_config: "resnet" 68 | 69 | 70 | latent_height: 16 71 | latent_width: 16 72 | 73 | n_cond: 1 74 | drop_loss_rate: 0.9 75 | 76 | # Causal Masking 77 | causal_masking: False 78 | frame_mask_id: None 79 | 80 | # Actions 81 | use_actions: true 82 | action_dim: 6 83 | action_embed_dim: 16 84 | dropout_actions: false 85 | action_dropout_rate: 0.0 86 | action_mask_id: None 87 | 88 | # Sampling 89 | open_loop_ctx: 36 90 | 91 | open_loop_ctx_1: 144 92 | action_conditioned_1: True 93 | open_loop_ctx_2: 36 94 | action_conditioned_2: True -------------------------------------------------------------------------------- /configs/3D_ENV_BENCHMARK/minecraft/minecraft_teco_convS5.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: True 6 | batch_size: 16 7 | eval_size: 1024 8 | num_workers: 4 9 | lr: 0.0005 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 1000000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/minecraft" 20 | eval_seq_len: 300 21 | seq_len: 300 22 | image_size: 128 23 | channels: 3 24 | 25 | num_shards: null 26 | rng_keys: ["dropout", "sample"] 27 | batch_keys: ["video", "actions"] 28 | 29 | # Model 30 | model: "teco_convS5" 31 | vqvae_ckpt: "./minecraft_vqgan" 32 | 33 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 34 | depths: [256, 512] # 16x16 -> 8x8 35 | blocks: 4 36 | 37 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 38 | depths: [256, 512] # 16x16 -> 8x8 39 | blocks: 6 40 | 41 | d_model: 512 42 | 43 | # Sequence Model 44 | seq_model: 45 | n_layers: 12 46 | layer_activation: "gelu" 47 | dropout: 0.0 48 | use_norm: True 49 | prenorm: False 50 | per_layer_skip: True 51 | num_groups: 32 52 | squeeze_excite: False 53 | skip_connections: False 54 | 55 | # SSM 56 | ssm: 57 | ssm_size: 512 58 | blocks: 1 59 | clip_eigs: False 60 | B_kernel_size: 3 61 | C_kernel_size: 3 62 | D_kernel_size: 3 63 | dt_min: 0.001 64 | dt_max: 0.1 65 | C_D_config: "resnet" 66 | 67 | #prior 68 | z_git: 69 | vocab_dim: 256 70 | mask_schedule: "cosine" 71 | tfm_kwargs: 72 | embed_dim: 768 73 | mlp_dim: 3072 74 | num_heads: 12 75 | num_layers: 6 76 | dropout: 0. 77 | attention_dropout: 0. 78 | 79 | embedding_dim: 128 80 | codebook: 81 | n_codes: 1024 82 | proj_dim: 32 83 | 84 | latent_height: 8 85 | latent_width: 8 86 | 87 | n_cond: 1 88 | drop_loss_rate: 0.9 89 | 90 | # Causal Masking 91 | causal_masking: False 92 | frame_mask_id: None 93 | 94 | # Actions 95 | use_actions: true 96 | action_dim: 6 97 | action_embed_dim: 16 98 | dropout_actions: False 99 | action_dropout_rate: 0.0 100 | action_mask_id: None 101 | 102 | # Sampling 103 | T_draft: 8 104 | T_revise: 8 105 | M: 2 106 | 107 | open_loop_ctx: 36 108 | 109 | open_loop_ctx_1: 144 110 | action_conditioned_1: True 111 | open_loop_ctx_2: 36 112 | action_conditioned_2: True -------------------------------------------------------------------------------- /configs/Moving-MNIST/300_train_len/mnist_convLSTM_novq.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 1024 8 | num_workers: 4 9 | lr: 0.0005 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 300000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/moving_mnist_long_1" 20 | eval_seq_len_1: 500 21 | eval_seq_len_2: 1000 22 | seq_len: 300 23 | image_size: 64 24 | channels: 1 25 | 26 | num_shards: null 27 | rng_keys: ["dropout", "sample"] 28 | batch_keys: ["video", "actions"] 29 | 30 | # Model 31 | model: "convLSTM_noVQ" 32 | 33 | loss_weight: 0.5 34 | 35 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 36 | depths: [64, 128, 256] # 64x64 to 16x16 37 | blocks: 1 #blocks doesn't do anything here 38 | 39 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 40 | depths: [64, 128, 256] # 16x16 to 64x64 41 | blocks: 1 42 | 43 | d_model: 256 44 | 45 | # Sequence Model 46 | seq_model: 47 | n_layers: 8 48 | dropout: 0.0 49 | use_norm: True 50 | prenorm: False 51 | per_layer_skip: True 52 | skip_connections: False 53 | 54 | # SSM 55 | ssm: 56 | ssm_size: 256 57 | kernel_size: 3 58 | 59 | 60 | latent_height: 16 61 | latent_width: 16 62 | 63 | n_cond: 1 64 | drop_loss_rate: 0.0 65 | 66 | causal_masking: False 67 | 68 | # Actions 69 | use_actions: False 70 | action_dim: 1 71 | action_embed_dim: 1 72 | dropout_actions: False 73 | action_dropout_rate: 0.0 74 | action_mask_id: -1 75 | 76 | # Sampling 77 | open_loop_ctx: 100 78 | 79 | open_loop_ctx_1: 100 80 | action_conditioned_1: False 81 | open_loop_ctx_2: 100 82 | action_conditioned_2: False -------------------------------------------------------------------------------- /configs/Moving-MNIST/300_train_len/mnist_convLSTM_novq_eval.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 1024 8 | num_workers: 4 9 | lr: 0.0005 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 300000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/moving_mnist_longer_eval" 20 | eval_seq_len_1: 500 21 | eval_seq_len_2: 900 22 | eval_seq_len_3: 1300 23 | seq_len: 300 24 | image_size: 64 25 | channels: 1 26 | 27 | num_shards: null 28 | rng_keys: ["dropout", "sample"] 29 | batch_keys: ["video", "actions"] 30 | 31 | # Model 32 | model: "convLSTM_noVQ" 33 | 34 | loss_weight: 0.5 35 | 36 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 37 | depths: [64, 128, 256] # 64x64 to 16x16 38 | blocks: 1 #blocks doesn't do anything here 39 | 40 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 41 | depths: [64, 128, 256] # 16x16 to 64x64 42 | blocks: 1 43 | 44 | d_model: 256 45 | 46 | # Sequence Model 47 | seq_model: 48 | n_layers: 8 49 | dropout: 0.0 50 | use_norm: True 51 | prenorm: False 52 | per_layer_skip: True 53 | skip_connections: False 54 | 55 | # SSM 56 | ssm: 57 | ssm_size: 256 58 | kernel_size: 3 59 | 60 | 61 | latent_height: 16 62 | latent_width: 16 63 | 64 | n_cond: 1 65 | drop_loss_rate: 0.0 66 | 67 | causal_masking: False 68 | 69 | # Actions 70 | use_actions: False 71 | action_dim: 1 72 | action_embed_dim: 1 73 | dropout_actions: False 74 | action_dropout_rate: 0.0 75 | action_mask_id: -1 76 | 77 | # Sampling 78 | open_loop_ctx: 100 79 | 80 | open_loop_ctx_1: 100 81 | action_conditioned_1: False 82 | open_loop_ctx_2: 100 83 | action_conditioned_2: False 84 | open_loop_ctx_3: 100 85 | action_conditioned_3: False -------------------------------------------------------------------------------- /configs/Moving-MNIST/300_train_len/mnist_convS5_novq.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 1024 8 | num_workers: 4 9 | lr: 0.001 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 300000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/moving_mnist_long_1" 20 | eval_seq_len_1: 500 21 | eval_seq_len_2: 1000 22 | seq_len: 300 23 | image_size: 64 24 | channels: 1 25 | 26 | num_shards: null 27 | rng_keys: ["dropout", "sample"] 28 | batch_keys: ["video", "actions"] 29 | 30 | # Model 31 | model: "convS5_noVQ" 32 | 33 | loss_weight: 0.5 34 | 35 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 36 | depths: [64, 128, 256] # 64x64 to 16x16 37 | blocks: 1 #blocks doesn't do anything here 38 | 39 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 40 | depths: [64, 128, 256] # 16x16 to 64x64 41 | blocks: 1 42 | 43 | d_model: 256 44 | 45 | # Sequence Model 46 | seq_model: 47 | n_layers: 8 48 | layer_activation: "gelu" 49 | dropout: 0.0 50 | use_norm: True 51 | prenorm: False 52 | per_layer_skip: True 53 | num_groups: 32 54 | squeeze_excite: False 55 | skip_connections: False 56 | 57 | # SSM 58 | ssm: 59 | ssm_size: 256 60 | blocks: 1 61 | clip_eigs: True 62 | B_kernel_size: 3 63 | C_kernel_size: 3 64 | D_kernel_size: 3 65 | dt_min: 0.001 66 | dt_max: 0.1 67 | C_D_config: "resnet" 68 | 69 | 70 | latent_height: 16 71 | latent_width: 16 72 | 73 | n_cond: 1 74 | drop_loss_rate: 0.0 75 | 76 | causal_masking: False 77 | 78 | # Actions 79 | use_actions: False 80 | action_dim: 1 81 | action_embed_dim: 1 82 | dropout_actions: False 83 | action_dropout_rate: 0.0 84 | action_mask_id: -1 85 | 86 | # Sampling 87 | open_loop_ctx: 100 88 | 89 | open_loop_ctx_1: 100 90 | action_conditioned_1: False 91 | open_loop_ctx_2: 100 92 | action_conditioned_2: False -------------------------------------------------------------------------------- /configs/Moving-MNIST/300_train_len/mnist_convS5_novq_eval.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 1024 8 | num_workers: 4 9 | lr: 0.001 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 300000 13 | warmup_steps: 5000 14 | save_interval: 2000 15 | viz_interval: 2000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/moving_mnist_longer_eval" 20 | eval_seq_len_1: 500 21 | eval_seq_len_2: 900 22 | eval_seq_len_3: 1300 23 | seq_len: 300 24 | image_size: 64 25 | channels: 1 26 | 27 | num_shards: null 28 | rng_keys: ["dropout", "sample"] 29 | batch_keys: ["video", "actions"] 30 | 31 | # Model 32 | model: "convS5_noVQ" 33 | 34 | loss_weight: 0.5 35 | 36 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 37 | depths: [64, 128, 256] # 64x64 to 16x16 38 | blocks: 1 #blocks doesn't do anything here 39 | 40 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 41 | depths: [64, 128, 256] # 16x16 to 64x64 42 | blocks: 1 43 | 44 | d_model: 256 45 | 46 | # Sequence Model 47 | seq_model: 48 | n_layers: 8 49 | layer_activation: "gelu" 50 | dropout: 0.0 51 | use_norm: True 52 | prenorm: False 53 | per_layer_skip: True 54 | num_groups: 32 55 | squeeze_excite: False 56 | skip_connections: False 57 | 58 | # SSM 59 | ssm: 60 | ssm_size: 256 61 | blocks: 1 62 | clip_eigs: True 63 | B_kernel_size: 3 64 | C_kernel_size: 3 65 | D_kernel_size: 3 66 | dt_min: 0.001 67 | dt_max: 0.1 68 | C_D_config: "resnet" 69 | 70 | 71 | latent_height: 16 72 | latent_width: 16 73 | 74 | n_cond: 1 75 | drop_loss_rate: 0.0 76 | 77 | causal_masking: False 78 | 79 | # Actions 80 | use_actions: False 81 | action_dim: 1 82 | action_embed_dim: 1 83 | dropout_actions: False 84 | action_dropout_rate: 0.0 85 | action_mask_id: -1 86 | 87 | # Sampling 88 | open_loop_ctx: 100 89 | 90 | open_loop_ctx_1: 100 91 | action_conditioned_1: False 92 | open_loop_ctx_2: 100 93 | action_conditioned_2: False 94 | open_loop_ctx_3: 100 95 | action_conditioned_3: False -------------------------------------------------------------------------------- /configs/Moving-MNIST/300_train_len/mnist_noVQ_transformer.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 1024 8 | num_workers: 4 9 | lr: 0.0005 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 300000 13 | warmup_steps: 5000 14 | save_interval: 50000 15 | viz_interval: 50000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/moving_mnist_long_1" 20 | eval_seq_len_1: 500 21 | eval_seq_len_2: 1000 22 | seq_len: 300 23 | image_size: 64 24 | channels: 1 25 | 26 | num_shards: null 27 | rng_keys: ["dropout", "sample"] 28 | batch_keys: ["video", "actions"] 29 | 30 | # Model 31 | model: "transformer_noVQ" 32 | 33 | loss_weight: 0.5 34 | 35 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 36 | depths: [64, 128, 256] # 64x64 to 16x16 37 | blocks: 1 38 | 39 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 40 | depths: [64, 128, 256] # 16x16 to 64x64 41 | blocks: 1 42 | 43 | latent_shape: [16, 16] 44 | z_ds: 16 # 16x16 -> 1x1 45 | z_tfm_kwargs: 46 | embed_dim: 1024 47 | mlp_dim: 4096 48 | num_heads: 16 49 | num_layers: 8 50 | dropout: 0. 51 | attention_dropout: 0. 52 | 53 | embedding_dim: 256 54 | n_cond: 1 55 | drop_loss_rate: 0.0 56 | 57 | causal_masking: False 58 | 59 | # Actions 60 | use_actions: False 61 | action_dim: 1 62 | action_embed_dim: 1 63 | dropout_actions: False 64 | action_dropout_rate: 0.0 65 | action_mask_id: -1 66 | 67 | # Sampling 68 | open_loop_ctx: 100 69 | 70 | open_loop_ctx_1: 100 71 | action_conditioned_1: False 72 | open_loop_ctx_2: 100 73 | action_conditioned_2: False 74 | -------------------------------------------------------------------------------- /configs/Moving-MNIST/300_train_len/mnist_noVQ_transformer_eval.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 1024 8 | num_workers: 4 9 | lr: 0.0005 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 300000 13 | warmup_steps: 5000 14 | save_interval: 50000 15 | viz_interval: 50000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/moving_mnist_longer_eval" 20 | eval_seq_len_1: 500 21 | eval_seq_len_2: 900 22 | eval_seq_len_3: 1300 23 | seq_len: 1500 24 | image_size: 64 25 | channels: 1 26 | 27 | num_shards: null 28 | rng_keys: ["dropout", "sample"] 29 | batch_keys: ["video", "actions"] 30 | 31 | # Model 32 | model: "transformer_noVQ" 33 | 34 | loss_weight: 0.5 35 | 36 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 37 | depths: [64, 128, 256] # 64x64 to 16x16 38 | blocks: 1 39 | 40 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 41 | depths: [64, 128, 256] # 16x16 to 64x64 42 | blocks: 1 43 | 44 | latent_shape: [16, 16] 45 | z_ds: 16 # 16x16 -> 1x1 46 | z_tfm_kwargs: 47 | embed_dim: 1024 48 | mlp_dim: 4096 49 | num_heads: 16 50 | num_layers: 8 51 | dropout: 0. 52 | attention_dropout: 0. 53 | 54 | embedding_dim: 256 55 | 56 | n_cond: 1 57 | drop_loss_rate: 0.0 58 | 59 | causal_masking: False 60 | 61 | # Actions 62 | use_actions: False 63 | action_dim: 1 64 | action_embed_dim: 1 65 | dropout_actions: False 66 | action_dropout_rate: 0.0 67 | action_mask_id: -1 68 | 69 | # Sampling 70 | open_loop_ctx: 100 71 | 72 | open_loop_ctx_1: 100 73 | action_conditioned_1: False 74 | open_loop_ctx_2: 100 75 | action_conditioned_2: False 76 | open_loop_ctx_3: 100 77 | action_conditioned_3: False 78 | -------------------------------------------------------------------------------- /configs/Moving-MNIST/600_train_len/mnist_convLSTM_novq.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 1024 8 | num_workers: 4 9 | lr: 0.0005 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 300000 13 | warmup_steps: 5000 14 | save_interval: 4000 15 | viz_interval: 4000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/moving_mnist_long_2" 20 | eval_seq_len_1: 800 21 | eval_seq_len_2: 1000 22 | seq_len: 600 23 | image_size: 64 24 | channels: 1 25 | 26 | num_shards: null 27 | rng_keys: ["dropout", "sample"] 28 | batch_keys: ["video", "actions"] 29 | 30 | # Model 31 | model: "convLSTM_noVQ" 32 | 33 | loss_weight: 0.5 34 | 35 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 36 | depths: [64, 128, 256] # 64x64 to 16x16 37 | blocks: 1 #blocks doesn't do anything here 38 | 39 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 40 | depths: [64, 128, 256] # 16x16 to 64x64 41 | blocks: 1 42 | 43 | d_model: 128 44 | 45 | # Sequence Model 46 | seq_model: 47 | n_layers: 8 48 | dropout: 0.0 49 | use_norm: True 50 | prenorm: False 51 | per_layer_skip: True 52 | skip_connections: False 53 | 54 | # SSM 55 | ssm: 56 | ssm_size: 128 57 | kernel_size: 3 58 | 59 | 60 | latent_height: 16 61 | latent_width: 16 62 | 63 | n_cond: 1 64 | drop_loss_rate: 0.0 65 | 66 | causal_masking: False 67 | 68 | # Actions 69 | use_actions: False 70 | action_dim: 1 71 | action_embed_dim: 1 72 | dropout_actions: False 73 | action_dropout_rate: 0.0 74 | action_mask_id: -1 75 | 76 | # Sampling 77 | open_loop_ctx: 100 78 | 79 | open_loop_ctx_1: 100 80 | action_conditioned_1: False 81 | open_loop_ctx_2: 100 82 | action_conditioned_2: False -------------------------------------------------------------------------------- /configs/Moving-MNIST/600_train_len/mnist_convLSTM_novq_eval.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 1024 8 | num_workers: 4 9 | lr: 0.0005 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 300000 13 | warmup_steps: 5000 14 | save_interval: 4000 15 | viz_interval: 4000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/moving_mnist_longer_eval" 20 | eval_seq_len_1: 500 21 | eval_seq_len_2: 900 22 | eval_seq_len_3: 1300 23 | seq_len: 300 24 | image_size: 64 25 | channels: 1 26 | 27 | num_shards: null 28 | rng_keys: ["dropout", "sample"] 29 | batch_keys: ["video", "actions"] 30 | 31 | # Model 32 | model: "convLSTM_noVQ" 33 | 34 | loss_weight: 0.5 35 | 36 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 37 | depths: [64, 128, 256] # 64x64 to 16x16 38 | blocks: 1 #blocks doesn't do anything here 39 | 40 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 41 | depths: [64, 128, 256] # 16x16 to 64x64 42 | blocks: 1 43 | 44 | d_model: 128 45 | 46 | # Sequence Model 47 | seq_model: 48 | n_layers: 8 49 | dropout: 0.0 50 | use_norm: True 51 | prenorm: False 52 | per_layer_skip: True 53 | skip_connections: False 54 | 55 | # SSM 56 | ssm: 57 | ssm_size: 128 58 | kernel_size: 3 59 | 60 | 61 | latent_height: 16 62 | latent_width: 16 63 | 64 | n_cond: 1 65 | drop_loss_rate: 0.0 66 | 67 | causal_masking: False 68 | 69 | # Actions 70 | use_actions: False 71 | action_dim: 1 72 | action_embed_dim: 1 73 | dropout_actions: False 74 | action_dropout_rate: 0.0 75 | action_mask_id: -1 76 | 77 | # Sampling 78 | open_loop_ctx: 100 79 | 80 | open_loop_ctx_1: 100 81 | action_conditioned_1: False 82 | open_loop_ctx_2: 100 83 | action_conditioned_2: False 84 | open_loop_ctx_3: 100 85 | action_conditioned_3: False -------------------------------------------------------------------------------- /configs/Moving-MNIST/600_train_len/mnist_convS5_novq.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 1024 8 | num_workers: 4 9 | lr: 0.0005 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 300000 13 | warmup_steps: 5000 14 | save_interval: 4000 15 | viz_interval: 4000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/moving_mnist_long_2" 20 | eval_seq_len_1: 800 21 | eval_seq_len_2: 1000 22 | seq_len: 600 23 | image_size: 64 24 | channels: 1 25 | 26 | num_shards: null 27 | rng_keys: ["dropout", "sample"] 28 | batch_keys: ["video", "actions"] 29 | 30 | # Model 31 | model: "convS5_noVQ" 32 | 33 | loss_weight: 0.5 34 | 35 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 36 | depths: [64, 128, 256] # 64x64 to 16x16 37 | blocks: 1 #blocks doesn't do anything here 38 | 39 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 40 | depths: [64, 128, 256] # 16x16 to 64x64 41 | blocks: 1 42 | 43 | d_model: 128 44 | 45 | # Sequence Model 46 | seq_model: 47 | n_layers: 8 48 | layer_activation: "gelu" 49 | dropout: 0.0 50 | use_norm: True 51 | prenorm: False 52 | per_layer_skip: True 53 | num_groups: 32 54 | squeeze_excite: False 55 | skip_connections: False 56 | 57 | # SSM 58 | ssm: 59 | ssm_size: 128 60 | blocks: 1 61 | clip_eigs: True 62 | B_kernel_size: 3 63 | C_kernel_size: 3 64 | D_kernel_size: 3 65 | dt_min: 0.001 66 | dt_max: 0.1 67 | C_D_config: "resnet" 68 | 69 | 70 | latent_height: 16 71 | latent_width: 16 72 | 73 | n_cond: 1 74 | drop_loss_rate: 0.0 75 | 76 | causal_masking: False 77 | 78 | # Actions 79 | use_actions: False 80 | action_dim: 1 81 | action_embed_dim: 1 82 | dropout_actions: False 83 | action_dropout_rate: 0.0 84 | action_mask_id: -1 85 | 86 | # Sampling 87 | open_loop_ctx: 100 88 | 89 | open_loop_ctx_1: 100 90 | action_conditioned_1: False 91 | open_loop_ctx_2: 100 92 | action_conditioned_2: False -------------------------------------------------------------------------------- /configs/Moving-MNIST/600_train_len/mnist_convS5_novq_eval.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 1024 8 | num_workers: 4 9 | lr: 0.0005 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 300000 13 | warmup_steps: 5000 14 | save_interval: 4000 15 | viz_interval: 4000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/moving_mnist_longer_eval" 20 | eval_seq_len_1: 500 21 | eval_seq_len_2: 900 22 | eval_seq_len_3: 1300 23 | seq_len: 300 24 | image_size: 64 25 | channels: 1 26 | 27 | num_shards: null 28 | rng_keys: ["dropout", "sample"] 29 | batch_keys: ["video", "actions"] 30 | 31 | # Model 32 | model: "convS5_noVQ" 33 | 34 | loss_weight: 0.5 35 | 36 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 37 | depths: [64, 128, 256] # 64x64 to 16x16 38 | blocks: 1 #blocks doesn't do anything here 39 | 40 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 41 | depths: [64, 128, 256] # 16x16 to 64x64 42 | blocks: 1 43 | 44 | d_model: 128 45 | 46 | # Sequence Model 47 | seq_model: 48 | n_layers: 8 49 | layer_activation: "gelu" 50 | dropout: 0.0 51 | use_norm: True 52 | prenorm: False 53 | per_layer_skip: True 54 | num_groups: 32 55 | squeeze_excite: False 56 | skip_connections: False 57 | 58 | # SSM 59 | ssm: 60 | ssm_size: 128 61 | blocks: 1 62 | clip_eigs: True 63 | B_kernel_size: 3 64 | C_kernel_size: 3 65 | D_kernel_size: 3 66 | dt_min: 0.001 67 | dt_max: 0.1 68 | C_D_config: "resnet" 69 | 70 | 71 | latent_height: 16 72 | latent_width: 16 73 | 74 | n_cond: 1 75 | drop_loss_rate: 0.0 76 | 77 | causal_masking: False 78 | 79 | # Actions 80 | use_actions: False 81 | action_dim: 1 82 | action_embed_dim: 1 83 | dropout_actions: False 84 | action_dropout_rate: 0.0 85 | action_mask_id: -1 86 | 87 | # Sampling 88 | open_loop_ctx: 100 89 | 90 | open_loop_ctx_1: 100 91 | action_conditioned_1: False 92 | open_loop_ctx_2: 100 93 | action_conditioned_2: False 94 | open_loop_ctx_3: 100 95 | action_conditioned_3: False -------------------------------------------------------------------------------- /configs/Moving-MNIST/600_train_len/mnist_noVQ_transformer.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 1024 8 | num_workers: 4 9 | lr: 0.0001 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 300000 13 | warmup_steps: 5000 14 | save_interval: 50000 15 | viz_interval: 50000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/moving_mnist_long_2" 20 | eval_seq_len_1: 800 21 | eval_seq_len_2: 1000 22 | seq_len: 600 23 | image_size: 64 24 | channels: 1 25 | 26 | num_shards: null 27 | rng_keys: ["dropout", "sample"] 28 | batch_keys: ["video", "actions"] 29 | 30 | # Model 31 | model: "transformer_noVQ" 32 | 33 | loss_weight: 0.5 34 | 35 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 36 | depths: [64, 128, 256, 512] # 64x64 to 16x16 37 | blocks: 1 38 | 39 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 40 | depths: [64, 128, 256, 512] # 16x16 to 64x64 41 | blocks: 1 42 | 43 | latent_shape: [8, 8] 44 | z_ds: 8 # 16x16 -> 1x1 45 | z_tfm_kwargs: 46 | embed_dim: 1024 47 | mlp_dim: 4096 48 | num_heads: 16 49 | num_layers: 8 50 | dropout: 0. 51 | attention_dropout: 0. 52 | 53 | embedding_dim: 256 54 | 55 | n_cond: 1 56 | drop_loss_rate: 0.0 57 | 58 | causal_masking: False 59 | 60 | # Actions 61 | use_actions: False 62 | action_dim: 1 63 | action_embed_dim: 1 64 | dropout_actions: False 65 | action_dropout_rate: 0.0 66 | action_mask_id: -1 67 | 68 | # Sampling 69 | open_loop_ctx: 300 70 | 71 | open_loop_ctx_1: 300 72 | action_conditioned_1: False 73 | open_loop_ctx_2: 100 74 | action_conditioned_2: False 75 | -------------------------------------------------------------------------------- /configs/Moving-MNIST/600_train_len/mnist_noVQ_transformer_eval.yaml: -------------------------------------------------------------------------------- 1 | seed: 1 2 | cache: false # caching available only for encoded datasets 3 | 4 | # Training 5 | multinode: False 6 | batch_size: 8 7 | eval_size: 1024 8 | num_workers: 4 9 | lr: 0.0001 10 | lr_schedule: "cosine" 11 | weight_decay: 0.00001 12 | total_steps: 300000 13 | warmup_steps: 5000 14 | save_interval: 50000 15 | viz_interval: 50000 16 | log_interval: 100 17 | 18 | # Data 19 | data_path: "/raid/moving_mnist_longer_eval" 20 | eval_seq_len_1: 500 21 | eval_seq_len_2: 900 22 | eval_seq_len_3: 1300 23 | seq_len: 1500 24 | image_size: 64 25 | channels: 1 26 | 27 | num_shards: null 28 | rng_keys: ["dropout", "sample"] 29 | batch_keys: ["video", "actions"] 30 | 31 | # Model 32 | model: "transformer_noVQ" 33 | 34 | loss_weight: 0.5 35 | 36 | encoder: # encoder / decoder are mirrored, with decoder depths reversed 37 | depths: [64, 128, 256, 512] # 64x64 to 16x16 38 | blocks: 1 39 | 40 | decoder: # encoder / decoder are mirrored, with decoder depths reversed 41 | depths: [64, 128, 256, 512] # 16x16 to 64x64 42 | blocks: 1 43 | 44 | latent_shape: [8, 8] 45 | z_ds: 8 # 16x16 -> 1x1 46 | z_tfm_kwargs: 47 | embed_dim: 1024 48 | mlp_dim: 4096 49 | num_heads: 16 50 | num_layers: 8 51 | dropout: 0. 52 | attention_dropout: 0. 53 | 54 | embedding_dim: 256 55 | 56 | n_cond: 1 57 | drop_loss_rate: 0.0 58 | 59 | causal_masking: False 60 | 61 | # Actions 62 | use_actions: False 63 | action_dim: 1 64 | action_embed_dim: 1 65 | dropout_actions: False 66 | action_dropout_rate: 0.0 67 | action_mask_id: -1 68 | 69 | # Sampling 70 | open_loop_ctx: 100 71 | 72 | open_loop_ctx_1: 100 73 | action_conditioned_1: False 74 | open_loop_ctx_2: 100 75 | action_conditioned_2: False 76 | open_loop_ctx_3: 100 77 | action_conditioned_3: False 78 | n_conditioned_3: False 79 | -------------------------------------------------------------------------------- /data/moving-mnist-pytorch/moving_mnist.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # 4 | # This work is made available under the Nvidia Source Code License. 5 | # To view a copy of this license, visit 6 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 7 | # 8 | # Written by Jimmy Smith 9 | # ------------------------------------------------------------------------------ 10 | 11 | 12 | from PIL import Image 13 | import imageio 14 | import sys 15 | import os 16 | import math 17 | import numpy as np 18 | import random 19 | import scipy.misc 20 | 21 | ########################################################################################### 22 | # script to generate moving mnist video dataset (frame by frame) as described in 23 | # [1] arXiv:1502.04681 - Unsupervised Learning of Video Representations Using LSTMs 24 | # Srivastava et al 25 | # by Tencia Lee 26 | # saves in hdf5, npz, or jpg (individual frames) format 27 | # usage: python3 moving_mnist.py --dest ~/dataset/moving-mnist/moving-mnist-val-new1.npz --filetype npz --seq_len 20 --n_seq 3000 --nums_per_image 2 28 | ########################################################################################### 29 | 30 | # image_size = 64 31 | # digit_size = 28 32 | step_length = 0.15 33 | 34 | # helper functions 35 | def arr_from_img(im,shift=0): 36 | w,h=im.size 37 | arr=im.getdata() 38 | c = np.int(np.product(arr.size) / (w*h)) 39 | return np.asarray(arr, dtype=np.float32).reshape((h,w,c)).transpose(2,1,0) / 255. - shift 40 | 41 | def get_picture_array(X, index, shift=0): 42 | ch, w, h = X.shape[1], X.shape[2], X.shape[3] 43 | ret = ((X[index]+shift)*255.).reshape(ch,w,h).transpose(2,1,0).clip(0,255).astype(np.uint8) 44 | if ch == 1: 45 | ret=ret.reshape(h,w) 46 | return ret 47 | 48 | 49 | def load_dataset(): 50 | # Load MNIST dataset for generating training data. 51 | import gzip 52 | # path = os.path.join(root, 'train-images-idx3-ubyte.gz') 53 | filename = 'train-images-idx3-ubyte.gz' 54 | with gzip.open(filename, 'rb') as f: 55 | mnist = np.frombuffer(f.read(), np.uint8, offset=16) 56 | mnist = mnist.reshape(-1, 28, 28) 57 | return mnist 58 | 59 | def get_random_trajectory(seq_length=30, image_size=64, digit_size=28): 60 | 61 | ''' Generate a random sequence of a MNIST digit ''' 62 | canvas_size = image_size - digit_size 63 | x = random.random() 64 | y = random.random() 65 | theta = random.random() * 2 * np.pi 66 | v_y = np.sin(theta) 67 | v_x = np.cos(theta) 68 | 69 | start_y = np.zeros(seq_length) 70 | start_x = np.zeros(seq_length) 71 | for i in range(seq_length): 72 | # Take a step along velocity. 73 | y += v_y * step_length 74 | x += v_x * step_length 75 | 76 | # Bounce off edges. 77 | if x <= 0: 78 | x = 0 79 | v_x = -v_x 80 | if x >= 1.0: 81 | x = 1.0 82 | v_x = -v_x 83 | if y <= 0: 84 | y = 0 85 | v_y = -v_y 86 | if y >= 1.0: 87 | y = 1.0 88 | v_y = -v_y 89 | start_y[i] = y 90 | start_x[i] = x 91 | 92 | # Scale to the size of the canvas. 93 | start_y = (canvas_size * start_y).astype(np.int32) 94 | start_x = (canvas_size * start_x).astype(np.int32) 95 | return start_y, start_x 96 | 97 | def generate_moving_mnist(num_digits=2, n_frames_total=30, n_seq=10000, image_size=64, digit_size=28): 98 | ''' 99 | Get random trajectories for the digits and generate a video. 100 | ''' 101 | mnist = load_dataset() 102 | data = np.zeros((n_seq, n_frames_total, image_size, image_size), dtype=np.float32) 103 | for seq_idx in range(n_seq): 104 | canvas = np.zeros((n_frames_total, image_size, image_size), dtype=np.float32) 105 | for n in range(num_digits): 106 | # Trajectory 107 | start_y, start_x = get_random_trajectory(n_frames_total, image_size, digit_size) 108 | ind = random.randint(0, mnist.shape[0] - 1) 109 | digit_image = mnist[ind] 110 | if digit_image.shape[0] != digit_size: 111 | digit_image = np.resize(digit_image, (digit_size, digit_size)) 112 | print("digit_image shape", digit_image.shape, digit_image.max(), digit_image.min()) 113 | for frame_idx in range(n_frames_total): 114 | top = start_y[frame_idx] 115 | left = start_x[frame_idx] 116 | bottom = top + digit_size 117 | right = left + digit_size 118 | # Draw digit 119 | canvas[frame_idx, top:bottom, left:right] = np.maximum(canvas[frame_idx, top:bottom, left:right], digit_image) 120 | if seq_idx == 0: 121 | for frame_idx in range(n_frames_total): 122 | imageio.imwrite('tmp/out_%d.jpg'%(frame_idx), canvas[frame_idx]) 123 | #scipy.misc.imsave('tmp/out_%d.jpg'%(frame_idx), canvas[frame_idx]) 124 | data[seq_idx] = canvas 125 | print(seq_idx, data[seq_idx].max()) 126 | # for frame_idx in range(n_frames_total): 127 | # print(seq_idx, frame_idx, data[seq_idx, frame_idx].shape, data[seq_idx, frame_idx].max()) 128 | # if frame_idx == 0: 129 | # save_img = data[seq_idx, frame_idx] 130 | # else: 131 | # save_img = np.concatenate([save_img, data[seq_idx, frame_idx]], axis=1) 132 | # print(save_img.shape, save_img.max()) 133 | # img = Image.fromarray(save_img[:,:].astype(np.int8), 'L') 134 | # img.save('temp/%d.png'%(seq_idx)) 135 | 136 | data = data[..., np.newaxis]#.astype(np.int8) 137 | print(data.shape) 138 | return data 139 | 140 | def main(dest, filetype='npz', seq_len=30, n_seq=100, nums_per_image=2, 141 | image_size=64, digit_size=28): 142 | dat = generate_moving_mnist(num_digits=nums_per_image, n_frames_total=seq_len, 143 | n_seq=n_seq, image_size=image_size, digit_size=digit_size) 144 | if filetype == 'hdf5': 145 | n = n_seq * seq_len 146 | import h5py 147 | from fuel.datasets.hdf5 import H5PYDataset 148 | def save_hd5py(dataset, destfile, indices_dict): 149 | f = h5py.File(destfile, mode='w') 150 | images = f.create_dataset('images', dataset.shape, dtype='uint8') 151 | images[...] = dataset 152 | split_dict = dict((k, {'images':v}) for k,v in indices_dict.iteritems()) 153 | f.attrs['split'] = H5PYDataset.create_split_array(split_dict) 154 | f.flush() 155 | f.close() 156 | indices_dict = {'train': (0, n*9/10), 'test': (n*9/10, n)} 157 | save_hd5py(dat, dest, indices_dict) 158 | elif filetype == 'npz': 159 | np.savez(dest, data=dat) 160 | print(dest) 161 | elif filetype == 'jpg': 162 | for i in range(dat.shape[0]): 163 | Image.fromarray(get_picture_array(dat, i, shift=0)).save(os.path.join(dest, '{}.jpg'.format(i))) 164 | 165 | if __name__ == '__main__': 166 | import argparse 167 | parser = argparse.ArgumentParser(description='Command line options') 168 | parser.add_argument('--image_size', type=int, dest='image_size') 169 | parser.add_argument('--dest', type=str, dest='dest') 170 | parser.add_argument('--filetype', type=str, dest='filetype') 171 | # parser.add_argument('--frame_size', type=int, dest='frame_size') 172 | parser.add_argument('--seq_len', type=int, dest='seq_len') # length of each sequence 173 | parser.add_argument('--n_seq', type=int, dest='n_seq') # number of sequences to generate 174 | parser.add_argument('--digit_size', type=int, dest='digit_size') # size of mnist digit within frame 175 | parser.add_argument('--nums_per_image', type=int, dest='nums_per_image') # number of digits in each frame 176 | # parser.add_argument('--step_length', type=float, default=0.1, dest='step_length') 177 | args = parser.parse_args(sys.argv[1:]) 178 | main(**{k:v for (k,v) in vars(args).items() if v is not None}) 179 | -------------------------------------------------------------------------------- /figs/convssm_50.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/ConvSSM/cb08bb3d5eef4f7438a4d76f4c24007beecd66db/figs/convssm_50.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | wandb 2 | tqdm 3 | ipdb 4 | pyyaml 5 | moviepy==1.0.3 6 | wandb==0.14.0 7 | av==10.0.0 8 | flax==0.6.1 9 | chex==0.1.6 10 | gcsfs==2022.7.1 11 | tensorflow_cpu==2.11.0 12 | tensorflow_datasets==4.8.3 13 | tensorflow_io==0.31.0 14 | tensorflow_gan==2.1.0 15 | tensorflow_probability==0.19.0 16 | tensorflow_hub==0.13.0 17 | lpips_jax==0.1.0 18 | numpy 19 | internetarchive 20 | gdown 21 | tensorstore==0.1.33 22 | gin-config -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/ConvSSM/cb08bb3d5eef4f7438a4d76f4c24007beecd66db/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/compute_fvd.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 BAIR OPEN RESEARCH COMMONS REPOSITORY 4 | # To view a copy of this license, visit 5 | # https://github.com/wilson1yan/teco/tree/master 6 | # ------------------------------------------------------------------------------ 7 | 8 | import os.path as osp 9 | import glob 10 | import sys 11 | import numpy as np 12 | from tqdm import tqdm 13 | from src.fvd import fvd 14 | 15 | import tensorflow as tf 16 | 17 | BATCH_SIZE = 256 18 | 19 | path = sys.argv[1] 20 | files = glob.glob(osp.join(path, '*.npz')) 21 | files.sort(key=lambda x: int(osp.basename(x).split('_')[-1].split('.')[0])) 22 | print(f'Found {len(files)} file:', files) 23 | 24 | SIZE = np.load(files[0])['real'].shape[0] 25 | 26 | 27 | def convert(video): 28 | video = tf.convert_to_tensor(video, dtype=tf.uint8) 29 | video = tf.cast(video, tf.float32) / 255. 30 | return video.numpy() 31 | 32 | 33 | def read(files): 34 | data = [np.load(f) for f in files] 35 | data = [(convert(d['real']), convert(d['fake'])) for d in tqdm(data)] 36 | real, fake = list(zip(*data)) 37 | return real, fake 38 | 39 | 40 | fvds = [] 41 | total = len(files) * SIZE 42 | pbar = tqdm(total=total) 43 | for j in range(0, len(files), BATCH_SIZE // SIZE): 44 | r, f = read(files[j:j + BATCH_SIZE // SIZE]) 45 | fvds.append(fvd(r, f)) 46 | pbar.update(BATCH_SIZE) 47 | del r 48 | del f 49 | print(f'FVD: {np.mean(fvds)} +/- {np.std(fvds)}') 50 | -------------------------------------------------------------------------------- /scripts/compute_fvd_mnist.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 BAIR OPEN RESEARCH COMMONS REPOSITORY 4 | # To view a copy of this license, visit 5 | # https://github.com/wilson1yan/teco/tree/master 6 | # ------------------------------------------------------------------------------ 7 | 8 | import os.path as osp 9 | import glob 10 | import sys 11 | import numpy as np 12 | from tqdm import tqdm 13 | from src.fvd_mnist import fvd 14 | 15 | import tensorflow as tf 16 | 17 | BATCH_SIZE = 256 18 | 19 | path = sys.argv[1] 20 | files = glob.glob(osp.join(path, '*.npz')) 21 | files.sort(key=lambda x: int(osp.basename(x).split('_')[-1].split('.')[0])) 22 | print(f'Found {len(files)} file:', files) 23 | 24 | SIZE = np.load(files[0])['real'].shape[0] 25 | 26 | 27 | def convert(video): 28 | video = tf.convert_to_tensor(video, dtype=tf.uint8) 29 | video = tf.cast(video, tf.float32) / 255. 30 | return video.numpy() 31 | 32 | 33 | def read(files): 34 | data = [np.load(f) for f in files] 35 | data = [(convert(d['real']), convert(d['fake'])) for d in tqdm(data)] 36 | real, fake = list(zip(*data)) 37 | return real, fake 38 | 39 | 40 | fvds = [] 41 | total = len(files) * SIZE 42 | pbar = tqdm(total=total) 43 | for j in range(0, len(files), BATCH_SIZE // SIZE): 44 | r, f = read(files[j:j + BATCH_SIZE // SIZE]) 45 | fvds.append(fvd(r, f)) 46 | pbar.update(BATCH_SIZE) 47 | del r 48 | del f 49 | print(f'FVD: {np.mean(fvds)} +/- {np.std(fvds)}') 50 | -------------------------------------------------------------------------------- /scripts/compute_metrics.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 BAIR OPEN RESEARCH COMMONS REPOSITORY 4 | # To view a copy of this license, visit 5 | # https://github.com/wilson1yan/teco/tree/master 6 | # ------------------------------------------------------------------------------ 7 | 8 | import sys 9 | import numpy as np 10 | from tqdm import tqdm 11 | import glob 12 | import os.path as osp 13 | 14 | from src.metrics import get_ssim, get_psnr, get_lpips 15 | 16 | BATCH_SIZE = 256 17 | 18 | path = sys.argv[1] 19 | if path.endswith('/'): 20 | path = path[:-1] 21 | open_loop_ctx = int(osp.basename(path).split('_')[-1]) 22 | 23 | files = glob.glob(osp.join(path, '*.npz')) 24 | files.sort(key=lambda x: int(osp.basename(x).split('_')[-1].split('.')[0])) 25 | print(f'Found {len(files)} file:', files) 26 | 27 | SIZE = np.load(files[0])['real'].shape[0] 28 | 29 | 30 | def read(files): 31 | scale = np.array(255., dtype=np.float32) 32 | data = [np.load(f) for f in files] 33 | data = [(d['real'][:, open_loop_ctx:] / scale, d['fake'][:, open_loop_ctx:] / scale) for d in data] 34 | return data 35 | 36 | 37 | ssim_fn = get_ssim() 38 | psnr_fn = get_psnr() 39 | lpips_fn = get_lpips() 40 | 41 | ssims, psnrs, lpips = [], [], [] 42 | total = len(files) * SIZE 43 | pbar = tqdm(total=total) 44 | for j in range(0, len(files), BATCH_SIZE // SIZE): 45 | data = read(files[j:j + BATCH_SIZE // SIZE]) 46 | ps, ss, ls = [], [], [] 47 | for r_i, f_i in data: 48 | ps.append(psnr_fn(r_i, f_i).mean()) 49 | ss.append(ssim_fn(r_i, f_i).mean()) 50 | ls.append(lpips_fn(r_i, f_i).mean()) 51 | pbar.update(r_i.shape[0]) 52 | psnrs.append(np.mean(ps)) 53 | ssims.append(np.mean(ss)) 54 | lpips.append(np.mean(ls)) 55 | 56 | print(f'PSNR: {np.mean(psnrs)} +/- {np.std(psnrs)}') 57 | print(f'SSIM: {np.mean(ssims)} +/- {np.std(ssims)}') 58 | print(f'LPIPS: {np.mean(lpips)} +/- {np.std(lpips)}') 59 | -------------------------------------------------------------------------------- /scripts/download/dmlab.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | mkdir -p $1 4 | cd $1 5 | 6 | for i in aa ab ac 7 | do 8 | ia download dmlab_dataset_$i dmlab.tar.part$i 9 | mv dmlab_dataset_$i/dmlab.tar.part$i . 10 | rmdir dmlab_dataset_$i 11 | done 12 | 13 | cat dmlab.tar.part* | tar x 14 | 15 | rm dmlab.tar.part* 16 | -------------------------------------------------------------------------------- /scripts/download/dmlab_encoded.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | mkdir -p $1 4 | cd $1 5 | 6 | ia download dmlab_encoded dmlab_encoded.tar 7 | mv dmlab_encoded/dmlab_encoded.tar . 8 | rmdir dmlab_encoded 9 | tar -xf dmlab_encoded.tar 10 | rm dmlab_encoded.tar -------------------------------------------------------------------------------- /scripts/download/habitat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | mkdir -p $1 4 | cd $1 5 | 6 | for i in aa ab ac ad 7 | do 8 | ia download habitat_dataset_$i habitat.tar.part$i 9 | mv habitat_dataset_$i/habitat.tar.part$i . 10 | rmdir habitat_dataset_$i 11 | done 12 | 13 | cat habitat.tar.part* | tar x 14 | 15 | rm habitat.tar.part* 16 | -------------------------------------------------------------------------------- /scripts/download/habitat_encoded.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | mkdir -p $1 4 | cd $1 5 | 6 | ia download habitat_encoded habitat_encoded.tar 7 | mv habitat_encoded/habitat_encoded.tar . 8 | rmdir habitat_encoded 9 | tar -xf habitat_encoded.tar 10 | rm habitat_encoded.tar -------------------------------------------------------------------------------- /scripts/download/kinetics600_encoded.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | mkdir -p $1 4 | cd $1 5 | 6 | ia download kinetics600_encoded kinetics600_encoded.tar 7 | mv kinetics600_encoded/kinetics600_encoded.tar . 8 | rmdir kinetics600_encoded 9 | tar -xf kinetics600_encoded.tar 10 | rm kinetics600_encoded.tar -------------------------------------------------------------------------------- /scripts/download/minecraft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | mkdir -p $1 4 | cd $1 5 | 6 | for i in aa ab ac ad ae af ag ah ai aj ak 7 | do 8 | ia download minecraft_marsh_dataset_$i minecraft.tar.part$i 9 | mv minecraft_marsh_dataset_$i/minecraft.tar.part$i . 10 | rmdir minecraft_marsh_dataset_$i 11 | done 12 | 13 | cat minecraft.tar.part* | tar x 14 | 15 | rm minecraft.tar.part* 16 | -------------------------------------------------------------------------------- /scripts/download/minecraft_encoded.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | mkdir -p $1 4 | cd $1 5 | 6 | ia download minecraft_marsh_encoded minecraft_encoded.tar 7 | mv minecraft_marsh_encoded/minecraft_encoded.tar . 8 | rmdir minecraft_encoded 9 | tar -xf minecraft_encoded.tar 10 | rm minecraft_encoded.tar -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | from setuptools import find_packages 3 | 4 | setup( 5 | name='convssm', 6 | version='0.0.1', 7 | packages=find_packages() 8 | ) -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/ConvSSM/cb08bb3d5eef4f7438a4d76f4c24007beecd66db/src/__init__.py -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 BAIR OPEN RESEARCH COMMONS REPOSITORY 4 | # To view a copy of this license, visit 5 | # https://github.com/wilson1yan/teco/tree/master 6 | # ------------------------------------------------------------------------------ 7 | 8 | import glob 9 | import os.path as osp 10 | import numpy as np 11 | from flax import jax_utils 12 | import jax 13 | import tensorflow as tf 14 | import tensorflow_datasets as tfds 15 | import tensorflow_io as tfio 16 | from tensorflow.python.lib.io import file_io 17 | import io 18 | 19 | 20 | def is_tfds_folder(path): 21 | path = osp.join(path, '1.0.0') 22 | if path.startswith('gs://'): 23 | return tf.io.gfile.exists(path) 24 | else: 25 | return osp.exists(path) 26 | 27 | 28 | def load_npz(config, split, num_ds_shards, ds_shard_id): 29 | folder = osp.join(config.data_path, split, '*', '*.npz') 30 | if folder.startswith('gs://'): 31 | fns = tf.io.gfile.glob(folder) 32 | else: 33 | fns = list(glob.glob(folder)) 34 | fns = np.array_split(fns, num_ds_shards)[ds_shard_id].tolist() 35 | 36 | def read(path): 37 | path = path.decode('utf-8') 38 | if path.startswith('gs://'): 39 | path = io.BytesIO(file_io.FileIO(path, 'rb').read()) 40 | data = np.load(path) 41 | video, actions = data['video'].astype(np.float32), data['actions'].astype(np.int32) 42 | video = 2 * (video / 255.) - 1 43 | return video, actions 44 | 45 | dataset = tf.data.Dataset.from_tensor_slices(fns) 46 | dataset = dataset.map( 47 | lambda item: tf.numpy_function( 48 | read, 49 | [item], 50 | [tf.float32, tf.int32] 51 | ), 52 | num_parallel_calls=tf.data.experimental.AUTOTUNE 53 | ) 54 | dataset = dataset.map( 55 | lambda video, actions: dict(video=video, actions=actions), 56 | num_parallel_calls=tf.data.experimental.AUTOTUNE 57 | ) 58 | 59 | return dataset 60 | 61 | 62 | def load_mnist_npz(config, split, num_ds_shards, ds_shard_id): 63 | folder = osp.join(config.data_path, split, '*', '*.npz') 64 | if folder.startswith('gs://'): 65 | fns = tf.io.gfile.glob(folder) 66 | else: 67 | fns = list(glob.glob(folder)) 68 | fns = np.array_split(fns, num_ds_shards)[ds_shard_id].tolist() 69 | 70 | def read(path): 71 | path = path.decode('utf-8') 72 | if path.startswith('gs://'): 73 | path = io.BytesIO(file_io.FileIO(path, 'rb').read()) 74 | data = np.load(path) 75 | video, actions = data['data'].astype(np.float32), None 76 | video = 2 * (video / 255.) - 1 77 | return video 78 | 79 | dataset = tf.data.Dataset.from_tensor_slices(fns) 80 | dataset = dataset.map( 81 | lambda item: tf.numpy_function( 82 | read, 83 | [item], 84 | [tf.float32] 85 | ), 86 | num_parallel_calls=tf.data.experimental.AUTOTUNE 87 | ) 88 | dataset = dataset.map( 89 | lambda video: dict(video=video, actions=None), 90 | num_parallel_calls=tf.data.experimental.AUTOTUNE 91 | ) 92 | 93 | return dataset 94 | 95 | 96 | def load_video(config, split, num_ds_shards, ds_shard_id): 97 | folder = osp.join(config.data_path, split, '*', '*.mp4') 98 | if folder.startswith('gs://'): 99 | fns = tf.io.gfile.glob(folder) 100 | else: 101 | fns = list(glob.glob(folder)) 102 | fns = np.array_split(fns, num_ds_shards)[ds_shard_id].tolist() 103 | 104 | # TODO resizing video 105 | def read(path): 106 | path = path.decode('utf-8') 107 | 108 | video = tfio.experimental.ffmpeg.decode_video(tf.io.read_file(path)).numpy() 109 | start_idx = np.random.randint(0, video.shape[0] - config.seq_len + 1) 110 | video = video[start_idx:start_idx + config.seq_len] 111 | video = 2 * (video / np.array(255., dtype=np.float32)) - 1 112 | 113 | np_path = path[:-3] + 'npz' 114 | if tf.io.gfile.exists(np_path): 115 | if path.startswith('gs://'): 116 | np_path = io.BytesIO(file_io.FileIO(np_path, 'rb').read()) 117 | np_data = np.load(np_path) 118 | actions = np_data['actions'].astype(np.int32) 119 | actions = actions[start_idx:start_idx + config.seq_len] 120 | else: 121 | actions = np.zeros((video.shape[0],), dtype=np.int32) 122 | 123 | return video, actions 124 | 125 | dataset = tf.data.Dataset.from_tensor_slices(fns) 126 | dataset = dataset.map( 127 | lambda item: tf.numpy_function( 128 | read, 129 | [item], 130 | [tf.float32, tf.int32] 131 | ), 132 | num_parallel_calls=tf.data.experimental.AUTOTUNE 133 | ) 134 | dataset = dataset.map( 135 | lambda video, actions: dict(video=video, actions=actions), 136 | num_parallel_calls=tf.data.experimental.AUTOTUNE 137 | ) 138 | 139 | return dataset 140 | 141 | 142 | class Data: 143 | def __init__(self, config, xmap=False): 144 | self.config = config 145 | self.xmap = xmap 146 | print('Dataset:', config.data_path) 147 | 148 | @property 149 | def train_itr_per_epoch(self): 150 | return self.train_size // self.config.batch_size 151 | 152 | @property 153 | def test_itr_per_epoch(self): 154 | return self.test_size // self.config.batch_size 155 | 156 | def create_iterator(self, train, repeat=True, prefetch=True): 157 | if self.xmap: 158 | num_data = jax.device_count() // self.config.num_shards 159 | num_data_local = max(1, jax.local_device_count() // self.config.num_shards) 160 | if num_data >= jax.process_count(): 161 | num_ds_shards = jax.process_count() 162 | ds_shard_id = jax.process_index() 163 | else: 164 | num_ds_shards = num_data 165 | n_proc_per_shard = jax.process_count() // num_data 166 | ds_shard_id = jax.process_index() // n_proc_per_shard 167 | else: 168 | num_data_local = jax.local_device_count() 169 | num_ds_shards = jax.process_count() 170 | ds_shard_id = jax.process_index() 171 | 172 | batch_size = self.config.batch_size // num_ds_shards 173 | split_name = 'train' if train else 'test' 174 | 175 | if not is_tfds_folder(self.config.data_path): 176 | if 'dmlab' in self.config.data_path: 177 | dataset = load_npz(self.config, split_name, num_ds_shards, ds_shard_id) 178 | elif 'mnist' in self.config.data_path: 179 | dataset = load_mnist_npz(self.config, split_name, num_ds_shards, ds_shard_id) 180 | else: 181 | dataset = load_video(self.config, split_name, num_ds_shards, ds_shard_id) 182 | else: 183 | seq_len = self.config.seq_len 184 | 185 | def process(features): 186 | video = tf.cast(features['video'], tf.int32) 187 | T = tf.shape(video)[0] 188 | start_idx = tf.random.uniform((), 0, T - seq_len + 1, dtype=tf.int32) 189 | video = tf.identity(video[start_idx:start_idx + seq_len]) 190 | actions = tf.cast(features['actions'], tf.int32) 191 | actions = tf.identity(actions[start_idx:start_idx + seq_len]) 192 | return dict(video=video, actions=actions) 193 | 194 | split = tfds.split_for_jax_process(split_name, process_index=ds_shard_id, 195 | process_count=num_ds_shards) 196 | dataset = tfds.load(osp.basename(self.config.data_path), split=split, 197 | data_dir=osp.dirname(self.config.data_path)) 198 | 199 | # caching only for pre-encoded since raw video will probably 200 | # run OOM on RAM 201 | if self.config.cache: 202 | dataset = dataset.cache() 203 | 204 | options = tf.data.Options() 205 | options.threading.private_threadpool_size = 48 206 | options.threading.max_intra_op_parallelism = 1 207 | dataset = dataset.with_options(options) 208 | dataset = dataset.map(process) 209 | 210 | if repeat: 211 | dataset = dataset.repeat() 212 | if train: 213 | dataset = dataset.shuffle(batch_size * 32, seed=self.config.seed) 214 | 215 | dataset = dataset.batch(batch_size, drop_remainder=True) 216 | dataset = dataset.prefetch(batch_size) 217 | 218 | def prepare_tf_data(xs): 219 | def _prepare(x): 220 | x = x._numpy() 221 | x = x.reshape((num_data_local, -1) + x.shape[1:]) 222 | return x 223 | xs = jax.tree_util.tree_map(_prepare, xs) 224 | return xs 225 | 226 | iterator = map(prepare_tf_data, dataset) 227 | 228 | if prefetch: 229 | iterator = jax_utils.prefetch_to_device(iterator, 2) 230 | 231 | return iterator 232 | -------------------------------------------------------------------------------- /src/fvd.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 BAIR OPEN RESEARCH COMMONS REPOSITORY 4 | # To view a copy of this license, visit 5 | # https://github.com/wilson1yan/teco/tree/master 6 | # ------------------------------------------------------------------------------ 7 | 8 | from tqdm import tqdm 9 | import numpy as np 10 | import tensorflow.compat.v2 as tf 11 | import tensorflow_gan as tfgan 12 | import tensorflow_hub as hub 13 | 14 | i3d_model = None 15 | 16 | 17 | def fvd_preprocess(videos, target_resolution): 18 | videos = tf.convert_to_tensor(videos * 255.0, dtype=tf.float32) 19 | videos_shape = videos.shape.as_list() 20 | all_frames = tf.reshape(videos, [-1] + videos_shape[-3:]) 21 | resized_videos = tf.image.resize(all_frames, size=target_resolution) 22 | target_shape = [videos_shape[0], -1] + list(target_resolution) + [3] 23 | output_videos = tf.reshape(resized_videos, target_shape) 24 | scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1 25 | return scaled_videos 26 | 27 | 28 | def create_id3_embedding(videos): 29 | """Get id3 embeddings.""" 30 | global i3d_model 31 | module_spec = 'https://tfhub.dev/deepmind/i3d-kinetics-400/1' 32 | 33 | if not i3d_model: 34 | base_model = hub.load(module_spec) 35 | input_tensor = base_model.graph.get_tensor_by_name('input_frames:0') 36 | i3d_model = base_model.prune(input_tensor, 'RGB/inception_i3d/Mean:0') 37 | 38 | output = i3d_model(videos) 39 | return output 40 | 41 | 42 | def calculate_fvd(real_activations, generated_activations): 43 | return tfgan.eval.frechet_classifier_distance_from_activations( 44 | real_activations, generated_activations) 45 | 46 | 47 | def embed(videos): 48 | pbar = tqdm(total=sum([v.shape[0] for v in videos])) 49 | embs = [] 50 | for video in videos: 51 | for v in video: 52 | v = v[None] 53 | v = fvd_preprocess(v, (224, 224)).numpy() 54 | emb = create_id3_embedding(tf.convert_to_tensor(v, dtype=tf.float32)) 55 | embs.append(emb) 56 | pbar.update(1) 57 | embs = np.concatenate(embs) 58 | return embs 59 | 60 | 61 | def fvd(video_1, video_2): 62 | if not isinstance(video_1, (tuple, list)): 63 | video_1, video_2 = [video_1], [video_2] 64 | 65 | embed_1 = embed(video_1) 66 | embed_2 = embed(video_2) 67 | result = calculate_fvd(embed_1, embed_2) 68 | return result.numpy() 69 | -------------------------------------------------------------------------------- /src/fvd_mnist.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 BAIR OPEN RESEARCH COMMONS REPOSITORY 4 | # To view a copy of this license, visit 5 | # https://github.com/wilson1yan/teco/tree/master 6 | # ------------------------------------------------------------------------------ 7 | # 8 | # ------------------------------------------------------------------------------ 9 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 10 | # 11 | # This work is made available under the Nvidia Source Code License. 12 | # To view a copy of this license, visit 13 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 14 | # 15 | # Written by Jimmy Smith 16 | # ------------------------------------------------------------------------------ 17 | 18 | 19 | from tqdm import tqdm 20 | import numpy as np 21 | import tensorflow.compat.v2 as tf 22 | import tensorflow_gan as tfgan 23 | import tensorflow_hub as hub 24 | 25 | i3d_model = None 26 | 27 | 28 | def fvd_preprocess(videos, target_resolution): 29 | videos = tf.convert_to_tensor(videos * 255.0, dtype=tf.float32) 30 | videos_shape = videos.shape.as_list() 31 | all_frames = tf.reshape(videos, [-1] + videos_shape[-3:]) 32 | resized_videos = tf.image.resize(all_frames, size=target_resolution) 33 | resized_videos = tf.repeat(resized_videos, 3, -1) 34 | target_shape = [videos_shape[0], -1] + list(target_resolution) + [3] 35 | output_videos = tf.reshape(resized_videos, target_shape) 36 | scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1 37 | return scaled_videos 38 | 39 | 40 | def create_id3_embedding(videos): 41 | """Get id3 embeddings.""" 42 | global i3d_model 43 | module_spec = 'https://tfhub.dev/deepmind/i3d-kinetics-400/1' 44 | 45 | if not i3d_model: 46 | base_model = hub.load(module_spec) 47 | input_tensor = base_model.graph.get_tensor_by_name('input_frames:0') 48 | i3d_model = base_model.prune(input_tensor, 'RGB/inception_i3d/Mean:0') 49 | 50 | output = i3d_model(videos) 51 | return output 52 | 53 | 54 | def calculate_fvd(real_activations, generated_activations): 55 | return tfgan.eval.frechet_classifier_distance_from_activations( 56 | real_activations, generated_activations) 57 | 58 | 59 | def embed(videos): 60 | pbar = tqdm(total=sum([v.shape[0] for v in videos])) 61 | embs = [] 62 | for video in videos: 63 | for v in video: 64 | v = v[None] 65 | v = fvd_preprocess(v, (224, 224)).numpy() 66 | emb = create_id3_embedding(tf.convert_to_tensor(v, dtype=tf.float32)) 67 | embs.append(emb) 68 | pbar.update(1) 69 | embs = np.concatenate(embs) 70 | return embs 71 | 72 | 73 | def fvd(video_1, video_2): 74 | if not isinstance(video_1, (tuple, list)): 75 | video_1, video_2 = [video_1], [video_2] 76 | 77 | embed_1 = embed(video_1) 78 | embed_2 = embed(video_2) 79 | result = calculate_fvd(embed_1, embed_2) 80 | return result.numpy() 81 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 BAIR OPEN RESEARCH COMMONS REPOSITORY 4 | # To view a copy of this license, visit 5 | # https://github.com/wilson1yan/teco/tree/master 6 | # ------------------------------------------------------------------------------ 7 | from tqdm import tqdm 8 | import numpy as np 9 | import jax 10 | import jax.numpy as jnp 11 | import lpips_jax 12 | 13 | 14 | lpips_eval = None 15 | 16 | 17 | def compute_metric(prediction, ground_truth, metric_fn, replicate=True, average_dim=1): 18 | # BTHWC in [0, 1] 19 | assert prediction.shape == ground_truth.shape 20 | B, T = prediction.shape[0], prediction.shape[1] 21 | prediction = prediction.reshape(-1, *prediction.shape[2:]) 22 | ground_truth = ground_truth.reshape(-1, *ground_truth.shape[2:]) 23 | 24 | if replicate: 25 | prediction = np.reshape(prediction, (jax.local_device_count(), -1, *prediction.shape[-3:])) 26 | ground_truth = np.reshape(ground_truth, (jax.local_device_count(), -1, *ground_truth.shape[-3:])) 27 | 28 | metrics = metric_fn(prediction, ground_truth) 29 | metrics = np.reshape(metrics, (B, T)) 30 | 31 | metrics = metrics.mean(axis=average_dim) # B or T depending on dim 32 | 33 | return metrics 34 | 35 | 36 | # all methods below take as input pairs of images 37 | # of shape BCHW. They DO NOT reduce batch dimension 38 | # NOTE: Assumes that images are in [0, 1] 39 | 40 | def get_ssim(replicate=True, average_dim=1): 41 | def fn(imgs1, imgs2): 42 | ssim_fn = jax.pmap(ssim) if replicate else ssim 43 | ssim_val = ssim_fn(imgs1, imgs2) 44 | return jax.device_get(ssim_val) 45 | return lambda imgs1, imgs2: compute_metric(imgs1, imgs2, fn, replicate=replicate, average_dim=average_dim) 46 | 47 | 48 | def get_psnr(replicate=True, average_dim=1): 49 | def fn(imgs1, imgs2): 50 | psnr_fn = jax.pmap(psnr) if replicate else psnr 51 | psnr_val = psnr_fn(imgs1, imgs2) 52 | return jax.device_get(psnr_val) 53 | return lambda imgs1, imgs2: compute_metric(imgs1, imgs2, fn, replicate=replicate, average_dim=average_dim) 54 | 55 | 56 | def psnr(a, b, max_val=1.0): 57 | mse = jnp.mean((a - b) ** 2, axis=[-3, -2, -1]) 58 | val = 20 * jnp.log(max_val) / jnp.log(10.0) - np.float32(10 / np.log(10)) * jnp.log(mse) 59 | return val 60 | 61 | 62 | def get_lpips(replicate=True, average_dim=1): 63 | global lpips_eval 64 | if lpips_eval is None: 65 | lpips_eval = lpips_jax.LPIPSEvaluator(net='alexnet', replicate=replicate) 66 | 67 | def fn(imgs1, imgs2): 68 | imgs1 = 2 * imgs1 - 1 69 | imgs2 = 2 * imgs2 - 1 70 | 71 | lpips = lpips_eval(imgs1, imgs2) 72 | lpips = np.reshape(lpips, (-1,)) 73 | return jax.device_get(lpips) 74 | return lambda imgs1, imgs2: compute_metric(imgs1, imgs2, fn, replicate=replicate, average_dim=average_dim) 75 | 76 | 77 | def ssim(img1, img2, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): 78 | ssim_per_channel, _ = _ssim_per_channel(img1, img2, max_val, filter_size, filter_sigma, k1, k2) 79 | return jnp.mean(ssim_per_channel, axis=-1) 80 | 81 | 82 | def _ssim_per_channel(img1, img2, max_val, filter_size, filter_sigma, k1, k2): 83 | kernel = _fspecial_gauss(filter_size, filter_sigma) 84 | kernel = jnp.tile(kernel, [1, 1, img1.shape[-1], 1]) 85 | kernel = jnp.transpose(kernel, [2, 3, 0, 1]) 86 | 87 | compensation = 1.0 88 | 89 | def reducer(x): 90 | x_shape = x.shape 91 | x = jnp.reshape(x, (-1, *x.shape[-3:])) 92 | x = jnp.transpose(x, [0, 3, 1, 2]) 93 | y = jax.lax.conv_general_dilated(x, kernel, [1, 1], 94 | 'VALID', feature_group_count=x.shape[1]) 95 | 96 | y = jnp.reshape(y, [*x_shape[:-3], *y.shape[1:]]) 97 | return y 98 | 99 | luminance, cs = _ssim_helper(img1, img2, reducer, max_val, compensation, k1, k2) 100 | ssim_val = jnp.mean(luminance * cs, axis=[-3, -2]) 101 | cs = jnp.mean(cs, axis=[-3, -2]) 102 | return ssim_val, cs 103 | 104 | 105 | def _ssim_helper(x, y, reducer, max_val, compensation=1.0, k1=0.01, k2=0.03): 106 | c1 = (k1 * max_val) ** 2 107 | c2 = (k2 * max_val) ** 2 108 | 109 | mean0 = reducer(x) 110 | mean1 = reducer(y) 111 | 112 | num0 = mean0 * mean1 * 2.0 113 | den0 = jnp.square(mean0) + jnp.square(mean1) 114 | luminance = (num0 + c1) / (den0 + c1) 115 | 116 | num1 = reducer(x * y) * 2.0 117 | den1 = reducer(jnp.square(x) + jnp.square(y)) 118 | c2 *= compensation 119 | cs = (num1 - num0 + c2) / (den1 - den0 + c2) 120 | 121 | return luminance, cs 122 | 123 | 124 | def _fspecial_gauss(size, sigma): 125 | coords = jnp.arange(size, dtype=jnp.float32) 126 | coords -= (size - 1.0) / 2.0 127 | 128 | g = jnp.square(coords) 129 | g *= -0.5 / jnp.square(sigma) 130 | 131 | g = jnp.reshape(g, [1, -1]) + jnp.reshape(g, [-1, 1]) 132 | g = jnp.reshape(g, [1, -1]) 133 | g = jax.nn.softmax(g, axis=-1) 134 | return jnp.reshape(g, [size, size, 1, 1]) 135 | 136 | 137 | import tensorflow.compat.v2 as tf 138 | import tensorflow_gan as tfgan 139 | import tensorflow_hub as hub 140 | 141 | i3d_model = None 142 | 143 | 144 | # FVD 145 | def fvd_preprocess(videos, target_resolution): 146 | # videos: BTHWC in [0, 1] 147 | videos = tf.convert_to_tensor(videos * 255., dtype=tf.float32) 148 | videos_shape = videos.shape.as_list() 149 | all_frames = tf.reshape(videos, [-1] + videos_shape[-3:]) 150 | resized_videos = tf.image.resize(all_frames, size=target_resolution) 151 | target_shape = [videos_shape[0], -1] + list(target_resolution) + [3] 152 | output_videos = tf.reshape(resized_videos, target_shape) 153 | scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1 154 | return scaled_videos 155 | 156 | 157 | def create_id3_embedding(videos): 158 | global i3d_model 159 | module_spec = 'https://tfhub.dev/deepmind/i3d-kinetics-400/1' 160 | 161 | if not i3d_model: 162 | base_model = hub.load(module_spec) 163 | input_tensor = base_model.graph.get_tensor_by_name('input_frames:0') 164 | i3d_model = base_model.prune(input_tensor, 'RGB/inception_i3d/Mean:0') 165 | 166 | output = i3d_model(videos) 167 | return output 168 | 169 | 170 | def calculate_fd(real_activations, generated_activations): 171 | return tfgan.eval.frechet_classifier_distance_from_activations( 172 | real_activations, generated_activations 173 | ).numpy() 174 | 175 | 176 | def fvd(video_1, video_2): 177 | video_1 = fvd_preprocess(video_1, (224, 224)) 178 | video_2 = fvd_preprocess(video_2, (224, 224)) 179 | x = create_id3_embedding(video_1) 180 | y = create_id3_embedding(video_2) 181 | result = calculate_fd(x, y) 182 | return result 183 | 184 | 185 | video_model, video_state = None, None 186 | 187 | 188 | def compute_feats(state, videos, rng): 189 | rng, new_rng = jax.random.split(rng) 190 | variables = {'params': state.params, **state.model_state} 191 | feats = video_model.apply(variables, videos, return_features=True, rngs={'rng': rng}) 192 | return feats, new_rng 193 | 194 | 195 | def create_video_embedding(videos): 196 | BATCH_SIZE = 32 197 | global video_model, video_state 198 | rngs = jax.random.PRNGKey(0) 199 | rngs = jax.random.split(rngs, jax.local_device_count()) 200 | 201 | if video_model is None: 202 | from .models import load_ckpt 203 | path = '/home/TODO/logs/hier_video/dl_maze_video_contr_1657861689.9321504' 204 | video_model, video_state = load_ckpt(path, data_path='dummy') 205 | 206 | pbar = tqdm(total=videos.shape[0] // BATCH_SIZE) 207 | feats = [] 208 | for i in range(0, videos.shape[0], BATCH_SIZE): 209 | inp = videos[i:i + BATCH_SIZE] 210 | inp = np.reshape(inp, (jax.local_device_count(), -1, *inp.shape[1:])) 211 | f, rngs = jax.pmap(compute_feats)(video_state, inp, rngs) 212 | f = jax.device_get(f) 213 | f = np.reshape(f, (-1, *f.shape[2:])) 214 | feats.append(f) 215 | pbar.update(1) 216 | feats = np.concatenate(feats) 217 | return feats 218 | -------------------------------------------------------------------------------- /src/models/S5/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/ConvSSM/cb08bb3d5eef4f7438a4d76f4c24007beecd66db/src/models/S5/__init__.py -------------------------------------------------------------------------------- /src/models/S5/diagonal_scans.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 Linderman Lab 4 | # To view a copy of this license, visit 5 | # https://github.com/lindermanlab/S5/blob/main/LICENSE 6 | # ------------------------------------------------------------------------------ 7 | 8 | import jax 9 | from jax import lax, numpy as np 10 | 11 | 12 | @jax.vmap 13 | def binary_operator(q_i, q_j): 14 | """ Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. 15 | Args: 16 | q_i: tuple containing A_i and Bu_i at position i (P,), (P,) 17 | q_j: tuple containing A_j and Bu_j at position j (P,), (P,) 18 | Returns: 19 | new element ( A_out, Bu_out ) 20 | """ 21 | A_i, b_i = q_i 22 | A_j, b_j = q_j 23 | return A_j * A_i, A_j * b_i + b_j 24 | 25 | 26 | def apply_ssm_parallel(Lambda_bar, B_bar, C_tilde, input_sequence, x0): 27 | """ Compute the LxH output of discretized SSM given an LxH input. 28 | Args: 29 | Lambda_bar (complex64): discretized diagonal state matrix (P,) 30 | B_bar (complex64): discretized input matrix (P, H) 31 | C_tilde (complex64): output matrix (H, P) 32 | input_sequence (float32): input sequence of features (L, H) 33 | x0 (complex64): initial state (P,) 34 | Returns: 35 | ys (float32): the SSM outputs (S5 layer preactivations) (L, H) 36 | """ 37 | Lambda_elements = Lambda_bar * np.ones((input_sequence.shape[0], 38 | Lambda_bar.shape[0])) 39 | Bu_elements = jax.vmap(lambda u: B_bar @ u)(input_sequence) 40 | Bu_elements = Bu_elements.at[0].add(Lambda_bar * x0) 41 | 42 | _, xs = jax.lax.associative_scan(binary_operator, (Lambda_elements, Bu_elements)) 43 | ys = jax.vmap(lambda x: 2*(C_tilde @ x).real)(xs) 44 | return xs[-1], ys 45 | 46 | 47 | def apply_ssm_sequential(Lambda_bar, B_bar, C_tilde, input_sequence, x0): 48 | """ Compute the LxH output of discretized SSM given an LxH input. 49 | Args: 50 | Lambda_bar (complex64): discretized diagonal state matrix (P,) 51 | B_bar (complex64): discretized input matrix (P, H) 52 | C_tilde (complex64): output matrix (H, P) 53 | input_sequence (float32): input sequence of features (L, H) 54 | x0 (complex64): initial state (P,) 55 | Returns: 56 | ys (float32): the SSM outputs (S5 layer preactivations) (L, H) 57 | """ 58 | def step(x_k_1, u_k): 59 | Bu = B_bar @ u_k 60 | x_k = Lambda_bar * x_k_1 + Bu 61 | y_k = 2*(C_tilde @ x_k).real 62 | return x_k, y_k 63 | 64 | return lax.scan(step, x0, input_sequence) 65 | -------------------------------------------------------------------------------- /src/models/S5/layers.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 Linderman Lab 4 | # To view a copy of this license, visit 5 | # https://github.com/lindermanlab/S5/blob/main/LICENSE 6 | # ------------------------------------------------------------------------------ 7 | 8 | from flax import linen as nn 9 | 10 | 11 | class SequenceLayer(nn.Module): 12 | """Defines a single layer with activation, 13 | layer/batch norm, pre/postnorm, dropout, etc""" 14 | ssm: nn.Module 15 | training: bool 16 | parallel: bool 17 | dropout: float = 0.0 18 | use_norm: bool = True 19 | prenorm: bool = False 20 | per_layer_skip: bool = True 21 | 22 | def setup(self): 23 | self.seq = self.ssm(parallel=self.parallel) 24 | 25 | if self.use_norm: 26 | self.norm = nn.LayerNorm() 27 | 28 | self.drop = nn.Dropout( 29 | self.dropout, 30 | broadcast_dims=[0], 31 | deterministic=not self.training, 32 | ) 33 | 34 | def __call__(self, u, x0): 35 | if self.per_layer_skip: 36 | skip = u 37 | else: 38 | skip = 0 39 | # Apply pre-norm if necessary 40 | if self.use_norm: 41 | if self.prenorm: 42 | u = self.norm(u) 43 | x_L, u = self.seq(u, x0) 44 | u = self.drop(u) 45 | u = skip + u 46 | if self.use_norm: 47 | if not self.prenorm: 48 | u = self.norm(u) 49 | return x_L, u 50 | 51 | 52 | class StackedLayers(nn.Module): 53 | """Stacks S5 layers 54 | output: outputs LxbszxH_uxW_uxU sequence of outputs and 55 | a list containing the last state of each layer""" 56 | ssm: nn.Module 57 | n_layers: int 58 | training: bool 59 | parallel: bool 60 | dropout: float = 0.0 61 | use_norm: bool = False 62 | prenorm: bool = False 63 | skip_connections: bool = False 64 | per_layer_skip: bool = True 65 | 66 | def setup(self): 67 | 68 | self.layers = [ 69 | SequenceLayer( 70 | ssm=self.ssm, 71 | dropout=self.dropout, 72 | training=self.training, 73 | parallel=self.parallel, 74 | use_norm=self.use_norm, 75 | prenorm=self.prenorm, 76 | per_layer_skip=self.per_layer_skip 77 | ) 78 | for _ in range(self.n_layers) 79 | ] 80 | 81 | def __call__(self, u, initial_states): 82 | # u is shape (L, bsz, d_in, im_H, im_W) 83 | # x0s is a list of initial arrays each of shape (bsz, d_model, im_H, im_W) 84 | last_states = [] 85 | for i in range(len(self.layers)): 86 | if self.skip_connections: 87 | if i == 3: 88 | layer9_in = u 89 | elif i == 6: 90 | layer12_in = u 91 | 92 | if i == 8: 93 | u = u + layer9_in 94 | elif i == 11: 95 | u = u + layer12_in 96 | 97 | x_L, u = self.layers[i](u, initial_states[i]) 98 | last_states.append(x_L) # keep last state of each layer 99 | return last_states, u 100 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 BAIR OPEN RESEARCH COMMONS REPOSITORY 4 | # To view a copy of this license, visit 5 | # https://github.com/wilson1yan/teco/tree/master 6 | # ------------------------------------------------------------------------------ 7 | # 8 | # ------------------------------------------------------------------------------ 9 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 10 | # 11 | # This work is made available under the Nvidia Source Code License. 12 | # To view a copy of this license, visit 13 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 14 | # 15 | # Written by Jimmy Smith 16 | # ------------------------------------------------------------------------------ 17 | 18 | 19 | import jax.numpy as jnp 20 | 21 | from src.models.sequence_models.VQ.S5 import S5 22 | from src.models.sequence_models.noVQ.convLSTM import CONVLSTM_NOVQ 23 | from src.models.sequence_models.noVQ.convS5 import CONVS5_NOVQ 24 | from src.models.sequence_models.VQ.teco_S5 import TECO_S5 25 | from src.models.sequence_models.VQ.teco_convS5 import TECO_CONVS5 26 | from src.models.sequence_models.VQ.convS5 import CONVS5 27 | from src.models.sequence_models.VQ.teco_transformer import TECO_TRANSFORMER 28 | from src.models.sequence_models.VQ.transformer import TRANSFORMER 29 | from src.models.sequence_models.noVQ.transformer import TRANSFORMER_NOVQ 30 | from .vqgan import VQGAN 31 | from .vae import VAE 32 | from functools import partial 33 | 34 | 35 | def load_vqvae(ckpt_path, need_encode=True): 36 | import jax 37 | import argparse 38 | 39 | model, state = load_ckpt(ckpt_path, training=False, replicate=False) 40 | 41 | def wrap_apply(fn): 42 | variables = {'params': state.params, **state.model_state} 43 | return lambda *args: model.apply(variables, *args, method=fn) 44 | 45 | def no_encode(encodings): 46 | variables = {'params': state.params, **state.model_state} 47 | embeddings = model.apply(variables, encodings, method=model.codebook_lookup) 48 | return embeddings, encodings 49 | 50 | video_encode = jax.jit(wrap_apply(model.encode)) if need_encode else jax.jit(no_encode) 51 | video_decode = jax.jit(wrap_apply(model.decode)) 52 | codebook_lookup = jax.jit(wrap_apply(model.codebook_lookup)) 53 | 54 | return dict(encode=video_encode, decode=video_decode, lookup=codebook_lookup), argparse.Namespace(latent_shape=model.latent_shape, embedding_dim=model.embedding_dim, n_codes=model.n_codes) 55 | 56 | 57 | def load_ckpt(ckpt_path, replicate=True, return_config=False, 58 | default_if_none=dict(), need_encode=None, **kwargs): 59 | import os.path as osp 60 | import pickle 61 | from flax import jax_utils 62 | from flax.training import checkpoints 63 | from ..train_utils import TrainState 64 | 65 | config = pickle.load(open(osp.join(ckpt_path, 'args'), 'rb')) 66 | for k, v in kwargs.items(): 67 | setattr(config, k, v) 68 | for k, v in default_if_none.items(): 69 | if not hasattr(config, k): 70 | print('did not find', k, 'setting default to', v) 71 | setattr(config, k, v) 72 | 73 | model = get_model(config, need_encode=need_encode) 74 | state = checkpoints.restore_checkpoint(osp.join(ckpt_path, 'checkpoints'), None) 75 | if config.model in ['teco_convS5', 'convS5', 'teco_S5', 'S5', 'convS5_noVQ', 'convLSTM_noVQ']: 76 | state = TrainState( 77 | step=state['step'], 78 | params=state['params'], 79 | opt_state=state['opt_state'], 80 | model_state=state['model_state'], 81 | apply_fn=model(parallel=True, training=True).apply, 82 | tx=None 83 | ) 84 | else: 85 | state = TrainState( 86 | step=state['step'], 87 | params=state['params'], 88 | opt_state=state['opt_state'], 89 | model_state=state['model_state'], 90 | apply_fn=model.apply, 91 | tx=None 92 | ) 93 | 94 | assert state is not None, f'No checkpoint found in {ckpt_path}' 95 | 96 | if replicate: 97 | state = jax_utils.replicate(state) 98 | 99 | if return_config: 100 | return model, state, config 101 | else: 102 | return model, state 103 | 104 | 105 | def get_model(config, need_encode=None, xmap=False, **kwargs): 106 | if config.model in ['teco_transformer', 'transformer', 107 | 'teco_convS5', 'convS5', 108 | 'S5', 'teco_S5']: 109 | if need_encode is None: 110 | need_encode = not 'encoded' in config.data_path 111 | vq_fns, vqvae = load_vqvae(config.vqvae_ckpt, need_encode) 112 | kwargs.update(vq_fns=vq_fns, vqvae=vqvae) 113 | 114 | kwargs['dtype'] = jnp.float32 115 | 116 | if config.model == 'vqgan': 117 | model = VQGAN(config, **kwargs) 118 | elif config.model == 'autoencoder': 119 | model = VAE(config, **kwargs) 120 | elif config.model == 'transformer': 121 | model = TRANSFORMER(config, **kwargs) 122 | elif config.model == 'teco_transformer': 123 | model = TECO_TRANSFORMER(config, **kwargs) 124 | elif config.model == 'transformer_noVQ': 125 | model = TRANSFORMER_NOVQ(config, **kwargs) 126 | elif config.model == 'convS5': 127 | model = partial(CONVS5, 128 | config=config, **kwargs) 129 | elif config.model == 'convS5_noVQ': 130 | model = partial(CONVS5_NOVQ, 131 | config=config, **kwargs) 132 | elif config.model == 'teco_convS5': 133 | model = partial(TECO_CONVS5, 134 | config=config, **kwargs) 135 | elif config.model == 'S5': 136 | model = partial(S5, 137 | config=config, **kwargs) 138 | elif config.model == 'teco_S5': 139 | model = partial(TECO_S5, 140 | config=config, **kwargs) 141 | elif config.model == 'convLSTM_noVQ': 142 | model = partial(CONVLSTM_NOVQ, 143 | config=config, **kwargs) 144 | else: 145 | raise ValueError(f'Invalid model: {config.model}') 146 | 147 | return model 148 | -------------------------------------------------------------------------------- /src/models/base.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 BAIR OPEN RESEARCH COMMONS REPOSITORY 4 | # To view a copy of this license, visit 5 | # https://github.com/wilson1yan/teco/tree/master 6 | # ------------------------------------------------------------------------------ 7 | 8 | from typing import Any, Tuple, Optional 9 | from flax import linen as nn 10 | import jax 11 | import jax.numpy as jnp 12 | import optax 13 | 14 | 15 | def constant(value, dtype=jnp.float32): 16 | def init(key, shape, dtype=dtype): 17 | dtype = jax.dtypes.canonicalize_dtype(dtype) 18 | return jnp.full(shape, value, dtype=dtype) 19 | return init 20 | 21 | 22 | class ResNetEncoder(nn.Module): 23 | depths: Tuple 24 | blocks: int 25 | dtype: Optional[Any] = jnp.float32 26 | 27 | @nn.compact 28 | def __call__(self, x): 29 | x = nn.Conv(self.depths[0], [3, 3], dtype=self.dtype)(x) 30 | x = ResNetBlock(self.depths[0], dtype=self.dtype)(x) 31 | for i in range(1, len(self.depths)): 32 | x = nn.avg_pool(x, (2, 2), strides=(2, 2)) 33 | for _ in range(self.blocks): 34 | x = ResNetBlock(self.depths[i], dtype=self.dtype)(x) 35 | return x 36 | 37 | 38 | class ResNetDecoder(nn.Module): 39 | image_size: int 40 | depths: Tuple 41 | blocks: int 42 | out_dim: int 43 | dtype: Optional[Any] = jnp.float32 44 | 45 | @nn.compact 46 | def __call__(self, deter, embeddings=None): 47 | depths = list(reversed(self.depths)) 48 | x = deter 49 | if embeddings is not None: 50 | x = jnp.concatenate([x, embeddings], axis=-1) 51 | 52 | x = nn.Conv(self.depths[0], [3, 3], dtype=self.dtype)(x) 53 | 54 | for i in range(len(depths) - 1): 55 | for _ in range(self.blocks): 56 | x = ResNetBlock(depths[i], dtype=self.dtype)(x) 57 | x = jax.image.resize(x, (x.shape[0], 2 * x.shape[1], 2 * x.shape[2], x.shape[3]), 58 | jax.image.ResizeMethod.NEAREST) 59 | for _ in range(self.blocks): 60 | x = ResNetBlock(depths[-1], dtype=self.dtype)(x) 61 | x = nn.LayerNorm(dtype=self.dtype)(x) 62 | 63 | x = nn.Dense(self.out_dim, dtype=self.dtype)(x) 64 | return x 65 | 66 | 67 | class ResNetBlock(nn.Module): 68 | depth: int 69 | dtype: Optional[Any] = jnp.float32 70 | 71 | @nn.compact 72 | def __call__(self, x): 73 | skip = x 74 | if skip.shape[-1] != self.depth: 75 | skip = nn.Conv(self.depth, [1, 1], use_bias=False, 76 | dtype=self.dtype, name='skip')(skip) 77 | 78 | x = nn.elu(nn.GroupNorm(dtype=self.dtype)(x)) 79 | x = nn.Conv(self.depth, [3, 3], dtype=self.dtype)(x) 80 | x = nn.elu(nn.GroupNorm(dtype=self.dtype)(x)) 81 | x = nn.Conv(self.depth, [3, 3], dtype=self.dtype, use_bias=False)(x) 82 | x = AddBias(dtype=self.dtype)(x) 83 | return skip + 0.1 * x 84 | 85 | 86 | class Codebook(nn.Module): 87 | n_codes: int 88 | proj_dim: int 89 | embedding_dim: int 90 | dtype: Optional[Any] = jnp.float32 91 | 92 | @nn.compact 93 | def __call__(self, z, encoding_indices=None): 94 | z = jnp.asarray(z, jnp.float32) 95 | 96 | # z: B...D 97 | codebook = self.param('codebook', nn.initializers.normal(stddev=0.02), 98 | [self.n_codes, self.proj_dim]) 99 | codebook = normalize(codebook) 100 | 101 | embedding_dim = self.embedding_dim 102 | proj_in = nn.Dense(self.proj_dim, use_bias=False) 103 | proj_out = nn.Dense(embedding_dim, use_bias=False) 104 | 105 | if encoding_indices is not None: 106 | z = codebook[(encoding_indices,)] 107 | z = proj_out(z) 108 | return z 109 | 110 | z_proj = normalize(proj_in(z)) 111 | flat_inputs = jnp.reshape(z_proj, (-1, self.proj_dim)) 112 | distances = 2 - 2 * flat_inputs @ codebook.T 113 | 114 | encoding_indices = jnp.argmin(distances, axis=1) 115 | encode_onehot = jax.nn.one_hot(encoding_indices, self.n_codes, dtype=flat_inputs.dtype) 116 | encoding_indices = jnp.reshape(encoding_indices, z.shape[:-1]) 117 | 118 | quantized = codebook[(encoding_indices,)] 119 | 120 | commitment_loss = 0.25 * optax.l2_loss(z_proj, jax.lax.stop_gradient(quantized)).mean() 121 | codebook_loss = optax.l2_loss(jax.lax.stop_gradient(z_proj), quantized).mean() 122 | 123 | quantized_st = jax.lax.stop_gradient(quantized - z_proj) + z_proj 124 | quantized_st = proj_out(quantized_st) 125 | 126 | avg_probs = jnp.mean(encode_onehot, axis=0) 127 | perplexity = jnp.exp(-jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10))) 128 | 129 | quantized_st = jnp.asarray(quantized_st, self.dtype) 130 | 131 | return dict(embeddings=quantized_st, encodings=encoding_indices, 132 | commitment_loss=commitment_loss, codebook_loss=codebook_loss, 133 | perplexity=perplexity) 134 | 135 | 136 | class AddBias(nn.Module): 137 | dtype: Any = jnp.float32 138 | param_dtype: Any = jnp.float32 139 | 140 | @nn.compact 141 | def __call__(self, x): 142 | bias = self.param('bias', nn.initializers.zeros, (x.shape[-1],), self.param_dtype) 143 | x += bias 144 | return x 145 | 146 | 147 | def normalize(x): 148 | x = x / jnp.clip(jnp.linalg.norm(x, axis=-1, keepdims=True), a_min=1e-6, a_max=None) 149 | return x 150 | -------------------------------------------------------------------------------- /src/models/convLSTM/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/ConvSSM/cb08bb3d5eef4f7438a4d76f4c24007beecd66db/src/models/convLSTM/__init__.py -------------------------------------------------------------------------------- /src/models/convLSTM/layers.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # 4 | # This work is made available under the Nvidia Source Code License. 5 | # To view a copy of this license, visit 6 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 7 | # 8 | # Written by Jimmy Smith 9 | # ------------------------------------------------------------------------------ 10 | 11 | from flax import linen as nn 12 | 13 | 14 | class SequenceLayer(nn.Module): 15 | """Defines a single layer with activation, 16 | layer/batch norm, pre/postnorm, dropout, etc""" 17 | ssm: nn.Module 18 | training: bool 19 | parallel: bool 20 | activation_fn: str = None 21 | dropout: float = 0.0 22 | use_norm: bool = True 23 | prenorm: bool = False 24 | per_layer_skip: bool = True 25 | num_groups: int = 32 26 | squeeze_excite: bool = False 27 | 28 | def setup(self): 29 | self.seq = self.ssm(parallel=self.parallel) 30 | 31 | if self.use_norm: 32 | self.norm = nn.LayerNorm() 33 | 34 | # TODO: Need to figure out dropout strategy, maybe drop whole channels? 35 | self.drop = nn.Dropout( 36 | self.dropout, 37 | broadcast_dims=[0], 38 | deterministic=not self.training, 39 | ) 40 | 41 | def __call__(self, u, x0): 42 | if self.per_layer_skip: 43 | skip = u 44 | else: 45 | skip = 0 46 | # Apply pre-norm if necessary 47 | if self.use_norm: 48 | if self.prenorm: 49 | u = self.norm(u) 50 | x_L, u = self.seq(u, x0) 51 | u = self.drop(u) 52 | u = skip + u 53 | if self.use_norm: 54 | if not self.prenorm: 55 | u = self.norm(u) 56 | return x_L, u 57 | 58 | 59 | class StackedLayers(nn.Module): 60 | """Stacks S5 layers 61 | output: outputs LxbszxH_uxW_uxU sequence of outputs and 62 | a list containing the last state of each layer""" 63 | ssm: nn.Module 64 | n_layers: int 65 | training: bool 66 | parallel: bool 67 | layer_activation: str = "gelu" 68 | dropout: float = 0.0 69 | use_norm: bool = False 70 | prenorm: bool = False 71 | skip_connections: bool = False 72 | per_layer_skip: bool = True 73 | num_groups: int = 32 74 | squeeze_excite: bool = False 75 | 76 | def setup(self): 77 | 78 | self.layers = [ 79 | SequenceLayer( 80 | ssm=self.ssm, 81 | activation_fn=self.layer_activation, 82 | dropout=self.dropout, 83 | training=self.training, 84 | parallel=self.parallel, 85 | use_norm=self.use_norm, 86 | prenorm=self.prenorm, 87 | per_layer_skip=self.per_layer_skip, 88 | num_groups=self.num_groups, 89 | squeeze_excite=self.squeeze_excite 90 | ) 91 | for _ in range(self.n_layers) 92 | ] 93 | 94 | def __call__(self, u, initial_states): 95 | # u is shape (L, bsz, d_in, im_H, im_W) 96 | # x0s is a list of initial arrays each of shape (bsz, d_model, im_H, im_W) 97 | last_states = [] 98 | for i in range(len(self.layers)): 99 | if self.skip_connections: 100 | if i == 3: 101 | layer9_in = u 102 | elif i == 6: 103 | layer12_in = u 104 | 105 | if i == 8: 106 | u = u + layer9_in 107 | elif i == 11: 108 | u = u + layer12_in 109 | 110 | x_L, u = self.layers[i](u, initial_states[i]) 111 | last_states.append(x_L) # keep last state of each layer 112 | return last_states, u 113 | -------------------------------------------------------------------------------- /src/models/convLSTM/scans.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # 4 | # This work is made available under the Nvidia Source Code License. 5 | # To view a copy of this license, visit 6 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 7 | # 8 | # Written by Jimmy Smith 9 | # ------------------------------------------------------------------------------ 10 | 11 | from jax import lax, numpy as np 12 | from jax.nn import sigmoid 13 | 14 | 15 | # Scan functions 16 | def apply_convLSTM(A, us, x0): 17 | """Compute the output sequence of the convolutional LSTM 18 | given the input sequence sequentially. For testing purposes. 19 | Args: 20 | A (float32): Conv kernel A (k_a,k_a, U+P, 4*P) 21 | us (float32): input sequence of features (L,bsz,H, W, U) 22 | x0 (float32): initial state (bsz, H, W, P) 23 | Returns: 24 | x_L (float32): the last state of the SSM (bsz, H, W, P) 25 | ys (float32): the conv LSTM states (L,bsz, H, W, U) 26 | """ 27 | 28 | def step(x_k_1, u_k): 29 | c_k_1, h_k_1 = x_k_1 30 | 31 | combo = np.concatenate((u_k, h_k_1), axis=-1) # concat along channel dim 32 | 33 | combo_conv = lax.conv_general_dilated(combo, A, (1, 1), 34 | 'SAME', 35 | dimension_numbers=('NHWC', 'HWIO', 'NHWC')) 36 | # combo_conv = lax.conv(combo, A, (1, 1), 'SAME') 37 | cc_i, cc_f, cc_o, cc_g = np.split(combo_conv, 4, axis=-1) 38 | 39 | i = sigmoid(cc_i) 40 | f = sigmoid(cc_f) 41 | o = sigmoid(cc_o) 42 | g = np.tanh(cc_g) 43 | 44 | c_k = f * c_k_1 + i * g 45 | h_k = o * np.tanh(c_k) 46 | return (c_k, h_k), h_k 47 | return lax.scan(step, x0, us) 48 | -------------------------------------------------------------------------------- /src/models/convLSTM/ssm.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # 4 | # This work is made available under the Nvidia Source Code License. 5 | # To view a copy of this license, visit 6 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 7 | # 8 | # Written by Jimmy Smith 9 | # ------------------------------------------------------------------------------ 10 | 11 | 12 | from functools import partial 13 | from flax import linen as nn 14 | from jax.nn.initializers import he_normal 15 | 16 | from . import scans 17 | 18 | 19 | def initialize_kernel(key, shape): 20 | """For general kernels, e.g. C,D, encoding/decoding""" 21 | out_dim, in_dim, k = shape 22 | fan_in = in_dim*(k**2) 23 | 24 | # Note in_axes should be the first by default: 25 | # https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.initializers.variance_scaling.html#jax.nn.initializers.variance_scaling 26 | return he_normal()(key, 27 | (fan_in, out_dim)).reshape(k, 28 | k, 29 | in_dim, 30 | out_dim) 31 | 32 | 33 | class ConvLSTM(nn.Module): 34 | U: int # Number of SSM input and output features 35 | P: int # Number of state features of SSM 36 | k_A: int # A kernel width/height 37 | parallel: bool = False # Cannot compute convLSTM in parallel 38 | # but include this attribute for consistency 39 | # in layers.py 40 | 41 | def setup(self): 42 | # Initialize state to state (A) transition kernel 43 | self.A = self.param("A", 44 | initialize_kernel, 45 | (4 * self.P, self.U+self.P, self.k_A)) 46 | 47 | def __call__(self, input_sequence, x0): 48 | """ 49 | input sequence is shape (L, bsz, U, H, W) 50 | x0 is (bsz, U, H, W) 51 | Returns: 52 | x_L (float32): the last state of the SSM (bsz, P, H, W) 53 | hs (float32): the conv LSTM states (L,bsz, U, H, W) 54 | """ 55 | # For sequential generation (e.g. autoregressive decoding) 56 | return scans.apply_convLSTM(self.A, 57 | input_sequence, 58 | x0) 59 | 60 | 61 | def init_ConvLSTM(U, 62 | P, 63 | k_A): 64 | return partial(ConvLSTM, 65 | U=U, 66 | P=P, 67 | k_A=k_A) 68 | -------------------------------------------------------------------------------- /src/models/convS5/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/ConvSSM/cb08bb3d5eef4f7438a4d76f4c24007beecd66db/src/models/convS5/__init__.py -------------------------------------------------------------------------------- /src/models/convS5/diagonal_scans.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # 4 | # This work is made available under the Nvidia Source Code License. 5 | # To view a copy of this license, visit 6 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 7 | # 8 | # Written by Jimmy Smith 9 | # ------------------------------------------------------------------------------ 10 | 11 | import jax 12 | from jax import lax, numpy as np 13 | from .conv_ops import vmap_conv 14 | 15 | 16 | # Scan functions 17 | @jax.vmap 18 | def conv_binary_operator(q_i, q_j): 19 | """Assumes 1x1 kernels 20 | :inputs q_i an q_j are tuples containing (A_i, BU_i) and (A_j, BU_j) 21 | :inputs A_i and A_j are (P,) 22 | :inputs BU_i and BU_j are bszxH_UxW_UxP 23 | :returns tuple where first entry AA is (P,) 24 | and second entry is bszxH_UxW_UxP""" 25 | 26 | A_i, BU_i = q_i 27 | A_j, BU_j = q_j 28 | 29 | # AA = convolve_1x1_kernels(A_j, A_i) 30 | AA = A_j * A_i 31 | A_jBU_i = np.expand_dims(A_j, (0, 1, 2)) * BU_i 32 | 33 | return AA, A_jBU_i + BU_j 34 | 35 | 36 | def apply_convSSM_parallel(A, B, C, us, x0): 37 | """Compute the output sequence of the convolutional SSM 38 | given the input sequence using a parallel scan. 39 | Computes x_k = A * x_{k-1} + B * u_k 40 | y_k = C * x_k + D * U_k 41 | where * is a convolution operator. 42 | Args: 43 | A (complex64): Conv kernel A (P,) 44 | B (complex64): input-to-state conv kernel (k_B,k_B,U,P) 45 | C (complex64): state-to-output conv kernel (k_c,k_c, P, U) 46 | us (float32): input sequence of features (L,bsz,H, W, U) 47 | x0 (complex64): initial state (bsz, H, W, P) 48 | Returns: 49 | x_L (complex64): the last state of the SSM (bsz, H, W, P) 50 | ys (float32): the conv SSM outputs (L,bsz, H, W, U) 51 | """ 52 | L = us.shape[0] 53 | As = A * np.ones((L,)+A.shape) 54 | Bus = vmap_conv(B, np.complex64(us)) 55 | Bus = Bus.at[0].add(np.expand_dims(A, (0, 1, 2)) * x0) 56 | 57 | _, xs = lax.associative_scan(conv_binary_operator, (As, Bus)) 58 | 59 | ys = 2 * vmap_conv(C, xs).real 60 | 61 | return xs[-1], ys 62 | 63 | 64 | def apply_convSSM_sequential(A, B, C, us, x0): 65 | """Compute the output sequence of the convolutional SSM 66 | given the input sequence sequentially. For testing purposes. 67 | Args: 68 | A (complex64): Conv kernel A (P,) 69 | B (complex64): input-to-state conv kernel (k_B,k_B,U,P) 70 | C (complex64): state-to-output conv kernel (k_c,k_c, P, U) 71 | us (float32): input sequence of features (L,bsz,H, W, U) 72 | x0 (complex64): initial state (bsz, H, W, P) 73 | Returns: 74 | x_L (complex64): the last state of the SSM (bsz, H, W, P) 75 | ys (float32): the conv SSM outputs (L,bsz, H, W, U) 76 | """ 77 | def step(x_k_1, u_k): 78 | Bu = lax.conv_general_dilated(np.complex64(u_k), B, (1, 1), 79 | 'SAME', 80 | dimension_numbers=('NHWC', 'HWIO', 'NHWC')) 81 | x_k = np.expand_dims(A, (0, 1, 2)) * x_k_1 + Bu 82 | y_k = 2 * lax.conv_general_dilated(x_k, C, (1, 1), 83 | 'SAME', 84 | dimension_numbers=('NHWC', 'HWIO', 'NHWC')).real 85 | return x_k, y_k 86 | return lax.scan(step, np.complex64(x0), us) 87 | -------------------------------------------------------------------------------- /src/models/convS5/layers.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # 4 | # This work is made available under the Nvidia Source Code License. 5 | # To view a copy of this license, visit 6 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 7 | # 8 | # Written by Jimmy Smith 9 | # ------------------------------------------------------------------------------ 10 | 11 | 12 | from flax import linen as nn 13 | 14 | 15 | class SequenceLayer(nn.Module): 16 | """Defines a single layer with activation, 17 | layer/batch norm, pre/postnorm, dropout, etc""" 18 | ssm: nn.Module 19 | training: bool 20 | parallel: bool 21 | activation_fn: str = "gelu" 22 | dropout: float = 0.0 23 | use_norm: bool = True 24 | prenorm: bool = False 25 | per_layer_skip: bool = True 26 | num_groups: int = 32 27 | squeeze_excite: bool = False 28 | 29 | def setup(self): 30 | if self.activation_fn in ["relu"]: 31 | self.activation = nn.relu 32 | elif self.activation_fn in ["gelu"]: 33 | self.activation = nn.gelu 34 | elif self.activation_fn in ["swish"]: 35 | self.activation = nn.swish 36 | elif self.activation_fn in ["elu"]: 37 | self.activation = nn.elu 38 | 39 | self.seq = self.ssm(parallel=self.parallel, 40 | activation=self.activation, 41 | num_groups=self.num_groups, 42 | squeeze_excite=self.squeeze_excite) 43 | 44 | if self.use_norm: 45 | self.norm = nn.LayerNorm() 46 | 47 | # TODO: Need to figure out dropout strategy, maybe drop whole channels? 48 | self.drop = nn.Dropout( 49 | self.dropout, 50 | broadcast_dims=[0], 51 | deterministic=not self.training, 52 | ) 53 | 54 | def __call__(self, u, x0): 55 | if self.per_layer_skip: 56 | skip = u 57 | else: 58 | skip = 0 59 | # Apply pre-norm if necessary 60 | if self.use_norm: 61 | if self.prenorm: 62 | u = self.norm(u) 63 | x_L, u = self.seq(u, x0) 64 | u = self.drop(u) 65 | u = skip + u 66 | if self.use_norm: 67 | if not self.prenorm: 68 | u = self.norm(u) 69 | return x_L, u 70 | 71 | 72 | class StackedLayers(nn.Module): 73 | """Stacks S5 layers 74 | output: outputs LxbszxH_uxW_uxU sequence of outputs and 75 | a list containing the last state of each layer""" 76 | ssm: nn.Module 77 | n_layers: int 78 | training: bool 79 | parallel: bool 80 | layer_activation: str = "gelu" 81 | dropout: float = 0.0 82 | use_norm: bool = False 83 | prenorm: bool = False 84 | skip_connections: bool = False 85 | per_layer_skip: bool = True 86 | num_groups: int = 32 87 | squeeze_excite: bool = False 88 | 89 | def setup(self): 90 | 91 | self.layers = [ 92 | SequenceLayer( 93 | ssm=self.ssm, 94 | activation_fn=self.layer_activation, 95 | dropout=self.dropout, 96 | training=self.training, 97 | parallel=self.parallel, 98 | use_norm=self.use_norm, 99 | prenorm=self.prenorm, 100 | per_layer_skip=self.per_layer_skip, 101 | num_groups=self.num_groups, 102 | squeeze_excite=self.squeeze_excite 103 | ) 104 | for _ in range(self.n_layers) 105 | ] 106 | 107 | def __call__(self, u, initial_states): 108 | # u is shape (L, bsz, d_in, im_H, im_W) 109 | # x0s is a list of initial arrays each of shape (bsz, d_model, im_H, im_W) 110 | last_states = [] 111 | for i in range(len(self.layers)): 112 | if self.skip_connections: 113 | if i == 3: 114 | layer9_in = u 115 | elif i == 6: 116 | layer12_in = u 117 | 118 | if i == 8: 119 | u = u + layer9_in 120 | elif i == 11: 121 | u = u + layer12_in 122 | 123 | x_L, u = self.layers[i](u, initial_states[i]) 124 | last_states.append(x_L) # keep last state of each layer 125 | return last_states, u 126 | -------------------------------------------------------------------------------- /src/models/sampling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/ConvSSM/cb08bb3d5eef4f7438a4d76f4c24007beecd66db/src/models/sampling/__init__.py -------------------------------------------------------------------------------- /src/models/sampling/sample_convSSM.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # 4 | # This work is made available under the Nvidia Source Code License. 5 | # To view a copy of this license, visit 6 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 7 | # 8 | # Written by Jimmy Smith 9 | # ------------------------------------------------------------------------------ 10 | 11 | 12 | import numpy as np 13 | from tqdm import tqdm 14 | import jax 15 | import jax.numpy as jnp 16 | 17 | 18 | def _observe(state, video, actions, model_par): 19 | variables = {'params': state.params, **state.model_state} 20 | outs = model_par.apply(variables, 21 | video, actions, 22 | method=model_par.condition) 23 | 24 | _, z_embeddings, _, _, last_states = outs 25 | z_T = z_embeddings[:, -1:] 26 | return z_T, last_states 27 | 28 | 29 | def _imagine(state, z_embedding, initial_states, action, rng, model_seq, causal_masking): 30 | variables = {'params': state.params, **state.model_state} 31 | rng, new_rng = jax.random.split(rng) 32 | out, _ = model_seq.apply(variables, 33 | z_embedding, initial_states, action, causal_masking, 34 | method=model_seq.sample_timestep, 35 | rngs={'sample': rng}, 36 | mutable=["prime"]) 37 | 38 | z_t, _, recon, last_states = out 39 | return recon, z_t, last_states, new_rng 40 | 41 | 42 | def _decode(x, model_par): 43 | return model_par.vq_fns['decode'](x[:, None])[:, 0] 44 | 45 | 46 | def _encode(x, model_par): 47 | return model_par.vq_fns['encode'](x) 48 | 49 | 50 | def sample(model_par, state, video, actions, action_conditioned, open_loop_ctx, 51 | p_observe, p_imagine, p_encode, p_decode, 52 | seed=0, state_spec=None): 53 | 54 | use_xmap = state_spec is not None 55 | 56 | if use_xmap: 57 | num_local_data = max(1, jax.local_device_count() // model_par.config.num_shards) 58 | else: 59 | num_local_data = jax.local_device_count() 60 | rngs = jax.random.PRNGKey(seed) 61 | rngs = jax.random.split(rngs, num_local_data) 62 | 63 | assert video.shape[0] == num_local_data, f'{video.shape}, {num_local_data}' 64 | assert model_par.config.n_cond <= model_par.config.open_loop_ctx 65 | 66 | if not model_par.config.use_actions: 67 | if actions is None: 68 | actions = jnp.zeros(video.shape[:3], dtype=jnp.int32) 69 | else: 70 | actions = jnp.zeros_like(actions) 71 | else: 72 | if not action_conditioned: 73 | actions = model_par.config.action_mask_id * np.ones(actions.shape, dtype=jnp.int32) 74 | 75 | if video.shape[0] < jax.local_device_count(): 76 | devices = jax.local_devices()[:video.shape[0]] 77 | else: 78 | devices = None 79 | 80 | num_input_frames = open_loop_ctx 81 | _, encodings = p_encode(video[:, :, :num_input_frames]) 82 | 83 | z, last_states = p_observe(state, encodings, actions[:, :, :num_input_frames]) 84 | 85 | recon = [encodings[:, :, i] for i in range(num_input_frames)] 86 | dummy_encoding = jnp.zeros_like(recon[0]) 87 | itr = list(range(num_input_frames, model_par.config.eval_seq_len)) 88 | for i in tqdm(itr): 89 | if i >= model_par.config.seq_len: 90 | # TODO 91 | pass 92 | else: 93 | act = actions[:, :, i:i+1] 94 | 95 | r, z, last_states, rngs = p_imagine(state, z, last_states, act, rngs) 96 | z = jnp.expand_dims(z, 2) 97 | recon.append(r) 98 | encodings = jnp.stack(recon, axis=2) 99 | 100 | def decode(samples): 101 | # samples: NBTHW 102 | N, B, T = samples.shape[:3] 103 | if N < jax.local_device_count(): 104 | devices = jax.local_devices()[:N] 105 | else: 106 | devices = None 107 | 108 | samples = jax.device_get(samples) 109 | samples = np.reshape(samples, (-1, *samples.shape[3:])) 110 | 111 | recons = [] 112 | for i in list(range(0, N * B * T, 64)): 113 | inp = samples[i:i + 64] 114 | inp = np.reshape(inp, (N, -1, *inp.shape[1:])) 115 | # recon = jax.pmap(_decode, devices=devices)(inp) 116 | recon = p_decode(inp) 117 | recon = jax.device_get(recon) 118 | recon = np.reshape(recon, (-1, *recon.shape[2:])) 119 | recons.append(recon) 120 | recons = np.concatenate(recons, axis=0) 121 | recons = np.reshape(recons, (N, B, T, *recons.shape[1:])) 122 | recons = np.clip(recons, -1, 1) 123 | return recons # BTHWC 124 | samples = decode(encodings) 125 | 126 | if video.shape[3] == 16: 127 | video = decode(video) 128 | 129 | return samples, video 130 | -------------------------------------------------------------------------------- /src/models/sampling/sample_convSSM_noVQ.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # 4 | # This work is made available under the Nvidia Source Code License. 5 | # To view a copy of this license, visit 6 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 7 | # 8 | # Written by Jimmy Smith 9 | # ------------------------------------------------------------------------------ 10 | 11 | import numpy as np 12 | from tqdm import tqdm 13 | import jax 14 | import jax.numpy as jnp 15 | 16 | 17 | def _observe(state, video, actions, model_par): 18 | variables = {'params': state.params, **state.model_state} 19 | outs = model_par.apply(variables, 20 | video, actions, 21 | method=model_par.condition) 22 | 23 | _, z_embeddings, _, _, last_states = outs 24 | z_T = z_embeddings[:, -1:] 25 | return z_T, last_states 26 | 27 | 28 | def _imagine(state, z_embedding, initial_states, action, rng, model_seq): 29 | variables = {'params': state.params, **state.model_state} 30 | rng, new_rng = jax.random.split(rng) 31 | out, _ = model_seq.apply(variables, 32 | z_embedding, initial_states, action, 33 | method=model_seq.sample_timestep, 34 | mutable=["prime"]) 35 | 36 | z_t, _, recon, last_states = out 37 | return recon, z_t, last_states, new_rng 38 | 39 | 40 | def sample(model_par, state, video, actions, action_conditioned, open_loop_ctx, 41 | p_observe, p_imagine, p_encode, p_decode, eval_seq_len, 42 | seed=0, state_spec=None): 43 | 44 | use_xmap = state_spec is not None 45 | 46 | if use_xmap: 47 | num_local_data = max(1, jax.local_device_count() // model_par.config.num_shards) 48 | else: 49 | num_local_data = jax.local_device_count() 50 | rngs = jax.random.PRNGKey(seed) 51 | rngs = jax.random.split(rngs, num_local_data) 52 | 53 | assert video.shape[0] == num_local_data, f'{video.shape}, {num_local_data}' 54 | assert model_par.config.n_cond <= model_par.config.open_loop_ctx 55 | 56 | if not model_par.config.use_actions: 57 | if actions is None: 58 | actions = jnp.zeros(video.shape[:3], dtype=jnp.int32) 59 | else: 60 | actions = jnp.zeros_like(actions) 61 | else: 62 | if not action_conditioned: 63 | actions = model_par.config.action_mask_id * np.ones(actions.shape, dtype=jnp.int32) 64 | 65 | if video.shape[0] < jax.local_device_count(): 66 | devices = jax.local_devices()[:video.shape[0]] 67 | else: 68 | devices = None 69 | 70 | num_input_frames = open_loop_ctx 71 | encodings = video[:, :, :num_input_frames] 72 | 73 | z, last_states = p_observe(state, encodings, actions[:, :, :num_input_frames]) 74 | 75 | recon = [encodings[:, :, i] for i in range(num_input_frames)] 76 | dummy_encoding = jnp.zeros_like(recon[0]) 77 | itr = list(range(num_input_frames, eval_seq_len)) 78 | for i in tqdm(itr): 79 | if i >= model_par.config.seq_len: 80 | # TODO 81 | act = actions[:, :, i:i+1] 82 | else: 83 | act = actions[:, :, i:i+1] 84 | 85 | r, z, last_states, rngs = p_imagine(state, z, last_states, act, rngs) 86 | recon.append(r[:, :, 0]) 87 | encodings = jnp.stack(recon, axis=2) 88 | 89 | def decode(samples): 90 | # samples: NBTHW 91 | N, B, T = samples.shape[:3] 92 | if N < jax.local_device_count(): 93 | devices = jax.local_devices()[:N] 94 | else: 95 | devices = None 96 | 97 | samples = jax.device_get(samples) 98 | samples = np.reshape(samples, (-1, *samples.shape[3:])) 99 | 100 | recons = [] 101 | for i in list(range(0, N * B * T, 64)): 102 | inp = samples[i:i + 64] 103 | inp = np.reshape(inp, (N, -1, *inp.shape[1:])) 104 | # recon = jax.pmap(_decode, devices=devices)(inp) 105 | recon = inp 106 | recon = jax.device_get(recon) 107 | recon = np.reshape(recon, (-1, *recon.shape[2:])) 108 | recons.append(recon) 109 | recons = np.concatenate(recons, axis=0) 110 | recons = np.reshape(recons, (N, B, T, *recons.shape[1:])) 111 | recons = np.clip(recons, -1, 1) 112 | return recons # BTHWC 113 | samples = decode(encodings) 114 | 115 | if video.shape[3] == 16: 116 | video = decode(video) 117 | 118 | return samples, video[:, :, :eval_seq_len] 119 | -------------------------------------------------------------------------------- /src/models/sampling/sample_transformer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 BAIR OPEN RESEARCH COMMONS REPOSITORY 4 | # To view a copy of this license, visit 5 | # https://github.com/wilson1yan/teco/tree/master 6 | # ------------------------------------------------------------------------------ 7 | # 8 | # ------------------------------------------------------------------------------ 9 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 10 | # 11 | # This work is made available under the Nvidia Source Code License. 12 | # To view a copy of this license, visit 13 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 14 | # 15 | # Written by Jimmy Smith 16 | # ------------------------------------------------------------------------------ 17 | 18 | import numpy as np 19 | from tqdm import tqdm 20 | import jax 21 | import jax.numpy as jnp 22 | 23 | 24 | def _observe(state, encodings, model): 25 | variables = {'params': state.params, **state.model_state} 26 | cond, out = model.apply(variables, encodings, 27 | method=model.encode) 28 | return cond, out['embeddings'] 29 | 30 | 31 | def _imagine(state, z_embeddings, actions, cond, t, rng, model, causal_masking): 32 | variables = {'params': state.params, **state.model_state} 33 | rng, new_rng = jax.random.split(rng) 34 | z, recon = model.apply(variables, z_embeddings, actions, cond, t, causal_masking, 35 | method=model.sample_timestep, 36 | rngs={'sample': rng}) 37 | return recon, z, new_rng 38 | 39 | 40 | def _decode(x, model): 41 | return model.vq_fns['decode'](x[:, None])[:, 0] 42 | 43 | 44 | def _encode(x, model): 45 | return model.vq_fns['encode'](x) 46 | 47 | 48 | def sample(model, state, video, actions, action_conditioned, open_loop_ctx, 49 | p_observe, p_imagine, p_encode, p_decode, seed=0, state_spec=None): 50 | 51 | use_xmap = state_spec is not None 52 | 53 | if use_xmap: 54 | num_local_data = max(1, jax.local_device_count() // model.config.num_shards) 55 | else: 56 | num_local_data = jax.local_device_count() 57 | rngs = jax.random.PRNGKey(seed) 58 | rngs = jax.random.split(rngs, num_local_data) 59 | 60 | assert video.shape[0] == num_local_data, f'{video.shape}, {num_local_data}' 61 | assert model.config.n_cond <= open_loop_ctx 62 | 63 | if not model.config.use_actions: 64 | if actions is None: 65 | actions = jnp.zeros(video.shape[:3], dtype=jnp.int32) 66 | else: 67 | actions = jnp.zeros_like(actions) 68 | else: 69 | if not action_conditioned: 70 | actions = model.config.action_mask_id * np.ones(actions.shape, dtype=jnp.int32) 71 | 72 | if video.shape[0] < jax.local_device_count(): 73 | devices = jax.local_devices()[:video.shape[0]] 74 | else: 75 | devices = None 76 | # _, encodings = jax.pmap(model.vq_fns['encode'], devices=devices)(video) 77 | _, encodings = p_encode(video) 78 | 79 | # if use_xmap: 80 | # p_observe = xmap(_observe, in_axes=(state_spec, ('data', ...)), 81 | # out_axes=('data', ...), 82 | # axis_resources={'data': 'dp', 'model': 'mp'}) 83 | # p_imagine = xmap(_imagine, in_axes=(state_spec, ('data', ...), ('data', ...), 84 | # ('data', ...), (...,), ('data', ...)), 85 | # out_axes=('data', ...), 86 | # axis_resources={'data': 'dp', 'model': 'mp'}) 87 | # else: 88 | # p_observe = jax.pmap(_observe) 89 | # p_imagine = jax.pmap(_imagine, in_axes=(0, 0, 0, 0, None, 0)) 90 | 91 | cond, zs = p_observe(state, encodings) 92 | 93 | if model.config.model in ["teco_transformer"]: 94 | sub = model.config.n_cond 95 | elif model.config.model in ["transformer", "performer"]: 96 | sub = 0 97 | else: 98 | raise NotImplementedError("The model type is not supported.") 99 | 100 | zs = zs[:, :, :model.config.seq_len - sub] 101 | 102 | recon = [encodings[:, :, i] for i in range(open_loop_ctx)] 103 | dummy_encoding = jnp.zeros_like(recon[0]) 104 | itr = list(range(open_loop_ctx, model.config.eval_seq_len)) 105 | for i in tqdm(itr): 106 | if i >= model.config.seq_len: 107 | encodings = jnp.stack([*recon[-model.config.seq_len + 1:], dummy_encoding], axis=2) 108 | cond, zs = p_observe(state, encodings) 109 | act = actions[:, :, i - model.config.seq_len + 1:i + 1] 110 | i = model.config.seq_len - 1 111 | else: 112 | act = actions[:, :, :model.config.seq_len] 113 | 114 | r, z, rngs = p_imagine(state, zs, act, cond, i, rngs) 115 | zs = zs.at[:, :, i - sub].set(z) 116 | recon.append(r) 117 | encodings = jnp.stack(recon, axis=2) 118 | 119 | def decode(samples): 120 | # samples: NBTHW 121 | N, B, T = samples.shape[:3] 122 | if N < jax.local_device_count(): 123 | devices = jax.local_devices()[:N] 124 | else: 125 | devices = None 126 | 127 | samples = jax.device_get(samples) 128 | samples = np.reshape(samples, (-1, *samples.shape[3:])) 129 | 130 | recons = [] 131 | for i in list(range(0, N * B * T, 64)): 132 | inp = samples[i:i + 64] 133 | inp = np.reshape(inp, (N, -1, *inp.shape[1:])) 134 | # recon = jax.pmap(_decode, devices=devices)(inp) 135 | recon = p_decode(inp) 136 | recon = jax.device_get(recon) 137 | recon = np.reshape(recon, (-1, *recon.shape[2:])) 138 | recons.append(recon) 139 | recons = np.concatenate(recons, axis=0) 140 | recons = np.reshape(recons, (N, B, T, *recons.shape[1:])) 141 | recons = np.clip(recons, -1, 1) 142 | return recons # BTHWC 143 | samples = decode(encodings) 144 | 145 | if video.shape[3] == 16: 146 | video = decode(video) 147 | 148 | return samples, video 149 | -------------------------------------------------------------------------------- /src/models/sampling/sample_transformer_noVQ.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # 4 | # This work is made available under the Nvidia Source Code License. 5 | # To view a copy of this license, visit 6 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 7 | # 8 | # Written by Jimmy Smith 9 | # ------------------------------------------------------------------------------ 10 | 11 | 12 | import numpy as np 13 | from tqdm import tqdm 14 | import jax 15 | import jax.numpy as jnp 16 | 17 | 18 | def _observe(state, encodings, model): 19 | variables = {'params': state.params, **state.model_state} 20 | cond, out = model.apply(variables, encodings, 21 | method=model.encode) 22 | return cond, out['embeddings'] 23 | 24 | 25 | def _imagine(state, z_embeddings, actions, cond, t, rng, model): 26 | variables = {'params': state.params, **state.model_state} 27 | rng, new_rng = jax.random.split(rng) 28 | z, recon = model.apply(variables, z_embeddings, actions, cond, t, 29 | method=model.sample_timestep, 30 | rngs={'sample': rng}) 31 | return recon, z, new_rng 32 | 33 | 34 | def sample(model, state, video, actions, action_conditioned, open_loop_ctx, 35 | p_observe, p_imagine, p_encode, p_decode, eval_seq_len, seed=0, state_spec=None): 36 | 37 | use_xmap = state_spec is not None 38 | 39 | if use_xmap: 40 | num_local_data = max(1, jax.local_device_count() // model.config.num_shards) 41 | else: 42 | num_local_data = jax.local_device_count() 43 | rngs = jax.random.PRNGKey(seed) 44 | rngs = jax.random.split(rngs, num_local_data) 45 | 46 | assert video.shape[0] == num_local_data, f'{video.shape}, {num_local_data}' 47 | assert model.config.n_cond <= open_loop_ctx 48 | 49 | if not model.config.use_actions: 50 | if actions is None: 51 | actions = jnp.zeros(video.shape[:3], dtype=jnp.int32) 52 | else: 53 | actions = jnp.zeros_like(actions) 54 | else: 55 | if not action_conditioned: 56 | actions = model.config.action_mask_id * np.ones(actions.shape, dtype=jnp.int32) 57 | 58 | if video.shape[0] < jax.local_device_count(): 59 | devices = jax.local_devices()[:video.shape[0]] 60 | else: 61 | devices = None 62 | # _, encodings = jax.pmap(model.vq_fns['encode'], devices=devices)(video) 63 | encodings = video 64 | 65 | # if use_xmap: 66 | # p_observe = xmap(_observe, in_axes=(state_spec, ('data', ...)), 67 | # out_axes=('data', ...), 68 | # axis_resources={'data': 'dp', 'model': 'mp'}) 69 | # p_imagine = xmap(_imagine, in_axes=(state_spec, ('data', ...), ('data', ...), 70 | # ('data', ...), (...,), ('data', ...)), 71 | # out_axes=('data', ...), 72 | # axis_resources={'data': 'dp', 'model': 'mp'}) 73 | # else: 74 | # p_observe = jax.pmap(_observe) 75 | # p_imagine = jax.pmap(_imagine, in_axes=(0, 0, 0, 0, None, 0)) 76 | 77 | cond, zs = p_observe(state, encodings) 78 | sub = 0 79 | 80 | zs = zs[:, :, :model.config.seq_len - sub] 81 | 82 | recon = [encodings[:, :, i] for i in range(open_loop_ctx)] 83 | dummy_encoding = jnp.zeros_like(recon[0]) 84 | itr = list(range(open_loop_ctx, eval_seq_len)) 85 | for i in tqdm(itr): 86 | if i >= model.config.seq_len: 87 | encodings = jnp.stack([*recon[-model.config.seq_len + 1:], dummy_encoding], axis=2) 88 | cond, zs = p_observe(state, encodings) 89 | act = actions[:, :, i - model.config.seq_len + 1:i + 1] 90 | i = model.config.seq_len - 1 91 | else: 92 | act = actions[:, :, :model.config.seq_len] 93 | 94 | r, z, rngs = p_imagine(state, zs, act, cond, i, rngs) 95 | zs = zs.at[:, :, i - sub].set(z) 96 | recon.append(r) 97 | encodings = jnp.stack(recon, axis=2) 98 | 99 | def decode(samples): 100 | # samples: NBTHW 101 | N, B, T = samples.shape[:3] 102 | if N < jax.local_device_count(): 103 | devices = jax.local_devices()[:N] 104 | else: 105 | devices = None 106 | 107 | samples = jax.device_get(samples) 108 | samples = np.reshape(samples, (-1, *samples.shape[3:])) 109 | 110 | recons = [] 111 | for i in list(range(0, N * B * T, 64)): 112 | inp = samples[i:i + 64] 113 | inp = np.reshape(inp, (N, -1, *inp.shape[1:])) 114 | # recon = jax.pmap(_decode, devices=devices)(inp) 115 | recon = inp 116 | recon = jax.device_get(recon) 117 | recon = np.reshape(recon, (-1, *recon.shape[2:])) 118 | recons.append(recon) 119 | recons = np.concatenate(recons, axis=0) 120 | recons = np.reshape(recons, (N, B, T, *recons.shape[1:])) 121 | recons = np.clip(recons, -1, 1) 122 | return recons # BTHWC 123 | samples = decode(encodings) 124 | 125 | if video.shape[3] == 16: 126 | video = decode(video) 127 | 128 | return samples, video[:, :, :eval_seq_len] 129 | -------------------------------------------------------------------------------- /src/models/sequence_models/VQ/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/ConvSSM/cb08bb3d5eef4f7438a4d76f4c24007beecd66db/src/models/sequence_models/VQ/__init__.py -------------------------------------------------------------------------------- /src/models/sequence_models/VQ/transformer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 BAIR OPEN RESEARCH COMMONS REPOSITORY 4 | # To view a copy of this license, visit 5 | # https://github.com/wilson1yan/teco/tree/master 6 | # ------------------------------------------------------------------------------ 7 | # 8 | # ------------------------------------------------------------------------------ 9 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 10 | # 11 | # This work is made available under the Nvidia Source Code License. 12 | # To view a copy of this license, visit 13 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 14 | # 15 | # Written by Jimmy Smith 16 | # ------------------------------------------------------------------------------ 17 | 18 | 19 | from typing import Optional, Any, Dict, Callable 20 | import optax 21 | import numpy as np 22 | import flax.linen as nn 23 | import jax 24 | import jax.numpy as jnp 25 | from jax import random 26 | 27 | from src.models.transformer.transformer import Transformer 28 | from src.models.base import ResNetEncoder, ResNetDecoder 29 | 30 | 31 | class TRANSFORMER(nn.Module): 32 | config: Any 33 | vq_fns: Dict[str, Callable] 34 | vqvae: Any 35 | dtype: Optional[Any] = jnp.float32 36 | 37 | @property 38 | def metrics(self): 39 | metrics = ['loss'] 40 | return metrics 41 | 42 | def setup(self): 43 | config = self.config 44 | 45 | self.action_embeds = nn.Embed(config.action_dim + 1, config.action_embed_dim, dtype=self.dtype) 46 | 47 | # Posterior 48 | self.sos_post = self.param('sos_post', nn.initializers.normal(stddev=0.02), 49 | (*self.vqvae.latent_shape, self.vqvae.embedding_dim), jnp.float32) 50 | self.encoder = ResNetEncoder(**config.encoder, dtype=self.dtype) 51 | ds = 2 ** (len(config.encoder['depths']) - 1) 52 | self.z_shape = tuple([d // ds for d in self.vqvae.latent_shape]) 53 | 54 | # Temporal Transformer 55 | z_kernel = [config.z_ds, config.z_ds] 56 | self.z_tfm_shape = tuple([d // config.z_ds for d in self.z_shape]) 57 | self.z_proj = nn.Conv(config.z_tfm_kwargs['embed_dim'], z_kernel, 58 | strides=z_kernel, use_bias=False, padding='VALID', dtype=self.dtype) 59 | 60 | self.sos = self.param('sos', nn.initializers.normal(stddev=0.02), 61 | (*self.z_tfm_shape, config.z_tfm_kwargs['embed_dim'],), jnp.float32) 62 | self.z_tfm = Transformer( 63 | **config.z_tfm_kwargs, pos_embed_type='sinusoidal', 64 | shape=(config.seq_len, *self.z_tfm_shape), 65 | dtype=self.dtype 66 | ) 67 | self.z_unproj = nn.ConvTranspose(config.embedding_dim, z_kernel, strides=z_kernel, 68 | padding='VALID', use_bias=False, dtype=self.dtype) 69 | 70 | # Decoder 71 | out_dim = self.vqvae.n_codes 72 | self.decoder = ResNetDecoder(**config.decoder, image_size=self.vqvae.latent_shape[0], 73 | out_dim=out_dim, dtype=self.dtype) 74 | 75 | def sample_timestep(self, z_embeddings, actions, cond, t, causal_masking): 76 | t -= self.config.n_cond 77 | actions = self.action_embeds(actions) 78 | 79 | if causal_masking: 80 | deter = self.temporal_transformer( 81 | z_embeddings, actions, cond, deterministic=True, 82 | mask_frames=True, num_input_frames=self.config.open_loop_ctx_1 83 | ) 84 | else: 85 | deter = self.temporal_transformer( 86 | z_embeddings, actions, cond, deterministic=True 87 | ) 88 | 89 | deter = deter[:, t] 90 | 91 | key = self.make_rng('sample') 92 | recon_logits = self.decoder(deter) 93 | recon = random.categorical(key, recon_logits) 94 | inp = self.vq_fns['lookup'](recon) 95 | z_t = self.encoder(inp) 96 | return z_t, recon 97 | 98 | def _init_mask(self): 99 | n_per = np.prod(self.z_tfm_shape) 100 | mask = jnp.tril(jnp.ones((self.config.seq_len, self.config.seq_len), dtype=bool)) 101 | mask = mask.repeat(n_per, axis=0).repeat(n_per, axis=1) 102 | return mask 103 | 104 | def encode(self, encodings): 105 | embeddings = self.vq_fns['lookup'](encodings) 106 | inp = embeddings 107 | out = jax.vmap(self.encoder, 1, 1)(inp) 108 | return None, {'embeddings': out} 109 | 110 | def temporal_transformer(self, z_embeddings, actions, cond, deterministic=False, 111 | mask_frames=False, drop_inds=None, num_input_frames=None): 112 | 113 | inp = z_embeddings 114 | 115 | if mask_frames: 116 | inp_embed = inp[:, :num_input_frames] 117 | mask_embed_shape = inp[:, num_input_frames:].shape 118 | if drop_inds is not None: 119 | inp = jnp.where(drop_inds[:, None, None, None, None], 120 | inp, 121 | jnp.concatenate((inp_embed, 122 | self.config.frame_mask_id * jnp.ones(mask_embed_shape)), 123 | axis=1)) 124 | else: 125 | inp = jnp.concatenate((inp_embed, 126 | self.config.frame_mask_id * jnp.ones(mask_embed_shape)), axis=1) 127 | 128 | actions = jnp.tile(actions[:, :, None, None], (1, 1, *inp.shape[2:4], 1)) 129 | inp = jnp.concatenate([inp[:, :-1], actions[:, 1:]], axis=-1) 130 | inp = jax.vmap(self.z_proj, 1, 1)(inp) 131 | 132 | sos = jnp.tile(self.sos[None, None], (z_embeddings.shape[0], 1, 1, 1, 1)) 133 | sos = jnp.asarray(sos, self.dtype) 134 | 135 | inp = jnp.concatenate([sos, inp], axis=1) 136 | deter = self.z_tfm(inp, mask=self._init_mask(), deterministic=deterministic) 137 | deter = deter[:, self.config.n_cond:] 138 | 139 | deter = jax.vmap(self.z_unproj, 1, 1)(deter) 140 | 141 | return deter 142 | 143 | def __call__(self, video, actions, deterministic=False): 144 | # video: BTCHW, actions: BT 145 | if not self.config.use_actions: 146 | if actions is None: 147 | actions = jnp.zeros(video.shape[:2], dtype=jnp.int32) 148 | else: 149 | actions = jnp.zeros_like(actions) 150 | 151 | if self.config.dropout_actions: 152 | dropout_actions = jax.random.bernoulli(self.make_rng('sample'), p=self.config.action_dropout_rate, 153 | shape=(video.shape[0],)) # B 154 | actions = jnp.where(dropout_actions[:, None], self.config.action_mask_id, actions) 155 | else: 156 | dropout_actions = None 157 | 158 | actions = self.action_embeds(actions) 159 | _, encodings = self.vq_fns['encode'](video) 160 | 161 | cond, vq_output = self.encode(encodings) 162 | z_embeddings = vq_output['embeddings'] 163 | 164 | if self.config.causal_masking: 165 | deter = self.temporal_transformer( 166 | z_embeddings, actions, cond, deterministic=deterministic, 167 | mask_frames=True, 168 | drop_inds=dropout_actions, 169 | num_input_frames=self.config.open_loop_ctx_1 170 | ) 171 | else: 172 | deter = self.temporal_transformer( 173 | z_embeddings, actions, cond, deterministic=deterministic 174 | ) 175 | 176 | encodings = encodings[:, self.config.n_cond:] 177 | labels = jax.nn.one_hot(encodings, num_classes=self.vqvae.n_codes) 178 | labels = labels * 0.99 + 0.01 / self.vqvae.n_codes # Label smoothing 179 | 180 | if self.config.causal_masking: 181 | # Currently no support for droploss with causal masking 182 | recon_logits = jax.vmap(self.decoder, 1, 1)(deter) 183 | recon_logits_out = recon_logits[:, self.config.open_loop_ctx_1 - 1:] 184 | labels_out = labels[:, self.config.open_loop_ctx_1 - 1:] 185 | 186 | recon_logits_1 = jnp.where(dropout_actions[:, None, None, None, None], 187 | labels_out, 188 | recon_logits_out) 189 | 190 | loss_1 = optax.softmax_cross_entropy(recon_logits_1, labels_out) 191 | loss_1 = loss_1.sum(axis=(-2, -1)) 192 | loss_1 = loss_1.mean() 193 | 194 | recon_logits_2 = jnp.where(dropout_actions[:, None, None, None, None], 195 | recon_logits, 196 | labels) 197 | 198 | loss_2 = optax.softmax_cross_entropy(recon_logits_2, labels) 199 | loss_2 = loss_2.sum(axis=(-2, -1)) 200 | loss_2 = loss_2.mean() 201 | 202 | loss = loss_1 + loss_2 203 | 204 | else: 205 | if self.config.drop_loss_rate is not None and self.config.drop_loss_rate > 0.0: 206 | n_sample = int((1 - self.config.drop_loss_rate) * deter.shape[1]) 207 | n_sample = max(1, n_sample) 208 | idxs = jax.random.randint(self.make_rng('sample'), 209 | [n_sample], 210 | 0, video.shape[1], dtype=jnp.int32) 211 | else: 212 | idxs = jnp.arange(deter.shape[1], dtype=jnp.int32) 213 | 214 | deter = deter[:, idxs] 215 | labels = labels[:, idxs] 216 | 217 | recon_logits = jax.vmap(self.decoder, 1, 1)(deter) 218 | recon_loss = optax.softmax_cross_entropy(recon_logits, labels) 219 | recon_loss = recon_loss.sum(axis=(-2, -1)) 220 | recon_loss = recon_loss.mean() 221 | loss = recon_loss 222 | 223 | out = dict(loss=loss) 224 | return out 225 | 226 | 227 | -------------------------------------------------------------------------------- /src/models/sequence_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/ConvSSM/cb08bb3d5eef4f7438a4d76f4c24007beecd66db/src/models/sequence_models/__init__.py -------------------------------------------------------------------------------- /src/models/sequence_models/noVQ/.ipynb_checkpoints/transformer-checkpoint.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any 2 | import optax 3 | import numpy as np 4 | import flax.linen as nn 5 | import jax 6 | import jax.numpy as jnp 7 | 8 | from src.models.transformer.transformer import Transformer 9 | from src.models.base import ResNetEncoder, ResNetDecoder 10 | 11 | 12 | class TECO_NOVQ(nn.Module): 13 | config: Any 14 | dtype: Optional[Any] = jnp.float32 15 | 16 | @property 17 | def metrics(self): 18 | metrics = ['loss', 'mse_loss', 'l1_loss'] 19 | return metrics 20 | 21 | def setup(self): 22 | config = self.config 23 | 24 | self.action_embeds = nn.Embed(config.action_dim + 1, config.action_embed_dim, dtype=self.dtype) 25 | 26 | # Posterior 27 | self.encoder = ResNetEncoder(**config.encoder, dtype=self.dtype) 28 | ds = 1 29 | self.z_shape = tuple([d // ds for d in self.config.latent_shape]) 30 | # Temporal Transformer 31 | z_kernel = [config.z_ds, config.z_ds] 32 | self.z_tfm_shape = tuple([d // config.z_ds for d in self.z_shape]) 33 | self.z_proj = nn.Conv(config.z_tfm_kwargs['embed_dim'], z_kernel, 34 | strides=z_kernel, use_bias=False, padding='VALID', dtype=self.dtype) 35 | 36 | self.sos = self.param('sos', nn.initializers.normal(stddev=0.02), 37 | (*self.z_tfm_shape, config.z_tfm_kwargs['embed_dim'],), jnp.float32) 38 | self.z_tfm = Transformer( 39 | **config.z_tfm_kwargs, pos_embed_type='sinusoidal', 40 | shape=(config.seq_len, *self.z_tfm_shape), 41 | dtype=self.dtype 42 | ) 43 | self.z_unproj = nn.ConvTranspose(config.embedding_dim, z_kernel, strides=z_kernel, 44 | padding='VALID', use_bias=False, dtype=self.dtype) 45 | 46 | # Decoder 47 | out_dim = self.config.channels 48 | self.decoder = ResNetDecoder(**config.decoder, image_size=0, 49 | out_dim=out_dim, dtype=self.dtype) 50 | 51 | def sample_timestep(self, z_embeddings, actions, cond, t): 52 | t -= self.config.n_cond 53 | actions = self.action_embeds(actions) 54 | deter = self.temporal_transformer( 55 | z_embeddings, actions, cond, deterministic=True 56 | ) 57 | deter = deter[:, t] 58 | 59 | key = self.make_rng('sample') 60 | recon_logits = nn.tanh(self.decoder(deter)) 61 | recon = recon_logits 62 | z_t = self.encoder(recon) 63 | return z_t, recon 64 | 65 | def _init_mask(self): 66 | n_per = np.prod(self.z_tfm_shape) 67 | mask = jnp.tril(jnp.ones((self.config.seq_len, self.config.seq_len), dtype=bool)) 68 | mask = mask.repeat(n_per, axis=0).repeat(n_per, axis=1) 69 | return mask 70 | 71 | def encode(self, encodings): 72 | inp = encodings 73 | out = jax.vmap(self.encoder, 1, 1)(inp) 74 | return None, {'embeddings': out} 75 | 76 | def temporal_transformer(self, z_embeddings, actions, cond, deterministic=False): 77 | 78 | inp = z_embeddings 79 | 80 | actions = jnp.tile(actions[:, :, None, None], (1, 1, *inp.shape[2:4], 1)) 81 | inp = jnp.concatenate([inp[:, :-1], actions[:, 1:]], axis=-1) 82 | inp = jax.vmap(self.z_proj, 1, 1)(inp) 83 | 84 | sos = jnp.tile(self.sos[None, None], (z_embeddings.shape[0], 1, 1, 1, 1)) 85 | sos = jnp.asarray(sos, self.dtype) 86 | 87 | inp = jnp.concatenate([sos, inp], axis=1) 88 | deter = self.z_tfm(inp, mask=self._init_mask(), deterministic=deterministic) 89 | deter = deter[:, self.config.n_cond:] 90 | 91 | deter = jax.vmap(self.z_unproj, 1, 1)(deter) 92 | 93 | return deter 94 | 95 | def __call__(self, video, actions, deterministic=False): 96 | # video: BTCHW, actions: BT 97 | if not self.config.use_actions: 98 | if actions is None: 99 | actions = jnp.zeros(video.shape[:2], dtype=jnp.int32) 100 | else: 101 | actions = jnp.zeros_like(actions) 102 | 103 | if self.config.dropout_actions: 104 | dropout_actions = jax.random.bernoulli(self.make_rng('sample'), p=self.config.action_dropout_rate, 105 | shape=(video.shape[0],)) # B 106 | if "minecraft" not in self.config.data_path: 107 | # Don't drop actions for minecraft 108 | actions = jnp.where(dropout_actions[:, None], self.config.action_mask_id, actions) 109 | else: 110 | dropout_actions = None 111 | 112 | actions = self.action_embeds(actions) 113 | encodings = video 114 | 115 | cond, vq_output = self.encode(encodings) 116 | z_embeddings = vq_output['embeddings'] 117 | 118 | deter = self.temporal_transformer( 119 | z_embeddings, actions, cond, deterministic=deterministic 120 | ) 121 | 122 | labels = video[:, self.config.n_cond:] 123 | 124 | if self.config.drop_loss_rate is not None and self.config.drop_loss_rate > 0.0: 125 | n_sample = int((1 - self.config.drop_loss_rate) * deter.shape[1]) 126 | n_sample = max(1, n_sample) 127 | idxs = jax.random.randint(self.make_rng('sample'), 128 | [n_sample], 129 | 0, video.shape[1], dtype=jnp.int32) 130 | else: 131 | idxs = jnp.arange(deter.shape[1], dtype=jnp.int32) 132 | 133 | deter = deter[:, idxs] 134 | labels = labels[:, idxs] 135 | 136 | # Decoder loss 137 | recon_logits = nn.tanh(jax.vmap(self.decoder, 1, 1)(deter)) 138 | 139 | mse_loss = 2*optax.l2_loss(recon_logits, labels) #optax puts a 0.5 in front automatically 140 | mse_loss = mse_loss.sum(axis=(-2, -1)) 141 | mse_loss = mse_loss.mean() 142 | 143 | l1_loss = jnp.abs(recon_logits-labels) 144 | l1_loss = l1_loss.sum(axis=(-2, -1)) 145 | l1_loss = l1_loss.mean() 146 | 147 | loss = self.config.loss_weight * mse_loss + (1-self.config.loss_weight) * l1_loss 148 | 149 | out = dict(loss=loss, mse_loss=mse_loss, l1_loss=l1_loss) 150 | return out 151 | 152 | 153 | -------------------------------------------------------------------------------- /src/models/sequence_models/noVQ/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/ConvSSM/cb08bb3d5eef4f7438a4d76f4c24007beecd66db/src/models/sequence_models/noVQ/__init__.py -------------------------------------------------------------------------------- /src/models/sequence_models/noVQ/convLSTM.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # 4 | # This work is made available under the Nvidia Source Code License. 5 | # To view a copy of this license, visit 6 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 7 | # 8 | # Written by Jimmy Smith 9 | # ------------------------------------------------------------------------------ 10 | 11 | 12 | from typing import Optional, Any 13 | import optax 14 | import flax.linen as nn 15 | import jax 16 | import jax.numpy as jnp 17 | import numpy as np 18 | 19 | from src.models.convS5.conv_ops import VmapBasicConv 20 | from src.models.convLSTM.ssm import init_ConvLSTM 21 | from src.models.base import ResNetEncoder, ResNetDecoder 22 | from src.models.convLSTM.layers import StackedLayers 23 | 24 | 25 | def reshape_data(frames): 26 | # Make(seq_len, dev_bsz, H, W, in_dim) 27 | frames = frames.transpose(1, 0, 2, 3, 4) 28 | return frames 29 | 30 | 31 | class CONVLSTM_NOVQ(nn.Module): 32 | config: Any 33 | training: bool 34 | parallel: bool 35 | dtype: Optional[Any] = jnp.float32 36 | 37 | @property 38 | def metrics(self): 39 | metrics = ['loss', 'mse_loss', 'l1_loss'] 40 | return metrics 41 | 42 | def setup(self): 43 | config = self.config 44 | 45 | # Sequence Model 46 | self.ssm = init_ConvLSTM(config.d_model, 47 | config.ssm['ssm_size'], 48 | config.ssm['kernel_size']) 49 | self.sequence_model = StackedLayers(**self.config.seq_model, 50 | ssm=self.ssm, 51 | training=self.training, 52 | parallel=self.parallel) 53 | 54 | initial_states = [] 55 | bsz_device, _ = divmod(config.batch_size, jax.device_count()) 56 | for i in range(config.seq_model['n_layers']): 57 | initial_states.append( 58 | (np.zeros((bsz_device, 59 | config.latent_height, 60 | config.latent_width, 61 | config.ssm['ssm_size'])), 62 | np.zeros((bsz_device, 63 | config.latent_height, 64 | config.latent_width, 65 | config.ssm['ssm_size'])) 66 | ) 67 | ) 68 | 69 | self.initial_states = initial_states 70 | 71 | self.action_embeds = nn.Embed(config.action_dim + 1, config.action_embed_dim, dtype=self.dtype) 72 | self.action_conv = VmapBasicConv(k_size=1, 73 | out_channels=config.d_model) 74 | 75 | # Encoder 76 | self.encoder = ResNetEncoder(**config.encoder, dtype=self.dtype) 77 | 78 | # Decoder 79 | out_dim = self.config.channels 80 | self.decoder = ResNetDecoder(**config.decoder, image_size=0, 81 | out_dim=out_dim, dtype=self.dtype) 82 | 83 | def sample_timestep(self, encoding, initial_states, action): 84 | inp = self.encode(encoding) 85 | 86 | action = self.action_embeds(action) 87 | action = jnp.tile(action[:, :, None, None], (1, 1, *inp.shape[2:4], 1)) 88 | inp = jnp.concatenate([inp, action], axis=-1) 89 | 90 | # inp is BTHWC, convS5 model needs TBHWC 91 | inp = reshape_data(inp) 92 | inp = self.action_conv(inp) 93 | last_states, deter = self.sequence_model(inp, initial_states) 94 | deter = reshape_data(deter) # Now BTHWC 95 | 96 | recon_logits, recon = self.reconstruct(deter) 97 | return recon, recon_logits, recon, last_states 98 | 99 | def encode(self, encodings): 100 | out = jax.vmap(self.encoder, 1, 1)(encodings) 101 | 102 | return out 103 | 104 | def condition(self, encodings, actions, initial_states=None): 105 | if initial_states is None: 106 | initial_states = self.initial_states 107 | 108 | # video: BTCHW, actions: BT 109 | inp = self.encode(encodings) 110 | 111 | # Combine inputs and actions 112 | actions = self.action_embeds(actions) 113 | actions = jnp.tile(actions[:, :, None, None], (1, 1, *inp.shape[2:4], 1)) 114 | inp = jnp.concatenate([inp[:, :-1], actions[:, 1:]], axis=-1) 115 | 116 | # inp is BTHWC, convS5 model needs TBHWC 117 | inp = reshape_data(inp) 118 | inp = self.action_conv(inp) 119 | last_states, deter = self.sequence_model(inp, initial_states) 120 | deter = reshape_data(deter) # swap back to BTHWC 121 | 122 | return None, encodings, None, deter, last_states 123 | 124 | def reconstruct(self, deter): 125 | recon_logits = jax.vmap(self.decoder, 1, 1)(deter) 126 | recon = nn.tanh(recon_logits) 127 | return recon, recon 128 | 129 | def __call__(self, video, actions, deterministic=False): 130 | # video: BTHWC, actions: BT 131 | if not self.config.use_actions: 132 | if actions is None: 133 | actions = jnp.zeros(video.shape[:2], dtype=jnp.int32) 134 | else: 135 | actions = jnp.zeros_like(actions) 136 | 137 | if self.config.dropout_actions: 138 | dropout_actions = jax.random.bernoulli(self.make_rng('sample'), p=self.config.action_dropout_rate, 139 | shape=(video.shape[0],)) # B 140 | actions = jnp.where(dropout_actions[:, None], self.config.action_mask_id, actions) 141 | 142 | encodings = video 143 | 144 | _, _, _, deter, _ = self.condition(encodings, actions) 145 | 146 | labels = video[:, self.config.n_cond:] 147 | 148 | if self.config.drop_loss_rate is not None and self.config.drop_loss_rate > 0.0: 149 | n_sample = int((1 - self.config.drop_loss_rate) * deter.shape[1]) 150 | n_sample = max(1, n_sample) 151 | idxs = jax.random.randint(self.make_rng('sample'), 152 | [n_sample], 153 | 0, video.shape[1], dtype=jnp.int32) 154 | else: 155 | idxs = jnp.arange(deter.shape[1], dtype=jnp.int32) 156 | 157 | deter = deter[:, idxs] 158 | labels = labels[:, idxs] 159 | 160 | # Decoder loss 161 | recon_logits, _ = self.reconstruct(deter) 162 | 163 | mse_loss = 2*optax.l2_loss(recon_logits, labels) # optax puts a 0.5 in front automatically 164 | mse_loss = mse_loss.sum(axis=(-2, -1)) 165 | mse_loss = mse_loss.mean() 166 | 167 | l1_loss = jnp.abs(recon_logits-labels) 168 | l1_loss = l1_loss.sum(axis=(-2, -1)) 169 | l1_loss = l1_loss.mean() 170 | 171 | loss = self.config.loss_weight * mse_loss + (1-self.config.loss_weight) * l1_loss 172 | 173 | out = dict(loss=loss, mse_loss=mse_loss, l1_loss=l1_loss) 174 | return out 175 | -------------------------------------------------------------------------------- /src/models/sequence_models/noVQ/convS5.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # 4 | # This work is made available under the Nvidia Source Code License. 5 | # To view a copy of this license, visit 6 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 7 | # 8 | # Written by Jimmy Smith 9 | # ------------------------------------------------------------------------------ 10 | 11 | from typing import Optional, Any 12 | import optax 13 | import flax.linen as nn 14 | import jax 15 | import jax.numpy as jnp 16 | 17 | from src.models.convS5.conv_ops import VmapBasicConv 18 | from src.models.convS5.diagonal_ssm import init_ConvS5SSM 19 | from src.models.base import ResNetEncoder, ResNetDecoder 20 | from src.models.convS5.layers import StackedLayers 21 | 22 | 23 | def reshape_data(frames): 24 | # Make(seq_len, dev_bsz, H, W, in_dim) 25 | frames = frames.transpose(1, 0, 2, 3, 4) 26 | return frames 27 | 28 | 29 | class CONVS5_NOVQ(nn.Module): 30 | config: Any 31 | training: bool 32 | parallel: bool 33 | dtype: Optional[Any] = jnp.float32 34 | 35 | @property 36 | def metrics(self): 37 | metrics = ['loss', 'mse_loss', 'l1_loss'] 38 | return metrics 39 | 40 | def setup(self): 41 | config = self.config 42 | 43 | # Sequence Model 44 | self.ssm = init_ConvS5SSM(config.ssm['ssm_size'], 45 | config.ssm['blocks'], 46 | config.ssm['clip_eigs'], 47 | config.d_model, 48 | config.ssm['B_kernel_size'], 49 | config.ssm['C_kernel_size'], 50 | config.ssm['D_kernel_size'], 51 | config.ssm['dt_min'], 52 | config.ssm['dt_max'], 53 | config.ssm['C_D_config']) 54 | self.sequence_model = StackedLayers(**self.config.seq_model, 55 | ssm=self.ssm, 56 | training=self.training, 57 | parallel=self.parallel) 58 | 59 | initial_states = [] 60 | bsz_device, _ = divmod(config.batch_size, jax.device_count()) 61 | for i in range(config.seq_model['n_layers']): 62 | initial_states.append(jnp.zeros((bsz_device, 63 | config.latent_height, 64 | config.latent_width, 65 | config.ssm['ssm_size']//2)) 66 | ) 67 | 68 | self.initial_states = initial_states 69 | 70 | self.action_embeds = nn.Embed(config.action_dim + 1, config.action_embed_dim, dtype=self.dtype) 71 | self.action_conv = VmapBasicConv(k_size=1, 72 | out_channels=config.d_model) 73 | 74 | # Encoder 75 | self.encoder = ResNetEncoder(**config.encoder, dtype=self.dtype) 76 | 77 | # Decoder 78 | out_dim = self.config.channels 79 | self.decoder = ResNetDecoder(**config.decoder, image_size=0, 80 | out_dim=out_dim, dtype=self.dtype) 81 | 82 | def sample_timestep(self, encoding, initial_states, action): 83 | inp = self.encode(encoding) 84 | 85 | action = self.action_embeds(action) 86 | action = jnp.tile(action[:, :, None, None], (1, 1, *inp.shape[2:4], 1)) 87 | inp = jnp.concatenate([inp, action], axis=-1) 88 | 89 | # inp is BTHWC, convS5 model needs TBHWC 90 | inp = reshape_data(inp) 91 | inp = self.action_conv(inp) 92 | last_states, deter = self.sequence_model(inp, initial_states) 93 | deter = reshape_data(deter) # Now BTHWC 94 | 95 | recon_logits, recon = self.reconstruct(deter) 96 | return recon, recon_logits, recon, last_states 97 | 98 | def encode(self, encodings): 99 | out = jax.vmap(self.encoder, 1, 1)(encodings) 100 | 101 | return out 102 | 103 | def condition(self, encodings, actions, initial_states=None): 104 | if initial_states is None: 105 | initial_states = self.initial_states 106 | 107 | # video: BTCHW, actions: BT 108 | inp = self.encode(encodings) 109 | 110 | # Combine inputs and actions 111 | actions = self.action_embeds(actions) 112 | actions = jnp.tile(actions[:, :, None, None], (1, 1, *inp.shape[2:4], 1)) 113 | inp = jnp.concatenate([inp[:, :-1], actions[:, 1:]], axis=-1) 114 | 115 | # inp is BTHWC, convS5 model needs TBHWC 116 | inp = reshape_data(inp) 117 | inp = self.action_conv(inp) 118 | last_states, deter = self.sequence_model(inp, initial_states) 119 | deter = reshape_data(deter) # swap back to BTHWC 120 | 121 | return None, encodings, None, deter, last_states 122 | 123 | def reconstruct(self, deter): 124 | recon_logits = jax.vmap(self.decoder, 1, 1)(deter) 125 | recon = nn.tanh(recon_logits) 126 | return recon, recon 127 | 128 | def __call__(self, video, actions, deterministic=False): 129 | # video: BTHWC, actions: BT 130 | if not self.config.use_actions: 131 | if actions is None: 132 | actions = jnp.zeros(video.shape[:2], dtype=jnp.int32) 133 | else: 134 | actions = jnp.zeros_like(actions) 135 | 136 | if self.config.dropout_actions: 137 | dropout_actions = jax.random.bernoulli(self.make_rng('sample'), p=self.config.action_dropout_rate, 138 | shape=(video.shape[0],)) # B 139 | actions = jnp.where(dropout_actions[:, None], self.config.action_mask_id, actions) 140 | 141 | encodings = video 142 | 143 | _, _, _, deter, _ = self.condition(encodings, actions) 144 | 145 | labels = video[:, self.config.n_cond:] 146 | 147 | if self.config.drop_loss_rate is not None and self.config.drop_loss_rate > 0.0: 148 | n_sample = int((1 - self.config.drop_loss_rate) * deter.shape[1]) 149 | n_sample = max(1, n_sample) 150 | idxs = jax.random.randint(self.make_rng('sample'), 151 | [n_sample], 152 | 0, video.shape[1], dtype=jnp.int32) 153 | else: 154 | idxs = jnp.arange(deter.shape[1], dtype=jnp.int32) 155 | 156 | deter = deter[:, idxs] 157 | labels = labels[:, idxs] 158 | 159 | # Decoder loss 160 | recon_logits, _ = self.reconstruct(deter) 161 | 162 | mse_loss = 2*optax.l2_loss(recon_logits, labels) # optax puts a 0.5 in front automatically 163 | mse_loss = mse_loss.sum(axis=(-2, -1)) 164 | mse_loss = mse_loss.mean() 165 | 166 | l1_loss = jnp.abs(recon_logits-labels) 167 | l1_loss = l1_loss.sum(axis=(-2, -1)) 168 | l1_loss = l1_loss.mean() 169 | 170 | loss = self.config.loss_weight * mse_loss + (1-self.config.loss_weight) * l1_loss 171 | 172 | out = dict(loss=loss, mse_loss=mse_loss, l1_loss=l1_loss) 173 | return out 174 | 175 | 176 | 177 | 178 | 179 | -------------------------------------------------------------------------------- /src/models/sequence_models/noVQ/transformer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 BAIR OPEN RESEARCH COMMONS REPOSITORY 4 | # To view a copy of this license, visit 5 | # https://github.com/wilson1yan/teco/tree/master 6 | # ------------------------------------------------------------------------------ 7 | # 8 | # ------------------------------------------------------------------------------ 9 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 10 | # 11 | # This work is made available under the Nvidia Source Code License. 12 | # To view a copy of this license, visit 13 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 14 | # 15 | # Written by Jimmy Smith 16 | # ------------------------------------------------------------------------------ 17 | 18 | 19 | from typing import Optional, Any 20 | import optax 21 | import numpy as np 22 | import flax.linen as nn 23 | import jax 24 | import jax.numpy as jnp 25 | 26 | from src.models.transformer.transformer import Transformer 27 | from src.models.base import ResNetEncoder, ResNetDecoder 28 | 29 | 30 | class TRANSFORMER_NOVQ(nn.Module): 31 | config: Any 32 | dtype: Optional[Any] = jnp.float32 33 | 34 | @property 35 | def metrics(self): 36 | metrics = ['loss', 'mse_loss', 'l1_loss'] 37 | return metrics 38 | 39 | def setup(self): 40 | config = self.config 41 | 42 | self.action_embeds = nn.Embed(config.action_dim + 1, config.action_embed_dim, dtype=self.dtype) 43 | 44 | # Posterior 45 | self.encoder = ResNetEncoder(**config.encoder, dtype=self.dtype) 46 | ds = 1 47 | self.z_shape = tuple([d // ds for d in self.config.latent_shape]) 48 | # Temporal Transformer 49 | z_kernel = [config.z_ds, config.z_ds] 50 | self.z_tfm_shape = tuple([d // config.z_ds for d in self.z_shape]) 51 | self.z_proj = nn.Conv(config.z_tfm_kwargs['embed_dim'], z_kernel, 52 | strides=z_kernel, use_bias=False, padding='VALID', dtype=self.dtype) 53 | 54 | self.sos = self.param('sos', nn.initializers.normal(stddev=0.02), 55 | (*self.z_tfm_shape, config.z_tfm_kwargs['embed_dim'],), jnp.float32) 56 | self.z_tfm = Transformer( 57 | **config.z_tfm_kwargs, pos_embed_type='sinusoidal', 58 | shape=(config.seq_len, *self.z_tfm_shape), 59 | dtype=self.dtype 60 | ) 61 | self.z_unproj = nn.ConvTranspose(config.embedding_dim, z_kernel, strides=z_kernel, 62 | padding='VALID', use_bias=False, dtype=self.dtype) 63 | 64 | # Decoder 65 | out_dim = self.config.channels 66 | self.decoder = ResNetDecoder(**config.decoder, image_size=0, 67 | out_dim=out_dim, dtype=self.dtype) 68 | 69 | def sample_timestep(self, z_embeddings, actions, cond, t): 70 | t -= self.config.n_cond 71 | actions = self.action_embeds(actions) 72 | deter = self.temporal_transformer( 73 | z_embeddings, actions, cond, deterministic=True 74 | ) 75 | deter = deter[:, t] 76 | 77 | recon_logits = nn.tanh(self.decoder(deter)) 78 | recon = recon_logits 79 | z_t = self.encoder(recon) 80 | return z_t, recon 81 | 82 | def _init_mask(self): 83 | n_per = np.prod(self.z_tfm_shape) 84 | mask = jnp.tril(jnp.ones((self.config.seq_len, self.config.seq_len), dtype=bool)) 85 | mask = mask.repeat(n_per, axis=0).repeat(n_per, axis=1) 86 | return mask 87 | 88 | def encode(self, encodings): 89 | inp = encodings 90 | out = jax.vmap(self.encoder, 1, 1)(inp) 91 | return None, {'embeddings': out} 92 | 93 | def temporal_transformer(self, z_embeddings, actions, cond, deterministic=False): 94 | 95 | inp = z_embeddings 96 | 97 | actions = jnp.tile(actions[:, :, None, None], (1, 1, *inp.shape[2:4], 1)) 98 | inp = jnp.concatenate([inp[:, :-1], actions[:, 1:]], axis=-1) 99 | inp = jax.vmap(self.z_proj, 1, 1)(inp) 100 | 101 | sos = jnp.tile(self.sos[None, None], (z_embeddings.shape[0], 1, 1, 1, 1)) 102 | sos = jnp.asarray(sos, self.dtype) 103 | 104 | inp = jnp.concatenate([sos, inp], axis=1) 105 | deter = self.z_tfm(inp, mask=self._init_mask(), deterministic=deterministic) 106 | deter = deter[:, self.config.n_cond:] 107 | 108 | deter = jax.vmap(self.z_unproj, 1, 1)(deter) 109 | 110 | return deter 111 | 112 | def __call__(self, video, actions, deterministic=False): 113 | # video: BTCHW, actions: BT 114 | if not self.config.use_actions: 115 | if actions is None: 116 | actions = jnp.zeros(video.shape[:2], dtype=jnp.int32) 117 | else: 118 | actions = jnp.zeros_like(actions) 119 | 120 | if self.config.dropout_actions: 121 | dropout_actions = jax.random.bernoulli(self.make_rng('sample'), p=self.config.action_dropout_rate, 122 | shape=(video.shape[0],)) # B 123 | if "minecraft" not in self.config.data_path: 124 | # Don't drop actions for minecraft 125 | actions = jnp.where(dropout_actions[:, None], self.config.action_mask_id, actions) 126 | else: 127 | dropout_actions = None 128 | 129 | actions = self.action_embeds(actions) 130 | encodings = video 131 | 132 | cond, vq_output = self.encode(encodings) 133 | z_embeddings = vq_output['embeddings'] 134 | 135 | deter = self.temporal_transformer( 136 | z_embeddings, actions, cond, deterministic=deterministic 137 | ) 138 | 139 | labels = video[:, self.config.n_cond:] 140 | 141 | if self.config.drop_loss_rate is not None and self.config.drop_loss_rate > 0.0: 142 | n_sample = int((1 - self.config.drop_loss_rate) * deter.shape[1]) 143 | n_sample = max(1, n_sample) 144 | idxs = jax.random.randint(self.make_rng('sample'), 145 | [n_sample], 146 | 0, video.shape[1], dtype=jnp.int32) 147 | else: 148 | idxs = jnp.arange(deter.shape[1], dtype=jnp.int32) 149 | 150 | deter = deter[:, idxs] 151 | labels = labels[:, idxs] 152 | 153 | # Decoder loss 154 | recon_logits = nn.tanh(jax.vmap(self.decoder, 1, 1)(deter)) 155 | 156 | mse_loss = 2*optax.l2_loss(recon_logits, labels) # optax puts a 0.5 in front automatically 157 | mse_loss = mse_loss.sum(axis=(-2, -1)) 158 | mse_loss = mse_loss.mean() 159 | 160 | l1_loss = jnp.abs(recon_logits-labels) 161 | l1_loss = l1_loss.sum(axis=(-2, -1)) 162 | l1_loss = l1_loss.mean() 163 | 164 | loss = self.config.loss_weight * mse_loss + (1-self.config.loss_weight) * l1_loss 165 | 166 | out = dict(loss=loss, mse_loss=mse_loss, l1_loss=l1_loss) 167 | return out 168 | 169 | 170 | -------------------------------------------------------------------------------- /src/models/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/ConvSSM/cb08bb3d5eef4f7438a4d76f4c24007beecd66db/src/models/transformer/__init__.py -------------------------------------------------------------------------------- /src/models/transformer/maskgit.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 BAIR OPEN RESEARCH COMMONS REPOSITORY 4 | # To view a copy of this license, visit 5 | # https://github.com/wilson1yan/teco/tree/master 6 | # ------------------------------------------------------------------------------ 7 | 8 | from typing import Optional, Any, Tuple, Dict, Callable 9 | import math 10 | import numpy as np 11 | import jax 12 | import flax.linen as nn 13 | import jax.numpy as jnp 14 | 15 | from src.utils import topk_sample 16 | from src.models.transformer.transformer import Transformer, LayerNorm 17 | from src.models.base import AddBias 18 | 19 | Array = Any 20 | Dtype = Any 21 | 22 | 23 | MASK_ID = -1 24 | 25 | 26 | def schedule(ratio, total_unknown, method='cosine'): 27 | if method == 'uniform': 28 | mask_ratio = 1. - ratio 29 | elif 'pow' in method: 30 | exponent = float(method.replace('pow', '')) 31 | mask_ratio = 1. - ratio ** exponent 32 | elif method == 'cosine': 33 | mask_ratio = jax.lax.cos(math.pi / 2. * ratio) 34 | elif method == 'log': 35 | mask_ratio = -jnp.log2(ratio) / jnp.log2(total_unknown) 36 | elif method == 'exp': 37 | mask_ratio = 1 - jnp.exp2(-jnp.log2(total_unknown) * (1 - ratio)) 38 | mask_ratio = jnp.clip(mask_ratio, 1e-6, 1.) 39 | return mask_ratio 40 | 41 | 42 | def mask_by_random_topk(rng, mask_len, probs, temperature=1.0): 43 | confidence = jnp.log(probs) + temperature * jax.random.gumbel(rng, probs.shape) 44 | sorted_confidence = jnp.sort(confidence, axis=-1) 45 | cut_off = jnp.take_along_axis(sorted_confidence, mask_len, axis=-1) 46 | masking = (confidence < cut_off) 47 | return masking 48 | 49 | 50 | def sample_mask(Z, T, rng): 51 | N = np.prod(Z) 52 | idxs = jnp.arange(N, dtype=jnp.int32) 53 | idxs = jax.random.permutation(rng, idxs) 54 | chunks = jnp.array_split(idxs, T) 55 | 56 | masks = [] 57 | for t in range(T): 58 | mask = jax.nn.one_hot(chunks[t], N).sum(axis=0).astype(bool) 59 | mask = jnp.reshape(mask, Z) 60 | masks.append(mask) 61 | return masks 62 | 63 | 64 | class MaskGit(nn.Module): 65 | shape: Tuple[int] 66 | vocab_size: int 67 | vocab_dim: int 68 | mask_schedule: str 69 | tfm_kwargs: Dict[str, Any] 70 | dtype: Optional[Any] = jnp.float32 71 | 72 | def setup(self): 73 | self.token_embed = self.param('token_embed', nn.initializers.normal(stddev=0.02), 74 | [self.vocab_size + 1, self.vocab_dim], 75 | jnp.float32) 76 | 77 | self.net = Transformer( 78 | **self.tfm_kwargs, 79 | shape=self.shape, 80 | pos_embed_type='broadcast', 81 | dtype=self.dtype 82 | ) 83 | self.mlm = MlmLayer(self.vocab_dim, self.dtype) 84 | 85 | def _step(self, x, cond=None, deterministic=False): 86 | token_embed = jnp.asarray(self.token_embed, self.dtype) 87 | x = token_embed[(x,)] 88 | 89 | x = self.net(x, cond=cond, deterministic=deterministic) 90 | logits = self.mlm(x, self.token_embed[:self.vocab_size]) 91 | return logits 92 | 93 | def sample(self, n, T_draft, T_revise, M, cond=None): 94 | sample = jnp.full((n, *self.shape), MASK_ID, dtype=jnp.int32) 95 | 96 | def _update(samples, masks): 97 | for mask in masks: 98 | samples = jnp.where(mask, MASK_ID, samples) 99 | logits = self._step(samples, cond=cond, deterministic=True) 100 | s = topk_sample(self.make_rng('sample'), logits) 101 | samples = jnp.where(mask, s, samples) 102 | return samples 103 | 104 | # Draft 105 | masks = sample_mask(self.shape, T_draft, self.make_rng('sample')) 106 | sample = _update(sample, masks) 107 | 108 | # Revise 109 | for _ in range(M): 110 | masks = sample_mask(self.shape, T_revise, self.make_rng('sample')) 111 | sample = _update(sample, masks) 112 | 113 | return sample 114 | 115 | def __call__(self, x, cond=None, deterministic=False): 116 | # x: B..., cond: B...D 117 | B, L = x.shape[0], np.prod(x.shape[1:]) 118 | 119 | ratio = jax.random.uniform(self.make_rng('sample'), shape=(B,), dtype=self.dtype) 120 | ratio = schedule(ratio, L, method=self.mask_schedule) 121 | ratio = jnp.maximum(1, jnp.floor(ratio * L)) 122 | 123 | sample = jnp.arange(L)[None, :].repeat(B, axis=0) 124 | sample = jax.random.permutation(self.make_rng('sample'), sample, axis=-1, independent=True) 125 | mask = sample < ratio[:, None] 126 | mask = mask.reshape(x.shape) 127 | 128 | masked_x = jnp.where(mask, MASK_ID, x) 129 | logits = self._step(masked_x, cond=cond, deterministic=deterministic) 130 | labels = jax.nn.one_hot(x, num_classes=self.vocab_size) 131 | return logits, labels, mask 132 | 133 | 134 | class MlmLayer(nn.Module): 135 | vocab_dim: int 136 | dtype: Optional[Any] = jnp.float32 137 | 138 | @nn.compact 139 | def __call__(self, x, embeddings): 140 | x = nn.Dense(self.vocab_dim, dtype=self.dtype, 141 | kernel_init=nn.initializers.normal(stddev=0.02))(x) 142 | x = nn.gelu(x) 143 | x = LayerNorm(dtype=self.dtype)(x) 144 | 145 | output_weights = jnp.transpose(embeddings) 146 | logits = jnp.matmul(x, output_weights) 147 | logits = AddBias(self.dtype)(logits) 148 | return logits 149 | -------------------------------------------------------------------------------- /src/runtime_metrics.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 BAIR OPEN RESEARCH COMMONS REPOSITORY 4 | # To view a copy of this license, visit 5 | # https://github.com/wilson1yan/teco/tree/master 6 | # ------------------------------------------------------------------------------ 7 | # 8 | # ------------------------------------------------------------------------------ 9 | # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 10 | # 11 | # This work is made available under the Nvidia Source Code License. 12 | # To view a copy of this license, visit 13 | # https://github.com/NVlabs/ConvSSM/blob/main/LICENSE 14 | # 15 | # Written by Jimmy Smith 16 | # ------------------------------------------------------------------------------ 17 | 18 | import jax 19 | import jax.numpy as jnp 20 | import lpips_jax 21 | 22 | # Metrics below adapted from https://github.com/wilson1yan/teco/blob/bf56c2956515751bfd2b90355d52f7e362e288a1/teco/metrics.py 23 | 24 | 25 | def compute_metric(prediction, ground_truth, metric_fn, average_dim=1): 26 | # BTHWC in [0, 1] 27 | assert prediction.shape == ground_truth.shape 28 | B, T = prediction.shape[0], prediction.shape[1] 29 | prediction = prediction.reshape(-1, *prediction.shape[2:]) 30 | ground_truth = ground_truth.reshape(-1, *ground_truth.shape[2:]) 31 | 32 | metrics = metric_fn(prediction, ground_truth) 33 | metrics = jnp.reshape(metrics, (B, T)) 34 | metrics = metrics.mean(axis=average_dim) # B or T depending on dim 35 | return metrics 36 | 37 | 38 | # all methods below take as input pairs of images 39 | # of shape BHWC. They DO NOT reduce batch dimension 40 | # NOTE: Assumes that images are in [0, 1] 41 | def get_ssim(pred, truth, average_dim=1): 42 | # output is shape bsz 43 | 44 | def fn(imgs1, imgs2): 45 | ssim_fn = ssim 46 | ssim_val = ssim_fn(imgs1, imgs2) 47 | return ssim_val 48 | 49 | return compute_metric(pred, truth, fn, average_dim=average_dim) 50 | 51 | 52 | def ssim(img1, img2, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): 53 | ssim_per_channel, _ = _ssim_per_channel(img1, img2, max_val, filter_size, filter_sigma, k1, k2) 54 | return jnp.mean(ssim_per_channel, axis=-1) 55 | 56 | 57 | def _ssim_per_channel(img1, img2, max_val, filter_size, filter_sigma, k1, k2): 58 | kernel = _fspecial_gauss(filter_size, filter_sigma) 59 | kernel = jnp.tile(kernel, [1, 1, img1.shape[-1], 1]) 60 | kernel = jnp.transpose(kernel, [2, 3, 0, 1]) 61 | 62 | compensation = 1.0 63 | 64 | def reducer(x): 65 | x_shape = x.shape 66 | x = jnp.reshape(x, (-1, *x.shape[-3:])) 67 | x = jnp.transpose(x, [0, 3, 1, 2]) 68 | y = jax.lax.conv_general_dilated(x, kernel, [1, 1], 69 | 'VALID', feature_group_count=x.shape[1]) 70 | 71 | y = jnp.reshape(y, [*x_shape[:-3], *y.shape[1:]]) 72 | return y 73 | 74 | luminance, cs = _ssim_helper(img1, img2, reducer, max_val, compensation, k1, k2) 75 | ssim_val = jnp.mean(luminance * cs, axis=[-3, -2]) 76 | cs = jnp.mean(cs, axis=[-3, -2]) 77 | return ssim_val, cs 78 | 79 | 80 | def _ssim_helper(x, y, reducer, max_val, compensation=1.0, k1=0.01, k2=0.03): 81 | c1 = (k1 * max_val) ** 2 82 | c2 = (k2 * max_val) ** 2 83 | 84 | mean0 = reducer(x) 85 | mean1 = reducer(y) 86 | 87 | num0 = mean0 * mean1 * 2.0 88 | den0 = jnp.square(mean0) + jnp.square(mean1) 89 | luminance = (num0 + c1) / (den0 + c1) 90 | 91 | num1 = reducer(x * y) * 2.0 92 | den1 = reducer(jnp.square(x) + jnp.square(y)) 93 | c2 *= compensation 94 | cs = (num1 - num0 + c2) / (den1 - den0 + c2) 95 | 96 | return luminance, cs 97 | 98 | 99 | def _fspecial_gauss(size, sigma): 100 | coords = jnp.arange(size, dtype=jnp.float32) 101 | coords -= (size - 1.0) / 2.0 102 | 103 | g = jnp.square(coords) 104 | g *= -0.5 / jnp.square(sigma) 105 | 106 | g = jnp.reshape(g, [1, -1]) + jnp.reshape(g, [-1, 1]) 107 | g = jnp.reshape(g, [1, -1]) 108 | g = jax.nn.softmax(g, axis=-1) 109 | return jnp.reshape(g, [size, size, 1, 1]) 110 | 111 | 112 | def get_psnr(pred, truth, average_dim=1): 113 | def fn(imgs1, imgs2): 114 | psnr_fn = psnr 115 | psnr_val = psnr_fn(imgs1, imgs2) 116 | return psnr_val 117 | return compute_metric(pred, truth, fn, average_dim=average_dim) 118 | 119 | 120 | def psnr(a, b, max_val=1.0): 121 | mse = jnp.mean((a - b) ** 2, axis=[-3, -2, -1]) 122 | val = 20 * jnp.log(max_val) / jnp.log(10.0) - jnp.float32(10 / jnp.log(10)) * jnp.log(mse) 123 | return val 124 | 125 | 126 | def get_lpips(pred, truth, net='alexnet', average_dim=1): 127 | """net: ['alexnet', 'vgg16']""" 128 | lpips_eval = lpips_jax.LPIPSEvaluator(net=net, replicate=False) 129 | 130 | def fn(imgs1, imgs2): 131 | imgs1 = 2 * imgs1 - 1 132 | imgs2 = 2 * imgs2 - 1 133 | 134 | lpips = lpips_eval(imgs1, imgs2) 135 | lpips = jnp.reshape(lpips, (-1,)) 136 | return lpips 137 | return compute_metric(pred, truth, fn, average_dim=average_dim) 138 | -------------------------------------------------------------------------------- /src/train_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 BAIR OPEN RESEARCH COMMONS REPOSITORY 4 | # To view a copy of this license, visit 5 | # https://github.com/wilson1yan/teco/tree/master 6 | # ------------------------------------------------------------------------------ 7 | 8 | from typing import Any 9 | from collections import OrderedDict 10 | import random 11 | import numpy as np 12 | import jax 13 | from flax.training import train_state 14 | from flax.core.frozen_dict import freeze 15 | import optax 16 | 17 | 18 | class TrainState(train_state.TrainState): 19 | model_state: Any 20 | 21 | 22 | def seed_all(seed): 23 | random.seed(seed) 24 | np.random.seed(seed) 25 | 26 | 27 | def get_first_device(x): 28 | x = jax.tree_util.tree_map(lambda a: a[0], x) 29 | return jax.device_get(x) 30 | 31 | 32 | def print_model_size(params, name=''): 33 | model_params_size = jax.tree_util.tree_map(lambda x: x.size, params) 34 | total_params_size = sum(jax.tree_util.tree_flatten(model_params_size)[0]) 35 | print('model parameter count:', total_params_size) 36 | 37 | 38 | def get_learning_rate_fn(config): 39 | if config.lr_schedule == 'cosine': 40 | learning_rate_fn = optax.warmup_cosine_decay_schedule( 41 | init_value=0., 42 | peak_value=config.lr, 43 | warmup_steps=config.warmup_steps, 44 | decay_steps=config.total_steps - config.warmup_steps 45 | ) 46 | elif config.lr_schedule == 'constant': 47 | learning_rate_fn = optax.join_schedules([ 48 | optax.linear_schedule( 49 | init_value=0., 50 | end_value=config.lr, 51 | transition_steps=config.warmup_steps 52 | ), 53 | optax.constant_schedule(config.lr) 54 | ], [config.warmup_steps]) 55 | else: 56 | raise ValueError(f'Unknown schedule: {config.lr_schedule}') 57 | 58 | return learning_rate_fn 59 | 60 | 61 | def get_optimizer(config): 62 | learning_rate_fn = get_learning_rate_fn(config) 63 | tx = optax.adamw(learning_rate=learning_rate_fn, b1=0.9, b2=0.95, 64 | weight_decay=config.weight_decay) 65 | return tx, learning_rate_fn 66 | 67 | 68 | def init_model_state(rng_key, model, sample, config): 69 | variables = model.init( 70 | rngs={k: rng_key for k in ['params', *config.rng_keys]}, 71 | **{k: sample[k] for k in config.batch_keys} 72 | ).unfreeze() 73 | params = freeze(variables.pop('params')) 74 | model_state = variables 75 | print_model_size(params) 76 | 77 | tx, learning_rate_fn = get_optimizer(config) 78 | 79 | return TrainState.create( 80 | apply_fn=model.apply, 81 | params=params, 82 | tx=tx, 83 | model_state=model_state 84 | ), learning_rate_fn 85 | 86 | 87 | class AverageMeter(object): 88 | """Computes and stores the average and current value""" 89 | def __init__(self, name, fmt=':f'): 90 | self.name = name 91 | self.fmt = fmt 92 | self.reset() 93 | 94 | def reset(self): 95 | self.val = 0 96 | self.avg = 0 97 | self.sum = 0 98 | self.count = 0 99 | 100 | def update(self, val, n=1): 101 | self.val = val 102 | self.sum += val * n 103 | self.count += n 104 | self.avg = self.sum / self.count 105 | 106 | def __str__(self): 107 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 108 | return fmtstr.format(**self.__dict__) 109 | 110 | 111 | class ProgressMeter(object): 112 | def __init__(self, total_iters, meter_names, prefix=""): 113 | self.iter_fmtstr = self._get_iter_fmtstr(total_iters) 114 | self.meters = OrderedDict({mn: AverageMeter(mn, ':6.3f') 115 | for mn in meter_names}) 116 | self.prefix = prefix 117 | 118 | def update(self, n=1, **kwargs): 119 | for k, v in kwargs.items(): 120 | self.meters[k].update(v, n=n) 121 | 122 | def display(self, iteration): 123 | entries = [self.prefix + self.iter_fmtstr.format(iteration)] 124 | entries += [str(meter) for meter in self.meters.values()] 125 | print('\t'.join(entries)) 126 | 127 | def _get_iter_fmtstr(self, total_iters): 128 | num_digits = len(str(total_iters // 1)) 129 | fmt = '{:' + str(num_digits) + 'd}' 130 | return '[' + fmt + '/' + fmt.format(total_iters) + ']' 131 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # MIT License 3 | # Copyright (c) 2022 BAIR OPEN RESEARCH COMMONS REPOSITORY 4 | # To view a copy of this license, visit 5 | # https://github.com/wilson1yan/teco/tree/master 6 | # ------------------------------------------------------------------------------ 7 | 8 | import math 9 | from moviepy.editor import ImageSequenceClip 10 | import numpy as np 11 | 12 | import jax 13 | import jax.numpy as jnp 14 | 15 | 16 | def topk_sample(rng, logits, top_k=None): 17 | if top_k is not None: 18 | top_k = min(top_k, logits.shape[-1]) 19 | indices_to_remove = logits < jax.lax.top_k(logits, top_k)[0][..., -1, None] 20 | logits = jnp.where(indices_to_remove, jnp.finfo(logits.dtype).min, logits) 21 | 22 | samples = jax.random.categorical(rng, logits, axis=-1) 23 | return samples 24 | 25 | 26 | def add_border(video, color, width=0.025): 27 | # video: BTHWC in [0, 1] 28 | S = math.ceil(int(video.shape[3] * width)) 29 | 30 | # top 31 | video[:, :, :S, :, 0] = color[0] 32 | video[:, :, :S, :, 1] = color[1] 33 | video[:, :, :S, :, 2] = color[2] 34 | 35 | # bottom 36 | video[:, :, -S:, :, 0] = color[0] 37 | video[:, :, -S:, :, 1] = color[1] 38 | video[:, :, -S:, :, 2] = color[2] 39 | 40 | # left 41 | video[:, :, :, :S, 0] = color[0] 42 | video[:, :, :, :S, 1] = color[1] 43 | video[:, :, :, :S, 2] = color[2] 44 | 45 | # right 46 | video[:, :, :, -S:, 0] = color[0] 47 | video[:, :, :, -S:, 1] = color[1] 48 | video[:, :, :, -S:, 2] = color[2] 49 | 50 | 51 | def add_border_mnist(video, color, width=0.025): 52 | # video: BTHWC in [0, 1] 53 | S = math.ceil(int(video.shape[3] * width)) 54 | 55 | # top 56 | video[:, :, :S, :, 0] = color[1] 57 | 58 | # bottom 59 | video[:, :, -S:, :, 0] = color[1] 60 | 61 | # left 62 | video[:, :, :, :S, 0] = color[1] 63 | 64 | # right 65 | video[:, :, :, -S:, 0] = color[1] 66 | 67 | 68 | def flatten(x, start=0, end=None): 69 | i, j = start, end 70 | n_dims = len(x.shape) 71 | if i < 0: 72 | i = n_dims + i 73 | 74 | if j is None: 75 | j = n_dims 76 | elif j < 0: 77 | j = n_dims + j 78 | 79 | return reshape_range(x, i, j, (np.prod(x.shape[i:j]),)) 80 | 81 | 82 | def reshape_range(x, i, j, shape): 83 | shape = tuple(shape) 84 | 85 | n_dims = len(x.shape) 86 | if i < 0: 87 | i = n_dims + i 88 | 89 | if j is None: 90 | j = n_dims 91 | elif j < 0: 92 | j = n_dims + j 93 | 94 | assert 0 <= i < j <= n_dims 95 | 96 | x_shape = x.shape 97 | target_shape = x_shape[:i] + shape + x_shape[j:] 98 | return jnp.reshape(x, target_shape) 99 | 100 | 101 | def save_video_grid(video, fname=None, nrow=None, fps=10): 102 | b, t, h, w, c = video.shape 103 | video = (video * 255).astype('uint8') 104 | 105 | if nrow is None: 106 | nrow = math.ceil(math.sqrt(b)) 107 | ncol = math.ceil(b / nrow) 108 | padding = 1 109 | video_grid = np.zeros((t, (padding + h) * ncol + padding, 110 | (padding + w) * nrow + padding, c), dtype='uint8') 111 | for i in range(b): 112 | r = i // nrow 113 | c = i % nrow 114 | 115 | start_r = (padding + h) * r 116 | start_c = (padding + w) * c 117 | video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] 118 | 119 | if fname is not None: 120 | clip = ImageSequenceClip(list(video_grid), fps=fps) 121 | clip.write_gif(fname, fps=fps) 122 | print('saved videos to', fname) 123 | 124 | return video_grid # THWC, uint8 125 | --------------------------------------------------------------------------------