├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── configs ├── vdd_beso_blockpush.yml ├── vdd_beso_kitchen.yml └── vdd_toytask2d.yml ├── install.sh ├── pylintrc ├── pyproject.toml ├── scripts ├── __init__.py ├── evaluate.py └── train.py ├── setup.py ├── static └── image │ ├── overview.png │ └── toy_task_ov.png └── vdd ├── __init__.py ├── agents ├── __init__.py ├── em.py ├── imc.py └── vdd.py ├── networks ├── __init__.py ├── gating.py ├── gaussian.py ├── gpt.py ├── mlp.py ├── moes.py └── network_utils.py ├── score_functions ├── __init__.py ├── beso_score.py ├── ddpm_score.py ├── gmm_score.py ├── score_base.py └── score_utils.py ├── utils.py └── workspaces ├── __init__.py ├── base_manager.py ├── block_push_manager.py ├── d3il_manager.py ├── kitchen_manager.py ├── manager_factory.py ├── toytask1d_manager.py └── toytask2d_manager.py /.gitignore: -------------------------------------------------------------------------------- 1 | # File created using '.gitignore Generator' for Visual Studio Code: https://bit.ly/vscode-gig 2 | # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,jupyternotebooks,python 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,linux,jupyternotebooks,python 4 | 5 | ### JupyterNotebooks ### 6 | # gitignore template for Jupyter Notebooks 7 | # website: http://jupyter.org/ 8 | 9 | .ipynb_checkpoints 10 | */.ipynb_checkpoints/* 11 | 12 | # IPython 13 | profile_default/ 14 | ipython_config.py 15 | 16 | # Remove previous ipynb_checkpoints 17 | # git rm -r .ipynb_checkpoints/ 18 | 19 | ### Linux ### 20 | *~ 21 | 22 | # temporary files which can be created if a process still has a handle open of a deleted file 23 | .fuse_hidden* 24 | 25 | # KDE directory preferences 26 | .directory 27 | 28 | # Linux trash folder which might appear on any partition or disk 29 | .Trash-* 30 | 31 | # .nfs files are created when an open file is removed but is still being accessed 32 | .nfs* 33 | 34 | ### Python ### 35 | # Byte-compiled / optimized / DLL files 36 | __pycache__/ 37 | *.py[cod] 38 | *$py.class 39 | 40 | # C extensions 41 | *.so 42 | 43 | # Distribution / packaging 44 | .Python 45 | build/ 46 | develop-eggs/ 47 | dist/ 48 | downloads/ 49 | eggs/ 50 | .eggs/ 51 | lib/ 52 | lib64/ 53 | parts/ 54 | sdist/ 55 | var/ 56 | wheels/ 57 | share/python-wheels/ 58 | *.egg-info/ 59 | .installed.cfg 60 | *.egg 61 | MANIFEST 62 | 63 | # PyInstaller 64 | # Usually these files are written by a python script from a template 65 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 66 | *.manifest 67 | *.spec 68 | 69 | # Installer logs 70 | pip-log.txt 71 | pip-delete-this-directory.txt 72 | 73 | # Unit test / coverage reports 74 | htmlcov/ 75 | .tox/ 76 | .nox/ 77 | .coverage 78 | .coverage.* 79 | .cache 80 | nosetests.xml 81 | coverage.xml 82 | *.cover 83 | *.py,cover 84 | .hypothesis/ 85 | .pytest_cache/ 86 | cover/ 87 | 88 | # Translations 89 | *.mo 90 | *.pot 91 | 92 | # Django stuff: 93 | *.log 94 | local_settings.py 95 | db.sqlite3 96 | db.sqlite3-journal 97 | 98 | # Flask stuff: 99 | instance/ 100 | .webassets-cache 101 | 102 | # Scrapy stuff: 103 | .scrapy 104 | 105 | # Sphinx documentation 106 | docs/_build/ 107 | 108 | # PyBuilder 109 | .pybuilder/ 110 | target/ 111 | 112 | # Jupyter Notebook 113 | 114 | # IPython 115 | 116 | # pyenv 117 | # For a library or package, you might want to ignore these files since the code is 118 | # intended to run in multiple environments; otherwise, check them in: 119 | # .python-version 120 | 121 | # pipenv 122 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 123 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 124 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 125 | # install all needed dependencies. 126 | #Pipfile.lock 127 | 128 | # poetry 129 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 130 | # This is especially recommended for binary packages to ensure reproducibility, and is more 131 | # commonly ignored for libraries. 132 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 133 | #poetry.lock 134 | 135 | # pdm 136 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 137 | #pdm.lock 138 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 139 | # in version control. 140 | # https://pdm.fming.dev/#use-with-ide 141 | .pdm.toml 142 | 143 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 144 | __pypackages__/ 145 | 146 | # Celery stuff 147 | celerybeat-schedule 148 | celerybeat.pid 149 | 150 | # SageMath parsed files 151 | *.sage.py 152 | 153 | # Environments 154 | .env 155 | .venv 156 | env/ 157 | venv/ 158 | ENV/ 159 | env.bak/ 160 | venv.bak/ 161 | 162 | # Spyder project settings 163 | .spyderproject 164 | .spyproject 165 | 166 | # Rope project settings 167 | .ropeproject 168 | 169 | # mkdocs documentation 170 | /site 171 | 172 | # mypy 173 | .mypy_cache/ 174 | .dmypy.json 175 | dmypy.json 176 | 177 | # Pyre type checker 178 | .pyre/ 179 | 180 | # pytype static type analyzer 181 | .pytype/ 182 | 183 | # Cython debug symbols 184 | cython_debug/ 185 | 186 | # PyCharm 187 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 188 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 189 | # and can be added to the global gitignore or merged into this file. For a more nuclear 190 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 191 | #.idea/ 192 | 193 | ### VisualStudioCode ### 194 | .vscode/* 195 | 196 | # Local History for Visual Studio Code 197 | .history/ 198 | 199 | # Built Visual Studio Code Extensions 200 | *.vsix 201 | 202 | ### VisualStudioCode Patch ### 203 | # Ignore all local history of files 204 | .history 205 | .ionide 206 | .idea/ 207 | *cw2_results/ 208 | *beso/ 209 | *relay-policy-learning/ 210 | 211 | # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,linux,jupyternotebooks,python 212 | 213 | # Custom rules (everything added below won't be overriden by 'Generate .gitignore File' if you use 'Update' option) 214 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.3.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | #- id: check-added-large-files 11 | 12 | - repo: https://github.com/psf/black 13 | rev: 22.10.0 14 | hooks: 15 | - id: black 16 | - id: black-jupyter 17 | 18 | - repo: https://github.com/pycqa/isort 19 | rev: 5.12.0 20 | hooks: 21 | - id: isort 22 | args: ["--profile", "black"] 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 intuitive-robots 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [NeurIPS 2024] Official code for "Variational Distillation of Diffusion Policies into Mixture of Experts" 2 | # (Under construction) 3 | 4 | ![Alt Text](./static/image/overview.png) 5 | 6 | ## Installation Guide 7 | 8 | First create a conda environment using the following command 9 | 10 | ```bash 11 | sh install.sh 12 | ``` 13 | 14 | During this process two additional packages will be installed: 15 | 16 | - [Relay Policy Learning](https://github.com/google-research/relay-policy-learning) 17 | - [BESO](https://github.com/intuitive-robots/beso) 18 | 19 | To add relay_kitchen environment to the PYTHONPATH run the following commands: 20 | 21 | ``` 22 | conda develop 23 | conda develop /adept_envs 24 | conda develop /adept_envs/adept_envs 25 | ``` 26 | 27 | **Dataset** 28 | 29 | To download the dataset for the Relay Kitchen and the Block Push environment from the given link and repository, and adjust the data paths in the ```franka_kitchen_main_config.yaml``` and ```block_push_main_config.yaml``` files, follow these steps: 30 | 31 | 1. Download the dataset: Go to the [link](https://osf.io/q3dx2/) from [play-to-policy](https://github.com/jeffacce/play-to-policy) and download the dataset for the Relay Kitchen and Block Push environments. 32 | 33 | 2. Unzip the dataset: After downloading, unzip the dataset file and store it. 34 | 35 | 3. Adjust model paths in the configuration files: 36 | 37 | For example, for franka kitchen. Open the ```./configs/vdd_beso_kitchen_config.yaml``` and set the model_path argument to ```[Path to Beso]/beso/trained_models/kitchen/c_beso_1```. 38 | 39 | --- 40 | ## Run Experiment 41 | ``` 42 | python scripts/training.py configs/.yml 43 | ``` 44 | For example, to run the 2D toy experiment, run the following command: 45 | ``` 46 | python scripts/training.py configs/vdd_toytask2d.yml 47 | ``` 48 | To run the experiment on the Franka Kitchen environment, run the following command: 49 | ``` 50 | python scripts/training.py configs/vdd_beso_kitchen.yml 51 | ``` 52 | 53 | --- 54 | ### Acknowledgements 55 | 56 | This repo relies on the following existing codebases: 57 | - The beso implementation are based on [BESO](https://github.com/intuitive-robots/beso). 58 | - The goal-conditioned variants of the environments are based on [play-to-policy](https://github.com/jeffacce/play-to-policy). 59 | - the ```score_gpt``` class is adapted from [miniGPT](https://github.com/karpathy/minGPT). 60 | --- 61 | 62 | ## Citation 63 | 64 | ```bibtex 65 | @article{zhou2024variational, 66 | title={Variational Distillation of Diffusion Policies into Mixture of Experts}, 67 | author={Zhou, Hongyi and Blessing, Denis and Li, Ge and Celik, Onur and Jia, Xiaogang and Neumann, Gerhard and Lioutikov, Rudolf}, 68 | journal={arXiv preprint arXiv:2406.12538}, 69 | year={2024} 70 | } 71 | ``` 72 | 73 | --- -------------------------------------------------------------------------------- /configs/vdd_beso_blockpush.yml: -------------------------------------------------------------------------------- 1 | # cw2 config 2 | repetitions: 1 3 | reps_per_job: 1 4 | reps_in_parallel: 1 5 | iterations: &iterations 20001 6 | num_checkpoints: 2 7 | 8 | # Global config 9 | exp_path: &exp_path "./cw2_results/vdd/blockpush" 10 | exp_name: &exp_name "vdd_beso_block_push" 11 | 12 | 13 | # cw2 config 14 | name: *exp_name 15 | path: *exp_path 16 | device: &device "cuda" 17 | dtype: &dtype "float32" 18 | seed: &seed 0 19 | enable_wandb: false 20 | 21 | # wandb 22 | wandb: 23 | project: "VDD_BlockPush" 24 | group: *exp_name 25 | entity: [YOUR_WANDB_ENTITY] 26 | log_interval: 10 27 | log_model: false 28 | model_name: model 29 | 30 | params: 31 | gpu_id: 0 32 | policy_params: 33 | moe_params: 34 | obs_dim: 10 35 | act_dim: 2 36 | goal_dim: 10 37 | goal_conditional: true 38 | num_components: 4 39 | cmp_cov_type: "full" 40 | cmp_mean_hidden_dims: 512 41 | cmp_mean_hidden_layers: 2 42 | cmp_cov_hidden_dims: 512 43 | cmp_cov_hidden_layers: 2 44 | bias_init_bound: 1.0 45 | cmp_activation: "mish" 46 | cmp_init: 'orthogonal' 47 | cmp_init_std: 1.0 48 | cmp_minimal_std: 0.0001 49 | prior_type: "uniform" 50 | learn_gating: true 51 | gating_hidden_layers: 2 52 | gating_hidden_dims: 64 53 | greedy_predict: false 54 | #### Transformer 55 | backbone_params: 56 | use_transformer: true 57 | n_layers: 4 58 | window_size: &window_size 5 59 | goal_seq_len: 1 60 | n_heads: 12 61 | embed_dim: 240 62 | embed_pdrop: 0.0 63 | atten_pdrop: 0.05 64 | resid_pdrop: 0.05 65 | #### vision encoder params 66 | vision_task: false 67 | 68 | 69 | optimizer_params: 70 | optimizer_type: "adam" 71 | cmps_lr: 0.0001 72 | cmps_lr_schedule: "linear" 73 | cmps_weight_decay: 0.0 74 | gating_lr: 0.0001 75 | gating_lr_schedule: "linear" 76 | gating_weight_decay: 0.0 77 | 78 | train_params: 79 | max_train_iters: *iterations 80 | cmp_steps: 2 81 | gating_steps: 1 82 | fix_gating_after_iters: *iterations 83 | vi_batch_size: 2 84 | train_batch_size: &train_batch_size 1024 85 | test_batch_size: &test_batch_size 1024 86 | # num_workers: &num_workers 4 87 | test_interval: 10 88 | env_rollout_interval: 4000 89 | num_rollouts: 40 90 | num_contexts: 10 91 | final_num_contexts: 1 92 | final_num_rollouts: 200 93 | device: *device 94 | dtype: *dtype 95 | 96 | experiment_params: 97 | experiment_name: "block_push" 98 | model_path: "[PATH to beso]/beso/trained_models/block_push/c_beso_1" 99 | sv_name: "model_state_dict.pth" 100 | model_select_metric: "avrg_result" 101 | datasets_config: 102 | train_batch_size: *train_batch_size 103 | test_batch_size: *test_batch_size 104 | window_size: *window_size 105 | train_fraction: 0.95 106 | num_workers: 1 107 | score_fn_params: 108 | noise_level_type: "uniform" 109 | weights_type: "stable" 110 | sigma_min: 0.2 111 | sigma_max: 0.5 112 | seed: *seed 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /configs/vdd_beso_kitchen.yml: -------------------------------------------------------------------------------- 1 | # cw2 config 2 | repetitions: 1 3 | reps_per_job: 1 4 | reps_in_parallel: 1 5 | iterations: &iterations 20001 6 | num_checkpoints: 2 7 | 8 | # Global config 9 | exp_path: &exp_path "./cw2_results/vdd/kitchen" 10 | exp_name: &exp_name "vdd_beso_kitchen" 11 | 12 | 13 | # cw2 config 14 | name: *exp_name 15 | path: *exp_path 16 | device: &device "cuda" 17 | dtype: &dtype "float32" 18 | seed: &seed 0 19 | enable_wandb: false 20 | 21 | # wandb 22 | wandb: 23 | project: "VDD_Kitchen" 24 | group: *exp_name 25 | entity: zhouhongyi 26 | log_interval: 10 27 | log_model: false 28 | model_name: model 29 | 30 | params: 31 | gpu_id: 0 32 | policy_params: 33 | moe_params: 34 | obs_dim: 30 35 | act_dim: 9 36 | goal_dim: 30 37 | goal_conditional: true 38 | num_components: 4 39 | cmp_cov_type: "full" 40 | cmp_mean_hidden_dims: 256 41 | cmp_mean_hidden_layers: 2 42 | cmp_cov_hidden_dims: 256 43 | cmp_cov_hidden_layers: 2 44 | bias_init_bound: 1.0 45 | cmp_activation: "mish" 46 | cmp_init: 'orthogonal' 47 | cmp_init_std: 1.0 48 | cmp_minimal_std: 0.0001 49 | prior_type: "uniform" 50 | learn_gating: true 51 | gating_hidden_layers: 2 52 | gating_hidden_dims: 64 53 | greedy_predict: false 54 | #### Transformer 55 | backbone_params: 56 | use_transformer: true 57 | n_layers: 6 58 | window_size: &window_size 4 59 | goal_seq_len: 2 60 | n_heads: 12 61 | embed_dim: 240 62 | embed_pdrop: 0.0 63 | atten_pdrop: 0.1 64 | resid_pdrop: 0.1 65 | #### vision encoder params 66 | vision_task: false 67 | 68 | 69 | optimizer_params: 70 | optimizer_type: "adam" 71 | cmps_lr: 0.0001 72 | cmps_lr_schedule: "linear" 73 | cmps_weight_decay: 0.0 74 | gating_lr: 0.0001 75 | gating_lr_schedule: "linear" 76 | gating_weight_decay: 0.0 77 | 78 | train_params: 79 | max_train_iters: *iterations 80 | cmp_steps: 2 81 | gating_steps: 1 82 | fix_gating_after_iters: *iterations 83 | vi_batch_size: 2 84 | train_batch_size: &train_batch_size 1024 85 | test_batch_size: &test_batch_size 1024 86 | # num_workers: &num_workers 4 87 | test_interval: 10 88 | env_rollout_interval: 2000 89 | num_rollouts: 40 90 | num_contexts: 10 91 | final_num_contexts: 1 92 | final_num_rollouts: 200 93 | device: *device 94 | dtype: *dtype 95 | 96 | experiment_params: 97 | experiment_name: "kitchen" 98 | model_path: "[Path to Beso]/beso/trained_models/kitchen/c_beso_1" 99 | sv_name: "model_state_dict.pth" 100 | model_select_metric: "avrg_result" 101 | datasets_config: 102 | train_batch_size: *train_batch_size 103 | test_batch_size: *test_batch_size 104 | window_size: *window_size 105 | train_fraction: 0.95 106 | num_workers: 1 107 | score_fn_params: 108 | noise_level_type: "uniform" 109 | weights_type: "stable" 110 | sigma_min: 0.1 111 | sigma_max: 0.1 112 | seed: *seed 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /configs/vdd_toytask2d.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: "SLURM" 3 | partition: "accelerated" 4 | job-name: "toytask2d" 5 | num_parallel_jobs: 120 6 | ntasks: 1 7 | cpus-per-task: 152 8 | time: 960 9 | #gpus_per_rep: 0.5 10 | #scheduler: horeka 11 | sbatch_args: 12 | gres: "gpu:4" 13 | account: "hk-project-sustainebot" 14 | --- 15 | # cw2 config 16 | repetitions: 1 17 | reps_per_job: 1 18 | reps_in_parallel: 1 19 | iterations: &iterations 600 20 | num_checkpoints: 2 21 | 22 | # Global config 23 | exp_path: &exp_path "./cw2_results/vdd/toytask2d" 24 | exp_name: &exp_name "toytask2d" 25 | 26 | 27 | # cw2 config 28 | name: *exp_name 29 | path: *exp_path 30 | device: &device "cuda" 31 | dtype: &dtype "float32" 32 | seed: &seed 0 33 | enable_wandb: false 34 | 35 | # wandb 36 | wandb: 37 | project: "neurips_toytask2d" 38 | group: *exp_name 39 | entity: zhouhongyi 40 | log_interval: 10 41 | log_model: false 42 | model_name: model 43 | 44 | params: 45 | gpu_id: 0 46 | policy_params: 47 | moe_params: 48 | obs_dim: &obs_dim 2 49 | act_dim: 2 50 | goal_dim: &goal_dim 0 51 | goal_conditional: &goal_conditional false 52 | num_components: 8 53 | cmp_cov_type: "diag" 54 | moe_network_type: "residual" # "mlp" or "residual" 55 | cmp_hidden_dims: 64 56 | cmp_hidden_layers: 2 57 | cmp_cov_hidden_dims: 64 58 | cmp_cov_hidden_layers: 2 59 | bias_init_bound: 0.5 60 | cmp_activation: "mish" 61 | cmp_init: 'orthogonal' 62 | cmp_init_std: 1.0 63 | cmp_minimal_std: 0.0001 64 | prior_type: "uniform" 65 | learn_gating: false 66 | gating_hidden_layers: 2 67 | gating_hidden_dims: 64 68 | greedy_predict: false 69 | #### Transformer 70 | backbone_params: 71 | use_transformer: false 72 | n_layers: 1 73 | window_size: &window_size 1 74 | goal_seq_len: &goal_seq_len 0 75 | n_heads: 2 76 | embed_dim: 16 77 | embed_pdrop: 0.0 78 | atten_pdrop: 0.0 79 | resid_pdrop: 0.0 80 | #### vision encoder params 81 | vision_task: false 82 | 83 | 84 | optimizer_params: 85 | optimizer_type: "adam" 86 | cmps_lr: 0.0005 87 | cmps_lr_schedule: "linear" 88 | cmps_weight_decay: 0.0 89 | gating_lr: 0.0001 90 | gating_lr_schedule: "linear" 91 | gating_weight_decay: 0.0 92 | 93 | train_params: 94 | max_train_iters: *iterations 95 | cmp_steps: 1 96 | gating_steps: 1 97 | fix_gating_after_iters: *iterations 98 | vi_batch_size: 64 99 | train_batch_size: &train_batch_size 32 100 | test_batch_size: &test_batch_size 32 101 | # num_workers: &num_workers 4 102 | test_interval: 10 103 | env_rollout_interval: 20 104 | num_rollouts: 4 105 | num_contexts: 10 106 | final_num_contexts: 60 107 | final_num_rollouts: 8 108 | device: *device 109 | dtype: *dtype 110 | 111 | experiment_params: 112 | experiment_name: "toytask2d" 113 | model_select_metric: "iter" 114 | num_datapoints: 5000 115 | datasets_config: 116 | batch_size: *train_batch_size 117 | score_fn_params: 118 | num_components: 8 119 | r: 2.0 120 | std: 0.5 121 | seed: *seed 122 | 123 | 124 | 125 | 126 | 127 | 128 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ############ GENERAL ENV SETUP ############ 4 | echo New Environment Name: 5 | read envname 6 | 7 | echo Creating new conda environment $envname 8 | conda create -n $envname python=3.8 -y -q 9 | 10 | eval "$(conda shell.bash hook)" 11 | conda activate $envname 12 | 13 | echo 14 | echo Activating $envname 15 | if [[ "$CONDA_DEFAULT_ENV" != "$envname" ]] 16 | then 17 | echo Failed to activate conda environment. 18 | exit 1 19 | fi 20 | 21 | 22 | ############ PYTHON ############ 23 | echo Install mamba 24 | conda install mamba -c conda-forge -y -q 25 | 26 | 27 | ############ REQUIRED DEPENDENCIES (PYBULLET) ############ 28 | echo Installing dependencies... 29 | 30 | mamba install -c conda-forge hydra-core -y -q 31 | # Mujoco System Dependencies 32 | mamba install -c conda-forge glew patchelf -y -q 33 | mamba install conda-build -y -q 34 | # Set Conda Env Variables 35 | conda env config vars set LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.mujoco/mujoco210/bin:/usr/lib/nvidia 36 | conda env config vars set LD_PRELOAD=$LD_PRELOAD:$CONDA_PREFIX/lib/libGLEW.so 37 | # Activate Mujoco Py Env Variables 38 | conda activate $envname 39 | 40 | # Install MujocoPy 41 | pip install mujoco-py 42 | 43 | # Install other PIP Dependencies 44 | pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu116 45 | pip install setuptools==65.5.0 46 | pip install wheel==0.38.4 47 | pip install dm-control==0.0.403778684 48 | pip install gym==0.21.0 49 | pip install termcolor 50 | pip install wandb 51 | pip install tikzplotlib 52 | pip install einops 53 | pip install torchdiffeq 54 | pip install gin-config 55 | pip install pybullet 56 | pip install -U scikit-learn 57 | pip install torchsde 58 | 59 | 60 | ###Install Dependencies for D3IL-Sim 61 | #conda install -c conda-forge opencv pinocchio -y -q 62 | #conda install -c open3d-admin open3d -y -q 63 | #pip install mujoco 64 | 65 | echo Clone Relay policy learning 66 | git clone https://github.com/google-research/relay-policy-learning 67 | 68 | echo Clone BESO 69 | git clone https://github.com/intuitive-robots/beso.git 70 | 71 | echo Done installing all necessary packages. Please follow the next steps mentioned on the readme 72 | 73 | pip install -e . 74 | 75 | cd beso 76 | 77 | pip install -e . 78 | 79 | echo 80 | echo 81 | echo Successfully installed. 82 | echo 83 | echo To activate your environment call: 84 | echo conda activate $envname 85 | exit 0 86 | -------------------------------------------------------------------------------- /pylintrc: -------------------------------------------------------------------------------- 1 | # This Pylint rcfile contains a best-effort configuration to uphold the 2 | # best-practices and style described in the Google Python style guide: 3 | # https://google.github.io/styleguide/pyguide.html 4 | # 5 | # Its canonical open-source location is: 6 | # https://google.github.io/styleguide/pylintrc 7 | 8 | [MASTER] 9 | 10 | # Files or directories to be skipped. They should be base names, not paths. 11 | ignore=third_party 12 | 13 | # Files or directories matching the regex patterns are skipped. The regex 14 | # matches against base names, not paths. 15 | ignore-patterns= 16 | 17 | # Pickle collected data for later comparisons. 18 | persistent=no 19 | 20 | # List of plugins (as comma separated values of python modules names) to load, 21 | # usually to register additional checkers. 22 | load-plugins= 23 | 24 | # Use multiple processes to speed up Pylint. 25 | jobs=4 26 | 27 | # Allow loading of arbitrary C extensions. Extensions are imported into the 28 | # active Python interpreter and may run arbitrary code. 29 | unsafe-load-any-extension=no 30 | 31 | 32 | [MESSAGES CONTROL] 33 | 34 | # Only show warnings with the listed confidence levels. Leave empty to show 35 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 36 | confidence= 37 | 38 | # Enable the message, report, category or checker with the given id(s). You can 39 | # either give multiple identifier separated by comma (,) or put this option 40 | # multiple time (only on the command line, not in the configuration file where 41 | # it should appear only once). See also the "--disable" option for examples. 42 | #enable= 43 | 44 | # Disable the message, report, category or checker with the given id(s). You 45 | # can either give multiple identifiers separated by comma (,) or put this 46 | # option multiple times (only on the command line, not in the configuration 47 | # file where it should appear only once).You can also use "--disable=all" to 48 | # disable everything first and then reenable specific checks. For example, if 49 | # you want to run only the similarities checker, you can use "--disable=all 50 | # --enable=similarities". If you want to run only the classes checker, but have 51 | # no Warning level messages displayed, use"--disable=all --enable=classes 52 | # --disable=W" 53 | disable=abstract-method, 54 | apply-builtin, 55 | arguments-differ, 56 | attribute-defined-outside-init, 57 | backtick, 58 | bad-option-value, 59 | basestring-builtin, 60 | buffer-builtin, 61 | c-extension-no-member, 62 | consider-using-enumerate, 63 | cmp-builtin, 64 | cmp-method, 65 | coerce-builtin, 66 | coerce-method, 67 | delslice-method, 68 | div-method, 69 | duplicate-code, 70 | eq-without-hash, 71 | execfile-builtin, 72 | file-builtin, 73 | filter-builtin-not-iterating, 74 | fixme, 75 | getslice-method, 76 | global-statement, 77 | hex-method, 78 | idiv-method, 79 | implicit-str-concat, 80 | import-error, 81 | import-self, 82 | import-star-module-level, 83 | inconsistent-return-statements, 84 | input-builtin, 85 | intern-builtin, 86 | invalid-str-codec, 87 | locally-disabled, 88 | long-builtin, 89 | long-suffix, 90 | map-builtin-not-iterating, 91 | misplaced-comparison-constant, 92 | missing-function-docstring, 93 | missing-class-docstring, 94 | missing-module-docstring, 95 | metaclass-assignment, 96 | next-method-called, 97 | next-method-defined, 98 | no-absolute-import, 99 | no-else-break, 100 | no-else-continue, 101 | no-else-raise, 102 | no-else-return, 103 | no-init, # added 104 | no-member, 105 | no-name-in-module, 106 | no-self-use, 107 | nonzero-method, 108 | oct-method, 109 | old-division, 110 | old-ne-operator, 111 | old-octal-literal, 112 | old-raise-syntax, 113 | parameter-unpacking, 114 | print-statement, 115 | raising-string, 116 | range-builtin-not-iterating, 117 | raw_input-builtin, 118 | rdiv-method, 119 | reduce-builtin, 120 | relative-import, 121 | reload-builtin, 122 | round-builtin, 123 | setslice-method, 124 | signature-differs, 125 | standarderror-builtin, 126 | suppressed-message, 127 | sys-max-int, 128 | too-few-public-methods, 129 | too-many-ancestors, 130 | too-many-arguments, 131 | too-many-boolean-expressions, 132 | too-many-branches, 133 | too-many-instance-attributes, 134 | too-many-locals, 135 | too-many-nested-blocks, 136 | too-many-public-methods, 137 | too-many-return-statements, 138 | too-many-statements, 139 | trailing-newlines, 140 | unichr-builtin, 141 | unicode-builtin, 142 | unnecessary-pass, 143 | unpacking-in-except, 144 | useless-else-on-loop, 145 | useless-object-inheritance, 146 | useless-suppression, 147 | using-cmp-argument, 148 | wrong-import-order, 149 | xrange-builtin, 150 | zip-builtin-not-iterating, 151 | 152 | 153 | [REPORTS] 154 | 155 | # Set the output format. Available formats are text, parseable, colorized, msvs 156 | # (visual studio) and html. You can also give a reporter class, eg 157 | # mypackage.mymodule.MyReporterClass. 158 | output-format=text 159 | 160 | # Tells whether to display a full report or only the messages 161 | reports=no 162 | 163 | # Python expression which should return a note less than 10 (10 is the highest 164 | # note). You have access to the variables errors warning, statement which 165 | # respectively contain the number of errors / warnings messages and the total 166 | # number of statements analyzed. This is used by the global evaluation report 167 | # (RP0004). 168 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 169 | 170 | # Template used to display messages. This is a python new-style format string 171 | # used to format the message information. See doc for all details 172 | #msg-template= 173 | 174 | 175 | [BASIC] 176 | 177 | # Good variable names which should always be accepted, separated by a comma 178 | good-names=main,_ 179 | 180 | # Bad variable names which should always be refused, separated by a comma 181 | bad-names= 182 | 183 | # Colon-delimited sets of names that determine each other's naming style when 184 | # the name regexes allow several styles. 185 | name-group= 186 | 187 | # Include a hint for the correct naming format with invalid-name 188 | include-naming-hint=no 189 | 190 | # List of decorators that produce properties, such as abc.abstractproperty. Add 191 | # to this list to register other decorators that produce valid properties. 192 | property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl 193 | 194 | # Regular expression matching correct function names 195 | function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 196 | 197 | # Regular expression matching correct variable names 198 | variable-rgx=^[a-z][a-z0-9_]*$ 199 | 200 | # Regular expression matching correct constant names 201 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 202 | 203 | # Regular expression matching correct attribute names 204 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 205 | 206 | # Regular expression matching correct argument names 207 | argument-rgx=^[a-z][a-z0-9_]*$ 208 | 209 | # Regular expression matching correct class attribute names 210 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 211 | 212 | # Regular expression matching correct inline iteration names 213 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 214 | 215 | # Regular expression matching correct class names 216 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 217 | 218 | # Regular expression matching correct module names 219 | module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ 220 | 221 | # Regular expression matching correct method names 222 | method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 223 | 224 | # Regular expression which should only match function or class names that do 225 | # not require a docstring. 226 | no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ 227 | 228 | # Minimum line length for functions/classes that require docstrings, shorter 229 | # ones are exempt. 230 | docstring-min-length=10 231 | 232 | 233 | [TYPECHECK] 234 | 235 | # List of decorators that produce context managers, such as 236 | # contextlib.contextmanager. Add to this list to register other decorators that 237 | # produce valid context managers. 238 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager 239 | 240 | # Tells whether missing members accessed in mixin class should be ignored. A 241 | # mixin class is detected if its name ends with "mixin" (case insensitive). 242 | ignore-mixin-members=yes 243 | 244 | # List of module names for which member attributes should not be checked 245 | # (useful for modules/projects where namespaces are manipulated during runtime 246 | # and thus existing member attributes cannot be deduced by static analysis. It 247 | # supports qualified module names, as well as Unix pattern matching. 248 | ignored-modules= 249 | 250 | # List of class names for which member attributes should not be checked (useful 251 | # for classes with dynamically set attributes). This supports the use of 252 | # qualified names. 253 | ignored-classes=optparse.Values,thread._local,_thread._local 254 | 255 | # List of members which are set dynamically and missed by pylint inference 256 | # system, and so shouldn't trigger E1101 when accessed. Python regular 257 | # expressions are accepted. 258 | generated-members= 259 | 260 | 261 | [FORMAT] 262 | 263 | # Maximum number of characters on a single line. 264 | max-line-length=88 265 | 266 | # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt 267 | # lines made too long by directives to pytype. 268 | 269 | # Regexp for a line that is allowed to be longer than the limit. 270 | ignore-long-lines=(?x)( 271 | ^\s*(\#\ )??$| 272 | ^\s*(from\s+\S+\s+)?import\s+.+$) 273 | 274 | # Allow the body of an if to be on the same line as the test if there is no 275 | # else. 276 | single-line-if-stmt=yes 277 | 278 | # Maximum number of lines in a module 279 | max-module-lines=99999 280 | 281 | # String used as indentation unit. The internal Google style guide mandates 2 282 | # spaces. Google's externaly-published style guide says 4, consistent with 283 | # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google 284 | # projects (like TensorFlow). 285 | indent-string=" " 286 | 287 | # Number of spaces of indent required inside a hanging or continued line. 288 | indent-after-paren=4 289 | 290 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 291 | expected-line-ending-format= 292 | 293 | 294 | [MISCELLANEOUS] 295 | 296 | # List of note tags to take in consideration, separated by a comma. 297 | notes=TODO 298 | 299 | 300 | [STRING] 301 | 302 | # This flag controls whether inconsistent-quotes generates a warning when the 303 | # character used as a quote delimiter is used inconsistently within a module. 304 | check-quote-consistency=yes 305 | 306 | 307 | [VARIABLES] 308 | 309 | # Tells whether we should check for unused import in __init__ files. 310 | init-import=no 311 | 312 | # A regular expression matching the name of dummy variables (i.e. expectedly 313 | # not used). 314 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 315 | 316 | # List of additional names supposed to be defined in builtins. Remember that 317 | # you should avoid to define new builtins when possible. 318 | additional-builtins= 319 | 320 | # List of strings which can identify a callback function by name. A callback 321 | # name must start or end with one of those strings. 322 | callbacks=cb_,_cb 323 | 324 | # List of qualified module names which can have objects that can redefine 325 | # builtins. 326 | redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools 327 | 328 | 329 | [LOGGING] 330 | 331 | # Logging modules to check that the string format arguments are in logging 332 | # function parameter format 333 | logging-modules=logging,absl.logging,tensorflow.io.logging 334 | 335 | 336 | [SIMILARITIES] 337 | 338 | # Minimum lines number of a similarity. 339 | min-similarity-lines=4 340 | 341 | # Ignore comments when computing similarities. 342 | ignore-comments=yes 343 | 344 | # Ignore docstrings when computing similarities. 345 | ignore-docstrings=yes 346 | 347 | # Ignore imports when computing similarities. 348 | ignore-imports=no 349 | 350 | 351 | [SPELLING] 352 | 353 | # Spelling dictionary name. Available dictionaries: none. To make it working 354 | # install python-enchant package. 355 | spelling-dict= 356 | 357 | # List of comma separated words that should not be checked. 358 | spelling-ignore-words= 359 | 360 | # A path to a file that contains private dictionary; one word per line. 361 | spelling-private-dict-file= 362 | 363 | # Tells whether to store unknown words to indicated private dictionary in 364 | # --spelling-private-dict-file option instead of raising a message. 365 | spelling-store-unknown-words=no 366 | 367 | 368 | [IMPORTS] 369 | 370 | # Deprecated modules which should not be used, separated by a comma 371 | deprecated-modules=regsub, 372 | TERMIOS, 373 | Bastion, 374 | rexec, 375 | sets 376 | 377 | # Create a graph of every (i.e. internal and external) dependencies in the 378 | # given file (report RP0402 must not be disabled) 379 | import-graph= 380 | 381 | # Create a graph of external dependencies in the given file (report RP0402 must 382 | # not be disabled) 383 | ext-import-graph= 384 | 385 | # Create a graph of internal dependencies in the given file (report RP0402 must 386 | # not be disabled) 387 | int-import-graph= 388 | 389 | # Force import order to recognize a module as part of the standard 390 | # compatibility libraries. 391 | known-standard-library= 392 | 393 | # Force import order to recognize a module as part of a third party library. 394 | known-third-party=enchant, absl 395 | 396 | # Analyse import fallback blocks. This can be used to support both Python 2 and 397 | # 3 compatible code, which means that the block might have code that exists 398 | # only in one or another interpreter, leading to false positives when analysed. 399 | analyse-fallback-blocks=no 400 | 401 | 402 | [CLASSES] 403 | 404 | # List of method names used to declare (i.e. assign) instance attributes. 405 | defining-attr-methods=__init__, 406 | __new__, 407 | setUp 408 | 409 | # List of member names, which should be excluded from the protected access 410 | # warning. 411 | exclude-protected=_asdict, 412 | _fields, 413 | _replace, 414 | _source, 415 | _make 416 | 417 | # List of valid names for the first argument in a class method. 418 | valid-classmethod-first-arg=cls, 419 | class_ 420 | 421 | # List of valid names for the first argument in a metaclass class method. 422 | valid-metaclass-classmethod-first-arg=mcs 423 | 424 | 425 | [EXCEPTIONS] 426 | 427 | # Exceptions that will emit a warning when being caught. Defaults to 428 | # "Exception" 429 | overgeneral-exceptions=StandardError, 430 | Exception, 431 | BaseException 432 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=64", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intuitive-robots/vdd/e1395f38dd896968abae21d4571c74cfb02ea0f8/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intuitive-robots/vdd/e1395f38dd896968abae21d4571c74cfb02ea0f8/scripts/evaluate.py -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | from cw2 import experiment, cw_error, cluster_work 5 | from cw2.cw_data import cw_logging 6 | from cw2.cw_data.cw_wandb_logger import WandBLogger 7 | 8 | from tqdm import tqdm 9 | from vdd.utils import process_cw2_train_rep_config_file 10 | 11 | from vdd.agents.vdd import VDD 12 | 13 | from vdd.workspaces.manager_factory import create_experiment_manager 14 | 15 | from vdd.utils import global_seeding 16 | 17 | import numpy as np 18 | 19 | 20 | def count_parameters(model): 21 | return sum(p.numel() for p in model.parameters()) 22 | 23 | 24 | class VDDtrain(experiment.AbstractIterativeExperiment): 25 | 26 | def initialize( 27 | self, cw_config: dict, rep: int, logger: cw_logging.LoggerArray 28 | ) -> None: 29 | algo_config = cw_config["params"] 30 | 31 | gpu_id = algo_config.get("gpu_id", None) 32 | 33 | cpu_cores = cw_config.get("cpu_cores", None) 34 | 35 | if gpu_id is not None: 36 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 37 | 38 | if cpu_cores is None: 39 | num_cpus = algo_config.get("num_cpus", 1) 40 | cpu_start_id = algo_config.get("cpu_start_id", 0) 41 | cpu_cores = tuple(np.arange(cpu_start_id, cpu_start_id + num_cpus)) 42 | 43 | self.device = algo_config['train_params']['device'] 44 | 45 | experiment_config = algo_config['experiment_params'] 46 | 47 | experiment_config['device'] = self.device 48 | 49 | experiment_config['cpu_cores'] = tuple(cpu_cores) 50 | 51 | self.experiment_manager = create_experiment_manager(experiment_config) 52 | 53 | train_config = algo_config['train_params'] 54 | self.num_final_rollouts = train_config.get('num_final_rollouts', 100) 55 | self.num_final_contexts = train_config.get('num_final_contexts', 60) 56 | print(f'num_final_rollouts: {self.num_final_rollouts}, num_final_contexts: {self.num_final_contexts}') 57 | 58 | self.train_dataset, self.test_dataset = self.experiment_manager.get_train_and_test_datasets() 59 | 60 | score_function = self.experiment_manager.get_score_function() 61 | 62 | cw_config['seed'] = rep 63 | 64 | global_seeding(cw_config['seed']) 65 | 66 | self.vid_agent = VDD.create_vdd_agent(policy_params=algo_config['policy_params'], 67 | optimizer_params=algo_config['optimizer_params'], 68 | training_params=algo_config['train_params'], 69 | score_function=score_function, 70 | train_dataset=self.train_dataset, 71 | test_dataset=self.test_dataset, ) 72 | 73 | self.vid_agent.get_scaler(self.experiment_manager.scaler) 74 | self.vid_agent.get_exp_manager(self.experiment_manager) 75 | print(f'Number of parameters: {count_parameters(self.vid_agent.agent)}') 76 | 77 | self.test_interval = algo_config['train_params']['test_interval'] 78 | self.env_rollout_interval = algo_config['train_params']['env_rollout_interval'] 79 | self.num_env_rollouts = algo_config['train_params']['num_rollouts'] 80 | self.num_env_contexts = algo_config['train_params']['num_contexts'] 81 | self.save_model_dir = os.path.join(cw_config['_rep_log_path'], 'model') 82 | self.model_select_metric = algo_config['experiment_params']['model_select_metric'] 83 | self.max_reward = -1e10 84 | 85 | self.progress_bar = tqdm(total=algo_config["train_params"]["max_train_iters"], disable=False) 86 | 87 | def iterate(self, cw_config: dict, rep: int, n: int) -> dict: 88 | """ 89 | Arguments: 90 | cw_config {dict} -- clusterwork experiment configuration 91 | rep {int} -- repetition counter 92 | n {int} -- iteration counter 93 | """ 94 | 95 | train_metric_dict = self.vid_agent.iterative_train(n) 96 | 97 | if n % self.test_interval == 0: 98 | test_metric_dict = self.vid_agent.iterative_evaluate() 99 | else: 100 | test_metric_dict = {} 101 | 102 | if n % self.env_rollout_interval == 0 and n > 0: 103 | rollout_dict = self.experiment_manager.env_rollout(self.vid_agent.agent, self.num_env_rollouts, 104 | num_ctxts=self.num_env_contexts) 105 | if rollout_dict[self.model_select_metric] > self.max_reward: 106 | self.vid_agent.save_best_model(path=self.save_model_dir) 107 | self.max_reward = rollout_dict[self.model_select_metric] 108 | print( 109 | f'save new best model at iteration {n}, with {self.model_select_metric}: {rollout_dict[self.model_select_metric]}') 110 | else: 111 | rollout_dict = {} 112 | 113 | self.progress_bar.update(1) 114 | 115 | if n == cw_config["iterations"] - 1: 116 | best_model_path = os.path.join(self.save_model_dir, 'best_model.pt') 117 | best_agent = torch.load(best_model_path) 118 | self.experiment_manager.goal_idx_offset = 0 119 | final_rollout_dict = self.experiment_manager.env_rollout(best_agent, self.num_final_rollouts, 120 | num_ctxts=self.num_final_contexts) 121 | final_rollout_dict = {"final_" + k: v for k, v in final_rollout_dict.items()} 122 | print(f'Final rollout with best model: {final_rollout_dict}') 123 | else: 124 | final_rollout_dict = {} 125 | 126 | return {**train_metric_dict, **test_metric_dict, **rollout_dict, **final_rollout_dict} 127 | 128 | def save_state(self, cw_config: dict, rep: int, n: int) -> None: 129 | 130 | if (n + 1) % (cw_config['iterations'] // cw_config['num_checkpoints']) == 0 \ 131 | or (n + 1) == cw_config["params"]["train_params"]["max_train_iters"]: 132 | self.vid_agent.save_model(iteration=n + 1, path=self.save_model_dir) 133 | 134 | def finalize(self, surrender: cw_error.ExperimentSurrender = None, crash: bool = False) -> None: 135 | print('Finalizing') 136 | pass 137 | 138 | 139 | if __name__ == "__main__": 140 | cw = cluster_work.ClusterWork(VDDtrain) 141 | if cw.config.exp_configs[0]['enable_wandb']: 142 | cw.add_logger(WandBLogger()) 143 | 144 | process_cw2_train_rep_config_file(cw.config, overwrite=True) 145 | 146 | cw.run() 147 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='vdd', 5 | version='1.0.0', 6 | author='Hongyi Zhou', 7 | author_email='hongyi.zhou@kit.edu', 8 | description='Variational Diffusion Distillation', 9 | packages=find_packages(), 10 | ) 11 | -------------------------------------------------------------------------------- /static/image/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intuitive-robots/vdd/e1395f38dd896968abae21d4571c74cfb02ea0f8/static/image/overview.png -------------------------------------------------------------------------------- /static/image/toy_task_ov.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intuitive-robots/vdd/e1395f38dd896968abae21d4571c74cfb02ea0f8/static/image/toy_task_ov.png -------------------------------------------------------------------------------- /vdd/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intuitive-robots/vdd/e1395f38dd896968abae21d4571c74cfb02ea0f8/vdd/__init__.py -------------------------------------------------------------------------------- /vdd/agents/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intuitive-robots/vdd/e1395f38dd896968abae21d4571c74cfb02ea0f8/vdd/agents/__init__.py -------------------------------------------------------------------------------- /vdd/agents/em.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intuitive-robots/vdd/e1395f38dd896968abae21d4571c74cfb02ea0f8/vdd/agents/em.py -------------------------------------------------------------------------------- /vdd/agents/imc.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intuitive-robots/vdd/e1395f38dd896968abae21d4571c74cfb02ea0f8/vdd/agents/imc.py -------------------------------------------------------------------------------- /vdd/agents/vdd.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict 3 | 4 | import abc 5 | import copy 6 | 7 | import torch 8 | import torch as ch 9 | from torch.nn import functional as F 10 | from torch.utils.data import DataLoader, Dataset 11 | 12 | 13 | import einops 14 | 15 | from vdd.networks.network_utils import get_lr_schedule, get_optimizer 16 | 17 | from vdd.score_functions.score_base import ScoreFunction 18 | 19 | from vdd.networks.moes import GaussianMoE 20 | from vdd.networks.gpt import GPTNetwork 21 | 22 | from vdd.networks.gating import SoftCrossEntropyLoss 23 | 24 | 25 | 26 | 27 | class VDD(abc.ABC): 28 | def __init__(self, 29 | agent: GaussianMoE, 30 | cmps_optimizer: ch.optim.Optimizer, 31 | gating_optimizer: ch.optim.Optimizer, 32 | score_function: ScoreFunction, 33 | train_dataloader: DataLoader, 34 | test_dataloader: DataLoader, 35 | train_batch_size: int = 512, 36 | vi_batch_size: int = 2, 37 | data_shuffle: bool = True, 38 | cmps_lr_scheduler = None, 39 | gating_lr_scheduler = None, 40 | cmp_steps: int = 1, 41 | gating_steps: int = 1, 42 | learn_gating: bool = True, 43 | fix_gating_after_iters: int = 0, 44 | max_train_iters: int = 10000, 45 | seed: int = 0, 46 | device: str = 'cuda', 47 | dtype: str = 'float32', 48 | **kwargs 49 | ): 50 | self.agent = agent 51 | self.is_vision_task = False 52 | 53 | self.batch_size = train_batch_size 54 | self.data_shuffle = data_shuffle 55 | self.max_train_iters = max_train_iters 56 | self.goal_idx_offset = 0 57 | 58 | self.score_function = score_function 59 | 60 | self.cmps_optimizer = cmps_optimizer 61 | self.gating_optimizer = gating_optimizer 62 | self.cmps_lr_scheduler = cmps_lr_scheduler 63 | self.gating_lr_scheduler = gating_lr_scheduler 64 | 65 | self.learn_gating = learn_gating 66 | 67 | self.test_dataloader = test_dataloader 68 | self.train_dataloader = train_dataloader 69 | self.iter_train_dataloader = iter(train_dataloader) 70 | 71 | self.vi_batch_size = vi_batch_size 72 | self.scaler = None 73 | self.exp_manager = None 74 | 75 | self.cmp_steps = cmp_steps 76 | self.gating_steps = gating_steps 77 | self.fix_gating_after_iters = fix_gating_after_iters 78 | 79 | self.train_dataset = train_dataloader 80 | 81 | self.device = device 82 | self.dtype = ch.float64 if dtype == 'float64' else ch.float32 83 | self.seed = seed 84 | 85 | def get_scaler(self, scaler): 86 | self.scaler = scaler 87 | 88 | def get_exp_manager(self, exp_manager): 89 | self.exp_manager = exp_manager 90 | 91 | def init_agent(self,): 92 | """ 93 | update the normalizer and reset the agent 94 | :return: 95 | """ 96 | pass 97 | 98 | def iterative_train(self, n: int): 99 | """ 100 | train the agent 101 | :return: 102 | """ 103 | train_metric_dict = {} 104 | 105 | for i in range(self.cmp_steps): 106 | try: 107 | batch = next(self.iter_train_dataloader) 108 | except StopIteration: 109 | self.iter_train_dataloader = iter(self.train_dataloader) 110 | batch = next(self.iter_train_dataloader) 111 | batch = self.exp_manager.preprocess_data(batch) 112 | train_metric_dict = self.iterative_train_cmp(batch, n) 113 | 114 | if self.learn_gating and n < self.fix_gating_after_iters: 115 | for i in range(self.gating_steps): 116 | try: 117 | batch = next(self.iter_train_dataloader) 118 | except StopIteration: 119 | self.iter_train_dataloader = iter(self.train_dataloader) 120 | batch = next(self.iter_train_dataloader) 121 | batch = self.exp_manager.preprocess_data(batch) 122 | train_metric_dict.update(self.iterative_train_gating(batch[0], batch[1], batch[2])) 123 | 124 | if self.learn_gating and n == self.fix_gating_after_iters: 125 | self.agent.gating_network.train(mode=False) 126 | 127 | return train_metric_dict 128 | 129 | @ch.no_grad() 130 | def iterative_evaluate(self, ): 131 | """ 132 | evaluate the agent 133 | :return: 134 | """ 135 | ch.cuda.empty_cache() 136 | self.agent.eval() 137 | test_losses = [] 138 | gatings = [] 139 | 140 | logging_dict = {} 141 | for batch in self.test_dataloader: 142 | batch = self.exp_manager.preprocess_data(batch) 143 | inputs = batch[0] 144 | outputs = batch[1] 145 | if batch[2] is not None: 146 | goals = batch[2].to(self.device).to(self.dtype) 147 | goals = self.scaler.scale_input(goals) 148 | else: 149 | goals = None 150 | 151 | cmp_means, _, gating = self.agent(inputs, goals, train=False) 152 | 153 | cmp_means = einops.rearrange(cmp_means, 'b c t a -> (b t) c a') 154 | outputs = einops.rearrange(outputs, 'b t a -> (b t) a') 155 | 156 | gating = einops.rearrange(gating, 'b c t -> (b t) c') 157 | 158 | if torch.isnan(cmp_means).any(): 159 | print("nan in cmp_means") 160 | if torch.isnan(gating).any(): 161 | print("nan in gating_prediction") 162 | gating_dist = ch.distributions.Categorical(gating) 163 | 164 | batch_indices = ch.arange(0, cmp_means.size(0)).unsqueeze(-1) 165 | 166 | indices = gating_dist.sample([1]).swapaxes(0, 1) 167 | 168 | pred_outputs = cmp_means[batch_indices, indices, :].squeeze(1) 169 | 170 | loss = F.mse_loss(pred_outputs, outputs, reduction='mean') 171 | test_losses.append(loss.item()) 172 | gatings.append(gating) 173 | 174 | logging_dict['test_mean_mse'] = sum(test_losses)/len(test_losses) 175 | 176 | gatings = ch.cat(gatings, dim=0) 177 | avrg_entropy = ch.distributions.Categorical(probs=gatings).entropy().mean().item() 178 | logging_dict['test_mean_gating_entropy'] = avrg_entropy 179 | 180 | return logging_dict 181 | 182 | 183 | def iterative_train_gating(self, inputs, actions, goals=None): 184 | 185 | pred_means, pred_chols, pred_gatings = self.agent.forward(inputs, goals) 186 | pred_log_gatings = pred_gatings.log() 187 | 188 | pred_means = einops.rearrange(pred_means, 'b c t a -> (b t) c a') 189 | pred_chols = einops.rearrange(pred_chols, 'b c t a1 a2 -> (b t) c a1 a2') 190 | pred_log_gatings = einops.rearrange(pred_log_gatings, 'b c t -> (b t) c') 191 | actions = einops.rearrange(actions, 'b t a -> (b t) a') 192 | 193 | with ch.no_grad(): 194 | actions = actions[:, None, :].repeat(1, self.agent.n_components, 1) 195 | log_probs = self.agent.gmm_head.log_prob(x=actions, mean=pred_means, chol=pred_chols) 196 | ### log the log_probs per component 197 | log_resps = log_probs + pred_log_gatings 198 | log_resps = log_resps - ch.logsumexp(log_resps, dim=1, keepdim=True) 199 | 200 | loss_fn = SoftCrossEntropyLoss() 201 | 202 | targets = log_resps.exp() + 1e-8 203 | 204 | if torch.isnan(log_probs).any(): 205 | print("Nan in log probs") 206 | print(log_probs) 207 | if torch.isnan(log_resps).any(): 208 | print("Nan in log resps") 209 | print(log_resps) 210 | if torch.isnan(pred_log_gatings).any(): 211 | print("Nan in pred log gatings") 212 | print(pred_log_gatings) 213 | if torch.isnan(targets).any(): 214 | print("Nan in gating targets") 215 | print(targets) 216 | 217 | loss = loss_fn(pred_log_gatings, targets) 218 | self.gating_optimizer.zero_grad() 219 | loss.backward() 220 | self.gating_optimizer.step() 221 | if self.gating_lr_scheduler is not None: 222 | self.gating_lr_scheduler.step() 223 | 224 | ret_dict = {'gating_loss': loss.item()} 225 | 226 | return ret_dict 227 | 228 | 229 | def iterative_train_cmp(self, batch, iter): 230 | """ 231 | Train the joint GMM policy 232 | b -- state batch 233 | c -- the number of components 234 | v -- the number of vi samples 235 | a -- action dimension 236 | o -- observation dimension 237 | """ 238 | if not self.is_vision_task: 239 | states = batch[0].to(self.device).to(self.dtype) 240 | # states = self.scaler.scale_input(states) 241 | else: 242 | states = batch[0] 243 | 244 | if batch[2] is not None: 245 | goals = batch[2].to(self.device).to(self.dtype) 246 | else: 247 | goals = None 248 | 249 | 250 | logging_dict = {} 251 | 252 | # input : (b, t, o) 253 | # pred_means, pred_chols : (b, c, t, a), (b, c, t, a, a) 254 | pred_means, pred_chols, pred_gatings = self.agent(states, goals) 255 | 256 | pred_gatings = pred_gatings.detach() 257 | 258 | # sampled actions : (v, b, c, t, a) 259 | sampled_actions = self.agent.gmm_head.rsample(mean=pred_means, chol=pred_chols, n=self.vi_batch_size) 260 | 261 | # rearrange the sampled actions to (b, c, v, t, a) 262 | if len(sampled_actions.size()) == 4: 263 | # sampled_actions = sampled_actions.permute(1, 2, 0, 3) 264 | sampled_actions = einops.rearrange(sampled_actions, 'v b c a -> b c v a') 265 | 266 | elif len(sampled_actions.size()) == 5: 267 | sampled_actions = einops.rearrange(sampled_actions, 'v b c t a -> b c v t a') 268 | 269 | 270 | # Query the scores function 271 | # input : states (b, c, v, t, o), actions (b, c, v, t, a) 272 | # output : scores (b, c, v, a) 273 | if self.is_vision_task: 274 | # for vision task first encode the states and then repeat the latent states to save VRAM 275 | score_states = states 276 | else: 277 | score_states = einops.repeat(states, 'b t o -> b c v t o', c=self.agent.n_components, v=self.vi_batch_size) 278 | 279 | score_goals = einops.repeat(goals, 'b t o -> b c v t o', c=self.agent.n_components, v=self.vi_batch_size) if goals is not None else None 280 | with ch.no_grad(): 281 | scores, noise_level = self.score_function(sampled_actions, score_states, score_goals, iter, self.is_vision_task) 282 | 283 | ### pack the scores to ((b,t), c, v, a) 284 | scores = einops.rearrange(scores, 'b c v t a -> (b t) c v a') 285 | sampled_actions = einops.rearrange(sampled_actions, 'b c v t a -> (b t) c v a') 286 | pred_means = einops.rearrange(pred_means, 'b c t a -> (b t) c a') 287 | pred_chols = einops.rearrange(pred_chols, 'b c t a1 a2 -> (b t) c a1 a2') 288 | pred_gatings = einops.rearrange(pred_gatings, 'b c t -> (b t) c') 289 | 290 | ### TODO: check if the multiplication is correct( should that be element-wise multiplication?) 291 | # score dot action : (b, c, v) 292 | score_w_act = torch.einsum('...va,...va->...v', scores, sampled_actions) 293 | 294 | # log responsibilities : (b, c, v) 295 | responsibilities = self.agent.gmm_head.log_responsibilities(pred_means.clone().detach(), 296 | pred_chols.clone().detach(), 297 | pred_gatings, 298 | sampled_actions) 299 | 300 | # entropies : (b, c) 301 | entropies = self.agent.gmm_head.entropy(mean=pred_means, chol=pred_chols) 302 | 303 | # expectation for r_sample_terms: (b, c) 304 | r_sample_term = (score_w_act + responsibilities).mean(dim=-1) 305 | 306 | ### FIXME: weighted update 307 | unweighted_vi_loss = r_sample_term + entropies 308 | vi_loss = - (unweighted_vi_loss * pred_gatings).mean() 309 | 310 | self.cmps_optimizer.zero_grad(set_to_none=True) 311 | vi_loss.backward() 312 | self.cmps_optimizer.step() 313 | if self.cmps_lr_scheduler is not None: 314 | self.cmps_lr_scheduler.step() 315 | 316 | logging_dict[f"vi_loss"] = vi_loss.item() 317 | logging_dict[f"score_w_loss"] = -score_w_act.mean().item() 318 | logging_dict[f"responsibility_loss"] = -responsibilities.mean().item() 319 | logging_dict[f"entropy_loss_cmp"] = -entropies.mean().item() 320 | logging_dict[f"noise_level"] = noise_level.float().mean().item() 321 | 322 | if torch.isnan(vi_loss).any(): 323 | print("vi loss is nan") 324 | if torch.isnan(score_w_act).any(): 325 | print("score_w_act is nan") 326 | if torch.isnan(responsibilities).any(): 327 | print("responsibilities is nan") 328 | if torch.isnan(entropies).any(): 329 | print("entropies is nan") 330 | if torch.isnan(pred_gatings).any(): 331 | print("pred_gatings is nan") 332 | exit(0) 333 | 334 | return logging_dict 335 | 336 | @staticmethod 337 | def create_vdd_agent(policy_params, optimizer_params, training_params, score_function, 338 | train_dataset=None, test_dataset=None): 339 | ### TODO: add the backbone here, support mlp and transformer 340 | backbone_params = policy_params['backbone_params'] 341 | 342 | backbone = GPTNetwork(obs_dim=policy_params['moe_params']['obs_dim'], 343 | goal_dim=policy_params['moe_params']['goal_dim'], 344 | output_dim=policy_params['moe_params']['cmp_mean_hidden_dims'], 345 | goal_conditional=policy_params['moe_params']['goal_conditional'], 346 | **backbone_params).to(training_params['device']) if backbone_params.pop('use_transformer') else None 347 | 348 | policy = GaussianMoE(**policy_params['moe_params'], backbone=backbone, device=training_params['device'], dtype=training_params['dtype']) 349 | 350 | if policy_params['vision_task']: 351 | if policy_params['copy_vision_encoder']: 352 | policy.vision_encoder = copy.deepcopy(score_function.model.model.obs_encoder) 353 | if policy_params['train_vision_encoder']: 354 | cmps_optimizer = get_optimizer(optimizer_type=optimizer_params['optimizer_type'], 355 | model_parameters=list(policy.joint_cmps.parameters())+list(policy.vision_encoder.parameters()), 356 | learning_rate=optimizer_params['cmps_lr'], 357 | weight_decay=optimizer_params['cmps_weight_decay']) 358 | else: 359 | cmps_optimizer = get_optimizer(optimizer_type=optimizer_params['optimizer_type'], 360 | model_parameters=policy.joint_cmps.parameters(), 361 | learning_rate=optimizer_params['cmps_lr'], 362 | weight_decay=optimizer_params['cmps_weight_decay']) 363 | else: 364 | cmps_optimizer = get_optimizer(optimizer_type=optimizer_params['optimizer_type'], 365 | model_parameters=policy.get_parameter('cmps'), 366 | learning_rate=optimizer_params['cmps_lr'], 367 | weight_decay=optimizer_params['cmps_weight_decay']) 368 | 369 | cmps_lr_scheduler = get_lr_schedule(optimizer_params['cmps_lr_schedule'], 370 | cmps_optimizer, training_params['max_train_iters']) \ 371 | if optimizer_params['cmps_lr_schedule'] is not None else None 372 | 373 | if policy_params['moe_params']['learn_gating']: 374 | gating_net_optimizer = get_optimizer(optimizer_type=optimizer_params['optimizer_type'], 375 | model_parameters=policy.get_parameter('gating'), 376 | learning_rate=optimizer_params['gating_lr'], 377 | weight_decay=optimizer_params['gating_weight_decay']) 378 | gating_lr_scheduler = get_lr_schedule(optimizer_params['gating_lr_schedule'], 379 | gating_net_optimizer, training_params['max_train_iters']) \ 380 | if optimizer_params['gating_lr_schedule'] is not None else None 381 | else: 382 | gating_net_optimizer = None 383 | gating_lr_scheduler = None 384 | 385 | vdd_agent = VDD(agent=policy, cmps_optimizer=cmps_optimizer, gating_optimizer=gating_net_optimizer, 386 | cmps_lr_scheduler=cmps_lr_scheduler, 387 | gating_lr_scheduler=gating_lr_scheduler, 388 | score_function=score_function, 389 | train_dataloader=train_dataset, 390 | test_dataloader=test_dataset, 391 | learn_gating=policy_params['moe_params']['learn_gating'], 392 | **training_params) 393 | return vdd_agent 394 | 395 | def save_best_model(self, path): 396 | """ 397 | save the model 398 | :return: 399 | """ 400 | ch.save(self.agent, os.path.join(path, f"best_model.pt")) 401 | 402 | def save_model(self, iteration, path): 403 | save_path = os.path.join(path, f"model_state_dict_{iteration}.pth") 404 | ch.save(self.agent.state_dict(), save_path) 405 | 406 | def save_debug_model(self, path): 407 | """ 408 | save the model 409 | :return: 410 | """ 411 | ch.save(self.agent, os.path.join(path, f"debug_model.pt")) 412 | 413 | -------------------------------------------------------------------------------- /vdd/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intuitive-robots/vdd/e1395f38dd896968abae21d4571c74cfb02ea0f8/vdd/networks/__init__.py -------------------------------------------------------------------------------- /vdd/networks/gating.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from vdd.networks.mlp import ResidualMLPNetwork 5 | 6 | 7 | class GatingNet(nn.Module): 8 | def __init__(self, 9 | obs_dim, 10 | n_components, 11 | num_hidden_layer, 12 | hidden_dim, 13 | device='cuda'): 14 | 15 | super(GatingNet, self).__init__() 16 | 17 | self.n_components = n_components 18 | 19 | self.network = ResidualMLPNetwork(obs_dim, n_components, hidden_dim, num_hidden_layer, 20 | dropout=0., device=device) 21 | 22 | self.trained = False 23 | 24 | def forward(self, observation): 25 | observation = self.network(observation) 26 | return torch.nn.functional.log_softmax(observation[..., :self.n_components], dim=-1) 27 | 28 | def sample(self, contexts): 29 | p = self.probabilities(contexts) 30 | thresholds = torch.cumsum(p, dim=-1) 31 | thresholds[:, -1] = 1.0 32 | eps = torch.rand(size=[contexts.shape[0], 1]) 33 | samples = torch.argmax((eps < thresholds) * 1., dim=-1) 34 | return samples 35 | 36 | def probabilities(self, contexts): 37 | return torch.exp(self(contexts)) 38 | 39 | def log_probabilities(self, contexts): 40 | return self(contexts) 41 | 42 | def entropies(self, contexts): 43 | p = self.probabilities(contexts) 44 | return -torch.sum(p * torch.log(p + 1e-25), dim=-1) 45 | 46 | def expected_entropy(self, contexts): 47 | return torch.mean(self.entropies(contexts)) 48 | 49 | def kls(self, contexts, other): 50 | p = self.probabilities(contexts) 51 | other_log_p = other.log_probabilities(contexts) 52 | return torch.sum(p * (torch.log(p + 1e-25) - other_log_p), dim=-1) 53 | 54 | def expected_kl(self, contexts, other): 55 | return torch.mean(self.kls(contexts, other)) 56 | 57 | def check_trained(self): 58 | if self.trained: 59 | return True 60 | else: 61 | raise ValueError('Inference network is not trained.') 62 | 63 | @property 64 | def params(self): 65 | return list(self.parameters()) 66 | 67 | @property 68 | def param_norm(self): 69 | """ 70 | Calculates the norm of network parameters. 71 | """ 72 | return torch.norm(torch.stack([torch.norm(p.detach()) for p in self.parameters()])) 73 | 74 | @property 75 | def grad_norm(self): 76 | """ 77 | Calculates the norm of current gradients. 78 | """ 79 | return torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in self.parameters()])) 80 | 81 | def add_component(self): 82 | self.mask[self.n_components] = 1 83 | self.n_components += 1 84 | 85 | def to_gpu(self): 86 | self.to(torch.device('cuda')) 87 | 88 | def to_cpu(self): 89 | self.to(torch.device('cpu')) 90 | 91 | 92 | class SoftCrossEntropyLoss(torch.nn.Module): 93 | def __init__(self): 94 | super(SoftCrossEntropyLoss, self).__init__() 95 | 96 | def forward(self, pred_log_resp, resp): 97 | return -(resp * pred_log_resp).mean() -------------------------------------------------------------------------------- /vdd/networks/gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import einops 4 | 5 | from vdd.networks.network_utils import inverse_softplus, fill_triangular, diag_bijector 6 | 7 | def get_gmm_head(n_dim, n_components, init_std, minimal_std, type='full', device='cuda'): 8 | if type == 'full': 9 | return FullGMMHead(n_dim, n_components, init_std, minimal_std, device=device) 10 | elif type == 'diag': 11 | return DiagGMMHead(n_dim, n_components, init_std, minimal_std, device=device) 12 | else: 13 | raise NotImplementedError(f"Unknown GMM head type {type}") 14 | 15 | 16 | class AbstractGaussianHead(nn.Module): 17 | def __init__(self, n_dim, init_std, minimal_std, device='cuda'): 18 | super(AbstractGaussianHead, self).__init__() 19 | self.device = device 20 | self.n_dim = n_dim 21 | self.input_dim = 2 * n_dim 22 | self.minimal_std = torch.tensor(minimal_std, device=device) 23 | self.init_std = torch.tensor(init_std, device=device) 24 | self.diag_activation = nn.Softplus() 25 | self.diag_activation_inv = inverse_softplus 26 | ##TODO: check if this is correct 27 | self._pre_activation_shift = inverse_softplus(self.init_std - self.minimal_std) 28 | 29 | def forward(self, mean, chol, train=True): 30 | raise NotImplementedError 31 | 32 | @staticmethod 33 | def log_prob(x, mean, chol): 34 | return torch.distributions.MultivariateNormal(mean, scale_tril=chol, validate_args=False).log_prob(x) 35 | 36 | @staticmethod 37 | def rsample(mean, chol, n=1): 38 | return torch.distributions.MultivariateNormal(mean, scale_tril=chol, validate_args=False).rsample((n,)) 39 | 40 | @staticmethod 41 | def entropy(mean, chol): 42 | return torch.distributions.MultivariateNormal(mean, scale_tril=chol, validate_args=False).entropy() 43 | 44 | def get_device(self, device: torch.device): 45 | self.device = device 46 | 47 | def get_params(self): 48 | return self.parameters() 49 | 50 | 51 | class AbstractGMMHead(AbstractGaussianHead): 52 | def __init__(self, n_dim, n_components, init_std, minimal_std, device='cuda'): 53 | super(AbstractGMMHead, self).__init__(n_dim, init_std, minimal_std, device=device) 54 | self.n_components = n_components 55 | self.flat_mean_dim = n_dim * n_components 56 | self.flat_chol_dim = 2 * n_dim * n_components 57 | 58 | def forward(self, flat_mean, flat_chol, train=True): 59 | raise NotImplementedError 60 | 61 | @staticmethod 62 | def gmm_log_prob(x, means, chols, gating): 63 | comps = torch.distributions.MultivariateNormal(means, scale_tril=chols) 64 | gmm = torch.distributions.MixtureSameFamily(gating, comps) 65 | return gmm.log_prob(x) 66 | 67 | @staticmethod 68 | def gmm_sample(means, chols, gating, n=1): 69 | comps = torch.distributions.MultivariateNormal(means, scale_tril=chols) 70 | gmm = torch.distributions.MixtureSameFamily(gating, comps) 71 | return gmm.sample((n,)) 72 | 73 | @staticmethod 74 | def log_responsibilities(pred_means, pred_chols, pred_gatings, samples): 75 | """ 76 | b -- state batch 77 | c -- the number of components 78 | v -- the number of vi samples 79 | a -- action dimension 80 | """ 81 | c = pred_means.shape[1] 82 | v = samples.shape[-2] 83 | 84 | ### pred_means: (b, c, a) 85 | ### pred_chols: (b, c, a, a) 86 | pred_means = pred_means[:, None, :, None, ...].repeat(1, 1, 1, v, 1) 87 | pred_chols = pred_chols[:, None, :, None, ...].repeat(1, 1, 1, v, 1, 1) 88 | 89 | samples = samples.unsqueeze(2).repeat(1, 1, c, 1, 1) 90 | 91 | ### samples: (b, c, c, v, a) 92 | ### log_probs_cmps: (b, c, c, v) 93 | log_probs_cmps = AbstractGaussianHead.log_prob(samples, pred_means, pred_chols) 94 | 95 | ### log_probs: (b, c, v) 96 | log_probs = log_probs_cmps.clone() 97 | log_probs = torch.einsum('ijj...->ij...', log_probs) 98 | 99 | log_gating = torch.log(pred_gatings) 100 | 101 | probs_cmps = log_probs_cmps.exp() 102 | 103 | margin = torch.einsum('ijkl,ik->ijl', probs_cmps, pred_gatings) 104 | 105 | log_margin = torch.log(margin + 1e-8) 106 | 107 | return log_probs + log_gating.unsqueeze(-1) - log_margin 108 | 109 | 110 | class FullGMMHead(AbstractGMMHead): 111 | def __init__(self, n_dim, n_components, init_std=1., minimal_std=1e-3, device='cuda'): 112 | super(FullGMMHead, self).__init__(n_dim, n_components, init_std, minimal_std, device=device) 113 | self.flat_mean_dim = n_dim * n_components 114 | self.flat_chol_dim = n_dim * (n_dim + 1) // 2 * n_components 115 | 116 | def forward(self, flat_means, flat_chols, train=True): 117 | assert flat_means.shape[-1] == self.flat_mean_dim, f"Expected {self.flat_mean_dim} but got {flat_means.shape[-1]}" 118 | assert flat_chols.shape[-1] == self.flat_chol_dim, f"Expected {self.flat_chol_dim} but got {flat_chols.shape[-1]}" 119 | means = einops.rearrange(flat_means, '... (n d) -> ... n d', n=self.n_components, d=self.n_dim) 120 | chols = einops.rearrange(flat_chols, '... (n d) -> ... n d', n=self.n_components) 121 | chols = fill_triangular(chols) 122 | chols = diag_bijector(lambda z: self.diag_activation(z + self._pre_activation_shift) + self.minimal_std, chols) 123 | return means, chols 124 | 125 | 126 | class DiagGMMHead(AbstractGMMHead): 127 | def __init__(self, n_dim, n_components, init_std=1., minimal_std=1e-3, device='cuda'): 128 | super(DiagGMMHead, self).__init__(n_dim, n_components, init_std, minimal_std, device=device) 129 | self.flat_mean_dim = n_dim * n_components 130 | self.flat_chol_dim = n_dim * n_components 131 | 132 | def forward(self, flat_mean, flat_chol, train=True): 133 | assert flat_mean.shape[-1] == self.flat_mean_dim, f"Expected {self.flat_mean_dim} but got {flat_mean.shape[-1]}" 134 | assert flat_chol.shape[-1] == self.flat_chol_dim, f"Expected {self.flat_chol_dim} but got {flat_chol.shape[-1]}" 135 | means = einops.rearrange(flat_mean, '... (n d) -> ... n d', n=self.n_components, d=self.n_dim) 136 | chols = einops.rearrange(flat_chol, '... (n d) -> ... n d', n=self.n_components, d=self.n_dim) 137 | chols = torch.diag_embed(chols) 138 | chols = diag_bijector(lambda z: self.diag_activation(z + self._pre_activation_shift) + self.minimal_std, chols) 139 | return means, chols -------------------------------------------------------------------------------- /vdd/networks/gpt.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class CausalSelfAttention(nn.Module): 12 | """ 13 | A vanilla multi-head masked self-attention layer with a projection at the end. 14 | It is possible to use torch.nn.MultiheadAttention here but I am including an 15 | explicit implementation here to show that there is nothing too scary here. 16 | """ 17 | 18 | def __init__( 19 | self, 20 | n_embd: int, 21 | n_heads: int, 22 | attn_pdrop: float, 23 | resid_pdrop: float, 24 | block_size: int, 25 | ): 26 | super().__init__() 27 | assert n_embd % n_heads == 0 28 | # key, query, value projections for all heads 29 | self.key = nn.Linear(n_embd, n_embd) 30 | self.query = nn.Linear(n_embd, n_embd) 31 | self.value = nn.Linear(n_embd, n_embd) 32 | # regularization 33 | self.attn_drop = nn.Dropout(attn_pdrop) 34 | self.resid_drop = nn.Dropout(resid_pdrop) 35 | # output projection 36 | self.proj = nn.Linear(n_embd, n_embd) 37 | # causal mask to ensure that attention is only applied to the left in the input sequence 38 | self.register_buffer( 39 | "mask", 40 | torch.tril(torch.ones(block_size, block_size)).view( 41 | 1, 1, block_size, block_size 42 | ), 43 | ) 44 | self.n_head = n_heads 45 | 46 | def forward(self, x): 47 | ( 48 | B, 49 | T, 50 | C, 51 | ) = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 52 | 53 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 54 | k = ( 55 | self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 56 | ) # (B, nh, T, hs) 57 | q = ( 58 | self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 59 | ) # (B, nh, T, hs) 60 | v = ( 61 | self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) 62 | ) # (B, nh, T, hs) 63 | 64 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 65 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 66 | att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf")) 67 | att = F.softmax(att, dim=-1) 68 | att = self.attn_drop(att) 69 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 70 | y = ( 71 | y.transpose(1, 2).contiguous().view(B, T, C) 72 | ) # re-assemble all head outputs side by side 73 | 74 | # output projection 75 | y = self.resid_drop(self.proj(y)) 76 | return y 77 | 78 | 79 | class Block(nn.Module): 80 | """an unassuming Transformer block""" 81 | 82 | def __init__( 83 | self, 84 | n_embd: int, 85 | n_heads: int, 86 | attn_pdrop: float, 87 | resid_pdrop: float, 88 | block_size: int, 89 | 90 | ): 91 | super().__init__() 92 | self.ln1 = nn.LayerNorm(n_embd) 93 | self.ln2 = nn.LayerNorm(n_embd) 94 | self.attn = CausalSelfAttention( 95 | n_embd, 96 | n_heads, 97 | attn_pdrop, 98 | resid_pdrop, 99 | block_size, 100 | ) 101 | self.mlp = nn.Sequential( 102 | nn.Linear(n_embd, 4 * n_embd), 103 | nn.GELU(), 104 | nn.Linear(4 * n_embd, n_embd), 105 | nn.Dropout(resid_pdrop), 106 | ) 107 | 108 | def forward(self, x): 109 | x = x + self.attn(self.ln1(x)) 110 | x = x + self.mlp(self.ln2(x)) 111 | return x 112 | 113 | 114 | class GPTNetwork(nn.Module): 115 | 116 | def __init__(self, 117 | obs_dim: int, 118 | goal_dim: int, 119 | output_dim: int, 120 | embed_dim: int, 121 | embed_pdrop: float, 122 | atten_pdrop: float, 123 | resid_pdrop: float, 124 | n_layers: int, 125 | n_heads: int, 126 | window_size: int, 127 | goal_conditional: bool, 128 | goal_seq_len: int = 1, 129 | linear_output: bool = False, 130 | pre_out_hidden_dim: int = 100, 131 | encode_actions: bool = False, 132 | action_dim: int = 0, 133 | device: str = 'cuda', ): 134 | 135 | super(GPTNetwork, self).__init__() 136 | self.device = device 137 | self.goal_conditional = goal_conditional 138 | 139 | self.goal_seq_len = goal_seq_len 140 | if not goal_conditional: 141 | goal_dim = 0 142 | self.goal_seq_len = 0 143 | 144 | ### window size is only for the state sequence, by default only one goal and one readout token 145 | ### window size: for (state, action) pairs, the window size is 2 * window_size 146 | ### the goal sequence length is 1 147 | ### TODO: extend to multiple readout tokens 148 | if encode_actions: 149 | block_size = self.goal_seq_len + 2 * window_size 150 | else: 151 | block_size = self.goal_seq_len + window_size 152 | 153 | ### sequence size for the state sequence, every (state, action) pair at the same timestep share the same PE 154 | sequence_size = self.goal_seq_len + window_size 155 | 156 | ### output dim can be different to action dim, 157 | ### as we can predict means and cholenskys of all components 158 | self.out_dim = output_dim 159 | self.obs_dim = obs_dim 160 | self.goal_dim = goal_dim 161 | 162 | self.encode_actions = encode_actions 163 | 164 | if encode_actions: 165 | self.action_dim = action_dim 166 | self.action_emb = nn.Linear(action_dim, embed_dim) 167 | 168 | # embedding layers 169 | ### Here we assume that the goal and state have the same dimension 170 | self.tok_emb = nn.Linear(obs_dim, embed_dim) 171 | self.pos_emb = nn.Parameter(torch.zeros(1, sequence_size, embed_dim)) 172 | self.drop = nn.Dropout(embed_pdrop) 173 | 174 | self.embed_dim = embed_dim 175 | 176 | self.block_size = block_size 177 | self.sequence_size = sequence_size 178 | self.window_size = window_size 179 | 180 | # transformer blocks 181 | self.blocks = nn.Sequential( 182 | *[Block(embed_dim, n_heads, atten_pdrop, resid_pdrop, block_size) for _ in range(n_layers)] 183 | ) 184 | 185 | # decoder head 186 | self.ln_f = nn.LayerNorm(embed_dim) 187 | 188 | if linear_output: 189 | self.head = nn.Linear(embed_dim, self.out_dim) 190 | else: 191 | self.head = nn.Sequential( 192 | nn.Linear(embed_dim, pre_out_hidden_dim), 193 | nn.SiLU(), 194 | nn.Linear(pre_out_hidden_dim, self.out_dim) 195 | ) 196 | 197 | self.apply(self._init_weights) 198 | 199 | logger.info(f"Number of parameters in GPT: {sum(p.numel() for p in self.parameters())}") 200 | 201 | 202 | def _init_weights(self, module): 203 | if isinstance(module, (nn.Linear, nn.Embedding)): 204 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 205 | if isinstance(module, nn.Linear) and module.bias is not None: 206 | torch.nn.init.zeros_(module.bias) 207 | elif isinstance(module, nn.LayerNorm): 208 | torch.nn.init.zeros_(module.bias) 209 | torch.nn.init.ones_(module.weight) 210 | elif isinstance(module, GPTNetwork): 211 | torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02) 212 | 213 | 214 | def configure_optimizers(self, train_config): 215 | """ 216 | This long function is unfortunately doing something very simple and is being very defensive: 217 | We are separating out all parameters of the model into two buckets: those that will experience 218 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 219 | We are then returning the PyTorch optimizer object. 220 | """ 221 | 222 | # separate out all parameters to those that will and won't experience regularizing weight decay 223 | decay = set() 224 | no_decay = set() 225 | whitelist_weight_modules = (torch.nn.Linear,) 226 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 227 | for mn, m in self.named_modules(): 228 | for pn, p in m.named_parameters(): 229 | fpn = "%s.%s" % (mn, pn) if mn else pn # full param name 230 | 231 | if pn.endswith("bias"): 232 | # all biases will not be decayed 233 | no_decay.add(fpn) 234 | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): 235 | # weights of whitelist modules will be weight decayed 236 | decay.add(fpn) 237 | elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): 238 | # weights of blacklist modules will NOT be weight decayed 239 | no_decay.add(fpn) 240 | 241 | # special case the position embedding parameter in the root GPT module as not decayed 242 | no_decay.add("pos_emb") 243 | 244 | # validate that we considered every parameter 245 | param_dict = {pn: p for pn, p in self.named_parameters()} 246 | inter_params = decay & no_decay 247 | union_params = decay | no_decay 248 | assert ( 249 | len(inter_params) == 0 250 | ), "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) 251 | assert ( 252 | len(param_dict.keys() - union_params) == 0 253 | ), "parameters %s were not separated into either decay/no_decay set!" % ( 254 | str(param_dict.keys() - union_params), 255 | ) 256 | 257 | # create the pytorch optimizer object 258 | optim_groups = [ 259 | { 260 | "params": [param_dict[pn] for pn in sorted(list(decay))], 261 | "weight_decay": train_config.weight_decay, 262 | }, 263 | { 264 | "params": [param_dict[pn] for pn in sorted(list(no_decay))], 265 | "weight_decay": 0.0, 266 | }, 267 | ] 268 | optimizer = torch.optim.AdamW( 269 | optim_groups, lr=train_config.learning_rate, betas=train_config.betas 270 | ) 271 | return optimizer 272 | 273 | 274 | def forward(self, states: torch.Tensor, goals: torch.Tensor = None, actions: torch.Tensor = None): 275 | """ 276 | Run the model forward. 277 | states: (B, T, obs_dim) 278 | actions: (B, T-1, action_dim) 279 | goals: (B, T, goal_dim) 280 | """ 281 | batch_size, window_size, dim = states.size() 282 | 283 | if self.encode_actions: 284 | assert actions.size(1) == window_size - 1, "Expected actions to have length T-1" 285 | 286 | assert window_size <= self.block_size, "Cannot forward, model block size is exhausted." 287 | assert dim == self.obs_dim, f"Expected state dim {self.obs_dim}, got {dim}" 288 | assert window_size <= self.window_size, f"Expected window size {self.window_size}, got {window_size}" 289 | 290 | state_embed = self.tok_emb(states) 291 | 292 | if self.goal_conditional: 293 | assert goals is not None, "Expected goals to be provided" 294 | assert goals.size(1) == self.goal_seq_len, f"Expected goal sequence length to be {self.goal_seq_len}, got {goals.size(1)}" 295 | goal_embed = self.tok_emb(goals) 296 | position_embeddings = self.pos_emb[:, :(window_size + self.goal_seq_len), :] 297 | goal_x = self.drop(goal_embed + position_embeddings[:, :self.goal_seq_len, :]) 298 | else: 299 | position_embeddings = self.pos_emb[:, :window_size, :] 300 | 301 | 302 | state_x = self.drop(state_embed + position_embeddings[:, self.goal_seq_len:, :]) 303 | 304 | if self.encode_actions: 305 | action_embed = self.action_emb(actions) 306 | action_x = self.drop(action_embed + position_embeddings[:, self.goal_seq_len:, :]) 307 | 308 | sa_seq = torch.stack([state_x, action_x], dim=1 309 | ).permute(0, 2, 1, 3).reshape(batch_size, 2 * window_size, self.embed_dim) 310 | else: 311 | sa_seq = state_x 312 | 313 | # next we stack everything together 314 | if self.goal_conditional: 315 | input_seq = torch.cat([goal_x, sa_seq], dim=1) 316 | else: 317 | input_seq = sa_seq 318 | 319 | x = self.blocks(input_seq) 320 | x = self.ln_f(x) 321 | x = x[:, self.goal_seq_len:, :] 322 | 323 | assert x.shape[1] == states.shape[1], f"Expected output window size {states.shape[1]}, got {x.shape[1]}" 324 | 325 | out = self.head(x) 326 | 327 | return out 328 | 329 | def get_params(self): 330 | return self.parameters() -------------------------------------------------------------------------------- /vdd/networks/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def get_activation_fn(activation_type: str): 5 | # build the activation layer 6 | if activation_type == "sigmoid": 7 | act = torch.nn.Sigmoid() 8 | elif activation_type == "tanh": 9 | act = torch.nn.Sigmoid() 10 | elif activation_type == "ReLU": 11 | act = torch.nn.ReLU() 12 | elif activation_type == "PReLU": 13 | act = torch.nn.PReLU() 14 | elif activation_type == "softmax": 15 | act = torch.nn.Softmax(dim=-1) 16 | elif activation_type == "Mish": 17 | act = torch.nn.Mish() 18 | else: 19 | act = torch.nn.PReLU() 20 | return act 21 | 22 | 23 | class MLPNetwork(nn.Module): 24 | """ 25 | Simple multi layer perceptron network which can be generated with different 26 | activation functions with and without spectral normalization of the weights 27 | """ 28 | 29 | def __init__( 30 | self, 31 | input_dim: int, 32 | hidden_dim: int = 100, 33 | num_hidden_layers: int = 1, 34 | output_dim=1, 35 | dropout: int = 0, 36 | activation: str = "ReLU", 37 | use_spectral_norm: bool = False, 38 | device: str = 'cuda' 39 | ): 40 | super(MLPNetwork, self).__init__() 41 | self.network_type = "mlp" 42 | # define number of variables in an input sequence 43 | self.input_dim = input_dim 44 | # the dimension of neurons in the hidden layer 45 | self.hidden_dim = hidden_dim 46 | self.num_hidden_layers = num_hidden_layers 47 | # number of samples per batch 48 | self.output_dim = output_dim 49 | self.dropout = dropout 50 | self.spectral_norm = use_spectral_norm 51 | # set up the network 52 | self.layers = nn.ModuleList([nn.Linear(self.input_dim, self.hidden_dim)]) 53 | self.layers.extend( 54 | [ 55 | nn.Linear(self.hidden_dim, self.hidden_dim) 56 | for i in range(1, self.num_hidden_layers) 57 | ] 58 | ) 59 | self.layers.append(nn.Linear(self.hidden_dim, self.output_dim)) 60 | 61 | # build the activation layer 62 | self.act = get_activation_fn(activation).to(device) 63 | self._device = device 64 | self.layers.to(self._device) 65 | 66 | def forward(self, x): 67 | for idx, layer in enumerate(self.layers): 68 | if idx == 0: 69 | out = layer(x) 70 | else: 71 | if idx < len(self.layers) - 2: 72 | out = layer(out) # + out 73 | else: 74 | out = layer(out) 75 | if idx < len(self.layers) - 1: 76 | out = self.act(out) 77 | return out 78 | 79 | def get_device(self, device: torch.device): 80 | self._device = device 81 | self.layers.to(device) 82 | 83 | def get_params(self): 84 | return self.layers.parameters() 85 | 86 | 87 | class ResidualMLPNetwork(nn.Module): 88 | """ 89 | Simple multi layer perceptron network with residual connections for 90 | benchmarking the performance of different networks. The resiudal layers 91 | are based on the IBC paper implementation, which uses 2 residual lalyers 92 | with pre-actication with or without dropout and normalization. 93 | """ 94 | 95 | def __init__( 96 | self, 97 | input_dim: int, 98 | output_dim: int, 99 | hidden_dim: int = 100, 100 | num_hidden_layers: int = 1, 101 | dropout: float = 0., 102 | activation: str = "Mish", 103 | use_norm: bool = False, 104 | norm_style: str = 'BatchNorm', 105 | device: str = 'cuda' 106 | ): 107 | super(ResidualMLPNetwork, self).__init__() 108 | self.network_type = "mlp" 109 | self._device = device 110 | # set up the network 111 | 112 | assert num_hidden_layers % 2 == 0 113 | 114 | self.layers = nn.ModuleList([nn.Linear(input_dim, hidden_dim)]) 115 | self.layers.extend( 116 | [ 117 | TwoLayerPreActivationResNetLinear( 118 | hidden_dim=hidden_dim, 119 | activation=activation, 120 | dropout_rate=dropout, 121 | use_norm=use_norm, 122 | norm_style=norm_style 123 | ) 124 | for i in range(1, num_hidden_layers, 2) 125 | ] 126 | ) 127 | self.layers.append(nn.Linear(hidden_dim, output_dim)) 128 | self.layers.to(self._device) 129 | 130 | def forward(self, x): 131 | 132 | for idx, layer in enumerate(self.layers): 133 | x = layer(x.to(torch.float32)) 134 | return x 135 | 136 | def get_device(self, device: torch.device): 137 | self._device = device 138 | self.layers.to(device) 139 | 140 | def get_params(self): 141 | return self.layers.parameters() 142 | 143 | 144 | class TwoLayerPreActivationResNetLinear(nn.Module): 145 | 146 | def __init__( 147 | self, 148 | hidden_dim: int = 100, 149 | activation: str = 'relu', 150 | dropout_rate: float = 0.25, 151 | use_norm: bool = False, 152 | norm_style: str = 'BatchNorm' 153 | ) -> None: 154 | super().__init__() 155 | 156 | self.l1 = nn.Linear(hidden_dim, hidden_dim) 157 | self.l2 = nn.Linear(hidden_dim, hidden_dim) 158 | self.dropout = nn.Dropout(dropout_rate) 159 | self.use_norm = use_norm 160 | self.act = get_activation_fn(activation) 161 | 162 | if use_norm: 163 | if norm_style == 'BatchNorm': 164 | self.normalizer = nn.BatchNorm1d(hidden_dim) 165 | elif norm_style == 'LayerNorm': 166 | self.normalizer = torch.nn.LayerNorm(hidden_dim, eps=1e-06) 167 | else: 168 | raise ValueError('not a defined norm type') 169 | 170 | def forward(self, x): 171 | x_input = x 172 | if self.use_norm: 173 | x = self.normalizer(x) 174 | x = self.l1(self.dropout(self.act(x))) 175 | if self.use_norm: 176 | x = self.normalizer(x) 177 | x = self.l2(self.dropout(self.act(x))) 178 | return x + x_input 179 | -------------------------------------------------------------------------------- /vdd/networks/moes.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import torch 3 | import torch.distributions as D 4 | from torch import nn 5 | 6 | import einops 7 | 8 | from vdd.networks.gaussian import get_gmm_head 9 | from vdd.networks.gating import GatingNet 10 | from vdd.networks.mlp import ResidualMLPNetwork, MLPNetwork 11 | from vdd.networks.network_utils import str2torchdtype, initialize_weights 12 | 13 | 14 | class GaussianMoE(nn.Module): 15 | def __init__(self, num_components, obs_dim, act_dim, goal_dim, 16 | prior_type, cmp_init, cmp_cov_type='diag', 17 | bias_init_bound=0.5, 18 | backbone = None, 19 | moe_network_type = 'residual', 20 | cmp_mean_hidden_dims = 64, 21 | cmp_mean_hidden_layers = 2, 22 | cmp_cov_hidden_dims = 64, 23 | cmp_cov_hidden_layers = 2, 24 | cmp_activation="tanh", 25 | cmp_init_std=1., cmp_minimal_std=1e-5, 26 | learn_gating=False, gating_hidden_layers=4, gating_hidden_dims = 64, 27 | dtype="float32", device="cpu", **kwargs): 28 | super(GaussianMoE, self).__init__() 29 | self.n_components = num_components 30 | self.obs_dim = obs_dim 31 | self.act_dim = act_dim 32 | 33 | self.device = device 34 | self.dtype = str2torchdtype(dtype) 35 | 36 | self.backbone = backbone 37 | input_dim = obs_dim + goal_dim if backbone is None else cmp_mean_hidden_dims, 38 | self.gmm_head = get_gmm_head(act_dim, num_components, cmp_init_std, cmp_minimal_std, cmp_cov_type, device=device) 39 | 40 | if moe_network_type == 'residual': 41 | self.gmm_mean_net = ResidualMLPNetwork(input_dim=input_dim[0], 42 | output_dim=self.gmm_head.flat_mean_dim, 43 | hidden_dim=cmp_mean_hidden_dims, 44 | num_hidden_layers=cmp_mean_hidden_layers, 45 | activation=cmp_activation, 46 | device=device) 47 | self.gmm_cov_net = ResidualMLPNetwork(input_dim=input_dim[0], 48 | output_dim=self.gmm_head.flat_chol_dim, 49 | hidden_dim=cmp_cov_hidden_dims, 50 | num_hidden_layers=cmp_cov_hidden_layers, 51 | activation=cmp_activation, 52 | device=device) 53 | else: 54 | self.gmm_mean_net = MLPNetwork(input_dim=input_dim[0], 55 | output_dim=self.gmm_head.flat_mean_dim, 56 | hidden_dim=cmp_mean_hidden_dims, 57 | num_hidden_layers=cmp_mean_hidden_layers, 58 | activation=cmp_activation, 59 | device=device) 60 | self.gmm_cov_net = MLPNetwork(input_dim=input_dim[0], 61 | output_dim=self.gmm_head.flat_chol_dim, 62 | hidden_dim=cmp_cov_hidden_dims, 63 | num_hidden_layers=cmp_cov_hidden_layers, 64 | activation=cmp_activation, 65 | device=device) 66 | 67 | initialize_weights(self.gmm_mean_net, cmp_init, 0.01, 1e-4) 68 | initialize_weights(self.gmm_cov_net, cmp_init, 0.01, 1e-4) 69 | self.uniform_distribute_components(bound=bias_init_bound) 70 | 71 | if hasattr(self.backbone, 'window_size'): 72 | self.window_size = self.backbone.window_size 73 | else: 74 | self.window_size = 1 75 | 76 | self.obs_contexts = deque(maxlen=self.window_size) 77 | 78 | self.learn_gating = learn_gating 79 | 80 | self.greedy_predict = kwargs.get("greedy_predict", False) 81 | 82 | self.gating_network = GatingNet(input_dim[0], num_components, gating_hidden_layers, 83 | gating_hidden_dims, device=device) if learn_gating else None 84 | 85 | if prior_type == 'uniform': 86 | self._prior = torch.ones(num_components, device=self.device, dtype=self.dtype) / num_components 87 | else: 88 | raise NotImplementedError(f"Prior type {prior_type} not implemented.") 89 | 90 | 91 | def uniform_distribute_components(self, bound=0.5): 92 | self.gmm_mean_net.layers[-1].weight.data.fill_(0) 93 | self.gmm_mean_net.layers[-1].bias.data = torch.rand(self.act_dim * self.n_components, device=self.device) * 2 * bound - bound 94 | 95 | def reset(self): 96 | self.obs_contexts.clear() 97 | 98 | def forward(self, states, goals=None, train=True): 99 | self.train(train) 100 | 101 | if self.backbone is not None: 102 | x = self.backbone(states=states, goals=goals) 103 | else: 104 | x = torch.cat([states, goals], dim=-1) if goals is not None else states 105 | 106 | pre_means = self.gmm_mean_net(x) 107 | pre_chols = self.gmm_cov_net(x) 108 | cmp_means, cmp_chols = self.gmm_head(pre_means, pre_chols) 109 | cmp_means = einops.rearrange(cmp_means, 'b t c d -> b c t d') 110 | cmp_chols = einops.rearrange(cmp_chols, 'b t c d1 d2 -> b c t d1 d2') 111 | 112 | if self.gating_network is None: 113 | gating_probs = einops.repeat(self._prior, 'c -> b c t', b=states.shape[0], t=states.shape[1]) 114 | else: 115 | x = x.clone().detach() 116 | gating_probs = self.gating_network(x).exp() + 1e-8 117 | gating_probs = einops.repeat(gating_probs, 'b t c -> b c t') 118 | 119 | return cmp_means, cmp_chols, gating_probs 120 | 121 | def sample(self, cmp_means, cmp_chols, gating=None, n=1): 122 | if gating is None: 123 | prior = self._prior.unsqueeze(0).repeat(cmp_means.shape[0], 1) 124 | gating = D.Categorical(probs=prior) 125 | else: 126 | gating = D.Categorical(gating) 127 | 128 | return self.gmm_head.gmm_sample(cmp_means, cmp_chols, gating, n) 129 | 130 | 131 | @torch.no_grad() 132 | def act(self, state, goal=None, vision_task=False): 133 | if vision_task: 134 | self.agentview_image_contexts.append(state[0]) 135 | self.inhand_image_contexts.append(state[1]) 136 | self.robot_ee_pos_contexts.append(state[2]) 137 | agentview_image_seq = torch.stack(list(self.agentview_image_contexts), dim=1) 138 | inhand_image_seq = torch.stack(list(self.inhand_image_contexts), dim=1) 139 | robot_ee_pos_seq = torch.stack(list(self.robot_ee_pos_contexts), dim=1) 140 | input_states = (agentview_image_seq, inhand_image_seq, robot_ee_pos_seq) 141 | else: 142 | self.obs_contexts.append(state) 143 | input_states = torch.stack(list(self.obs_contexts), dim=1) 144 | 145 | if goal is not None and len(goal.size()) == 2: 146 | goal = goal.unsqueeze(0) 147 | 148 | cmp_means, cmp_chols, gating = self(input_states, goal, train=False) 149 | 150 | cmp_means = cmp_means[..., -1, :].squeeze(0) 151 | gating = gating[..., -1].squeeze(0) 152 | 153 | if self.greedy_predict: 154 | indexs = gating.argmax(-1) 155 | else: 156 | gating_dist = D.Categorical(gating) 157 | indexs = gating_dist.sample([1]) 158 | action_means = cmp_means[indexs, :] 159 | 160 | return action_means 161 | 162 | def get_parameter(self, target: str) -> "Parameter": 163 | if target == "gating": 164 | return self.gating_network.parameters() 165 | elif target == "cmps": 166 | if self.backbone is None: 167 | return list(self.gmm_mean_net.parameters()) + list(self.gmm_cov_net.parameters()) + list(self.gmm_head.parameters()) 168 | return list(self.gmm_mean_net.parameters()) + list(self.gmm_cov_net.parameters()) + list(self.gmm_head.parameters()) + list(self.backbone.parameters()) 169 | else: 170 | raise ValueError(f"Unknown target {target}") -------------------------------------------------------------------------------- /vdd/networks/network_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.optim.optimizer import Optimizer 7 | 8 | import numpy as np 9 | 10 | def str2torchdtype(str_dtype: str = 'float32'): 11 | if str_dtype == 'float32': 12 | return torch.float32 13 | elif str_dtype == 'float64': 14 | return torch.float64 15 | elif str_dtype == 'float16': 16 | return torch.float16 17 | elif str_dtype == 'int32': 18 | return torch.int32 19 | elif str_dtype == 'int64': 20 | return torch.int64 21 | else: 22 | raise NotImplementedError 23 | 24 | 25 | def fanin_init(tensor, scale=1 / 3): 26 | size = tensor.size() 27 | if len(size) == 2: 28 | fan_in = size[0] 29 | elif len(size) > 2: 30 | fan_in = np.prod(size[1:]) 31 | else: 32 | raise Exception("Shape must be have dimension at least 2.") 33 | bound = np.sqrt(3 * scale / fan_in) 34 | return tensor.data.uniform_(-bound, bound) 35 | 36 | 37 | def initialize_weights(mod, initialization_type, gain: float = 2 ** 0.5, scale=1 / 3, init_w=3e-3): 38 | """ 39 | Weight initializer for the models. 40 | Inputs: A model, Returns: none, initializes the parameters 41 | """ 42 | for p in mod.parameters(): 43 | if initialization_type == "normal": 44 | if len(p.data.shape) >= 2: 45 | p.data.normal_(init_w) # 0.01 46 | else: 47 | p.data.zero_() 48 | elif initialization_type == "uniform": 49 | if len(p.data.shape) >= 2: 50 | p.data.uniform_(-init_w, init_w) 51 | else: 52 | p.data.zero_() 53 | elif initialization_type == "fanin": 54 | if len(p.data.shape) >= 2: 55 | fanin_init(p, scale) 56 | else: 57 | p.data.zero_() 58 | elif initialization_type == "xavier": 59 | if len(p.data.shape) >= 2: 60 | nn.init.xavier_uniform_(p.data) 61 | else: 62 | p.data.zero_() 63 | elif initialization_type == "orthogonal": 64 | if len(p.data.shape) >= 2: 65 | nn.init.orthogonal_(p.data, gain=gain) 66 | else: 67 | p.data.zero_() 68 | else: 69 | raise ValueError("Need a valid initialization key") 70 | 71 | 72 | def inverse_softplus(x): 73 | 74 | """ 75 | x = inverse_softplus(softplus(x)) 76 | Args: 77 | x: data 78 | 79 | Returns: 80 | 81 | """ 82 | return (x.exp() - 1.).log() 83 | 84 | def fill_triangular(x, upper=False): 85 | """ 86 | From: https://github.com/tensorflow/probability/blob/c833ee5cd9f60f3257366b25447b9e50210b0590/tensorflow_probability/python/math/linalg.py#L787 87 | License: Apache-2.0 88 | 89 | Creates a (batch of) triangular matrix from a vector of inputs. 90 | 91 | Created matrix can be lower- or upper-triangular. (It is more efficient to 92 | create the matrix as upper or lower, rather than transpose.) 93 | 94 | Triangular matrix elements are filled in a clockwise spiral. See example, 95 | below. 96 | 97 | If `x.shape` is `[b1, b2, ..., bB, d]` then the output shape is 98 | `[b1, b2, ..., bB, n, n]` where `n` is such that `d = n(n+1)/2`, i.e., 99 | `n = int(np.sqrt(0.25 + 2. * m) - 0.5)`. 100 | 101 | Example: 102 | 103 | ```python 104 | fill_triangular([1, 2, 3, 4, 5, 6]) 105 | # ==> [[4, 0, 0], 106 | # [6, 5, 0], 107 | # [3, 2, 1]] 108 | 109 | fill_triangular([1, 2, 3, 4, 5, 6], upper=True) 110 | # ==> [[1, 2, 3], 111 | # [0, 5, 6], 112 | # [0, 0, 4]] 113 | ``` 114 | 115 | The key trick is to create an upper triangular matrix by concatenating `x` 116 | and a tail of itself, then reshaping. 117 | 118 | Suppose that we are filling the upper triangle of an `n`-by-`n` matrix `M` 119 | from a vector `x`. The matrix `M` contains n**2 entries total. The vector `x` 120 | contains `n * (n+1) / 2` entries. For concreteness, we'll consider `n = 5` 121 | (so `x` has `15` entries and `M` has `25`). We'll concatenate `x` and `x` with 122 | the first (`n = 5`) elements removed and reversed: 123 | 124 | ```python 125 | x = np.arange(15) + 1 126 | xc = np.concatenate([x, x[5:][::-1]]) 127 | # ==> array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 15, 14, 13, 128 | # 12, 11, 10, 9, 8, 7, 6]) 129 | 130 | # (We add one to the arange result to disambiguate the zeros below the 131 | # diagonal of our upper-triangular matrix from the first entry in `x`.) 132 | 133 | # Now, when reshapedlay this out as a matrix: 134 | y = np.reshape(xc, [5, 5]) 135 | # ==> array([[ 1, 2, 3, 4, 5], 136 | # [ 6, 7, 8, 9, 10], 137 | # [11, 12, 13, 14, 15], 138 | # [15, 14, 13, 12, 11], 139 | # [10, 9, 8, 7, 6]]) 140 | 141 | # Finally, zero the elements below the diagonal: 142 | y = np.triu(y, k=0) 143 | # ==> array([[ 1, 2, 3, 4, 5], 144 | # [ 0, 7, 8, 9, 10], 145 | # [ 0, 0, 13, 14, 15], 146 | # [ 0, 0, 0, 12, 11], 147 | # [ 0, 0, 0, 0, 6]]) 148 | ``` 149 | 150 | From this example we see tht the resuting matrix is upper-triangular, and 151 | contains all the entries of ax, as desired. The rest is details: 152 | 153 | - If `n` is even, `x` doesn't exactly fill an even number of rows (it fills 154 | `n / 2` rows and half of an additional row), but the whole scheme still 155 | works. 156 | - If we want a lower triangular matrix instead of an upper triangular, 157 | we remove the first `n` elements from `x` rather than from the reversed 158 | `x`. 159 | 160 | For additional comparisons, a pure numpy version of this function can be found 161 | in `distribution_util_test.py`, function `_fill_triangular`. 162 | 163 | Args: 164 | x: `Tensor` representing lower (or upper) triangular elements. 165 | upper: Python `bool` representing whether output matrix should be upper 166 | triangular (`True`) or lower triangular (`False`, default). 167 | 168 | Returns: 169 | tril: `Tensor` with lower (or upper) triangular elements filled from `x`. 170 | 171 | Raises: 172 | ValueError: if `x` cannot be mapped to a triangular matrix. 173 | """ 174 | 175 | m = np.int32(x.shape[-1]) 176 | # Formula derived by solving for n: m = n(n+1)/2. 177 | n = np.sqrt(0.25 + 2. * m) - 0.5 178 | if n != np.floor(n): 179 | raise ValueError('Input right-most shape ({}) does not ' 180 | 'correspond to a triangular matrix.'.format(m)) 181 | n = np.int32(n) 182 | new_shape = x.shape[:-1] + (n, n) 183 | 184 | ndims = len(x.shape) 185 | if upper: 186 | x_list = [x, torch.flip(x[..., n:], dims=[ndims - 1])] 187 | else: 188 | x_list = [x[..., n:], torch.flip(x, dims=[ndims - 1])] 189 | 190 | x = torch.cat(x_list, dim=-1).reshape(new_shape) 191 | x = torch.triu(x) if upper else torch.tril(x) 192 | return x 193 | 194 | def diag_bijector(f: callable, x): 195 | """ 196 | Apply transformation f(x) on the diagonal of a batched matrix. 197 | Args: 198 | f: callable to apply to diagonal 199 | x: data 200 | 201 | Returns: 202 | transformed matrix x 203 | """ 204 | return x.tril(-1) + f(x.diagonal(dim1=-2, dim2=-1)).diag_embed() + x.triu(1) 205 | 206 | def get_optimizer(optimizer_type: str, model_parameters: Union[Iterable[torch.Tensor], Iterable[dict]], 207 | learning_rate: float, **kwargs): 208 | """ 209 | Get optimizer instance for given model parameters 210 | Args: 211 | model_parameters: 212 | optimizer_type: 213 | learning_rate: 214 | **kwargs: 215 | 216 | Returns: 217 | 218 | """ 219 | if optimizer_type.lower() == "sgd": 220 | return optim.SGD(model_parameters, learning_rate, **kwargs) 221 | elif optimizer_type.lower() == "sgd_momentum": 222 | momentum = kwargs.pop("momentum") if kwargs.get("momentum") else 0.9 223 | return optim.SGD(model_parameters, learning_rate, momentum=momentum, **kwargs) 224 | elif optimizer_type.lower() == "adam": 225 | return optim.Adam(model_parameters, learning_rate, **kwargs) 226 | elif optimizer_type.lower() == "adamw": 227 | return optim.AdamW(model_parameters, learning_rate, betas=(0.95, 0.999), eps=1e-8, **kwargs) 228 | elif optimizer_type.lower() == "adagrad": 229 | return optim.adagrad.Adagrad(model_parameters, learning_rate, **kwargs) 230 | else: 231 | ValueError(f"Optimizer {optimizer_type} is not supported.") 232 | 233 | 234 | def get_lr_schedule(schedule_type: str, optimizer: Optimizer, total_iters) -> Union[ 235 | optim.lr_scheduler._LRScheduler, None]: 236 | if not schedule_type or schedule_type.isspace(): 237 | return None 238 | 239 | elif schedule_type.lower() == "linear": 240 | return optim.lr_scheduler.LinearLR(optimizer, start_factor=1., end_factor=0., total_iters=total_iters) 241 | 242 | elif schedule_type.lower() == "papi": 243 | # Multiply learning rate with 0.8 every time the backtracking fails 244 | return optim.lr_scheduler.MultiplicativeLR(optimizer, lambda n_calls: 0.8) 245 | 246 | elif schedule_type.lower() == "performance": 247 | return optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: 0.8), \ 248 | optim.lr_scheduler.MultiplicativeLR(optimizer, lambda epoch: 1.01) 249 | 250 | else: 251 | raise ValueError( 252 | f"Learning rate schedule {schedule_type} is not supported. Select one of [None, linear, papi, performance].") -------------------------------------------------------------------------------- /vdd/score_functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intuitive-robots/vdd/e1395f38dd896968abae21d4571c74cfb02ea0f8/vdd/score_functions/__init__.py -------------------------------------------------------------------------------- /vdd/score_functions/beso_score.py: -------------------------------------------------------------------------------- 1 | from vdd.score_functions.score_base import ScoreFunction 2 | 3 | from beso.agents.diffusion_agents.beso_agent import BesoAgent 4 | 5 | import torch 6 | import einops 7 | 8 | class BesoScoreFunction(ScoreFunction): 9 | def __init__(self, model: BesoAgent, sigma_index=-1, obs_dim=10, goal_dim=10, weights_type='srpo', 10 | sigma_min=0.1, sigma_max=1.0, anneal_end_iter=1e6, 11 | noise_level_type='uniform', device='cuda', **kwargs): 12 | super().__init__(model) 13 | self.sigma_index = sigma_index 14 | self.normalize_score = False 15 | self.goal_dim = goal_dim 16 | self.obs_dim = obs_dim 17 | self.weights_type = weights_type 18 | self.sigma_min = sigma_min 19 | self.sigma_max = sigma_max 20 | self.annealing_end_iter = anneal_end_iter 21 | self.noise_level_type = noise_level_type 22 | self.device = device 23 | 24 | def __call__(self, samples, states, goals=None, iter=None, vision_task=False): 25 | return self._get_score(samples, states, goals, iter, vision_task) 26 | 27 | @torch.no_grad() 28 | def _get_score(self, samples, state, goal, iter=None, vision_task=False): 29 | self.model.model.eval() 30 | 31 | noise_level = self._get_noise_level(samples, noise_level_type=self.noise_level_type, iter=iter).to(self.device) 32 | 33 | weights = self._get_weights(noise_level[..., None, None], weights_type=self.weights_type).to(self.device) 34 | 35 | ### einpack the samples 36 | # b = samples.shape[0] 37 | # c = samples.shape[1] 38 | # v = samples.shape[2] 39 | 40 | (b, c, v, t) = samples.shape[:4] 41 | 42 | if vision_task: 43 | # self.model.model.obs_encoder.eval() 44 | ### hack for vision-based tasks 45 | agent_view_image = einops.rearrange(state[0], 'b t ... -> (b t) ... ') 46 | in_hand_image = einops.rearrange(state[1], 'b t ... -> (b t) ... ') 47 | robot_ee_pos = einops.rearrange(state[2], 'b t ... -> (b t) ... ') 48 | state_dict = {"agentview_image": agent_view_image, 49 | "in_hand_image": in_hand_image, 50 | "robot_ee_pos": robot_ee_pos} 51 | try: 52 | state = self.model.model.obs_encoder(state_dict) 53 | except Exception as e: 54 | print("error: ", e) 55 | print("Error in encoding the state") 56 | 57 | pack_state = einops.rearrange(state, '(b t) ... -> b t ...', b=b, t=t) 58 | pack_state = einops.repeat(pack_state, 'b t ... -> b c v t ...', c=c, v=v) 59 | pack_state = einops.rearrange(pack_state, 'b c v t ... -> (b c v) t ...') 60 | else: 61 | pack_state = einops.rearrange(state, 'b c v t ... -> (b c v) t ...') 62 | pack_samples = einops.rearrange(samples, 'b c v t ... -> (b c v) t ...') 63 | pack_goal = einops.rearrange(goal, 'b c v t ... -> (b c v) t ...') if goal is not None else None 64 | pack_noise_level = einops.rearrange(noise_level, 'b c v -> (b c v)') 65 | 66 | s_in = torch.ones_like(pack_noise_level).to(self.device) 67 | 68 | if vision_task: 69 | denoised = self.model.model.model(state=pack_state, action=pack_samples, goal=pack_goal, sigma=pack_noise_level * s_in) 70 | else: 71 | denoised = self.model.model(state=pack_state, action=pack_samples, goal=pack_goal, sigma=pack_noise_level * s_in) 72 | 73 | ### unpack the denoised samples 74 | denoised = einops.rearrange(denoised, '(b c v) t d -> b c v t d', b=b, c=c, v=v) 75 | 76 | ### score D(x;sigma) - x / sigma^2 77 | ###TODO:FIXME: check if the broadcast is correct 78 | score = (denoised - samples) / noise_level[..., None, None] ** 2 79 | 80 | if self.normalize_score: 81 | score = score/torch.norm(score, dim=-1, keepdim=True) 82 | return score * weights, noise_level 83 | 84 | 85 | def _get_noise_level(self, samples, noise_level_type='uniform', iter=None): 86 | if noise_level_type == 'uniform': 87 | return torch.rand(samples.shape[:3]) * (self.sigma_max - self.sigma_min) + self.sigma_min 88 | elif noise_level_type == 'last_sigma': 89 | return torch.ones(samples.shape[:3]) * self.sigma_min 90 | elif noise_level_type == 'anneal': 91 | iter = min(iter, self.annealing_end_iter) 92 | annealed_sigma = self.sigma_min + (self.sigma_max - self.sigma_min) * (1 - iter / self.annealing_end_iter) 93 | return torch.ones(samples.shape[:3]) * annealed_sigma 94 | else: 95 | raise ValueError(f"Unknown noise level type: {noise_level_type}") 96 | 97 | def _get_weights(self, noise_level, weights_type='stable'): 98 | if weights_type == 'stable': 99 | return torch.ones_like(noise_level) 100 | else: 101 | raise ValueError(f"Unknown weights type: {weights_type}") -------------------------------------------------------------------------------- /vdd/score_functions/ddpm_score.py: -------------------------------------------------------------------------------- 1 | from vdd.score_functions.score_base import ScoreFunction 2 | 3 | from agents.ddpm_agent import DiffusionAgent 4 | 5 | from copy import deepcopy 6 | 7 | import torch 8 | import einops 9 | 10 | class DDPMScoreFunction(ScoreFunction): 11 | def __init__(self, model: DiffusionAgent, sigma_index=-1, obs_dim=10, goal_dim=10, weights_type='srpo', 12 | t_min=1, t_max=8, t_bound=8, anneal_end_iter=1e6, 13 | noise_level_type='uniform', device='cuda', **kwargs): 14 | super().__init__(model) 15 | self.sigma_index = sigma_index 16 | self.goal_dim = goal_dim 17 | self.obs_dim = obs_dim 18 | self.weights_type = weights_type 19 | self.t_min = t_min 20 | self.t_max = t_max 21 | self.t_bound = t_bound 22 | self.annealing_end_iter = anneal_end_iter 23 | self.noise_level_type = noise_level_type 24 | self.device = device 25 | 26 | self.vision_task = kwargs.get('vision_task', False) 27 | 28 | if self.vision_task: 29 | self.noise_network = self.model.model.model.model 30 | self.betas = self.model.model.model.betas.clone() 31 | self.sqrt_one_minus_alphas_cumprod = self.model.model.model.sqrt_one_minus_alphas_cumprod.clone() 32 | else: 33 | self.noise_network = self.model.model.model 34 | self.betas = self.model.model.betas.clone() 35 | self.sqrt_one_minus_alphas_cumprod = self.model.model.sqrt_one_minus_alphas_cumprod.clone() 36 | 37 | print("DDPM Score Function Initialized") 38 | 39 | def __call__(self, samples, states, goals=None, iter=None, vision_task=False): 40 | return self._get_score(samples, states, goals, iter, vision_task) 41 | 42 | @torch.no_grad() 43 | def _get_score(self, samples, state, goal, iter=None, vision_task=False): 44 | self.noise_network.eval() 45 | 46 | noise_level = self._get_noise_level(samples, noise_level_type=self.noise_level_type, iter=iter).to(self.device) 47 | 48 | weights = self._get_weights(noise_level[..., None, None], weights_type=self.weights_type).to(self.device) 49 | 50 | ### einpack the samples 51 | # b = samples.shape[0] 52 | # c = samples.shape[1] 53 | # v = samples.shape[2] 54 | 55 | (b, c, v, t) = samples.shape[:4] 56 | 57 | if vision_task: 58 | # self.model.model.obs_encoder.eval() 59 | ### hack for vision-based tasks 60 | agent_view_image = einops.rearrange(state[0], 'b t ... -> (b t) ... ') 61 | in_hand_image = einops.rearrange(state[1], 'b t ... -> (b t) ... ') 62 | robot_ee_pos = einops.rearrange(state[2], 'b t ... -> (b t) ... ') 63 | state_dict = {"agentview_image": agent_view_image, 64 | "in_hand_image": in_hand_image, 65 | "robot_ee_pos": robot_ee_pos} 66 | try: 67 | state = self.model.model.obs_encoder(state_dict) 68 | except Exception as e: 69 | print("error: ", e) 70 | print("Error in encoding the state") 71 | 72 | pack_state = einops.rearrange(state, '(b t) ... -> b t ...', b=b, t=t) 73 | pack_state = einops.repeat(pack_state, 'b t ... -> b c v t ...', c=c, v=v) 74 | pack_state = einops.rearrange(pack_state, 'b c v t ... -> (b c v) t ...') 75 | else: 76 | pack_state = einops.rearrange(state, 'b c v t ... -> (b c v) t ...') 77 | pack_samples = einops.rearrange(samples, 'b c v t ... -> (b c v) t ...') 78 | pack_goal = einops.rearrange(goal, 'b c v t ... -> (b c v) t ...') if goal is not None else None 79 | pack_noise_level = einops.rearrange(noise_level, 'b c v -> (b c v)') 80 | 81 | noise = self.noise_network(actions=pack_samples, time=pack_noise_level, states=pack_state, goals=pack_goal) 82 | 83 | ### unpack the denoised samples 84 | noise = einops.rearrange(noise, '(b c v) t d -> b c v t d', b=b, c=c, v=v) 85 | 86 | ### score epsilon(x;sigma) / sqrt(beta) 87 | ###TODO:FIXME: check if the broadcast is correct 88 | if self.noise_level_type == 'discrete': 89 | score = - noise / self.sqrt_one_minus_alphas_cumprod[noise_level][..., None, None] 90 | elif self.noise_level_type == 'uniform': 91 | score = - noise / ((noise_level[..., None, None].float() + 1e-4)/self.t_bound) 92 | else: 93 | raise ValueError(f"Unknown noise level type: {self.noise_level_type}, expected 'discrete' or 'uniform'") 94 | 95 | return score * weights, noise_level 96 | 97 | 98 | def _get_noise_level(self, samples, noise_level_type='uniform', iter=None): 99 | 100 | if noise_level_type == 'discrete': 101 | # torch.randint is exclusive of the upper bound 102 | sampled_t = torch.randint(self.t_min, self.t_max+1, samples.shape[:3]) 103 | elif noise_level_type == 'uniform': 104 | sampled_t = torch.rand(samples.shape[:3]) * (self.t_max - self.t_min) + self.t_min 105 | else: 106 | raise ValueError(f"Unknown noise level type: {noise_level_type}, expected 'discrete' or 'uniform'") 107 | return sampled_t 108 | 109 | def _get_weights(self, noise_level, weights_type='stable'): 110 | if weights_type == 'stable': 111 | return torch.ones_like(noise_level) 112 | else: 113 | raise ValueError(f"Unknown weights type: {weights_type}") -------------------------------------------------------------------------------- /vdd/score_functions/gmm_score.py: -------------------------------------------------------------------------------- 1 | from vdd.score_functions.score_base import ScoreFunction 2 | 3 | import torch as ch 4 | import torch.distributions as D 5 | import torch.nn.functional as F 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | from vdd.score_functions.score_utils import plot_2d_gaussians, plot_2d_gaussians_color_map 10 | 11 | class GMMScoreFunction(ScoreFunction): 12 | def __init__(self, model=None, means=ch.Tensor, chols=ch.Tensor, prior=None, device='cuda'): 13 | super().__init__(model) 14 | self.dim = means.shape[-1] 15 | self.n_components = means.shape[0] 16 | self.means = means.to(device) 17 | self.chols = chols.to(device) 18 | assert self.means.shape == self.chols.shape[:-1], f"Means shape: {self.means.shape}, Chols shape: {self.chols.shape}" 19 | self.prior = prior if prior is not None else ch.ones(means.shape[0]) / self.n_components 20 | self.prior = self.prior.to(device) 21 | self.device = device 22 | 23 | def log_probability(self, x): 24 | gating = self.prior.view(1, -1).repeat(x.shape[0], 1) 25 | gating_dist = D.Categorical(gating) 26 | cmps = D.MultivariateNormal(self.means, scale_tril=self.chols, validate_args=False) 27 | gmm = D.MixtureSameFamily(gating_dist, cmps) 28 | return gmm.log_prob(x) 29 | 30 | def sample(self, n: int = 1): 31 | gating = D.Categorical(self.prior) 32 | comps = D.MultivariateNormal(self.means, scale_tril=self.chols, validate_args=False) 33 | gmm = D.MixtureSameFamily(gating, comps) 34 | return gmm.sample((n,)) 35 | 36 | def __call__(self, samples, states=None, score_goals=False, iter=None, is_vision=False): 37 | with ch.enable_grad(): 38 | samples = samples.clone().detach().requires_grad_(True) 39 | samples.retain_grad() 40 | log_probs = self.log_probability(samples) 41 | log_probs = log_probs.sum()/samples.shape[0] 42 | log_probs.backward() 43 | return samples.grad, ch.zeros(samples.shape[0]) 44 | 45 | def visualize_cmps(self, ax=None): 46 | cmp_means = self.means.clone().cpu() 47 | cmp_chols = self.chols.clone().cpu() 48 | if ax is None: 49 | fig, ax = plt.subplots(1, 1) 50 | plot_2d_gaussians(cmp_means, cmp_chols, ax, title="GT GMM") 51 | ax.set_aspect('equal') 52 | plt.show() 53 | else: 54 | plot_2d_gaussians(cmp_means, cmp_chols, ax, title="GT GMM") 55 | ax.set_aspect('equal') 56 | 57 | def visualize_gradient_field(self, n=20, x_range=[-1, 1], y_range=[-1, 1], ax=None): 58 | raw_x, raw_y = ch.meshgrid(ch.linspace(x_range[0], x_range[1], n), ch.linspace(x_range[0], y_range[1], n)) 59 | raw_x = raw_x.to(self.device) 60 | raw_y = raw_y.to(self.device) 61 | grid_actions = ch.stack([raw_x, raw_y], dim=-1).view(-1, 2) 62 | scores = self(grid_actions, None)[0] 63 | scores = scores.view(n, n, 2).cpu() 64 | u = scores[..., 0] 65 | v = scores[..., 1] 66 | ax.quiver(raw_x.cpu(), raw_y.cpu(), u, v, color="white" 67 | "") 68 | 69 | def visualize_grad_and_cmps(self, x_range=[-1, 1], y_range=[-1, 1], n=20): 70 | fig, ax = plt.subplots(1, 1, figsize=(6, 6), dpi=150) 71 | 72 | plot_2d_gaussians_color_map(self.means.cpu(), self.chols.cpu(), ax, 73 | x_range=x_range, y_range=y_range, 74 | title="GT GMM") 75 | self.visualize_gradient_field(n=n, x_range=x_range, y_range=y_range, 76 | ax=ax) 77 | return fig, ax 78 | 79 | @staticmethod 80 | def generate_random_params(n_components, dim): 81 | means = ch.rand(n_components, dim) * 2.0 - 1.0 82 | chols = ch.rand(n_components, dim, dim) 83 | chols = ch.tril(chols, diagonal=-1) 84 | diag = ch.rand(n_components, dim) 85 | diag = F.softplus(diag) + 1e-4 86 | chols = chols + ch.diag_embed(diag) 87 | chols = 0.15 * chols 88 | return means, chols -------------------------------------------------------------------------------- /vdd/score_functions/score_base.py: -------------------------------------------------------------------------------- 1 | class ScoreFunction: 2 | def __init__(self, model, **kwargs): 3 | self.model = model 4 | 5 | def __call__(self, samples, states, **kwargs): 6 | raise NotImplementedError -------------------------------------------------------------------------------- /vdd/score_functions/score_utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib.patches import Ellipse 3 | from sklearn.mixture import GaussianMixture 4 | from matplotlib.colors import LinearSegmentedColormap 5 | import numpy as np 6 | import torch 7 | 8 | colors = ["#471365", "#24878E", "#F7E621"] # dark green to light yellow 9 | color_map = LinearSegmentedColormap.from_list("custom_cmap", colors) 10 | 11 | def plot_gaussian_ellipse(mean, covariance, ax, color): 12 | """ 13 | Plot the mean and covariance of a Gaussian distribution as an ellipse. 14 | 15 | Parameters: 16 | mean (np.array): Mean vector of the Gaussian distribution. 17 | covariance (np.array): Covariance matrix of the Gaussian distribution. 18 | """ 19 | # Compute eigenvalues and eigenvectors of the covariance matrix 20 | eigenvalues, eigenvectors = np.linalg.eigh(covariance) 21 | 22 | # Compute the angle between the x-axis and the largest eigenvector 23 | angle = np.degrees(np.arctan2(*eigenvectors[:, 0][::-1])) 24 | 25 | # Create an ellipse representing the covariance matrix 26 | width, height = 2 * np.sqrt(eigenvalues) # 2 standard deviations 27 | ellipse = Ellipse(xy=mean, width=width, height=height, angle=angle, alpha=1.0, fill=False, color=color, linewidth=4.0) 28 | 29 | # Plot the ellipse 30 | ax.add_patch(ellipse) 31 | plt.scatter(*mean, c=color) # Plot the mean 32 | 33 | def plot_2d_gaussians(means, chols, ax, title: str = '2D Gaussian', color: str = 'green'): 34 | for i in range(means.shape[0]): 35 | plot_gaussian_ellipse(means[i], chols[i] @ chols[i].T, ax, color) 36 | # plt.title(title) 37 | 38 | def plot_2d_gaussians_color_map(means, chols, ax, 39 | x_range, y_range, 40 | title: str = '2D Gaussian', 41 | color: str = 'green'): 42 | cov = chols @ chols.transpose(-1, -2) 43 | n_components = means.shape[0] 44 | gmm = GaussianMixture(n_components=n_components, covariance_type='full') 45 | gmm.means_ = means 46 | gmm.covariances_ = cov 47 | 48 | gmm.weights_ = np.ones(n_components) / n_components 49 | gmm.precisions_cholesky_ = np.linalg.cholesky(np.linalg.inv(cov)) 50 | 51 | # Step 4: Create a grid of points 52 | x = np.linspace(x_range[0], x_range[1], 500) 53 | y = np.linspace(y_range[0], y_range[1], 500) 54 | X_, Y_ = np.meshgrid(x, y) 55 | XX = np.array([X_.ravel(), Y_.ravel()]).T 56 | 57 | # Step 5: Compute the GMM density 58 | log_density = gmm.score_samples(XX) 59 | density = np.exp(log_density) 60 | # Z = log_density.reshape(X_.shape) # Use log density 61 | Z = density.reshape(X_.shape) # Use density 62 | 63 | ax.contourf(X_, Y_, Z, levels=10, cmap=color_map) 64 | 65 | 66 | def distribute_components_torch(n): 67 | # Calculate grid size 68 | grid_side = int(torch.ceil(torch.sqrt(torch.tensor(n).float()))) # Number of points along one dimension 69 | 70 | # Generate grid points 71 | linspace = torch.linspace(-0.5, 0.5, grid_side) 72 | grid_x, grid_y = torch.meshgrid(linspace, linspace, indexing='ij') 73 | 74 | # Flatten the grid and take the first n points 75 | points_x = grid_x.flatten()[:n] 76 | points_y = grid_y.flatten()[:n] 77 | 78 | return points_x, points_y 79 | 80 | 81 | def plot_distribution_torch(n): 82 | x, y = distribute_components_torch(n) 83 | plt.scatter(x.numpy(), y.numpy()) # Convert to numpy for plotting 84 | plt.xlim(-1.1, 1.1) 85 | plt.ylim(-1.1, 1.1) 86 | plt.gca().set_aspect('equal', adjustable='box') 87 | plt.title(f'Uniform distribution of {n} components in PyTorch') 88 | plt.show() 89 | 90 | -------------------------------------------------------------------------------- /vdd/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | # from GPUtil import GPUtil 3 | import os 4 | import copy 5 | import shutil 6 | import yaml 7 | 8 | from pathlib import Path 9 | from datetime import datetime 10 | 11 | import logging 12 | 13 | import pathlib 14 | 15 | import torch 16 | import numpy as np 17 | 18 | 19 | def global_seeding(seed: int): 20 | """ 21 | Set the seed for numpy and torch 22 | Args: 23 | seed: seed value 24 | """ 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) # Sets the seed for generating random numbers. 27 | torch.cuda.manual_seed(seed) # Sets the seed for generating random numbers for the current GPU. 28 | torch.cuda.manual_seed_all(seed) # Sets the seed for generating random numbers on all GPUs. 29 | torch.backends.cudnn.deterministic = True # Forces deterministic algorithm selections for convolution. 30 | torch.backends.cudnn.benchmark = False # Disables the inbuilt cudnn auto-tuner. 31 | print(f"Setting global seed: {seed}") 32 | 33 | 34 | # def get_free_gpus(): 35 | # return GPUtil.getAvailable(order='memory', limit=1, maxLoad=0.8, maxMemory=0.8, includeNan=False, 36 | # excludeID=[], excludeUUID=[]) 37 | 38 | def assign_process_to_cpu(pid, cpus): 39 | os.sched_setaffinity(pid, cpus) 40 | 41 | def mkdir(directory: str, overwrite: bool = False): 42 | """ 43 | 44 | Args: 45 | directory: dir path to make 46 | overwrite: overwrite exist dir 47 | 48 | Returns: 49 | None 50 | 51 | Raise: 52 | FileExistsError if dir exists and overwrite is False 53 | """ 54 | path = Path(directory) 55 | try: 56 | path.mkdir(parents=True, exist_ok=overwrite) 57 | except FileExistsError: 58 | logging.error("Directory already exists, remove it before make a new one.") 59 | raise 60 | 61 | def set_value_in_nest_dict(config, key, value): 62 | """ 63 | Set value of a certain key in a recursive way in a nested dictionary 64 | 65 | Args: 66 | config: configuration dictionary 67 | key: key to ref 68 | value: value to set 69 | 70 | Returns: 71 | config 72 | """ 73 | for k in config.keys(): 74 | if k == key: 75 | config[k] = value 76 | if isinstance(config[k], dict): 77 | set_value_in_nest_dict(config[k], key, value) 78 | return config 79 | 80 | def remove_file_dir(path: str) -> bool: 81 | """ 82 | Remove file or directory 83 | Args: 84 | path: path to directory or file 85 | 86 | Returns: 87 | True if successfully remove file or directory 88 | 89 | """ 90 | if not os.path.exists(path): 91 | return False 92 | elif os.path.isfile(path) or os.path.islink(path): 93 | os.unlink(path) 94 | return True 95 | else: 96 | shutil.rmtree(path) 97 | return True 98 | 99 | 100 | def dump_config(config_dict: dict, config_name: str, dump_dir: str): 101 | """ 102 | Dump configuration into yaml file 103 | Args: 104 | config_dict: config dictionary to be dumped 105 | config_name: config file name 106 | dump_dir: dir to dump 107 | Returns: 108 | None 109 | """ 110 | 111 | # Generate config path 112 | dump_path = os.path.join(dump_dir, config_name + ".yaml") 113 | 114 | # Remove old config if exists 115 | # remove_file_dir(dump_path) 116 | 117 | # Write new config to file 118 | with open(dump_path, "w") as f: 119 | yaml.dump(config_dict, f) 120 | 121 | def dir_go_up(num_level: int = 2, current_file_dir: str = "default") -> str: 122 | """ 123 | Go to upper n level of current file directory 124 | Args: 125 | num_level: number of level to go up 126 | current_file_dir: current dir 127 | 128 | Returns: 129 | dir n level up 130 | """ 131 | if current_file_dir == "default": 132 | current_file_dir = os.path.realpath(__file__) 133 | while num_level != 0: 134 | current_file_dir = os.path.dirname(current_file_dir) 135 | num_level -= 1 136 | return current_file_dir 137 | 138 | 139 | def get_formatted_date_time() -> str: 140 | """ 141 | Get formatted date and time, e.g. May-01-2021 22:14:31 142 | Returns: 143 | dt_string: date time string 144 | """ 145 | now = datetime.now() 146 | dt_string = now.strftime("%b-%2d-%Y-%H:%M:%S") 147 | return dt_string 148 | 149 | def make_log_dir_with_time_stamp(log_name: str) -> str: 150 | """ 151 | Get the dir to the log 152 | Args: 153 | log_name: log's name 154 | 155 | Returns: 156 | directory to log file 157 | """ 158 | 159 | return os.path.join(dir_go_up(3), "log", log_name, 160 | get_formatted_date_time()) 161 | 162 | def process_cw2_train_rep_config_file(config_obj, overwrite: bool = False): 163 | """ 164 | Given processed cw2 configuration, do further process, including: 165 | - Overwrite log path with time stamp 166 | - Create model save folders 167 | - Overwrite random seed by the repetition number 168 | - Save the current repository commits 169 | - Make a copy of the config and restore the exp path to the original 170 | - Dump this copied config into yaml file into the model save folder 171 | - Dump the current time stamped config file in log folder to make slurm 172 | call bug free 173 | Args: 174 | exp_configs: list of configs processed by cw2 already 175 | 176 | Returns: 177 | None 178 | 179 | """ 180 | exp_configs = config_obj.exp_configs 181 | formatted_time = get_formatted_date_time() 182 | # Loop over the config of each repetition 183 | for i, rep_config in enumerate(exp_configs): 184 | 185 | # Make model save directory 186 | model_save_dir = os.path.join(rep_config["_rep_log_path"], "model") 187 | 188 | try: 189 | mkdir(os.path.abspath(model_save_dir), overwrite=overwrite) 190 | except FileExistsError: 191 | import logging 192 | logging.error(formatted_time) 193 | raise 194 | 195 | # Set random seed to the repetition number 196 | set_value_in_nest_dict(rep_config, "seed", 197 | rep_config['_rep_idx']) 198 | 199 | 200 | # Make a hard copy of the config 201 | copied_rep_config = copy.deepcopy(rep_config) 202 | 203 | # Recover the path to its original 204 | copied_rep_config["path"] = copied_rep_config["_basic_path"] 205 | 206 | # Reset the repetition number to 1 for future test usage 207 | copied_rep_config["repetitions"] = 1 208 | if copied_rep_config.get("reps_in_parallel", False): 209 | del copied_rep_config["reps_in_parallel"] 210 | if copied_rep_config.get("reps_per_job", False): 211 | del copied_rep_config["reps_per_job"] 212 | 213 | # Delete the generated cw2 configs 214 | for key in rep_config.keys(): 215 | if key[0] == "_": 216 | del copied_rep_config[key] 217 | del copied_rep_config["log_path"] 218 | 219 | # Save this copied subconfig file 220 | dump_config(copied_rep_config, "config", 221 | os.path.abspath(model_save_dir)) 222 | 223 | # Save the time stamped config file in local /log directory 224 | time_stamped_config_path = make_log_dir_with_time_stamp("") 225 | mkdir(time_stamped_config_path, overwrite=True) 226 | 227 | config_obj.to_yaml(time_stamped_config_path, 228 | relpath=False) 229 | config_obj.config_path = \ 230 | os.path.join(time_stamped_config_path, 231 | "relative_" + config_obj.f_name) 232 | 233 | def load_metadata_from_yaml(path: str, file_name: str = 'config.yaml') -> Dict: 234 | """ 235 | Load meta data from yaml file 236 | """ 237 | load_path = pathlib.Path(path, file_name) 238 | with open(load_path, 'r') as f: 239 | try: 240 | meta_data = yaml.load(f, Loader=yaml.FullLoader) 241 | except yaml.YAMLError as exc: 242 | print(exc) 243 | return meta_data 244 | 245 | 246 | def sort_filenames(filenames): 247 | # Function to extract the numerical part from the filename 248 | def extract_number(filename): 249 | # This assumes the format is always "model_state_dict_NUMBER.pth" 250 | k = int(filename.split('.')[0].split('_')[-1]) 251 | return k 252 | 253 | sorted_filenames = sorted(filenames, key=extract_number) 254 | return sorted_filenames 255 | 256 | 257 | def get_files_with_prefix(path, prefix): 258 | file_name_list = os.listdir(path) 259 | file_name_list = [file_name for file_name in file_name_list if file_name.startswith(prefix)] 260 | return sort_filenames(file_name_list) -------------------------------------------------------------------------------- /vdd/workspaces/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/intuitive-robots/vdd/e1395f38dd896968abae21d4571c74cfb02ea0f8/vdd/workspaces/__init__.py -------------------------------------------------------------------------------- /vdd/workspaces/base_manager.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | class BaseManager(abc.ABC): 4 | 5 | def __init__(self, seed, device, **kwargs): 6 | self.seed = seed 7 | self.device = device 8 | 9 | @abc.abstractmethod 10 | def env_rollout(self, agent, env, n_episodes: int, **kwargs): 11 | pass 12 | 13 | @abc.abstractmethod 14 | def get_train_and_test_datasets(self, **kwargs): 15 | pass 16 | 17 | @abc.abstractmethod 18 | def get_scaler(self, **kwargs): 19 | pass 20 | 21 | @abc.abstractmethod 22 | def get_score_function(self, **kwargs): 23 | pass 24 | 25 | def preprocess_data(self, batch_data): 26 | return batch_data 27 | -------------------------------------------------------------------------------- /vdd/workspaces/block_push_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch as ch 4 | import numpy as np 5 | 6 | from vdd.workspaces.base_manager import BaseManager 7 | from vdd.score_functions.beso_score import BesoScoreFunction 8 | 9 | import hydra 10 | from omegaconf import OmegaConf 11 | def assign_process_to_cpu(pid, cpus): 12 | os.sched_setaffinity(pid, cpus) 13 | 14 | class AgentWrapper: 15 | def __init__(self, gmm_agent, scaler): 16 | self.gmm_agent = gmm_agent 17 | self.scaler = scaler 18 | 19 | @ch.no_grad() 20 | def predict(self, obs_dict, **kwargs): 21 | obs = obs_dict['observation'] 22 | goal = obs_dict['goal_observation'] 23 | obs = self.scaler.scale_input(obs) 24 | goal = self.scaler.scale_input(goal) 25 | goal[..., [2, 5, 6, 7, 8, 9]] = 0 26 | act = self.gmm_agent.act(obs, goal) 27 | act = self.scaler.clip_action(act) 28 | act = self.scaler.inverse_scale_output(act) 29 | return act 30 | 31 | def reset(self): 32 | self.gmm_agent.reset() 33 | 34 | class BlockPushManager(BaseManager): 35 | 36 | def __init__(self, model_path, sv_name, seed, device, score_fn_params=None, **kwargs): 37 | super().__init__(seed, device, **kwargs) 38 | datasets_config = kwargs.get("datasets_config", None) 39 | self.beso_agent, self.workspace_manager = self.get_agent_and_workspace(model_path, sv_name, datasets_config) 40 | if score_fn_params is not None: 41 | self.score_function = BesoScoreFunction(self.beso_agent, obs_dim=10, goal_dim=10, **score_fn_params) 42 | else: 43 | self.score_function = None 44 | self.scaler = self.workspace_manager.scaler 45 | self.seed = seed 46 | self.train_fraction = self.workspace_manager.train_fraction 47 | self.goal_idx_offset = 0 48 | self.push_traj = self.workspace_manager.push_traj 49 | self.cpu_cores = kwargs.get("cpu_cores", None) 50 | 51 | def env_rollout(self, agent, n_episodes: int, **kwargs): 52 | """ 53 | evaluate the agent 54 | :return: 55 | """ 56 | if self.cpu_cores is not None: 57 | assign_process_to_cpu(os.getpid(), set([int(list(self.cpu_cores)[0])])) 58 | ch.cuda.empty_cache() 59 | agent.eval() 60 | wrapped_agent = AgentWrapper(agent, self.scaler) 61 | self.workspace_manager.eval_n_times = n_episodes 62 | return_dict = self.workspace_manager.test_agent(wrapped_agent, log_wandb=False) 63 | agent.train() 64 | return return_dict 65 | 66 | def get_train_and_test_datasets(self, **kwargs): 67 | return self.workspace_manager.data_loader['train'], self.workspace_manager.data_loader['test'] 68 | 69 | def preprocess_data(self, batch_data): 70 | scaled_obs = self.scaler.scale_input(batch_data['observation']) 71 | scaled_actions = self.scaler.scale_output(batch_data['action']) 72 | scaled_goals = self.scaler.scale_input(batch_data['goal_observation']) 73 | scaled_goals[..., [2, 5, 6, 7, 8, 9]] = 0 74 | return scaled_obs, scaled_actions, scaled_goals 75 | 76 | def get_scaler(self,): 77 | return self.workspace_manager.scaler 78 | 79 | def get_score_function(self, **kwargs): 80 | return self.score_function 81 | 82 | def get_agent_and_workspace(self, model_path, sv_name, datasets_config=None): 83 | cfg_store_path = os.path.join(model_path, ".hydra", "config.yaml") 84 | config = OmegaConf.load(cfg_store_path) 85 | np.random.seed(self.seed) 86 | ch.manual_seed(self.seed) 87 | config.seed = self.seed 88 | agent = hydra.utils.instantiate(config.agents) 89 | agent.load_pretrained_model(model_path, sv_name=sv_name) 90 | 91 | if datasets_config is not None: 92 | config.workspaces.seed = self.seed 93 | config.workspaces.dataset_fn.random_seed = self.seed 94 | config.workspaces.goal_fn.seed = self.seed 95 | assert datasets_config[ 96 | 'window_size'] == config.workspaces.dataset_fn.window_size, "Window size mismatch" 97 | if 'goal_seq_len' in datasets_config: 98 | assert datasets_config['goal_seq_len'] == config.workspaces.goal_fn.goal_seq_len, "Goal sequence length mismatch" 99 | config.workspaces.dataset_fn.train_fraction = datasets_config['train_fraction'] 100 | config.workspaces.num_workers = datasets_config['num_workers'] 101 | config.workspaces.train_batch_size = datasets_config['train_batch_size'] 102 | config.workspaces.test_batch_size = datasets_config['test_batch_size'] 103 | 104 | workspace_manager = hydra.utils.instantiate(config.workspaces) 105 | 106 | return agent, workspace_manager -------------------------------------------------------------------------------- /vdd/workspaces/d3il_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch as ch 3 | import numpy as np 4 | import einops 5 | 6 | import hydra 7 | from omegaconf import OmegaConf 8 | 9 | import torch.utils.data as Data 10 | 11 | from vdd.workspaces.base_manager import BaseManager 12 | from vdd.score_functions.beso_score import BesoScoreFunction 13 | # from vdd.score_functions.ddpm_score import DDPMScoreFunction 14 | 15 | from matplotlib import pyplot as plt 16 | 17 | class D3ILAgent: 18 | def __init__(self, gmm_agent, scaler): 19 | self.gmm_agent = gmm_agent 20 | self.scaler = scaler 21 | 22 | def predict(self, obs, if_vision=False): 23 | 24 | if if_vision: 25 | agent_view_image = ch.from_numpy(obs[0]).to(self.gmm_agent.device).to(ch.float32).unsqueeze(0) 26 | in_hand_image = ch.from_numpy(obs[1]).to(self.gmm_agent.device).to(ch.float32).unsqueeze(0) 27 | robot_ee_pos = self.scaler.scale_input(ch.from_numpy(obs[2]).to(self.gmm_agent.device)).to(ch.float32).unsqueeze(0) 28 | obs = (agent_view_image, in_hand_image, robot_ee_pos) 29 | act = self.gmm_agent.act(obs, vision_task=True) 30 | act = self.scaler.inverse_scale_output(act) 31 | return act.cpu().numpy() 32 | 33 | obs = ch.from_numpy(obs).unsqueeze(0).to(self.gmm_agent.device).to(ch.float32) 34 | obs = self.scaler.scale_input(obs) 35 | act = self.gmm_agent.act(obs) 36 | act = self.scaler.inverse_scale_output(act) 37 | return act.cpu().numpy() 38 | 39 | def reset(self): 40 | self.gmm_agent.reset() 41 | 42 | class D3ILManager(BaseManager): 43 | def __init__(self, model_path, sv_name, seed=0, device='cuda', score_fn_params=None, vision_task=False, 44 | goal_conditioned=False, score_type='beso', **kwargs): 45 | super().__init__(seed, device, **kwargs) 46 | datasets_config = kwargs.get("datasets_config", None) 47 | self.agent, self.env_sim = self.get_agent_and_workspace(model_path, sv_name, datasets_config) 48 | self.scaler = self.agent.scaler 49 | if score_fn_params is not None: 50 | if score_type == 'beso': 51 | self.score_function = BesoScoreFunction(self.agent, **score_fn_params) 52 | elif score_type == 'ddpm': 53 | self.score_function = DDPMScoreFunction(self.agent, **score_fn_params) 54 | else: 55 | raise NotImplementedError 56 | else: 57 | self.score_function = None 58 | self.cpu_cores = kwargs.get("cpu_cores", None) 59 | self.is_vision_task = vision_task 60 | self.goal_conditioned = goal_conditioned 61 | 62 | def env_rollout(self, agent, n_episodes: int, render=False, **kwargs): 63 | agent.eval() 64 | ch.cuda.empty_cache() 65 | d3il_agent = D3ILAgent(agent, self.scaler) 66 | self.env_sim.n_trajectories = n_episodes 67 | self.env_sim.render = render 68 | eval_dict = self.env_sim.test_agent(d3il_agent, self.cpu_cores) 69 | print(eval_dict) 70 | agent.train() 71 | return eval_dict 72 | 73 | def get_scaler(self, **kwargs): 74 | return self.scaler 75 | 76 | def preprocess_dataloader(self, dataloader, **kwargs): 77 | obs = [] 78 | actions = [] 79 | for batch in dataloader: 80 | obs.append(batch[0]) 81 | actions.append(batch[1]) 82 | obs = ch.cat(obs, dim=0).squeeze(1).to(self.device) 83 | obs = self.scaler.scale_input(obs) 84 | actions = ch.cat(actions, dim=0).squeeze(1).to(self.device) 85 | actions = self.scaler.scale_output(actions) 86 | idx = ch.arange(obs.shape[0]).long().to(self.device) 87 | return Data.TensorDataset(idx, obs, actions) 88 | 89 | def get_score_function(self, **kwargs): 90 | return self.score_function 91 | 92 | def get_train_and_test_datasets(self, **kwargs): 93 | #TODO: Hackey way to get train and test datasets 94 | return self.agent.train_dataloader, self.agent.test_dataloader 95 | 96 | def plot_grad_field(self, state): 97 | 98 | state_0, state_1, state_2 = state 99 | 100 | raw_x, raw_y = np.meshgrid(np.linspace(-2, 2, 10), np.linspace(-2, 2, 10)) 101 | 102 | x = ch.tensor(raw_x).to(self.device).to(ch.float32).unsqueeze(-1) 103 | y = ch.tensor(raw_y).to(self.device).to(ch.float32).unsqueeze(-1) 104 | 105 | grid_actions = ch.cat([x, y], axis=-1).to(self.device).to(ch.float32).reshape(-1, 2) 106 | 107 | t = 5 108 | 109 | grid_actions = einops.rearrange(grid_actions, 'n d -> 1 1 n d') 110 | 111 | grid_actions = einops.repeat(grid_actions, '1 1 n d -> 1 1 n t d', t=t) 112 | 113 | state_0 = einops.rearrange(state_0, '... -> 1 ...') 114 | state_1 = einops.rearrange(state_1, '... -> 1 ...') 115 | state_2 = einops.rearrange(state_2, '... -> 1 ...') 116 | 117 | state = (state_0, state_1, state_2) 118 | 119 | score = self.score_function(samples=grid_actions, states=state, vision_task=True) 120 | 121 | score = score[0].squeeze() 122 | score = score[:, -1, :] 123 | 124 | score = score.reshape(10, 10, 2).cpu().numpy() 125 | 126 | u = score[..., 0] 127 | v = score[..., 1] 128 | 129 | plt.quiver(raw_x, raw_y, u, v) 130 | 131 | def preprocess_data(self, batch_data): 132 | if self.is_vision_task: 133 | obs_batch = batch_data[:3] 134 | agentview_image = obs_batch[0].to(self.device) 135 | in_hand_image = obs_batch[1].to(self.device) 136 | robot_ee_pos = self.scaler.scale_input(obs_batch[2].to(self.device)) 137 | 138 | obs_batch = (agentview_image, in_hand_image, robot_ee_pos) 139 | 140 | action_batch = self.scaler.scale_output(batch_data[3]).to(self.device) 141 | 142 | # for idx in range(10): 143 | # self.plot_grad_field((obs_batch[0][idx], obs_batch[1][idx], obs_batch[2][idx])) 144 | # normalized_action = action_batch[idx, -1, :].cpu().numpy() 145 | # plt.plot(normalized_action[0], normalized_action[1], 'ro') 146 | # plt.show() 147 | 148 | # print("finished") 149 | 150 | if self.goal_conditioned: 151 | return obs_batch, action_batch, batch_data[4] 152 | else: 153 | return obs_batch, action_batch, None 154 | else: 155 | obs_batch = batch_data[0] 156 | action_batch = batch_data[1] 157 | scaled_obs = self.scaler.scale_input(obs_batch).to(self.device) 158 | scaled_action = self.scaler.scale_output(action_batch).to(self.device) 159 | return scaled_obs, scaled_action, None 160 | def preprocess_config(self, config): 161 | return config 162 | 163 | def get_agent_and_workspace(self, model_path, sv_name, datasets_config=None): 164 | cfg_store_path = os.path.join(model_path, ".hydra", "config.yaml") 165 | config = OmegaConf.load(cfg_store_path) 166 | 167 | config = self.preprocess_config(config) 168 | 169 | np.random.seed(self.seed) 170 | ch.manual_seed(self.seed) 171 | 172 | if datasets_config is not None: 173 | assert datasets_config['window_size'] == config.agents.window_size, "Window size mismatch" 174 | if 'goal_seq_len' in datasets_config: 175 | assert datasets_config['goal_seq_len'] == config.agents.goal_window_size, "Goal sequence length mismatch" 176 | config.agents.num_workers = datasets_config['num_workers'] 177 | config.agents.train_batch_size = datasets_config['train_batch_size'] 178 | config.agents.val_batch_size = datasets_config['test_batch_size'] 179 | 180 | agent = hydra.utils.instantiate(config.agents) 181 | agent.load_pretrained_model(model_path, sv_name=sv_name) 182 | env_sim = hydra.utils.instantiate(config.simulation) 183 | return agent, env_sim 184 | 185 | 186 | class D3ILAlignManager(D3ILManager): 187 | def env_rollout(self, agent, n_episodes: int, num_ctxts: int = 10, **kwargs): 188 | agent.eval() 189 | ch.cuda.empty_cache() 190 | d3il_agent = D3ILAgent(agent, self.scaler) 191 | self.env_sim.n_trajectories_per_context = n_episodes 192 | self.env_sim.n_contexts = num_ctxts 193 | self.env_sim.render = False 194 | if self.cpu_cores is not None and len(self.cpu_cores) > 20: 195 | self.cpu_cores = self.cpu_cores[:20] 196 | eval_dict = self.env_sim.test_agent(d3il_agent, self.cpu_cores) 197 | print(eval_dict) 198 | agent.train() 199 | return eval_dict 200 | 201 | def preprocess_config(self, config): 202 | config['trainset']['_target_'] = 'environments.dataset.aligning_dataset.Aligning_Dataset' 203 | config['valset']['_target_'] = 'environments.dataset.aligning_dataset.Aligning_Dataset' 204 | config['simulation']['_target_'] = 'simulation.aligning_sim.Aligning_Sim' 205 | config['train_data_path'] = 'environments/dataset/data/aligning/train_files.pkl' 206 | config['eval_data_path'] = 'environments/dataset/data/aligning/eval_files.pkl' 207 | return config 208 | 209 | class D3ILSortingVisionManager(D3ILManager): 210 | def env_rollout(self, agent, n_episodes: int, num_ctxts:int=10, **kwargs): 211 | agent.eval() 212 | ch.cuda.empty_cache() 213 | d3il_agent = D3ILAgent(agent, self.scaler) 214 | self.env_sim.n_trajectories_per_context = n_episodes 215 | self.env_sim.n_contexts = num_ctxts 216 | self.env_sim.render = False 217 | ## Limiting the number of cores to 8, otherwise it will run out of memory even on Horeka 218 | if self.cpu_cores is not None and len(self.cpu_cores) > 10: 219 | self.cpu_cores = self.cpu_cores[:10] 220 | eval_dict = self.env_sim.test_agent(d3il_agent, self.cpu_cores) 221 | print(eval_dict) 222 | agent.train() 223 | return eval_dict 224 | 225 | def preprocess_config(self, config): 226 | return config 227 | 228 | class D3ILStackingManager(D3ILManager): 229 | def env_rollout(self, agent, n_episodes: int, num_ctxts:int=10, **kwargs): 230 | agent.eval() 231 | ch.cuda.empty_cache() 232 | d3il_agent = D3ILAgent(agent, self.scaler) 233 | self.env_sim.n_trajectories_per_context = n_episodes 234 | self.env_sim.n_contexts = num_ctxts 235 | self.env_sim.render = False 236 | if self.cpu_cores is not None and len(self.cpu_cores) > 20: 237 | self.cpu_cores = self.cpu_cores[:20] 238 | eval_dict = self.env_sim.test_agent(d3il_agent, self.cpu_cores) 239 | print(eval_dict) 240 | agent.train() 241 | return eval_dict 242 | def preprocess_config(self, config): 243 | config['trainset']['_target_'] = 'environments.dataset.stacking_dataset.Stacking_Dataset' 244 | config['valset']['_target_'] = 'environments.dataset.stacking_dataset.Stacking_Dataset' 245 | config['simulation']['_target_'] = 'simulation.stacking_sim.Stacking_Sim' 246 | config['train_data_path'] = 'environments/dataset/data/stacking/train_files.pkl' 247 | config['eval_data_path'] = 'environments/dataset/data/stacking/eval_files.pkl' 248 | return config 249 | 250 | 251 | class D3ILStackingVisionManager(D3ILManager): 252 | def env_rollout(self, agent, n_episodes: int, num_ctxts:int=10, **kwargs): 253 | agent.eval() 254 | ch.cuda.empty_cache() 255 | d3il_agent = D3ILAgent(agent, self.scaler) 256 | self.env_sim.n_trajectories_per_context = n_episodes 257 | self.env_sim.n_contexts = num_ctxts 258 | self.env_sim.render = False 259 | ## Limiting the number of cores to 8, otherwise it will run out of memory even on Horeka 260 | if self.cpu_cores is not None and len(self.cpu_cores) > 10: 261 | self.cpu_cores = self.cpu_cores[:10] 262 | eval_dict = self.env_sim.test_agent(d3il_agent, self.cpu_cores) 263 | print(eval_dict) 264 | agent.train() 265 | return eval_dict 266 | 267 | def preprocess_config(self, config): 268 | config['trainset']['_target_'] = 'environments.dataset.stacking_dataset.Stacking_Img_Dataset' 269 | config['valset']['_target_'] = 'environments.dataset.stacking_dataset.Stacking_Img_Dataset' 270 | config['simulation']['_target_'] = 'simulation.stacking_vision_sim.Stacking_Sim' 271 | config['train_data_path'] = 'environments/dataset/data/stacking/vision_train_files.pkl' 272 | config['eval_data_path'] = 'environments/dataset/data/stacking/vision_eval_files.pkl' 273 | return config 274 | 275 | class D3ILAvoidingManager(D3ILManager): 276 | def env_rollout(self, agent, n_episodes: int, render=False, **kwargs): 277 | agent.eval() 278 | ch.cuda.empty_cache() 279 | d3il_agent = D3ILAgent(agent, self.scaler) 280 | self.env_sim.n_trajectories = n_episodes 281 | self.env_sim.render = render 282 | eval_dict = self.env_sim.test_agent(d3il_agent, self.cpu_cores) 283 | print(eval_dict) 284 | agent.train() 285 | return eval_dict 286 | 287 | def preprocess_config(self, config): 288 | config['trainset']['_target_'] = 'environments.dataset.avoiding_dataset.Avoiding_Dataset' 289 | config['valset']['_target_'] = 'environments.dataset.avoiding_dataset.Avoiding_Dataset' 290 | config['simulation']['_target_'] = 'simulation.avoiding_sim.Avoiding_Sim' 291 | config['data_directory'] = 'environments/dataset/data/avoiding/data' 292 | return config 293 | 294 | class D3ILPushingManager(D3ILManager): 295 | def env_rollout(self, agent, n_episodes: int, num_ctxts:int=10, **kwargs): 296 | agent.eval() 297 | ch.cuda.empty_cache() 298 | d3il_agent = D3ILAgent(agent, self.scaler) 299 | self.env_sim.n_trajectories_per_context = n_episodes 300 | self.env_sim.n_contexts = num_ctxts 301 | self.env_sim.render = False 302 | if self.cpu_cores is not None and len(self.cpu_cores) > 20: 303 | self.cpu_cores = self.cpu_cores[:20] 304 | eval_dict = self.env_sim.test_agent(d3il_agent, self.cpu_cores) 305 | print(eval_dict) 306 | agent.train() 307 | return eval_dict 308 | def preprocess_config(self, config): 309 | config['trainset']['_target_'] = 'environments.dataset.pushing_dataset.Pushing_Dataset' 310 | config['valset']['_target_'] = 'environments.dataset.pushing_dataset.Pushing_Dataset' 311 | config['simulation']['_target_'] = 'simulation.pushing_sim.Pushing_Sim' 312 | config['train_data_path'] = 'environments/dataset/data/pushing/train_files.pkl' 313 | config['eval_data_path'] = 'environments/dataset/data/pushing/eval_files.pkl' 314 | return config -------------------------------------------------------------------------------- /vdd/workspaces/kitchen_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch as ch 3 | import torch.utils.data as Data 4 | import numpy as np 5 | 6 | from vdd.workspaces.base_manager import BaseManager 7 | from vdd.score_functions.beso_score import BesoScoreFunction 8 | 9 | import hydra 10 | from omegaconf import OmegaConf 11 | 12 | from copy import deepcopy 13 | 14 | import gym 15 | import adept_envs 16 | 17 | class KitchenAgentWrapper: 18 | def __init__(self, gmm_agent, scaler): 19 | self.gmm_agent = gmm_agent 20 | self.scaler = scaler 21 | 22 | def predict(self, obs_dict, **kwargs): 23 | obs = obs_dict['observation'] 24 | goal = obs_dict['goal_observation'] 25 | obs = self.scaler.scale_input(obs) 26 | goal = self.scaler.scale_input(goal) 27 | act = self.gmm_agent.act(obs, goal) 28 | act = self.scaler.clip_action(act) 29 | act = self.scaler.inverse_scale_output(act) 30 | return act 31 | 32 | def reset(self): 33 | self.gmm_agent.reset() 34 | 35 | class KitchenManager(BaseManager): 36 | 37 | def __init__(self, model_path, sv_name, score_fn_params, seed, device, **kwargs): 38 | super().__init__(seed, device, **kwargs) 39 | datasets_config = kwargs.get("datasets_config", None) 40 | self.agent, self.workspace_manager = self.get_agent_and_workspace(model_path, sv_name, datasets_config) 41 | self.score_function = BesoScoreFunction(self.agent, obs_dim=30, goal_dim=30, **score_fn_params) 42 | self.scaler = self.workspace_manager.scaler 43 | self.cpu_cores = kwargs.get("cpu_cores", None) 44 | self.goal_idx_offset = 0 45 | 46 | def env_rollout(self, agent, n_episodes: int, **kwargs): 47 | if self.cpu_cores is not None: 48 | os.sched_setaffinity(os.getpid(), set([int(list(self.cpu_cores)[0])])) 49 | ch.cuda.empty_cache() 50 | agent.eval() 51 | wrapped_agent = KitchenAgentWrapper(agent, self.scaler) 52 | self.workspace_manager.eval_n_times = n_episodes 53 | return_dict, _ = self.workspace_manager.test_agent(wrapped_agent, log_wandb=False) 54 | agent.train() 55 | return return_dict 56 | 57 | def pre_process_dataset(self, dataset, keep_window=True): 58 | obs = [] 59 | actions = [] 60 | goals = [] 61 | for i in range(len(dataset)): 62 | obs.append(dataset[i]['observation']) 63 | actions.append(dataset[i]['action']) 64 | goals.append(dataset[i]['goal_observation']) 65 | 66 | obs = ch.cat(obs, dim=0).to(self.device).to(ch.float32) 67 | actions = ch.cat(actions, dim=0).to(self.device).to(ch.float32) 68 | goals = ch.cat(goals, dim=0).to(self.device).to(ch.float32) 69 | idx = ch.arange(0, obs.shape[0]).long().to(self.device) 70 | inputs = ch.cat([obs, goals], dim=-1) 71 | return Data.TensorDataset(idx, inputs, actions) 72 | 73 | def get_train_and_test_datasets(self, **kwargs): 74 | return self.workspace_manager.data_loader['train'], self.workspace_manager.data_loader['test'] 75 | 76 | def preprocess_data(self, batch_data): 77 | scaled_obs = self.scaler.scale_input(batch_data['observation']) 78 | scaled_output = self.scaler.scale_output(batch_data['action']) 79 | scaled_goal_obs = self.scaler.scale_input(batch_data['goal_observation']) 80 | return scaled_obs, scaled_output, scaled_goal_obs 81 | 82 | def get_scaler(self, **kwargs): 83 | return deepcopy(self.agent.scaler) 84 | 85 | def get_score_function(self, **kwargs): 86 | return self.score_function 87 | 88 | def get_agent_and_workspace(self, model_path, sv_name, datasets_config=None): 89 | cfg_store_path = os.path.join(model_path, ".hydra", "config.yaml") 90 | config = OmegaConf.load(cfg_store_path) 91 | np.random.seed(self.seed) 92 | ch.manual_seed(self.seed) 93 | agent = hydra.utils.instantiate(config.agents) 94 | agent.load_pretrained_model(model_path, sv_name=sv_name) 95 | 96 | if datasets_config is not None: 97 | assert datasets_config['window_size'] == config.workspaces.dataset_fn.window_size, "Window size mismatch" 98 | # config.workspaces.dataset_fn.window_size = datasets_config['window_size'] 99 | config.workspaces.dataset_fn.train_fraction = datasets_config['train_fraction'] 100 | config.workspaces.num_workers = datasets_config['num_workers'] 101 | config.workspaces.train_batch_size = datasets_config['train_batch_size'] 102 | config.workspaces.test_batch_size = datasets_config['test_batch_size'] 103 | 104 | workspace_manager = hydra.utils.instantiate(config.workspaces) 105 | 106 | return agent, workspace_manager -------------------------------------------------------------------------------- /vdd/workspaces/manager_factory.py: -------------------------------------------------------------------------------- 1 | from vdd.workspaces.toytask2d_manager import ToyTask2DManager 2 | from vdd.workspaces.block_push_manager import BlockPushManager 3 | from vdd.workspaces.kitchen_manager import KitchenManager 4 | from vdd.workspaces.d3il_manager import D3ILAlignManager, D3ILAvoidingManager, D3ILStackingManager, \ 5 | D3ILSortingVisionManager, D3ILPushingManager, D3ILStackingVisionManager 6 | 7 | def create_experiment_manager(config): 8 | if config['experiment_name'] == 'toytask2d': 9 | return ToyTask2DManager(**config) 10 | elif config['experiment_name'] == 'block_push': 11 | return BlockPushManager(**config) 12 | elif config['experiment_name'] == 'kitchen': 13 | return KitchenManager(**config) 14 | elif config['experiment_name'] == 'd3il_avoiding': 15 | return D3ILAvoidingManager(**config) 16 | elif config['experiment_name'] == 'd3il_aligning': 17 | return D3ILAlignManager(**config) 18 | elif config['experiment_name'] == 'd3il_stacking': 19 | return D3ILStackingManager(**config) 20 | elif config['experiment_name'] == 'd3il_sorting_vision': 21 | return D3ILSortingVisionManager(**config) 22 | elif config['experiment_name'] == 'd3il_stacking_vision': 23 | return D3ILStackingVisionManager(**config) 24 | elif config['experiment_name'] == 'd3il_pushing': 25 | return D3ILPushingManager(**config) 26 | else: 27 | raise ValueError(f"Unknown experiment name: {config['experiment_name']}") -------------------------------------------------------------------------------- /vdd/workspaces/toytask1d_manager.py: -------------------------------------------------------------------------------- 1 | import torch as ch 2 | import torch.utils.data as Data 3 | import numpy as np 4 | 5 | from vi.experiment_managers.base_manager import BaseManager 6 | from vi.score_functions.gating_toytask_score import ToyTaskScoreFunction 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | def plot_1d_gaussians(means, stds, ax, title=""): 11 | x = ch.linspace(-2.0, 2.0, 100) 12 | if len(means.shape) == 0: 13 | means = means[None, ...] 14 | stds = stds[None, ...] 15 | for i in range(means.shape[0]): 16 | y = ch.exp(-0.5 * ((x - means[i]) / stds[i]) ** 2) / (stds[i] * (2 * 3.1415) ** 0.5) 17 | ax.plot(x, y, label=f"Component {i}") 18 | ax.set_title(title) 19 | ax.legend() 20 | 21 | def plot_1d_gmm(means, stds, weights, ax, title=""): 22 | x = ch.linspace(-2.0, 2.0, 100) 23 | y = ch.zeros_like(x) 24 | if len(means.shape) == 0: 25 | means = means[None, ...] 26 | stds = stds[None, ...] 27 | weights = weights[None, ...] 28 | for i in range(means.shape[0]): 29 | y += weights[i] * ch.exp(-0.5 * ((x - means[i]) / stds[i]) ** 2) / (stds[i] * (2 * 3.1415) ** 0.5) 30 | ax.plot(x, y, label=f"GMM", linestyle='dashed', zorder=-1, linewidth=10.0) 31 | ax.set_title(title) 32 | ax.legend() 33 | 34 | class ToyTask2DManager(BaseManager): 35 | 36 | def __init__(self, n_component, seed, device, **kwargs): 37 | super().__init__(seed, device, **kwargs) 38 | self.score_function = GMMScoreFunction.generate_random_params(4, 2) -------------------------------------------------------------------------------- /vdd/workspaces/toytask2d_manager.py: -------------------------------------------------------------------------------- 1 | import torch as ch 2 | import torch.utils.data as Data 3 | import numpy as np 4 | 5 | from vdd.workspaces.base_manager import BaseManager 6 | from vdd.score_functions.gmm_score import GMMScoreFunction 7 | 8 | import einops 9 | 10 | from vdd.score_functions.score_utils import plot_2d_gaussians 11 | import matplotlib.pyplot as plt 12 | 13 | 14 | class ToyTask2DManager(BaseManager): 15 | 16 | def __init__(self, num_datapoints=1000, score_fn_params={}, datasets_config={}, seed=0, device='cuda', **kwargs): 17 | super().__init__(seed, device, **kwargs) 18 | r = score_fn_params.get("r", 1) 19 | std = score_fn_params.get("std", 0.4) 20 | n_component = score_fn_params.get("num_components", 4) 21 | thetas = ch.linspace(0, 2 * np.pi, n_component + 1)[:-1] 22 | means = ch.stack([r * ch.cos(thetas), r * ch.sin(thetas)], dim=-1) 23 | chols = ch.eye(2).view(1, 2, 2).repeat(n_component, 1, 1) * std 24 | self.score_function = GMMScoreFunction(means=means, chols=chols, device=device) 25 | self.scaler = None 26 | self.r = r 27 | self.actions = self.score_function.sample(num_datapoints).unsqueeze(1).to(device) 28 | self.test_actions = self.score_function.sample(num_datapoints).unsqueeze(1).to(device) 29 | self.states = ch.ones_like(self.actions[:, 0]).unsqueeze(1).to(device) 30 | self.dataset = Data.TensorDataset(self.states, self.actions) 31 | self.test_dataset = Data.TensorDataset(self.states, self.test_actions) 32 | self.train_loader = Data.DataLoader(self.dataset, batch_size=datasets_config['batch_size'], shuffle=True) 33 | self.test_loader = Data.DataLoader(self.test_dataset, batch_size=datasets_config['batch_size'], shuffle=True) 34 | self.iter = 0 35 | 36 | def env_rollout(self, agent, n_episodes: int, **kwargs): 37 | 38 | agent.eval() 39 | fig, ax = self.score_function.visualize_grad_and_cmps(x_range=[-2*self.r, 2*self.r], 40 | y_range=[-2*self.r, 2*self.r], n=15) 41 | means, chols, gating = agent(self.states[:1, ...]) 42 | 43 | means = einops.rearrange(means, '1 n 1 d -> n d').cpu().detach().numpy() 44 | chols = einops.rearrange(chols, '1 n 1 d1 d2 -> n d1 d2').cpu().detach().numpy() 45 | plot_2d_gaussians(means, chols, ax, color='orange') 46 | ax.set_aspect('equal') 47 | plt.show() 48 | agent.train() 49 | 50 | self.iter += 1 51 | 52 | return {'iter': self.iter} 53 | 54 | def get_scaler(self, **kwargs): 55 | return None 56 | 57 | def preprocess_data(self, batch_data): 58 | return batch_data[0], batch_data[1], None 59 | 60 | def get_score_function(self, **kwargs): 61 | return self.score_function 62 | 63 | def get_train_and_test_datasets(self, **kwargs): 64 | return self.train_loader, self.test_loader --------------------------------------------------------------------------------