├── .DS_Store ├── LICENSE ├── README.md ├── config ├── __pycache__ │ └── locomotion.cpython-38.pyc └── locomotion.py ├── diffuser ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ └── __init__.cpython-39.pyc ├── datasets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── latendata.cpython-38.pyc │ │ └── utils.cpython-38.pyc │ ├── latendata.py │ └── utils.py ├── models │ ├── GPT2.py │ ├── VLA_model.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── GPT2.cpython-38.pyc │ │ ├── VLA_model.cpython-38.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── diffusion.cpython-37.pyc │ │ ├── diffusion.cpython-38.pyc │ │ ├── encoder.cpython-38.pyc │ │ ├── helpers.cpython-37.pyc │ │ ├── helpers.cpython-38.pyc │ │ ├── soda_model.cpython-38.pyc │ │ ├── temporal.cpython-37.pyc │ │ ├── temporal.cpython-38.pyc │ │ ├── utils.cpython-38.pyc │ │ └── video_model.cpython-38.pyc │ ├── diffusion.py │ ├── encoder.py │ ├── helpers.py │ ├── soda_model.py │ ├── temporal.py │ ├── utils.py │ └── video_model.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── action_tokenizer.cpython-38.pyc │ ├── arrays.cpython-37.pyc │ ├── arrays.cpython-38.pyc │ ├── arrays.cpython-39.pyc │ ├── cloud.cpython-37.pyc │ ├── cloud.cpython-38.pyc │ ├── cloud.cpython-39.pyc │ ├── colab.cpython-37.pyc │ ├── colab.cpython-38.pyc │ ├── config.cpython-37.pyc │ ├── config.cpython-38.pyc │ ├── git_utils.cpython-37.pyc │ ├── git_utils.cpython-38.pyc │ ├── logger.cpython-37.pyc │ ├── logger.cpython-38.pyc │ ├── progress.cpython-37.pyc │ ├── progress.cpython-38.pyc │ ├── progress.cpython-39.pyc │ ├── rendering.cpython-37.pyc │ ├── rendering.cpython-38.pyc │ ├── serialization.cpython-37.pyc │ ├── serialization.cpython-38.pyc │ ├── serialization.cpython-39.pyc │ ├── setup.cpython-37.pyc │ ├── setup.cpython-38.pyc │ ├── setup.cpython-39.pyc │ ├── timer.cpython-37.pyc │ ├── timer.cpython-38.pyc │ ├── timer.cpython-39.pyc │ ├── training.cpython-37.pyc │ ├── training.cpython-38.pyc │ ├── training.cpython-39.pyc │ ├── video.cpython-37.pyc │ └── video.cpython-38.pyc │ ├── action_tokenizer.py │ ├── arrays.py │ ├── cloud.py │ ├── colab.py │ ├── config.py │ ├── distributed │ ├── __pycache__ │ │ ├── distributed.cpython-38.pyc │ │ └── launch.cpython-38.pyc │ ├── distributed.py │ └── launch.py │ ├── git_utils.py │ ├── iql.py │ ├── logger.py │ ├── progress.py │ ├── pybullet_utils.py │ ├── rendering.py │ ├── serialization.py │ ├── setup.py │ ├── timer.py │ ├── training.py │ ├── transformations.py │ └── video.py ├── environment.yml ├── helpers ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── custom_rlbench_env.cpython-38.pyc │ ├── demo_loading_utils.cpython-38.pyc │ ├── network_utils.cpython-38.pyc │ ├── preprocess_agent.cpython-38.pyc │ └── utils.cpython-38.pyc ├── clip │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── __init__.cpython-39.pyc │ └── core │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── clip.cpython-38.pyc │ │ ├── clip.cpython-39.pyc │ │ ├── simple_tokenizer.cpython-38.pyc │ │ └── simple_tokenizer.cpython-39.pyc │ │ ├── attention.py │ │ ├── attention_image_goal.py │ │ ├── bpe_simple_vocab_16e6.txt.gz │ │ ├── clip.py │ │ ├── fusion.py │ │ ├── resnet.py │ │ ├── simple_tokenizer.py │ │ ├── transport.py │ │ ├── transport_image_goal.py │ │ └── unet.py ├── custom_rlbench_env.py ├── demo_loading_utils.py ├── network_utils.py ├── optim │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── lamb.cpython-38.pyc │ └── lamb.py ├── preprocess_agent.py └── utils.py ├── scripts ├── compute_bound_box.py ├── compute_fvd.py ├── count.py ├── preprocess │ ├── bair │ │ ├── bair_extract_images.py │ │ ├── bair_image_to_hdf5.py │ │ └── create_bair_dataset.sh │ └── ucf101 │ │ ├── create_ucf_dataset.sh │ │ └── ucf_split_train_test.py ├── pretrain_meta.py ├── pretrain_video_diff.py ├── test.py ├── test_finetune_meta.py └── train_vqvae.py └── videogpt ├── __init__.py ├── attention.py ├── data.py ├── download.py ├── fvd ├── __init__.py ├── convert_tf_pretrained.py ├── fvd.py └── pytorch_i3d.py ├── gpt.py ├── resnet.py ├── utils.py ├── videogpt ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── attention.cpython-38.pyc │ ├── data.cpython-38.pyc │ ├── download.cpython-38.pyc │ ├── gpt.cpython-38.pyc │ ├── resnet.cpython-38.pyc │ ├── utils.cpython-38.pyc │ └── vqvae.cpython-38.pyc ├── attention.py ├── data.py ├── download.py ├── fvd │ ├── __init__.py │ ├── convert_tf_pretrained.py │ ├── fvd.py │ └── pytorch_i3d.py ├── gpt.py ├── resnet.py ├── utils.py └── vqvae.py └── vqvae.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # :rocket: VPDD | NeurIPS 2024 2 | 3 | ## Learning an Actionable Discrete Diffusion Policy via Large-Scale Actionless Video Pre-Training 4 | 5 | This is the official code for the paper "Learning an Actionable Discrete Diffusion Policy via Large-Scale Actionless Video Pre-Training". 6 | We introduce a novel framework that leverages a unified discrete diffusion to combine generative pre-training on human videos and policy fine-tuning on a small number of action-labeled robot videos. We aim to incorporate foresight from predicted videos to facilitate efficient policy learning. 7 | 8 | 📝 [Paper](https://arxiv.org/abs/2402.14407) \| [中文blog@知乎](https://zhuanlan.zhihu.com/p/684830185) \| [公众号@量子位](https://mp.weixin.qq.com/s/bFVwWpjFQpTTWkbpaEqYCQ) 9 | ## Environment Configurations 10 | ``` 11 | conda env create -f environment.yml 12 | conda activate VPDD 13 | ``` 14 | ## Dataset 15 | - For experiments on RLBench, you can use [pre-generated dataset](https://drive.google.com/drive/folders/0B2LlLwoO3nfZfkFqMEhXWkxBdjJNNndGYl9uUDQwS1pfNkNHSzFDNGwzd1NnTmlpZXR1bVE?resourcekey=0-jRw5RaXEYRLe2W6aNrNFEQ) provided by [PerAct](https://github.com/peract/peract). 16 | - For experiments on MetaWorld, we use the script policy provided in [MetaWorld](https://github.com/Farama-Foundation/Metaworld) to collect image-based data. You can refer to [generate_metaworld_dataset.py](https://github.com/pairlab/QueST/blob/main/scripts/generate_metaworld_dataset.py) for implementation. 17 | 18 | ## Model 19 | The pre-trained VQ-VAE models and discrete diffusion models are available at https://huggingface.co/haoranhe/VPDD-pretrain. You can download them and change the `path` in the corresponding code. 20 | ## Pre-training 21 | 22 | We first train a VQ-VAE to learn a unified discrete latent codebook: 23 | 24 | `torchrun --master_addr=$MASTER_ADDR --master_port=$MASTER_PORT --nproc_per_node=4 --nnodes=$WORLD_SIZE --node_rank=$RANK scripts/train_vqvae.py --gpus=4 --max_epoch=10 --resolution 96 --sequence_length 8 --batch_size 32` 25 | 26 | We then pre-train VPDD on Meta-World: 27 | 28 | `python scripts/pretrain_meta.py --seed 1 --model models.VideoDiffuserModel --diffusion models.GaussianVideoDiffusion --loss_type video --device cuda:0 --batch_size 10 --loader datasets.MetaDataset --act_classes 48` 29 | 30 | or on RLBench which requires multi-view videos prediction: 31 | 32 | `python scripts/pretrain_video_diff.py --seed 1 --model models.VideoDiffuserModel --diffusion models.MultiviewVideoDiffusion --loss_type video --device cuda:0 --batch_size 3 --loader datasets.MultiViewDataset --act_classes 360 --n_diffusion_steps 100` 33 | ## Fine-Tuning 34 | After pre-training, we fine-tune VPDD with a limited set of robot data: 35 | 36 | `python scripts/pretrain_meta.py --seed 1 --model models.VideoDiffuserModel --diffusion models.GaussianVideoDiffusion --loss_type video --device cuda:0 --batch_size 1 --loader datasets.MetaFinetuneDataset --pretrain False` 37 | 38 | `python scripts/pretrain_video_diff.py --seed 1 --model models.VideoDiffuserModel --diffusion models.MultiviewVideoDiffusion --loss_type video --device cuda:0 --batch_size 10 --loader datasets.MultiviewFinetuneDataset --pretrain False --act_classes 360` 39 | 40 | 41 | ## Acknowledgment 42 | Our code for VPDD is partially based on the following awesome projects: 43 | - MTDiff from https://github.com/tinnerhrhe/MTDiff/ 44 | - UniD3 from https://github.com/mhh0318/UniD3 45 | ## Citation 46 | ```bibtex 47 | @inproceedings{ 48 | he2024learning, 49 | title={Learning an Actionable Discrete Diffusion Policy via Large-Scale Actionless Video Pre-Training}, 50 | author={He, Haoran and Bai, Chenjia and Pan, Ling and Zhang, Weinan and Zhao, Bin and Li, Xuelong}, 51 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, 52 | year={2024}, 53 | } 54 | ``` 55 | 56 | ## Star History 57 | 58 | [![Star History Chart](https://api.star-history.com/svg?repos=tinnerhrhe/VPDD&type=Date)](https://star-history.com/#hpcaitech/Open-Sora&Date) 59 | -------------------------------------------------------------------------------- /config/__pycache__/locomotion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/config/__pycache__/locomotion.cpython-38.pyc -------------------------------------------------------------------------------- /config/locomotion.py: -------------------------------------------------------------------------------- 1 | import socket 2 | 3 | from diffuser.utils import watch 4 | 5 | #------------------------ base ------------------------# 6 | 7 | ## automatically make experiment names for planning 8 | ## by labelling folders with these args 9 | 10 | args_to_watch = [ 11 | ('prefix', ''), 12 | ('horizon', 'H'), 13 | ('n_diffusion_steps', 'T'), 14 | ## value kwargs 15 | ('discount', 'd'), 16 | ] 17 | 18 | logbase = 'logs' 19 | 20 | base = { 21 | 'diffusion': { 22 | ## model 23 | 'model': 'models.VideoModel', 24 | 'diffusion': 'models.GaussianDiffusion', 25 | 'horizon': 1, 26 | 'n_diffusion_steps': 100, 27 | 'action_weight': 10, 28 | 'loss_weights': None, 29 | 'loss_discount': 1, 30 | 'predict_epsilon': False, 31 | 'dim_mults': (1, 2, 4, 8), 32 | 'attention': False, 33 | 'renderer': 'utils.MuJoCoRenderer', 34 | ## distributed 35 | 'num_node':1, 36 | 37 | ## dataset 38 | 'loader': 'datasets.VideoDataset', 39 | 'normalizer': 'GaussianNormalizer', 40 | 'data_folder': './data', 41 | 'sequence_length': 16, 42 | 'preprocess_fns': [], 43 | 'clip_denoised': False, 44 | 'use_padding': True, 45 | 'max_path_length': 1000, 46 | 'pretrain': True, 47 | 'concat': False, 48 | 'focal': False, 49 | 'force': False, 50 | 'act_classes': 256, 51 | 'tasks': ['put_all_groceries_in_cupboard', 'set_the_table', 'hang_frame_on_hanger', 'setup_chess', 'turn_tap', 'take_plate_off_colored_dish_rack', 'take_toilet_roll_off_stand', 52 | 'close_jar', 'put_books_at_shelf_location', 'meat_on_grill', 'toilet_seat_down', 'light_bulb_in', 'take_cup_out_from_cabinet', 'wipe_desk', 'tv_on', 'slide_block_to_color_target', 53 | 'open_jar', 'sweep_to_dustpan_of_size', 'screw_nail', 'push_buttons', 'put_groceries_in_cupboard', 'empty_dishwasher', 'put_money_in_safe', 'put_tray_in_oven', 'straighten_rope', 54 | 'solve_puzzle', 'slide_block_to_target', 'place_shape_in_shape_sorter', 'put_item_in_drawer', 'take_shoes_out_of_box', 'lamp_on', 'play_jenga', 'insert_usb_in_computer', 'water_plants', 55 | 'insert_onto_square_peg', 'pour_from_cup_to_cup', 'hit_ball_with_queue', 'take_off_weighing_scales', 'scoop_with_spatula', 'move_hanger', 'unplug_charger', 'reach_and_drag', 'place_wine_at_rack_location', 56 | 'get_ice_from_fridge', 'stack_cups', 'place_cups', 'sweep_to_dustpan', 'meat_off_grill', 'change_clock', 'take_umbrella_out_of_umbrella_stand', 'slide_cabinet_open_and_place_cups', 'put_knife_in_knife_block', 'stack_blocks', 'hockey'], 57 | 'single': False, 58 | 'meta_tasks': ['basketball-v2', 'bin-picking-v2', 'button-press-topdown-v2', 59 | 'button-press-v2', 'button-press-wall-v2', 'coffee-button-v2', 60 | 'coffee-pull-v2', 'coffee-push-v2', 'dial-turn-v2', 'disassemble-v2', 'door-close-v2', 'door-lock-v2', 61 | 'door-open-v2', 'door-unlock-v2', 'hand-insert-v2', 'drawer-close-v2', 'drawer-open-v2', 'faucet-open-v2', 62 | 'faucet-close-v2', 'handle-press-side-v2', 'handle-press-v2', 'handle-pull-side-v2', 'handle-pull-v2', 63 | 'lever-pull-v2', 'peg-insert-side-v2', 'pick-place-wall-v2', 'pick-out-of-hole-v2', 'reach-v2', 'push-back-v2', 64 | 'push-v2', 'pick-place-v2', 'plate-slide-v2', 'plate-slide-side-v2', 'plate-slide-back-v2', 65 | 'plate-slide-back-side-v2', 'soccer-v2', 66 | 'push-wall-v2', 'shelf-place-v2', 'sweep-into-v2', 'sweep-v2', 'window-open-v2', 'window-close-v2','assembly-v2', 67 | 'button-press-topdown-wall-v2','hammer-v2','peg-unplug-side-v2', 68 | 'reach-wall-v2', 'stick-push-v2', 'stick-pull-v2', 'box-close-v2'], 69 | 70 | ## serialization 71 | 'logbase': logbase, 72 | 'prefix': 'diffusion/defaults', 73 | 'exp_name': watch(args_to_watch), 74 | 'num_demos':20, 75 | 76 | ## training 77 | 'n_steps_per_epoch': 50000, 78 | 'loss_type': 'l2', 79 | 'n_train_steps': 1e6, 80 | 'batch_size': 32, 81 | 'learning_rate': 2e-4, 82 | 'gradient_accumulate_every': 2, 83 | 'ema_decay': 0.995, 84 | 'save_freq': 50000, 85 | 'sample_freq': 20000, 86 | 'n_saves': 5, 87 | 'save_parallel': False, 88 | 'n_reference': 8, 89 | 'bucket': None, 90 | 'device': 'cuda', 91 | 'seed': None, 92 | }, 93 | 94 | 'values': { 95 | 'model': 'models.ValueFunction', 96 | 'diffusion': 'models.ValueDiffusion', 97 | 'horizon': 32, 98 | 'n_diffusion_steps': 20, 99 | 'dim_mults': (1, 2, 4, 8), 100 | 'renderer': 'utils.MuJoCoRenderer', 101 | 102 | ## value-specific kwargs 103 | 'discount': 0.99, 104 | 'termination_penalty': -100, 105 | 'normed': False, 106 | 107 | ## dataset 108 | 'loader': 'datasets.ValueDataset', 109 | 'normalizer': 'GaussianNormalizer', 110 | 'preprocess_fns': [], 111 | 'use_padding': True, 112 | 'max_path_length': 1000, 113 | 114 | ## serialization 115 | 'logbase': logbase, 116 | 'prefix': 'values/defaults', 117 | 'exp_name': watch(args_to_watch), 118 | 119 | ## training 120 | 'n_steps_per_epoch': 10000, 121 | 'loss_type': 'value_l2', 122 | 'n_train_steps': 200e3, 123 | 'batch_size': 32, 124 | 'learning_rate': 2e-4, 125 | 'gradient_accumulate_every': 2, 126 | 'ema_decay': 0.995, 127 | 'save_freq': 1000, 128 | 'sample_freq': 0, 129 | 'n_saves': 5, 130 | 'save_parallel': False, 131 | 'n_reference': 8, 132 | 'bucket': None, 133 | 'device': 'cuda', 134 | 'seed': None, 135 | }, 136 | 137 | 'plan': { 138 | 'guide': 'sampling.nogradGuide', 139 | 'policy': 'sampling.GuidedPolicy', 140 | 'max_episode_length': 1000, 141 | 'batch_size': 64, 142 | 'preprocess_fns': [], 143 | 'device': 'cuda', 144 | 'seed': None, 145 | 146 | ## sample_kwargs 147 | 'n_guide_steps': 2, 148 | 'scale': 0.1, 149 | 't_stopgrad': 2, 150 | 'scale_grad_by_std': True, 151 | 152 | ## serialization 153 | 'loadbase': None, 154 | 'logbase': logbase, 155 | 'prefix': 'plans/', 156 | 'exp_name': watch(args_to_watch), 157 | 'vis_freq': 100, 158 | 'max_render': 8, 159 | 160 | ## diffusion model 161 | 'horizon': 32, 162 | 'n_diffusion_steps': 20, 163 | 164 | ## value function 165 | 'discount': 0.997, 166 | 167 | ## loading 168 | 'diffusion_loadpath': 'f:diffusion/defaults_H{horizon}_T{n_diffusion_steps}', 169 | 'value_loadpath': 'f:values/defaults_H{horizon}_T{n_diffusion_steps}_d{discount}', 170 | 171 | 'diffusion_epoch': 'latest', 172 | 'value_epoch': 'latest', 173 | 174 | 'verbose': True, 175 | 'suffix': '0', 176 | 'meta_task':'basketball-v2' 177 | }, 178 | } 179 | 180 | 181 | #------------------------ overrides ------------------------# 182 | 183 | 184 | hopper_medium_expert_v2 = { 185 | 'plan': { 186 | 'scale': 0.0001, 187 | 't_stopgrad': 4, 188 | }, 189 | } 190 | 191 | 192 | halfcheetah_medium_replay_v2 = halfcheetah_medium_v2 = halfcheetah_medium_expert_v2 = { 193 | 'diffusion': { 194 | 'horizon': 4, 195 | 'dim_mults': (1, 4, 8), 196 | 'attention': True, 197 | }, 198 | 'values': { 199 | 'horizon': 4, 200 | 'dim_mults': (1, 4, 8), 201 | }, 202 | 'plan': { 203 | 'horizon': 4, 204 | 'scale': 0.001, 205 | 't_stopgrad': 4, 206 | }, 207 | } 208 | -------------------------------------------------------------------------------- /diffuser/__init__.py: -------------------------------------------------------------------------------- 1 | from . import environments -------------------------------------------------------------------------------- /diffuser/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /diffuser/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /diffuser/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /diffuser/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .latendata import * -------------------------------------------------------------------------------- /diffuser/datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/datasets/__pycache__/latendata.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/datasets/__pycache__/latendata.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/datasets/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/datasets/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/datasets/utils.py: -------------------------------------------------------------------------------- 1 | 2 | task_prompts = [ 3 | 'Dunk the basketball into the basket', 4 | 'Grasp the puck from one bin and place it into another bin', 5 | 'Press a button from the top', 6 | 'Press a button', 7 | 'Bypass a wall and press a button', 8 | 'Push a button on the coffee machine', 9 | 'Pull a mug from a coffee machine', 10 | 'Push a mug under a coffee machine', 11 | 'Rotate a dial 180 degrees', 12 | 'pick a nut out of the a peg', 13 | 'Close a door with a revolving joint', 14 | 'Lock the door by rotating the lock clockwise', 15 | 'Open a door with a revolving joint', 16 | 'Unlock the door by rotating the lock counter-clockwise', 17 | 'Insert the gripper into a hole', 18 | 'Push and close a drawer', 19 | 'Open a drawer', 20 | 'Rotate the faucet counter-clockwise', 21 | 'Rotate the faucet clockwise', 22 | 'Press a handle down sideways', 23 | 'Press a handle down', 24 | 'Pull a handle up sideways', 25 | 'Pull a handle up', 26 | 'Pull a lever down 90 degrees', 27 | 'Insert a peg sideways', 28 | 'Pick a puck, bypass a wall and place the puck', 29 | 'Pick up a puck from a hole', 30 | 'reach a goal position', 31 | 'Pull a puck to a goal', 32 | 'Push the puck to a goal', 33 | 'Pick and place a puck to a goal', 34 | 'Slide a plate into a cabinet', 35 | 'Slide a plate into a cabinet sideways', 36 | 'Get a plate from the cabinet', 37 | 'Get a plate from the cabinet sideways', 38 | 'Kick a soccer into the goal', 39 | 'Bypass a wall and push a puck to a goal', 40 | 'pick and place a puck onto a shelf', 41 | 'Sweep a puck into a hole', 42 | 'Sweep a puck off the table', 43 | 'Push and open a window', 44 | 'Push and close a window', 45 | 'Pick up a nut and place it onto a peg', 46 | 'Bypass a wall and press a button from the top', 47 | 'Hammer a screw on the wall', 48 | 'Unplug a peg sideways', 49 | 'Bypass a wall and reach a goal', 50 | 'Grasp a stick and push a box using the stick', 51 | 'Grasp a stick and pull a box with the stick', 52 | 'Grasp the cover and close the box with it', 53 | ] 54 | -------------------------------------------------------------------------------- /diffuser/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .temporal import VideoModel,PerceiverVideoModel, VQVideoModel 2 | from .diffusion import GaussianVideoDiffusion, VQVideoDiffusion, MultiviewVideoDiffusion, GaussianDiffusion,sodaDiffusion,DTDiffusion 3 | from .VLA_model import VLATransformer 4 | from .video_model import VideoDiffuserModel,R3MModel, Tasksmeta, DT 5 | from .soda_model import pretrainModel, pretrainModel_v1,sodaActModel -------------------------------------------------------------------------------- /diffuser/models/__pycache__/GPT2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/models/__pycache__/GPT2.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/models/__pycache__/VLA_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/models/__pycache__/VLA_model.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/models/__pycache__/diffusion.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/models/__pycache__/diffusion.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/models/__pycache__/diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/models/__pycache__/diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/models/__pycache__/encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/models/__pycache__/encoder.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/models/__pycache__/helpers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/models/__pycache__/helpers.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/models/__pycache__/helpers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/models/__pycache__/helpers.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/models/__pycache__/soda_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/models/__pycache__/soda_model.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/models/__pycache__/temporal.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/models/__pycache__/temporal.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/models/__pycache__/temporal.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/models/__pycache__/temporal.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/models/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/models/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/models/__pycache__/video_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/models/__pycache__/video_model.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .serialization import * 2 | from .training import * 3 | from .progress import * 4 | from .setup import * 5 | from .config import * 6 | #from .rendering import * 7 | from .arrays import * 8 | from .colab import * 9 | from .logger import * 10 | from .action_tokenizer import * 11 | from gym.envs.registration import register 12 | register( 13 | id='maze2d-1', 14 | entry_point='d4rl.pointmaze:MazeEnv', 15 | max_episode_steps=600, 16 | kwargs={ 17 | 'maze_spec':MAZE_1, 18 | 'reward_type':'sparse', 19 | 'reset_target': False, 20 | 'ref_min_score': 13.13, 21 | 'ref_max_score': 277.39, 22 | } 23 | ) 24 | register( 25 | id='maze2d-2', 26 | entry_point='d4rl.pointmaze:MazeEnv', 27 | max_episode_steps=600, 28 | kwargs={ 29 | 'maze_spec':MAZE_2, 30 | 'reward_type':'sparse', 31 | 'reset_target': False, 32 | 'ref_min_score': 13.13, 33 | 'ref_max_score': 277.39, 34 | } 35 | ) 36 | register( 37 | id='maze2d-3', 38 | entry_point='d4rl.pointmaze:MazeEnv', 39 | max_episode_steps=600, 40 | kwargs={ 41 | 'maze_spec':MAZE_3, 42 | 'reward_type':'sparse', 43 | 'reset_target': False, 44 | 'ref_min_score': 13.13, 45 | 'ref_max_score': 277.39, 46 | } 47 | ) 48 | register( 49 | id='maze2d-4', 50 | entry_point='d4rl.pointmaze:MazeEnv', 51 | max_episode_steps=600, 52 | kwargs={ 53 | 'maze_spec':MAZE_4, 54 | 'reward_type':'sparse', 55 | 'reset_target': False, 56 | 'ref_min_score': 13.13, 57 | 'ref_max_score': 277.39, 58 | } 59 | ) 60 | register( 61 | id='maze2d-5', 62 | entry_point='d4rl.pointmaze:MazeEnv', 63 | max_episode_steps=600, 64 | kwargs={ 65 | 'maze_spec':MAZE_5, 66 | 'reward_type':'sparse', 67 | 'reset_target': False, 68 | 'ref_min_score': 13.13, 69 | 'ref_max_score': 277.39, 70 | } 71 | ) 72 | register( 73 | id='maze2d-6', 74 | entry_point='d4rl.pointmaze:MazeEnv', 75 | max_episode_steps=600, 76 | kwargs={ 77 | 'maze_spec':MAZE_6, 78 | 'reward_type':'sparse', 79 | 'reset_target': False, 80 | 'ref_min_score': 13.13, 81 | 'ref_max_score': 277.39, 82 | } 83 | ) 84 | register( 85 | id='maze2d-7', 86 | entry_point='d4rl.pointmaze:MazeEnv', 87 | max_episode_steps=600, 88 | kwargs={ 89 | 'maze_spec':MAZE_7, 90 | 'reward_type':'sparse', 91 | 'reset_target': False, 92 | 'ref_min_score': 13.13, 93 | 'ref_max_score': 277.39, 94 | } 95 | ) 96 | register( 97 | id='maze2d-8', 98 | entry_point='d4rl.pointmaze:MazeEnv', 99 | max_episode_steps=600, 100 | kwargs={ 101 | 'maze_spec':MAZE_8, 102 | 'reward_type':'sparse', 103 | 'reset_target': False, 104 | 'ref_min_score': 13.13, 105 | 'ref_max_score': 277.39, 106 | } 107 | ) 108 | -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/action_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/action_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/arrays.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/arrays.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/arrays.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/arrays.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/arrays.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/arrays.cpython-39.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/cloud.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/cloud.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/cloud.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/cloud.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/cloud.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/cloud.cpython-39.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/colab.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/colab.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/colab.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/colab.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/git_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/git_utils.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/git_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/git_utils.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/logger.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/logger.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/progress.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/progress.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/progress.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/progress.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/progress.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/progress.cpython-39.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/rendering.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/rendering.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/rendering.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/rendering.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/serialization.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/serialization.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/serialization.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/serialization.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/serialization.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/serialization.cpython-39.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/setup.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/setup.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/setup.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/setup.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/setup.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/setup.cpython-39.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/timer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/timer.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/timer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/timer.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/timer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/timer.cpython-39.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/training.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/training.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/training.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/training.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/training.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/training.cpython-39.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/video.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/video.cpython-37.pyc -------------------------------------------------------------------------------- /diffuser/utils/__pycache__/video.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/__pycache__/video.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/action_tokenizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pdb 4 | 5 | from .arrays import to_np, to_torch 6 | 7 | class QuantileDiscretizer: 8 | 9 | def __init__(self, data, N): 10 | self.data = data #shape: N x dim 11 | self.N = N #number of bins 12 | 13 | n_points_per_bin = int(np.ceil(len(data) / N)) 14 | obs_sorted = np.sort(data, axis=0) 15 | #a wrong example in https://numpy.org/doc/stable/reference/generated/numpy.ndarray.sort.html#numpy.ndarray.sort 16 | thresholds = obs_sorted[::n_points_per_bin, :] 17 | maxs = data.max(axis=0, keepdims=True) 18 | 19 | ## [ (N + 1) x dim ] 20 | self.thresholds = np.concatenate([thresholds, maxs], axis=0) 21 | np.savez(f'./tokenizer_{N}.npz', thresholds=self.thresholds, N=self.N) 22 | # threshold_inds = np.linspace(0, len(data) - 1, N + 1, dtype=int) 23 | # obs_sorted = np.sort(data, axis=0) 24 | 25 | # ## [ (N + 1) x dim ] 26 | # self.thresholds = obs_sorted[threshold_inds, :] 27 | 28 | ## [ N x dim ] 29 | self.diffs = self.thresholds[1:] - self.thresholds[:-1] 30 | 31 | ## for sparse reward tasks 32 | # if (self.diffs[:,-1] == 0).any(): 33 | # raise RuntimeError('rebin for sparse reward tasks') 34 | 35 | self._test() 36 | 37 | def __call__(self, x): 38 | indices = self.discretize(x) 39 | recon = self.reconstruct(indices) 40 | error = np.abs(recon - x).max(0) 41 | return indices, recon, error 42 | #def normalize(self, x): 43 | def _test(self): 44 | print('[ utils/discretization ] Testing...', end=' ', flush=True) 45 | inds = np.random.randint(0, len(self.data), size=1000) 46 | X = self.data[inds] 47 | #import pdb;pdb.set_trace() 48 | indices = self.discretize(X) 49 | recon = self.reconstruct(indices) 50 | ## make sure reconstruction error is less than the max allowed per dimension 51 | error = np.abs(X - recon).max(0) 52 | assert (error <= self.diffs.max(axis=0)).all() 53 | ## re-discretize reconstruction and make sure it is the same as original indices 54 | indices_2 = self.discretize(recon) 55 | assert (indices == indices_2).all() 56 | ## reconstruct random indices 57 | ## @TODO: remove duplicate thresholds 58 | # randint = np.random.randint(0, self.N, indices.shape) 59 | # randint_2 = self.discretize(self.reconstruct(randint)) 60 | # assert (randint == randint_2).all() 61 | print('✓') 62 | 63 | def discretize(self, x, subslice=(None, None)): 64 | ''' 65 | x : [ B x observation_dim ] 66 | ''' 67 | 68 | if torch.is_tensor(x): 69 | x = to_np(x) 70 | 71 | ## enforce batch mode 72 | if x.ndim == 1: 73 | x = x[None] 74 | 75 | #gripper_open 76 | #xg = x[:, -1:].astype(int) 77 | ## [ N x B x observation_dim ] 78 | start, end = subslice 79 | thresholds = self.thresholds[:, start:end] 80 | 81 | gt = x[None] >= thresholds[:,None] 82 | indices = largest_nonzero_index(gt, dim=0) 83 | 84 | if indices.min() < 0 or indices.max() >= self.N: 85 | indices = np.clip(indices, 0, self.N - 1) 86 | #indices = np.concatenate([indices[:,:-1], xg], axis=-1) #gripper_open 87 | return indices 88 | 89 | def reconstruct(self, indices, subslice=(None, None)): 90 | 91 | if torch.is_tensor(indices): 92 | indices = to_np(indices) 93 | 94 | ## enforce batch mode 95 | if indices.ndim == 1: 96 | indices = indices[None] 97 | 98 | if indices.min() < 0 or indices.max() >= self.N: 99 | print(f'[ utils/discretization ] indices out of range: ({indices.min()}, {indices.max()}) | N: {self.N}') 100 | indices = np.clip(indices, 0, self.N - 1) 101 | #gripper_open 102 | #xg = indices[:, -1:].astype(int) 103 | start, end = subslice 104 | thresholds = self.thresholds[:, start:end] 105 | 106 | left = np.take_along_axis(thresholds, indices, axis=0) 107 | right = np.take_along_axis(thresholds, indices + 1, axis=0) 108 | recon = (left + right) / 2. 109 | #recon = np.concatenate([recon[:,:-1], xg], axis=-1)#gripper_open 110 | return recon 111 | 112 | #---------------------------- wrappers for planning ----------------------------# 113 | 114 | def expectation(self, probs, subslice): 115 | ''' 116 | probs : [ B x N ] 117 | ''' 118 | 119 | if torch.is_tensor(probs): 120 | probs = to_np(probs) 121 | 122 | ## [ N ] 123 | thresholds = self.thresholds[:, subslice] 124 | ## [ B ] 125 | left = probs @ thresholds[:-1] 126 | right = probs @ thresholds[1:] 127 | 128 | avg = (left + right) / 2. 129 | return avg 130 | 131 | def percentile(self, probs, percentile, subslice): 132 | ''' 133 | percentile `p` : 134 | returns least value `v` s.t. cdf up to `v` is >= `p` 135 | e.g., p=0.8 and v=100 indicates that 136 | 100 is in the 80% percentile of values 137 | ''' 138 | ## [ N ] 139 | thresholds = self.thresholds[:, subslice] 140 | ## [ B x N ] 141 | cumulative = np.cumsum(probs, axis=-1) 142 | valid = cumulative > percentile 143 | ## [ B ] 144 | inds = np.argmax(np.arange(self.N, 0, -1) * valid, axis=-1) 145 | left = thresholds[inds-1] 146 | right = thresholds[inds] 147 | avg = (left + right) / 2. 148 | return avg 149 | 150 | #---------------------------- wrappers for planning ----------------------------# 151 | 152 | def value_expectation(self, probs): 153 | ''' 154 | probs : [ B x 2 x ( N + 1 ) ] 155 | extra token comes from termination 156 | ''' 157 | 158 | if torch.is_tensor(probs): 159 | probs = to_np(probs) 160 | return_torch = True 161 | else: 162 | return_torch = False 163 | 164 | probs = probs[:, :, :-1] 165 | assert probs.shape[-1] == self.N 166 | 167 | rewards = self.expectation(probs[:, 0], subslice=-2) 168 | next_values = self.expectation(probs[:, 1], subslice=-1) 169 | 170 | if return_torch: 171 | rewards = to_torch(rewards) 172 | next_values = to_torch(next_values) 173 | 174 | return rewards, next_values 175 | 176 | def value_fn(self, probs, percentile): 177 | if percentile == 'mean': 178 | return self.value_expectation(probs) 179 | else: 180 | ## percentile should be interpretable as float, 181 | ## even if passed in as str because of command-line parser 182 | percentile = float(percentile) 183 | 184 | if torch.is_tensor(probs): 185 | probs = to_np(probs) 186 | return_torch = True 187 | else: 188 | return_torch = False 189 | 190 | probs = probs[:, :, :-1] 191 | assert probs.shape[-1] == self.N 192 | 193 | rewards = self.percentile(probs[:, 0], percentile, subslice=-2) 194 | next_values = self.percentile(probs[:, 1], percentile, subslice=-1) 195 | 196 | if return_torch: 197 | rewards = to_torch(rewards) 198 | next_values = to_torch(next_values) 199 | 200 | return rewards, next_values 201 | 202 | def largest_nonzero_index(x, dim): 203 | N = x.shape[dim] 204 | arange = np.arange(N) + 1 205 | 206 | for i in range(dim): 207 | arange = np.expand_dims(arange, axis=0) 208 | for i in range(dim+1, x.ndim): 209 | arange = np.expand_dims(arange, axis=-1) 210 | 211 | inds = np.argmax(x * arange, axis=0) 212 | ## masks for all `False` or all `True` 213 | lt_mask = (~x).all(axis=0) 214 | gt_mask = (x).all(axis=0) 215 | 216 | inds[lt_mask] = 0 217 | inds[gt_mask] = N 218 | 219 | return inds -------------------------------------------------------------------------------- /diffuser/utils/arrays.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | import torch 4 | import pdb 5 | 6 | DTYPE = torch.float 7 | DEVICE = 'cuda' 8 | 9 | #-----------------------------------------------------------------------------# 10 | #------------------------------ numpy <--> torch -----------------------------# 11 | #-----------------------------------------------------------------------------# 12 | def parse_maze(maze_str): 13 | lines = maze_str.strip().split('\\') 14 | width, height = len(lines)-2, len(lines[0])-2 15 | maze_arr = np.zeros((width, height), dtype=np.int32) 16 | for w in range(width): 17 | for h in range(height): 18 | tile = lines[w+1][h+1] 19 | if tile == '#': 20 | maze_arr[w][h] = 1 21 | elif tile == 'G': 22 | maze_arr[w][h] = -1 23 | elif tile == ' ' or tile == 'O' or tile == '0': 24 | maze_arr[w][h] = 0 25 | else: 26 | raise ValueError('Unknown tile type: %s' % tile) 27 | return maze_arr.flatten()[:-1] 28 | MAZE_1 = \ 29 | '########\\'+\ 30 | '#OOOOOO#\\'+\ 31 | '#OOOOOO#\\'+\ 32 | '#OOOOOO#\\'+\ 33 | '#OOOOOO#\\'+\ 34 | '#OOOOOO#\\'+\ 35 | '#OOOOOG#\\'+\ 36 | "########" 37 | MAZE_2 = \ 38 | '########\\'+\ 39 | '#OO##OO#\\'+\ 40 | '##OOOOO#\\'+\ 41 | '#OOO####\\'+\ 42 | '#OOO####\\'+\ 43 | '##OOOOO#\\'+\ 44 | '#OO##OG#\\'+\ 45 | "########" 46 | MAZE_3 = \ 47 | '########\\'+\ 48 | '##OOOOO#\\'+\ 49 | '#OOO#OO#\\'+\ 50 | '#OO#OOO#\\'+\ 51 | '#O####O#\\'+\ 52 | '#OOO#OO#\\'+\ 53 | '##OOOOG#\\'+\ 54 | "########" 55 | MAZE_4 = \ 56 | '########\\'+\ 57 | '#OOOOOO#\\'+\ 58 | '##OOOO##\\'+\ 59 | '#OO##OO#\\'+\ 60 | '#OO##OO#\\'+\ 61 | '##OOOO##\\'+\ 62 | '#OOOOOG#\\'+\ 63 | "########" 64 | MAZE_5 = \ 65 | '########\\'+\ 66 | '#OOOOOO#\\'+\ 67 | '#OO#OOO#\\'+\ 68 | '##O##O##\\'+\ 69 | '##O##O##\\'+\ 70 | '#OO#OOO#\\'+\ 71 | '#OOOOOG#\\'+\ 72 | "########" 73 | MAZE_6 = \ 74 | '########\\'+\ 75 | '#OO##OO#\\'+\ 76 | '#OO##OO#\\'+\ 77 | '#OO##OO#\\'+\ 78 | '##OOOO##\\'+\ 79 | '#OO#OOO#\\'+\ 80 | '#OO##OG#\\'+\ 81 | "########" 82 | MAZE_7 = \ 83 | '########\\'+\ 84 | '#OOOOOO#\\'+\ 85 | '#OOOOOO#\\'+\ 86 | '####OOO#\\'+\ 87 | '####OOO#\\'+\ 88 | '#OOOO###\\'+\ 89 | '#OOOOOG#\\'+\ 90 | "########" 91 | MAZE_8 = \ 92 | '########\\'+\ 93 | '#OOOOO##\\'+\ 94 | '#OO#OOO#\\'+\ 95 | '#O####O#\\'+\ 96 | '#OO#OOO#\\'+\ 97 | '#OO#OOO#\\'+\ 98 | '#OOOOOG#\\'+\ 99 | "########" 100 | def to_np(x): 101 | if torch.is_tensor(x): 102 | x = x.detach().cpu().numpy() 103 | return x 104 | 105 | def to_torch(x, dtype=None, device=None): 106 | dtype = dtype or DTYPE 107 | device = device or DEVICE 108 | if type(x) is dict: 109 | return {k: to_torch(v, dtype, device) for k, v in x.items()} 110 | elif torch.is_tensor(x): 111 | return x.to(device).type(dtype) 112 | return torch.tensor(x, dtype=dtype, device=device) 113 | 114 | def to_device(x, device=DEVICE): 115 | if torch.is_tensor(x): 116 | return x.to(device, dtype=torch.float32) 117 | elif type(x) is dict: 118 | return {k: to_device(v, device) for k, v in x.items()} 119 | else: 120 | raise RuntimeError(f'Unrecognized type in `to_device`: {type(x)}') 121 | 122 | def batchify(batch): 123 | ''' 124 | convert a single dataset item to a batch suitable for passing to a model by 125 | 1) converting np arrays to torch tensors and 126 | 2) and ensuring that everything has a batch dimension 127 | ''' 128 | fn = lambda x: to_torch(x[None]) 129 | 130 | batched_vals = [] 131 | for field in batch._fields: 132 | val = getattr(batch, field) 133 | val = apply_dict(fn, val) if type(val) is dict else fn(val) 134 | batched_vals.append(val) 135 | return type(batch)(*batched_vals) 136 | 137 | def apply_dict(fn, d, *args, **kwargs): 138 | return { 139 | k: fn(v, *args, **kwargs) 140 | for k, v in d.items() 141 | } 142 | 143 | def normalize(x): 144 | """ 145 | scales `x` to [0, 1] 146 | """ 147 | x = x - x.min() 148 | x = x / x.max() 149 | return x 150 | 151 | def to_img(x): 152 | normalized = normalize(x) 153 | array = to_np(normalized) 154 | array = np.transpose(array, (1,2,0)) 155 | return (array * 255).astype(np.uint8) 156 | 157 | def set_device(device): 158 | DEVICE = device 159 | if 'cuda' in device: 160 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 161 | 162 | def batch_to_device(batch, device='cuda:0'): 163 | vals = [ 164 | to_device(getattr(batch, field), device) 165 | for field in batch._fields 166 | ] 167 | return type(batch)(*vals) 168 | 169 | def _to_str(num): 170 | if num >= 1e6: 171 | return f'{(num/1e6):.2f} M' 172 | else: 173 | return f'{(num/1e3):.2f} k' 174 | 175 | #-----------------------------------------------------------------------------# 176 | #----------------------------- parameter counting ----------------------------# 177 | #-----------------------------------------------------------------------------# 178 | 179 | def param_to_module(param): 180 | module_name = param[::-1].split('.', maxsplit=1)[-1][::-1] 181 | return module_name 182 | 183 | def report_parameters(model, topk=10): 184 | counts = {k: p.numel() for k, p in model.named_parameters()} 185 | n_parameters = sum(counts.values()) 186 | print(f'[ utils/arrays ] Total parameters: {_to_str(n_parameters)}') 187 | 188 | modules = dict(model.named_modules()) 189 | sorted_keys = sorted(counts, key=lambda x: -counts[x]) 190 | max_length = max([len(k) for k in sorted_keys]) 191 | for i in range(topk): 192 | key = sorted_keys[i] 193 | if key == 'position_emb': continue 194 | count = counts[key] 195 | module = param_to_module(key) 196 | print(' '*8, f'{key:10}: {_to_str(count)} | {modules[module]}') 197 | 198 | remaining_parameters = sum([counts[k] for k in sorted_keys[topk:]]) 199 | print(' '*8, f'... and {len(counts)-topk} others accounting for {_to_str(remaining_parameters)} parameters') 200 | return n_parameters 201 | -------------------------------------------------------------------------------- /diffuser/utils/cloud.py: -------------------------------------------------------------------------------- 1 | import shlex 2 | import subprocess 3 | import pdb 4 | 5 | 6 | def sync_logs(logdir, bucket, background=False): 7 | ## remove prefix 'logs' on google cloud 8 | destination = 'logs' + logdir.split('logs')[-1] 9 | upload_blob(logdir, destination, bucket, background) 10 | 11 | def upload_blob(source, destination, bucket, background): 12 | command = f'gsutil -m -o GSUtil:parallel_composite_upload_threshold=150M rsync -r {source} {bucket}/{destination}' 13 | print(f'[ utils/cloud ] Syncing bucket: {command}') 14 | command = shlex.split(command) 15 | 16 | if background: 17 | subprocess.Popen(command) 18 | else: 19 | subprocess.call(command) 20 | -------------------------------------------------------------------------------- /diffuser/utils/colab.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import einops 4 | import matplotlib.pyplot as plt 5 | from tqdm import tqdm 6 | 7 | try: 8 | import io 9 | import base64 10 | from IPython.display import HTML 11 | from IPython import display as ipythondisplay 12 | except: 13 | print('[ utils/colab ] Warning: not importing colab dependencies') 14 | 15 | from .serialization import mkdir 16 | from .arrays import to_torch, to_np 17 | from .video import save_video 18 | 19 | 20 | def run_diffusion(model, dataset, obs, n_samples=1, device='cuda:0', **diffusion_kwargs): 21 | ## normalize observation for model 22 | obs = dataset.normalizer.normalize(obs, 'observations') 23 | 24 | ## add a batch dimension and repeat for multiple samples 25 | ## [ observation_dim ] --> [ n_samples x observation_dim ] 26 | obs = obs[None].repeat(n_samples, axis=0) 27 | 28 | ## format `conditions` input for model 29 | conditions = { 30 | 0: to_torch(obs, device=device) 31 | } 32 | 33 | samples, diffusion = model.conditional_sample(conditions, 34 | return_diffusion=True, verbose=False, **diffusion_kwargs) 35 | 36 | ## [ n_samples x (n_diffusion_steps + 1) x horizon x (action_dim + observation_dim)] 37 | diffusion = to_np(diffusion) 38 | 39 | ## extract observations 40 | ## [ n_samples x (n_diffusion_steps + 1) x horizon x observation_dim ] 41 | normed_observations = diffusion[:, :, :, dataset.action_dim:] 42 | 43 | ## unnormalize observation samples from model 44 | observations = dataset.normalizer.unnormalize(normed_observations, 'observations') 45 | 46 | ## [ (n_diffusion_steps + 1) x n_samples x horizon x observation_dim ] 47 | observations = einops.rearrange(observations, 48 | 'batch steps horizon dim -> steps batch horizon dim') 49 | 50 | return observations 51 | 52 | 53 | def show_diffusion(renderer, observations, n_repeat=100, substep=1, filename='diffusion.mp4', savebase='/content/videos'): 54 | ''' 55 | observations : [ n_diffusion_steps x batch_size x horizon x observation_dim ] 56 | ''' 57 | mkdir(savebase) 58 | savepath = os.path.join(savebase, filename) 59 | 60 | subsampled = observations[::substep] 61 | 62 | images = [] 63 | for t in tqdm(range(len(subsampled))): 64 | observation = subsampled[t] 65 | 66 | img = renderer.composite(None, observation) 67 | images.append(img) 68 | images = np.stack(images, axis=0) 69 | 70 | ## pause at the end of video 71 | images = np.concatenate([ 72 | images, 73 | images[-1:].repeat(n_repeat, axis=0) 74 | ], axis=0) 75 | 76 | save_video(savepath, images) 77 | show_video(savepath) 78 | 79 | 80 | def show_sample(renderer, observations, filename='sample.mp4', savebase='/content/videos'): 81 | ''' 82 | observations : [ batch_size x horizon x observation_dim ] 83 | ''' 84 | 85 | mkdir(savebase) 86 | savepath = os.path.join(savebase, filename) 87 | 88 | images = [] 89 | for rollout in observations: 90 | ## [ horizon x height x width x channels ] 91 | img = renderer._renders(rollout, partial=True) 92 | images.append(img) 93 | 94 | ## [ horizon x height x (batch_size * width) x channels ] 95 | images = np.concatenate(images, axis=2) 96 | 97 | save_video(savepath, images) 98 | show_video(savepath, height=200) 99 | 100 | 101 | def show_samples(renderer, observations_l, figsize=12): 102 | ''' 103 | observations_l : [ [ n_diffusion_steps x batch_size x horizon x observation_dim ], ... ] 104 | ''' 105 | 106 | images = [] 107 | for observations in observations_l: 108 | path = observations[-1] 109 | img = renderer.composite(None, path) 110 | images.append(img) 111 | images = np.concatenate(images, axis=0) 112 | 113 | plt.imshow(images) 114 | plt.axis('off') 115 | plt.gcf().set_size_inches(figsize, figsize) 116 | 117 | 118 | def show_video(path, height=400): 119 | video = io.open(path, 'r+b').read() 120 | encoded = base64.b64encode(video) 121 | ipythondisplay.display(HTML(data=''''''.format(height, encoded.decode('ascii')))) 125 | -------------------------------------------------------------------------------- /diffuser/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import importlib 4 | import pickle 5 | 6 | def import_class(_class): 7 | if type(_class) is not str: return _class 8 | ## 'diffusion' on standard installs 9 | repo_name = __name__.split('.')[0] 10 | ## eg, 'utils' 11 | module_name = '.'.join(_class.split('.')[:-1]) 12 | ## eg, 'Renderer' 13 | class_name = _class.split('.')[-1] 14 | ## eg, 'diffusion.utils' 15 | module = importlib.import_module(f'{repo_name}.{module_name}') 16 | ## eg, diffusion.utils.Renderer 17 | _class = getattr(module, class_name) 18 | print(f'[ utils/config ] Imported {repo_name}.{module_name}:{class_name}') 19 | return _class 20 | 21 | class Config(collections.abc.Mapping): 22 | 23 | def __init__(self, _class, verbose=True, savepath=None, device=None, **kwargs): 24 | #print(">>>>>>") 25 | self._class = import_class(_class) 26 | self._device = device 27 | self._dict = {} 28 | 29 | for key, val in kwargs.items(): 30 | #print("config---key:",key," value: ",val) 31 | self._dict[key] = val 32 | 33 | if verbose: 34 | print(self) 35 | 36 | if savepath is not None: 37 | savepath = os.path.join(*savepath) if type(savepath) is tuple else savepath 38 | pickle.dump(self, open(savepath, 'wb')) 39 | print(f'[ utils/config ] Saved config to: {savepath}\n') 40 | def __repr__(self): 41 | string = f'\n[utils/config ] Config: {self._class}\n' 42 | for key in sorted(self._dict.keys()): 43 | val = self._dict[key] 44 | string += f' {key}: {val}\n' 45 | return string 46 | 47 | def __iter__(self): 48 | return iter(self._dict) 49 | 50 | def __getitem__(self, item): 51 | return self._dict[item] 52 | 53 | def __len__(self): 54 | return len(self._dict) 55 | 56 | def __getattr__(self, attr): 57 | if attr == '_dict' and '_dict' not in vars(self): 58 | self._dict = {} 59 | return self._dict 60 | try: 61 | return self._dict[attr] 62 | except KeyError: 63 | raise AttributeError(attr) 64 | 65 | def __call__(self, *args, **kwargs): 66 | instance = self._class(*args, **kwargs, **self._dict) 67 | #self._device = 'cuda:1' 68 | if self._device: 69 | #self._device = 'cuda:1' 70 | instance = instance.to(self._device) 71 | return instance 72 | -------------------------------------------------------------------------------- /diffuser/utils/distributed/__pycache__/distributed.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/distributed/__pycache__/distributed.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/distributed/__pycache__/launch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/diffuser/utils/distributed/__pycache__/launch.cpython-38.pyc -------------------------------------------------------------------------------- /diffuser/utils/distributed/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils import data 7 | 8 | 9 | LOCAL_PROCESS_GROUP = None 10 | 11 | 12 | def is_primary(): 13 | return get_rank() == 0 14 | 15 | 16 | def get_rank(): 17 | if not dist.is_available(): 18 | return 0 19 | 20 | if not dist.is_initialized(): 21 | return 0 22 | 23 | return dist.get_rank() 24 | 25 | 26 | def get_local_rank(): 27 | if not dist.is_available(): 28 | return 0 29 | 30 | if not dist.is_initialized(): 31 | return 0 32 | 33 | if LOCAL_PROCESS_GROUP is None: 34 | raise ValueError("tensorfn.distributed.LOCAL_PROCESS_GROUP is None") 35 | 36 | return dist.get_rank(group=LOCAL_PROCESS_GROUP) 37 | 38 | 39 | def synchronize(): 40 | if not dist.is_available(): 41 | return 42 | 43 | if not dist.is_initialized(): 44 | return 45 | 46 | world_size = dist.get_world_size() 47 | 48 | if world_size == 1: 49 | return 50 | 51 | dist.barrier() 52 | 53 | 54 | def get_world_size(): 55 | if not dist.is_available(): 56 | return 1 57 | 58 | if not dist.is_initialized(): 59 | return 1 60 | 61 | return dist.get_world_size() 62 | 63 | 64 | def is_distributed(): 65 | raise RuntimeError('Please debug this function!') 66 | return get_world_size() > 1 67 | 68 | 69 | def all_reduce(tensor, op=dist.ReduceOp.SUM, async_op=False): 70 | world_size = get_world_size() 71 | 72 | if world_size == 1: 73 | return tensor 74 | dist.all_reduce(tensor, op=op, async_op=async_op) 75 | 76 | return tensor 77 | 78 | 79 | def all_gather(data): 80 | world_size = get_world_size() 81 | 82 | if world_size == 1: 83 | return [data] 84 | 85 | buffer = pickle.dumps(data) 86 | storage = torch.ByteStorage.from_buffer(buffer) 87 | tensor = torch.ByteTensor(storage).to("cuda") 88 | 89 | local_size = torch.IntTensor([tensor.numel()]).to("cuda") 90 | size_list = [torch.IntTensor([1]).to("cuda") for _ in range(world_size)] 91 | dist.all_gather(size_list, local_size) 92 | size_list = [int(size.item()) for size in size_list] 93 | max_size = max(size_list) 94 | 95 | tensor_list = [] 96 | for _ in size_list: 97 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 98 | 99 | if local_size != max_size: 100 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 101 | tensor = torch.cat((tensor, padding), 0) 102 | 103 | dist.all_gather(tensor_list, tensor) 104 | 105 | data_list = [] 106 | 107 | for size, tensor in zip(size_list, tensor_list): 108 | buffer = tensor.cpu().numpy().tobytes()[:size] 109 | data_list.append(pickle.loads(buffer)) 110 | 111 | return data_list 112 | 113 | 114 | def reduce_dict(input_dict, average=True): 115 | world_size = get_world_size() 116 | 117 | if world_size < 2: 118 | return input_dict 119 | 120 | with torch.no_grad(): 121 | keys = [] 122 | values = [] 123 | 124 | for k in sorted(input_dict.keys()): 125 | keys.append(k) 126 | values.append(input_dict[k]) 127 | 128 | values = torch.stack(values, 0) 129 | dist.reduce(values, dst=0) 130 | 131 | if dist.get_rank() == 0 and average: 132 | values /= world_size 133 | 134 | reduced_dict = {k: v for k, v in zip(keys, values)} 135 | 136 | return reduced_dict 137 | 138 | 139 | def data_sampler(dataset, shuffle, distributed): 140 | if distributed: 141 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 142 | 143 | if shuffle: 144 | return data.RandomSampler(dataset) 145 | 146 | else: 147 | return data.SequentialSampler(dataset) 148 | -------------------------------------------------------------------------------- /diffuser/utils/distributed/launch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import distributed as dist 5 | from torch import multiprocessing as mp 6 | 7 | # import distributed as dist_fn 8 | import diffuser.utils.distributed.distributed as dist_fn 9 | 10 | def find_free_port(): 11 | import socket 12 | 13 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 14 | 15 | sock.bind(("", 0)) 16 | port = sock.getsockname()[1] 17 | sock.close() 18 | 19 | return port 20 | 21 | 22 | def launch(fn, n_gpu_per_machine, n_machine=1, machine_rank=0, dist_url=None, args=()): 23 | world_size = n_machine * n_gpu_per_machine 24 | 25 | if world_size > 1: 26 | # if "OMP_NUM_THREADS" not in os.environ: 27 | # os.environ["OMP_NUM_THREADS"] = "1" 28 | 29 | if dist_url == "auto": 30 | if n_machine != 1: 31 | raise ValueError('dist_url="auto" not supported in multi-machine jobs') 32 | 33 | port = find_free_port() 34 | dist_url = f"tcp://127.0.0.1:{port}" 35 | 36 | if n_machine > 1 and dist_url.startswith("file://"): 37 | raise ValueError( 38 | "file:// is not a reliable init method in multi-machine jobs. Prefer tcp://" 39 | ) 40 | 41 | mp.spawn( 42 | distributed_worker, 43 | nprocs=n_gpu_per_machine, 44 | args=(fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args), 45 | daemon=False, 46 | ) 47 | 48 | else: 49 | local_rank = 0 50 | fn(local_rank, *args) 51 | 52 | 53 | def distributed_worker( 54 | local_rank, fn, world_size, n_gpu_per_machine, machine_rank, dist_url, args 55 | ): 56 | if not torch.cuda.is_available(): 57 | raise OSError("CUDA is not available. Please check your environments") 58 | 59 | global_rank = machine_rank * n_gpu_per_machine + local_rank 60 | 61 | try: 62 | dist.init_process_group( 63 | backend="NCCL", 64 | init_method=dist_url, 65 | world_size=world_size, 66 | rank=global_rank, 67 | ) 68 | 69 | except Exception: 70 | raise OSError("failed to initialize NCCL groups") 71 | 72 | dist_fn.synchronize() 73 | 74 | if n_gpu_per_machine > torch.cuda.device_count(): 75 | raise ValueError( 76 | f"specified n_gpu_per_machine larger than available device ({torch.cuda.device_count()})" 77 | ) 78 | 79 | torch.cuda.set_device(local_rank) 80 | 81 | if dist_fn.LOCAL_PROCESS_GROUP is not None: 82 | raise ValueError("torch.distributed.LOCAL_PROCESS_GROUP is not None") 83 | 84 | n_machine = world_size // n_gpu_per_machine 85 | 86 | for i in range(n_machine): 87 | ranks_on_i = list(range(i * n_gpu_per_machine, (i + 1) * n_gpu_per_machine)) 88 | pg = dist.new_group(ranks_on_i) 89 | 90 | if i == machine_rank: 91 | dist_fn.LOCAL_PROCESS_GROUP = pg 92 | 93 | fn(local_rank, *args) 94 | -------------------------------------------------------------------------------- /diffuser/utils/git_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import git 3 | import pdb 4 | 5 | PROJECT_PATH = os.path.dirname( 6 | os.path.realpath(os.path.join(__file__, '..', '..'))) 7 | 8 | def get_repo(path=PROJECT_PATH, search_parent_directories=True): 9 | repo = git.Repo( 10 | path, search_parent_directories=search_parent_directories) 11 | return repo 12 | 13 | def get_git_rev(*args, **kwargs): 14 | try: 15 | repo = get_repo(*args, **kwargs) 16 | if repo.head.is_detached: 17 | git_rev = repo.head.object.name_rev 18 | else: 19 | git_rev = repo.active_branch.commit.name_rev 20 | except: 21 | git_rev = None 22 | 23 | return git_rev 24 | 25 | def git_diff(*args, **kwargs): 26 | repo = get_repo(*args, **kwargs) 27 | diff = repo.git.diff() 28 | return diff 29 | 30 | def save_git_diff(savepath, *args, **kwargs): 31 | diff = git_diff(*args, **kwargs) 32 | with open(savepath, 'w') as f: 33 | f.write(diff) 34 | 35 | if __name__ == '__main__': 36 | 37 | git_rev = get_git_rev() 38 | print(git_rev) 39 | 40 | save_git_diff('diff_test.txt') -------------------------------------------------------------------------------- /diffuser/utils/iql.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import jax 4 | import jax.numpy as jnp 5 | import functools 6 | import pdb 7 | 8 | from diffuser.iql.common import Model 9 | from diffuser.iql.value_net import DoubleCritic 10 | 11 | def load_q(env, loadpath, hidden_dims=(256, 256), seed=42): 12 | print(f'[ utils/iql ] Loading Q: {loadpath}') 13 | observations = env.observation_space.sample()[np.newaxis] 14 | actions = env.action_space.sample()[np.newaxis] 15 | 16 | rng = jax.random.PRNGKey(seed) 17 | rng, key = jax.random.split(rng) 18 | 19 | critic_def = DoubleCritic(hidden_dims) 20 | critic = Model.create(critic_def, 21 | inputs=[key, observations, actions]) 22 | 23 | ## allows for relative paths 24 | loadpath = os.path.expanduser(loadpath) 25 | critic = critic.load(loadpath) 26 | return critic 27 | 28 | class JaxWrapper: 29 | 30 | def __init__(self, env, loadpath, *args, **kwargs): 31 | self.model = load_q(env, loadpath) 32 | 33 | @functools.partial(jax.jit, static_argnames=('self'), device=jax.devices('cpu')[0]) 34 | def forward(self, xs): 35 | Qs = self.model(*xs) 36 | Q = jnp.minimum(*Qs) 37 | return Q 38 | 39 | def __call__(self, *xs): 40 | Q = self.forward(xs) 41 | return np.array(Q) 42 | -------------------------------------------------------------------------------- /diffuser/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | 5 | class Logger: 6 | 7 | def __init__(self, renderer, logpath, vis_freq=10, max_render=8): 8 | self.renderer = renderer 9 | self.savepath = logpath 10 | self.vis_freq = vis_freq 11 | self.max_render = max_render 12 | 13 | def log(self, t, samples, state, rollout=None): 14 | if t % self.vis_freq != 0: 15 | return 16 | 17 | ## render image of plans 18 | self.renderer.composite( 19 | os.path.join(self.savepath, f'{t}.png'), 20 | samples.observations, 21 | ) 22 | 23 | ## render video of plans 24 | self.renderer.render_plan( 25 | os.path.join(self.savepath, f'{t}_plan.mp4'), 26 | samples.actions[:self.max_render], 27 | samples.observations[:self.max_render], 28 | state, 29 | ) 30 | 31 | if rollout is not None: 32 | ## render video of rollout thus far 33 | self.renderer.render_rollout( 34 | os.path.join(self.savepath, f'rollout.mp4'), 35 | rollout, 36 | fps=80, 37 | ) 38 | 39 | def finish(self, t, score, total_reward, terminal, diffusion_experiment, value_experiment): 40 | json_path = os.path.join(self.savepath, 'rollout.json') 41 | json_data = {'score': score, 'step': t, 'return': total_reward, 'term': terminal, 42 | 'epoch_diffusion': diffusion_experiment.epoch, 'epoch_value': value_experiment.epoch} 43 | json.dump(json_data, open(json_path, 'w'), indent=2, sort_keys=True) 44 | print(f'[ utils/logger ] Saved log to {json_path}') 45 | -------------------------------------------------------------------------------- /diffuser/utils/progress.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import pdb 4 | 5 | class Progress: 6 | 7 | def __init__(self, total, name = 'Progress', ncol=3, max_length=20, indent=0, line_width=100, speed_update_freq=100): 8 | self.total = total 9 | self.name = name 10 | self.ncol = ncol 11 | self.max_length = max_length 12 | self.indent = indent 13 | self.line_width = line_width 14 | self._speed_update_freq = speed_update_freq 15 | 16 | self._step = 0 17 | self._prev_line = '\033[F' 18 | self._clear_line = ' ' * self.line_width 19 | 20 | self._pbar_size = self.ncol * self.max_length 21 | self._complete_pbar = '#' * self._pbar_size 22 | self._incomplete_pbar = ' ' * self._pbar_size 23 | 24 | self.lines = [''] 25 | self.fraction = '{} / {}'.format(0, self.total) 26 | 27 | self.resume() 28 | 29 | 30 | def update(self, description, n=1): 31 | self._step += n 32 | if self._step % self._speed_update_freq == 0: 33 | self._time0 = time.time() 34 | self._step0 = self._step 35 | self.set_description(description) 36 | 37 | def resume(self): 38 | self._skip_lines = 1 39 | print('\n', end='') 40 | self._time0 = time.time() 41 | self._step0 = self._step 42 | 43 | def pause(self): 44 | self._clear() 45 | self._skip_lines = 1 46 | 47 | def set_description(self, params=[]): 48 | 49 | if type(params) == dict: 50 | params = sorted([ 51 | (key, val) 52 | for key, val in params.items() 53 | ]) 54 | 55 | ############ 56 | # Position # 57 | ############ 58 | self._clear() 59 | 60 | ########### 61 | # Percent # 62 | ########### 63 | percent, fraction = self._format_percent(self._step, self.total) 64 | self.fraction = fraction 65 | 66 | ######### 67 | # Speed # 68 | ######### 69 | speed = self._format_speed(self._step) 70 | 71 | ########## 72 | # Params # 73 | ########## 74 | num_params = len(params) 75 | nrow = math.ceil(num_params / self.ncol) 76 | params_split = self._chunk(params, self.ncol) 77 | params_string, lines = self._format(params_split) 78 | self.lines = lines 79 | 80 | 81 | description = '{} | {}{}'.format(percent, speed, params_string) 82 | print(description) 83 | self._skip_lines = nrow + 1 84 | 85 | def append_description(self, descr): 86 | self.lines.append(descr) 87 | 88 | def _clear(self): 89 | position = self._prev_line * self._skip_lines 90 | empty = '\n'.join([self._clear_line for _ in range(self._skip_lines)]) 91 | print(position, end='') 92 | print(empty) 93 | print(position, end='') 94 | 95 | def _format_percent(self, n, total): 96 | if total: 97 | percent = n / float(total) 98 | 99 | complete_entries = int(percent * self._pbar_size) 100 | incomplete_entries = self._pbar_size - complete_entries 101 | 102 | pbar = self._complete_pbar[:complete_entries] + self._incomplete_pbar[:incomplete_entries] 103 | fraction = '{} / {}'.format(n, total) 104 | string = '{} [{}] {:3d}%'.format(fraction, pbar, int(percent*100)) 105 | else: 106 | fraction = '{}'.format(n) 107 | string = '{} iterations'.format(n) 108 | return string, fraction 109 | 110 | def _format_speed(self, n): 111 | num_steps = n - self._step0 112 | t = time.time() - self._time0 113 | speed = num_steps / t 114 | string = '{:.1f} Hz'.format(speed) 115 | if num_steps > 0: 116 | self._speed = string 117 | return string 118 | 119 | def _chunk(self, l, n): 120 | return [l[i:i+n] for i in range(0, len(l), n)] 121 | 122 | def _format(self, chunks): 123 | lines = [self._format_chunk(chunk) for chunk in chunks] 124 | lines.insert(0,'') 125 | padding = '\n' + ' '*self.indent 126 | string = padding.join(lines) 127 | return string, lines 128 | 129 | def _format_chunk(self, chunk): 130 | line = ' | '.join([self._format_param(param) for param in chunk]) 131 | return line 132 | 133 | def _format_param(self, param): 134 | k, v = param 135 | return '{} : {}'.format(k, v)[:self.max_length] 136 | 137 | def stamp(self): 138 | if self.lines != ['']: 139 | params = ' | '.join(self.lines) 140 | string = '[ {} ] {}{} | {}'.format(self.name, self.fraction, params, self._speed) 141 | self._clear() 142 | print(string, end='\n') 143 | self._skip_lines = 1 144 | else: 145 | self._clear() 146 | self._skip_lines = 0 147 | 148 | def close(self): 149 | self.pause() 150 | 151 | class Silent: 152 | 153 | def __init__(self, *args, **kwargs): 154 | pass 155 | 156 | def __getattr__(self, attr): 157 | return lambda *args: None 158 | 159 | 160 | if __name__ == '__main__': 161 | silent = Silent() 162 | silent.update() 163 | silent.stamp() 164 | 165 | num_steps = 1000 166 | progress = Progress(num_steps) 167 | for i in range(num_steps): 168 | progress.update() 169 | params = [ 170 | ['A', '{:06d}'.format(i)], 171 | ['B', '{:06d}'.format(i)], 172 | ['C', '{:06d}'.format(i)], 173 | ['D', '{:06d}'.format(i)], 174 | ['E', '{:06d}'.format(i)], 175 | ['F', '{:06d}'.format(i)], 176 | ['G', '{:06d}'.format(i)], 177 | ['H', '{:06d}'.format(i)], 178 | ] 179 | progress.set_description(params) 180 | time.sleep(0.01) 181 | progress.close() 182 | -------------------------------------------------------------------------------- /diffuser/utils/serialization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import glob 4 | import torch 5 | import pdb 6 | 7 | from collections import namedtuple 8 | 9 | DiffusionExperiment = namedtuple('Diffusion', 'dataset renderer model diffusion ema trainer epoch') 10 | mtdtExperiment = namedtuple('mtdtExperiment', 'dataset model ema trainer epoch') 11 | def mkdir(savepath): 12 | """ 13 | returns `True` iff `savepath` is created 14 | """ 15 | if not os.path.exists(savepath): 16 | os.makedirs(savepath) 17 | return True 18 | else: 19 | return False 20 | 21 | def get_latest_epoch(loadpath): 22 | states = glob.glob1(os.path.join(*loadpath), 'state_*') 23 | latest_epoch = -1 24 | for state in states: 25 | epoch = int(state.replace('state_', '').replace('.pt', '')) 26 | latest_epoch = max(epoch, latest_epoch) 27 | return latest_epoch 28 | 29 | def load_config(*loadpath): 30 | loadpath = os.path.join(*loadpath) 31 | config = pickle.load(open(loadpath, 'rb')) 32 | print(f'[ utils/serialization ] Loaded config from {loadpath}') 33 | #print(config) 34 | return config 35 | 36 | def load_diffusion(*loadpath, epoch='latest', device='cuda:0', seed=None): 37 | dataset_config = load_config(*loadpath, 'dataset_config.pkl') 38 | #render_config = load_config(*loadpath, 'render_config.pkl') 39 | model_config = load_config(*loadpath, 'model_config.pkl') 40 | #model_config._device = device 41 | diffusion_config = load_config(*loadpath, 'diffusion_config.pkl') 42 | trainer_config = load_config(*loadpath, 'trainer_config.pkl') 43 | 44 | ## remove absolute path for results loaded from azure 45 | ## @TODO : remove results folder from within trainer class 46 | trainer_config._dict['results_folder'] = os.path.join(*loadpath) 47 | 48 | dataset = dataset_config() 49 | #dataset = dataset() 50 | #renderer = render_config() 51 | renderer = None 52 | model = model_config() 53 | diffusion = diffusion_config(model) 54 | #data = { 55 | # 'model': diffusion.inv_model.state_dict(), 56 | #} 57 | #savepath = os.path.join(f'quadruped_inv_model_v4.pt') 58 | #torch.save(data, savepath) 59 | trainer = trainer_config(diffusion, dataset, renderer) 60 | #trainer = None 61 | if epoch == 'latest': 62 | #epoch = '0_4231.278535482819_0.14285714285714285' 63 | epoch = get_latest_epoch(loadpath) 64 | 65 | print(f'\n[ utils/serialization ] Loading model epoch: {epoch}\n') 66 | 67 | trainer.load(epoch) 68 | return DiffusionExperiment(dataset, trainer.dataloader, trainer.model, diffusion, trainer.ema_model, trainer, epoch) 69 | 70 | def load_mtdt(*loadpath, epoch='latest', device='cuda:0', seed=None): 71 | dataset_config = load_config(*loadpath, 'dataset_config.pkl') 72 | #dataset_config._dict['task_list'] = ['basketball-v2', 'bin-picking-v2'] 73 | #render_config = load_config(*loadpath, 'render_config.pkl') 74 | model_config = load_config(*loadpath, 'model_config.pkl') 75 | trainer_config = load_config(*loadpath, 'trainer_config.pkl') 76 | 77 | ## remove absolute path for results loaded from azure 78 | ## @TODO : remove results folder from within trainer class 79 | trainer_config._dict['results_folder'] = os.path.join(*loadpath) 80 | 81 | dataset = dataset_config(seed=seed) 82 | #renderer = render_config() 83 | model = model_config() 84 | #data = { 85 | # 'model': diffusion.inv_model.state_dict(), 86 | #} 87 | #savepath = os.path.join(f'quadruped_inv_model_v4.pt') 88 | #torch.save(data, savepath) 89 | trainer = trainer_config(model, dataset) 90 | if epoch == 'latest': 91 | #epoch = '0_4231.278535482819_0.14285714285714285' 92 | epoch = get_latest_epoch(loadpath) 93 | 94 | print(f'\n[ utils/serialization ] Loading model epoch: {epoch}\n') 95 | 96 | trainer.load(epoch) 97 | return mtdtExperiment(dataset, model, trainer.model, trainer, epoch) 98 | def load_model(*loadpath, dataset=None, epoch='latest', device='cuda:0', seed=None): 99 | model_config = load_config(*loadpath, 'model_config.pkl') 100 | trainer_config = load_config(*loadpath, 'trainer_config.pkl') 101 | ## @TODO : remove results folder from within trainer class 102 | trainer_config._dict['results_folder'] = os.path.join(*loadpath) 103 | 104 | model = model_config() 105 | trainer = trainer_config(model, dataset) 106 | if epoch == 'latest': 107 | epoch = get_latest_epoch(loadpath) 108 | 109 | print(f'\n[ utils/serialization ] Loading model epoch: {epoch}\n') 110 | 111 | trainer.load(epoch) 112 | return trainer.model 113 | def check_compatibility(experiment_1, experiment_2): 114 | ''' 115 | returns True if `experiment_1 and `experiment_2` have 116 | the same normalizers and number of diffusion steps 117 | ''' 118 | normalizers_1 = experiment_1.dataset.normalizer.get_field_normalizers() 119 | normalizers_2 = experiment_2.dataset.normalizer.get_field_normalizers() 120 | for key in normalizers_1: 121 | norm_1 = type(normalizers_1[key]) 122 | norm_2 = type(normalizers_2[key]) 123 | assert norm_1 == norm_2, \ 124 | f'Normalizers should be identical, found {norm_1} and {norm_2} for field {key}' 125 | 126 | n_steps_1 = experiment_1.diffusion.n_timesteps 127 | n_steps_2 = experiment_2.diffusion.n_timesteps 128 | assert n_steps_1 == n_steps_2, \ 129 | ('Number of timesteps should match between diffusion experiments, ' 130 | f'found {n_steps_1} and {n_steps_2}') 131 | -------------------------------------------------------------------------------- /diffuser/utils/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | import random 4 | import numpy as np 5 | import torch 6 | from tap import Tap 7 | import pdb 8 | from datetime import datetime 9 | from .serialization import mkdir 10 | from .git_utils import ( 11 | get_git_rev, 12 | save_git_diff, 13 | ) 14 | 15 | def set_seed(seed): 16 | random.seed(seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | 21 | def watch(args_to_watch): 22 | def _fn(args): 23 | exp_name = [] 24 | for key, label in args_to_watch: 25 | if not hasattr(args, key): 26 | continue 27 | val = getattr(args, key) 28 | if type(val) == dict: 29 | val = '_'.join(f'{k}-{v}' for k, v in val.items()) 30 | exp_name.append(f'{label}{val}') 31 | exp_name.append('-'+datetime.now().strftime('%b%d_%H-%M-%S')) 32 | exp_name = '_'.join(exp_name) 33 | exp_name = exp_name.replace('/_', '/') 34 | exp_name = exp_name.replace('(', '').replace(')', '') 35 | exp_name = exp_name.replace(', ', '-') 36 | return exp_name 37 | return _fn 38 | 39 | def lazy_fstring(template, args): 40 | ## https://stackoverflow.com/a/53671539 41 | return eval(f"f'{template}'") 42 | 43 | class Parser(Tap): 44 | 45 | def save(self): 46 | fullpath = os.path.join(self.savepath, 'args.json') 47 | print(f'[ utils/setup ] Saved args to {fullpath}') 48 | super().save(fullpath, skip_unpicklable=True) 49 | 50 | def parse_args(self, experiment=None): 51 | args = super().parse_args(known_only=True) 52 | ## if not loading from a config script, skip the result of the setup 53 | if not hasattr(args, 'config'): return args 54 | args = self.read_config(args, experiment) 55 | self.add_extras(args) 56 | self.eval_fstrings(args) 57 | self.set_seed(args) 58 | self.get_commit(args) 59 | self.set_loadbase(args) 60 | self.generate_exp_name(args) 61 | self.mkdir(args) 62 | self.save_diff(args) 63 | return args 64 | 65 | def read_config(self, args, experiment): 66 | ''' 67 | Load parameters from config file 68 | ''' 69 | dataset = args.dataset.replace('-', '_') 70 | print(f'[ utils/setup ] Reading config: {args.config}:{dataset}') 71 | module = importlib.import_module(args.config) 72 | params = getattr(module, 'base')[experiment] 73 | 74 | if hasattr(module, dataset) and experiment in getattr(module, dataset): 75 | print(f'[ utils/setup ] Using overrides | config: {args.config} | dataset: {dataset}') 76 | overrides = getattr(module, dataset)[experiment] 77 | params.update(overrides) 78 | else: 79 | print(f'[ utils/setup ] Not using overrides | config: {args.config} | dataset: {dataset}') 80 | 81 | self._dict = {} 82 | for key, val in params.items(): 83 | setattr(args, key, val) 84 | self._dict[key] = val 85 | 86 | return args 87 | 88 | def add_extras(self, args): 89 | ''' 90 | Override config parameters with command-line arguments 91 | ''' 92 | extras = args.extra_args 93 | if not len(extras): 94 | return 95 | 96 | print(f'[ utils/setup ] Found extras: {extras}') 97 | #assert len(extras) % 2 == 0, f'Found odd number ({len(extras)}) of extras: {extras}' 98 | for i in range(0, len(extras), 2): 99 | key = extras[i].replace('--', '') 100 | val = extras[i+1] 101 | assert hasattr(args, key), f'[ utils/setup ] {key} not found in config: {args.config}' 102 | old_val = getattr(args, key) 103 | old_type = type(old_val) 104 | print(f'[ utils/setup ] Overriding config | {key} : {old_val} --> {val}') 105 | if val == 'None': 106 | val = None 107 | elif val == 'latest': 108 | val = 'latest' 109 | elif old_type in [bool, type(None)]: 110 | try: 111 | val = eval(val) 112 | except: 113 | print(f'[ utils/setup ] Warning: could not parse {val} (old: {old_val}, {old_type}), using str') 114 | else: 115 | val = old_type(val) 116 | setattr(args, key, val) 117 | self._dict[key] = val 118 | 119 | def eval_fstrings(self, args): 120 | for key, old in self._dict.items(): 121 | if type(old) is str and old[:2] == 'f:': 122 | val = old.replace('{', '{args.').replace('f:', '') 123 | new = lazy_fstring(val, args) 124 | print(f'[ utils/setup ] Lazy fstring | {key} : {old} --> {new}') 125 | setattr(self, key, new) 126 | self._dict[key] = new 127 | 128 | def set_seed(self, args): 129 | if not hasattr(args, 'seed') or args.seed is None: 130 | return 131 | print(f'[ utils/setup ] Setting seed: {args.seed}') 132 | set_seed(args.seed) 133 | 134 | def set_loadbase(self, args): 135 | if hasattr(args, 'loadbase') and args.loadbase is None: 136 | print(f'[ utils/setup ] Setting loadbase: {args.logbase}') 137 | args.loadbase = args.logbase 138 | 139 | def generate_exp_name(self, args): 140 | if not 'exp_name' in dir(args): 141 | return 142 | exp_name = getattr(args, 'exp_name') 143 | if callable(exp_name): 144 | exp_name_string = exp_name(args) 145 | print(f'[ utils/setup ] Setting exp_name to: {exp_name_string}') 146 | setattr(args, 'exp_name', exp_name_string) 147 | self._dict['exp_name'] = exp_name_string 148 | 149 | def mkdir(self, args): 150 | if 'logbase' in dir(args) and 'dataset' in dir(args) and 'exp_name' in dir(args): 151 | args.savepath = os.path.join(args.logbase, args.dataset, args.exp_name) 152 | self._dict['savepath'] = args.savepath 153 | if 'suffix' in dir(args): 154 | args.savepath = os.path.join(args.savepath, args.suffix) 155 | if mkdir(args.savepath): 156 | print(f'[ utils/setup ] Made savepath: {args.savepath}') 157 | self.save() 158 | 159 | def get_commit(self, args): 160 | args.commit = get_git_rev() 161 | 162 | def save_diff(self, args): 163 | try: 164 | save_git_diff(os.path.join(args.savepath, 'diff.txt')) 165 | except: 166 | print('[ utils/setup ] WARNING: did not save git diff') 167 | -------------------------------------------------------------------------------- /diffuser/utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class Timer: 4 | 5 | def __init__(self): 6 | self._start = time.time() 7 | 8 | def __call__(self, reset=True): 9 | now = time.time() 10 | diff = now - self._start 11 | if reset: 12 | self._start = now 13 | return diff -------------------------------------------------------------------------------- /diffuser/utils/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import skvideo.io 4 | 5 | def _make_dir(filename): 6 | folder = os.path.dirname(filename) 7 | if not os.path.exists(folder): 8 | os.makedirs(folder) 9 | 10 | def save_video(filename, video_frames, fps=60, video_format='mp4'): 11 | assert fps == int(fps), fps 12 | _make_dir(filename) 13 | 14 | skvideo.io.vwrite( 15 | filename, 16 | video_frames, 17 | inputdict={ 18 | '-r': str(int(fps)), 19 | }, 20 | outputdict={ 21 | '-f': video_format, 22 | '-pix_fmt': 'yuv420p', # '-pix_fmt=yuv420p' needed for osx https://github.com/scikit-video/scikit-video/issues/74 23 | } 24 | ) 25 | 26 | def save_videos(filename, *video_frames, axis=1, **kwargs): 27 | ## video_frame : [ N x H x W x C ] 28 | video_frames = np.concatenate(video_frames, axis=axis) 29 | save_video(filename, video_frames, **kwargs) 30 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: VPDD 2 | channels: 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python=3.8 7 | - pip 8 | - patchelf 9 | - pip: 10 | - -f https://download.pytorch.org/whl/torch_stable.html 11 | - numpy 12 | - gym 13 | - mujoco-py==2.0.2.13 14 | - matplotlib==3.3.4 15 | - torch==1.9.1+cu111 16 | - typed-argument-parser 17 | - git+https://github.com/rail-berkeley/d4rl@master#egg=d4rl 18 | - git+https://github.com/Farama-Foundation/Metaworld.git@master#egg=metaworld 19 | - scikit-image==0.17.2 20 | - scikit-video==1.1.11 21 | - gitpython 22 | - einops 23 | - pillow 24 | - tqdm 25 | - pandas 26 | - wandb 27 | - flax >= 0.3.5 28 | - ray==1.9.1 29 | - ml_logger==0.8.69 30 | - jaynes==0.8.11 31 | - params_proto==2.9.6 32 | - transformers 33 | - pytorch_lightning==1.4.2 34 | - diffusers 35 | - imageio==2.9.0 36 | - imageio-ffmpeg==0.4.4 -------------------------------------------------------------------------------- /helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/__init__.py -------------------------------------------------------------------------------- /helpers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /helpers/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /helpers/__pycache__/custom_rlbench_env.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/__pycache__/custom_rlbench_env.cpython-38.pyc -------------------------------------------------------------------------------- /helpers/__pycache__/demo_loading_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/__pycache__/demo_loading_utils.cpython-38.pyc -------------------------------------------------------------------------------- /helpers/__pycache__/network_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/__pycache__/network_utils.cpython-38.pyc -------------------------------------------------------------------------------- /helpers/__pycache__/preprocess_agent.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/__pycache__/preprocess_agent.cpython-38.pyc -------------------------------------------------------------------------------- /helpers/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /helpers/clip/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/clip/__init__.py -------------------------------------------------------------------------------- /helpers/clip/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/clip/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /helpers/clip/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/clip/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /helpers/clip/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/clip/core/__init__.py -------------------------------------------------------------------------------- /helpers/clip/core/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/clip/core/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /helpers/clip/core/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/clip/core/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /helpers/clip/core/__pycache__/clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/clip/core/__pycache__/clip.cpython-38.pyc -------------------------------------------------------------------------------- /helpers/clip/core/__pycache__/clip.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/clip/core/__pycache__/clip.cpython-39.pyc -------------------------------------------------------------------------------- /helpers/clip/core/__pycache__/simple_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/clip/core/__pycache__/simple_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /helpers/clip/core/__pycache__/simple_tokenizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/clip/core/__pycache__/simple_tokenizer.cpython-39.pyc -------------------------------------------------------------------------------- /helpers/clip/core/attention.py: -------------------------------------------------------------------------------- 1 | """Attention module.""" 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import cliport.models as models 9 | from cliport.utils import utils 10 | 11 | 12 | class Attention(nn.Module): 13 | """Attention (a.k.a Pick) module.""" 14 | 15 | def __init__(self, stream_fcn, in_shape, n_rotations, preprocess, cfg, device): 16 | super().__init__() 17 | self.stream_fcn = stream_fcn 18 | self.n_rotations = n_rotations 19 | self.preprocess = preprocess 20 | self.cfg = cfg 21 | self.device = device 22 | self.batchnorm = self.cfg['train']['batchnorm'] 23 | 24 | self.padding = np.zeros((3, 2), dtype=int) 25 | max_dim = np.max(in_shape[:2]) 26 | pad = (max_dim - np.array(in_shape[:2])) / 2 27 | self.padding[:2] = pad.reshape(2, 1) 28 | 29 | in_shape = np.array(in_shape) 30 | in_shape += np.sum(self.padding, axis=1) 31 | in_shape = tuple(in_shape) 32 | self.in_shape = in_shape 33 | 34 | self.rotator = utils.ImageRotator(self.n_rotations) 35 | 36 | self._build_nets() 37 | 38 | def _build_nets(self): 39 | stream_one_fcn, _ = self.stream_fcn 40 | self.attn_stream = models.names[stream_one_fcn](self.in_shape, 1, self.cfg, self.device) 41 | print(f"Attn FCN: {stream_one_fcn}") 42 | 43 | def attend(self, x): 44 | return self.attn_stream(x) 45 | 46 | def forward(self, inp_img, softmax=True): 47 | """Forward pass.""" 48 | in_data = np.pad(inp_img, self.padding, mode='constant') 49 | in_shape = (1,) + in_data.shape 50 | in_data = in_data.reshape(in_shape) 51 | in_tens = torch.from_numpy(in_data).to(dtype=torch.float, device=self.device) # [B W H 6] 52 | 53 | # Rotation pivot. 54 | pv = np.array(in_data.shape[1:3]) // 2 55 | 56 | # Rotate input. 57 | in_tens = in_tens.permute(0, 3, 1, 2) # [B 6 W H] 58 | in_tens = in_tens.repeat(self.n_rotations, 1, 1, 1) 59 | in_tens = self.rotator(in_tens, pivot=pv) 60 | 61 | # Forward pass. 62 | logits = [] 63 | for x in in_tens: 64 | lgts = self.attend(x) 65 | logits.append(lgts) 66 | logits = torch.cat(logits, dim=0) 67 | 68 | # Rotate back output. 69 | logits = self.rotator(logits, reverse=True, pivot=pv) 70 | logits = torch.cat(logits, dim=0) 71 | c0 = self.padding[:2, 0] 72 | c1 = c0 + inp_img.shape[:2] 73 | logits = logits[:, :, c0[0]:c1[0], c0[1]:c1[1]] 74 | 75 | logits = logits.permute(1, 2, 3, 0) # [B W H 1] 76 | output = logits.reshape(1, np.prod(logits.shape)) 77 | if softmax: 78 | output = F.softmax(output, dim=-1) 79 | output = output.reshape(logits.shape[1:]) 80 | return output -------------------------------------------------------------------------------- /helpers/clip/core/attention_image_goal.py: -------------------------------------------------------------------------------- 1 | """Attention module.""" 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | from cliport.models.core.attention import Attention 9 | 10 | 11 | class AttentionImageGoal(Attention): 12 | """Attention (a.k.a Pick) with image-goals module.""" 13 | 14 | def __init__(self, stream_fcn, in_shape, n_rotations, preprocess, cfg, device): 15 | super().__init__(stream_fcn, in_shape, n_rotations, preprocess, cfg, device) 16 | 17 | def forward(self, inp_img, goal_img, softmax=True): 18 | """Forward pass.""" 19 | # Input image. 20 | in_data = np.pad(inp_img, self.padding, mode='constant') 21 | in_shape = (1,) + in_data.shape 22 | in_data = in_data.reshape(in_shape) 23 | in_tens = torch.from_numpy(in_data).to(dtype=torch.float, device=self.device) 24 | 25 | goal_tensor = np.pad(goal_img, self.padding, mode='constant') 26 | goal_shape = (1,) + goal_tensor.shape 27 | goal_tensor = goal_tensor.reshape(goal_shape) 28 | goal_tensor = torch.from_numpy(goal_tensor.copy()).to(dtype=torch.float, device=self.device) 29 | in_tens = in_tens * goal_tensor 30 | 31 | # Rotation pivot. 32 | pv = np.array(in_data.shape[1:3]) // 2 33 | 34 | # Rotate input. 35 | in_tens = in_tens.permute(0, 3, 1, 2) 36 | in_tens = in_tens.repeat(self.n_rotations, 1, 1, 1) 37 | in_tens = self.rotator(in_tens, pivot=pv) 38 | 39 | # Forward pass. 40 | logits = [] 41 | for x in in_tens: 42 | logits.append(self.attend(x)) 43 | logits = torch.cat(logits, dim=0) 44 | 45 | # Rotate back output. 46 | logits = self.rotator(logits, reverse=True, pivot=pv) 47 | logits = torch.cat(logits, dim=0) 48 | c0 = self.padding[:2, 0] 49 | c1 = c0 + inp_img.shape[:2] 50 | logits = logits[:, :, c0[0]:c1[0], c0[1]:c1[1]] 51 | 52 | logits = logits.permute(1, 2, 3, 0) # D H W C 53 | output = logits.reshape(1, np.prod(logits.shape)) 54 | if softmax: 55 | output = F.softmax(output, dim=-1) 56 | output = output.reshape(logits.shape[1:]) 57 | return output -------------------------------------------------------------------------------- /helpers/clip/core/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/clip/core/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /helpers/clip/core/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class IdentityBlock(nn.Module): 7 | def __init__(self, in_planes, filters, kernel_size, stride=1, final_relu=True, batchnorm=True): 8 | super(IdentityBlock, self).__init__() 9 | self.final_relu = final_relu 10 | self.batchnorm = batchnorm 11 | 12 | filters1, filters2, filters3 = filters 13 | self.conv1 = nn.Conv2d(in_planes, filters1, kernel_size=1, bias=False) 14 | self.bn1 = nn.BatchNorm2d(filters1) if self.batchnorm else nn.Identity() 15 | self.conv2 = nn.Conv2d(filters1, filters2, kernel_size=kernel_size, dilation=1, 16 | stride=stride, padding=1, bias=False) 17 | self.bn2 = nn.BatchNorm2d(filters2) if self.batchnorm else nn.Identity() 18 | self.conv3 = nn.Conv2d(filters2, filters3, kernel_size=1, bias=False) 19 | self.bn3 = nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity() 20 | 21 | def forward(self, x): 22 | out = F.relu(self.bn1(self.conv1(x))) 23 | out = F.relu(self.bn2(self.conv2(out))) 24 | out = self.bn3(self.conv3(out)) 25 | out += x 26 | if self.final_relu: 27 | out = F.relu(out) 28 | return out 29 | 30 | 31 | class ConvBlock(nn.Module): 32 | def __init__(self, in_planes, filters, kernel_size, stride=1, final_relu=True, batchnorm=True): 33 | super(ConvBlock, self).__init__() 34 | self.final_relu = final_relu 35 | self.batchnorm = batchnorm 36 | 37 | filters1, filters2, filters3 = filters 38 | self.conv1 = nn.Conv2d(in_planes, filters1, kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm2d(filters1) if self.batchnorm else nn.Identity() 40 | self.conv2 = nn.Conv2d(filters1, filters2, kernel_size=kernel_size, dilation=1, 41 | stride=stride, padding=1, bias=False) 42 | self.bn2 = nn.BatchNorm2d(filters2) if self.batchnorm else nn.Identity() 43 | self.conv3 = nn.Conv2d(filters2, filters3, kernel_size=1, bias=False) 44 | self.bn3 = nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity() 45 | 46 | self.shortcut = nn.Sequential( 47 | nn.Conv2d(in_planes, filters3, 48 | kernel_size=1, stride=stride, bias=False), 49 | nn.BatchNorm2d(filters3) if self.batchnorm else nn.Identity() 50 | ) 51 | 52 | def forward(self, x): 53 | out = F.relu(self.bn1(self.conv1(x))) 54 | out = F.relu(self.bn2(self.conv2(out))) 55 | out = self.bn3(self.conv3(out)) 56 | out += self.shortcut(x) 57 | if self.final_relu: 58 | out = F.relu(out) 59 | return out 60 | 61 | 62 | class ResNet43_8s(nn.Module): 63 | def __init__(self, input_shape, output_dim, cfg, device, preprocess): 64 | super(ResNet43_8s, self).__init__() 65 | self.input_shape = input_shape 66 | self.input_dim = input_shape[-1] 67 | self.output_dim = output_dim 68 | self.cfg = cfg 69 | self.device = device 70 | self.batchnorm = self.cfg['train']['batchnorm'] 71 | self.preprocess = preprocess 72 | 73 | self.layers = self._make_layers() 74 | 75 | def _make_layers(self): 76 | layers = nn.Sequential( 77 | # conv1 78 | nn.Conv2d(self.input_dim, 64, stride=1, kernel_size=3, padding=1), 79 | nn.BatchNorm2d(64) if self.batchnorm else nn.Identity(), 80 | nn.ReLU(True), 81 | 82 | # fcn 83 | ConvBlock(64, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm), 84 | IdentityBlock(64, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm), 85 | 86 | ConvBlock(64, [128, 128, 128], kernel_size=3, stride=2, batchnorm=self.batchnorm), 87 | IdentityBlock(128, [128, 128, 128], kernel_size=3, stride=1, batchnorm=self.batchnorm), 88 | 89 | ConvBlock(128, [256, 256, 256], kernel_size=3, stride=2, batchnorm=self.batchnorm), 90 | IdentityBlock(256, [256, 256, 256], kernel_size=3, stride=1, batchnorm=self.batchnorm), 91 | 92 | ConvBlock(256, [512, 512, 512], kernel_size=3, stride=2, batchnorm=self.batchnorm), 93 | IdentityBlock(512, [512, 512, 512], kernel_size=3, stride=1, batchnorm=self.batchnorm), 94 | 95 | # head 96 | ConvBlock(512, [256, 256, 256], kernel_size=3, stride=1, batchnorm=self.batchnorm), 97 | IdentityBlock(256, [256, 256, 256], kernel_size=3, stride=1, batchnorm=self.batchnorm), 98 | nn.UpsamplingBilinear2d(scale_factor=2), 99 | 100 | ConvBlock(256, [128, 128, 128], kernel_size=3, stride=1, batchnorm=self.batchnorm), 101 | IdentityBlock(128, [128, 128, 128], kernel_size=3, stride=1, batchnorm=self.batchnorm), 102 | nn.UpsamplingBilinear2d(scale_factor=2), 103 | 104 | ConvBlock(128, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm), 105 | IdentityBlock(64, [64, 64, 64], kernel_size=3, stride=1, batchnorm=self.batchnorm), 106 | nn.UpsamplingBilinear2d(scale_factor=2), 107 | 108 | # conv2 109 | ConvBlock(64, [16, 16, self.output_dim], kernel_size=3, stride=1, 110 | final_relu=False, batchnorm=self.batchnorm), 111 | IdentityBlock(self.output_dim, [16, 16, self.output_dim], kernel_size=3, stride=1, 112 | final_relu=False, batchnorm=self.batchnorm), 113 | ) 114 | return layers 115 | 116 | def forward(self, x): 117 | x = self.preprocess(x, dist='transporter') 118 | 119 | out = self.layers(x) 120 | return out -------------------------------------------------------------------------------- /helpers/clip/core/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text -------------------------------------------------------------------------------- /helpers/clip/core/transport.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cliport.models as models 3 | from cliport.utils import utils 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class Transport(nn.Module): 11 | 12 | def __init__(self, stream_fcn, in_shape, n_rotations, crop_size, preprocess, cfg, device): 13 | """Transport (a.k.a Place) module.""" 14 | super().__init__() 15 | 16 | self.iters = 0 17 | self.stream_fcn = stream_fcn 18 | self.n_rotations = n_rotations 19 | self.crop_size = crop_size # crop size must be N*16 (e.g. 96) 20 | self.preprocess = preprocess 21 | self.cfg = cfg 22 | self.device = device 23 | self.batchnorm = self.cfg['train']['batchnorm'] 24 | 25 | self.pad_size = int(self.crop_size / 2) 26 | self.padding = np.zeros((3, 2), dtype=int) 27 | self.padding[:2, :] = self.pad_size 28 | 29 | in_shape = np.array(in_shape) 30 | in_shape = tuple(in_shape) 31 | self.in_shape = in_shape 32 | 33 | # Crop before network (default from Transporters CoRL 2020). 34 | self.kernel_shape = (self.crop_size, self.crop_size, self.in_shape[2]) 35 | 36 | if not hasattr(self, 'output_dim'): 37 | self.output_dim = 3 38 | if not hasattr(self, 'kernel_dim'): 39 | self.kernel_dim = 3 40 | 41 | self.rotator = utils.ImageRotator(self.n_rotations) 42 | 43 | self._build_nets() 44 | 45 | def _build_nets(self): 46 | stream_one_fcn, _ = self.stream_fcn 47 | model = models.names[stream_one_fcn] 48 | self.key_resnet = model(self.in_shape, self.output_dim, self.cfg, self.device) 49 | self.query_resnet = model(self.kernel_shape, self.kernel_dim, self.cfg, self.device) 50 | print(f"Transport FCN: {stream_one_fcn}") 51 | 52 | def correlate(self, in0, in1, softmax): 53 | """Correlate two input tensors.""" 54 | output = F.conv2d(in0, in1, padding=(self.pad_size, self.pad_size)) 55 | output = F.interpolate(output, size=(in0.shape[-2], in0.shape[-1]), mode='bilinear') 56 | output = output[:,:,self.pad_size:-self.pad_size, self.pad_size:-self.pad_size] 57 | if softmax: 58 | output_shape = output.shape 59 | output = output.reshape((1, np.prod(output.shape))) 60 | output = F.softmax(output, dim=-1) 61 | output = output.reshape(output_shape[1:]) 62 | return output 63 | 64 | def transport(self, in_tensor, crop): 65 | logits = self.key_resnet(in_tensor) 66 | kernel = self.query_resnet(crop) 67 | return logits, kernel 68 | 69 | def forward(self, inp_img, p, softmax=True): 70 | """Forward pass.""" 71 | img_unprocessed = np.pad(inp_img, self.padding, mode='constant') 72 | input_data = img_unprocessed 73 | in_shape = (1,) + input_data.shape 74 | input_data = input_data.reshape(in_shape) # [B W H D] 75 | in_tensor = torch.from_numpy(input_data).to(dtype=torch.float, device=self.device) 76 | 77 | # Rotation pivot. 78 | pv = np.array([p[0], p[1]]) + self.pad_size 79 | 80 | # Crop before network (default from Transporters CoRL 2020). 81 | hcrop = self.pad_size 82 | in_tensor = in_tensor.permute(0, 3, 1, 2) # [B D W H] 83 | 84 | crop = in_tensor.repeat(self.n_rotations, 1, 1, 1) 85 | crop = self.rotator(crop, pivot=pv) 86 | crop = torch.cat(crop, dim=0) 87 | crop = crop[:, :, pv[0]-hcrop:pv[0]+hcrop, pv[1]-hcrop:pv[1]+hcrop] 88 | 89 | logits, kernel = self.transport(in_tensor, crop) 90 | 91 | # TODO(Mohit): Crop after network. Broken for now. 92 | # in_tensor = in_tensor.permute(0, 3, 1, 2) 93 | # logits, crop = self.transport(in_tensor) 94 | # crop = crop.repeat(self.n_rotations, 1, 1, 1) 95 | # crop = self.rotator(crop, pivot=pv) 96 | # crop = torch.cat(crop, dim=0) 97 | 98 | # kernel = crop[:, :, pv[0]-hcrop:pv[0]+hcrop, pv[1]-hcrop:pv[1]+hcrop] 99 | # kernel = crop[:, :, p[0]:(p[0] + self.crop_size), p[1]:(p[1] + self.crop_size)] 100 | 101 | return self.correlate(logits, kernel, softmax) 102 | 103 | -------------------------------------------------------------------------------- /helpers/clip/core/transport_image_goal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cliport.models as models 3 | from cliport.utils import utils 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class TransportImageGoal(nn.Module): 11 | """Transport module.""" 12 | 13 | def __init__(self, stream_fcn, in_shape, n_rotations, crop_size, preprocess, cfg, device): 14 | """Transport module for placing. 15 | Args: 16 | in_shape: shape of input image. 17 | n_rotations: number of rotations of convolving kernel. 18 | crop_size: crop size around pick argmax used as convolving kernel. 19 | preprocess: function to preprocess input images. 20 | """ 21 | super().__init__() 22 | 23 | self.iters = 0 24 | self.stream_fcn = stream_fcn 25 | self.n_rotations = n_rotations 26 | self.crop_size = crop_size # crop size must be N*16 (e.g. 96) 27 | self.preprocess = preprocess 28 | self.cfg = cfg 29 | self.device = device 30 | self.batchnorm = self.cfg['train']['batchnorm'] 31 | 32 | self.pad_size = int(self.crop_size / 2) 33 | self.padding = np.zeros((3, 2), dtype=int) 34 | self.padding[:2, :] = self.pad_size 35 | 36 | in_shape = np.array(in_shape) 37 | in_shape = tuple(in_shape) 38 | self.in_shape = in_shape 39 | 40 | # Crop before network (default for Transporters CoRL 2020). 41 | self.kernel_shape = (self.crop_size, self.crop_size, self.in_shape[2]) 42 | 43 | if not hasattr(self, 'output_dim'): 44 | self.output_dim = 3 45 | if not hasattr(self, 'kernel_dim'): 46 | self.kernel_dim = 3 47 | 48 | self.rotator = utils.ImageRotator(self.n_rotations) 49 | 50 | self._build_nets() 51 | 52 | def _build_nets(self): 53 | stream_one_fcn, _ = self.stream_fcn 54 | model = models.names[stream_one_fcn] 55 | self.key_resnet = model(self.in_shape, self.output_dim, self.cfg, self.device) 56 | self.query_resnet = model(self.in_shape, self.kernel_dim, self.cfg, self.device) 57 | self.goal_resnet = model(self.in_shape, self.output_dim, self.cfg, self.device) 58 | print(f"Transport FCN: {stream_one_fcn}") 59 | 60 | def correlate(self, in0, in1, softmax): 61 | """Correlate two input tensors.""" 62 | output = F.conv2d(in0, in1, padding=(self.pad_size, self.pad_size)) 63 | output = F.interpolate(output, size=(in0.shape[-2], in0.shape[-1]), mode='bilinear') 64 | output = output[:,:,self.pad_size:-self.pad_size, self.pad_size:-self.pad_size] 65 | if softmax: 66 | output_shape = output.shape 67 | output = output.reshape((1, np.prod(output.shape))) 68 | output = F.softmax(output, dim=-1) 69 | output = output.reshape(output_shape[1:]) 70 | return output 71 | 72 | def forward(self, inp_img, goal_img, p, softmax=True): 73 | """Forward pass.""" 74 | 75 | # Input image. 76 | img_unprocessed = np.pad(inp_img, self.padding, mode='constant') 77 | input_data = img_unprocessed 78 | in_shape = (1,) + input_data.shape 79 | input_data = input_data.reshape(in_shape) 80 | in_tensor = torch.from_numpy(input_data.copy()).to(dtype=torch.float, device=self.device) 81 | in_tensor = in_tensor.permute(0, 3, 1, 2) 82 | 83 | # Goal image. 84 | goal_tensor = np.pad(goal_img, self.padding, mode='constant') 85 | goal_shape = (1,) + goal_tensor.shape 86 | goal_tensor = goal_tensor.reshape(goal_shape) 87 | goal_tensor = torch.from_numpy(goal_tensor.copy()).to(dtype=torch.float, device=self.device) 88 | goal_tensor = goal_tensor.permute(0, 3, 1, 2) 89 | 90 | # Rotation pivot. 91 | pv = np.array([p[0], p[1]]) + self.pad_size 92 | hcrop = self.pad_size 93 | 94 | # Cropped input features. 95 | in_crop = in_tensor.repeat(self.n_rotations, 1, 1, 1) 96 | in_crop = self.rotator(in_crop, pivot=pv) 97 | in_crop = torch.cat(in_crop, dim=0) 98 | in_crop = in_crop[:, :, pv[0]-hcrop:pv[0]+hcrop, pv[1]-hcrop:pv[1]+hcrop] 99 | 100 | # Cropped goal features. 101 | goal_crop = goal_tensor.repeat(self.n_rotations, 1, 1, 1) 102 | goal_crop = self.rotator(goal_crop, pivot=pv) 103 | goal_crop = torch.cat(goal_crop, dim=0) 104 | goal_crop = goal_crop[:, :, pv[0]-hcrop:pv[0]+hcrop, pv[1]-hcrop:pv[1]+hcrop] 105 | 106 | in_logits = self.key_resnet(in_tensor) 107 | goal_logits = self.goal_resnet(goal_tensor) 108 | kernel_crop = self.query_resnet(in_crop) 109 | goal_crop = self.goal_resnet(goal_crop) 110 | 111 | # Fuse Goal and Transport features 112 | goal_x_in_logits = goal_logits + in_logits # Mohit: why doesn't multiply work? :( 113 | goal_x_kernel = goal_crop + kernel_crop 114 | 115 | # TODO(Mohit): Crop after network. Broken for now 116 | # in_logits = self.key_resnet(in_tensor) 117 | # kernel_nocrop_logits = self.query_resnet(in_tensor) 118 | # goal_logits = self.goal_resnet(goal_tensor) 119 | 120 | # goal_x_in_logits = in_logits 121 | # goal_x_kernel_logits = goal_logits * kernel_nocrop_logits 122 | 123 | # goal_crop = goal_x_kernel_logits.repeat(self.n_rotations, 1, 1, 1) 124 | # goal_crop = self.rotator(goal_crop, pivot=pv) 125 | # goal_crop = torch.cat(goal_crop, dim=0) 126 | # goal_crop = goal_crop[:, :, pv[0]-hcrop:pv[0]+hcrop, pv[1]-hcrop:pv[1]+hcrop] 127 | 128 | return self.correlate(goal_x_in_logits, goal_x_kernel, softmax) 129 | 130 | -------------------------------------------------------------------------------- /helpers/clip/core/unet.py: -------------------------------------------------------------------------------- 1 | # Credit: https://github.com/milesial/Pytorch-UNet/ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | # nn.BatchNorm2d(mid_channels), # (Mohit): argh... forgot to remove this batchnorm 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | # nn.BatchNorm2d(out_channels), # (Mohit): argh... forgot to remove this batchnorm 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # input is CHW 60 | diffY = x2.size()[2] - x1.size()[2] 61 | diffX = x2.size()[3] - x1.size()[3] 62 | 63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | diffY // 2, diffY - diffY // 2]) 65 | # if you have padding issues, see 66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 68 | x = torch.cat([x2, x1], dim=1) 69 | return self.conv(x) 70 | 71 | 72 | class OutConv(nn.Module): 73 | def __init__(self, in_channels, out_channels): 74 | super(OutConv, self).__init__() 75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 76 | 77 | def forward(self, x): 78 | return self.conv(x) -------------------------------------------------------------------------------- /helpers/demo_loading_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | import numpy as np 5 | from rlbench.demo import Demo 6 | 7 | 8 | def _is_stopped(demo, i, obs, stopped_buffer, delta=0.1): 9 | next_is_not_final = i == (len(demo) - 2) 10 | gripper_state_no_change = ( 11 | i < (len(demo) - 2) and 12 | (obs.gripper_open == demo[i + 1].gripper_open and 13 | obs.gripper_open == demo[i - 1].gripper_open and 14 | demo[i - 2].gripper_open == demo[i - 1].gripper_open)) 15 | small_delta = np.allclose(obs.joint_velocities, 0, atol=delta) 16 | stopped = (stopped_buffer <= 0 and small_delta and 17 | (not next_is_not_final) and gripper_state_no_change) 18 | return stopped 19 | 20 | 21 | def keypoint_discovery(demo: Demo, 22 | stopping_delta=0.1, 23 | method='heuristic') -> List[int]: 24 | episode_keypoints = [] 25 | if method == 'heuristic': 26 | prev_gripper_open = demo[0].gripper_open 27 | stopped_buffer = 0 28 | for i, obs in enumerate(demo): 29 | stopped = _is_stopped(demo, i, obs, stopped_buffer, stopping_delta) 30 | stopped_buffer = 4 if stopped else stopped_buffer - 1 31 | # If change in gripper, or end of episode. 32 | last = i == (len(demo) - 1) 33 | if i != 0 and (obs.gripper_open != prev_gripper_open or 34 | last or stopped): 35 | episode_keypoints.append(i) 36 | prev_gripper_open = obs.gripper_open 37 | if len(episode_keypoints) > 1 and (episode_keypoints[-1] - 1) == \ 38 | episode_keypoints[-2]: 39 | episode_keypoints.pop(-2) 40 | logging.debug('Found %d keypoints.' % len(episode_keypoints), 41 | episode_keypoints) 42 | return episode_keypoints 43 | 44 | elif method == 'random': 45 | # Randomly select keypoints. 46 | episode_keypoints = np.random.choice( 47 | range(len(demo)), 48 | size=20, 49 | replace=False) 50 | episode_keypoints.sort() 51 | return episode_keypoints 52 | 53 | elif method == 'fixed_interval': 54 | # Fixed interval. 55 | episode_keypoints = [] 56 | segment_length = len(demo) // 20 57 | for i in range(0, len(demo), segment_length): 58 | episode_keypoints.append(i) 59 | return episode_keypoints 60 | 61 | else: 62 | raise NotImplementedError 63 | 64 | 65 | # find minimum difference between any two elements in list 66 | def find_minimum_difference(lst): 67 | minimum = lst[-1] 68 | for i in range(1, len(lst)): 69 | if lst[i] - lst[i - 1] < minimum: 70 | minimum = lst[i] - lst[i - 1] 71 | return minimum -------------------------------------------------------------------------------- /helpers/optim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/optim/__init__.py -------------------------------------------------------------------------------- /helpers/optim/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/optim/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /helpers/optim/__pycache__/lamb.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/helpers/optim/__pycache__/lamb.cpython-38.pyc -------------------------------------------------------------------------------- /helpers/optim/lamb.py: -------------------------------------------------------------------------------- 1 | """Lamb optimizer.""" 2 | 3 | # LAMB optimizer used as is. 4 | # Source: https://github.com/cybertronai/pytorch-lamb 5 | # License: https://github.com/cybertronai/pytorch-lamb/blob/master/LICENSE 6 | 7 | import collections 8 | import math 9 | 10 | import torch 11 | from torch.optim import Optimizer 12 | 13 | 14 | # def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): 15 | # """Log a histogram of trust ratio scalars in across layers.""" 16 | # results = collections.defaultdict(list) 17 | # for group in optimizer.param_groups: 18 | # for p in group['params']: 19 | # state = optimizer.state[p] 20 | # for i in ('weight_norm', 'adam_norm', 'trust_ratio'): 21 | # if i in state: 22 | # results[i].append(state[i]) 23 | # 24 | # for k, v in results.items(): 25 | # event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) 26 | 27 | class Lamb(Optimizer): 28 | r"""Implements Lamb algorithm. 29 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 30 | Arguments: 31 | params (iterable): iterable of parameters to optimize or dicts defining 32 | parameter groups 33 | lr (float, optional): learning rate (default: 1e-3) 34 | betas (Tuple[float, float], optional): coefficients used for computing 35 | running averages of gradient and its square (default: (0.9, 0.999)) 36 | eps (float, optional): term added to the denominator to improve 37 | numerical stability (default: 1e-8) 38 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 39 | adam (bool, optional): always use trust ratio = 1, which turns this into 40 | Adam. Useful for comparison purposes. 41 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 42 | https://arxiv.org/abs/1904.00962 43 | """ 44 | 45 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 46 | weight_decay=0, adam=False): 47 | if not 0.0 <= lr: 48 | raise ValueError("Invalid learning rate: {}".format(lr)) 49 | if not 0.0 <= eps: 50 | raise ValueError("Invalid epsilon value: {}".format(eps)) 51 | if not 0.0 <= betas[0] < 1.0: 52 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 53 | if not 0.0 <= betas[1] < 1.0: 54 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 55 | defaults = dict(lr=lr, betas=betas, eps=eps, 56 | weight_decay=weight_decay) 57 | self.adam = adam 58 | super(Lamb, self).__init__(params, defaults) 59 | 60 | def step(self, closure=None): 61 | """Performs a single optimization step. 62 | Arguments: 63 | closure (callable, optional): A closure that reevaluates the model 64 | and returns the loss. 65 | """ 66 | loss = None 67 | if closure is not None: 68 | loss = closure() 69 | 70 | for group in self.param_groups: 71 | for p in group['params']: 72 | if p.grad is None: 73 | continue 74 | grad = p.grad.data 75 | if grad.is_sparse: 76 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 77 | 78 | state = self.state[p] 79 | 80 | # State initialization 81 | if len(state) == 0: 82 | state['step'] = 0 83 | # Exponential moving average of gradient values 84 | state['exp_avg'] = torch.zeros_like(p.data) 85 | # Exponential moving average of squared gradient values 86 | state['exp_avg_sq'] = torch.zeros_like(p.data) 87 | 88 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 89 | beta1, beta2 = group['betas'] 90 | 91 | state['step'] += 1 92 | 93 | # Decay the first and second moment running average coefficient 94 | # m_t 95 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 96 | # v_t 97 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 98 | 99 | # Paper v3 does not use debiasing. 100 | # bias_correction1 = 1 - beta1 ** state['step'] 101 | # bias_correction2 = 1 - beta2 ** state['step'] 102 | # Apply bias to lr to avoid broadcast. 103 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 104 | 105 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 106 | 107 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 108 | if group['weight_decay'] != 0: 109 | adam_step.add_(p.data, alpha=group['weight_decay']) 110 | 111 | adam_norm = adam_step.pow(2).sum().sqrt() 112 | if weight_norm == 0 or adam_norm == 0: 113 | trust_ratio = 1 114 | else: 115 | trust_ratio = weight_norm / adam_norm 116 | state['weight_norm'] = weight_norm 117 | state['adam_norm'] = adam_norm 118 | state['trust_ratio'] = trust_ratio 119 | if self.adam: 120 | trust_ratio = 1 121 | 122 | p.data.add_(adam_step, alpha=-step_size * trust_ratio) 123 | 124 | return loss -------------------------------------------------------------------------------- /helpers/preprocess_agent.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | from yarr.agents.agent import Agent, Summary, ActResult, \ 6 | ScalarSummary, HistogramSummary, ImageSummary 7 | 8 | 9 | class PreprocessAgent(Agent): 10 | 11 | def __init__(self, 12 | pose_agent: Agent, 13 | norm_rgb: bool = True): 14 | self._pose_agent = pose_agent 15 | self._norm_rgb = norm_rgb 16 | 17 | def build(self, training: bool, device: torch.device = None): 18 | self._pose_agent.build(training, device) 19 | 20 | def _norm_rgb_(self, x): 21 | return (x.float() / 255.0) * 2.0 - 1.0 22 | 23 | def update(self, step: int, replay_sample: dict) -> dict: 24 | # Samples are (B, N, ...) where N is number of buffers/tasks. This is a single task setup, so 0 index. 25 | replay_sample = {k: v[:, 0] if len(v.shape) > 2 else v for k, v in replay_sample.items()} 26 | for k, v in replay_sample.items(): 27 | if self._norm_rgb and 'rgb' in k: 28 | replay_sample[k] = self._norm_rgb_(v) 29 | else: 30 | replay_sample[k] = v.float() 31 | self._replay_sample = replay_sample 32 | return self._pose_agent.update(step, replay_sample) 33 | 34 | def act(self, step: int, observation: dict, 35 | deterministic=False) -> ActResult: 36 | # observation = {k: torch.tensor(v) for k, v in observation.items()} 37 | for k, v in observation.items(): 38 | if self._norm_rgb and 'rgb' in k: 39 | observation[k] = self._norm_rgb_(v) 40 | else: 41 | observation[k] = v.float() 42 | act_res = self._pose_agent.act(step, observation, deterministic) 43 | act_res.replay_elements.update({'demo': False}) 44 | return act_res 45 | 46 | def update_summaries(self) -> List[Summary]: 47 | prefix = 'inputs' 48 | demo_f = self._replay_sample['demo'].float() 49 | demo_proportion = demo_f.mean() 50 | tile = lambda x: torch.squeeze( 51 | torch.cat(x.split(1, dim=1), dim=-1), dim=1) 52 | sums = [ 53 | ScalarSummary('%s/demo_proportion' % prefix, demo_proportion), 54 | HistogramSummary('%s/low_dim_state' % prefix, 55 | self._replay_sample['low_dim_state']), 56 | HistogramSummary('%s/low_dim_state_tp1' % prefix, 57 | self._replay_sample['low_dim_state_tp1']), 58 | ScalarSummary('%s/low_dim_state_mean' % prefix, 59 | self._replay_sample['low_dim_state'].mean()), 60 | ScalarSummary('%s/low_dim_state_min' % prefix, 61 | self._replay_sample['low_dim_state'].min()), 62 | ScalarSummary('%s/low_dim_state_max' % prefix, 63 | self._replay_sample['low_dim_state'].max()), 64 | ScalarSummary('%s/timeouts' % prefix, 65 | self._replay_sample['timeout'].float().mean()), 66 | ] 67 | 68 | for k, v in self._replay_sample.items(): 69 | if 'rgb' in k or 'point_cloud' in k: 70 | if 'rgb' in k: 71 | # Convert back to 0 - 1 72 | v = (v + 1.0) / 2.0 73 | sums.append(ImageSummary('%s/%s' % (prefix, k), tile(v))) 74 | 75 | if 'sampling_probabilities' in self._replay_sample: 76 | sums.extend([ 77 | HistogramSummary('replay/priority', 78 | self._replay_sample['sampling_probabilities']), 79 | ]) 80 | sums.extend(self._pose_agent.update_summaries()) 81 | return sums 82 | 83 | def act_summaries(self) -> List[Summary]: 84 | return self._pose_agent.act_summaries() 85 | 86 | def load_weights(self, savedir: str): 87 | self._pose_agent.load_weights(savedir) 88 | 89 | def save_weights(self, savedir: str): 90 | self._pose_agent.save_weights(savedir) 91 | 92 | def reset(self) -> None: 93 | self._pose_agent.reset() 94 | 95 | -------------------------------------------------------------------------------- /scripts/compute_bound_box.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script computes the minimum and maximum gripper locations for 3 | each task in the training set. 4 | """ 5 | 6 | import tap 7 | from typing import List, Tuple, Optional 8 | from pathlib import Path 9 | import torch 10 | import pprint 11 | import json 12 | 13 | 14 | 15 | if __name__ == "__main__": 16 | args = Arguments().parse_args() 17 | 18 | bounds = {task: [] for task in args.tasks} 19 | 20 | for task in args.tasks: 21 | instruction = load_instructions( 22 | args.instructions, tasks=[task], variations=args.variations 23 | ) 24 | 25 | taskvar = [ 26 | (task, var) 27 | for task, var_instr in instruction.items() 28 | for var in var_instr.keys() 29 | ] 30 | max_episode_length = get_max_episode_length([task], args.variations) 31 | 32 | dataset = RLBenchDataset( 33 | root=args.dataset, 34 | instructions=instruction, 35 | taskvar=taskvar, 36 | max_episode_length=max_episode_length, 37 | cache_size=args.cache_size, 38 | max_episodes_per_task=args.max_episodes_per_task, 39 | cameras=args.cameras, # type: ignore 40 | return_low_lvl_trajectory=True, 41 | dense_interpolation=True, 42 | interpolation_length=50, 43 | training=False 44 | ) 45 | 46 | print( 47 | f"Computing gripper location bounds for task {task} " 48 | f"from dataset of length {len(dataset)}" 49 | ) 50 | 51 | for i in range(len(dataset)): 52 | ep = dataset[i] 53 | bounds[task].append(ep["action"][:, :3]) 54 | bounds[task].append(ep["trajectory"][..., :3].reshape([-1, 3])) 55 | 56 | bounds = { 57 | task: [ 58 | torch.cat(gripper_locs, dim=0).min(dim=0).values.tolist(), 59 | torch.cat(gripper_locs, dim=0).max(dim=0).values.tolist() 60 | ] 61 | for task, gripper_locs in bounds.items() 62 | if len(gripper_locs) > 0 63 | } 64 | 65 | pprint.pprint(bounds) 66 | json.dump(bounds, open(args.out_file, "w"), indent=4) -------------------------------------------------------------------------------- /scripts/compute_fvd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import functools 3 | import argparse 4 | from videogpt.download import load_i3d_pretrained 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | import torch 9 | import torch.multiprocessing as mp 10 | import torch.distributed as dist 11 | 12 | from videogpt.fvd.fvd import get_fvd_logits, frechet_distance 13 | from videogpt import VideoData, VideoGPT, load_videogpt 14 | 15 | 16 | MAX_BATCH = 32 17 | 18 | 19 | def main(): 20 | assert torch.cuda.is_available() 21 | ngpus = torch.cuda.device_count() 22 | assert 256 % ngpus == 0, f"Must have 256 % n_gpus == 0" 23 | 24 | mp.spawn(main_worker, nprocs=ngpus, args=(ngpus, args), join=True) 25 | 26 | 27 | def main_worker(rank, size, args_in): 28 | global args 29 | args = args_in 30 | is_root = rank == 0 31 | dist.init_process_group(backend='nccl', init_method=f'tcp://localhost:{args.port}', 32 | world_size=size, rank=rank) 33 | device = torch.device(f"cuda:{rank}") 34 | torch.cuda.set_device(device) 35 | torch.set_grad_enabled(False) 36 | 37 | n_trials = args.n_trials 38 | 39 | #################### Load VideoGPT ######################################## 40 | if not os.path.exists(args.ckpt): 41 | gpt = load_videogpt(args.ckpt, device=device) 42 | else: 43 | gpt = VideoGPT.load_from_checkpoint(args.ckpt).to(device) 44 | gpt.eval() 45 | args = gpt.hparams['args'] 46 | 47 | args.batch_size = 256 // dist.get_world_size() 48 | loader = VideoData(args).test_dataloader() 49 | 50 | #################### Load I3D ######################################## 51 | i3d = load_i3d_pretrained(device) 52 | 53 | #################### Compute FVD ############################### 54 | fvds = [] 55 | fvds_star = [] 56 | if is_root: 57 | pbar = tqdm(total=n_trials) 58 | for _ in range(n_trials): 59 | fvd, fvd_star = eval_fvd(i3d, gpt, loader, device) 60 | fvds.append(fvd) 61 | fvds_star.append(fvd_star) 62 | 63 | if is_root: 64 | pbar.update(1) 65 | fvd_mean = np.mean(fvds) 66 | fvd_std = np.std(fvds) 67 | 68 | fvd_star_mean = np.mean(fvds_star) 69 | fvd_star_std = np.std(fvds_star) 70 | 71 | pbar.set_description(f"FVD {fvd_mean:.2f} +/- {fvd_std:.2f}, FVD* {fvd_star_mean:.2f} +/0 {fvd_star_std:.2f}") 72 | if is_root: 73 | pbar.close() 74 | print(f"Final FVD {fvd_mean:.2f} +/- {fvd_std:.2f}, FVD* {fvd_star_mean:.2f} +/- {fvd_star_std:.2f}") 75 | 76 | 77 | def all_gather(tensor): 78 | rank, size = dist.get_rank(), dist.get_world_size() 79 | tensor_list = [torch.zeros_like(tensor) for _ in range(size)] 80 | dist.all_gather(tensor_list, tensor) 81 | return torch.cat(tensor_list) 82 | 83 | 84 | def eval_fvd(i3d, videogpt, loader, device): 85 | rank, size = dist.get_rank(), dist.get_world_size() 86 | is_root = rank == 0 87 | 88 | batch = next(iter(loader)) 89 | batch = {k: v.to(device) for k, v in batch.items()} 90 | 91 | fake_embeddings = [] 92 | for i in range(0, batch['video'].shape[0], MAX_BATCH): 93 | fake = videogpt.sample(MAX_BATCH, {k: v[i:i+MAX_BATCH] for k, v in batch.items()}) 94 | fake = fake.permute(0, 2, 3, 4, 1).cpu().numpy() # BCTHW -> BTHWC 95 | fake = (fake * 255).astype('uint8') 96 | fake_embeddings.append(get_fvd_logits(fake, i3d=i3d, device=device)) 97 | fake_embeddings = torch.cat(fake_embeddings) 98 | 99 | real = batch['video'].to(device) 100 | real_recon_embeddings = [] 101 | for i in range(0, batch['video'].shape[0], MAX_BATCH): 102 | real_recon = (videogpt.get_reconstruction(batch['video'][i:i+MAX_BATCH]) + 0.5).clamp(0, 1) 103 | real_recon = real_recon.permute(0, 2, 3, 4, 1).cpu().numpy() 104 | real_recon = (real_recon * 255).astype('uint8') 105 | real_recon_embeddings.append(get_fvd_logits(real_recon, i3d=i3d, device=device)) 106 | real_recon_embeddings = torch.cat(real_recon_embeddings) 107 | 108 | real = real + 0.5 109 | real = real.permute(0, 2, 3, 4, 1).cpu().numpy() # BCTHW -> BTHWC 110 | real = (real * 255).astype('uint8') 111 | real_embeddings = get_fvd_logits(real, i3d=i3d, device=device) 112 | 113 | fake_embeddings = all_gather(fake_embeddings) 114 | real_recon_embeddings = all_gather(real_recon_embeddings) 115 | real_embeddings = all_gather(real_embeddings) 116 | 117 | assert fake_embeddings.shape[0] == real_recon_embeddings.shape[0] == real_embeddings.shape[0] == 256 118 | 119 | fvd = frechet_distance(fake_embeddings.clone(), real_embeddings) 120 | fvd_star = frechet_distance(fake_embeddings.clone(), real_recon_embeddings) 121 | return fvd.item(), fvd_star.item() 122 | 123 | 124 | if __name__ == '__main__': 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument('--ckpt', type=str, default='bair_gpt') 127 | parser.add_argument('--n_trials', type=int, default=1, help="Number of trials to compute mean/std") 128 | parser.add_argument('--port', type=int, default=23452) 129 | args = parser.parse_args() 130 | 131 | main() 132 | -------------------------------------------------------------------------------- /scripts/count.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | def softmax(x): 3 | """ softmax function """ 4 | 5 | # assert(len(x.shape) > 1, "dimension must be larger than 1") 6 | # print(np.max(x, axis = 1, keepdims = True)) # axis = 1, 行 7 | 8 | x -= np.max(x, axis = 0, keepdims = True) #为了稳定地计算softmax概率, 一般会减掉最大的那个元素 9 | 10 | x = np.exp(x) / np.sum(np.exp(x), axis = 0, keepdims = True) 11 | 12 | return x 13 | 14 | robot_datas = [np.load(f'./data/robot_latents_v1_{idr}.npz')['robot'] for idr in range(540)] 15 | cnt_num = np.zeros(2048) 16 | for path in robot_datas: 17 | for item in path: 18 | tmp = item.flatten() 19 | for d in tmp: 20 | cnt_num[d] += 1 21 | import pdb;pdb.set_trace() 22 | cnt_num = cnt_num / cnt_num.max() 23 | print(cnt_num) 24 | np.save('./data/cnt_num.npy', cnt_num) -------------------------------------------------------------------------------- /scripts/preprocess/bair/bair_extract_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | 4 | import numpy as np 5 | from PIL import Image 6 | import tensorflow as tf 7 | 8 | from tensorflow.python.platform import flags 9 | from tensorflow.python.platform import gfile 10 | 11 | import imageio 12 | 13 | import argparse 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--data_dir', default='', help='base directory to save processed data') 16 | opt = parser.parse_args() 17 | 18 | def get_seq(dname): 19 | data_dir = '%s/softmotion30_44k/%s' % (opt.data_dir, dname) 20 | 21 | filenames = gfile.Glob(os.path.join(data_dir, '*')) 22 | if not filenames: 23 | raise RuntimeError('No data files found.') 24 | 25 | for f in filenames: 26 | k=0 27 | for serialized_example in tf.python_io.tf_record_iterator(f): 28 | example = tf.train.Example() 29 | example.ParseFromString(serialized_example) 30 | image_seq, action_seq = [], [] 31 | for i in range(30): 32 | image_name = str(i) + '/image_aux1/encoded' 33 | byte_str = example.features.feature[image_name].bytes_list.value[0] 34 | img = Image.frombytes('RGB', (64, 64), byte_str) 35 | arr = np.array(img.getdata()).reshape(img.size[1], img.size[0], 3) 36 | image_seq.append(arr.reshape(1, 64, 64, 3)) 37 | 38 | action_name = str(i) + '/action' 39 | action = example.features.feature[action_name].float_list.value 40 | action = np.array(action).astype('float32') 41 | action_seq.append(action) 42 | image_seq = np.concatenate(image_seq, axis=0) 43 | action_seq = np.stack(action_seq, axis=0) 44 | k=k+1 45 | yield f, k, image_seq, action_seq 46 | 47 | def convert_data(dname): 48 | seq_generator = get_seq(dname) 49 | n = 0 50 | while True: 51 | n+=1 52 | try: 53 | f, k, seq, actions = next(seq_generator) 54 | seq = seq.astype('uint8') 55 | except StopIteration: 56 | break 57 | f = f.split('/')[-1] 58 | os.makedirs('%s/processed_data/%s/%s/%d/' % (opt.data_dir, dname, f[:-10], k), exist_ok=True) 59 | for i in range(len(seq)): 60 | imageio.imwrite('%s/processed_data/%s/%s/%d/%d.png' % (opt.data_dir, dname, f[:-10], k, i), seq[i]) 61 | np.save('%s/processed_data/%s/%s/%d/actions.npy' % (opt.data_dir, dname, f[:-10], k), actions) 62 | 63 | print('%s data: %s (%d) (%d)' % (dname, f, k, n)) 64 | 65 | convert_data('test') 66 | convert_data('train') 67 | -------------------------------------------------------------------------------- /scripts/preprocess/bair/bair_image_to_hdf5.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import argparse 3 | import h5py 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | import os.path as osp 8 | from tqdm import tqdm 9 | import sys 10 | 11 | def convert_data(f, split): 12 | root_dir = args.data_dir 13 | path = osp.join(root_dir, 'processed_data', split) 14 | traj_paths = glob.glob(osp.join(path, '*', '*')) 15 | trajs, actions = [], [] 16 | for traj_path in tqdm(traj_paths): 17 | image_paths = glob.glob(osp.join(traj_path, '*.png')) 18 | image_paths.sort(key=lambda x: int(osp.splitext(osp.basename(x))[0])) 19 | traj = [] 20 | for img_path in image_paths: 21 | img = Image.open(img_path) 22 | arr = np.array(img) # HWC 23 | traj.append(arr) 24 | traj = np.stack(traj, axis=0) # THWC 25 | trajs.append(traj) 26 | 27 | actions.append(np.load(osp.join(traj_path, 'actions.npy'))) 28 | 29 | idxs = np.arange(len(trajs)) * 30 30 | trajs = np.concatenate(trajs, axis=0) # (NT)HWC 31 | actions = np.concatenate(actions, axis=0) # (NT)(act_dim) 32 | 33 | f.create_dataset(f'{split}_data', data=trajs) 34 | f.create_dataset(f'{split}_actions', data=actions) 35 | f.create_dataset(f'{split}_idx', data=idxs) 36 | 37 | print(split) 38 | print(f'\timages: {f[f"{split}_data"].shape}, {f[f"{split}_data"].dtype}') 39 | print(f'\timages: {f[f"{split}_idx"].shape}, {f[f"{split}_idx"].dtype}') 40 | 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument('--data_dir', type=str, required=True) 43 | parser.add_argument('--output_dir', type=str, required=True) 44 | args = parser.parse_args() 45 | 46 | os.makedirs(args.output_dir, exist_ok=True) 47 | f = h5py.File(osp.join(args.output_dir, 'bair.hdf5'), 'a') 48 | convert_data(f, 'train') 49 | convert_data(f, 'test') 50 | f.close() 51 | -------------------------------------------------------------------------------- /scripts/preprocess/bair/create_bair_dataset.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | mkdir -p ${1} 4 | mkdir -p ~/.cache/bair 5 | wget -P ~/.cache/bair/ http://rail.eecs.berkeley.edu/datasets/bair_robot_pushing_dataset_v0.tar 6 | tar -xf ~/.cache/bair/bair_robot_pushing_dataset_v0.tar -C ~/.cache/bair/ 7 | 8 | python scripts/preprocess/bair/bair_extract_images.py --data_dir ~/.cache/bair 9 | python scripts/preprocess/bair/bair_image_to_hdf5.py --data_dir ~/.cache/bair --output_dir ${1} 10 | 11 | rm -r ~/.cache/bair 12 | -------------------------------------------------------------------------------- /scripts/preprocess/ucf101/create_ucf_dataset.sh: -------------------------------------------------------------------------------- 1 | #! /bin/sh 2 | 3 | mkdir -p ${1} 4 | 5 | # Download UCF-101 video files 6 | wget --no-check-certificate -P ${1} https://www.crcv.ucf.edu/data/UCF101/UCF101.rar 7 | unrar x ${1}/UCF101.rar ${1} 8 | rm ${1}/UCF101.rar 9 | 10 | # Download UCF-101 train/test splits 11 | wget --no-check-certificate -P ${1} https://www.crcv.ucf.edu/data/UCF101/UCF101TrainTestSplits-RecognitionTask.zip 12 | unzip ${1}/UCF101TrainTestSplits-RecognitionTask.zip -d ${1} 13 | rm ${1}/UCF101TrainTestSplits-RecognitionTask.zip 14 | 15 | # Move video files into train / test directories based on train/test split 16 | python scripts/preprocess/ucf101/ucf_split_train_test.py ${1} 1 17 | 18 | # Delete leftover files 19 | rm -r ${1}/UCF-101 20 | -------------------------------------------------------------------------------- /scripts/preprocess/ucf101/ucf_split_train_test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import os.path as osp 4 | import shutil 5 | 6 | root = sys.argv[1] 7 | fold = int(sys.argv[2]) 8 | assert fold in [1, 2, 3] 9 | 10 | def move_files(files, split): 11 | split_dir = osp.join(root, split) 12 | os.makedirs(split_dir, exist_ok=True) 13 | for filename in files: 14 | folder = osp.join(split_dir, osp.dirname(filename)) 15 | os.makedirs(folder, exist_ok=True) 16 | shutil.move(osp.join(root, 'UCF-101', filename), osp.join(split_dir, filename)) 17 | 18 | with open(osp.join(root, 'ucfTrainTestlist', f'trainlist0{fold}.txt'), 'r') as f: 19 | train_files = [p.strip().split(' ')[0] for p in f.readlines()] 20 | move_files(train_files, 'train') 21 | 22 | with open(osp.join(root, 'ucfTrainTestlist', f'testlist0{fold}.txt'), 'r') as f: 23 | test_files = [p.strip() for p in f.readlines()] 24 | move_files(test_files, 'test') 25 | -------------------------------------------------------------------------------- /scripts/pretrain_meta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import inspect 3 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 4 | parentdir = os.path.dirname(currentdir) 5 | os.sys.path.insert(0, parentdir) 6 | import diffuser.utils as utils 7 | from pathlib import Path 8 | import numpy as np 9 | from numpy.random import randn 10 | import matplotlib as mpl 11 | from scipy import stats 12 | #os.environ['CUDA_VISIBLE_DEVICES'] = '1' 13 | #-----------------------------------------------------------------------------# 14 | #----------------------------------- setup -----------------------------------# 15 | #-----------------------------------------------------------------------------# 16 | 17 | class Parser(utils.Parser): 18 | dataset:str = 'meta' 19 | config: str = 'config.locomotion' 20 | 21 | args = Parser().parse_args('diffusion') 22 | #''' 23 | if args.single: 24 | args.tasks = ['close_jar'] 25 | import torch 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | 29 | dataset_config = utils.Config( 30 | args.loader, 31 | tasks = args.meta_tasks, 32 | savepath=(args.savepath, 'dataset_config.pkl'), 33 | data_folder = args.data_folder, 34 | sequence_length = 4, #args.sequence_length, #TODO 35 | devices = args.device, 36 | horizon = args.horizon, 37 | num_demos=args.num_demos, 38 | ) 39 | 40 | 41 | dataset = dataset_config() 42 | 43 | 44 | 45 | model_config = utils.Config( 46 | args.model, 47 | savepath=(args.savepath, 'model_config.pkl'), 48 | horizon=args.horizon, 49 | transition_dim=24, 50 | obs_cls=2048, 51 | act_cls=args.act_classes, 52 | cond_dim=512, 53 | num_tasks=50, 54 | dim_mults=args.dim_mults, 55 | attention=args.attention, 56 | device=args.device, 57 | train_device=args.device, 58 | verbose=False, 59 | action_dim=4, 60 | vqvae=dataset.vqvae, 61 | pretrain=args.pretrain, 62 | meta=True, 63 | ) 64 | diffusion_config = utils.Config( 65 | args.diffusion, 66 | savepath=(args.savepath, 'diffusion_config.pkl'), 67 | horizon=args.horizon, #TODO 68 | observation_dim=24*24, 69 | obs_classes=2048, 70 | act_classes=args.act_classes, 71 | action_dim=4, 72 | n_timesteps=args.n_diffusion_steps, 73 | loss_type=args.loss_type, 74 | clip_denoised=args.clip_denoised, 75 | predict_epsilon=args.predict_epsilon, 76 | ## loss weighting 77 | action_weight=args.action_weight, 78 | loss_weights=args.loss_weights, 79 | loss_discount=args.loss_discount, 80 | device=args.device, 81 | pretrain=args.pretrain, 82 | focal=args.focal, 83 | force=args.force, 84 | ) 85 | if args.pretrain and args.concat: 86 | dataset = dataset() 87 | trainer_config = utils.Config( 88 | utils.MetaworldTrainer, 89 | savepath=(args.savepath, 'trainer_config.pkl'), 90 | train_batch_size=args.batch_size, 91 | train_lr=args.learning_rate, 92 | gradient_accumulate_every=args.gradient_accumulate_every, 93 | ema_decay=args.ema_decay, 94 | sample_freq=args.sample_freq, 95 | save_freq=args.save_freq, 96 | label_freq=int(args.n_train_steps // args.n_saves), 97 | save_parallel=args.save_parallel, 98 | results_folder=args.savepath, 99 | bucket=args.bucket, 100 | n_reference=args.n_reference, 101 | trainer_device=args.device, 102 | horizon=args.horizon, 103 | distributed=False, 104 | pretrain=args.pretrain, 105 | ) 106 | 107 | #-----------------------------------------------------------------------------# 108 | #-------------------------------- instantiate --------------------------------# 109 | #-----------------------------------------------------------------------------# 110 | 111 | model = model_config() 112 | 113 | diffusion = diffusion_config(model) 114 | renderer=None 115 | 116 | trainer = trainer_config(diffusion, dataset, renderer) 117 | 118 | 119 | #-----------------------------------------------------------------------------# 120 | #------------------------ test forward & backward pass -----------------------# 121 | #-----------------------------------------------------------------------------# 122 | 123 | utils.report_parameters(model) 124 | # print('Testing forward...', end=' ', flush=True) 125 | # batch = utils.batchify(dataset[0]) 126 | # #loss, _ = diffusion.loss(*batch, device=args.device) 127 | # #loss.backward() 128 | # print('✓') 129 | 130 | #-----------------------------------------------------------------------------# 131 | #--------------------------------- main loop ---------------------------------# 132 | #-----------------------------------------------------------------------------# 133 | #args.n_train_steps = 5e4 134 | #args.n_steps_per_epoch = 1000 135 | n_epochs = int(args.n_train_steps // args.n_steps_per_epoch) 136 | 137 | for i in range(n_epochs): 138 | print(f'Epoch {i} / {n_epochs} | {args.savepath}') 139 | trainer.train(n_train_steps=args.n_steps_per_epoch) 140 | -------------------------------------------------------------------------------- /scripts/pretrain_video_diff.py: -------------------------------------------------------------------------------- 1 | import os 2 | import inspect 3 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 4 | parentdir = os.path.dirname(currentdir) 5 | os.sys.path.insert(0, parentdir) 6 | import diffuser.utils as utils 7 | from pathlib import Path 8 | import numpy as np 9 | from numpy.random import randn 10 | import matplotlib as mpl 11 | from scipy import stats 12 | #os.environ['CUDA_VISIBLE_DEVICES'] = '1' 13 | #-----------------------------------------------------------------------------# 14 | #----------------------------------- setup -----------------------------------# 15 | #-----------------------------------------------------------------------------# 16 | 17 | class Parser(utils.Parser): 18 | dataset:str = 'test' 19 | config: str = 'config.locomotion' 20 | 21 | args = Parser().parse_args('diffusion') 22 | #''' 23 | if args.single: 24 | args.tasks = ['close_jar'] 25 | import torch 26 | torch.backends.cudnn.deterministic = True 27 | torch.backends.cudnn.benchmark = False 28 | 29 | dataset_config = utils.Config( 30 | args.loader, 31 | tasks = args.tasks, 32 | savepath=(args.savepath, 'dataset_config.pkl'), 33 | data_folder = args.data_folder, 34 | sequence_length = 8, #args.sequence_length, #TODO 35 | devices = args.device, 36 | horizon = args.horizon, 37 | ) 38 | 39 | #render_config = utils.Config( 40 | # args.renderer, 41 | # savepath=(args.savepath, 'render_config.pkl'), 42 | # env=args.dataset, 43 | #) 44 | 45 | dataset = dataset_config() 46 | 47 | #""" 48 | #renderer = render_config() 49 | #-----------------------------------------------------------------------------# 50 | #------------------------------ model & trainer ------------------------------# 51 | #-----------------------------------------------------------------------------# 52 | #Batch_size x 4 x 24 x 24 53 | 54 | model_config = utils.Config( 55 | args.model, 56 | savepath=(args.savepath, 'model_config.pkl'), 57 | horizon=args.horizon, 58 | transition_dim=24, 59 | obs_cls=2048, 60 | act_cls=args.act_classes, 61 | cond_dim=512, 62 | num_tasks=50, 63 | dim_mults=args.dim_mults, 64 | attention=args.attention, 65 | device=args.device, 66 | train_device=args.device, 67 | verbose=False, 68 | action_dim=7, 69 | vqvae=dataset.vqvae, 70 | pretrain=args.pretrain, 71 | multiview=True, 72 | ) 73 | diffusion_config = utils.Config( 74 | args.diffusion, 75 | savepath=(args.savepath, 'diffusion_config.pkl'), 76 | horizon=args.horizon, #TODO 77 | observation_dim=24*24, 78 | obs_classes=2048, 79 | act_classes=args.act_classes, 80 | action_dim=7, 81 | n_timesteps=args.n_diffusion_steps, 82 | loss_type=args.loss_type, 83 | clip_denoised=args.clip_denoised, 84 | predict_epsilon=args.predict_epsilon, 85 | ## loss weighting 86 | action_weight=args.action_weight, 87 | loss_weights=args.loss_weights, 88 | loss_discount=args.loss_discount, 89 | device=args.device, 90 | pretrain=args.pretrain, 91 | focal=args.focal, 92 | force=args.force, 93 | 94 | ) 95 | if args.pretrain and args.concat: 96 | dataset = dataset() 97 | trainer_config = utils.Config( 98 | utils.MultiviewTrainer, 99 | savepath=(args.savepath, 'trainer_config.pkl'), 100 | train_batch_size=args.batch_size, 101 | train_lr=args.learning_rate, 102 | gradient_accumulate_every=args.gradient_accumulate_every, 103 | ema_decay=args.ema_decay, 104 | sample_freq=args.sample_freq, 105 | save_freq=args.save_freq, 106 | label_freq=int(args.n_train_steps // args.n_saves), 107 | save_parallel=args.save_parallel, 108 | results_folder=args.savepath, 109 | bucket=args.bucket, 110 | n_reference=args.n_reference, 111 | trainer_device=args.device, 112 | horizon=args.horizon, 113 | distributed=False, 114 | pretrain=args.pretrain, 115 | ) 116 | 117 | #-----------------------------------------------------------------------------# 118 | #-------------------------------- instantiate --------------------------------# 119 | #-----------------------------------------------------------------------------# 120 | 121 | model = model_config() 122 | 123 | diffusion = diffusion_config(model) 124 | renderer=None 125 | 126 | trainer = trainer_config(diffusion, dataset, renderer) 127 | 128 | 129 | #-----------------------------------------------------------------------------# 130 | #------------------------ test forward & backward pass -----------------------# 131 | #-----------------------------------------------------------------------------# 132 | 133 | utils.report_parameters(model) 134 | # print('Testing forward...', end=' ', flush=True) 135 | # batch = utils.batchify(dataset[0]) 136 | # #loss, _ = diffusion.loss(*batch, device=args.device) 137 | # #loss.backward() 138 | # print('✓') 139 | 140 | #-----------------------------------------------------------------------------# 141 | #--------------------------------- main loop ---------------------------------# 142 | #-----------------------------------------------------------------------------# 143 | #args.n_train_steps = 5e4 144 | #args.n_steps_per_epoch = 1000 145 | n_epochs = int(args.n_train_steps // args.n_steps_per_epoch) 146 | 147 | for i in range(n_epochs): 148 | print(f'Epoch {i} / {n_epochs} | {args.savepath}') 149 | trainer.train(n_train_steps=args.n_steps_per_epoch) 150 | -------------------------------------------------------------------------------- /scripts/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import inspect 3 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 4 | parentdir = os.path.dirname(currentdir) 5 | os.sys.path.insert(0, parentdir) 6 | import diffuser.utils as utils 7 | from pathlib import Path 8 | import numpy as np 9 | from numpy.random import randn 10 | import matplotlib as mpl 11 | from scipy import stats 12 | import torch 13 | import einops 14 | from diffuser.utils.arrays import batch_to_device 15 | from videogpt.utils import save_video_grid 16 | # -----------------------------------------------------------------------------# 17 | # ----------------------------------- setup -----------------------------------# 18 | # -----------------------------------------------------------------------------# 19 | DTYPE = torch.float 20 | DEVICE = 'cuda' 21 | def cycle(dl): 22 | while True: 23 | for data in dl: 24 | yield data 25 | 26 | def to_torch(x, dtype=None, device=None): 27 | dtype = dtype or DTYPE 28 | device = device or DEVICE 29 | if type(x) is dict: 30 | return {k: to_torch(v, dtype, device) for k, v in x.items()} 31 | elif torch.is_tensor(x): 32 | return x.to(device).type(dtype) 33 | # import pdb; pdb.set_trace() 34 | return torch.tensor(x, dtype=dtype, device=device) 35 | class Parser(utils.Parser): 36 | dataset: str = 'meta' 37 | config: str = 'config.locomotion' 38 | 39 | torch.backends.cudnn.deterministic = True 40 | torch.backends.cudnn.benchmark = False 41 | args = Parser().parse_args('plan') 42 | 43 | # -----------------------------------------------------------------------------# 44 | # ---------------------------------- loading ----------------------------------# 45 | # -----------------------------------------------------------------------------# 46 | 47 | # load diffusion model and value function from disk 48 | diffusion_experiment = utils.load_diffusion( 49 | args.loadbase, args.dataset, args.diffusion_loadpath, 50 | epoch=args.diffusion_epoch, device=args.device, seed=args.seed, 51 | ) 52 | 53 | diffusion = diffusion_experiment.ema.to(args.device) 54 | diffusion.clip_denoised = True 55 | dataset = diffusion_experiment.dataset 56 | ## initialize value guide 57 | # value_function = value_experiment.ema 58 | def evaluate(): 59 | # batch = next(iter(diffusion_experiment.renderer)) 60 | # batch = batch_to_device(batch,args.device) 61 | # loss, info = diffusion.loss(*batch) 62 | data = dataset[47000] 63 | #print(dataset.indices[46000]) 64 | #import pdb;pdb.set_trace() 65 | #diffuion = diffusion.to('cpu') 66 | #vqvae = diffusion.model.traj_model.vqvae 67 | 68 | vqvae = dataset.vqvae 69 | #import pdb;pdb.set_trace() 70 | hist = to_torch(data.conditions).unsqueeze(0).long() 71 | hist = einops.rearrange(hist, 'i j h b k c -> (i j) (h b) k c') 72 | print(hist.shape[0]) 73 | video_recon = vqvae.decode(hist) 74 | samples = torch.clamp(video_recon, -0.5, 0.5) + 0.5 75 | save_video_grid(samples.detach(), './samples_origin.mp4') 76 | #import pdb;pdb.set_trace() 77 | obs = diffusion.sample_mask(2, task=to_torch(data.task).unsqueeze(0), x_condition=to_torch(data.conditions))#obs, _ = diffusion.sample_mask(2, task=to_torch(data.task).unsqueeze(0), x_condition=to_torch(data.conditions).unsqueeze(0)) 78 | print(obs.shape) 79 | hist = to_torch(data.trajectories['obs']).unsqueeze(0).long() 80 | hist = einops.rearrange(hist, 'i j h b k c -> (i j) (h b) k c') 81 | gt = hist 82 | print("Ground Truth:", gt) 83 | print("Denoised:", obs) 84 | print("Error:", (obs-gt)**2/np.prod(obs.shape)) 85 | video_recon = vqvae.decode(hist) 86 | samples = torch.clamp(video_recon, -0.5, 0.5) + 0.5 87 | save_video_grid(samples.detach(), './samples_origin_traj.mp4') 88 | #import pdb; pdb.set_trace() 89 | video_recon = vqvae.decode(obs) 90 | samples = torch.clamp(video_recon, -0.5, 0.5) + 0.5 91 | save_video_grid(samples.detach(), './samples_pretrain.mp4') 92 | if __name__ == '__main__': 93 | evaluate() 94 | 95 | -------------------------------------------------------------------------------- /scripts/train_vqvae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pytorch_lightning as pl 3 | from pytorch_lightning.callbacks import ModelCheckpoint 4 | from pytorch_lightning.loggers import TensorBoardLogger 5 | import os 6 | import inspect 7 | currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) 8 | parentdir = os.path.dirname(currentdir) 9 | os.sys.path.insert(0, parentdir) 10 | from videogpt import VQVAE, VideoData 11 | import torch 12 | 13 | def main(): 14 | pl.seed_everything(1234) 15 | 16 | parser = argparse.ArgumentParser() 17 | parser = pl.Trainer.add_argparse_args(parser) 18 | parser = VQVAE.add_model_specific_args(parser) 19 | parser.add_argument('--data_path', type=str, default='./data') 20 | parser.add_argument('--n_codes', type=int, default=2048) 21 | parser.add_argument('--ckpt', type=str, default='./lightning_logs/version_83/checkpoints/last.ckpt') 22 | parser.add_argument('--sequence_length', type=int, default=16) 23 | parser.add_argument('--resolution', type=int, default=128) 24 | parser.add_argument('--batch_size', type=int, default=16)#16 25 | parser.add_argument('--num_workers', type=int, default=0) 26 | parser.add_argument("--save_path", default='./results', type=str, help="path to save checkpoints") 27 | parser.add_argument("--save_topk", default=5, type=int, help="save topk checkpoint") 28 | args = parser.parse_args() 29 | torch.backends.cudnn.deterministic = True 30 | torch.backends.cudnn.benchmark = False 31 | data = VideoData(args) 32 | # pre-make relevant cached files if necessary 33 | data.train_dataloader() 34 | data.test_dataloader() 35 | model = VQVAE(args) 36 | #from videogpt.download import load_vqvae 37 | 38 | #vqvae = VQVAE.load_from_checkpoint(args.ckpt) 39 | model = model.load_from_checkpoint(args.ckpt) 40 | logger = TensorBoardLogger( 41 | save_dir=args.save_path, 42 | name='log' 43 | ) 44 | callbacks = [] 45 | callbacks.append(ModelCheckpoint(filename=os.path.join(args.save_path, '{val/recon_loss:.4f}'), 46 | save_top_k=args.save_topk, 47 | save_last=True,monitor='val/recon_loss', mode='min')) 48 | 49 | kwargs = dict() 50 | if args.gpus > 1: 51 | kwargs = dict(strategy="ddp", accelerator="gpu", gpus=args.gpus, devices=-1) 52 | trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, 53 | max_steps=200000000, **kwargs) 54 | print("Executing training!") 55 | trainer.fit(model, data) 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | 61 | -------------------------------------------------------------------------------- /videogpt/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .vqvae import VQVAE 3 | from .gpt import VideoGPT 4 | from .data import VideoData 5 | from .download import load_vqvae, load_videogpt, load_i3d_pretrained, download 6 | 7 | -------------------------------------------------------------------------------- /videogpt/download.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from tqdm import tqdm 3 | import os 4 | import gdown 5 | import torch 6 | 7 | from .vqvae import VQVAE 8 | from .gpt import VideoGPT 9 | 10 | 11 | def download(id, fname, root=os.path.expanduser('~/.cache/videogpt')): 12 | os.makedirs(root, exist_ok=True) 13 | destination = os.path.join(root, fname) 14 | 15 | if os.path.exists(destination): 16 | return destination 17 | 18 | gdown.download(id=id, output=destination, quiet=False) 19 | return destination 20 | 21 | 22 | _VQVAE = { 23 | 'bair_stride4x2x2': '1iIAYJ2Qqrx5Q94s5eIXQYJgAydzvT_8L', # trained on 16 frames of 64 x 64 images 24 | 'ucf101_stride4x4x4': '1uuB_8WzHP_bbBmfuaIV7PK_Itl3DyHY5', # trained on 16 frames of 128 x 128 images 25 | 'kinetics_stride4x4x4': '1DOvOZnFAIQmux6hG7pN_HkyJZy3lXbCB', # trained on 16 frames of 128 x 128 images 26 | 'kinetics_stride2x4x4': '1jvtjjtrtE4cy6pl7DK_zWFEPY3RZt2pB' # trained on 16 frames of 128 x 128 images 27 | } 28 | 29 | def load_vqvae(model_name, device=torch.device('cpu'), root=os.path.expanduser('~/.cache/videogpt')): 30 | assert model_name in _VQVAE, f"Invalid model_name: {model_name}" 31 | filepath = download(_VQVAE[model_name], model_name, root=root) 32 | vqvae = VQVAE.load_from_checkpoint(filepath).to(device) 33 | vqvae.eval() 34 | 35 | return vqvae 36 | 37 | 38 | _VIDEOGPT = { 39 | 'bair_gpt': '1fNTtJAgO6grEtPNrufkpbee1CfGztW-1', # 1-frame conditional, 16 frames of 64 x 64 images 40 | 'ucf101_uncond_gpt': '1QkF_Sb2XVRgSbFT_SxQ6aZUeDFoliPQq', # unconditional, 16 frames of 128 x 128 images 41 | } 42 | 43 | def load_videogpt(model_name, device=torch.device('cpu')): 44 | assert model_name in _VIDEOGPT, f"Invalid model_name: {model_name}" 45 | filepath = download(_VIDEOGPT[model_name], model_name) 46 | gpt = VideoGPT.load_from_checkpoint(filepath).to(device) 47 | gpt.eval() 48 | 49 | return gpt 50 | 51 | 52 | _I3D_PRETRAINED_ID = '1mQK8KD8G6UWRa5t87SRMm5PVXtlpneJT' 53 | 54 | def load_i3d_pretrained(device=torch.device('cpu')): 55 | from .fvd.pytorch_i3d import InceptionI3d 56 | i3d = InceptionI3d(400, in_channels=3).to(device) 57 | filepath = download(_I3D_PRETRAINED_ID, 'i3d_pretrained_400.pt') 58 | i3d.load_state_dict(torch.load(filepath, map_location=device)) 59 | i3d.eval() 60 | return i3d 61 | -------------------------------------------------------------------------------- /videogpt/fvd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/videogpt/fvd/__init__.py -------------------------------------------------------------------------------- /videogpt/fvd/convert_tf_pretrained.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | import tensorflow_hub as hub 4 | import torch 5 | 6 | from src_pytorch.fvd.pytorch_i3d import InceptionI3d 7 | 8 | 9 | def convert_name(name): 10 | mapping = { 11 | 'conv_3d': 'conv3d', 12 | 'batch_norm': 'bn', 13 | 'w:0': 'weight', 14 | 'b:0': 'bias', 15 | 'moving_mean:0': 'running_mean', 16 | 'moving_variance:0': 'running_var', 17 | 'beta:0': 'bias' 18 | } 19 | 20 | segs = name.split('/') 21 | new_segs = [] 22 | i = 0 23 | while i < len(segs): 24 | seg = segs[i] 25 | if 'Mixed' in seg: 26 | new_segs.append(seg) 27 | elif 'Conv' in seg and 'Mixed' not in name: 28 | new_segs.append(seg) 29 | elif 'Branch' in seg: 30 | branch_i = int(seg.split('_')[-1]) 31 | i += 1 32 | seg = segs[i] 33 | 34 | # special case due to typo in original code 35 | if 'Mixed_5b' in name and branch_i == 2: 36 | if '1x1' in seg: 37 | new_segs.append(f'b{branch_i}a') 38 | elif '3x3' in seg: 39 | new_segs.append(f'b{branch_i}b') 40 | else: 41 | raise Exception() 42 | # Either Conv3d_{i}a_... or Conv3d_{i}b_... 43 | elif 'a' in seg: 44 | if branch_i == 0: 45 | new_segs.append('b0') 46 | else: 47 | new_segs.append(f'b{branch_i}a') 48 | elif 'b' in seg: 49 | new_segs.append(f'b{branch_i}b') 50 | else: 51 | raise Exception 52 | elif seg == 'Logits': 53 | new_segs.append('logits') 54 | i += 1 55 | elif seg in mapping: 56 | new_segs.append(mapping[seg]) 57 | else: 58 | raise Exception(f"No match found for seg {seg} in name {name}") 59 | 60 | i += 1 61 | return '.'.join(new_segs) 62 | 63 | def convert_tensor(tensor): 64 | tensor_dim = len(tensor.shape) 65 | if tensor_dim == 5: # conv or bn 66 | if all([t == 1 for t in tensor.shape[:-1]]): 67 | tensor = tensor.squeeze() 68 | else: 69 | tensor = tensor.permute(4, 3, 0, 1, 2).contiguous() 70 | elif tensor_dim == 1: # conv bias 71 | pass 72 | else: 73 | raise Exception(f"Invalid shape {tensor.shape}") 74 | return tensor 75 | 76 | n_class = int(sys.argv[1]) # 600 or 400 77 | assert n_class in [400, 600] 78 | 79 | # Converts model from https://github.com/google-research/google-research/tree/master/frechet_video_distance 80 | # to pytorch version for loading 81 | model_url = f"https://tfhub.dev/deepmind/i3d-kinetics-{n_class}/1" 82 | i3d = hub.load(model_url) 83 | name_prefix = 'RGB/inception_i3d/' 84 | 85 | print('Creating state_dict...') 86 | all_names = [] 87 | state_dict = OrderedDict() 88 | for var in i3d.variables: 89 | name = var.name[len(name_prefix):] 90 | new_name = convert_name(name) 91 | all_names.append(new_name) 92 | 93 | tensor = torch.FloatTensor(var.value().numpy()) 94 | new_tensor = convert_tensor(tensor) 95 | 96 | state_dict[new_name] = new_tensor 97 | 98 | if 'bn.bias' in new_name: 99 | new_name = new_name[:-4] + 'weight' # bn.weight 100 | new_tensor = torch.ones_like(new_tensor).float() 101 | state_dict[new_name] = new_tensor 102 | 103 | print(f'Complete state_dict with {len(state_dict)} entries') 104 | 105 | s = dict() 106 | for i, n in enumerate(all_names): 107 | s[n] = s.get(n, []) + [i] 108 | 109 | for k, v in s.items(): 110 | if len(v) > 1: 111 | print('dup', k) 112 | for i in v: 113 | print('\t', i3d.variables[i].name) 114 | 115 | print('Testing load_state_dict...') 116 | print('Creating model...') 117 | 118 | i3d = InceptionI3d(n_class, in_channels=3) 119 | 120 | print('Loading state_dict...') 121 | i3d.load_state_dict(state_dict) 122 | 123 | print(f'Saving state_dict as fvd/i3d_pretrained_{n_class}.pt') 124 | torch.save(state_dict, f'fvd/i3d_pretrained_{n_class}.pt') 125 | 126 | print('Done') 127 | 128 | -------------------------------------------------------------------------------- /videogpt/fvd/fvd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..data import preprocess as preprocess_single 3 | 4 | 5 | def preprocess(videos, target_resolution=224): 6 | # videos in {0, ..., 255} as np.uint8 array 7 | b, t, h, w, c = videos.shape 8 | videos = torch.from_numpy(videos) 9 | videos = torch.stack([preprocess_single(video, target_resolution) for video in videos]) 10 | return videos * 2 # [-0.5, 0.5] -> [-1, 1] 11 | 12 | def get_fvd_logits(videos, i3d, device): 13 | videos = preprocess(videos) 14 | embeddings = get_logits(i3d, videos, device) 15 | return embeddings 16 | 17 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161 18 | def _symmetric_matrix_square_root(mat, eps=1e-10): 19 | u, s, v = torch.svd(mat) 20 | si = torch.where(s < eps, s, torch.sqrt(s)) 21 | return torch.matmul(torch.matmul(u, torch.diag(si)), v.t()) 22 | 23 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400 24 | def trace_sqrt_product(sigma, sigma_v): 25 | sqrt_sigma = _symmetric_matrix_square_root(sigma) 26 | sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma)) 27 | return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) 28 | 29 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 30 | def cov(m, rowvar=False): 31 | '''Estimate a covariance matrix given data. 32 | 33 | Covariance indicates the level to which two variables vary together. 34 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 35 | then the covariance matrix element `C_{ij}` is the covariance of 36 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 37 | 38 | Args: 39 | m: A 1-D or 2-D array containing multiple variables and observations. 40 | Each row of `m` represents a variable, and each column a single 41 | observation of all those variables. 42 | rowvar: If `rowvar` is True, then each row represents a 43 | variable, with observations in the columns. Otherwise, the 44 | relationship is transposed: each column represents a variable, 45 | while the rows contain observations. 46 | 47 | Returns: 48 | The covariance matrix of the variables. 49 | ''' 50 | if m.dim() > 2: 51 | raise ValueError('m has more than 2 dimensions') 52 | if m.dim() < 2: 53 | m = m.view(1, -1) 54 | if not rowvar and m.size(0) != 1: 55 | m = m.t() 56 | 57 | fact = 1.0 / (m.size(1) - 1) # unbiased estimate 58 | m -= torch.mean(m, dim=1, keepdim=True) 59 | mt = m.t() # if complex: mt = m.t().conj() 60 | return fact * m.matmul(mt).squeeze() 61 | 62 | 63 | def frechet_distance(x1, x2): 64 | x1 = x1.flatten(start_dim=1) 65 | x2 = x2.flatten(start_dim=1) 66 | m, m_w = x1.mean(dim=0), x2.mean(dim=0) 67 | sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False) 68 | 69 | sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) 70 | trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component 71 | 72 | mean = torch.sum((m - m_w) ** 2) 73 | fd = trace + mean 74 | return fd 75 | 76 | 77 | def get_logits(i3d, videos, device): 78 | assert videos.shape[0] % 16 == 0 79 | with torch.no_grad(): 80 | logits = [] 81 | for i in range(0, videos.shape[0], 16): 82 | batch = videos[i:i + 16].to(device) 83 | logits.append(i3d(batch)) 84 | logits = torch.cat(logits, dim=0) 85 | return logits 86 | -------------------------------------------------------------------------------- /videogpt/gpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import itertools 3 | import numpy as np 4 | from tqdm import tqdm 5 | import argparse 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim.lr_scheduler as lr_scheduler 11 | import pytorch_lightning as pl 12 | 13 | from .resnet import resnet34 14 | from .attention import AttentionStack, LayerNorm, AddBroadcastPosEmbed 15 | from .utils import shift_dim 16 | 17 | 18 | class VideoGPT(pl.LightningModule): 19 | def __init__(self, args): 20 | super().__init__() 21 | self.args = args 22 | 23 | # Load VQ-VAE and set all parameters to no grad 24 | from .vqvae import VQVAE 25 | from .download import load_vqvae 26 | if not os.path.exists(args.vqvae): 27 | self.vqvae = load_vqvae(args.vqvae) 28 | else: 29 | self.vqvae = VQVAE.load_from_checkpoint(args.vqvae) 30 | for p in self.vqvae.parameters(): 31 | p.requires_grad = False 32 | self.vqvae.codebook._need_init = False 33 | self.vqvae.eval() 34 | 35 | # ResNet34 for frame conditioning 36 | self.use_frame_cond = args.n_cond_frames > 0 37 | if self.use_frame_cond: 38 | frame_cond_shape = (args.n_cond_frames, 39 | args.resolution // 4, 40 | args.resolution // 4, 41 | 240) 42 | self.resnet = resnet34(1, (1, 4, 4), resnet_dim=240) 43 | self.cond_pos_embd = AddBroadcastPosEmbed( 44 | shape=frame_cond_shape[:-1], embd_dim=frame_cond_shape[-1] 45 | ) 46 | else: 47 | frame_cond_shape = None 48 | 49 | # VideoGPT transformer 50 | self.shape = self.vqvae.latent_shape 51 | 52 | self.fc_in = nn.Linear(self.vqvae.embedding_dim, args.hidden_dim, bias=False) 53 | self.fc_in.weight.data.normal_(std=0.02) 54 | 55 | self.attn_stack = AttentionStack( 56 | self.shape, args.hidden_dim, args.heads, args.layers, args.dropout, 57 | args.attn_type, args.attn_dropout, args.class_cond_dim, frame_cond_shape 58 | ) 59 | 60 | self.norm = LayerNorm(args.hidden_dim, args.class_cond_dim) 61 | 62 | self.fc_out = nn.Linear(args.hidden_dim, self.vqvae.n_codes, bias=False) 63 | self.fc_out.weight.data.copy_(torch.zeros(self.vqvae.n_codes, args.hidden_dim)) 64 | 65 | # caches for faster decoding (if necessary) 66 | self.frame_cond_cache = None 67 | 68 | self.save_hyperparameters() 69 | 70 | def get_reconstruction(self, videos): 71 | return self.vqvae.decode(self.vqvae.encode(videos)) 72 | 73 | def sample(self, n, batch=None): 74 | device = self.fc_in.weight.device 75 | 76 | cond = dict() 77 | if self.use_frame_cond or self.args.class_cond: 78 | assert batch is not None 79 | video = batch['video'] 80 | 81 | if self.args.class_cond: 82 | label = batch['label'] 83 | cond['class_cond'] = F.one_hot(label, self.args.class_cond_dim).type_as(video) 84 | if self.use_frame_cond: 85 | cond['frame_cond'] = video[:, :, :self.args.n_cond_frames] 86 | 87 | samples = torch.zeros((n,) + self.shape).long().to(device) 88 | idxs = list(itertools.product(*[range(s) for s in self.shape])) 89 | 90 | with torch.no_grad(): 91 | prev_idx = None 92 | for i, idx in enumerate(tqdm(idxs)): 93 | batch_idx_slice = (slice(None, None), *[slice(i, i + 1) for i in idx]) 94 | batch_idx = (slice(None, None), *idx) 95 | embeddings = self.vqvae.codebook.dictionary_lookup(samples) 96 | 97 | if prev_idx is None: 98 | # set arbitrary input values for the first token 99 | # does not matter what value since it will be shifted anyways 100 | embeddings_slice = embeddings[batch_idx_slice] 101 | samples_slice = samples[batch_idx_slice] 102 | else: 103 | embeddings_slice = embeddings[prev_idx] 104 | samples_slice = samples[prev_idx] 105 | 106 | logits = self(embeddings_slice, samples_slice, cond, 107 | decode_step=i, decode_idx=idx)[1] 108 | # squeeze all possible dim except batch dimension 109 | logits = logits.squeeze().unsqueeze(0) if logits.shape[0] == 1 else logits.squeeze() 110 | probs = F.softmax(logits, dim=-1) 111 | samples[batch_idx] = torch.multinomial(probs, 1).squeeze(-1) 112 | 113 | prev_idx = batch_idx_slice 114 | samples = self.vqvae.decode(samples) 115 | samples = torch.clamp(samples, -0.5, 0.5) + 0.5 116 | 117 | return samples # BCTHW in [0, 1] 118 | 119 | 120 | def forward(self, x, targets, cond, decode_step=None, decode_idx=None): 121 | if self.use_frame_cond: 122 | if decode_step is None: 123 | cond['frame_cond'] = self.cond_pos_embd(self.resnet(cond['frame_cond'])) 124 | elif decode_step == 0: 125 | self.frame_cond_cache = self.cond_pos_embd(self.resnet(cond['frame_cond'])) 126 | cond['frame_cond'] = self.frame_cond_cache 127 | else: 128 | cond['frame_cond'] = self.frame_cond_cache 129 | 130 | h = self.fc_in(x) 131 | h = self.attn_stack(h, cond, decode_step, decode_idx) 132 | h = self.norm(h, cond) 133 | logits = self.fc_out(h) 134 | 135 | loss = F.cross_entropy(shift_dim(logits, -1, 1), targets) 136 | 137 | return loss, logits 138 | 139 | def training_step(self, batch, batch_idx): 140 | self.vqvae.eval() 141 | x = batch['video'] 142 | 143 | cond = dict() 144 | if self.args.class_cond: 145 | label = batch['label'] 146 | cond['class_cond'] = F.one_hot(label, self.args.class_cond_dim).type_as(x) 147 | if self.use_frame_cond: 148 | cond['frame_cond'] = x[:, :, :self.args.n_cond_frames] 149 | 150 | with torch.no_grad(): 151 | targets, x = self.vqvae.encode(x, include_embeddings=True) 152 | x = shift_dim(x, 1, -1) 153 | 154 | loss, _ = self(x, targets, cond) 155 | return loss 156 | 157 | def validation_step(self, batch, batch_idx): 158 | loss = self.training_step(batch, batch_idx) 159 | self.log('val/loss', loss, prog_bar=True) 160 | 161 | def configure_optimizers(self): 162 | optimizer = torch.optim.Adam(self.parameters(), lr=3e-4, betas=(0.9, 0.999)) 163 | assert hasattr(self.args, 'max_steps') and self.args.max_steps is not None, f"Must set max_steps argument" 164 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, self.args.max_steps) 165 | return [optimizer], [dict(scheduler=scheduler, interval='step', frequency=1)] 166 | 167 | 168 | @staticmethod 169 | def add_model_specific_args(parent_parser): 170 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 171 | parser.add_argument('--vqvae', type=str, default='kinetics_stride4x4x4', 172 | help='path to vqvae ckpt, or model name to download pretrained') 173 | parser.add_argument('--n_cond_frames', type=int, default=0) 174 | parser.add_argument('--class_cond', action='store_true') 175 | 176 | # VideoGPT hyperparmeters 177 | parser.add_argument('--hidden_dim', type=int, default=576) 178 | parser.add_argument('--heads', type=int, default=4) 179 | parser.add_argument('--layers', type=int, default=8) 180 | parser.add_argument('--dropout', type=float, default=0.2) 181 | parser.add_argument('--attn_type', type=str, default='full', 182 | choices=['full', 'sparse']) 183 | parser.add_argument('--attn_dropout', type=float, default=0.3) 184 | 185 | return parser 186 | -------------------------------------------------------------------------------- /videogpt/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .utils import shift_dim 9 | 10 | class ChannelLayerNorm(nn.Module): 11 | # layer norm on channels 12 | def __init__(self, in_features): 13 | super().__init__() 14 | self.norm = nn.LayerNorm(in_features) 15 | 16 | def forward(self, x): 17 | x = shift_dim(x, 1, -1) 18 | x = self.norm(x) 19 | x = shift_dim(x, -1, 1) 20 | return x 21 | 22 | 23 | class NormReLU(nn.Module): 24 | 25 | def __init__(self, channels, relu=True, affine=True): 26 | super().__init__() 27 | 28 | self.relu = relu 29 | self.norm = ChannelLayerNorm(channels) 30 | 31 | def forward(self, x): 32 | x_float = x.float() 33 | x_float = self.norm(x_float) 34 | x = x_float.type_as(x) 35 | if self.relu: 36 | x = F.relu(x, inplace=True) 37 | return x 38 | 39 | 40 | class ResidualBlock(nn.Module): 41 | 42 | def __init__(self, in_channels, filters, stride, use_projection=False): 43 | super().__init__() 44 | 45 | if use_projection: 46 | self.proj_conv = nn.Conv3d(in_channels, filters, kernel_size=1, 47 | stride=stride, bias=False) 48 | self.proj_bnr = NormReLU(filters, relu=False) 49 | 50 | self.conv1 = nn.Conv3d(in_channels, filters, kernel_size=3, 51 | stride=stride, bias=False, padding=1) 52 | self.bnr1 = NormReLU(filters) 53 | 54 | self.conv2 = nn.Conv3d(filters, filters, kernel_size=3, 55 | stride=1, bias=False, padding=1) 56 | self.bnr2 = NormReLU(filters) 57 | 58 | self.use_projection = use_projection 59 | 60 | def forward(self, x): 61 | shortcut = x 62 | if self.use_projection: 63 | shortcut = self.proj_bnr(self.proj_conv(x)) 64 | x = self.bnr1(self.conv1(x)) 65 | x = self.bnr2(self.conv2(x)) 66 | 67 | return F.relu(x + shortcut, inplace=True) 68 | 69 | class BlockGroup(nn.Module): 70 | 71 | def __init__(self, in_channels, filters, blocks, stride): 72 | super().__init__() 73 | 74 | self.start_block = ResidualBlock(in_channels, filters, stride, use_projection=True) 75 | in_channels = filters 76 | 77 | self.blocks = [] 78 | for _ in range(1, blocks): 79 | self.blocks.append(ResidualBlock(in_channels, filters, 1)) 80 | self.blocks = nn.Sequential(*self.blocks) 81 | 82 | def forward(self, x): 83 | x = self.start_block(x) 84 | x = self.blocks(x) 85 | return x 86 | 87 | 88 | class ResNet(nn.Module): 89 | 90 | def __init__(self, in_channels, layers, width_multiplier, 91 | stride, resnet_dim=240, cifar_stem=True): 92 | super().__init__() 93 | self.width_multiplier = width_multiplier 94 | self.resnet_dim = resnet_dim 95 | 96 | assert all([int(math.log2(d)) == math.log2(d) for d in stride]), stride 97 | n_times_downsample = np.array([int(math.log2(d)) for d in stride]) 98 | 99 | if cifar_stem: 100 | self.stem = nn.Sequential( 101 | nn.Conv3d(in_channels, 64 * width_multiplier, 102 | kernel_size=3, padding=1, bias=False), 103 | NormReLU(64 * width_multiplier) 104 | ) 105 | else: 106 | stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) 107 | n_times_downsample -= 1 # conv 108 | n_times_downsample[-2:] = n_times_downsample[-2:] - 1 # pooling 109 | self.stem = nn.Sequential( 110 | nn.Conv3d(in_channels, 64 * width_multiplier, 111 | kernel_size=7, stride=stride, bias=False, 112 | padding=3), 113 | NormReLU(64 * width_multiplier), 114 | nn.MaxPool3d(kernel_size=3, stride=(1, 2, 2), padding=1) 115 | ) 116 | 117 | self.group1 = BlockGroup(64 * width_multiplier, 64 * width_multiplier, 118 | blocks=layers[0], stride=1) 119 | 120 | stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) 121 | n_times_downsample -= 1 122 | self.group2 = BlockGroup(64 * width_multiplier, 128 * width_multiplier, 123 | blocks=layers[1], stride=stride) 124 | 125 | stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) 126 | n_times_downsample -= 1 127 | self.group3 = BlockGroup(128 * width_multiplier, 256 * width_multiplier, 128 | blocks=layers[2], stride=stride) 129 | 130 | stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) 131 | n_times_downsample -= 1 132 | self.group4 = BlockGroup(256 * width_multiplier, resnet_dim, 133 | blocks=layers[3], stride=stride) 134 | assert all([d <= 0 for d in n_times_downsample]), f'final downsample {n_times_downsample}' 135 | 136 | def forward(self, x): 137 | x = self.stem(x) 138 | x = self.group1(x) 139 | x = self.group2(x) 140 | x = self.group3(x) 141 | x = self.group4(x) 142 | x = shift_dim(x, 1, -1) 143 | 144 | return x 145 | 146 | 147 | def resnet34(width_multiplier, stride, cifar_stem=True, resnet_dim=240): 148 | return ResNet(3, [3, 4, 6, 3], width_multiplier, 149 | stride, cifar_stem=cifar_stem, resnet_dim=resnet_dim) 150 | -------------------------------------------------------------------------------- /videogpt/utils.py: -------------------------------------------------------------------------------- 1 | # Shifts src_tf dim to dest dim 2 | # i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) 3 | def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): 4 | n_dims = len(x.shape) 5 | if src_dim < 0: 6 | src_dim = n_dims + src_dim 7 | if dest_dim < 0: 8 | dest_dim = n_dims + dest_dim 9 | 10 | assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims 11 | 12 | dims = list(range(n_dims)) 13 | del dims[src_dim] 14 | 15 | permutation = [] 16 | ctr = 0 17 | for i in range(n_dims): 18 | if i == dest_dim: 19 | permutation.append(src_dim) 20 | else: 21 | permutation.append(dims[ctr]) 22 | ctr += 1 23 | x = x.permute(permutation) 24 | if make_contiguous: 25 | x = x.contiguous() 26 | return x 27 | 28 | # reshapes tensor start from dim i (inclusive) 29 | # to dim j (exclusive) to the desired shape 30 | # e.g. if x.shape = (b, thw, c) then 31 | # view_range(x, 1, 2, (t, h, w)) returns 32 | # x of shape (b, t, h, w, c) 33 | def view_range(x, i, j, shape): 34 | shape = tuple(shape) 35 | 36 | n_dims = len(x.shape) 37 | if i < 0: 38 | i = n_dims + i 39 | 40 | if j is None: 41 | j = n_dims 42 | elif j < 0: 43 | j = n_dims + j 44 | 45 | assert 0 <= i < j <= n_dims 46 | 47 | x_shape = x.shape 48 | target_shape = x_shape[:i] + shape + x_shape[j:] 49 | return x.view(target_shape) 50 | 51 | 52 | def tensor_slice(x, begin, size): 53 | assert all([b >= 0 for b in begin]) 54 | size = [l - b if s == -1 else s 55 | for s, b, l in zip(size, begin, x.shape)] 56 | assert all([s >= 0 for s in size]) 57 | 58 | slices = [slice(b, b + s) for b, s in zip(begin, size)] 59 | return x[slices] 60 | 61 | 62 | import math 63 | import numpy as np 64 | import skvideo.io 65 | def save_video_grid(video, fname, nrow=None): 66 | b, c, t, h, w = video.shape 67 | video = video.permute(0, 2, 3, 4, 1) 68 | video = (video.cpu().numpy() * 255).astype('uint8') 69 | 70 | if nrow is None: 71 | nrow = math.ceil(math.sqrt(b)) 72 | ncol = math.ceil(b / nrow) 73 | padding = 1 74 | video_grid = np.zeros((t, (padding + h) * nrow + padding, 75 | (padding + w) * ncol + padding, c), dtype='uint8') 76 | for i in range(b): 77 | r = i // ncol 78 | c = i % ncol 79 | 80 | start_r = (padding + h) * r 81 | start_c = (padding + w) * c 82 | video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] 83 | 84 | skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'}) 85 | print('saved videos to', fname) 86 | 87 | 88 | -------------------------------------------------------------------------------- /videogpt/videogpt/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .vqvae import VQVAE 3 | from .gpt import VideoGPT 4 | from .data import VideoData 5 | from .download import load_vqvae, load_videogpt, load_i3d_pretrained, download 6 | 7 | -------------------------------------------------------------------------------- /videogpt/videogpt/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/videogpt/videogpt/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /videogpt/videogpt/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/videogpt/videogpt/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /videogpt/videogpt/__pycache__/data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/videogpt/videogpt/__pycache__/data.cpython-38.pyc -------------------------------------------------------------------------------- /videogpt/videogpt/__pycache__/download.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/videogpt/videogpt/__pycache__/download.cpython-38.pyc -------------------------------------------------------------------------------- /videogpt/videogpt/__pycache__/gpt.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/videogpt/videogpt/__pycache__/gpt.cpython-38.pyc -------------------------------------------------------------------------------- /videogpt/videogpt/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/videogpt/videogpt/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /videogpt/videogpt/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/videogpt/videogpt/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /videogpt/videogpt/__pycache__/vqvae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/videogpt/videogpt/__pycache__/vqvae.cpython-38.pyc -------------------------------------------------------------------------------- /videogpt/videogpt/download.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from tqdm import tqdm 3 | import os 4 | import gdown 5 | import torch 6 | 7 | from .vqvae import VQVAE 8 | from .gpt import VideoGPT 9 | 10 | 11 | def download(id, fname, root=os.path.expanduser('~/.cache/videogpt')): 12 | os.makedirs(root, exist_ok=True) 13 | destination = os.path.join(root, fname) 14 | 15 | if os.path.exists(destination): 16 | return destination 17 | 18 | gdown.download(id=id, output=destination, quiet=False) 19 | return destination 20 | 21 | 22 | _VQVAE = { 23 | 'bair_stride4x2x2': '1iIAYJ2Qqrx5Q94s5eIXQYJgAydzvT_8L', # trained on 16 frames of 64 x 64 images 24 | 'ucf101_stride4x4x4': '1uuB_8WzHP_bbBmfuaIV7PK_Itl3DyHY5', # trained on 16 frames of 128 x 128 images 25 | 'kinetics_stride4x4x4': '1DOvOZnFAIQmux6hG7pN_HkyJZy3lXbCB', # trained on 16 frames of 128 x 128 images 26 | 'kinetics_stride2x4x4': '1jvtjjtrtE4cy6pl7DK_zWFEPY3RZt2pB' # trained on 16 frames of 128 x 128 images 27 | } 28 | 29 | def load_vqvae(model_name, device=torch.device('cpu'), root=os.path.expanduser('~/.cache/videogpt')): 30 | assert model_name in _VQVAE, f"Invalid model_name: {model_name}" 31 | filepath = download(_VQVAE[model_name], model_name, root=root) 32 | vqvae = VQVAE.load_from_checkpoint(filepath).to(device) 33 | vqvae.eval() 34 | 35 | return vqvae 36 | 37 | 38 | _VIDEOGPT = { 39 | 'bair_gpt': '1fNTtJAgO6grEtPNrufkpbee1CfGztW-1', # 1-frame conditional, 16 frames of 64 x 64 images 40 | 'ucf101_uncond_gpt': '1QkF_Sb2XVRgSbFT_SxQ6aZUeDFoliPQq', # unconditional, 16 frames of 128 x 128 images 41 | } 42 | 43 | def load_videogpt(model_name, device=torch.device('cpu')): 44 | assert model_name in _VIDEOGPT, f"Invalid model_name: {model_name}" 45 | filepath = download(_VIDEOGPT[model_name], model_name) 46 | gpt = VideoGPT.load_from_checkpoint(filepath).to(device) 47 | gpt.eval() 48 | 49 | return gpt 50 | 51 | 52 | _I3D_PRETRAINED_ID = '1mQK8KD8G6UWRa5t87SRMm5PVXtlpneJT' 53 | 54 | def load_i3d_pretrained(device=torch.device('cpu')): 55 | from .fvd.pytorch_i3d import InceptionI3d 56 | i3d = InceptionI3d(400, in_channels=3).to(device) 57 | filepath = download(_I3D_PRETRAINED_ID, 'i3d_pretrained_400.pt') 58 | i3d.load_state_dict(torch.load(filepath, map_location=device)) 59 | i3d.eval() 60 | return i3d 61 | -------------------------------------------------------------------------------- /videogpt/videogpt/fvd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tinnerhrhe/VPDD/4dc21c244420043c31423c1c5161a4728f5ca5c0/videogpt/videogpt/fvd/__init__.py -------------------------------------------------------------------------------- /videogpt/videogpt/fvd/convert_tf_pretrained.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from collections import OrderedDict 3 | import tensorflow_hub as hub 4 | import torch 5 | 6 | from src_pytorch.fvd.pytorch_i3d import InceptionI3d 7 | 8 | 9 | def convert_name(name): 10 | mapping = { 11 | 'conv_3d': 'conv3d', 12 | 'batch_norm': 'bn', 13 | 'w:0': 'weight', 14 | 'b:0': 'bias', 15 | 'moving_mean:0': 'running_mean', 16 | 'moving_variance:0': 'running_var', 17 | 'beta:0': 'bias' 18 | } 19 | 20 | segs = name.split('/') 21 | new_segs = [] 22 | i = 0 23 | while i < len(segs): 24 | seg = segs[i] 25 | if 'Mixed' in seg: 26 | new_segs.append(seg) 27 | elif 'Conv' in seg and 'Mixed' not in name: 28 | new_segs.append(seg) 29 | elif 'Branch' in seg: 30 | branch_i = int(seg.split('_')[-1]) 31 | i += 1 32 | seg = segs[i] 33 | 34 | # special case due to typo in original code 35 | if 'Mixed_5b' in name and branch_i == 2: 36 | if '1x1' in seg: 37 | new_segs.append(f'b{branch_i}a') 38 | elif '3x3' in seg: 39 | new_segs.append(f'b{branch_i}b') 40 | else: 41 | raise Exception() 42 | # Either Conv3d_{i}a_... or Conv3d_{i}b_... 43 | elif 'a' in seg: 44 | if branch_i == 0: 45 | new_segs.append('b0') 46 | else: 47 | new_segs.append(f'b{branch_i}a') 48 | elif 'b' in seg: 49 | new_segs.append(f'b{branch_i}b') 50 | else: 51 | raise Exception 52 | elif seg == 'Logits': 53 | new_segs.append('logits') 54 | i += 1 55 | elif seg in mapping: 56 | new_segs.append(mapping[seg]) 57 | else: 58 | raise Exception(f"No match found for seg {seg} in name {name}") 59 | 60 | i += 1 61 | return '.'.join(new_segs) 62 | 63 | def convert_tensor(tensor): 64 | tensor_dim = len(tensor.shape) 65 | if tensor_dim == 5: # conv or bn 66 | if all([t == 1 for t in tensor.shape[:-1]]): 67 | tensor = tensor.squeeze() 68 | else: 69 | tensor = tensor.permute(4, 3, 0, 1, 2).contiguous() 70 | elif tensor_dim == 1: # conv bias 71 | pass 72 | else: 73 | raise Exception(f"Invalid shape {tensor.shape}") 74 | return tensor 75 | 76 | n_class = int(sys.argv[1]) # 600 or 400 77 | assert n_class in [400, 600] 78 | 79 | # Converts model from https://github.com/google-research/google-research/tree/master/frechet_video_distance 80 | # to pytorch version for loading 81 | model_url = f"https://tfhub.dev/deepmind/i3d-kinetics-{n_class}/1" 82 | i3d = hub.load(model_url) 83 | name_prefix = 'RGB/inception_i3d/' 84 | 85 | print('Creating state_dict...') 86 | all_names = [] 87 | state_dict = OrderedDict() 88 | for var in i3d.variables: 89 | name = var.name[len(name_prefix):] 90 | new_name = convert_name(name) 91 | all_names.append(new_name) 92 | 93 | tensor = torch.FloatTensor(var.value().numpy()) 94 | new_tensor = convert_tensor(tensor) 95 | 96 | state_dict[new_name] = new_tensor 97 | 98 | if 'bn.bias' in new_name: 99 | new_name = new_name[:-4] + 'weight' # bn.weight 100 | new_tensor = torch.ones_like(new_tensor).float() 101 | state_dict[new_name] = new_tensor 102 | 103 | print(f'Complete state_dict with {len(state_dict)} entries') 104 | 105 | s = dict() 106 | for i, n in enumerate(all_names): 107 | s[n] = s.get(n, []) + [i] 108 | 109 | for k, v in s.items(): 110 | if len(v) > 1: 111 | print('dup', k) 112 | for i in v: 113 | print('\t', i3d.variables[i].name) 114 | 115 | print('Testing load_state_dict...') 116 | print('Creating model...') 117 | 118 | i3d = InceptionI3d(n_class, in_channels=3) 119 | 120 | print('Loading state_dict...') 121 | i3d.load_state_dict(state_dict) 122 | 123 | print(f'Saving state_dict as fvd/i3d_pretrained_{n_class}.pt') 124 | torch.save(state_dict, f'fvd/i3d_pretrained_{n_class}.pt') 125 | 126 | print('Done') 127 | 128 | -------------------------------------------------------------------------------- /videogpt/videogpt/fvd/fvd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ..data import preprocess as preprocess_single 3 | 4 | 5 | def preprocess(videos, target_resolution=224): 6 | # videos in {0, ..., 255} as np.uint8 array 7 | b, t, h, w, c = videos.shape 8 | videos = torch.from_numpy(videos) 9 | videos = torch.stack([preprocess_single(video, target_resolution) for video in videos]) 10 | return videos * 2 # [-0.5, 0.5] -> [-1, 1] 11 | 12 | def get_fvd_logits(videos, i3d, device): 13 | videos = preprocess(videos) 14 | embeddings = get_logits(i3d, videos, device) 15 | return embeddings 16 | 17 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161 18 | def _symmetric_matrix_square_root(mat, eps=1e-10): 19 | u, s, v = torch.svd(mat) 20 | si = torch.where(s < eps, s, torch.sqrt(s)) 21 | return torch.matmul(torch.matmul(u, torch.diag(si)), v.t()) 22 | 23 | # https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400 24 | def trace_sqrt_product(sigma, sigma_v): 25 | sqrt_sigma = _symmetric_matrix_square_root(sigma) 26 | sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma)) 27 | return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) 28 | 29 | # https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2 30 | def cov(m, rowvar=False): 31 | '''Estimate a covariance matrix given data. 32 | 33 | Covariance indicates the level to which two variables vary together. 34 | If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`, 35 | then the covariance matrix element `C_{ij}` is the covariance of 36 | `x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`. 37 | 38 | Args: 39 | m: A 1-D or 2-D array containing multiple variables and observations. 40 | Each row of `m` represents a variable, and each column a single 41 | observation of all those variables. 42 | rowvar: If `rowvar` is True, then each row represents a 43 | variable, with observations in the columns. Otherwise, the 44 | relationship is transposed: each column represents a variable, 45 | while the rows contain observations. 46 | 47 | Returns: 48 | The covariance matrix of the variables. 49 | ''' 50 | if m.dim() > 2: 51 | raise ValueError('m has more than 2 dimensions') 52 | if m.dim() < 2: 53 | m = m.view(1, -1) 54 | if not rowvar and m.size(0) != 1: 55 | m = m.t() 56 | 57 | fact = 1.0 / (m.size(1) - 1) # unbiased estimate 58 | m -= torch.mean(m, dim=1, keepdim=True) 59 | mt = m.t() # if complex: mt = m.t().conj() 60 | return fact * m.matmul(mt).squeeze() 61 | 62 | 63 | def frechet_distance(x1, x2): 64 | x1 = x1.flatten(start_dim=1) 65 | x2 = x2.flatten(start_dim=1) 66 | m, m_w = x1.mean(dim=0), x2.mean(dim=0) 67 | sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False) 68 | 69 | sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) 70 | trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component 71 | 72 | mean = torch.sum((m - m_w) ** 2) 73 | fd = trace + mean 74 | return fd 75 | 76 | 77 | def get_logits(i3d, videos, device): 78 | assert videos.shape[0] % 16 == 0 79 | with torch.no_grad(): 80 | logits = [] 81 | for i in range(0, videos.shape[0], 16): 82 | batch = videos[i:i + 16].to(device) 83 | logits.append(i3d(batch)) 84 | logits = torch.cat(logits, dim=0) 85 | return logits 86 | -------------------------------------------------------------------------------- /videogpt/videogpt/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .utils import shift_dim 9 | 10 | class ChannelLayerNorm(nn.Module): 11 | # layer norm on channels 12 | def __init__(self, in_features): 13 | super().__init__() 14 | self.norm = nn.LayerNorm(in_features) 15 | 16 | def forward(self, x): 17 | x = shift_dim(x, 1, -1) 18 | x = self.norm(x) 19 | x = shift_dim(x, -1, 1) 20 | return x 21 | 22 | 23 | class NormReLU(nn.Module): 24 | 25 | def __init__(self, channels, relu=True, affine=True): 26 | super().__init__() 27 | 28 | self.relu = relu 29 | self.norm = ChannelLayerNorm(channels) 30 | 31 | def forward(self, x): 32 | x_float = x.float() 33 | x_float = self.norm(x_float) 34 | x = x_float.type_as(x) 35 | if self.relu: 36 | x = F.relu(x, inplace=True) 37 | return x 38 | 39 | 40 | class ResidualBlock(nn.Module): 41 | 42 | def __init__(self, in_channels, filters, stride, use_projection=False): 43 | super().__init__() 44 | 45 | if use_projection: 46 | self.proj_conv = nn.Conv3d(in_channels, filters, kernel_size=1, 47 | stride=stride, bias=False) 48 | self.proj_bnr = NormReLU(filters, relu=False) 49 | 50 | self.conv1 = nn.Conv3d(in_channels, filters, kernel_size=3, 51 | stride=stride, bias=False, padding=1) 52 | self.bnr1 = NormReLU(filters) 53 | 54 | self.conv2 = nn.Conv3d(filters, filters, kernel_size=3, 55 | stride=1, bias=False, padding=1) 56 | self.bnr2 = NormReLU(filters) 57 | 58 | self.use_projection = use_projection 59 | 60 | def forward(self, x): 61 | shortcut = x 62 | if self.use_projection: 63 | shortcut = self.proj_bnr(self.proj_conv(x)) 64 | x = self.bnr1(self.conv1(x)) 65 | x = self.bnr2(self.conv2(x)) 66 | 67 | return F.relu(x + shortcut, inplace=True) 68 | 69 | class BlockGroup(nn.Module): 70 | 71 | def __init__(self, in_channels, filters, blocks, stride): 72 | super().__init__() 73 | 74 | self.start_block = ResidualBlock(in_channels, filters, stride, use_projection=True) 75 | in_channels = filters 76 | 77 | self.blocks = [] 78 | for _ in range(1, blocks): 79 | self.blocks.append(ResidualBlock(in_channels, filters, 1)) 80 | self.blocks = nn.Sequential(*self.blocks) 81 | 82 | def forward(self, x): 83 | x = self.start_block(x) 84 | x = self.blocks(x) 85 | return x 86 | 87 | 88 | class ResNet(nn.Module): 89 | 90 | def __init__(self, in_channels, layers, width_multiplier, 91 | stride, resnet_dim=240, cifar_stem=True): 92 | super().__init__() 93 | self.width_multiplier = width_multiplier 94 | self.resnet_dim = resnet_dim 95 | 96 | assert all([int(math.log2(d)) == math.log2(d) for d in stride]), stride 97 | n_times_downsample = np.array([int(math.log2(d)) for d in stride]) 98 | 99 | if cifar_stem: 100 | self.stem = nn.Sequential( 101 | nn.Conv3d(in_channels, 64 * width_multiplier, 102 | kernel_size=3, padding=1, bias=False), 103 | NormReLU(64 * width_multiplier) 104 | ) 105 | else: 106 | stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) 107 | n_times_downsample -= 1 # conv 108 | n_times_downsample[-2:] = n_times_downsample[-2:] - 1 # pooling 109 | self.stem = nn.Sequential( 110 | nn.Conv3d(in_channels, 64 * width_multiplier, 111 | kernel_size=7, stride=stride, bias=False, 112 | padding=3), 113 | NormReLU(64 * width_multiplier), 114 | nn.MaxPool3d(kernel_size=3, stride=(1, 2, 2), padding=1) 115 | ) 116 | 117 | self.group1 = BlockGroup(64 * width_multiplier, 64 * width_multiplier, 118 | blocks=layers[0], stride=1) 119 | 120 | stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) 121 | n_times_downsample -= 1 122 | self.group2 = BlockGroup(64 * width_multiplier, 128 * width_multiplier, 123 | blocks=layers[1], stride=stride) 124 | 125 | stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) 126 | n_times_downsample -= 1 127 | self.group3 = BlockGroup(128 * width_multiplier, 256 * width_multiplier, 128 | blocks=layers[2], stride=stride) 129 | 130 | stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) 131 | n_times_downsample -= 1 132 | self.group4 = BlockGroup(256 * width_multiplier, resnet_dim, 133 | blocks=layers[3], stride=stride) 134 | assert all([d <= 0 for d in n_times_downsample]), f'final downsample {n_times_downsample}' 135 | 136 | def forward(self, x): 137 | x = self.stem(x) 138 | x = self.group1(x) 139 | x = self.group2(x) 140 | x = self.group3(x) 141 | x = self.group4(x) 142 | x = shift_dim(x, 1, -1) 143 | 144 | return x 145 | 146 | 147 | def resnet34(width_multiplier, stride, cifar_stem=True, resnet_dim=240): 148 | return ResNet(3, [3, 4, 6, 3], width_multiplier, 149 | stride, cifar_stem=cifar_stem, resnet_dim=resnet_dim) 150 | -------------------------------------------------------------------------------- /videogpt/videogpt/utils.py: -------------------------------------------------------------------------------- 1 | # Shifts src_tf dim to dest dim 2 | # i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) 3 | def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): 4 | n_dims = len(x.shape) 5 | if src_dim < 0: 6 | src_dim = n_dims + src_dim 7 | if dest_dim < 0: 8 | dest_dim = n_dims + dest_dim 9 | 10 | assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims 11 | 12 | dims = list(range(n_dims)) 13 | del dims[src_dim] 14 | 15 | permutation = [] 16 | ctr = 0 17 | for i in range(n_dims): 18 | if i == dest_dim: 19 | permutation.append(src_dim) 20 | else: 21 | permutation.append(dims[ctr]) 22 | ctr += 1 23 | x = x.permute(permutation) 24 | if make_contiguous: 25 | x = x.contiguous() 26 | return x 27 | 28 | # reshapes tensor start from dim i (inclusive) 29 | # to dim j (exclusive) to the desired shape 30 | # e.g. if x.shape = (b, thw, c) then 31 | # view_range(x, 1, 2, (t, h, w)) returns 32 | # x of shape (b, t, h, w, c) 33 | def view_range(x, i, j, shape): 34 | shape = tuple(shape) 35 | 36 | n_dims = len(x.shape) 37 | if i < 0: 38 | i = n_dims + i 39 | 40 | if j is None: 41 | j = n_dims 42 | elif j < 0: 43 | j = n_dims + j 44 | 45 | assert 0 <= i < j <= n_dims 46 | 47 | x_shape = x.shape 48 | target_shape = x_shape[:i] + shape + x_shape[j:] 49 | return x.view(target_shape) 50 | 51 | 52 | def tensor_slice(x, begin, size): 53 | assert all([b >= 0 for b in begin]) 54 | size = [l - b if s == -1 else s 55 | for s, b, l in zip(size, begin, x.shape)] 56 | assert all([s >= 0 for s in size]) 57 | 58 | slices = [slice(b, b + s) for b, s in zip(begin, size)] 59 | return x[slices] 60 | 61 | 62 | import math 63 | import numpy as np 64 | import skvideo.io 65 | def save_video_grid(video, fname, nrow=None): 66 | b, c, t, h, w = video.shape 67 | video = video.permute(0, 2, 3, 4, 1) 68 | video = (video.cpu().numpy() * 255).astype('uint8') 69 | 70 | if nrow is None: 71 | nrow = math.ceil(math.sqrt(b)) 72 | ncol = math.ceil(b / nrow) 73 | padding = 1 74 | video_grid = np.zeros((t, (padding + h) * nrow + padding, 75 | (padding + w) * ncol + padding, c), dtype='uint8') 76 | for i in range(b): 77 | r = i // ncol 78 | c = i % ncol 79 | 80 | start_r = (padding + h) * r 81 | start_c = (padding + w) * c 82 | video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] 83 | 84 | skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'}) 85 | print('saved videos to', fname) 86 | 87 | 88 | --------------------------------------------------------------------------------