├── .gitignore ├── LICENSE ├── README.md ├── VILP ├── config │ ├── save_vilp_pushT_state_planning.yaml │ ├── train_vilp_can_agent_planning.yaml │ ├── train_vilp_can_agent_state_planning.yaml │ ├── train_vilp_can_policy.yaml │ ├── train_vilp_can_state_policy.yaml │ ├── train_vilp_can_wrist_planning.yaml │ ├── train_vilp_can_wrist_state_planning.yaml │ ├── train_vilp_pushT_planning.yaml │ ├── train_vilp_pushT_policy.yaml │ ├── train_vilp_pushT_state_planning.yaml │ ├── train_vilp_pushT_state_policy.yaml │ ├── train_vq_can_agent.yaml │ ├── train_vq_can_wrist.yaml │ └── train_vq_pushT.yaml ├── dataset │ ├── push_multi_image_ae_dataset.py │ ├── pusht_image_dataset.py │ ├── robomimic_ae_dataset.py │ └── robomimic_vilp_dataset.py ├── env │ ├── push_multi │ │ └── push_multi_env.py │ └── pusht │ │ ├── __init__.py │ │ ├── pusht_collect_env.py │ │ ├── pusht_env.py │ │ ├── pusht_image_env.py │ │ └── pusht_keypoints_env.py ├── env_runner │ ├── pusht_vilp_runner.py │ ├── robomimic_vilp_runner.py │ └── utils.py ├── flowdiffusion │ ├── datasets.py │ ├── goal_diffusion.py │ ├── guided_diffusion │ │ └── guided_diffusion │ │ │ ├── dist_util.py │ │ │ ├── fp16_util.py │ │ │ ├── gaussian_diffusion.py │ │ │ ├── image_datasets.py │ │ │ ├── imagen.py │ │ │ ├── logger.py │ │ │ ├── losses.py │ │ │ ├── nn.py │ │ │ ├── resample.py │ │ │ ├── respace.py │ │ │ ├── script_util.py │ │ │ ├── train_util.py │ │ │ └── unet.py │ ├── model │ │ ├── __init__.py │ │ ├── attention_processor.py │ │ ├── imagen.py │ │ ├── myunet.py │ │ ├── resnet.py │ │ ├── transformer_temporal.py │ │ ├── unet_3d_blocks.py │ │ └── unet_3d_condition.py │ ├── train_bridge.py │ ├── train_mw.py │ ├── train_thor.py │ ├── unet.py │ └── utils.py ├── model │ ├── conditional_unet3d.py │ ├── modules │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ ├── image_degradation │ │ │ ├── __init__.py │ │ │ ├── bsrgan.py │ │ │ ├── bsrgan_light.py │ │ │ ├── utils │ │ │ │ └── test.png │ │ │ └── utils_image.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── contperceptual.py │ │ │ └── vqperceptual.py │ │ └── x_transformer.py │ ├── normalizer.py │ ├── openaimodel.py │ ├── spatial_mask_generator_3d.py │ └── vision │ │ ├── crop_randomizer.py │ │ ├── model_getter.py │ │ └── multi_image_obs_encoder.py ├── policy │ ├── action_mapping_itp.py │ ├── diffusion_unet_image_policy.py │ ├── latent_video_diffusion.py │ ├── latent_video_diffusion_conditional_encoding.py │ ├── utils.py │ ├── vilp_low_level_policy.py │ ├── vilp_low_level_policy_cond_encode.py │ ├── vilp_planning.py │ └── vilp_planning_conditional_encoding.py ├── taming │ ├── data │ │ ├── ade20k.py │ │ ├── annotated_objects_coco.py │ │ ├── annotated_objects_dataset.py │ │ ├── annotated_objects_open_images.py │ │ ├── base.py │ │ ├── coco.py │ │ ├── conditional_builder │ │ │ ├── objects_bbox.py │ │ │ ├── objects_center_points.py │ │ │ └── utils.py │ │ ├── custom.py │ │ ├── faceshq.py │ │ ├── helper_types.py │ │ ├── image_transforms.py │ │ ├── imagenet.py │ │ ├── open_images_helper.py │ │ ├── sflckr.py │ │ └── utils.py │ ├── lr_scheduler.py │ ├── models │ │ ├── cond_transformer.py │ │ ├── dummy_cond_stage.py │ │ └── vqgan.py │ ├── modules │ │ ├── diffusionmodules │ │ │ └── model.py │ │ ├── discriminator │ │ │ └── model.py │ │ ├── losses │ │ │ ├── lpips.py │ │ │ └── vqperceptual.py │ │ ├── misc │ │ │ └── coord.py │ │ ├── transformer │ │ │ ├── mingpt.py │ │ │ └── permuter.py │ │ ├── util.py │ │ └── vqvae │ │ │ └── quantize.py │ └── util.py └── workspace │ ├── save_vilp_planning_workspace.py │ ├── train_vilp_planning_workspace.py │ ├── train_vilp_policy_workspace.py │ ├── train_vqgan_workspace.py │ └── utils.py ├── conda_environment.yaml ├── images_to_replybuffer.py ├── install_custom_packages.sh ├── setup.py ├── teasers └── teaser.gif └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | bin 2 | logs 3 | wandb 4 | outputs 5 | .vscode 6 | _wandb 7 | /data 8 | **/.DS_Store 9 | /outputs 10 | fuse.cfg 11 | images 12 | *.ai 13 | img 14 | img_log 15 | /lightning_logs 16 | /log_images 17 | /outputs 18 | /taming 19 | /vq_models 20 | # Generation results 21 | results/ 22 | 2pi_vis/ 23 | ray/auth.json 24 | *.ckpt 25 | # Byte-compiled / optimized / DLL files 26 | __pycache__/ 27 | *.py[cod] 28 | *$py.class 29 | *.pth 30 | # C extensions 31 | *.so 32 | 33 | # Distribution / packaging 34 | .Python 35 | build/ 36 | develop-eggs/ 37 | dist/ 38 | downloads/ 39 | eggs/ 40 | .eggs/ 41 | lib/ 42 | lib64/ 43 | parts/ 44 | sdist/ 45 | var/ 46 | wheels/ 47 | pip-wheel-metadata/ 48 | share/python-wheels/ 49 | /*.egg-info 50 | *.egg-info/ 51 | VILP.egg-info 52 | .installed.cfg 53 | *.egg 54 | MANIFEST 55 | # PyInstaller 56 | # Usually these files are written by a python script from a template 57 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 58 | *.manifest 59 | *.spec 60 | 61 | # Installer logs 62 | pip-log.txt 63 | pip-delete-this-directory.txt 64 | 65 | # Unit test / coverage reports 66 | htmlcov/ 67 | .tox/ 68 | .nox/ 69 | .coverage 70 | .coverage.* 71 | .cache 72 | nosetests.xml 73 | coverage.xml 74 | *.cover 75 | *.py,cover 76 | .hypothesis/ 77 | .pytest_cache/ 78 | 79 | # Translations 80 | *.mo 81 | *.pot 82 | 83 | # Django stuff: 84 | *.log 85 | local_settings.py 86 | db.sqlite3 87 | db.sqlite3-journal 88 | 89 | # Flask stuff: 90 | instance/ 91 | .webassets-cache 92 | 93 | # Scrapy stuff: 94 | .scrapy 95 | 96 | # Sphinx documentation 97 | docs/_build/ 98 | 99 | # PyBuilder 100 | target/ 101 | 102 | # Jupyter Notebook 103 | .ipynb_checkpoints 104 | 105 | # IPython 106 | profile_default/ 107 | ipython_config.py 108 | 109 | # pyenv 110 | .python-version 111 | 112 | # pipenv 113 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 114 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 115 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 116 | # install all needed dependencies. 117 | #Pipfile.lock 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Spyder project settings 130 | .spyderproject 131 | .spyproject 132 | 133 | # Rope project settings 134 | .ropeproject 135 | 136 | # mkdocs documentation 137 | /site 138 | 139 | # mypy 140 | .mypy_cache/ 141 | .dmypy.json 142 | dmypy.json 143 | 144 | # Pyre type checker 145 | .pyre/ 146 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Zhengtong Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VILP: Imitation Learning with Latent Video Planning 2 | 3 | _Accepted by IEEE RA-L_ 4 | 5 | [arXiv](https://arxiv.org/abs/2502.01784) | [Summary Video](https://www.youtube.com/watch?v=sfa_AmI0NoI) 6 | 7 | ![teaser](teasers/teaser.gif) 8 | 9 | 10 | ## Installation 11 | 12 | For installation, please run 13 | 14 | ```console 15 | $ cd VILP 16 | $ mamba env create -f conda_environment.yaml && bash install_custom_packages.sh 17 | ``` 18 | 19 | Please note that in the `install_custom_packages.sh` script, the following command is executed 20 | ```console 21 | $ source ~/miniforge3/etc/profile.d/conda.sh 22 | ``` 23 | 24 | This command is generally correct. However, if your Conda environments are not located in the `~/miniforge3` directory, please adjust the command to match the path of your environment. 25 | 26 | ## Example 27 | 28 | Try the [simulation Push-T task](https://diffusion-policy.cs.columbia.edu/) with VILP! 29 | 30 | ### First step: image compression training 31 | 32 | Activate conda environment 33 | ```console 34 | $ conda activate vilpenv 35 | ``` 36 | 37 | Then launch the training by running 38 | ```console 39 | $ python train.py --config-dir=./VILP/config --config-name=train_vq_pushT.yaml 40 | ``` 41 | The pretrained models will be saved in /vq_models 42 | 43 | ### Second step: video planning training 44 | 45 | All logs from training will be uploaded to wandb. Login to [wandb](https://wandb.ai) (if you haven't already) 46 | ```console 47 | $ wandb login 48 | ``` 49 | Then launch the training by running 50 | ```console 51 | $ python train.py --config-dir=./VILP/config --config-name=train_vilp_pushT_state_planning.yaml hydra.run.dir=data/outputs/your_folder_name 52 | ``` 53 | Please note that you need to specify the path to the pretrained VQVAE in the YAML config file. 54 | 55 | After the model is fully trained (It usually requires at least several hours, which depends on your GPU), run the following command line to export the model from the checkpoint 56 | 57 | ```console 58 | $ python train.py --config-dir=./VILP/config --config-name=save_vilp_pushT_state_planning.yaml hydra.run.dir=data/outputs/the_checkpoint_folder 59 | ``` 60 | 61 | If you training the planning model without low dimentional observations (use `train_vilp_pushT_planning.yaml`), you should directly see some generated videos on wandb during training! 62 | 63 | 64 | ### Third step: policy training and rollout 65 | 66 | Launch the job by running 67 | 68 | ```console 69 | $ python train.py --config-dir=./VILP/config --config-name=train_vilp_pushT_state_policy.yaml hydra.run.dir=data/outputs/your_folder_name 70 | ``` 71 | 72 | All results will be uploaded to wandb! 73 | 74 | ## BibTex 75 | 76 | If you find this codebase useful, consider citing: 77 | 78 | ```bibtex 79 | @misc{xu2025vilp, 80 | title={VILP: Imitation Learning with Latent Video Planning}, 81 | author={Zhengtong Xu and Qiang Qiu and Yu She}, 82 | year={2025}, 83 | eprint={2502.01784}, 84 | archivePrefix={arXiv}, 85 | primaryClass={cs.RO}, 86 | url={https://arxiv.org/abs/2502.01784}, 87 | } 88 | ``` -------------------------------------------------------------------------------- /VILP/config/save_vilp_pushT_state_planning.yaml: -------------------------------------------------------------------------------- 1 | _target_: VILP.workspace.train_vilp_planning_workspace.TrainVilpPlanningWorkspace 2 | checkpoint: 3 | save_last_ckpt: true 4 | save_last_snapshot: false 5 | topk: 6 | format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt 7 | k: 5 8 | mode: max 9 | monitor_key: test_mean_score 10 | logging: 11 | group: null 12 | id: null 13 | mode: online 14 | name: VILP_video_pushT_planning 15 | project: VILP_debug 16 | resume: true 17 | tags: 18 | - VILP_video 19 | - pushT_video 20 | - default 21 | multi_run: 22 | run_dir: data/outputs/VILP_video_pushT_planning 23 | wandb_name_base: VILP_video_pushT_planning 24 | 25 | optimizer: 26 | _target_: torch.optim.AdamW 27 | betas: 28 | - 0.95 29 | - 0.999 30 | eps: 1.0e-08 31 | lr: 0.0001 32 | weight_decay: 1.0e-06 33 | past_action_visible: false 34 | 35 | dataset_obs_steps: &dataset_obs_steps 1 36 | batch_size: &batch_size 64 37 | device: &device cuda:0 38 | exp_name: default 39 | horizon: &horizon 23 40 | n_obs_steps: *dataset_obs_steps 41 | name: vilp_video 42 | obs_as_global_cond: true 43 | image_shape: &image_shape [3,96,96] 44 | crop_shape: &crop_shape null 45 | 46 | max_generation_steps: 25 47 | input_key: &input_key img 48 | output_key: &output_key image 49 | # dummy values for not visualizing the generated views 50 | generated_views: 3 51 | use_sim: false 52 | 53 | policy: 54 | _target_: VILP.policy.vilp_planning.VilpPlanning 55 | subgoal_steps: &subgoal_steps 12 56 | subgoal_interval: 2 57 | latent_shape: &latent_shape [3,12,12] 58 | vqgan_config: 59 | ckpt_path: vq_models/2024-08-12T21-42-33/checkpoints/checkpoint-epoch=74.ckpt 60 | embed_dim: 3 61 | n_embed: 1024 62 | ddconfig: 63 | double_z: False 64 | z_channels: 3 65 | resolution: 96 66 | in_channels: 3 67 | out_ch: 3 68 | ch: 128 69 | ch_mult: [1,1,2,4] # num_down = len(ch_mult)-1 70 | num_res_blocks: 2 71 | attn_resolutions: [16] 72 | dropout: 0.0 73 | lossconfig: 74 | target: VILP.taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 75 | params: 76 | disc_conditional: False 77 | disc_in_channels: 3 78 | disc_start: 10000 79 | disc_weight: 0.8 80 | codebook_weight: 1.0 81 | model_high_level: 82 | _target_: VILP.policy.latent_video_diffusion.LatentVideoDiffusion 83 | latent_shape: *latent_shape 84 | subgoal_steps: *subgoal_steps 85 | device: *device 86 | cond_predict_scale: false 87 | channel_mult: [1,2,4] 88 | num_res_blocks: 2 89 | transformer_depth: 1 90 | attention_resolutions: [4,8] 91 | model_channels: 128 92 | num_head_channels: 8 93 | crop_shape: *crop_shape 94 | eval_fixed_crop: true 95 | horizon: *subgoal_steps 96 | n_obs_steps: 1 97 | noise_scheduler: 98 | _target_: diffusers.schedulers.scheduling_ddim.DDIMScheduler 99 | beta_end: 0.0195 100 | beta_schedule: squaredcos_cap_v2 101 | beta_start: 0.0015 102 | clip_sample: true 103 | set_alpha_to_one: True 104 | steps_offset: 0 105 | num_train_timesteps: 300 106 | prediction_type: epsilon 107 | num_inference_steps: 16 108 | obs_as_global_cond: true 109 | obs_encoder_group_norm: true 110 | obs_encoder: 111 | _target_: VILP.model.vision.multi_image_obs_encoder.MultiImageObsEncoder 112 | shape_meta: 113 | obs: 114 | image: 115 | shape: *image_shape 116 | type: rgb 117 | agent_pos: 118 | shape: 119 | - 2 120 | rgb_model: 121 | _target_: diffusion_policy.model.vision.model_getter.get_resnet 122 | name: resnet18 123 | weights: null 124 | resize_shape: null 125 | crop_shape: *crop_shape 126 | # constant center crop 127 | random_crop: True 128 | use_group_norm: True 129 | share_rgb_model: False 130 | imagenet_norm: True 131 | shape_meta: 132 | latent: 133 | shape: *latent_shape 134 | obs: 135 | image: 136 | shape: *image_shape 137 | type: rgb 138 | agent_pos: 139 | shape: 140 | - 2 141 | 142 | 143 | task: 144 | dataset: 145 | _target_: VILP.dataset.pusht_image_dataset.PushtImageDataset 146 | horizon: *horizon 147 | max_train_episodes: null 148 | pad_after: 15 149 | pad_before: 1 150 | seed: 42 151 | val_ratio: 0.007 152 | zarr_path: data/pusht/pusht_cchi_v7_replay.zarr 153 | image_shape: *image_shape 154 | name: pushT_video 155 | shape_meta: 156 | action: 157 | shape: 158 | - 2 159 | obs: 160 | image: 161 | shape: *image_shape 162 | type: rgb 163 | agent_pos: 164 | shape: 165 | - 2 166 | ema: 167 | _target_: diffusion_policy.model.diffusion.ema_model.EMAModel 168 | inv_gamma: 1.0 169 | max_value: 0.9999 170 | min_value: 0.0 171 | power: 0.75 172 | update_after_step: 0 173 | task_name: pushT_state 174 | training: 175 | checkpoint_every: 1 176 | debug: false 177 | device: *device 178 | gradient_accumulate_every: 1 179 | lr_scheduler: cosine 180 | lr_warmup_steps: 500 181 | max_train_steps: null 182 | max_val_steps: null 183 | num_epochs: 1000 184 | resume: true 185 | rollout_every: 25 186 | sample_every: 25 187 | seed: 42 188 | tqdm_interval_sec: 1.0 189 | use_ema: true 190 | val_every: 50 191 | val_dataloader: 192 | batch_size: *batch_size 193 | num_workers: 8 194 | persistent_workers: false 195 | pin_memory: true 196 | shuffle: false 197 | dataloader: 198 | batch_size: *batch_size 199 | num_workers: 8 200 | persistent_workers: false 201 | pin_memory: true 202 | shuffle: true -------------------------------------------------------------------------------- /VILP/config/train_vilp_can_agent_planning.yaml: -------------------------------------------------------------------------------- 1 | _target_: VILP.workspace.train_vilp_planning_workspace.TrainVilpPlanningWorkspace 2 | checkpoint: 3 | save_last_ckpt: true 4 | save_last_snapshot: false 5 | topk: 6 | format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt 7 | k: 5 8 | mode: max 9 | monitor_key: test_mean_score 10 | logging: 11 | group: null 12 | id: null 13 | mode: online 14 | name: VILP_video_can_planning 15 | project: VILP_debug 16 | resume: true 17 | tags: 18 | - VILP_video 19 | - can_video_robomimic 20 | - default 21 | multi_run: 22 | run_dir: data/outputs/VILP_video_can_planning 23 | wandb_name_base: VILP_video_can_planning 24 | 25 | optimizer: 26 | _target_: torch.optim.AdamW 27 | betas: 28 | - 0.95 29 | - 0.999 30 | eps: 1.0e-08 31 | lr: 0.0001 32 | weight_decay: 1.0e-06 33 | past_action_visible: false 34 | 35 | dataset_obs_steps: &dataset_obs_steps 1 36 | batch_size: &batch_size 128 37 | device: &device cuda:0 38 | exp_name: default 39 | horizon: &horizon 11 40 | n_obs_steps: *dataset_obs_steps 41 | name: 2pi_video 42 | obs_as_global_cond: true 43 | image_shape: &image_shape [3,96,96] 44 | crop_shape: &crop_shape null 45 | max_generation_steps: 25 46 | input_keys: &input_keys [agentview_image, robot0_eye_in_hand_image] 47 | output_key: &output_key agentview_image 48 | generated_views: 2 49 | use_sim: false 50 | policy: 51 | _target_: VILP.policy.vilp_planning.VilpPlanning 52 | subgoal_steps: &subgoal_steps 6 53 | subgoal_interval: 2 54 | latent_shape: &latent_shape [3,12,12] 55 | output_key: *output_key 56 | vqgan_config: 57 | ckpt_path: vq_models/2024-08-12T00-43-33/checkpoints/checkpoint-epoch=54.ckpt 58 | embed_dim: 3 59 | n_embed: 1024 60 | ddconfig: 61 | double_z: False 62 | z_channels: 3 63 | resolution: 96 64 | in_channels: 3 65 | out_ch: 3 66 | ch: 128 67 | ch_mult: [1,1,2,4] # num_down = len(ch_mult)-1 68 | num_res_blocks: 2 69 | attn_resolutions: [16] 70 | dropout: 0.0 71 | lossconfig: 72 | target: VILP.taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 73 | params: 74 | disc_conditional: False 75 | disc_in_channels: 3 76 | disc_start: 10000 77 | disc_weight: 0.8 78 | codebook_weight: 1.0 79 | model_high_level: 80 | _target_: VILP.policy.latent_video_diffusion.LatentVideoDiffusion 81 | latent_shape: *latent_shape 82 | subgoal_steps: *subgoal_steps 83 | device: *device 84 | cond_predict_scale: false 85 | channel_mult: [1,2,4] 86 | num_res_blocks: 2 87 | transformer_depth: 1 88 | attention_resolutions: [4,8] 89 | model_channels: 64 90 | num_head_channels: 8 91 | crop_shape: *crop_shape 92 | eval_fixed_crop: true 93 | n_obs_steps: *dataset_obs_steps 94 | horizon: *subgoal_steps 95 | noise_scheduler: 96 | _target_: diffusers.schedulers.scheduling_ddim.DDIMScheduler 97 | beta_end: 0.0195 98 | beta_schedule: squaredcos_cap_v2 99 | beta_start: 0.0015 100 | clip_sample: true 101 | set_alpha_to_one: True 102 | steps_offset: 0 103 | num_train_timesteps: 300 104 | prediction_type: epsilon 105 | num_inference_steps: 16 106 | obs_as_global_cond: true 107 | obs_encoder_group_norm: true 108 | obs_encoder: 109 | _target_: VILP.model.vision.multi_image_obs_encoder.MultiImageObsEncoder 110 | shape_meta: 111 | obs: 112 | agentview_image: 113 | shape: [3, 96, 96] 114 | type: rgb 115 | robot0_eye_in_hand_image: 116 | shape: [3, 96, 96] 117 | type: rgb 118 | rgb_model: 119 | _target_: diffusion_policy.model.vision.model_getter.get_resnet 120 | name: resnet18 121 | weights: null 122 | resize_shape: null 123 | crop_shape: *crop_shape 124 | # constant center crop 125 | random_crop: True 126 | use_group_norm: True 127 | share_rgb_model: False 128 | imagenet_norm: True 129 | shape_meta: 130 | latent: 131 | shape: *latent_shape 132 | obs: 133 | agentview_image: 134 | shape: [3, 96, 96] 135 | type: rgb 136 | robot0_eye_in_hand_image: 137 | shape: [3, 96, 96] 138 | type: rgb 139 | 140 | dataset_shape_meta: &dataset_shape_meta 141 | obs: 142 | agentview_image: 143 | shape: [3, 84, 84] 144 | type: rgb 145 | robot0_eye_in_hand_image: 146 | shape: [3, 84, 84] 147 | type: rgb 148 | action: 149 | shape: [10] 150 | 151 | task: 152 | dataset: 153 | _target_: VILP.dataset.robomimic_vilp_dataset.RobomimicVilpDataset 154 | shape_meta: *dataset_shape_meta 155 | dataset_path: data/robomimic/datasets/can/ph/image_abs.hdf5 156 | horizon: *horizon 157 | pad_before: 1 158 | pad_after: 7 159 | use_cache: False 160 | abs_action: True 161 | seed: 242 162 | val_ratio: 0 163 | rotation_rep: 'rotation_6d' 164 | use_legacy_normalizer: False 165 | 166 | 167 | image_shape: *image_shape 168 | name: can_video_robomimic 169 | shape_meta: 170 | obs: 171 | agentview_image: 172 | shape: [3, 96, 96] 173 | type: rgb 174 | robot0_eye_in_hand_image: 175 | shape: [3, 96, 96] 176 | type: rgb 177 | action: 178 | shape: [10] 179 | ema: 180 | _target_: diffusion_policy.model.diffusion.ema_model.EMAModel 181 | inv_gamma: 1.0 182 | max_value: 0.9999 183 | min_value: 0.0 184 | power: 0.75 185 | update_after_step: 0 186 | task_name: can 187 | training: 188 | checkpoint_every: 1 189 | debug: false 190 | device: *device 191 | gradient_accumulate_every: 1 192 | lr_scheduler: cosine 193 | lr_warmup_steps: 500 194 | max_train_steps: null 195 | max_val_steps: null 196 | num_epochs: 400 197 | resume: true 198 | rollout_every: 25 199 | sample_every: 25 200 | seed: 242 201 | tqdm_interval_sec: 1.0 202 | use_ema: true 203 | val_every: 50 204 | val_dataloader: 205 | batch_size: *batch_size 206 | num_workers: 8 207 | persistent_workers: false 208 | pin_memory: true 209 | shuffle: false 210 | dataloader: 211 | batch_size: *batch_size 212 | num_workers: 8 213 | persistent_workers: false 214 | pin_memory: true 215 | shuffle: true -------------------------------------------------------------------------------- /VILP/config/train_vilp_can_wrist_planning.yaml: -------------------------------------------------------------------------------- 1 | _target_: VILP.workspace.train_vilp_planning_workspace.TrainVilpPlanningWorkspace 2 | checkpoint: 3 | save_last_ckpt: true 4 | save_last_snapshot: false 5 | topk: 6 | format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt 7 | k: 5 8 | mode: max 9 | monitor_key: test_mean_score 10 | logging: 11 | group: null 12 | id: null 13 | mode: online 14 | name: VILP_video_can_planning 15 | project: VILP_debug 16 | resume: true 17 | tags: 18 | - VILP_video 19 | - can_video_robomimic 20 | - default 21 | multi_run: 22 | run_dir: data/outputs/VILP_video_can_planning 23 | wandb_name_base: VILP_video_can_planning 24 | 25 | optimizer: 26 | _target_: torch.optim.AdamW 27 | betas: 28 | - 0.95 29 | - 0.999 30 | eps: 1.0e-08 31 | lr: 0.0001 32 | weight_decay: 1.0e-06 33 | past_action_visible: false 34 | 35 | dataset_obs_steps: &dataset_obs_steps 1 36 | batch_size: &batch_size 128 37 | device: &device cuda:0 38 | exp_name: default 39 | horizon: &horizon 11 40 | n_obs_steps: *dataset_obs_steps 41 | name: 2pi_video 42 | obs_as_global_cond: true 43 | image_shape: &image_shape [3,96,96] 44 | crop_shape: &crop_shape null 45 | max_generation_steps: 25 46 | input_keys: &input_keys [agentview_image, robot0_eye_in_hand_image] 47 | output_key: &output_key robot0_eye_in_hand_image 48 | generated_views: 2 49 | use_sim: false 50 | policy: 51 | _target_: VILP.policy.vilp_planning.VilpPlanning 52 | subgoal_steps: &subgoal_steps 6 53 | subgoal_interval: 2 54 | latent_shape: &latent_shape [3,12,12] 55 | output_key: *output_key 56 | vqgan_config: 57 | ckpt_path: vq_models/2024-08-12T00-44-25/checkpoints/checkpoint-epoch=54.ckpt 58 | embed_dim: 3 59 | n_embed: 1024 60 | ddconfig: 61 | double_z: False 62 | z_channels: 3 63 | resolution: 96 64 | in_channels: 3 65 | out_ch: 3 66 | ch: 128 67 | ch_mult: [1,1,2,4] # num_down = len(ch_mult)-1 68 | num_res_blocks: 2 69 | attn_resolutions: [16] 70 | dropout: 0.0 71 | lossconfig: 72 | target: VILP.taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 73 | params: 74 | disc_conditional: False 75 | disc_in_channels: 3 76 | disc_start: 10000 77 | disc_weight: 0.8 78 | codebook_weight: 1.0 79 | model_high_level: 80 | _target_: VILP.policy.latent_video_diffusion.LatentVideoDiffusion 81 | latent_shape: *latent_shape 82 | subgoal_steps: *subgoal_steps 83 | device: *device 84 | cond_predict_scale: false 85 | channel_mult: [1,2,4] 86 | num_res_blocks: 2 87 | transformer_depth: 1 88 | attention_resolutions: [4,8] 89 | model_channels: 64 90 | num_head_channels: 8 91 | crop_shape: *crop_shape 92 | eval_fixed_crop: true 93 | n_obs_steps: *dataset_obs_steps 94 | horizon: *subgoal_steps 95 | noise_scheduler: 96 | _target_: diffusers.schedulers.scheduling_ddim.DDIMScheduler 97 | beta_end: 0.0195 98 | beta_schedule: squaredcos_cap_v2 99 | beta_start: 0.0015 100 | clip_sample: true 101 | set_alpha_to_one: True 102 | steps_offset: 0 103 | num_train_timesteps: 300 104 | prediction_type: epsilon 105 | num_inference_steps: 16 106 | obs_as_global_cond: true 107 | obs_encoder_group_norm: true 108 | obs_encoder: 109 | _target_: VILP.model.vision.multi_image_obs_encoder.MultiImageObsEncoder 110 | shape_meta: 111 | obs: 112 | agentview_image: 113 | shape: [3, 96, 96] 114 | type: rgb 115 | robot0_eye_in_hand_image: 116 | shape: [3, 96, 96] 117 | type: rgb 118 | rgb_model: 119 | _target_: diffusion_policy.model.vision.model_getter.get_resnet 120 | name: resnet18 121 | weights: null 122 | resize_shape: null 123 | crop_shape: *crop_shape 124 | # constant center crop 125 | random_crop: True 126 | use_group_norm: True 127 | share_rgb_model: False 128 | imagenet_norm: True 129 | shape_meta: 130 | latent: 131 | shape: *latent_shape 132 | obs: 133 | agentview_image: 134 | shape: [3, 96, 96] 135 | type: rgb 136 | robot0_eye_in_hand_image: 137 | shape: [3, 96, 96] 138 | type: rgb 139 | 140 | dataset_shape_meta: &dataset_shape_meta 141 | obs: 142 | agentview_image: 143 | shape: [3, 84, 84] 144 | type: rgb 145 | robot0_eye_in_hand_image: 146 | shape: [3, 84, 84] 147 | type: rgb 148 | action: 149 | shape: [10] 150 | 151 | task: 152 | dataset: 153 | _target_: VILP.dataset.robomimic_vilp_dataset.RobomimicVilpDataset 154 | shape_meta: *dataset_shape_meta 155 | dataset_path: data/robomimic/datasets/can/ph/image_abs.hdf5 156 | horizon: *horizon 157 | pad_before: 1 158 | pad_after: 7 159 | use_cache: False 160 | abs_action: True 161 | seed: 242 162 | val_ratio: 0 163 | rotation_rep: 'rotation_6d' 164 | use_legacy_normalizer: False 165 | 166 | 167 | image_shape: *image_shape 168 | name: can_video_robomimic 169 | shape_meta: 170 | obs: 171 | agentview_image: 172 | shape: [3, 96, 96] 173 | type: rgb 174 | robot0_eye_in_hand_image: 175 | shape: [3, 96, 96] 176 | type: rgb 177 | action: 178 | shape: [10] 179 | ema: 180 | _target_: diffusion_policy.model.diffusion.ema_model.EMAModel 181 | inv_gamma: 1.0 182 | max_value: 0.9999 183 | min_value: 0.0 184 | power: 0.75 185 | update_after_step: 0 186 | task_name: can 187 | training: 188 | checkpoint_every: 1 189 | debug: false 190 | device: *device 191 | gradient_accumulate_every: 1 192 | lr_scheduler: cosine 193 | lr_warmup_steps: 500 194 | max_train_steps: null 195 | max_val_steps: null 196 | num_epochs: 600 197 | resume: true 198 | rollout_every: 25 199 | sample_every: 25 200 | seed: 242 201 | tqdm_interval_sec: 1.0 202 | use_ema: true 203 | val_every: 50 204 | val_dataloader: 205 | batch_size: *batch_size 206 | num_workers: 8 207 | persistent_workers: false 208 | pin_memory: true 209 | shuffle: false 210 | dataloader: 211 | batch_size: *batch_size 212 | num_workers: 8 213 | persistent_workers: false 214 | pin_memory: true 215 | shuffle: true -------------------------------------------------------------------------------- /VILP/config/train_vilp_pushT_planning.yaml: -------------------------------------------------------------------------------- 1 | _target_: VILP.workspace.train_vilp_planning_workspace.TrainVilpPlanningWorkspace 2 | checkpoint: 3 | save_last_ckpt: true 4 | save_last_snapshot: false 5 | topk: 6 | format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt 7 | k: 5 8 | mode: max 9 | monitor_key: test_mean_score 10 | logging: 11 | group: null 12 | id: null 13 | mode: online 14 | name: VILP_video_pushT_planning 15 | project: VILP_debug 16 | resume: true 17 | tags: 18 | - VILP_video 19 | - pushT_video 20 | - default 21 | multi_run: 22 | run_dir: data/outputs/VILP_video_pushT_planning 23 | wandb_name_base: VILP_video_pushT_planning 24 | 25 | optimizer: 26 | _target_: torch.optim.AdamW 27 | betas: 28 | - 0.95 29 | - 0.999 30 | eps: 1.0e-08 31 | lr: 0.0001 32 | weight_decay: 1.0e-06 33 | past_action_visible: false 34 | 35 | dataset_obs_steps: &dataset_obs_steps 1 36 | batch_size: &batch_size 64 37 | device: &device cuda:0 38 | exp_name: default 39 | horizon: &horizon 23 40 | n_obs_steps: *dataset_obs_steps 41 | name: vilp_video 42 | obs_as_global_cond: true 43 | image_shape: &image_shape [3,96,96] 44 | crop_shape: &crop_shape null 45 | 46 | max_generation_steps: 25 47 | input_key: &input_key img 48 | output_key: &output_key image 49 | generated_views: 1 50 | use_sim: false 51 | 52 | policy: 53 | _target_: VILP.policy.vilp_planning.VilpPlanning 54 | subgoal_steps: &subgoal_steps 12 55 | subgoal_interval: 2 56 | latent_shape: &latent_shape [3,12,12] 57 | vqgan_config: 58 | ckpt_path: vq_models/2024-08-12T21-42-33/checkpoints/checkpoint-epoch=74.ckpt 59 | embed_dim: 3 60 | n_embed: 1024 61 | ddconfig: 62 | double_z: False 63 | z_channels: 3 64 | resolution: 96 65 | in_channels: 3 66 | out_ch: 3 67 | ch: 128 68 | ch_mult: [1,1,2,4] # num_down = len(ch_mult)-1 69 | num_res_blocks: 2 70 | attn_resolutions: [16] 71 | dropout: 0.0 72 | lossconfig: 73 | target: VILP.taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 74 | params: 75 | disc_conditional: False 76 | disc_in_channels: 3 77 | disc_start: 10000 78 | disc_weight: 0.8 79 | codebook_weight: 1.0 80 | model_high_level: 81 | _target_: VILP.policy.latent_video_diffusion.LatentVideoDiffusion 82 | latent_shape: *latent_shape 83 | subgoal_steps: *subgoal_steps 84 | device: *device 85 | cond_predict_scale: false 86 | channel_mult: [1,2,4] 87 | num_res_blocks: 2 88 | transformer_depth: 1 89 | attention_resolutions: [4,8] 90 | model_channels: 128 91 | num_head_channels: 8 92 | crop_shape: *crop_shape 93 | eval_fixed_crop: true 94 | horizon: *subgoal_steps 95 | n_obs_steps: 1 96 | noise_scheduler: 97 | _target_: diffusers.schedulers.scheduling_ddim.DDIMScheduler 98 | beta_end: 0.0195 99 | beta_schedule: squaredcos_cap_v2 100 | beta_start: 0.0015 101 | clip_sample: true 102 | set_alpha_to_one: True 103 | steps_offset: 0 104 | num_train_timesteps: 300 105 | prediction_type: epsilon 106 | num_inference_steps: 16 107 | obs_as_global_cond: true 108 | obs_encoder_group_norm: true 109 | obs_encoder: 110 | _target_: VILP.model.vision.multi_image_obs_encoder.MultiImageObsEncoder 111 | shape_meta: 112 | obs: 113 | image: 114 | shape: *image_shape 115 | type: rgb 116 | rgb_model: 117 | _target_: diffusion_policy.model.vision.model_getter.get_resnet 118 | name: resnet18 119 | weights: null 120 | resize_shape: null 121 | crop_shape: *crop_shape 122 | # constant center crop 123 | random_crop: True 124 | use_group_norm: True 125 | share_rgb_model: False 126 | imagenet_norm: True 127 | shape_meta: 128 | latent: 129 | shape: *latent_shape 130 | obs: 131 | image: 132 | shape: *image_shape 133 | type: rgb 134 | 135 | 136 | task: 137 | dataset: 138 | _target_: VILP.dataset.pusht_image_dataset.PushtImageDataset 139 | horizon: *horizon 140 | max_train_episodes: null 141 | pad_after: 15 142 | pad_before: 1 143 | seed: 42 144 | val_ratio: 0.007 145 | zarr_path: data/pusht/pusht_cchi_v7_replay.zarr 146 | image_shape: *image_shape 147 | name: pushT_video 148 | shape_meta: 149 | action: 150 | shape: 151 | - 2 152 | obs: 153 | image: 154 | shape: *image_shape 155 | type: rgb 156 | ema: 157 | _target_: diffusion_policy.model.diffusion.ema_model.EMAModel 158 | inv_gamma: 1.0 159 | max_value: 0.9999 160 | min_value: 0.0 161 | power: 0.75 162 | update_after_step: 0 163 | task_name: pushT 164 | training: 165 | checkpoint_every: 1 166 | debug: false 167 | device: *device 168 | gradient_accumulate_every: 1 169 | lr_scheduler: cosine 170 | lr_warmup_steps: 500 171 | max_train_steps: null 172 | max_val_steps: null 173 | num_epochs: 1000 174 | resume: true 175 | rollout_every: 25 176 | sample_every: 25 177 | seed: 42 178 | tqdm_interval_sec: 1.0 179 | use_ema: true 180 | val_every: 50 181 | val_dataloader: 182 | batch_size: *batch_size 183 | num_workers: 8 184 | persistent_workers: false 185 | pin_memory: true 186 | shuffle: false 187 | dataloader: 188 | batch_size: *batch_size 189 | num_workers: 8 190 | persistent_workers: false 191 | pin_memory: true 192 | shuffle: true -------------------------------------------------------------------------------- /VILP/config/train_vilp_pushT_state_planning.yaml: -------------------------------------------------------------------------------- 1 | _target_: VILP.workspace.train_vilp_planning_workspace.TrainVilpPlanningWorkspace 2 | checkpoint: 3 | save_last_ckpt: true 4 | save_last_snapshot: false 5 | topk: 6 | format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt 7 | k: 5 8 | mode: max 9 | monitor_key: test_mean_score 10 | logging: 11 | group: null 12 | id: null 13 | mode: online 14 | name: VILP_video_pushT_planning 15 | project: VILP_debug 16 | resume: true 17 | tags: 18 | - VILP_video 19 | - pushT_video 20 | - default 21 | multi_run: 22 | run_dir: data/outputs/VILP_video_pushT_planning 23 | wandb_name_base: VILP_video_pushT_planning 24 | 25 | optimizer: 26 | _target_: torch.optim.AdamW 27 | betas: 28 | - 0.95 29 | - 0.999 30 | eps: 1.0e-08 31 | lr: 0.0001 32 | weight_decay: 1.0e-06 33 | past_action_visible: false 34 | 35 | dataset_obs_steps: &dataset_obs_steps 1 36 | batch_size: &batch_size 64 37 | device: &device cuda:0 38 | exp_name: default 39 | horizon: &horizon 23 40 | n_obs_steps: *dataset_obs_steps 41 | name: vilp_video 42 | obs_as_global_cond: true 43 | image_shape: &image_shape [3,96,96] 44 | crop_shape: &crop_shape null 45 | 46 | max_generation_steps: 25 47 | input_key: &input_key img 48 | output_key: &output_key image 49 | # dummy values for not visualizing the generated views 50 | generated_views: 3 51 | use_sim: false 52 | 53 | policy: 54 | _target_: VILP.policy.vilp_planning.VilpPlanning 55 | subgoal_steps: &subgoal_steps 12 56 | subgoal_interval: 2 57 | latent_shape: &latent_shape [3,12,12] 58 | vqgan_config: 59 | ckpt_path: vq_models/2024-08-12T21-42-33/checkpoints/checkpoint-epoch=74.ckpt 60 | embed_dim: 3 61 | n_embed: 1024 62 | ddconfig: 63 | double_z: False 64 | z_channels: 3 65 | resolution: 96 66 | in_channels: 3 67 | out_ch: 3 68 | ch: 128 69 | ch_mult: [1,1,2,4] # num_down = len(ch_mult)-1 70 | num_res_blocks: 2 71 | attn_resolutions: [16] 72 | dropout: 0.0 73 | lossconfig: 74 | target: VILP.taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 75 | params: 76 | disc_conditional: False 77 | disc_in_channels: 3 78 | disc_start: 10000 79 | disc_weight: 0.8 80 | codebook_weight: 1.0 81 | model_high_level: 82 | _target_: VILP.policy.latent_video_diffusion.LatentVideoDiffusion 83 | latent_shape: *latent_shape 84 | subgoal_steps: *subgoal_steps 85 | device: *device 86 | cond_predict_scale: false 87 | channel_mult: [1,2,4] 88 | num_res_blocks: 2 89 | transformer_depth: 1 90 | attention_resolutions: [4,8] 91 | model_channels: 128 92 | num_head_channels: 8 93 | crop_shape: *crop_shape 94 | eval_fixed_crop: true 95 | horizon: *subgoal_steps 96 | n_obs_steps: 1 97 | noise_scheduler: 98 | _target_: diffusers.schedulers.scheduling_ddim.DDIMScheduler 99 | beta_end: 0.0195 100 | beta_schedule: squaredcos_cap_v2 101 | beta_start: 0.0015 102 | clip_sample: true 103 | set_alpha_to_one: True 104 | steps_offset: 0 105 | num_train_timesteps: 300 106 | prediction_type: epsilon 107 | num_inference_steps: 16 108 | obs_as_global_cond: true 109 | obs_encoder_group_norm: true 110 | obs_encoder: 111 | _target_: VILP.model.vision.multi_image_obs_encoder.MultiImageObsEncoder 112 | shape_meta: 113 | obs: 114 | image: 115 | shape: *image_shape 116 | type: rgb 117 | agent_pos: 118 | shape: 119 | - 2 120 | rgb_model: 121 | _target_: diffusion_policy.model.vision.model_getter.get_resnet 122 | name: resnet18 123 | weights: null 124 | resize_shape: null 125 | crop_shape: *crop_shape 126 | # constant center crop 127 | random_crop: True 128 | use_group_norm: True 129 | share_rgb_model: False 130 | imagenet_norm: True 131 | shape_meta: 132 | latent: 133 | shape: *latent_shape 134 | obs: 135 | image: 136 | shape: *image_shape 137 | type: rgb 138 | agent_pos: 139 | shape: 140 | - 2 141 | 142 | 143 | task: 144 | dataset: 145 | _target_: VILP.dataset.pusht_image_dataset.PushtImageDataset 146 | horizon: *horizon 147 | max_train_episodes: null 148 | pad_after: 15 149 | pad_before: 1 150 | seed: 42 151 | val_ratio: 0.007 152 | zarr_path: data/pusht/pusht_cchi_v7_replay.zarr 153 | image_shape: *image_shape 154 | name: pushT_video 155 | shape_meta: 156 | action: 157 | shape: 158 | - 2 159 | obs: 160 | image: 161 | shape: *image_shape 162 | type: rgb 163 | agent_pos: 164 | shape: 165 | - 2 166 | ema: 167 | _target_: diffusion_policy.model.diffusion.ema_model.EMAModel 168 | inv_gamma: 1.0 169 | max_value: 0.9999 170 | min_value: 0.0 171 | power: 0.75 172 | update_after_step: 0 173 | task_name: pushT_state 174 | training: 175 | checkpoint_every: 1 176 | debug: false 177 | device: *device 178 | gradient_accumulate_every: 1 179 | lr_scheduler: cosine 180 | lr_warmup_steps: 500 181 | max_train_steps: null 182 | max_val_steps: null 183 | num_epochs: 1000 184 | resume: true 185 | rollout_every: 25 186 | sample_every: 25 187 | seed: 42 188 | tqdm_interval_sec: 1.0 189 | use_ema: true 190 | val_every: 50 191 | val_dataloader: 192 | batch_size: *batch_size 193 | num_workers: 8 194 | persistent_workers: false 195 | pin_memory: true 196 | shuffle: false 197 | dataloader: 198 | batch_size: *batch_size 199 | num_workers: 8 200 | persistent_workers: false 201 | pin_memory: true 202 | shuffle: true -------------------------------------------------------------------------------- /VILP/config/train_vq_can_agent.yaml: -------------------------------------------------------------------------------- 1 | _target_: VILP.workspace.train_vqgan_workspace.TrainVqganWorkspace 2 | base_learning_rate: 4.5e-6 3 | model: 4 | embed_dim: 3 5 | n_embed: 1024 6 | ddconfig: 7 | double_z: False 8 | z_channels: 3 9 | resolution: 96 10 | in_channels: 3 11 | out_ch: 3 12 | ch: 128 13 | ch_mult: [1,1,2,4] 14 | num_res_blocks: 2 15 | attn_resolutions: [16] 16 | dropout: 0.0 17 | lossconfig: 18 | target: VILP.taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 19 | params: 20 | disc_conditional: False 21 | disc_in_channels: 3 22 | disc_start: 10000 23 | disc_weight: 0.8 24 | codebook_weight: 1.0 25 | 26 | 27 | image_shape: &image_shape [3, 84, 84] 28 | shape_meta: &shape_meta 29 | obs: 30 | agentview_image: 31 | shape: *image_shape 32 | type: rgb 33 | action: 34 | shape: [10] 35 | 36 | batch_size: &batch_size 64 37 | gpus: '0,' 38 | seed: &seed 42 39 | save_every: 1 40 | max_epochs: 100000 41 | 42 | dataset: 43 | _target_: VILP.dataset.robomimic_ae_dataset.RobomimicAeDataset 44 | shape_meta: *shape_meta 45 | dataset_path: data/robomimic/datasets/can/ph/image_abs.hdf5 46 | horizon: 1 47 | pad_before: 1 48 | pad_after: 1 49 | n_obs_steps: 1 50 | use_cache: False 51 | abs_action: True 52 | seed: *seed 53 | val_ratio: 0.0 54 | rotation_rep: 'rotation_6d' 55 | use_legacy_normalizer: False 56 | key_index: agentview_image 57 | 58 | val_dataloader: 59 | batch_size: *batch_size 60 | num_workers: 8 61 | persistent_workers: true 62 | pin_memory: true 63 | shuffle: true 64 | dataloader: 65 | batch_size: *batch_size 66 | num_workers: 8 67 | persistent_workers: true 68 | pin_memory: true 69 | shuffle: true 70 | 71 | trainer: 72 | name: "vqvae_dubug" 73 | resume: "" 74 | base: [] 75 | no_test: false 76 | project: null 77 | debug: false 78 | seed: *seed 79 | postfix: "" 80 | train: true -------------------------------------------------------------------------------- /VILP/config/train_vq_can_wrist.yaml: -------------------------------------------------------------------------------- 1 | _target_: VILP.workspace.train_vqgan_workspace.TrainVqganWorkspace 2 | base_learning_rate: 4.5e-6 3 | model: 4 | embed_dim: 3 5 | n_embed: 1024 6 | ddconfig: 7 | double_z: False 8 | z_channels: 3 9 | resolution: 96 10 | in_channels: 3 11 | out_ch: 3 12 | ch: 128 13 | ch_mult: [1,1,2,4] 14 | num_res_blocks: 2 15 | attn_resolutions: [16] 16 | dropout: 0.0 17 | lossconfig: 18 | target: VILP.taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 19 | params: 20 | disc_conditional: False 21 | disc_in_channels: 3 22 | disc_start: 10000 23 | disc_weight: 0.8 24 | codebook_weight: 1.0 25 | 26 | 27 | image_shape: &image_shape [3, 84, 84] 28 | shape_meta: &shape_meta 29 | obs: 30 | robot0_eye_in_hand_image: 31 | shape: *image_shape 32 | type: rgb 33 | action: 34 | shape: [10] 35 | 36 | batch_size: &batch_size 64 37 | gpus: '0,' 38 | seed: &seed 42 39 | save_every: 1 40 | max_epochs: 100000 41 | 42 | dataset: 43 | _target_: VILP.dataset.robomimic_ae_dataset.RobomimicAeDataset 44 | shape_meta: *shape_meta 45 | dataset_path: data/robomimic/datasets/can/ph/image_abs.hdf5 46 | horizon: 1 47 | pad_before: 1 48 | pad_after: 1 49 | n_obs_steps: 1 50 | use_cache: False 51 | abs_action: True 52 | seed: *seed 53 | val_ratio: 0.0 54 | rotation_rep: 'rotation_6d' 55 | use_legacy_normalizer: False 56 | key_index: robot0_eye_in_hand_image 57 | 58 | val_dataloader: 59 | batch_size: *batch_size 60 | num_workers: 8 61 | persistent_workers: true 62 | pin_memory: true 63 | shuffle: true 64 | dataloader: 65 | batch_size: *batch_size 66 | num_workers: 8 67 | persistent_workers: true 68 | pin_memory: true 69 | shuffle: true 70 | 71 | trainer: 72 | name: "vqvae_dubug" 73 | resume: "" 74 | base: [] 75 | no_test: false 76 | project: null 77 | debug: false 78 | seed: *seed 79 | postfix: "" 80 | train: true -------------------------------------------------------------------------------- /VILP/config/train_vq_pushT.yaml: -------------------------------------------------------------------------------- 1 | _target_: VILP.workspace.train_vqgan_workspace.TrainVqganWorkspace 2 | base_learning_rate: 4.5e-6 3 | model: 4 | embed_dim: 3 5 | n_embed: 1024 6 | ddconfig: 7 | double_z: False 8 | z_channels: 3 9 | resolution: 96 10 | in_channels: 3 11 | out_ch: 3 12 | ch: 128 13 | ch_mult: [1,1,2,4] # num_down = len(ch_mult)-1 14 | num_res_blocks: 2 15 | attn_resolutions: [16] 16 | dropout: 0.0 17 | lossconfig: 18 | target: VILP.taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 19 | params: 20 | disc_conditional: False 21 | disc_in_channels: 3 22 | #disc_start: 10000 23 | # for sim push-T, due to its white background, we do not need to use the discriminator 24 | # to help reconstruct image details 25 | disc_start: 1000000 26 | disc_weight: 0.8 27 | codebook_weight: 1.0 28 | dataset: 29 | _target_: VILP.dataset.push_multi_image_ae_dataset.PushMultiImageAeDataset 30 | zarr_path: data/pusht/pusht_cchi_v7_replay.zarr 31 | horizon: 1 32 | max_train_episodes: null 33 | pad_after: 7 34 | pad_before: 1 35 | seed: 42 36 | val_ratio: 0.00 37 | val_dataloader: 38 | batch_size: 64 39 | num_workers: 8 40 | persistent_workers: false 41 | pin_memory: true 42 | shuffle: false 43 | dataloader: 44 | batch_size: 64 45 | num_workers: 8 46 | persistent_workers: false 47 | pin_memory: true 48 | shuffle: true 49 | gpus: '0,' 50 | batch_size: 8 51 | seed: 42 52 | trainer: 53 | name: "vqvae_debug" 54 | resume: "" 55 | base: [] 56 | no_test: false 57 | project: null 58 | debug: false 59 | seed: 42 60 | postfix: "" 61 | train: true 62 | save_every: 1 63 | max_epochs: 100000 -------------------------------------------------------------------------------- /VILP/dataset/push_multi_image_ae_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import numpy as np 4 | import copy 5 | from diffusion_policy.common.pytorch_util import dict_apply 6 | from diffusion_policy.common.replay_buffer import ReplayBuffer 7 | from diffusion_policy.common.sampler import ( 8 | SequenceSampler, get_val_mask, downsample_mask) 9 | from diffusion_policy.model.common.normalizer import LinearNormalizer 10 | from diffusion_policy.dataset.base_dataset import BaseImageDataset 11 | from diffusion_policy.common.normalize_util import get_image_range_normalizer 12 | import cv2 13 | class PushMultiImageAeDataset(BaseImageDataset): 14 | def __init__(self, 15 | zarr_path='data/pusht_multi_goals', 16 | horizon=1, 17 | pad_before=0, 18 | pad_after=0, 19 | seed=42, 20 | val_ratio=0.0, 21 | max_train_episodes=None 22 | ): 23 | 24 | super().__init__() 25 | 26 | self.replay_buffer = ReplayBuffer.copy_from_path( 27 | zarr_path, keys=['img']) 28 | val_mask = get_val_mask( 29 | n_episodes=self.replay_buffer.n_episodes, 30 | val_ratio=val_ratio, 31 | seed=seed) 32 | train_mask = ~val_mask 33 | train_mask = downsample_mask( 34 | mask=train_mask, 35 | max_n=max_train_episodes, 36 | seed=seed) 37 | 38 | self.sampler = SequenceSampler( 39 | replay_buffer=self.replay_buffer, 40 | sequence_length=horizon, 41 | pad_before=pad_before, 42 | pad_after=pad_after, 43 | episode_mask=train_mask) 44 | self.train_mask = train_mask 45 | self.horizon = horizon 46 | self.pad_before = pad_before 47 | self.pad_after = pad_after 48 | 49 | def get_validation_dataset(self): 50 | val_set = copy.copy(self) 51 | val_set.sampler = SequenceSampler( 52 | replay_buffer=self.replay_buffer, 53 | sequence_length=self.horizon, 54 | pad_before=self.pad_before, 55 | pad_after=self.pad_after, 56 | episode_mask=~self.train_mask 57 | ) 58 | val_set.train_mask = ~self.train_mask 59 | return val_set 60 | 61 | def get_normalizer(self, mode='limits', **kwargs): 62 | normalizer = LinearNormalizer() 63 | normalizer['image'] = get_image_range_normalizer() 64 | return normalizer 65 | 66 | def __len__(self) -> int: 67 | return len(self.sampler) 68 | 69 | def _sample_to_data(self, sample): 70 | image = np.moveaxis(sample['img'],-1,1)/255 71 | 72 | image = image[0].transpose((1, 2, 0)) 73 | 74 | data = {'image': image} 75 | return data 76 | 77 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 78 | sample = self.sampler.sample_sequence(idx) 79 | data = self._sample_to_data(sample) 80 | torch_data = dict_apply(data, torch.from_numpy) 81 | return torch_data -------------------------------------------------------------------------------- /VILP/dataset/pusht_image_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import numpy as np 4 | import copy 5 | from diffusion_policy.common.pytorch_util import dict_apply 6 | from diffusion_policy.common.replay_buffer import ReplayBuffer 7 | from diffusion_policy.common.sampler import ( 8 | SequenceSampler, get_val_mask, downsample_mask) 9 | from diffusion_policy.model.common.normalizer import LinearNormalizer 10 | from diffusion_policy.dataset.base_dataset import BaseImageDataset 11 | from diffusion_policy.common.normalize_util import get_image_range_normalizer 12 | import cv2 13 | class PushtImageDataset(BaseImageDataset): 14 | def __init__(self, 15 | zarr_path, 16 | horizon=1, 17 | pad_before=0, 18 | pad_after=0, 19 | seed=42, 20 | val_ratio=0.0, 21 | max_train_episodes=None, 22 | ): 23 | 24 | super().__init__() 25 | 26 | self.replay_buffer = ReplayBuffer.copy_from_path( 27 | zarr_path, keys=['img', 'state', 'action']) 28 | val_mask = get_val_mask( 29 | n_episodes=self.replay_buffer.n_episodes, 30 | val_ratio=val_ratio, 31 | seed=seed) 32 | train_mask = ~val_mask 33 | train_mask = downsample_mask( 34 | mask=train_mask, 35 | max_n=max_train_episodes, 36 | seed=seed) 37 | 38 | self.sampler = SequenceSampler( 39 | replay_buffer=self.replay_buffer, 40 | sequence_length=horizon, 41 | pad_before=pad_before, 42 | pad_after=pad_after, 43 | episode_mask=train_mask) 44 | self.train_mask = train_mask 45 | 46 | 47 | self.horizon = horizon 48 | self.pad_before = pad_before 49 | self.pad_after = pad_after 50 | self.val_mask = val_mask 51 | def get_validation_dataset(self): 52 | val_set = copy.copy(self) 53 | val_set.sampler = SequenceSampler( 54 | replay_buffer=self.replay_buffer, 55 | sequence_length=self.horizon, 56 | pad_before=self.pad_before, 57 | pad_after=self.pad_after, 58 | episode_mask=~self.train_mask 59 | ) 60 | val_set.train_mask = ~self.train_mask 61 | return val_set 62 | 63 | def get_normalizer(self, mode='limits', **kwargs): 64 | 65 | data = { 66 | 'action': self.replay_buffer['action'], 67 | 'agent_pos': self.replay_buffer['state'][...,:2] 68 | } 69 | 70 | normalizer = LinearNormalizer() 71 | normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs) 72 | normalizer['image'] = get_image_range_normalizer() 73 | normalizer['latent_img'] = get_image_range_normalizer() 74 | return normalizer 75 | 76 | def __len__(self) -> int: 77 | return len(self.sampler) 78 | 79 | def _sample_to_data(self, sample): 80 | agent_pos = sample['state'][:,:2].astype(np.float32) # (agent_posx2, block_posex3) 81 | image = np.moveaxis(sample['img'],-1,1)/255 82 | data = { 83 | 'obs': { 84 | 'image': image, # T, 3, H, W 85 | 'agent_pos': agent_pos, # T, 2 86 | }, 87 | 'action': sample['action'].astype(np.float32) # T, 2 88 | } 89 | return data 90 | 91 | def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: 92 | sample = self.sampler.sample_sequence(idx) 93 | data = self._sample_to_data(sample) 94 | torch_data = dict_apply(data, torch.from_numpy) 95 | return torch_data 96 | 97 | def get_episode_lengths(self): 98 | return len(self.val_mask) 99 | def get_val_episode_start(self,idx, key_index): 100 | episode_ends = self.replay_buffer.episode_ends[:] 101 | if not self.val_mask[idx]: 102 | return False, None 103 | else: 104 | if idx > 0: 105 | image = self.replay_buffer[key_index][episode_ends[idx-1]] 106 | image = image.astype(np.float32) / 255. 107 | image = np.moveaxis(image,-1,0) 108 | return True, image 109 | else: 110 | image = self.replay_buffer[key_index][episode_ends[idx-1]] 111 | image = image.astype(np.float32) / 255. 112 | image = np.moveaxis(image,-1,0) 113 | return True, image 114 | def get_val_episode_full(self,idx, key_index): 115 | episode_ends = self.replay_buffer.episode_ends[:] 116 | if not self.val_mask[idx]: 117 | return False, None 118 | else: 119 | if idx > 0: 120 | one_episode = [] 121 | for i in range((episode_ends[idx]-episode_ends[idx-1])): 122 | image = self.replay_buffer[key_index][episode_ends[idx-1]+i] 123 | image = image.astype(np.float32) / 255. 124 | image = np.moveaxis(image,-1,0) 125 | one_episode.append(image) 126 | return True, one_episode 127 | else: 128 | one_episode = [] 129 | for i in range(episode_ends[1]): 130 | image = self.replay_buffer[key_index][i] 131 | image = image.astype(np.float32) / 255. 132 | image = np.moveaxis(image,-1,0) 133 | one_episode.append(image) 134 | return True, one_episode 135 | -------------------------------------------------------------------------------- /VILP/env/pusht/__init__.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | import diffusion_policy.env.pusht 3 | 4 | register( 5 | id='pusht-keypoints-v0', 6 | entry_point='envs.pusht.pusht_keypoints_env:PushTKeypointsEnv', 7 | max_episode_steps=200, 8 | reward_threshold=1.0 9 | ) -------------------------------------------------------------------------------- /VILP/env/pusht/pusht_collect_env.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Sequence, Union, Optional 2 | from gym import spaces 3 | from VILP.env.pusht.pusht_env import PushTEnv 4 | from diffusion_policy.env.pusht.pymunk_keypoint_manager import PymunkKeypointManager 5 | import numpy as np 6 | 7 | class PushTCollectEnv(PushTEnv): 8 | def __init__(self, 9 | legacy=False, 10 | block_cog=None, 11 | damping=None, 12 | render_size=96, 13 | keypoint_visible_rate=1.0, 14 | agent_keypoints=False, 15 | draw_keypoints=False, 16 | reset_to_state=None, 17 | render_action=True, 18 | local_keypoint_map: Dict[str, np.ndarray]=None, 19 | color_map: Optional[Dict[str, np.ndarray]]=None): 20 | super().__init__( 21 | legacy=legacy, 22 | block_cog=block_cog, 23 | damping=damping, 24 | render_size=render_size, 25 | reset_to_state=reset_to_state, 26 | render_action=render_action) 27 | ws = self.window_size 28 | 29 | if local_keypoint_map is None: 30 | # create default keypoint definition 31 | kp_kwargs = self.genenerate_keypoint_manager_params() 32 | local_keypoint_map = kp_kwargs['local_keypoint_map'] 33 | color_map = kp_kwargs['color_map'] 34 | 35 | # create observation spaces 36 | Dblockkps = np.prod(local_keypoint_map['block'].shape) 37 | Dagentkps = np.prod(local_keypoint_map['agent'].shape) 38 | Dagentpos = 2 39 | 40 | Do = Dblockkps 41 | if agent_keypoints: 42 | # blockkp + agnet_pos 43 | Do += Dagentkps 44 | else: 45 | # blockkp + agnet_kp 46 | Do += Dagentpos 47 | # obs + obs_mask 48 | Dobs = Do * 2 49 | 50 | low = np.zeros((Dobs,), dtype=np.float64) 51 | high = np.full_like(low, ws) 52 | # mask range 0-1 53 | high[Do:] = 1. 54 | 55 | # (block_kps+agent_kps, xy+confidence) 56 | self.observation_space = spaces.Box( 57 | low=low, 58 | high=high, 59 | shape=low.shape, 60 | dtype=np.float64 61 | ) 62 | 63 | self.keypoint_visible_rate = keypoint_visible_rate 64 | self.agent_keypoints = agent_keypoints 65 | self.draw_keypoints = draw_keypoints 66 | self.kp_manager = PymunkKeypointManager( 67 | local_keypoint_map=local_keypoint_map, 68 | color_map=color_map) 69 | self.draw_kp_map = None 70 | 71 | @classmethod 72 | def genenerate_keypoint_manager_params(cls): 73 | env = PushTEnv() 74 | kp_manager = PymunkKeypointManager.create_from_pusht_env(env) 75 | kp_kwargs = kp_manager.kwargs 76 | return kp_kwargs 77 | 78 | def _get_obs(self): 79 | # get keypoints 80 | obj_map = { 81 | 'block': self.block 82 | } 83 | if self.agent_keypoints: 84 | obj_map['agent'] = self.agent 85 | 86 | kp_map = self.kp_manager.get_keypoints_global( 87 | pose_map=obj_map, is_obj=True) 88 | # python dict guerentee order of keys and values 89 | kps = np.concatenate(list(kp_map.values()), axis=0) 90 | 91 | # select keypoints to drop 92 | n_kps = kps.shape[0] 93 | visible_kps = self.np_random.random(size=(n_kps,)) < self.keypoint_visible_rate 94 | kps_mask = np.repeat(visible_kps[:,None], 2, axis=1) 95 | 96 | # save keypoints for rendering 97 | vis_kps = kps.copy() 98 | vis_kps[~visible_kps] = 0 99 | draw_kp_map = { 100 | 'block': vis_kps[:len(kp_map['block'])] 101 | } 102 | if self.agent_keypoints: 103 | draw_kp_map['agent'] = vis_kps[len(kp_map['block']):] 104 | self.draw_kp_map = draw_kp_map 105 | 106 | # construct obs 107 | obs = kps.flatten() 108 | obs_mask = kps_mask.flatten() 109 | if not self.agent_keypoints: 110 | # passing agent position when keypoints are not available 111 | agent_pos = np.array(self.agent.position) 112 | obs = np.concatenate([ 113 | obs, agent_pos 114 | ]) 115 | obs_mask = np.concatenate([ 116 | obs_mask, np.ones((2,), dtype=bool) 117 | ]) 118 | 119 | # obs, obs_mask 120 | obs = np.concatenate([ 121 | obs, obs_mask.astype(obs.dtype) 122 | ], axis=0) 123 | return obs 124 | 125 | 126 | def _render_frame(self, mode): 127 | img = super()._render_frame(mode) 128 | if self.draw_keypoints: 129 | self.kp_manager.draw_keypoints( 130 | img, self.draw_kp_map, radius=int(img.shape[0]/96)) 131 | return img -------------------------------------------------------------------------------- /VILP/env/pusht/pusht_image_env.py: -------------------------------------------------------------------------------- 1 | from gym import spaces 2 | from VILP.env.pusht.pusht_env import PushTEnv 3 | import numpy as np 4 | import cv2 5 | 6 | class PushTImageEnv(PushTEnv): 7 | metadata = {"render.modes": ["rgb_array"], "video.frames_per_second": 10} 8 | 9 | def __init__(self, 10 | legacy=False, 11 | block_cog=None, 12 | damping=None, 13 | render_size=96): 14 | super().__init__( 15 | legacy=legacy, 16 | block_cog=block_cog, 17 | damping=damping, 18 | render_size=render_size, 19 | render_action=False) 20 | ws = self.window_size 21 | self.observation_space = spaces.Dict({ 22 | 'image': spaces.Box( 23 | low=0, 24 | high=1, 25 | shape=(3,render_size,render_size), 26 | dtype=np.float32 27 | ), 28 | 'agent_pos': spaces.Box( 29 | low=0, 30 | high=ws, 31 | shape=(2,), 32 | dtype=np.float32 33 | ) 34 | }) 35 | self.render_cache = None 36 | 37 | def _get_obs(self): 38 | img = super()._render_frame(mode='rgb_array') 39 | 40 | agent_pos = np.array(self.agent.position) 41 | img_obs = np.moveaxis(img.astype(np.float32) / 255, -1, 0) 42 | obs = { 43 | 'image': img_obs, 44 | 'agent_pos': agent_pos 45 | } 46 | 47 | # draw action 48 | if self.latest_action is not None: 49 | action = np.array(self.latest_action) 50 | coord = (action / 512 * 96).astype(np.int32) 51 | marker_size = int(8/96*self.render_size) 52 | thickness = int(1/96*self.render_size) 53 | cv2.drawMarker(img, coord, 54 | color=(255,0,0), markerType=cv2.MARKER_CROSS, 55 | markerSize=marker_size, thickness=thickness) 56 | self.render_cache = img 57 | 58 | return obs 59 | 60 | def render(self, mode): 61 | assert mode == 'rgb_array' 62 | 63 | if self.render_cache is None: 64 | self._get_obs() 65 | 66 | return self.render_cache -------------------------------------------------------------------------------- /VILP/env/pusht/pusht_keypoints_env.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Sequence, Union, Optional 2 | from gym import spaces 3 | from VILP.env.pusht.pusht_env import PushTEnv 4 | from diffusion_policy.env.pusht.pymunk_keypoint_manager import PymunkKeypointManager 5 | import numpy as np 6 | 7 | class PushTKeypointsEnv(PushTEnv): 8 | def __init__(self, 9 | legacy=False, 10 | block_cog=None, 11 | damping=None, 12 | render_size=96, 13 | keypoint_visible_rate=1.0, 14 | agent_keypoints=False, 15 | draw_keypoints=True, 16 | reset_to_state=None, 17 | render_action=True, 18 | local_keypoint_map: Dict[str, np.ndarray]=None, 19 | color_map: Optional[Dict[str, np.ndarray]]=None): 20 | super().__init__( 21 | legacy=legacy, 22 | block_cog=block_cog, 23 | damping=damping, 24 | render_size=render_size, 25 | reset_to_state=reset_to_state, 26 | render_action=render_action) 27 | ws = self.window_size 28 | 29 | if local_keypoint_map is None: 30 | # create default keypoint definition 31 | kp_kwargs = self.genenerate_keypoint_manager_params() 32 | local_keypoint_map = kp_kwargs['local_keypoint_map'] 33 | color_map = kp_kwargs['color_map'] 34 | 35 | # create observation spaces 36 | Dblockkps = np.prod(local_keypoint_map['block'].shape) 37 | Dagentkps = np.prod(local_keypoint_map['agent'].shape) 38 | Dagentpos = 2 39 | 40 | Do = Dblockkps 41 | if agent_keypoints: 42 | # blockkp + agnet_pos 43 | Do += Dagentkps 44 | else: 45 | # blockkp + agnet_kp 46 | Do += Dagentpos 47 | # obs + obs_mask 48 | #Dobs = Do * 2 49 | Do = Do+3 50 | Dobs = (Do)*2 51 | 52 | low = np.zeros((Dobs,), dtype=np.float64) 53 | high = np.full_like(low, ws) 54 | # mask range 0-1 55 | high[Do:] = 1. 56 | 57 | high[Do-1] = np.pi/2 58 | low[Do-1] = -np.pi/2 59 | high[Do-3:Do-1] = 1000 60 | low[Do-3:Do-1] = -1000 61 | 62 | 63 | # (block_kps+agent_kps, xy+confidence) 64 | self.observation_space = spaces.Box( 65 | low=low, 66 | high=high, 67 | shape=low.shape, 68 | dtype=np.float64 69 | ) 70 | 71 | self.keypoint_visible_rate = keypoint_visible_rate 72 | self.agent_keypoints = agent_keypoints 73 | self.draw_keypoints = draw_keypoints 74 | self.kp_manager = PymunkKeypointManager( 75 | local_keypoint_map=local_keypoint_map, 76 | color_map=color_map) 77 | self.draw_kp_map = None 78 | 79 | @classmethod 80 | def genenerate_keypoint_manager_params(cls): 81 | env = PushTEnv() 82 | kp_manager = PymunkKeypointManager.create_from_pusht_env(env) 83 | kp_kwargs = kp_manager.kwargs 84 | return kp_kwargs 85 | 86 | def _get_obs(self): 87 | # get keypoints 88 | obj_map = { 89 | 'block': self.block 90 | } 91 | if self.agent_keypoints: 92 | obj_map['agent'] = self.agent 93 | 94 | kp_map = self.kp_manager.get_keypoints_global( 95 | pose_map=obj_map, is_obj=True) 96 | # python dict guerentee order of keys and values 97 | kps = np.concatenate(list(kp_map.values()), axis=0) 98 | 99 | # select keypoints to drop 100 | n_kps = kps.shape[0] 101 | visible_kps = self.np_random.random(size=(n_kps,)) < self.keypoint_visible_rate 102 | kps_mask = np.repeat(visible_kps[:,None], 2, axis=1) 103 | 104 | # save keypoints for rendering 105 | vis_kps = kps.copy() 106 | vis_kps[~visible_kps] = 0 107 | draw_kp_map = { 108 | 'block': vis_kps[:len(kp_map['block'])] 109 | } 110 | if self.agent_keypoints: 111 | draw_kp_map['agent'] = vis_kps[len(kp_map['block']):] 112 | self.draw_kp_map = draw_kp_map 113 | 114 | # construct obs 115 | obs = kps.flatten() 116 | obs_mask = kps_mask.flatten() 117 | if not self.agent_keypoints: 118 | # passing agent position when keypoints are not available 119 | agent_pos = np.array(self.agent.position) 120 | obs = np.concatenate([ 121 | obs, agent_pos 122 | ]) 123 | obs_mask = np.concatenate([ 124 | obs_mask, np.ones((2,), dtype=bool) 125 | ]) 126 | 127 | # TODO: new goal pose is quaternion 128 | goal = super()._get_info()['goal_pose'], 129 | goal = goal[0] 130 | 131 | obs = np.concatenate([ 132 | obs, goal 133 | ]) 134 | obs_mask = np.concatenate([ 135 | obs_mask, np.ones((3,), dtype=bool) 136 | ]) 137 | 138 | 139 | # obs, obs_mask 140 | obs = np.concatenate([ 141 | obs, obs_mask.astype(obs.dtype) 142 | ], axis=0) 143 | 144 | return obs 145 | 146 | 147 | def _render_frame(self, mode): 148 | img = super()._render_frame(mode) 149 | if self.draw_keypoints: 150 | self.kp_manager.draw_keypoints( 151 | img, self.pred_kps, radius=int(img.shape[0]/96)) 152 | return img 153 | -------------------------------------------------------------------------------- /VILP/env_runner/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | def resize_images(batch_images, new_height, new_width): 5 | """ 6 | Resize a batch of images to a new height and width. 7 | 8 | Parameters: 9 | - batch_images: numpy array of shape (batch, steps, 3, H, W) 10 | - new_height: int, the new height of the images 11 | - new_width: int, the new width of the images 12 | 13 | Returns: 14 | - resized_images: numpy array of shape (batch, steps, 3, H2, W2) 15 | """ 16 | batch, steps, channels, _, _ = batch_images.shape 17 | resized_images = np.empty((batch, steps, channels, new_height, new_width), dtype=batch_images.dtype) 18 | 19 | for i in range(batch): 20 | for j in range(steps): 21 | image = batch_images[i, j, :, :, :].transpose(1, 2, 0) # Reshape to (H, W, C) 22 | resized_image = cv2.resize(image, (new_width, new_height)) 23 | resized_images[i, j, :, :, :] = resized_image.transpose(2, 0, 1) # Back to (C, H, W) 24 | 25 | return resized_images -------------------------------------------------------------------------------- /VILP/flowdiffusion/guided_diffusion/guided_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | 14 | # Change this to reflect your cluster layout. 15 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 16 | GPUS_PER_NODE = 8 17 | 18 | SETUP_RETRY_COUNT = 3 19 | 20 | 21 | def setup_dist(): 22 | """ 23 | Setup a distributed process group. 24 | """ 25 | if dist.is_initialized(): 26 | return 27 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}" 28 | 29 | comm = MPI.COMM_WORLD 30 | backend = "gloo" if not th.cuda.is_available() else "nccl" 31 | 32 | if backend == "gloo": 33 | hostname = "localhost" 34 | else: 35 | hostname = socket.gethostbyname(socket.getfqdn()) 36 | os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0) 37 | os.environ["RANK"] = str(comm.rank) 38 | os.environ["WORLD_SIZE"] = str(comm.size) 39 | 40 | port = comm.bcast(_find_free_port(), root=0) 41 | os.environ["MASTER_PORT"] = str(port) 42 | dist.init_process_group(backend=backend, init_method="env://") 43 | 44 | 45 | def dev(): 46 | """ 47 | Get the device to use for torch.distributed. 48 | """ 49 | if th.cuda.is_available(): 50 | return th.device(f"cuda") 51 | return th.device("cpu") 52 | 53 | 54 | def load_state_dict(path, **kwargs): 55 | """ 56 | Load a PyTorch file without redundant fetches across MPI ranks. 57 | """ 58 | chunk_size = 2 ** 30 # MPI has a relatively small size limit 59 | if MPI.COMM_WORLD.Get_rank() == 0: 60 | with bf.BlobFile(path, "rb") as f: 61 | data = f.read() 62 | num_chunks = len(data) // chunk_size 63 | if len(data) % chunk_size: 64 | num_chunks += 1 65 | MPI.COMM_WORLD.bcast(num_chunks) 66 | for i in range(0, len(data), chunk_size): 67 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 68 | else: 69 | num_chunks = MPI.COMM_WORLD.bcast(None) 70 | data = bytes() 71 | for _ in range(num_chunks): 72 | data += MPI.COMM_WORLD.bcast(None) 73 | 74 | return th.load(io.BytesIO(data), **kwargs) 75 | 76 | 77 | def sync_params(params): 78 | """ 79 | Synchronize a sequence of Tensors across ranks from rank 0. 80 | """ 81 | for p in params: 82 | with th.no_grad(): 83 | dist.broadcast(p, 0) 84 | 85 | 86 | def _find_free_port(): 87 | try: 88 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 89 | s.bind(("", 0)) 90 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 91 | return s.getsockname()[1] 92 | finally: 93 | s.close() 94 | -------------------------------------------------------------------------------- /VILP/flowdiffusion/guided_diffusion/guided_diffusion/losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for various likelihood-based losses. These are ported from the original 3 | Ho et al. diffusion models codebase: 4 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py 5 | """ 6 | 7 | import numpy as np 8 | 9 | import torch as th 10 | 11 | 12 | def normal_kl(mean1, logvar1, mean2, logvar2): 13 | """ 14 | Compute the KL divergence between two gaussians. 15 | 16 | Shapes are automatically broadcasted, so batches can be compared to 17 | scalars, among other use cases. 18 | """ 19 | tensor = None 20 | for obj in (mean1, logvar1, mean2, logvar2): 21 | if isinstance(obj, th.Tensor): 22 | tensor = obj 23 | break 24 | assert tensor is not None, "at least one argument must be a Tensor" 25 | 26 | # Force variances to be Tensors. Broadcasting helps convert scalars to 27 | # Tensors, but it does not work for th.exp(). 28 | logvar1, logvar2 = [ 29 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 30 | for x in (logvar1, logvar2) 31 | ] 32 | 33 | return 0.5 * ( 34 | -1.0 35 | + logvar2 36 | - logvar1 37 | + th.exp(logvar1 - logvar2) 38 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 39 | ) 40 | 41 | 42 | def approx_standard_normal_cdf(x): 43 | """ 44 | A fast approximation of the cumulative distribution function of the 45 | standard normal. 46 | """ 47 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 48 | 49 | 50 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 51 | """ 52 | Compute the log-likelihood of a Gaussian distribution discretizing to a 53 | given image. 54 | 55 | :param x: the target images. It is assumed that this was uint8 values, 56 | rescaled to the range [-1, 1]. 57 | :param means: the Gaussian mean Tensor. 58 | :param log_scales: the Gaussian log stddev Tensor. 59 | :return: a tensor like x of log probabilities (in nats). 60 | """ 61 | assert x.shape == means.shape == log_scales.shape 62 | centered_x = x - means 63 | inv_stdv = th.exp(-log_scales) 64 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 65 | cdf_plus = approx_standard_normal_cdf(plus_in) 66 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 67 | cdf_min = approx_standard_normal_cdf(min_in) 68 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 69 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 70 | cdf_delta = cdf_plus - cdf_min 71 | log_probs = th.where( 72 | x < -0.999, 73 | log_cdf_plus, 74 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 75 | ) 76 | assert log_probs.shape == x.shape 77 | return log_probs 78 | -------------------------------------------------------------------------------- /VILP/flowdiffusion/guided_diffusion/guided_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /VILP/flowdiffusion/guided_diffusion/guided_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | from .gaussian_diffusion import GaussianDiffusion 5 | 6 | 7 | def space_timesteps(num_timesteps, section_counts): 8 | """ 9 | Create a list of timesteps to use from an original diffusion process, 10 | given the number of timesteps we want to take from equally-sized portions 11 | of the original process. 12 | 13 | For example, if there's 300 timesteps and the section counts are [10,15,20] 14 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 15 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 16 | 17 | If the stride is a string starting with "ddim", then the fixed striding 18 | from the DDIM paper is used, and only one section is allowed. 19 | 20 | :param num_timesteps: the number of diffusion steps in the original 21 | process to divide up. 22 | :param section_counts: either a list of numbers, or a string containing 23 | comma-separated numbers, indicating the step count 24 | per section. As a special case, use "ddimN" where N 25 | is a number of steps to use the striding from the 26 | DDIM paper. 27 | :return: a set of diffusion steps from the original process to use. 28 | """ 29 | if isinstance(section_counts, str): 30 | if section_counts.startswith("ddim"): 31 | desired_count = int(section_counts[len("ddim") :]) 32 | for i in range(1, num_timesteps): 33 | if len(range(0, num_timesteps, i)) == desired_count: 34 | return set(range(0, num_timesteps, i)) 35 | raise ValueError( 36 | f"cannot create exactly {num_timesteps} steps with an integer stride" 37 | ) 38 | section_counts = [int(x) for x in section_counts.split(",")] 39 | size_per = num_timesteps // len(section_counts) 40 | extra = num_timesteps % len(section_counts) 41 | start_idx = 0 42 | all_steps = [] 43 | for i, section_count in enumerate(section_counts): 44 | size = size_per + (1 if i < extra else 0) 45 | if size < section_count: 46 | raise ValueError( 47 | f"cannot divide section of {size} steps into {section_count}" 48 | ) 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | 67 | :param use_timesteps: a collection (sequence or set) of timesteps from the 68 | original diffusion process to retain. 69 | :param kwargs: the kwargs to create the base diffusion process. 70 | """ 71 | 72 | def __init__(self, use_timesteps, **kwargs): 73 | self.use_timesteps = set(use_timesteps) 74 | self.timestep_map = [] 75 | self.original_num_steps = len(kwargs["betas"]) 76 | 77 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 78 | last_alpha_cumprod = 1.0 79 | new_betas = [] 80 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 81 | if i in self.use_timesteps: 82 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 83 | last_alpha_cumprod = alpha_cumprod 84 | self.timestep_map.append(i) 85 | kwargs["betas"] = np.array(new_betas) 86 | super().__init__(**kwargs) 87 | 88 | def p_mean_variance( 89 | self, model, *args, **kwargs 90 | ): # pylint: disable=signature-differs 91 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 92 | 93 | def training_losses( 94 | self, model, *args, **kwargs 95 | ): # pylint: disable=signature-differs 96 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 97 | 98 | def condition_mean(self, cond_fn, *args, **kwargs): 99 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 100 | 101 | def condition_score(self, cond_fn, *args, **kwargs): 102 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 103 | 104 | def _wrap_model(self, model): 105 | if isinstance(model, _WrappedModel): 106 | return model 107 | return _WrappedModel( 108 | model, self.timestep_map, self.rescale_timesteps, self.original_num_steps 109 | ) 110 | 111 | def _scale_timesteps(self, t): 112 | # Scaling is done by the wrapped model. 113 | return t 114 | 115 | 116 | class _WrappedModel: 117 | def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps): 118 | self.model = model 119 | self.timestep_map = timestep_map 120 | self.rescale_timesteps = rescale_timesteps 121 | self.original_num_steps = original_num_steps 122 | 123 | def __call__(self, x, ts, **kwargs): 124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 125 | new_ts = map_tensor[ts] 126 | if self.rescale_timesteps: 127 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 128 | return self.model(x, new_ts, **kwargs) 129 | -------------------------------------------------------------------------------- /VILP/flowdiffusion/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .myunet import Unet3D0409, Unet3D0410 -------------------------------------------------------------------------------- /VILP/flowdiffusion/model/myunet.py: -------------------------------------------------------------------------------- 1 | from .imagen import Unet3D 2 | from pynvml import * 3 | import torch 4 | import torch.nn as nn 5 | 6 | def print_gpu_utilization(): 7 | nvmlInit() 8 | handle = nvmlDeviceGetHandleByIndex(0) 9 | info = nvmlDeviceGetMemoryInfo(handle) 10 | print(f"GPU memory occupied: {info.used//1024**2} MB.") 11 | 12 | class Unet3D0409(nn.Module): 13 | def __init__(self): 14 | super(Unet3D0409, self).__init__() 15 | self.unet = Unet3D( 16 | dim=128, 17 | text_embed_dim = 512, 18 | attn_dim_head = 64, 19 | attn_heads = 8, 20 | ff_mult = 2., 21 | cond_images_channels = 3, 22 | channels = 3, 23 | dim_mults = (1, 2, 4, 8), 24 | ff_time_token_shift = True, # this would do a token shift along time axis, at the hidden layer within feedforwards - from successful use in RWKV (Peng et al), and other token shift video transformer works 25 | lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ 26 | layer_attns = False, 27 | layer_attns_depth = 1, 28 | layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 29 | attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) 30 | time_rel_pos_bias_depth = 2, 31 | time_causal_attn = True, 32 | layer_cross_attns = True, 33 | use_linear_attn = False, 34 | use_linear_cross_attn = False, 35 | cond_on_text = True, 36 | max_text_len = 32, 37 | memory_efficient = True, 38 | final_conv_kernel_size = 3, 39 | self_cond = False, 40 | combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully 41 | pixel_shuffle_upsample = True, # may address checkboard artifacts 42 | resize_mode = 'nearest' 43 | ) 44 | 45 | def forward(self, x, t, x_cond, text_embed=None, **kwargs): 46 | x = self.unet(x, t, cond_images=x_cond, text_embeds=text_embed, **kwargs) 47 | return x 48 | 49 | class BaseUnet64(Unet3D): 50 | def __init__(self, *args, **kwargs): 51 | default_kwargs = dict( 52 | dim = 160, 53 | dim_mults = (1, 2, 3, 4), 54 | num_resnet_blocks = 3, 55 | layer_attns = (False, True, True, True), 56 | layer_cross_attns = (False, True, True, True), 57 | attn_heads = 4, 58 | ff_mult = 2., 59 | memory_efficient = True, 60 | cond_images_channels=3, 61 | ) 62 | super().__init__(*args, **{**default_kwargs, **kwargs}) 63 | 64 | class Unet3D0410(nn.Module): 65 | def __init__(self): 66 | super(Unet3D0410, self).__init__() 67 | self.unet = BaseUnet64( 68 | time_causal_attn = False, 69 | ) 70 | 71 | def forward(self, x, t, x_cond, text_embed=None, **kwargs): 72 | x = self.unet(x, t, cond_images=x_cond, text_embeds=text_embed, **kwargs) 73 | return x 74 | 75 | def forward_with_cond_scale(self, x, t, x_cond, text_embed=None, cond_scale=1.0, **kwargs): 76 | x = self.unet.forward_with_cond_scale(x, t, cond_images=x_cond, text_embeds=text_embed, cond_scale=cond_scale, **kwargs) 77 | return x 78 | 79 | 80 | 81 | 82 | 83 | -------------------------------------------------------------------------------- /VILP/flowdiffusion/train_bridge.py: -------------------------------------------------------------------------------- 1 | from goal_diffusion import GoalGaussianDiffusion, Trainer 2 | from unet import UnetBridge as Unet 3 | from transformers import CLIPTextModel, CLIPTokenizer 4 | import torch 5 | import numpy as np 6 | import os 7 | from torch.utils.data import Dataset, DataLoader, Subset 8 | import torch.nn as nn 9 | import json 10 | from PIL import Image 11 | import tqdm 12 | from accelerate import Accelerator 13 | from datasets import SequentialDatasetNp, SequentialDatasetVal 14 | import argparse 15 | 16 | 17 | def main(args): 18 | valid_n = 1 19 | sample_per_seq = 7 20 | target_size = (48, 64) 21 | 22 | if args.mode == 'inference': 23 | train_set = valid_set = [None] # dummy 24 | else: 25 | train_set = SequentialDatasetNp( 26 | sample_per_seq=sample_per_seq, 27 | path="../datasets/bridge/numpy/bridge_data_v1/berkeley/", 28 | target_size=target_size, 29 | debug=False, 30 | ) 31 | valid_inds = [i for i in range(0, len(train_set), len(train_set)//valid_n)][:valid_n] 32 | valid_set = Subset(train_set, valid_inds) 33 | unet = Unet() 34 | 35 | pretrained_model = "openai/clip-vit-base-patch32" 36 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model) 37 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model) 38 | text_encoder.requires_grad_(False) 39 | text_encoder.eval() 40 | 41 | diffusion = GoalGaussianDiffusion( 42 | channels=3*(sample_per_seq-1), 43 | model=unet, 44 | image_size=target_size, 45 | timesteps=100, 46 | sampling_timesteps=args.sample_steps, 47 | loss_type='l2', 48 | objective='pred_v', 49 | beta_schedule = 'cosine', 50 | min_snr_loss_weight = True, 51 | ) 52 | 53 | trainer = Trainer( 54 | diffusion_model=diffusion, 55 | tokenizer=tokenizer, 56 | text_encoder=text_encoder, 57 | train_set=train_set, 58 | valid_set=valid_set, 59 | train_lr=1e-4, 60 | train_num_steps =180000, 61 | save_and_sample_every =4000, 62 | ema_update_every = 10, 63 | ema_decay = 0.999, 64 | train_batch_size =32, 65 | valid_batch_size =valid_n, 66 | gradient_accumulate_every = 1, 67 | num_samples=30, 68 | results_folder ='../results/bridge', 69 | fp16 =True, 70 | amp=True, 71 | ) 72 | 73 | if args.checkpoint_num is not None: 74 | trainer.load(args.checkpoint_num) 75 | 76 | if args.mode == 'train': 77 | trainer.train() 78 | else: 79 | from PIL import Image 80 | from torchvision import transforms 81 | import imageio 82 | import torch 83 | from os.path import splitext 84 | text = args.text 85 | image = Image.open(args.inference_path) 86 | batch_size = 1 87 | transform = transforms.Compose([ 88 | transforms.Resize(target_size), 89 | transforms.ToTensor(), 90 | ]) 91 | image = transform(image) 92 | output = trainer.sample(image.unsqueeze(0), [text], batch_size).cpu() 93 | output = output[0].reshape(-1, 3, *target_size) 94 | output = torch.cat([image.unsqueeze(0), output], dim=0) 95 | root, ext = splitext(args.inference_path) 96 | output_gif = root + '_out.gif' 97 | output = (output.cpu().numpy().transpose(0, 2, 3, 1).clip(0, 1) * 255).astype('uint8') 98 | 99 | ## 231130 resize output image to 240x320 to make it look better 100 | output = [np.array(Image.fromarray(frame).resize((320, 240))) for frame in output] 101 | 102 | imageio.mimsave(output_gif, output, duration=200, loop=1000) 103 | print(f'Generated {output_gif}') 104 | 105 | if __name__ == "__main__": 106 | parser = argparse.ArgumentParser() 107 | parser.add_argument('-m', '--mode', type=str, default='train', choices=['train', 'inference']) # 'train for training, 'inference' for generating samples 108 | parser.add_argument('-c', '--checkpoint_num', type=int, default=None) # checkpoint number to resume training or generate samples 109 | parser.add_argument('-p', '--inference_path', type=str, default=None) # path to input image 110 | parser.add_argument('-t', '--text', type=str, default=None) # task text 111 | parser.add_argument('-g', '--guidance_weight', type=int, default=0) # set to positive to use guidance 112 | args = parser.parse_args() 113 | if args.mode == 'inference': 114 | assert args.checkpoint_num is not None 115 | assert args.inference_path is not None 116 | assert args.text is not None 117 | assert args.sample_steps <= 100 118 | main(args) 119 | 120 | -------------------------------------------------------------------------------- /VILP/flowdiffusion/train_mw.py: -------------------------------------------------------------------------------- 1 | from VILP.flowdiffusion.goal_diffusion import GoalGaussianDiffusion, Trainer 2 | from VILP.flowdiffusion.unet import UnetMW as Unet 3 | from transformers import CLIPTextModel, CLIPTokenizer 4 | from datasets import SequentialDatasetv2 5 | from torch.utils.data import Subset 6 | import argparse 7 | 8 | 9 | def main(args): 10 | valid_n = 1 11 | sample_per_seq = 8 12 | target_size = (128, 128) 13 | 14 | if args.mode == 'inference': 15 | train_set = valid_set = [None] # dummy 16 | else: 17 | train_set = SequentialDatasetv2( 18 | sample_per_seq=sample_per_seq, 19 | path="../datasets/metaworld", 20 | target_size=target_size, 21 | randomcrop=True 22 | ) 23 | valid_inds = [i for i in range(0, len(train_set), len(train_set)//valid_n)][:valid_n] 24 | valid_set = Subset(train_set, valid_inds) 25 | 26 | unet = Unet() 27 | 28 | pretrained_model = "openai/clip-vit-base-patch32" 29 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model) 30 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model) 31 | text_encoder.requires_grad_(False) 32 | text_encoder.eval() 33 | 34 | diffusion = GoalGaussianDiffusion( 35 | channels=3*(sample_per_seq-1), 36 | model=unet, 37 | image_size=target_size, 38 | timesteps=100, 39 | sampling_timesteps=args.sample_steps, 40 | loss_type='l2', 41 | objective='pred_v', 42 | beta_schedule = 'cosine', 43 | min_snr_loss_weight = True, 44 | ) 45 | 46 | trainer = Trainer( 47 | diffusion_model=diffusion, 48 | tokenizer=tokenizer, 49 | text_encoder=text_encoder, 50 | train_set=train_set, 51 | valid_set=valid_set, 52 | train_lr=1e-4, 53 | train_num_steps =60000, 54 | save_and_sample_every =2500, 55 | ema_update_every = 10, 56 | ema_decay = 0.999, 57 | train_batch_size =16, 58 | valid_batch_size =32, 59 | gradient_accumulate_every = 1, 60 | num_samples=valid_n, 61 | results_folder ='../results/mw', 62 | fp16 =True, 63 | amp=True, 64 | ) 65 | 66 | #if args.checkpoint_num is not None: 67 | # trainer.load(args.checkpoint_num) 68 | 69 | if args.mode == 'train': 70 | trainer.train() 71 | else: 72 | from PIL import Image 73 | from torchvision import transforms 74 | import imageio 75 | import torch 76 | from os.path import splitext 77 | text = args.text 78 | guidance_weight = args.guidance_weight 79 | image = Image.open(args.inference_path) 80 | batch_size = 1 81 | ### 231130 fixed center crop issue 82 | transform = transforms.Compose([ 83 | transforms.Resize((240, 320)), 84 | transforms.CenterCrop(target_size), 85 | transforms.ToTensor(), 86 | ]) 87 | image = transform(image) 88 | output = trainer.sample(image.unsqueeze(0), [text], batch_size, guidance_weight).cpu() 89 | 90 | output = output[0].reshape(-1, 3, *target_size) 91 | 92 | output = torch.cat([image.unsqueeze(0), output], dim=0) 93 | 94 | root, ext = splitext(args.inference_path) 95 | output_gif = root + '_out.gif' 96 | output = (output.cpu().numpy().transpose(0, 2, 3, 1).clip(0, 1) * 255).astype('uint8') 97 | imageio.mimsave(output_gif, output, duration=200, loop=1000) 98 | 99 | 100 | if __name__ == "__main__": 101 | parser = argparse.ArgumentParser() 102 | parser.add_argument('-m', '--mode', type=str, default='train', choices=['train', 'inference']) # set to 'inference' to generate samples 103 | parser.add_argument('-c', '--checkpoint_num', type=int, default=None) # set to checkpoint number to resume training or generate samples 104 | parser.add_argument('-p', '--inference_path', type=str, default=None) # set to path to generate samples 105 | parser.add_argument('-t', '--text', type=str, default=None) # set to text to generate samples 106 | parser.add_argument('-n', '--sample_steps', type=int, default=100) # set to number of steps to sample 107 | parser.add_argument('-g', '--guidance_weight', type=int, default=0) # set to positive to use guidance 108 | args = parser.parse_args() 109 | ''' 110 | if args.mode == 'inference': 111 | assert args.checkpoint_num is not None 112 | assert args.inference_path is not None 113 | assert args.text is not None 114 | assert args.sample_steps <= 100 115 | ''' 116 | main(args) -------------------------------------------------------------------------------- /VILP/flowdiffusion/train_thor.py: -------------------------------------------------------------------------------- 1 | from goal_diffusion import GoalGaussianDiffusion, Trainer 2 | from unet import UnetThor as Unet 3 | from transformers import CLIPTextModel, CLIPTokenizer 4 | from datasets import SequentialNavDataset 5 | from torch.utils.data import Subset 6 | import argparse 7 | 8 | 9 | def main(args): 10 | valid_n = 1 11 | sample_per_seq = 8 12 | target_size = (64, 64) 13 | 14 | if args.mode == 'inference': 15 | train_set = valid_set = [None] # dummy 16 | else: 17 | train_set = SequentialNavDataset( 18 | sample_per_seq=sample_per_seq, 19 | path="../datasets/thor", 20 | target_size=target_size, 21 | ) 22 | valid_inds = [i for i in range(0, len(train_set), len(train_set)//valid_n)][:valid_n] 23 | valid_set = Subset(train_set, valid_inds) 24 | unet = Unet() 25 | 26 | pretrained_model = "openai/clip-vit-base-patch32" 27 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model) 28 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model) 29 | text_encoder.requires_grad_(False) 30 | text_encoder.eval() 31 | 32 | diffusion = GoalGaussianDiffusion( 33 | channels=3*(sample_per_seq-1), 34 | model=unet, 35 | image_size=target_size, 36 | timesteps=100, 37 | sampling_timesteps=args.sample_steps, 38 | loss_type='l2', 39 | objective='pred_v', 40 | beta_schedule = 'cosine', 41 | min_snr_loss_weight = True, 42 | ) 43 | 44 | trainer = Trainer( 45 | diffusion_model=diffusion, 46 | tokenizer=tokenizer, 47 | text_encoder=text_encoder, 48 | train_set=train_set, 49 | valid_set=valid_set, 50 | train_lr=1e-4, 51 | train_num_steps =80000, 52 | save_and_sample_every =5000, 53 | ema_update_every = 10, 54 | ema_decay = 0.999, 55 | train_batch_size =32, 56 | valid_batch_size =32, 57 | gradient_accumulate_every = 1, 58 | num_samples=valid_n, 59 | results_folder ='../results/thor', 60 | fp16 =True, 61 | amp=True, 62 | ) 63 | 64 | if args.checkpoint_num is not None: 65 | trainer.load(args.checkpoint_num) 66 | 67 | if args.mode == 'train': 68 | trainer.train() 69 | else: 70 | from PIL import Image 71 | from torchvision import transforms 72 | import imageio 73 | import torch 74 | from os.path import splitext 75 | text = args.text 76 | image = Image.open(args.inference_path) 77 | batch_size = 1 78 | transform = transforms.Compose([ 79 | transforms.Resize(target_size), 80 | transforms.ToTensor(), 81 | ]) 82 | image = transform(image) 83 | output = trainer.sample(image.unsqueeze(0), [text], batch_size).cpu() 84 | output = output[0].reshape(-1, 3, *target_size) 85 | output = torch.cat([image.unsqueeze(0), output], dim=0) 86 | root, ext = splitext(args.inference_path) 87 | output_gif = root + '_out.gif' 88 | output = (output.cpu().numpy().transpose(0, 2, 3, 1).clip(0, 1) * 255).astype('uint8') 89 | imageio.mimsave(output_gif, output, duration=200, loop=1000) 90 | print(f'Generated {output_gif}') 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument('-m', '--mode', type=str, default='train', choices=['train', 'inference']) # set to 'inference' to generate samples 95 | parser.add_argument('-c', '--checkpoint_num', type=int, default=None) # set to checkpoint number to resume training or generate samples 96 | parser.add_argument('-p', '--inference_path', type=str, default=None) # set to path to generate samples 97 | parser.add_argument('-t', '--text', type=str, default=None) # set to text to generate samples 98 | parser.add_argumant('-n', '--sample_steps', type=int, default=100) # set to number of steps to sample 99 | parser.add_argument('-g', '--guidance_weight', type=int, default=0) # set to positive to use guidance 100 | args = parser.parse_args() 101 | if args.mode == 'inference': 102 | assert args.checkpoint_num is not None 103 | assert args.inference_path is not None 104 | assert args.text is not None 105 | assert args.sample_steps <= 100 106 | main(args) -------------------------------------------------------------------------------- /VILP/flowdiffusion/unet.py: -------------------------------------------------------------------------------- 1 | from VILP.flowdiffusion.guided_diffusion.guided_diffusion.unet import UNetModel 2 | from torch import nn 3 | import torch 4 | from einops import repeat, rearrange 5 | 6 | 7 | class UnetBridge(nn.Module): 8 | def __init__(self): 9 | super(UnetBridge, self).__init__() 10 | 11 | self.unet = UNetModel( 12 | image_size=(48, 64), 13 | in_channels=6, 14 | model_channels=160, 15 | out_channels=3, 16 | num_res_blocks=3, 17 | attention_resolutions=(4, 8), 18 | dropout=0, 19 | channel_mult=(1, 2, 4), 20 | conv_resample=True, 21 | dims=3, 22 | num_classes=None, 23 | task_tokens=True, 24 | task_token_channels=512, 25 | use_checkpoint=False, 26 | use_fp16=False, 27 | num_head_channels=32, 28 | ) 29 | self.unet.convert_to_fp32() 30 | 31 | def forward(self, x, t, task_embed=None, **kwargs): 32 | f = x.shape[1] // 3 - 1 33 | x_cond = repeat(x[:, -3:], 'b c h w -> b c f h w', f=f) 34 | x = rearrange(x[:, :-3], 'b (f c) h w -> b c f h w', c=3) 35 | x = torch.cat([x, x_cond], dim=1) 36 | out = self.unet(x, t, task_embed, **kwargs) 37 | return rearrange(out, 'b c f h w -> b (f c) h w') 38 | 39 | class UnetMW(nn.Module): 40 | def __init__(self): 41 | super(UnetMW, self).__init__() 42 | self.unet = UNetModel( 43 | image_size=(128, 128), 44 | in_channels=6, 45 | #model_channels=128, 46 | model_channels=64, 47 | out_channels=3, 48 | num_res_blocks=2, 49 | #attention_resolutions=(8, 16), 50 | attention_resolutions=(4, 8), 51 | dropout=0, 52 | #channel_mult=(1, 2, 3, 4, 5), 53 | channel_mult=(1, 2, 4), 54 | conv_resample=True, 55 | dims=3, 56 | num_classes=None, 57 | task_tokens=True, 58 | task_token_channels=512, 59 | use_checkpoint=False, 60 | use_fp16=False, 61 | #num_head_channels=32, 62 | num_head_channels=8, 63 | ) 64 | def forward(self, x, t, task_embed=None, **kwargs): 65 | f = x.shape[1] // 3 - 1 66 | x_cond = repeat(x[:, -3:], 'b c h w -> b c f h w', f=f) 67 | x = rearrange(x[:, :-3], 'b (f c) h w -> b c f h w', c=3) 68 | x = torch.cat([x, x_cond], dim=1) 69 | out = self.unet(x, t, task_embed, **kwargs) 70 | return rearrange(out, 'b c f h w -> b (f c) h w') 71 | 72 | class UnetMW_flow(nn.Module): 73 | def __init__(self): 74 | super(UnetMW_flow, self).__init__() 75 | self.unet = UNetModel( 76 | image_size=(128, 128), 77 | in_channels=5, 78 | model_channels=128, 79 | out_channels=2, 80 | num_res_blocks=2, 81 | attention_resolutions=(8, 16), 82 | dropout=0, 83 | channel_mult=(1, 2, 3, 4, 5), 84 | conv_resample=True, 85 | dims=3, 86 | num_classes=None, 87 | task_tokens=True, 88 | task_token_channels=512, 89 | use_checkpoint=False, 90 | use_fp16=False, 91 | num_head_channels=32, 92 | ) 93 | def forward(self, x, t, task_embed=None, **kwargs): 94 | f = x.shape[1] // 2 - 1 95 | x_cond = repeat(x[:, -3:], 'b c h w -> b c f h w', f=f) 96 | x = rearrange(x[:, :-3], 'b (f c) h w -> b c f h w', f=f) 97 | x = torch.cat([x, x_cond], dim=1) 98 | 99 | # For VILP. experiments, set task_embed to None globally 100 | task_embed = None 101 | 102 | out = self.unet(x, t, task_embed, **kwargs) 103 | return rearrange(out, 'b c f h w -> b (f c) h w') 104 | 105 | class UnetThor(nn.Module): 106 | def __init__(self): 107 | super(UnetThor, self).__init__() 108 | 109 | self.unet = UNetModel( 110 | image_size=(64, 64), 111 | in_channels=6, 112 | model_channels=128, 113 | out_channels=3, 114 | num_res_blocks=3, 115 | attention_resolutions=(4, 8), 116 | dropout=0, 117 | channel_mult=(1, 2, 4), 118 | conv_resample=True, 119 | dims=3, 120 | num_classes=None, 121 | task_tokens=True, 122 | task_token_channels=512, 123 | use_checkpoint=False, 124 | use_fp16=False, 125 | num_head_channels=32, 126 | ) 127 | self.unet.convert_to_fp32() 128 | 129 | def forward(self, x, t, task_embed=None, **kwargs): 130 | f = x.shape[1] // 3 - 1 131 | x_cond = repeat(x[:, -3:], 'b c h w -> b c f h w', f=f) 132 | x = rearrange(x[:, :-3], 'b (f c) h w -> b c f h w', c=3) 133 | x = torch.cat([x, x_cond], dim=1) 134 | out = self.unet(x, t, task_embed, **kwargs) 135 | return rearrange(out, 'b c f h w -> b (f c) h w') 136 | 137 | 138 | -------------------------------------------------------------------------------- /VILP/flowdiffusion/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | def get_paths(root="../berkeley"): 5 | f = [] 6 | for dirpath, dirname, filename in os.walk(root): 7 | if "image" in dirpath: 8 | f.append(dirpath) 9 | print(f"Found {len(f)} sequences") 10 | return f 11 | 12 | def get_paths_from_dir(dir_path): 13 | paths = glob.glob(os.path.join(dir_path, 'im*.jpg')) 14 | try: 15 | paths = sorted(paths, key=lambda x: int((x.split('/')[-1].split('.')[0])[3:])) 16 | except: 17 | print(paths) 18 | return paths 19 | 20 | -------------------------------------------------------------------------------- /VILP/model/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhengtongXu/VILP/fea07284c61915da258cd600f940544fd0960e5f/VILP/model/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /VILP/model/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhengtongXu/VILP/fea07284c61915da258cd600f940544fd0960e5f/VILP/model/modules/distributions/__init__.py -------------------------------------------------------------------------------- /VILP/model/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /VILP/model/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /VILP/model/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhengtongXu/VILP/fea07284c61915da258cd600f940544fd0960e5f/VILP/model/modules/encoders/__init__.py -------------------------------------------------------------------------------- /VILP/model/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /VILP/model/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhengtongXu/VILP/fea07284c61915da258cd600f940544fd0960e5f/VILP/model/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /VILP/model/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /VILP/model/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from VILP.taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /VILP/model/spatial_mask_generator_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | 4 | class MultiDimMaskGenerator(Module): 5 | def __init__(self, 6 | action_dims, # Tuple of (action_dim1, action_dim2, action_dim3) 7 | max_n_obs_steps=2, 8 | fix_obs_steps=True, 9 | action_visible=False, 10 | device='cuda:1'): 11 | super().__init__() 12 | self.action_dims = action_dims 13 | self.max_n_obs_steps = max_n_obs_steps 14 | self.fix_obs_steps = fix_obs_steps 15 | self.action_visible = action_visible 16 | self.device = device 17 | @torch.no_grad() 18 | def forward(self, shape, seed=None): 19 | device = self.device 20 | B, T, D1, D2, D3 = shape 21 | assert (D1, D2, D3) == (self.action_dims[0], 22 | self.action_dims[1], 23 | self.action_dims[2]) 24 | 25 | rng = torch.Generator(device=device) 26 | if seed is not None: 27 | rng.manual_seed(seed) 28 | 29 | dim_mask = torch.zeros(size=(D1, D2, D3), dtype=torch.bool, device=device) 30 | is_action_dim = dim_mask.clone() 31 | 32 | # Assuming action_dims and obs_dims cover different parts of the D1, D2, D3 dimensions entirely 33 | # Update the action mask to True for action dimensions 34 | is_action_dim[:self.action_dims[0], :self.action_dims[1], :self.action_dims[2]] = True 35 | 36 | # The observation dimensions are the complement of the action dimensions 37 | is_obs_dim = ~is_action_dim 38 | 39 | # Extend is_action_dim and is_obs_dim to match the input shape (B, T, D1, D2, D3) 40 | is_action_dim = is_action_dim.unsqueeze(0).unsqueeze(0).expand(B, T, -1, -1, -1) 41 | is_obs_dim = is_obs_dim.unsqueeze(0).unsqueeze(0).expand(B, T, -1, -1, -1) 42 | 43 | # Determine the number of observation steps for each batch 44 | obs_steps = torch.full((B,), fill_value=self.max_n_obs_steps, device=device) 45 | 46 | 47 | steps = torch.arange(0, T, device=device).expand(B, T) 48 | obs_time_mask = (steps.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) < obs_steps.view(-1, 1, 1, 1, 1)) 49 | obs_time_mask = obs_time_mask.expand(-1, -1, D1, D2, D3) 50 | 51 | # Apply observation dimension mask 52 | # Note: is_obs_dim should be defined similarly to is_action_dim in the adjusted 3D context 53 | obs_mask = obs_time_mask & is_obs_dim 54 | 55 | mask = obs_mask 56 | 57 | return mask -------------------------------------------------------------------------------- /VILP/model/vision/model_getter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | def get_resnet(name, weights=None, **kwargs): 5 | """ 6 | name: resnet18, resnet34, resnet50 7 | weights: "IMAGENET1K_V1", "r3m" 8 | """ 9 | # load r3m weights 10 | if (weights == "r3m") or (weights == "R3M"): 11 | return get_r3m(name=name, **kwargs) 12 | 13 | func = getattr(torchvision.models, name) 14 | resnet = func(weights=weights, **kwargs) 15 | resnet.fc = torch.nn.Identity() 16 | return resnet 17 | 18 | def get_r3m(name, **kwargs): 19 | """ 20 | name: resnet18, resnet34, resnet50 21 | """ 22 | import r3m 23 | r3m.device = 'cpu' 24 | model = r3m.load_r3m(name) 25 | r3m_model = model.module 26 | resnet_model = r3m_model.convnet 27 | resnet_model = resnet_model.to('cpu') 28 | return resnet_model 29 | -------------------------------------------------------------------------------- /VILP/policy/action_mapping_itp.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange, reduce 6 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler 7 | 8 | from VILP.model.normalizer import LinearNormalizer 9 | from diffusion_policy.policy.base_image_policy import BaseImagePolicy 10 | from diffusion_policy.model.diffusion.conditional_unet1d import ConditionalUnet1D 11 | from diffusion_policy.model.diffusion.mask_generator import LowdimMaskGenerator 12 | from VILP.model.vision.multi_image_obs_encoder import MultiImageObsEncoder 13 | from diffusion_policy.common.pytorch_util import dict_apply 14 | 15 | 16 | class MlpDecoder(nn.Module): 17 | def __init__(self, 18 | input_dim, 19 | output_dim, 20 | hidden_dims=(256,256), 21 | activation=nn.ReLU(), 22 | output_activation=None): 23 | super().__init__() 24 | dims = [input_dim, *hidden_dims, output_dim] 25 | layers = [] 26 | for i in range(len(dims) - 1): 27 | layers.append(nn.Linear(dims[i], dims[i+1])) 28 | if i < len(dims) - 2: 29 | layers.append(activation) 30 | if output_activation is not None: 31 | layers.append(output_activation) 32 | self.layers = nn.Sequential(*layers) 33 | 34 | def forward(self, x): 35 | return self.layers(x) 36 | class ActionMappingItp(BaseImagePolicy): 37 | def __init__(self, 38 | shape_meta: dict, 39 | obs_encoder: MultiImageObsEncoder, 40 | horizon = 1, 41 | # parameters passed to step 42 | **kwargs): 43 | super().__init__() 44 | 45 | # parse shapes 46 | action_shape = shape_meta['action']['shape'] 47 | assert len(action_shape) == 1 48 | action_dim = action_shape[0] 49 | # get feature dim 50 | obs_feature_dim = obs_encoder.output_shape()[0] 51 | 52 | self.obs_encoder = obs_encoder 53 | self.normalizer = LinearNormalizer() 54 | self.horizon = horizon 55 | self.obs_feature_dim = obs_feature_dim 56 | self.action_dim = action_dim 57 | self.n_obs_steps = 2 58 | 59 | self.kwargs = kwargs 60 | 61 | self.decoder = MlpDecoder(obs_feature_dim*2, action_dim*horizon) 62 | 63 | 64 | 65 | def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 66 | """ 67 | obs_dict: must include "obs" key 68 | result: must include "action" key 69 | """ 70 | assert 'past_action' not in obs_dict # not implemented yet 71 | # normalize input 72 | nobs = self.normalizer.normalize(obs_dict) 73 | value = next(iter(nobs.values())) 74 | B, To = value.shape[:2] 75 | T = self.horizon 76 | Da = self.action_dim 77 | Do = self.obs_feature_dim 78 | To = self.n_obs_steps 79 | 80 | # condition through global feature 81 | this_nobs = dict_apply(nobs, lambda x: x[:,:To,...].reshape(-1,*x.shape[2:])) 82 | nobs_features = self.obs_encoder(this_nobs) 83 | # reshape back to B, Do 84 | global_cond = nobs_features.reshape(B, -1) 85 | action = self.decoder(global_cond).reshape(B, T, Da) 86 | result = { 87 | 'action': self.normalizer['action'].unnormalize(action), 88 | } 89 | return result 90 | 91 | # ========= training ============ 92 | def set_normalizer(self, normalizer: LinearNormalizer): 93 | self.normalizer.load_state_dict(normalizer.state_dict()) 94 | 95 | def compute_loss(self, batch): 96 | # normalize input 97 | assert 'valid_mask' not in batch 98 | nobs = self.normalizer.normalize(batch['obs']) 99 | nactions = self.normalizer['action'].normalize(batch['action']) 100 | batch_size = nactions.shape[0] 101 | 102 | this_nobs = dict_apply(nobs, 103 | lambda x: x[:,:self.n_obs_steps,...].reshape(-1,*x.shape[2:])) 104 | nobs_features = self.obs_encoder(this_nobs) 105 | # reshape back to B, Do 106 | global_cond = nobs_features.reshape(batch_size, -1) 107 | 108 | pred = self.decoder(global_cond) 109 | 110 | # naction, from B,1,Da to B,Da 111 | nactions = nactions[:,:self.horizon,:] 112 | nactions = nactions.reshape(-1, self.action_dim*self.horizon) 113 | loss = F.mse_loss(pred, nactions, reduction='none') 114 | loss = reduce(loss, 'b ... -> b (...)', 'mean') 115 | loss = loss.mean() 116 | return loss -------------------------------------------------------------------------------- /VILP/policy/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | def interpolate_tensor(input_tensor, n): 3 | 4 | batch_size, steps, dim = input_tensor.shape 5 | new_steps = steps + (steps - 1) * n 6 | output_tensor = torch.zeros((batch_size, new_steps, dim), device=input_tensor.device, dtype=input_tensor.dtype) 7 | output_tensor[:, ::n+1, :] = input_tensor 8 | 9 | for i in range(n): 10 | alpha = (i + 1) / (n + 1) 11 | output_tensor[:, 1+i::n+1, :] = alpha * input_tensor[:, 1:, :] + (1 - alpha) * input_tensor[:, :-1, :] 12 | 13 | return output_tensor -------------------------------------------------------------------------------- /VILP/policy/vilp_low_level_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | from einops import rearrange 4 | from VILP.model.normalizer import LinearNormalizer 5 | from diffusion_policy.policy.base_lowdim_policy import BaseLowdimPolicy 6 | from VILP.taming.models.vqgan import VQModel 7 | from VILP.policy.latent_video_diffusion import LatentVideoDiffusion 8 | from VILP.policy.action_mapping_itp import ActionMappingItp 9 | from VILP.policy.utils import interpolate_tensor 10 | from typing import List 11 | 12 | class VilpLowLevelPolicy(BaseLowdimPolicy): 13 | def __init__(self, 14 | model_low_level: ActionMappingItp, 15 | planner_paths: List[str], 16 | planners: List[LatentVideoDiffusion], 17 | vqgan_configs: List[dict], 18 | keys: List[str], 19 | subgoal_steps=4, 20 | subgoal_interval=10, 21 | obs_steps=2, 22 | latent_shape=[2, 12, 12], 23 | n_action_steps_rollout=5, 24 | n_frames_steps_rollout=4, 25 | with_itp = True, 26 | device = None, 27 | **kwargs): 28 | super().__init__() 29 | 30 | self.model_low_level = model_low_level 31 | models_high_level = [] 32 | for index, planner in enumerate(planners): 33 | planner.load_state_dict(torch.load(planner_paths[index])) 34 | planner.to(device) 35 | for param in planner.parameters(): 36 | param.requires_grad = False 37 | models_high_level.append(planner) 38 | self.models_high_level = models_high_level 39 | self.vqgans = [VQModel(**config) for config in vqgan_configs] 40 | for vqgan in self.vqgans: 41 | vqgan.to(device) 42 | for param in vqgan.parameters(): 43 | param.requires_grad = False 44 | self.normalizer = LinearNormalizer() 45 | 46 | self.subgoal_steps = subgoal_steps 47 | self.subgoal_interval = subgoal_interval 48 | self.kwargs = kwargs 49 | self.obs_steps = obs_steps 50 | self.latent_shape = latent_shape 51 | self.keys = keys 52 | self.n_action_steps_rollout = n_action_steps_rollout 53 | self.n_frames_steps_rollout = n_frames_steps_rollout 54 | self.with_itp = with_itp 55 | 56 | def predict_action(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 57 | img_preds = [] 58 | for index, model_high_level in enumerate(self.models_high_level): 59 | 60 | obs_pred = model_high_level.predict_latent(obs_dict)['latent_pred'].to(dtype=torch.float32) 61 | obs_pred = obs_pred[:, :self.n_frames_steps_rollout, :, :, :] 62 | obs_pred = obs_pred.reshape(-1, obs_pred.shape[2], obs_pred.shape[3], obs_pred.shape[4]) 63 | quant, emb_loss, info = self.vqgans[index].quantize(obs_pred) 64 | img_pred = self.vqgans[index].decode(quant) 65 | img_pred = img_pred.reshape(-1, self.n_frames_steps_rollout, img_pred.shape[1], img_pred.shape[2], img_pred.shape[3]) 66 | img_pred = torch.clamp(img_pred.to(dtype=torch.float32), -1., 1.) 67 | img_pred[:, 0] = obs_dict[self.keys[index]][:, 0] 68 | img_preds.append(img_pred) 69 | 70 | obs_low_level = {} 71 | for index, img_pred in enumerate(img_preds): 72 | # Concatenate every two adjacent images to a new image sequence 73 | img_pred_start = img_pred[:, :-1, :, :, :].unsqueeze(2) 74 | img_pred_end = img_pred[:, 1:, :, :, :].unsqueeze(2) 75 | img_seq = torch.cat([img_pred_start, img_pred_end], dim=2) 76 | img_seq = img_seq.reshape(-1, img_seq.shape[2], img_seq.shape[3], img_seq.shape[4], img_seq.shape[5]) 77 | obs_low_level[self.keys[index]] = img_seq 78 | 79 | result = self.model_low_level.predict_action(obs_low_level) 80 | if self.with_itp: 81 | sparse_action = rearrange(result['action'][:, 0, :], '(b s) d -> b s d', b=obs_dict[self.keys[0]].shape[0]) 82 | action = interpolate_tensor(sparse_action, 1).to(device=obs_dict[self.keys[0]].device) 83 | inte_action = {} 84 | inte_action['action'] = action[:, :self.n_action_steps_rollout, :] 85 | 86 | return inte_action, img_preds[0] 87 | else: 88 | 89 | action = rearrange(result['action'], '(b s) h d -> b (s h) d', b=obs_dict[self.keys[0]].shape[0]) 90 | result['action'] = action[:, :self.n_action_steps_rollout, :] 91 | 92 | return result, img_preds[0] 93 | 94 | def set_normalizer(self, normalizer: LinearNormalizer): 95 | self.normalizer.load_state_dict(normalizer.state_dict()) 96 | for model in self.models_high_level: 97 | model.set_normalizer(normalizer) 98 | self.model_low_level.set_normalizer(normalizer) 99 | 100 | def train_on_batch(self, batch): 101 | obs = batch['obs'] 102 | 103 | obs_new = {} 104 | 105 | for key in self.keys: 106 | obs_new[key] = torch.cat([obs[key][:,0].unsqueeze(1), obs[key][:,-1].unsqueeze(1)], dim=1) 107 | low_level_batch = { 108 | 'obs':obs_new, 109 | 'action':batch['action'] 110 | } 111 | loss_low_level = self.model_low_level.compute_loss(low_level_batch) 112 | loss_low_level.backward() 113 | 114 | return {'low_level_loss':loss_low_level.item(), 'high_level_loss':0} 115 | 116 | def eval_on_batch(self, batch): 117 | obs = batch['obs'] 118 | 119 | obs_new = {} 120 | 121 | for key in self.keys: 122 | obs_new[key] = torch.cat([obs[key][:,0].unsqueeze(1), obs[key][:,-1].unsqueeze(1)], dim=1) 123 | low_level_batch = { 124 | 'obs':obs_new, 125 | 'action':batch['action'] 126 | } 127 | loss_low_level = self.model_low_level.compute_loss(low_level_batch) 128 | 129 | return {'low_level_loss':loss_low_level.item(), 'high_level_loss':0} -------------------------------------------------------------------------------- /VILP/policy/vilp_planning.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | from VILP.model.normalizer import LinearNormalizer 4 | from diffusion_policy.policy.base_lowdim_policy import BaseLowdimPolicy 5 | from VILP.taming.models.vqgan import VQModel 6 | from VILP.policy.latent_video_diffusion import LatentVideoDiffusion 7 | 8 | class VilpPlanning(BaseLowdimPolicy): 9 | def __init__(self, 10 | model_high_level:LatentVideoDiffusion, 11 | vqgan_config, 12 | subgoal_steps = 4, 13 | subgoal_interval=10, 14 | latent_dim = 18, 15 | latent_shape = [2,12,12], 16 | output_key = 'image', 17 | **kwargs): 18 | super().__init__() 19 | 20 | self.model_high_level = model_high_level 21 | self.vqgan = VQModel(**vqgan_config) 22 | for param in self.vqgan.parameters(): 23 | param.requires_grad = False 24 | 25 | self.normalizer = LinearNormalizer() 26 | 27 | self.subgoal_steps = subgoal_steps 28 | self.latent_dim = latent_dim 29 | self.subgoal_interval = subgoal_interval 30 | self.kwargs = kwargs 31 | self.latent_shape = latent_shape 32 | self.output_key = output_key 33 | 34 | # ========= inference ============ 35 | def predict_image(self, obs_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 36 | 37 | obs_pred = self.model_high_level.predict_latent(obs_dict)['latent_pred'].to(dtype=torch.float32) 38 | # batch*t, dim1, dim2, dim3 39 | obs_pred = obs_pred.reshape(-1, obs_pred.shape[2], obs_pred.shape[3], obs_pred.shape[4]) 40 | 41 | quant, emb_loss, info = self.vqgan.quantize(obs_pred) 42 | rec_output = self.vqgan.decode(quant) 43 | rec_output = rec_output.reshape(-1, self.subgoal_steps, rec_output.shape[1], rec_output.shape[2], rec_output.shape[3]) 44 | 45 | rec_output = torch.clamp(rec_output.to(dtype=torch.float32), -1., 1.) 46 | return rec_output[:,1:,:,:,:] # remove the first image 47 | 48 | 49 | # ========= training ============ 50 | def set_normalizer(self, normalizer: LinearNormalizer): 51 | self.normalizer.load_state_dict(normalizer.state_dict()) 52 | self.model_high_level.set_normalizer(normalizer) 53 | 54 | def train_on_batch(self, batch): 55 | 56 | obs = batch['obs'] 57 | img_pred = None 58 | if self.output_key == 'depth': 59 | img_pred = obs[self.output_key][:, 60 | self.subgoal_interval:self.subgoal_interval*(self.subgoal_steps):self.subgoal_interval, 0:1,:,:] 61 | # cat the first image 62 | img_pred = torch.cat([obs[self.output_key][:, 0, 0:1,:,:].unsqueeze(1), img_pred], dim=1) 63 | else: 64 | img_pred = obs[self.output_key][:, 65 | self.subgoal_interval:self.subgoal_interval*(self.subgoal_steps):self.subgoal_interval, :] 66 | # cat the first image 67 | img_pred = torch.cat([obs[self.output_key][:, 0, :].unsqueeze(1), img_pred], dim=1) 68 | 69 | batch_size, subgoal_steps, channels, height, width = img_pred.shape 70 | img_pred = img_pred.reshape(batch_size*subgoal_steps, channels, height, width) 71 | 72 | img_pred = img_pred.permute(0, 2, 3, 1) 73 | img_batch = {'image':img_pred} 74 | latent = self.vqgan.to_latent(self.vqgan.get_input(img_batch,'image')) 75 | 76 | latent = latent.reshape(batch_size, subgoal_steps, latent.shape[1], latent.shape[2], latent.shape[3]) 77 | 78 | 79 | obs_high_level = {} 80 | for key, value in obs.items(): 81 | obs_high_level[key] = value 82 | 83 | high_level_batch = { 84 | 'obs':obs_high_level, 85 | 'latent':latent 86 | } 87 | 88 | loss_high_level = self.model_high_level.compute_loss(high_level_batch) 89 | loss_high_level.backward() 90 | 91 | return { 'high_level_loss':loss_high_level.item()} 92 | 93 | def eval_on_batch(self, batch): 94 | obs = batch['obs'] 95 | img_pred = None 96 | if self.output_key == 'depth': 97 | img_pred = obs[self.output_key][:, 98 | self.subgoal_interval:self.subgoal_interval*(self.subgoal_steps):self.subgoal_interval, 0:1,:,:] 99 | # cat the first image 100 | img_pred = torch.cat([obs[self.output_key][:, 0, 0:1,:,:].unsqueeze(1), img_pred], dim=1) 101 | else: 102 | img_pred = obs[self.output_key][:, 103 | self.subgoal_interval:self.subgoal_interval*(self.subgoal_steps):self.subgoal_interval, :] 104 | # cat the first image 105 | img_pred = torch.cat([obs[self.output_key][:, 0, :].unsqueeze(1), img_pred], dim=1) 106 | 107 | batch_size, subgoal_steps, channels, height, width = img_pred.shape 108 | img_pred = img_pred.reshape(batch_size*subgoal_steps, channels, height, width) 109 | 110 | img_pred = img_pred.permute(0, 2, 3, 1) 111 | img_batch = {'image':img_pred} 112 | latent = self.vqgan.to_latent(self.vqgan.get_input(img_batch,'image')) 113 | 114 | latent = latent.reshape(batch_size, subgoal_steps, latent.shape[1], latent.shape[2], latent.shape[3]) 115 | obs_high_level = {} 116 | for key, value in obs.items(): 117 | obs_high_level[key] = value 118 | 119 | high_level_batch = { 120 | 'obs':obs_high_level, 121 | 'latent':latent 122 | } 123 | 124 | loss_high_level = self.model_high_level.compute_loss(high_level_batch) 125 | 126 | return { 'high_level_loss':loss_high_level.item()} -------------------------------------------------------------------------------- /VILP/taming/data/ade20k.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | from VILP.taming.data.sflckr import SegmentationBase # for examples included in repo 9 | 10 | 11 | class Examples(SegmentationBase): 12 | def __init__(self, size=256, random_crop=False, interpolation="bicubic"): 13 | super().__init__(data_csv="data/ade20k_examples.txt", 14 | data_root="data/ade20k_images", 15 | segmentation_root="data/ade20k_segmentations", 16 | size=size, random_crop=random_crop, 17 | interpolation=interpolation, 18 | n_labels=151, shift_segmentation=False) 19 | 20 | 21 | # With semantic map and scene label 22 | class ADE20kBase(Dataset): 23 | def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None): 24 | self.split = self.get_split() 25 | self.n_labels = 151 # unknown + 150 26 | self.data_csv = {"train": "data/ade20k_train.txt", 27 | "validation": "data/ade20k_test.txt"}[self.split] 28 | self.data_root = "data/ade20k_root" 29 | with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f: 30 | self.scene_categories = f.read().splitlines() 31 | self.scene_categories = dict(line.split() for line in self.scene_categories) 32 | with open(self.data_csv, "r") as f: 33 | self.image_paths = f.read().splitlines() 34 | self._length = len(self.image_paths) 35 | self.labels = { 36 | "relative_file_path_": [l for l in self.image_paths], 37 | "file_path_": [os.path.join(self.data_root, "images", l) 38 | for l in self.image_paths], 39 | "relative_segmentation_path_": [l.replace(".jpg", ".png") 40 | for l in self.image_paths], 41 | "segmentation_path_": [os.path.join(self.data_root, "annotations", 42 | l.replace(".jpg", ".png")) 43 | for l in self.image_paths], 44 | "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")] 45 | for l in self.image_paths], 46 | } 47 | 48 | size = None if size is not None and size<=0 else size 49 | self.size = size 50 | if crop_size is None: 51 | self.crop_size = size if size is not None else None 52 | else: 53 | self.crop_size = crop_size 54 | if self.size is not None: 55 | self.interpolation = interpolation 56 | self.interpolation = { 57 | "nearest": cv2.INTER_NEAREST, 58 | "bilinear": cv2.INTER_LINEAR, 59 | "bicubic": cv2.INTER_CUBIC, 60 | "area": cv2.INTER_AREA, 61 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 62 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 63 | interpolation=self.interpolation) 64 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 65 | interpolation=cv2.INTER_NEAREST) 66 | 67 | if crop_size is not None: 68 | self.center_crop = not random_crop 69 | if self.center_crop: 70 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) 71 | else: 72 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) 73 | self.preprocessor = self.cropper 74 | 75 | def __len__(self): 76 | return self._length 77 | 78 | def __getitem__(self, i): 79 | example = dict((k, self.labels[k][i]) for k in self.labels) 80 | image = Image.open(example["file_path_"]) 81 | if not image.mode == "RGB": 82 | image = image.convert("RGB") 83 | image = np.array(image).astype(np.uint8) 84 | if self.size is not None: 85 | image = self.image_rescaler(image=image)["image"] 86 | segmentation = Image.open(example["segmentation_path_"]) 87 | segmentation = np.array(segmentation).astype(np.uint8) 88 | if self.size is not None: 89 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 90 | if self.size is not None: 91 | processed = self.preprocessor(image=image, mask=segmentation) 92 | else: 93 | processed = {"image": image, "mask": segmentation} 94 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 95 | segmentation = processed["mask"] 96 | onehot = np.eye(self.n_labels)[segmentation] 97 | example["segmentation"] = onehot 98 | return example 99 | 100 | 101 | class ADE20kTrain(ADE20kBase): 102 | # default to random_crop=True 103 | def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None): 104 | super().__init__(config=config, size=size, random_crop=random_crop, 105 | interpolation=interpolation, crop_size=crop_size) 106 | 107 | def get_split(self): 108 | return "train" 109 | 110 | 111 | class ADE20kValidation(ADE20kBase): 112 | def get_split(self): 113 | return "validation" 114 | 115 | 116 | if __name__ == "__main__": 117 | dset = ADE20kValidation() 118 | ex = dset[0] 119 | for k in ["image", "scene_category", "segmentation"]: 120 | print(type(ex[k])) 121 | try: 122 | print(ex[k].shape) 123 | except: 124 | print(ex[k]) 125 | -------------------------------------------------------------------------------- /VILP/taming/data/annotated_objects_coco.py: -------------------------------------------------------------------------------- 1 | import json 2 | from itertools import chain 3 | from pathlib import Path 4 | from typing import Iterable, Dict, List, Callable, Any 5 | from collections import defaultdict 6 | 7 | from tqdm import tqdm 8 | 9 | from VILP.taming.data.annotated_objects_dataset import AnnotatedObjectsDataset 10 | from VILP.taming.data.helper_types import Annotation, ImageDescription, Category 11 | 12 | COCO_PATH_STRUCTURE = { 13 | 'train': { 14 | 'top_level': '', 15 | 'instances_annotations': 'annotations/instances_train2017.json', 16 | 'stuff_annotations': 'annotations/stuff_train2017.json', 17 | 'files': 'train2017' 18 | }, 19 | 'validation': { 20 | 'top_level': '', 21 | 'instances_annotations': 'annotations/instances_val2017.json', 22 | 'stuff_annotations': 'annotations/stuff_val2017.json', 23 | 'files': 'val2017' 24 | } 25 | } 26 | 27 | 28 | def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]: 29 | return { 30 | str(img['id']): ImageDescription( 31 | id=img['id'], 32 | license=img.get('license'), 33 | file_name=img['file_name'], 34 | coco_url=img['coco_url'], 35 | original_size=(img['width'], img['height']), 36 | date_captured=img.get('date_captured'), 37 | flickr_url=img.get('flickr_url') 38 | ) 39 | for img in description_json 40 | } 41 | 42 | 43 | def load_categories(category_json: Iterable) -> Dict[str, Category]: 44 | return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name']) 45 | for cat in category_json if cat['name'] != 'other'} 46 | 47 | 48 | def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription], 49 | category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]: 50 | annotations = defaultdict(list) 51 | total = sum(len(a) for a in annotations_json) 52 | for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total): 53 | image_id = str(ann['image_id']) 54 | if image_id not in image_descriptions: 55 | raise ValueError(f'image_id [{image_id}] has no image description.') 56 | category_id = ann['category_id'] 57 | try: 58 | category_no = category_no_for_id(str(category_id)) 59 | except KeyError: 60 | continue 61 | 62 | width, height = image_descriptions[image_id].original_size 63 | bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height) 64 | 65 | annotations[image_id].append( 66 | Annotation( 67 | id=ann['id'], 68 | area=bbox[2]*bbox[3], # use bbox area 69 | is_group_of=ann['iscrowd'], 70 | image_id=ann['image_id'], 71 | bbox=bbox, 72 | category_id=str(category_id), 73 | category_no=category_no 74 | ) 75 | ) 76 | return dict(annotations) 77 | 78 | 79 | class AnnotatedObjectsCoco(AnnotatedObjectsDataset): 80 | def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs): 81 | """ 82 | @param data_path: is the path to the following folder structure: 83 | coco/ 84 | ├── annotations 85 | │ ├── instances_train2017.json 86 | │ ├── instances_val2017.json 87 | │ ├── stuff_train2017.json 88 | │ └── stuff_val2017.json 89 | ├── train2017 90 | │ ├── 000000000009.jpg 91 | │ ├── 000000000025.jpg 92 | │ └── ... 93 | ├── val2017 94 | │ ├── 000000000139.jpg 95 | │ ├── 000000000285.jpg 96 | │ └── ... 97 | @param: split: one of 'train' or 'validation' 98 | @param: desired image size (give square images) 99 | """ 100 | super().__init__(**kwargs) 101 | self.use_things = use_things 102 | self.use_stuff = use_stuff 103 | 104 | with open(self.paths['instances_annotations']) as f: 105 | inst_data_json = json.load(f) 106 | with open(self.paths['stuff_annotations']) as f: 107 | stuff_data_json = json.load(f) 108 | 109 | category_jsons = [] 110 | annotation_jsons = [] 111 | if self.use_things: 112 | category_jsons.append(inst_data_json['categories']) 113 | annotation_jsons.append(inst_data_json['annotations']) 114 | if self.use_stuff: 115 | category_jsons.append(stuff_data_json['categories']) 116 | annotation_jsons.append(stuff_data_json['annotations']) 117 | 118 | self.categories = load_categories(chain(*category_jsons)) 119 | self.filter_categories() 120 | self.setup_category_id_and_number() 121 | 122 | self.image_descriptions = load_image_descriptions(inst_data_json['images']) 123 | annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split) 124 | self.annotations = self.filter_object_number(annotations, self.min_object_area, 125 | self.min_objects_per_image, self.max_objects_per_image) 126 | self.image_ids = list(self.annotations.keys()) 127 | self.clean_up_annotations_and_image_descriptions() 128 | 129 | def get_path_structure(self) -> Dict[str, str]: 130 | if self.split not in COCO_PATH_STRUCTURE: 131 | raise ValueError(f'Split [{self.split} does not exist for COCO data.]') 132 | return COCO_PATH_STRUCTURE[self.split] 133 | 134 | def get_image_path(self, image_id: str) -> Path: 135 | return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name) 136 | 137 | def get_image_description(self, image_id: str) -> Dict[str, Any]: 138 | # noinspection PyProtectedMember 139 | return self.image_descriptions[image_id]._asdict() 140 | -------------------------------------------------------------------------------- /VILP/taming/data/base.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import albumentations 4 | from PIL import Image 5 | from torch.utils.data import Dataset, ConcatDataset 6 | 7 | 8 | class ConcatDatasetWithIndex(ConcatDataset): 9 | """Modified from original pytorch code to return dataset idx""" 10 | def __getitem__(self, idx): 11 | if idx < 0: 12 | if -idx > len(self): 13 | raise ValueError("absolute value of index should not exceed dataset length") 14 | idx = len(self) + idx 15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 16 | if dataset_idx == 0: 17 | sample_idx = idx 18 | else: 19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 20 | return self.datasets[dataset_idx][sample_idx], dataset_idx 21 | 22 | 23 | class ImagePaths(Dataset): 24 | def __init__(self, paths, size=None, random_crop=False, labels=None): 25 | self.size = size 26 | self.random_crop = random_crop 27 | 28 | self.labels = dict() if labels is None else labels 29 | self.labels["file_path_"] = paths 30 | self._length = len(paths) 31 | 32 | if self.size is not None and self.size > 0: 33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 34 | if not self.random_crop: 35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 36 | else: 37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 39 | else: 40 | self.preprocessor = lambda **kwargs: kwargs 41 | 42 | def __len__(self): 43 | return self._length 44 | 45 | def preprocess_image(self, image_path): 46 | image = Image.open(image_path) 47 | if not image.mode == "RGB": 48 | image = image.convert("RGB") 49 | image = np.array(image).astype(np.uint8) 50 | image = self.preprocessor(image=image)["image"] 51 | image = (image/127.5 - 1.0).astype(np.float32) 52 | return image 53 | 54 | def __getitem__(self, i): 55 | example = dict() 56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 57 | for k in self.labels: 58 | example[k] = self.labels[k][i] 59 | return example 60 | 61 | 62 | class NumpyPaths(ImagePaths): 63 | def preprocess_image(self, image_path): 64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 65 | image = np.transpose(image, (1,2,0)) 66 | image = Image.fromarray(image, mode="RGB") 67 | image = np.array(image).astype(np.uint8) 68 | image = self.preprocessor(image=image)["image"] 69 | image = (image/127.5 - 1.0).astype(np.float32) 70 | return image 71 | -------------------------------------------------------------------------------- /VILP/taming/data/conditional_builder/objects_bbox.py: -------------------------------------------------------------------------------- 1 | from itertools import cycle 2 | from typing import List, Tuple, Callable, Optional 3 | 4 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont 5 | from more_itertools.recipes import grouper 6 | from VILP.taming.data.image_transforms import convert_pil_to_tensor 7 | from torch import LongTensor, Tensor 8 | 9 | from VILP.taming.data.helper_types import BoundingBox, Annotation 10 | from VILP.taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder 11 | from VILP.taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ 12 | pad_list, get_plot_font_size, absolute_bbox 13 | 14 | 15 | class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): 16 | @property 17 | def object_descriptor_length(self) -> int: 18 | return 3 19 | 20 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: 21 | object_triples = [ 22 | (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) 23 | for ann in annotations 24 | ] 25 | empty_triple = (self.none, self.none, self.none) 26 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) 27 | return object_triples 28 | 29 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: 30 | conditional_list = conditional.tolist() 31 | crop_coordinates = None 32 | if self.encode_crop: 33 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) 34 | conditional_list = conditional_list[:-2] 35 | object_triples = grouper(conditional_list, 3) 36 | assert conditional.shape[0] == self.embedding_dim 37 | return [ 38 | (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) 39 | for object_triple in object_triples if object_triple[0] != self.none 40 | ], crop_coordinates 41 | 42 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], 43 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor: 44 | plot = pil_image.new('RGB', figure_size, WHITE) 45 | draw = pil_img_draw.Draw(plot) 46 | font = ImageFont.truetype( 47 | "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", 48 | size=get_plot_font_size(font_size, figure_size) 49 | ) 50 | width, height = plot.size 51 | description, crop_coordinates = self.inverse_build(conditional) 52 | for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): 53 | annotation = self.representation_to_annotation(representation) 54 | class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) 55 | bbox = absolute_bbox(bbox, width, height) 56 | draw.rectangle(bbox, outline=color, width=line_width) 57 | draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) 58 | if crop_coordinates is not None: 59 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) 60 | return convert_pil_to_tensor(plot) / 127.5 - 1. 61 | -------------------------------------------------------------------------------- /VILP/taming/data/conditional_builder/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import List, Any, Tuple, Optional 3 | 4 | from VILP.taming.data.helper_types import BoundingBox, Annotation 5 | 6 | # source: seaborn, color palette tab10 7 | COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), 8 | (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] 9 | BLACK = (0, 0, 0) 10 | GRAY_75 = (63, 63, 63) 11 | GRAY_50 = (127, 127, 127) 12 | GRAY_25 = (191, 191, 191) 13 | WHITE = (255, 255, 255) 14 | FULL_CROP = (0., 0., 1., 1.) 15 | 16 | 17 | def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: 18 | """ 19 | Give intersection area of two rectangles. 20 | @param rectangle1: (x0, y0, w, h) of first rectangle 21 | @param rectangle2: (x0, y0, w, h) of second rectangle 22 | """ 23 | rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] 24 | rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] 25 | x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) 26 | y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) 27 | return x_overlap * y_overlap 28 | 29 | 30 | def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox: 31 | return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3] 32 | 33 | 34 | def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]: 35 | bbox = relative_bbox 36 | bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height 37 | return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) 38 | 39 | 40 | def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: 41 | return list_ + [pad_element for _ in range(pad_to_length - len(list_))] 42 | 43 | 44 | def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ 45 | List[Annotation]: 46 | def clamp(x: float): 47 | return max(min(x, 1.), 0.) 48 | 49 | def rescale_bbox(bbox: BoundingBox) -> BoundingBox: 50 | x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) 51 | y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) 52 | w = min(bbox[2] / crop_coordinates[2], 1 - x0) 53 | h = min(bbox[3] / crop_coordinates[3], 1 - y0) 54 | if flip: 55 | x0 = 1 - (x0 + w) 56 | return x0, y0, w, h 57 | 58 | return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations] 59 | 60 | 61 | def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List: 62 | return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0] 63 | 64 | 65 | def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: 66 | sl = slice(1) if short else slice(None) 67 | string = '' 68 | if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): 69 | return string 70 | if annotation.is_group_of: 71 | string += 'group'[sl] + ',' 72 | if annotation.is_occluded: 73 | string += 'occluded'[sl] + ',' 74 | if annotation.is_depiction: 75 | string += 'depiction'[sl] + ',' 76 | if annotation.is_inside: 77 | string += 'inside'[sl] 78 | return '(' + string.strip(",") + ')' 79 | 80 | 81 | def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: 82 | if font_size is None: 83 | font_size = 10 84 | if max(figure_size) >= 256: 85 | font_size = 12 86 | if max(figure_size) >= 512: 87 | font_size = 15 88 | return font_size 89 | 90 | 91 | def get_circle_size(figure_size: Tuple[int, int]) -> int: 92 | circle_size = 2 93 | if max(figure_size) >= 256: 94 | circle_size = 3 95 | if max(figure_size) >= 512: 96 | circle_size = 4 97 | return circle_size 98 | 99 | 100 | def load_object_from_string(object_string: str) -> Any: 101 | """ 102 | Source: https://stackoverflow.com/a/10773699 103 | """ 104 | module_name, class_name = object_string.rsplit(".", 1) 105 | return getattr(importlib.import_module(module_name), class_name) 106 | -------------------------------------------------------------------------------- /VILP/taming/data/custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from VILP.taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class CustomBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | 14 | def __len__(self): 15 | return len(self.data) 16 | 17 | def __getitem__(self, i): 18 | example = self.data[i] 19 | return example 20 | 21 | 22 | 23 | class CustomTrain(CustomBase): 24 | def __init__(self, size, training_images_list_file): 25 | super().__init__() 26 | with open(training_images_list_file, "r") as f: 27 | paths = f.read().splitlines() 28 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 29 | 30 | 31 | class CustomTest(CustomBase): 32 | def __init__(self, size, test_images_list_file): 33 | super().__init__() 34 | with open(test_images_list_file, "r") as f: 35 | paths = f.read().splitlines() 36 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 37 | 38 | 39 | -------------------------------------------------------------------------------- /VILP/taming/data/faceshq.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from VILP.taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class FacesBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | self.keys = None 14 | 15 | def __len__(self): 16 | return len(self.data) 17 | 18 | def __getitem__(self, i): 19 | example = self.data[i] 20 | ex = {} 21 | if self.keys is not None: 22 | for k in self.keys: 23 | ex[k] = example[k] 24 | else: 25 | ex = example 26 | return ex 27 | 28 | 29 | class CelebAHQTrain(FacesBase): 30 | def __init__(self, size, keys=None): 31 | super().__init__() 32 | root = "data/celebahq" 33 | with open("data/celebahqtrain.txt", "r") as f: 34 | relpaths = f.read().splitlines() 35 | paths = [os.path.join(root, relpath) for relpath in relpaths] 36 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False) 37 | self.keys = keys 38 | 39 | 40 | class CelebAHQValidation(FacesBase): 41 | def __init__(self, size, keys=None): 42 | super().__init__() 43 | root = "data/celebahq" 44 | with open("data/celebahqvalidation.txt", "r") as f: 45 | relpaths = f.read().splitlines() 46 | paths = [os.path.join(root, relpath) for relpath in relpaths] 47 | self.data = NumpyPaths(paths=paths, size=size, random_crop=False) 48 | self.keys = keys 49 | 50 | 51 | class FFHQTrain(FacesBase): 52 | def __init__(self, size, keys=None): 53 | super().__init__() 54 | root = "data/ffhq" 55 | with open("data/ffhqtrain.txt", "r") as f: 56 | relpaths = f.read().splitlines() 57 | paths = [os.path.join(root, relpath) for relpath in relpaths] 58 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 59 | self.keys = keys 60 | 61 | 62 | class FFHQValidation(FacesBase): 63 | def __init__(self, size, keys=None): 64 | super().__init__() 65 | root = "data/ffhq" 66 | with open("data/ffhqvalidation.txt", "r") as f: 67 | relpaths = f.read().splitlines() 68 | paths = [os.path.join(root, relpath) for relpath in relpaths] 69 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 70 | self.keys = keys 71 | 72 | 73 | class FacesHQTrain(Dataset): 74 | # CelebAHQ [0] + FFHQ [1] 75 | def __init__(self, size, keys=None, crop_size=None, coord=False): 76 | d1 = CelebAHQTrain(size=size, keys=keys) 77 | d2 = FFHQTrain(size=size, keys=keys) 78 | self.data = ConcatDatasetWithIndex([d1, d2]) 79 | self.coord = coord 80 | if crop_size is not None: 81 | self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) 82 | if self.coord: 83 | self.cropper = albumentations.Compose([self.cropper], 84 | additional_targets={"coord": "image"}) 85 | 86 | def __len__(self): 87 | return len(self.data) 88 | 89 | def __getitem__(self, i): 90 | ex, y = self.data[i] 91 | if hasattr(self, "cropper"): 92 | if not self.coord: 93 | out = self.cropper(image=ex["image"]) 94 | ex["image"] = out["image"] 95 | else: 96 | h,w,_ = ex["image"].shape 97 | coord = np.arange(h*w).reshape(h,w,1)/(h*w) 98 | out = self.cropper(image=ex["image"], coord=coord) 99 | ex["image"] = out["image"] 100 | ex["coord"] = out["coord"] 101 | ex["class"] = y 102 | return ex 103 | 104 | 105 | class FacesHQValidation(Dataset): 106 | # CelebAHQ [0] + FFHQ [1] 107 | def __init__(self, size, keys=None, crop_size=None, coord=False): 108 | d1 = CelebAHQValidation(size=size, keys=keys) 109 | d2 = FFHQValidation(size=size, keys=keys) 110 | self.data = ConcatDatasetWithIndex([d1, d2]) 111 | self.coord = coord 112 | if crop_size is not None: 113 | self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) 114 | if self.coord: 115 | self.cropper = albumentations.Compose([self.cropper], 116 | additional_targets={"coord": "image"}) 117 | 118 | def __len__(self): 119 | return len(self.data) 120 | 121 | def __getitem__(self, i): 122 | ex, y = self.data[i] 123 | if hasattr(self, "cropper"): 124 | if not self.coord: 125 | out = self.cropper(image=ex["image"]) 126 | ex["image"] = out["image"] 127 | else: 128 | h,w,_ = ex["image"].shape 129 | coord = np.arange(h*w).reshape(h,w,1)/(h*w) 130 | out = self.cropper(image=ex["image"], coord=coord) 131 | ex["image"] = out["image"] 132 | ex["coord"] = out["coord"] 133 | ex["class"] = y 134 | return ex 135 | -------------------------------------------------------------------------------- /VILP/taming/data/helper_types.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Optional, NamedTuple, Union 2 | from PIL.Image import Image as pil_image 3 | from torch import Tensor 4 | 5 | try: 6 | from typing import Literal 7 | except ImportError: 8 | from typing_extensions import Literal 9 | 10 | Image = Union[Tensor, pil_image] 11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h 12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d'] 13 | SplitType = Literal['train', 'validation', 'test'] 14 | 15 | 16 | class ImageDescription(NamedTuple): 17 | id: int 18 | file_name: str 19 | original_size: Tuple[int, int] # w, h 20 | url: Optional[str] = None 21 | license: Optional[int] = None 22 | coco_url: Optional[str] = None 23 | date_captured: Optional[str] = None 24 | flickr_url: Optional[str] = None 25 | flickr_id: Optional[str] = None 26 | coco_id: Optional[str] = None 27 | 28 | 29 | class Category(NamedTuple): 30 | id: str 31 | super_category: Optional[str] 32 | name: str 33 | 34 | 35 | class Annotation(NamedTuple): 36 | area: float 37 | image_id: str 38 | bbox: BoundingBox 39 | category_no: int 40 | category_id: str 41 | id: Optional[int] = None 42 | source: Optional[str] = None 43 | confidence: Optional[float] = None 44 | is_group_of: Optional[bool] = None 45 | is_truncated: Optional[bool] = None 46 | is_occluded: Optional[bool] = None 47 | is_depiction: Optional[bool] = None 48 | is_inside: Optional[bool] = None 49 | segmentation: Optional[Dict] = None 50 | -------------------------------------------------------------------------------- /VILP/taming/data/image_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | from typing import Union 4 | 5 | import torch 6 | from torch import Tensor 7 | from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor 8 | from torchvision.transforms.functional import get_image_size as get_image_size 9 | 10 | from VILP.taming.data.helper_types import BoundingBox, Image 11 | 12 | pil_to_tensor = PILToTensor() 13 | 14 | 15 | def convert_pil_to_tensor(image: Image) -> Tensor: 16 | with warnings.catch_warnings(): 17 | # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194 18 | warnings.simplefilter("ignore") 19 | return pil_to_tensor(image) 20 | 21 | 22 | class RandomCrop1dReturnCoordinates(RandomCrop): 23 | def forward(self, img: Image) -> (BoundingBox, Image): 24 | """ 25 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 26 | Args: 27 | img (PIL Image or Tensor): Image to be cropped. 28 | 29 | Returns: 30 | Bounding box: x0, y0, w, h 31 | PIL Image or Tensor: Cropped image. 32 | 33 | Based on: 34 | torchvision.transforms.RandomCrop, torchvision 1.7.0 35 | """ 36 | if self.padding is not None: 37 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 38 | 39 | width, height = get_image_size(img) 40 | # pad the width if needed 41 | if self.pad_if_needed and width < self.size[1]: 42 | padding = [self.size[1] - width, 0] 43 | img = F.pad(img, padding, self.fill, self.padding_mode) 44 | # pad the height if needed 45 | if self.pad_if_needed and height < self.size[0]: 46 | padding = [0, self.size[0] - height] 47 | img = F.pad(img, padding, self.fill, self.padding_mode) 48 | 49 | i, j, h, w = self.get_params(img, self.size) 50 | bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h 51 | return bbox, F.crop(img, i, j, h, w) 52 | 53 | 54 | class Random2dCropReturnCoordinates(torch.nn.Module): 55 | """ 56 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 57 | Args: 58 | img (PIL Image or Tensor): Image to be cropped. 59 | 60 | Returns: 61 | Bounding box: x0, y0, w, h 62 | PIL Image or Tensor: Cropped image. 63 | 64 | Based on: 65 | torchvision.transforms.RandomCrop, torchvision 1.7.0 66 | """ 67 | 68 | def __init__(self, min_size: int): 69 | super().__init__() 70 | self.min_size = min_size 71 | 72 | def forward(self, img: Image) -> (BoundingBox, Image): 73 | width, height = get_image_size(img) 74 | max_size = min(width, height) 75 | if max_size <= self.min_size: 76 | size = max_size 77 | else: 78 | size = random.randint(self.min_size, max_size) 79 | top = random.randint(0, height - size) 80 | left = random.randint(0, width - size) 81 | bbox = left / width, top / height, size / width, size / height 82 | return bbox, F.crop(img, top, left, size, size) 83 | 84 | 85 | class CenterCropReturnCoordinates(CenterCrop): 86 | @staticmethod 87 | def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox: 88 | if width > height: 89 | w = height / width 90 | h = 1.0 91 | x0 = 0.5 - w / 2 92 | y0 = 0. 93 | else: 94 | w = 1.0 95 | h = width / height 96 | x0 = 0. 97 | y0 = 0.5 - h / 2 98 | return x0, y0, w, h 99 | 100 | def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]): 101 | """ 102 | Additionally to cropping, returns the relative coordinates of the crop bounding box. 103 | Args: 104 | img (PIL Image or Tensor): Image to be cropped. 105 | 106 | Returns: 107 | Bounding box: x0, y0, w, h 108 | PIL Image or Tensor: Cropped image. 109 | Based on: 110 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 111 | """ 112 | width, height = get_image_size(img) 113 | return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) 114 | 115 | 116 | class RandomHorizontalFlipReturn(RandomHorizontalFlip): 117 | def forward(self, img: Image) -> (bool, Image): 118 | """ 119 | Additionally to flipping, returns a boolean whether it was flipped or not. 120 | Args: 121 | img (PIL Image or Tensor): Image to be flipped. 122 | 123 | Returns: 124 | flipped: whether the image was flipped or not 125 | PIL Image or Tensor: Randomly flipped image. 126 | 127 | Based on: 128 | torchvision.transforms.RandomHorizontalFlip (version 1.7.0) 129 | """ 130 | if torch.rand(1) < self.p: 131 | return True, F.hflip(img) 132 | return False, img 133 | -------------------------------------------------------------------------------- /VILP/taming/data/sflckr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class SegmentationBase(Dataset): 10 | def __init__(self, 11 | data_csv, data_root, segmentation_root, 12 | size=None, random_crop=False, interpolation="bicubic", 13 | n_labels=182, shift_segmentation=False, 14 | ): 15 | self.n_labels = n_labels 16 | self.shift_segmentation = shift_segmentation 17 | self.data_csv = data_csv 18 | self.data_root = data_root 19 | self.segmentation_root = segmentation_root 20 | with open(self.data_csv, "r") as f: 21 | self.image_paths = f.read().splitlines() 22 | self._length = len(self.image_paths) 23 | self.labels = { 24 | "relative_file_path_": [l for l in self.image_paths], 25 | "file_path_": [os.path.join(self.data_root, l) 26 | for l in self.image_paths], 27 | "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) 28 | for l in self.image_paths] 29 | } 30 | 31 | size = None if size is not None and size<=0 else size 32 | self.size = size 33 | if self.size is not None: 34 | self.interpolation = interpolation 35 | self.interpolation = { 36 | "nearest": cv2.INTER_NEAREST, 37 | "bilinear": cv2.INTER_LINEAR, 38 | "bicubic": cv2.INTER_CUBIC, 39 | "area": cv2.INTER_AREA, 40 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 41 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 42 | interpolation=self.interpolation) 43 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 44 | interpolation=cv2.INTER_NEAREST) 45 | self.center_crop = not random_crop 46 | if self.center_crop: 47 | self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) 48 | else: 49 | self.cropper = albumentations.RandomCrop(height=self.size, width=self.size) 50 | self.preprocessor = self.cropper 51 | 52 | def __len__(self): 53 | return self._length 54 | 55 | def __getitem__(self, i): 56 | example = dict((k, self.labels[k][i]) for k in self.labels) 57 | image = Image.open(example["file_path_"]) 58 | if not image.mode == "RGB": 59 | image = image.convert("RGB") 60 | image = np.array(image).astype(np.uint8) 61 | if self.size is not None: 62 | image = self.image_rescaler(image=image)["image"] 63 | segmentation = Image.open(example["segmentation_path_"]) 64 | assert segmentation.mode == "L", segmentation.mode 65 | segmentation = np.array(segmentation).astype(np.uint8) 66 | if self.shift_segmentation: 67 | # used to support segmentations containing unlabeled==255 label 68 | segmentation = segmentation+1 69 | if self.size is not None: 70 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 71 | if self.size is not None: 72 | processed = self.preprocessor(image=image, 73 | mask=segmentation 74 | ) 75 | else: 76 | processed = {"image": image, 77 | "mask": segmentation 78 | } 79 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 80 | segmentation = processed["mask"] 81 | onehot = np.eye(self.n_labels)[segmentation] 82 | example["segmentation"] = onehot 83 | return example 84 | 85 | 86 | class Examples(SegmentationBase): 87 | def __init__(self, size=None, random_crop=False, interpolation="bicubic"): 88 | super().__init__(data_csv="data/sflckr_examples.txt", 89 | data_root="data/sflckr_images", 90 | segmentation_root="data/sflckr_segmentations", 91 | size=size, random_crop=random_crop, interpolation=interpolation) 92 | -------------------------------------------------------------------------------- /VILP/taming/data/utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import tarfile 4 | import urllib 5 | import zipfile 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | from VILP.taming.data.helper_types import Annotation 11 | from torch._six import string_classes 12 | from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format 13 | from tqdm import tqdm 14 | 15 | 16 | def unpack(path): 17 | if path.endswith("tar.gz"): 18 | with tarfile.open(path, "r:gz") as tar: 19 | tar.extractall(path=os.path.split(path)[0]) 20 | elif path.endswith("tar"): 21 | with tarfile.open(path, "r:") as tar: 22 | tar.extractall(path=os.path.split(path)[0]) 23 | elif path.endswith("zip"): 24 | with zipfile.ZipFile(path, "r") as f: 25 | f.extractall(path=os.path.split(path)[0]) 26 | else: 27 | raise NotImplementedError( 28 | "Unknown file extension: {}".format(os.path.splitext(path)[1]) 29 | ) 30 | 31 | 32 | def reporthook(bar): 33 | """tqdm progress bar for downloads.""" 34 | 35 | def hook(b=1, bsize=1, tsize=None): 36 | if tsize is not None: 37 | bar.total = tsize 38 | bar.update(b * bsize - bar.n) 39 | 40 | return hook 41 | 42 | 43 | def get_root(name): 44 | base = "data/" 45 | root = os.path.join(base, name) 46 | os.makedirs(root, exist_ok=True) 47 | return root 48 | 49 | 50 | def is_prepared(root): 51 | return Path(root).joinpath(".ready").exists() 52 | 53 | 54 | def mark_prepared(root): 55 | Path(root).joinpath(".ready").touch() 56 | 57 | 58 | def prompt_download(file_, source, target_dir, content_dir=None): 59 | targetpath = os.path.join(target_dir, file_) 60 | while not os.path.exists(targetpath): 61 | if content_dir is not None and os.path.exists( 62 | os.path.join(target_dir, content_dir) 63 | ): 64 | break 65 | print( 66 | "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) 67 | ) 68 | if content_dir is not None: 69 | print( 70 | "Or place its content into '{}'.".format( 71 | os.path.join(target_dir, content_dir) 72 | ) 73 | ) 74 | input("Press Enter when done...") 75 | return targetpath 76 | 77 | 78 | def download_url(file_, url, target_dir): 79 | targetpath = os.path.join(target_dir, file_) 80 | os.makedirs(target_dir, exist_ok=True) 81 | with tqdm( 82 | unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ 83 | ) as bar: 84 | urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) 85 | return targetpath 86 | 87 | 88 | def download_urls(urls, target_dir): 89 | paths = dict() 90 | for fname, url in urls.items(): 91 | outpath = download_url(fname, url, target_dir) 92 | paths[fname] = outpath 93 | return paths 94 | 95 | 96 | def quadratic_crop(x, bbox, alpha=1.0): 97 | """bbox is xmin, ymin, xmax, ymax""" 98 | im_h, im_w = x.shape[:2] 99 | bbox = np.array(bbox, dtype=np.float32) 100 | bbox = np.clip(bbox, 0, max(im_h, im_w)) 101 | center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) 102 | w = bbox[2] - bbox[0] 103 | h = bbox[3] - bbox[1] 104 | l = int(alpha * max(w, h)) 105 | l = max(l, 2) 106 | 107 | required_padding = -1 * min( 108 | center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) 109 | ) 110 | required_padding = int(np.ceil(required_padding)) 111 | if required_padding > 0: 112 | padding = [ 113 | [required_padding, required_padding], 114 | [required_padding, required_padding], 115 | ] 116 | padding += [[0, 0]] * (len(x.shape) - 2) 117 | x = np.pad(x, padding, "reflect") 118 | center = center[0] + required_padding, center[1] + required_padding 119 | xmin = int(center[0] - l / 2) 120 | ymin = int(center[1] - l / 2) 121 | return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) 122 | 123 | 124 | def custom_collate(batch): 125 | r"""source: pytorch 1.9.0, only one modification to original code """ 126 | 127 | elem = batch[0] 128 | elem_type = type(elem) 129 | if isinstance(elem, torch.Tensor): 130 | out = None 131 | if torch.utils.data.get_worker_info() is not None: 132 | # If we're in a background process, concatenate directly into a 133 | # shared memory tensor to avoid an extra copy 134 | numel = sum([x.numel() for x in batch]) 135 | storage = elem.storage()._new_shared(numel) 136 | out = elem.new(storage) 137 | return torch.stack(batch, 0, out=out) 138 | elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ 139 | and elem_type.__name__ != 'string_': 140 | if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': 141 | # array of string classes and object 142 | if np_str_obj_array_pattern.search(elem.dtype.str) is not None: 143 | raise TypeError(default_collate_err_msg_format.format(elem.dtype)) 144 | 145 | return custom_collate([torch.as_tensor(b) for b in batch]) 146 | elif elem.shape == (): # scalars 147 | return torch.as_tensor(batch) 148 | elif isinstance(elem, float): 149 | return torch.tensor(batch, dtype=torch.float64) 150 | elif isinstance(elem, int): 151 | return torch.tensor(batch) 152 | elif isinstance(elem, string_classes): 153 | return batch 154 | elif isinstance(elem, collections.abc.Mapping): 155 | return {key: custom_collate([d[key] for d in batch]) for key in elem} 156 | elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple 157 | return elem_type(*(custom_collate(samples) for samples in zip(*batch))) 158 | if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added 159 | return batch # added 160 | elif isinstance(elem, collections.abc.Sequence): 161 | # check to make sure that the elements in batch have consistent size 162 | it = iter(batch) 163 | elem_size = len(next(it)) 164 | if not all(len(elem) == elem_size for elem in it): 165 | raise RuntimeError('each element in list of batch should be of equal size') 166 | transposed = zip(*batch) 167 | return [custom_collate(samples) for samples in transposed] 168 | 169 | raise TypeError(default_collate_err_msg_format.format(elem_type)) 170 | -------------------------------------------------------------------------------- /VILP/taming/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n): 33 | return self.schedule(n) 34 | 35 | -------------------------------------------------------------------------------- /VILP/taming/models/dummy_cond_stage.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | class DummyCondStage: 5 | def __init__(self, conditional_key): 6 | self.conditional_key = conditional_key 7 | self.train = None 8 | 9 | def eval(self): 10 | return self 11 | 12 | @staticmethod 13 | def encode(c: Tensor): 14 | return c, None, (None, None, c) 15 | 16 | @staticmethod 17 | def decode(c: Tensor): 18 | return c 19 | 20 | @staticmethod 21 | def to_rgb(c: Tensor): 22 | return c 23 | -------------------------------------------------------------------------------- /VILP/taming/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from VILP.taming.modules.util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | -------------------------------------------------------------------------------- /VILP/taming/modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | 8 | from VILP.taming.util import get_ckpt_path 9 | 10 | 11 | class LPIPS(nn.Module): 12 | # Learned perceptual metric 13 | def __init__(self, use_dropout=True): 14 | super().__init__() 15 | self.scaling_layer = ScalingLayer() 16 | self.chns = [64, 128, 256, 512, 512] # vg16 features 17 | self.net = vgg16(pretrained=True, requires_grad=False) 18 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 19 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 20 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 21 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 22 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 23 | self.load_from_pretrained() 24 | for param in self.parameters(): 25 | param.requires_grad = False 26 | 27 | def load_from_pretrained(self, name="vgg_lpips"): 28 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 29 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 30 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 31 | 32 | @classmethod 33 | def from_pretrained(cls, name="vgg_lpips"): 34 | if name != "vgg_lpips": 35 | raise NotImplementedError 36 | model = cls() 37 | ckpt = get_ckpt_path(name) 38 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 39 | return model 40 | 41 | def forward(self, input, target): 42 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 43 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 44 | feats0, feats1, diffs = {}, {}, {} 45 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 46 | for kk in range(len(self.chns)): 47 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 48 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 49 | 50 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 51 | val = res[0] 52 | for l in range(1, len(self.chns)): 53 | val += res[l] 54 | return val 55 | 56 | 57 | class ScalingLayer(nn.Module): 58 | def __init__(self): 59 | super(ScalingLayer, self).__init__() 60 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 61 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 62 | 63 | def forward(self, inp): 64 | return (inp - self.shift) / self.scale 65 | 66 | 67 | class NetLinLayer(nn.Module): 68 | """ A single linear layer which does a 1x1 conv """ 69 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 70 | super(NetLinLayer, self).__init__() 71 | layers = [nn.Dropout(), ] if (use_dropout) else [] 72 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 73 | self.model = nn.Sequential(*layers) 74 | 75 | 76 | class vgg16(torch.nn.Module): 77 | def __init__(self, requires_grad=False, pretrained=True): 78 | super(vgg16, self).__init__() 79 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 80 | self.slice1 = torch.nn.Sequential() 81 | self.slice2 = torch.nn.Sequential() 82 | self.slice3 = torch.nn.Sequential() 83 | self.slice4 = torch.nn.Sequential() 84 | self.slice5 = torch.nn.Sequential() 85 | self.N_slices = 5 86 | for x in range(4): 87 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 88 | for x in range(4, 9): 89 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 90 | for x in range(9, 16): 91 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 92 | for x in range(16, 23): 93 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 94 | for x in range(23, 30): 95 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 96 | if not requires_grad: 97 | for param in self.parameters(): 98 | param.requires_grad = False 99 | 100 | def forward(self, X): 101 | h = self.slice1(X) 102 | h_relu1_2 = h 103 | h = self.slice2(h) 104 | h_relu2_2 = h 105 | h = self.slice3(h) 106 | h_relu3_3 = h 107 | h = self.slice4(h) 108 | h_relu4_3 = h 109 | h = self.slice5(h) 110 | h_relu5_3 = h 111 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 112 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 113 | return out 114 | 115 | 116 | def normalize_tensor(x,eps=1e-10): 117 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 118 | return x/(norm_factor+eps) 119 | 120 | 121 | def spatial_average(x, keepdim=True): 122 | return x.mean([2,3],keepdim=keepdim) 123 | 124 | -------------------------------------------------------------------------------- /VILP/taming/modules/misc/coord.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CoordStage(object): 4 | def __init__(self, n_embed, down_factor): 5 | self.n_embed = n_embed 6 | self.down_factor = down_factor 7 | 8 | def eval(self): 9 | return self 10 | 11 | def encode(self, c): 12 | """fake vqmodel interface""" 13 | assert 0.0 <= c.min() and c.max() <= 1.0 14 | b,ch,h,w = c.shape 15 | assert ch == 1 16 | 17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, 18 | mode="area") 19 | c = c.clamp(0.0, 1.0) 20 | c = self.n_embed*c 21 | c_quant = c.round() 22 | c_ind = c_quant.to(dtype=torch.long) 23 | 24 | info = None, None, c_ind 25 | return c_quant, None, info 26 | 27 | def decode(self, c): 28 | c = c/self.n_embed 29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, 30 | mode="nearest") 31 | return c 32 | -------------------------------------------------------------------------------- /VILP/taming/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def count_params(model): 6 | total_params = sum(p.numel() for p in model.parameters()) 7 | return total_params 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, num_features, logdet=False, affine=True, 12 | allow_reverse_init=False): 13 | assert affine 14 | super().__init__() 15 | self.logdet = logdet 16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 18 | self.allow_reverse_init = allow_reverse_init 19 | 20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 21 | 22 | def initialize(self, input): 23 | with torch.no_grad(): 24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 25 | mean = ( 26 | flatten.mean(1) 27 | .unsqueeze(1) 28 | .unsqueeze(2) 29 | .unsqueeze(3) 30 | .permute(1, 0, 2, 3) 31 | ) 32 | std = ( 33 | flatten.std(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | 40 | self.loc.data.copy_(-mean) 41 | self.scale.data.copy_(1 / (std + 1e-6)) 42 | 43 | def forward(self, input, reverse=False): 44 | if reverse: 45 | return self.reverse(input) 46 | if len(input.shape) == 2: 47 | input = input[:,:,None,None] 48 | squeeze = True 49 | else: 50 | squeeze = False 51 | 52 | _, _, height, width = input.shape 53 | 54 | if self.training and self.initialized.item() == 0: 55 | self.initialize(input) 56 | self.initialized.fill_(1) 57 | 58 | h = self.scale * (input + self.loc) 59 | 60 | if squeeze: 61 | h = h.squeeze(-1).squeeze(-1) 62 | 63 | if self.logdet: 64 | log_abs = torch.log(torch.abs(self.scale)) 65 | logdet = height*width*torch.sum(log_abs) 66 | logdet = logdet * torch.ones(input.shape[0]).to(input) 67 | return h, logdet 68 | 69 | return h 70 | 71 | def reverse(self, output): 72 | if self.training and self.initialized.item() == 0: 73 | if not self.allow_reverse_init: 74 | raise RuntimeError( 75 | "Initializing ActNorm in reverse direction is " 76 | "disabled by default. Use allow_reverse_init=True to enable." 77 | ) 78 | else: 79 | self.initialize(output) 80 | self.initialized.fill_(1) 81 | 82 | if len(output.shape) == 2: 83 | output = output[:,:,None,None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | h = output / self.scale - self.loc 89 | 90 | if squeeze: 91 | h = h.squeeze(-1).squeeze(-1) 92 | return h 93 | 94 | 95 | class AbstractEncoder(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def encode(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | 103 | class Labelator(AbstractEncoder): 104 | """Net2Net Interface for Class-Conditional Model""" 105 | def __init__(self, n_classes, quantize_interface=True): 106 | super().__init__() 107 | self.n_classes = n_classes 108 | self.quantize_interface = quantize_interface 109 | 110 | def encode(self, c): 111 | c = c[:,None] 112 | if self.quantize_interface: 113 | return c, None, [None, None, c.long()] 114 | return c 115 | 116 | 117 | class SOSProvider(AbstractEncoder): 118 | # for unconditional training 119 | def __init__(self, sos_token, quantize_interface=True): 120 | super().__init__() 121 | self.sos_token = sos_token 122 | self.quantize_interface = quantize_interface 123 | 124 | def encode(self, x): 125 | # get batch size from data and replicate sos_token 126 | c = torch.ones(x.shape[0], 1)*self.sos_token 127 | c = c.long().to(x.device) 128 | if self.quantize_interface: 129 | return c, None, [None, None, c] 130 | return c 131 | -------------------------------------------------------------------------------- /VILP/taming/util.py: -------------------------------------------------------------------------------- 1 | import os, hashlib 2 | import requests 3 | from tqdm import tqdm 4 | 5 | URL_MAP = { 6 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 7 | } 8 | 9 | CKPT_MAP = { 10 | "vgg_lpips": "vgg.pth" 11 | } 12 | 13 | MD5_MAP = { 14 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 15 | } 16 | 17 | 18 | def download(url, local_path, chunk_size=1024): 19 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 20 | with requests.get(url, stream=True) as r: 21 | total_size = int(r.headers.get("content-length", 0)) 22 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 23 | with open(local_path, "wb") as f: 24 | for data in r.iter_content(chunk_size=chunk_size): 25 | if data: 26 | f.write(data) 27 | pbar.update(chunk_size) 28 | 29 | 30 | def md5_hash(path): 31 | with open(path, "rb") as f: 32 | content = f.read() 33 | return hashlib.md5(content).hexdigest() 34 | 35 | 36 | def get_ckpt_path(name, root, check=False): 37 | assert name in URL_MAP 38 | path = os.path.join(root, CKPT_MAP[name]) 39 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 40 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 41 | download(URL_MAP[name], path) 42 | md5 = md5_hash(path) 43 | assert md5 == MD5_MAP[name], md5 44 | return path 45 | 46 | 47 | class KeyNotFoundError(Exception): 48 | def __init__(self, cause, keys=None, visited=None): 49 | self.cause = cause 50 | self.keys = keys 51 | self.visited = visited 52 | messages = list() 53 | if keys is not None: 54 | messages.append("Key not found: {}".format(keys)) 55 | if visited is not None: 56 | messages.append("Visited: {}".format(visited)) 57 | messages.append("Cause:\n{}".format(cause)) 58 | message = "\n".join(messages) 59 | super().__init__(message) 60 | 61 | 62 | def retrieve( 63 | list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False 64 | ): 65 | """Given a nested list or dict return the desired value at key expanding 66 | callable nodes if necessary and :attr:`expand` is ``True``. The expansion 67 | is done in-place. 68 | 69 | Parameters 70 | ---------- 71 | list_or_dict : list or dict 72 | Possibly nested list or dictionary. 73 | key : str 74 | key/to/value, path like string describing all keys necessary to 75 | consider to get to the desired value. List indices can also be 76 | passed here. 77 | splitval : str 78 | String that defines the delimiter between keys of the 79 | different depth levels in `key`. 80 | default : obj 81 | Value returned if :attr:`key` is not found. 82 | expand : bool 83 | Whether to expand callable nodes on the path or not. 84 | 85 | Returns 86 | ------- 87 | The desired value or if :attr:`default` is not ``None`` and the 88 | :attr:`key` is not found returns ``default``. 89 | 90 | Raises 91 | ------ 92 | Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is 93 | ``None``. 94 | """ 95 | 96 | keys = key.split(splitval) 97 | 98 | success = True 99 | try: 100 | visited = [] 101 | parent = None 102 | last_key = None 103 | for key in keys: 104 | if callable(list_or_dict): 105 | if not expand: 106 | raise KeyNotFoundError( 107 | ValueError( 108 | "Trying to get past callable node with expand=False." 109 | ), 110 | keys=keys, 111 | visited=visited, 112 | ) 113 | list_or_dict = list_or_dict() 114 | parent[last_key] = list_or_dict 115 | 116 | last_key = key 117 | parent = list_or_dict 118 | 119 | try: 120 | if isinstance(list_or_dict, dict): 121 | list_or_dict = list_or_dict[key] 122 | else: 123 | list_or_dict = list_or_dict[int(key)] 124 | except (KeyError, IndexError, ValueError) as e: 125 | raise KeyNotFoundError(e, keys=keys, visited=visited) 126 | 127 | visited += [key] 128 | # final expansion of retrieved value 129 | if expand and callable(list_or_dict): 130 | list_or_dict = list_or_dict() 131 | parent[last_key] = list_or_dict 132 | except KeyNotFoundError as e: 133 | if default is None: 134 | raise e 135 | else: 136 | list_or_dict = default 137 | success = False 138 | 139 | if not pass_success: 140 | return list_or_dict 141 | else: 142 | return list_or_dict, success 143 | 144 | 145 | if __name__ == "__main__": 146 | config = {"keya": "a", 147 | "keyb": "b", 148 | "keyc": 149 | {"cc1": 1, 150 | "cc2": 2, 151 | } 152 | } 153 | from omegaconf import OmegaConf 154 | config = OmegaConf.create(config) 155 | print(config) 156 | retrieve(config, "keya") 157 | 158 | -------------------------------------------------------------------------------- /VILP/workspace/save_vilp_planning_workspace.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | os.chdir(ROOT_DIR) 9 | 10 | import os 11 | import hydra 12 | import torch 13 | from omegaconf import OmegaConf 14 | import pathlib 15 | from torch.utils.data import DataLoader 16 | import copy 17 | import random 18 | import wandb 19 | import tqdm 20 | import numpy as np 21 | import shutil 22 | from PIL import Image as PILImage 23 | from diffusion_policy.workspace.base_workspace import BaseWorkspace 24 | from diffusion_policy.policy.diffusion_unet_image_policy import DiffusionUnetImagePolicy 25 | from diffusion_policy.dataset.base_dataset import BaseImageDataset 26 | from diffusion_policy.env_runner.base_image_runner import BaseImageRunner 27 | from diffusion_policy.common.checkpoint_util import TopKCheckpointManager 28 | from diffusion_policy.common.json_logger import JsonLogger 29 | from diffusion_policy.common.pytorch_util import dict_apply, optimizer_to 30 | from diffusion_policy.model.diffusion.ema_model import EMAModel 31 | from diffusion_policy.model.common.lr_scheduler import get_scheduler 32 | from VILP.policy.vilp_planning import VilpPlanning 33 | OmegaConf.register_new_resolver("eval", eval, replace=True) 34 | 35 | 36 | class SaveVilpPlanningWorkspace(BaseWorkspace): 37 | include_keys = ['global_step', 'epoch'] 38 | 39 | def __init__(self, cfg: OmegaConf, output_dir=None): 40 | super().__init__(cfg, output_dir=output_dir) 41 | 42 | # set seed 43 | seed = cfg.training.seed 44 | torch.manual_seed(seed) 45 | np.random.seed(seed) 46 | random.seed(seed) 47 | 48 | # configure model 49 | self.model: VilpPlanning = hydra.utils.instantiate(cfg.policy) 50 | 51 | self.ema_model: VilpPlanning = None 52 | if cfg.training.use_ema: 53 | print("Using EMA model") 54 | self.ema_model = copy.deepcopy(self.model) 55 | # configure training state 56 | self.optimizer = hydra.utils.instantiate( 57 | cfg.optimizer, params=self.model.parameters()) 58 | 59 | # configure training state 60 | self.global_step = 0 61 | self.epoch = 0 62 | self.cfg = cfg 63 | 64 | def run(self): 65 | cfg = copy.deepcopy(self.cfg) 66 | 67 | # resume training 68 | if cfg.training.resume: 69 | lastest_ckpt_path = self.get_checkpoint_path() 70 | if lastest_ckpt_path.is_file(): 71 | print(f"Resuming from checkpoint {lastest_ckpt_path}") 72 | self.load_checkpoint(path=lastest_ckpt_path) 73 | print('saving model') 74 | folder_path = os.path.join('latent_planning', self.cfg.task_name) 75 | if not os.path.exists(folder_path): 76 | os.makedirs(folder_path) 77 | file_name = self.cfg.output_key + '.pth' 78 | file_path = os.path.join(folder_path, file_name) 79 | torch.save(self.ema_model.model_high_level.state_dict(), file_path) 80 | 81 | @hydra.main( 82 | version_base=None, 83 | config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")), 84 | config_name=pathlib.Path(__file__).stem) 85 | def main(cfg): 86 | workspace = SaveVilpPlanningWorkspace(cfg) 87 | workspace.run() 88 | 89 | if __name__ == "__main__": 90 | main() -------------------------------------------------------------------------------- /VILP/workspace/train_vqgan_workspace.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | import sys 3 | import os 4 | import pathlib 5 | 6 | ROOT_DIR = str(pathlib.Path(__file__).parent.parent.parent) 7 | sys.path.append(ROOT_DIR) 8 | os.chdir(ROOT_DIR) 9 | 10 | import os 11 | import hydra 12 | import torch 13 | from omegaconf import OmegaConf 14 | import pathlib 15 | from torch.utils.data import DataLoader 16 | import copy 17 | import random 18 | import wandb 19 | import numpy as np 20 | from diffusion_policy.workspace.base_workspace import BaseWorkspace 21 | from diffusion_policy.dataset.base_dataset import BaseImageDataset 22 | import argparse, os, sys, datetime, glob, importlib 23 | from omegaconf import OmegaConf 24 | import numpy as np 25 | from PIL import Image 26 | import torch 27 | import torchvision 28 | from torch.utils.data import random_split, DataLoader 29 | import pytorch_lightning as pl 30 | from pytorch_lightning.trainer import Trainer 31 | from pytorch_lightning.utilities import rank_zero_only 32 | from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor 33 | import wandb 34 | import copy 35 | from VILP.taming.models.vqgan import VQModel 36 | from VILP.dataset.robomimic_ae_dataset import RobomimicAeDataset 37 | from VILP.workspace.utils import config_to_parser, instantiate_from_config, HumanDemonstrationsWrapper 38 | 39 | OmegaConf.register_new_resolver("eval", eval, replace=True) 40 | class TrainVqganWorkspace(BaseWorkspace): 41 | include_keys = ['global_step', 'epoch'] 42 | 43 | def __init__(self, cfg: OmegaConf): 44 | super().__init__(cfg) 45 | 46 | # set seed 47 | seed = cfg.seed 48 | torch.manual_seed(seed) 49 | np.random.seed(seed) 50 | random.seed(seed) 51 | trainer_config = {} 52 | trainer_config["gpus"] = cfg.gpus 53 | # configure model 54 | 55 | model = VQModel(**cfg.model) 56 | 57 | self.model = model 58 | # print model size 59 | print("Model size: {:.2f}M".format(sum(p.numel() for p in model.parameters()) / 1e6)) 60 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 61 | sys.path.append(os.getcwd()) 62 | parser = argparse.ArgumentParser() 63 | parser = Trainer.add_argparse_args(parser) 64 | parser = config_to_parser(cfg.trainer, parser) 65 | opt, unknown = parser.parse_known_args() 66 | 67 | nowname = now+opt.postfix 68 | logdir = os.path.join("vq_models", nowname) 69 | ckptdir = os.path.join(logdir, "checkpoints") 70 | cfgdir = os.path.join(logdir, "configs") 71 | configs = {} 72 | cli = OmegaConf.from_dotlist(unknown) 73 | config = OmegaConf.merge(*configs, cli) 74 | lightning_config = config.pop("lightning", OmegaConf.create()) 75 | trainer_opt = argparse.Namespace(**trainer_config) 76 | lightning_config.trainer = trainer_config 77 | # trainer and callbacks 78 | trainer_kwargs = dict() 79 | 80 | # add callback which sets up log directory 81 | default_callbacks_cfg = { 82 | "setup_callback": { 83 | "target": "VILP.workspace.utils.SetupCallback", 84 | "params": { 85 | "resume": opt.resume, 86 | "now": now, 87 | "logdir": logdir, 88 | "ckptdir": ckptdir, 89 | "cfgdir": cfgdir, 90 | "config": config, 91 | "lightning_config": lightning_config, 92 | } 93 | }, 94 | "image_logger": { 95 | "target": "VILP.workspace.utils.ImageLogger", 96 | "params": { 97 | "batch_frequency": 750, 98 | "max_images": 4, 99 | "clamp": True, 100 | "home_dir": logdir 101 | } 102 | }, 103 | } 104 | 105 | default_callbacks_cfg["periodic_checkpoint"] = { 106 | "target": "VILP.workspace.utils.PeriodicCheckpointCallback", 107 | "params": { 108 | "ckptdir": ckptdir, 109 | "save_interval": 1, 110 | "save_last_n": 3, 111 | "verbose": True, 112 | } 113 | } 114 | 115 | callbacks_cfg = lightning_config.get('callbacks') or OmegaConf.create() 116 | callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg) 117 | trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] 118 | 119 | trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) 120 | self.trainer = trainer 121 | self.lightning_config = lightning_config 122 | self.cpktdir = ckptdir 123 | self.opt = opt 124 | def run(self): 125 | cfg = copy.deepcopy(self.cfg) 126 | lightning_config = self.lightning_config 127 | opt = self.opt 128 | human_demonstrations_img: BaseImageDataset = hydra.utils.instantiate(cfg.dataset) 129 | data = HumanDemonstrationsWrapper(human_demonstrations_img,cfg) 130 | bs, base_lr = cfg.batch_size, cfg.base_learning_rate 131 | ngpu = len(lightning_config.trainer.gpus.strip(",").split(',')) 132 | accumulate_grad_batches = lightning_config.trainer.get('accumulate_grad_batches', 1) 133 | self.model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr 134 | print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format( 135 | self.model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr)) 136 | 137 | if opt.train: 138 | try: 139 | self.trainer.fit(self.model, data) 140 | except Exception: 141 | raise 142 | if not opt.no_test and not self.trainer.interrupted: 143 | self.trainer.test(self.model, data) 144 | 145 | @hydra.main( 146 | version_base=None, 147 | config_path=str(pathlib.Path(__file__).parent.parent.joinpath("config")), 148 | config_name=pathlib.Path(__file__).stem) 149 | def main(cfg): 150 | workspace = TrainVqganWorkspace(cfg) 151 | workspace.run() 152 | 153 | if __name__ == "__main__": 154 | main() -------------------------------------------------------------------------------- /conda_environment.yaml: -------------------------------------------------------------------------------- 1 | name: vilpenv 2 | channels: 3 | - pytorch 4 | - pytorch3d 5 | - nvidia 6 | - conda-forge 7 | dependencies: 8 | - python=3.9 9 | - pip=22.2.2 10 | - cudatoolkit=11.6 11 | - pytorch=1.12.1 12 | - torchvision=0.13.1 13 | - pytorch3d=0.7.0 14 | - numpy=1.23.3 15 | - numba==0.56.4 16 | - scipy==1.9.1 17 | - py-opencv=4.6.0 18 | - cffi=1.15.1 19 | - ipykernel=6.16 20 | - matplotlib=3.6.1 21 | - zarr=2.12.0 22 | - numcodecs=0.10.2 23 | - h5py=3.7.0 24 | - hydra-core=1.2.0 25 | - einops=0.4.1 26 | - tqdm=4.64.1 27 | - dill=0.3.5.1 28 | - scikit-video=1.1.11 29 | - scikit-image=0.19.3 30 | - gym=0.21.0 31 | - pymunk=6.2.1 32 | - wandb=0.13.3 33 | - threadpoolctl=3.1.0 34 | - shapely=1.8.4 35 | - cython=0.29.32 36 | - imageio=2.22.0 37 | - imageio-ffmpeg=0.4.7 38 | - termcolor=2.0.1 39 | - tensorboard=2.10.1 40 | - tensorboardx=2.5.1 41 | - psutil=5.9.2 42 | - click=8.0.4 43 | - boto3=1.24.96 44 | - accelerate=0.13.2 45 | - datasets=2.6.1 46 | - diffusers=0.26.1 47 | - av=10.0.0 48 | - cmake=3.24.3 49 | # trick to avoid cpu affinity issue described in https://github.com/pytorch/pytorch/issues/99625 50 | - llvm-openmp=14 51 | # trick to force reinstall imagecodecs via pip 52 | - imagecodecs==2022.8.8 53 | - pip: 54 | - ray[default,tune]==2.2.0 55 | # requires mujoco py dependencies libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf 56 | - free-mujoco-py==2.1.6 57 | - pygame==2.1.2 58 | - pybullet-svl==3.1.6.4 59 | - robosuite @ https://github.com/cheng-chi/robosuite/archive/277ab9588ad7a4f4b55cf75508b44aa67ec171f0.tar.gz 60 | - robomimic==0.2.0 61 | - pytorchvideo==0.1.5 62 | # pip package required for jpeg-xl 63 | - imagecodecs==2022.9.26 64 | - r3m @ https://github.com/facebookresearch/r3m/archive/b2334e726887fa0206962d7984c69c5fb09cceab.tar.gz 65 | - dm-control==1.0.9 66 | - albumentations==0.4.3 67 | - pudb==2019.2 68 | - imageio==2.9.0 69 | - imageio-ffmpeg==0.4.2 70 | - pytorch-lightning==1.0.8 71 | - test-tube>=0.7.5 72 | - streamlit>=0.73.1 73 | - einops==0.3.0 74 | - more-itertools>=8.0.0 75 | - transformers==4.3.1 76 | 77 | -------------------------------------------------------------------------------- /images_to_replybuffer.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | from absl import app 3 | import numpy as np 4 | import zarr 5 | from diffusion_policy.common.replay_buffer import ReplayBuffer 6 | import os 7 | from pathlib import Path 8 | import cv2 9 | 10 | dataset_dir = 'data/handle/train_handle' 11 | 12 | def batch_process(batch): 13 | processed_data = [] 14 | for img in batch: 15 | processed_img = cv2.cvtColor(cv2.imread(str(img)), cv2.COLOR_BGR2RGB) 16 | processed_data.append(processed_img) 17 | return np.array(processed_data) 18 | 19 | def load_data(data_type, dataset_dir): 20 | data_by_episode = {} 21 | data_path = Path(dataset_dir) 22 | run_folders = sorted(data_path.glob('run_*')) 23 | 24 | for run_folder in run_folders: 25 | episode_id = int(run_folder.name.split('_')[1]) 26 | image_folder = run_folder / 'images_linear' 27 | image_files = sorted(image_folder.glob('*.jpg')) 28 | images = batch_process(image_files) 29 | 30 | if episode_id not in data_by_episode: 31 | data_by_episode[episode_id] = [] 32 | data_by_episode[episode_id].append(images) 33 | 34 | return data_by_episode 35 | 36 | def add_episode_to_buffer(buffer, episode_data): 37 | episode_length = len(episode_data['image']) 38 | if episode_length == 0: 39 | return # No data to add 40 | 41 | episode_data = {key: np.array(value) for key, value in episode_data.items()} 42 | buffer.add_episode(episode_data, compressors="disk") 43 | 44 | def main(argv): 45 | if len(argv) > 1: 46 | raise app.UsageError('Too many command-line arguments.') 47 | 48 | output_dir = "data/handle_reply_buffer" 49 | os.makedirs(output_dir, exist_ok=True) 50 | output_dir = Path(output_dir) 51 | zarr_path = str(output_dir.joinpath("replay_buffer.zarr").absolute()) 52 | replay_buffer = ReplayBuffer.create_from_path(zarr_path=zarr_path, mode="a") 53 | image_data = load_data('color', dataset_dir) 54 | 55 | episode_ids = set(image_data) 56 | 57 | for episode_id in episode_ids: 58 | print((image_data.get(episode_id,[]))[0].shape) 59 | episode_data = { 60 | 'image': image_data.get(episode_id, [])[0], 61 | } 62 | add_episode_to_buffer(replay_buffer, episode_data) 63 | 64 | if __name__ == '__main__': 65 | app.run(main) 66 | -------------------------------------------------------------------------------- /install_custom_packages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | source ~/miniforge3/etc/profile.d/conda.sh 6 | conda activate vilpenv 7 | 8 | mkdir third_party 9 | cd third_party 10 | git clone https://github.com/real-stanford/diffusion_policy.git 11 | cd diffusion_policy 12 | pip install -e . 13 | cd .. 14 | cd .. 15 | pip install -e . -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'VILP', 5 | packages = find_packages(), 6 | ) 7 | -------------------------------------------------------------------------------- /teasers/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhengtongXu/VILP/fea07284c61915da258cd600f940544fd0960e5f/teasers/teaser.gif -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | Training: 4 | python train.py --config-name=train_diffusion_lowdim_workspace 5 | """ 6 | 7 | import sys 8 | # use line-buffering for both stdout and stderr 9 | sys.stdout = open(sys.stdout.fileno(), mode='w', buffering=1) 10 | sys.stderr = open(sys.stderr.fileno(), mode='w', buffering=1) 11 | 12 | import hydra 13 | from omegaconf import OmegaConf 14 | import pathlib 15 | from diffusion_policy.workspace.base_workspace import BaseWorkspace 16 | 17 | # allows arbitrary python code execution in configs using the ${eval:''} resolver 18 | OmegaConf.register_new_resolver("eval", eval, replace=True) 19 | 20 | @hydra.main( 21 | version_base=None, 22 | config_path=str(pathlib.Path(__file__).parent.joinpath( 23 | 'VILP','config')) 24 | ) 25 | def main(cfg: OmegaConf): 26 | # resolve immediately so all the ${now:} resolvers 27 | # will use the same time. 28 | OmegaConf.resolve(cfg) 29 | 30 | cls = hydra.utils.get_class(cfg._target_) 31 | workspace: BaseWorkspace = cls(cfg) 32 | workspace.run() 33 | 34 | if __name__ == "__main__": 35 | main() 36 | --------------------------------------------------------------------------------