├── .gitignore ├── README.md ├── assets └── framework.png └── skilldiffuser ├── README.md ├── eval_lorel.sh ├── h5py2pkl.py ├── hrl ├── conf │ ├── agent │ │ └── big_dt.yaml │ ├── config.yaml │ ├── env │ │ ├── hopper.yaml │ │ ├── lorel_franka.yaml │ │ ├── lorel_sawyer_obs.yaml │ │ └── lorel_sawyer_state.yaml │ └── model │ │ ├── option.yaml │ │ ├── traj_option.yaml │ │ └── vanilla.yaml ├── dec_diffuser.py ├── dec_encoder.py ├── decision_transformer.py ├── difftrainer.py ├── diffuser │ ├── __init__.py │ ├── datasets │ │ ├── __init__.py │ │ ├── buffer.py │ │ ├── d4rl.py │ │ ├── normalization.py │ │ ├── preprocessing.py │ │ └── sequence.py │ ├── environments │ │ ├── __init__.py │ │ ├── ant.py │ │ ├── assets │ │ │ ├── ant.xml │ │ │ ├── half_cheetah.xml │ │ │ ├── hopper.xml │ │ │ └── walker2d.xml │ │ ├── half_cheetah.py │ │ ├── hopper.py │ │ ├── registration.py │ │ └── walker2d.py │ ├── models │ │ ├── __init__.py │ │ ├── diffusion.py │ │ ├── helpers.py │ │ └── temporal.py │ └── utils │ │ ├── __init__.py │ │ ├── arrays.py │ │ ├── cloud.py │ │ ├── colab.py │ │ ├── config.py │ │ ├── git_utils.py │ │ ├── iql.py │ │ ├── progress.py │ │ ├── pybullet_utils.py │ │ ├── rendering.py │ │ ├── serialization.py │ │ ├── setup.py │ │ ├── timer.py │ │ ├── training.py │ │ ├── transformations.py │ │ └── video.py ├── env.py ├── eval.py ├── eval_orig.py ├── expert_dataset.py ├── hrl_model.py ├── img_encoder.py ├── iq.py ├── main_hot.py ├── option_selector.py ├── option_transformer.py ├── reconstructors.py ├── trainer.py ├── trajectory_gpt2.py ├── trajectory_model.py ├── utils.py ├── vector_quantize_pytorch.py ├── viz-lorl.ipynb ├── viz-lorl.py └── viz.py ├── matrix.npy ├── pkl2pkl.py ├── requirements.txt └── train_lorel_compose.sh /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | outputs/ 3 | checkpoints/ 4 | wandb/ 5 | alfred/data 6 | alfworld/storage 7 | eval_*/ 8 | skilldiffuser/outputs 9 | buckets 10 | alfred/exp 11 | lorel 12 | metaworld 13 | hrl/gifs 14 | 15 | # Jupyter Notebook 16 | .ipynb_checkpoints 17 | 18 | .venv 19 | .vscode 20 | 21 | babyai/babyai.egg-info/ 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SkillDiffuser on LOReL Compositional Tasks 2 | 3 | This is a part of the official PyTorch implementation of paper: 4 | 5 | > ### [SkillDiffuser: Interpretable Hierarchical Planning via Skill Abstractions in Diffusion-Based Task Execution](https://arxiv.org/abs/2312.11598) 6 | > Zhixuan Liang, Yao Mu, Hengbo Ma, Masayoshi Tomizuka, Mingyu Ding, Ping Luo 7 | > 8 | > CVPR 2024 9 | > 10 | > [Project Page](https://skilldiffuser.github.io/) | [Paper](https://arxiv.org/abs/2312.11598) 11 | 12 | ### Framework of SkillDiffuser 13 | 14 | 15 | 16 | SkillDiffuser is a hierarchical planning model that leverages the cooperation of interpretable skill abstractions at the higher level and a skill conditioned diffusion model at the lower level for task execution in a multi-task learning environment. The high-level skill abstraction is achieved through a skill predictor and a vector quantization operation, generating sub-goals (skill set) that the diffusion model employs to determine the appropriate future states. Future states are converted to actions using an inverse dynamics model. This unique fusion enables a consistent underlying planner across different tasks, with the variation only in the inverse dynamics model. 17 | 18 | ### Citation 19 | 20 | ```bibtex 21 | @article{liang2023skilldiffuser, 22 | title={Skilldiffuser: Interpretable hierarchical planning via skill abstractions in diffusion-based task execution}, 23 | author={Liang, Zhixuan and Mu, Yao and Ma, Hengbo and Tomizuka, Masayoshi and Ding, Mingyu and Luo, Ping}, 24 | journal={arXiv preprint arXiv:2312.11598}, 25 | year={2023} 26 | } 27 | ``` 28 | 29 | ### Code 30 | 31 | To install and use SkillDiffuser check the instructions provided in the [skilldiffuser](skilldiffuser) folder. 32 | 33 | 34 | ### Acknowledgements 35 | 36 | The diffusion model implementation is based on Michael Janner's [diffuser](https://github.com/jannerm/diffuser) repo. 37 | The organization of this repo and remote launcher is based on the [LISA](https://github.com/Div99/LISA) repo. 38 | 39 | ### Questions 40 | Please email us if you have any questions. 41 | 42 | Zhixuan Liang ([liangzx@connect.hku.hk](mailto:liangzx@connect.hku.hk?subject=[GitHub]%skilldiffuser)) 43 | 44 | -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liang-ZX/SkillDiffuser/bc9da40fb98b8e65797005d3d9293d3d75665d9d/assets/framework.png -------------------------------------------------------------------------------- /skilldiffuser/README.md: -------------------------------------------------------------------------------- 1 | # SkillDiffuser 2 | 3 | Zhixuan Liang, Yao Mu, Hengbo Ma, Masayoshi Tomizuka, Mingyu Ding, Ping Luo 4 | 5 | ## Usage 6 | ### Setup Python Environment 7 | 1. Install MuJoCo 200 8 | ```shell 9 | unzip mujoco200_linux.zip 10 | mv mujoco200_linux mujoco200 11 | cp mjkey.txt ~/.mujoco 12 | cp mjkey.txt ~/.mujoco/mujoco200/bin 13 | 14 | # test the install 15 | cd ~/.mujoco/mujoco200/bin 16 | ./simulate ../model/humanoid.xml 17 | 18 | # add environment variables 19 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco200/bin 20 | export MUJOCO_KEY_PATH=~/.mujoco/${MUJOCO_KEY_PATH} 21 | ``` 22 | 2. Install Pypi Packages 23 | ```shell 24 | pip install -r requirements.txt 25 | ``` 26 | 3. Install LOReL Environment 27 | ```shell 28 | git clone https://github.com/suraj-nair-1/lorel.git 29 | 30 | cd lorel/env 31 | pip install -e . 32 | ``` 33 | 34 | ### Setup LOReL Dataset 35 | 1. Download the dataset from [LOReL](https://drive.google.com/file/d/1pLnctqkOzyWZa1F1zTFqkNgUzSkUCtEv/view?usp=sharing) 36 | 2. Process the dataset from h5py to pickle 37 | ```shell 38 | python h5py2pkl.py --root_path --output_name 39 | ``` 40 | 3. Change the path in `hrl/conf/env/lorel_sawyer_obs.yaml` to the processed dataset. 41 | 42 | ## Instructions 43 | 44 | Our code for running SkillDiffuser experiments is present in `hrl` folder. 45 | 46 | To run the code, please use the following command: 47 | 48 | `./train_lorel_compose.sh` 49 | 50 | This is a sample command intended to show the usage of different flags available. The checkpoints can be downloaded from [here](https://connecthkuhk-my.sharepoint.com/:f:/g/personal/liangzx_connect_hku_hk/Em3qBc3AxWpOkYR7Pgd7lnUBH0bkLsILMpgUX2Xg5l3YGg?e=QslO6n). 51 | (The checkpoint is used for fine-tuning, not for evaluation directly.) 52 | 53 | If you would like to evaluate the model directly, please see [this issue](https://github.com/Liang-ZX/SkillDiffuser/issues/2#issuecomment-2377606631). 54 | 55 | ## License 56 | 57 | The code is made available for academic, non-commercial usage. 58 | 59 | For any inquiry, contact: Zhixuan Liang (liangzx@connect.hku.hk) 60 | -------------------------------------------------------------------------------- /skilldiffuser/eval_lorel.sh: -------------------------------------------------------------------------------- 1 | BASE_PATH="/home/zxliang/new-code/LISA/lisa/outputs/2023-03-09/20-56-04/checkpoints/LorlEnv-v0-40108-traj_option-2023-03-09-20:56:04" 2 | ITER=500 3 | CKPT="${BASE_PATH}/model_${ITER}.ckpt" 4 | 5 | #method=option_dt 6 | python hrl/main.py env=lorel_sawyer_obs method=traj_option dt.n_layer=1 dt.n_head=4 option_selector.option_transformer.n_layer=1 option_selector.option_transformer.n_head=4 option_selector.commitment_weight=0.1 option_selector.option_transformer.hidden_size=128 batch_size=256 seed=1 warmup_steps=5000 eval=True render=True checkpoint_path=${CKPT} 7 | -------------------------------------------------------------------------------- /skilldiffuser/h5py2pkl.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import h5py 4 | import numpy as np 5 | import pickle 6 | import pandas as pd 7 | import os 8 | import argparse 9 | 10 | # parse args 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--root_path', default=None, type=str, help='Absolute path to the root directory of the dataset') 13 | parser.add_argument('--output_name', default="prep_data.pkl", type=str, help='Output file name') 14 | args = parser.parse_args() 15 | 16 | root_path = args.root_path 17 | f = h5py.File(os.path.join(root_path, 'data.hdf5'),'r') 18 | # f.visit(lambda x: print(x)) 19 | 20 | data = dict() 21 | 22 | df = pd.read_table(os.path.join(root_path, "labels.csv"), sep=",") 23 | langs = df["Text Description"].str.strip().to_numpy().reshape(-1) 24 | langs = np.array(['' if x is np.isnan else x for x in langs]) 25 | filtr1 = np.array([int(("nothing" in l) or ("nan" in l) or ("wave" in l)) for l in langs]) 26 | filtr = filtr1 == 0 27 | data['langs'] = langs[filtr] 28 | 29 | for group in f.keys(): 30 | for key in f[group].keys(): 31 | print(group, key) 32 | data[key] = f[group][key][:][filtr] 33 | 34 | with open(os.path.join(root_path, args.output_name), "wb") as fo: 35 | pickle.dump(data, fo, protocol=4) 36 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/conf/agent/big_dt.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | agent: 4 | hidden_size: 288 5 | n_layer: 6 6 | n_head: 6 7 | activation_function: 'relu' 8 | n_positions: 1024 9 | dropout: 0.1 10 | 11 | 12 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/conf/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | # Set default options 3 | - _self_ 4 | - model: traj_option 5 | - env: lorel_sawyer_obs 6 | 7 | cuda_deterministic: False 8 | 9 | wandb: True 10 | 11 | seed: 0 12 | resume: False 13 | load_options: False 14 | load_frozen: False 15 | freeze_loaded_options: False 16 | checkpoint_path: 17 | 18 | eval: False 19 | 20 | render: False # False 21 | render_path: ./eval_${env.name}/ 22 | 23 | batch_size: 64 # 512 24 | max_iters: 500 # TODO 25 | warmup_steps: 2500 # 5000 26 | 27 | lr_decay: 0.1 28 | decay_steps: 100000 29 | 30 | # options config 31 | option_dim: 128 32 | codebook_dim: 16 33 | 34 | parallel: False # True 35 | savedir: 'checkpoints' 36 | savepath: ## to be filled in code 37 | 38 | method: ## to be filled in code 39 | use_iq: False ## use IQ-Learn objective instead of BC 40 | 41 | learning_rate: 1e-4 42 | lm_learning_rate: 1e-6 43 | weight_decay: 1e-4 44 | os_learning_rate: 1e-4 45 | 46 | trainer: 47 | device: ## to be filled in code 48 | state_il: False 49 | num_eval_episodes: 100 50 | eval_every: 1 51 | K: ${model.K} 52 | 53 | model: 54 | # Model specific configuration 55 | 56 | env: 57 | # Env specific configuration 58 | skip_words: ['go', 'to', 'the', 'a', '[SEP]'] 59 | 60 | option_selector: 61 | # Option configuration 62 | option_transformer: 63 | 64 | iq: 65 | alpha: 0.1 66 | div: chi 67 | loss: value 68 | gamma: 0.99 69 | # Don't use target updates 70 | use_target: False 71 | 72 | # Extra args 73 | log_interval: 1 # Log every this many iterations 74 | save_interval: 50 # Save networks every this many iterations 75 | hydra_base_dir: "" 76 | exp_name: '' 77 | project_name: ${env.name} 78 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/conf/env/hopper.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env: 4 | name: Hopper-v2 5 | state_dim: 11 6 | action_dim: 3 7 | discrete: False 8 | eval_offline: False 9 | eval_episode_factor: 2 10 | eval_env: 11 | 12 | train_dataset: 13 | expert_location: '/sailhome/divgarg/implicit-irl/experts/Hopper-v2_25.pkl' 14 | num_trajectories: 10 15 | normalize_states: False 16 | no_lang: False 17 | seed: ${seed} 18 | 19 | val_dataset: 20 | expert_location: 21 | num_trajectories: ${trainer.num_eval_episodes} 22 | normalize_states: False 23 | seed: ${seed} 24 | 25 | codebook_dim: 8 26 | 27 | trainer: 28 | device: ## to be filled in code 29 | state_il: False 30 | num_eval_episodes: 5 31 | eval_every: 5 32 | K: ${model.K} 33 | 34 | model: 35 | train_lm: False 36 | 37 | option_selector: 38 | commitment_weight: 20 39 | option_transformer: 40 | n_layer: 2 41 | n_head: 4 42 | 43 | dt: 44 | hidden_size: 128 45 | n_layer: 2 46 | n_head: 4 47 | activation_function: 'relu' 48 | n_positions: 1024 49 | dropout: 0.1 50 | no_states: false 51 | no_actions: false 52 | 53 | learning_rate: 1e-4 54 | lm_learning_rate: 1e-6 55 | weight_decay: 1e-4 56 | os_learning_rate: 1e-4 -------------------------------------------------------------------------------- /skilldiffuser/hrl/conf/env/lorel_franka.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env: 4 | name: 5 | state_dim: (12, 64, 64) 6 | action_dim: 5 7 | discrete: False 8 | eval_offline: True 9 | use_state: False 10 | 11 | train_dataset: 12 | expert_location: '/atlas/u/divgarg/datasets/lorel/may_06_franka_3k/prep_data3.pkl' 13 | num_trajectories: 10000 14 | normalize_states: True 15 | no_lang: False 16 | seed: ${seed} 17 | aug: True 18 | 19 | val_dataset: 20 | expert_location: 21 | num_trajectories: ${trainer.num_eval_episodes} 22 | normalize_states: True 23 | seed: ${seed} 24 | 25 | trainer: 26 | device: ## to be filled in code 27 | state_il: False 28 | num_eval_episodes: 5 29 | eval_every: 5 30 | K: ${model.K} -------------------------------------------------------------------------------- /skilldiffuser/hrl/conf/env/lorel_sawyer_obs.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | learning_rate: 1e-5 4 | lm_learning_rate: 1e-7 5 | weight_decay: 1e-4 6 | os_learning_rate: 1e-6 # 1e-5 7 | 8 | env: 9 | name: LorlEnv-v0 10 | state_dim: (3, 64, 64) 11 | action_dim: 5 12 | discrete: False 13 | eval_offline: False 14 | use_state: False 15 | eval_episode_factor: 10 # 10 16 | eval_env: 17 | 18 | train_dataset: 19 | expert_location: '/path/to/prep_data.pkl' 20 | num_trajectories: 40108 21 | normalize_states: False 22 | no_lang: False 23 | seed: ${seed} 24 | aug: True 25 | 26 | val_dataset: 27 | expert_location: 28 | num_trajectories: ${trainer.num_eval_episodes} 29 | normalize_states: False 30 | seed: ${seed} 31 | aug: True 32 | 33 | 34 | trainer: 35 | device: ## to be filled in code 36 | state_il: False 37 | num_eval_episodes: 5 # 5 38 | eval_every: 20 39 | K: ${model.K} -------------------------------------------------------------------------------- /skilldiffuser/hrl/conf/env/lorel_sawyer_state.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | env: 4 | name: LorlEnv-v0 5 | state_dim: 15 6 | action_dim: 5 7 | discrete: False 8 | eval_offline: False 9 | use_state: True 10 | eval_episode_factor: 10 11 | 12 | train_dataset: 13 | expert_location: '/home/zxliang/dataset/lorel/may_08_sawyer_50k/prep_data3.pkl' 14 | num_trajectories: 40108 15 | normalize_states: False 16 | seed: ${seed} 17 | 18 | val_dataset: 19 | expert_location: 20 | num_trajectories: ${trainer.num_eval_episodes} 21 | normalize_states: False 22 | no_lang: False 23 | seed: ${seed} 24 | 25 | 26 | trainer: 27 | device: ## to be filled in code 28 | state_il: False 29 | num_eval_episodes: 5 30 | eval_every: 5 31 | K: ${model.K} -------------------------------------------------------------------------------- /skilldiffuser/hrl/conf/model/option.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | ema_decay: 0.995 4 | gradient_accumulate_every: 2 5 | 6 | diffuser: 7 | # model: 'models.TemporalUnet' 8 | # diffusion: 'models.GaussianDiffusion' 9 | # horizon: ${model.horizon} 10 | ## dim_mults: '(1, 4, 8)' 11 | # n_diffusion_steps: 128 # 512 12 | # loss_type: 'l2' 13 | # clip_denoised: True 14 | # predict_epsilon: False 15 | # ## loss weighting 16 | # action_weight: 1 17 | ## loss_weights: None 18 | # loss_discount: 1 19 | # savepath: '/home/zxliang/new-code/LISA/lisa/outputs/diffuser_debug' 20 | 21 | model: 'models.TemporalUnet' 22 | diffusion: 'models.GaussianInvDynDiffusion' 23 | horizon: ${model.horizon} # 100 24 | dim_mults: '(1, 4, 8)' 25 | n_diffusion_steps: 200 # 512 26 | # loss_type: 'l2' 27 | # clip_denoised: True 28 | predict_epsilon: True 29 | ## loss weighting 30 | action_weight: 10 31 | loss_weights: None 32 | loss_discount: 1 33 | 34 | returns_condition: True 35 | calc_energy: False 36 | condition_dropout: 0.25 37 | condition_guidance_w: 1.2 38 | test_ret: 0.9 39 | # renderer: 'utils.MuJoCoRenderer' 40 | dim: 128 41 | savepath: 'outputs/diffuser_debug' 42 | 43 | ## dataset 44 | loader: 'datasets.SequenceDataset' 45 | normalizer: 'CDFNormalizer' 46 | preprocess_fns: [] 47 | clip_denoised: True 48 | use_padding: True 49 | include_returns: True 50 | discount: 0.99 51 | max_path_length: 1000 52 | inv_hidden_dim: 256 53 | ar_inv: False 54 | train_only_inv: False 55 | termination_penalty: -100 56 | returns_scale: 400.0 # Determined using rewards from the dataset 57 | 58 | ## training 59 | n_steps_per_epoch: 10000 60 | loss_type: 'l2' 61 | n_train_steps: 8e5 # 8e5 # 1e6 62 | batch_size: 64 # equal to args.batch_size * fold num 63 | learning_rate: 5e-3 # 2e-4 64 | gradient_accumulate_every: 1 # 2 65 | ema_decay: 0.995 66 | log_freq: 1000 67 | save_freq: 10000 68 | sample_freq: 10000 69 | n_saves: 5 70 | save_parallel: False 71 | n_reference: 8 72 | save_checkpoints: True 73 | loadpath: ## to be filled in code 74 | 75 | ## misc 76 | bucket: '' 77 | seed: 100 78 | 79 | model: 80 | name: option 81 | 82 | horizon: 8 83 | K: 8 84 | train_lm: False 85 | use_iq: ${use_iq} 86 | method: ${model.name} 87 | state_reconstruct: False 88 | lang_reconstruct: False 89 | 90 | state_reconstructor: 91 | num_hidden: 2 92 | hidden_size: 128 93 | 94 | lang_reconstructor: 95 | num_hidden: 2 96 | hidden_size: 128 97 | max_options: ## to be filled in code 98 | 99 | option_selector: 100 | horizon: ${model.horizon} 101 | use_vq: True 102 | kmeans_init: True 103 | commitment_weight: 0.25 104 | num_options: 20 105 | num_hidden: 2 106 | hidden_size: 128 107 | 108 | dt: 109 | hidden_size: 128 110 | n_layer: 4 111 | n_head: 4 112 | option_il: False 113 | activation_function: 'relu' 114 | n_positions: 1024 115 | dropout: 0.1 116 | no_actions: False 117 | no_states: False 118 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/conf/model/traj_option.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | ema_decay: 0.995 4 | gradient_accumulate_every: 2 5 | 6 | diffuser: 7 | # model: 'models.TemporalUnet' 8 | # diffusion: 'models.GaussianDiffusion' 9 | # horizon: ${model.horizon} 10 | ## dim_mults: '(1, 4, 8)' 11 | # n_diffusion_steps: 128 # 512 12 | # loss_type: 'l2' 13 | # clip_denoised: True 14 | # predict_epsilon: False 15 | # ## loss weighting 16 | # action_weight: 1 17 | ## loss_weights: None 18 | # loss_discount: 1 19 | # savepath: '/home/zxliang/new-code/LISA/lisa/outputs/diffuser_debug' 20 | 21 | model: 'models.TemporalUnet' 22 | diffusion: 'models.GaussianInvDynDiffusion' 23 | horizon: ${model.horizon} # 100 24 | dim_mults: '(1, 4, 8)' 25 | n_diffusion_steps: 200 # 512 26 | # loss_type: 'l2' 27 | # clip_denoised: True 28 | predict_epsilon: True 29 | ## loss weighting 30 | action_weight: 10 31 | loss_weights: None 32 | loss_discount: 1 33 | 34 | returns_condition: True 35 | calc_energy: False 36 | condition_dropout: 0.25 37 | condition_guidance_w: 1.2 38 | test_ret: 0.9 39 | # renderer: 'utils.MuJoCoRenderer' 40 | dim: 128 41 | savepath: 'outputs/diffuser_debug' 42 | 43 | ## dataset 44 | loader: 'datasets.SequenceDataset' 45 | normalizer: 'CDFNormalizer' 46 | preprocess_fns: [] 47 | clip_denoised: True 48 | use_padding: True 49 | include_returns: True 50 | discount: 0.99 51 | max_path_length: 1000 52 | inv_hidden_dim: 256 53 | ar_inv: False 54 | train_only_inv: False 55 | termination_penalty: -100 56 | returns_scale: 400.0 # Determined using rewards from the dataset 57 | 58 | ## training 59 | n_steps_per_epoch: 10000 60 | loss_type: 'l2' 61 | n_train_steps: 5e5 # 8e5 # 1e6 62 | batch_size: 64 # 32 equal to args.batch_size 63 | learning_rate: 1e-3 # 2e-4 64 | gradient_accumulate_every: 1 # 2 65 | ema_decay: 0.995 66 | log_freq: 1000 67 | save_freq: 10000 68 | sample_freq: 10000 69 | n_saves: 5 70 | save_parallel: False 71 | n_reference: 8 72 | save_checkpoints: True 73 | loadpath: ## to be filled in code 74 | 75 | ## misc 76 | bucket: '' 77 | seed: 100 78 | 79 | model: 80 | name: traj_option 81 | 82 | horizon: 8 # 10 83 | K: 8 # 10 84 | train_lm: False 85 | use_iq: ${use_iq} 86 | method: ${model.name} 87 | state_reconstruct: False 88 | lang_reconstruct: False 89 | 90 | state_reconstructor: 91 | num_hidden: 2 92 | hidden_size: 128 93 | 94 | lang_reconstructor: 95 | num_hidden: 2 96 | hidden_size: 128 97 | max_options: ## to be filled in code 98 | 99 | option_selector: 100 | horizon: ${model.horizon} 101 | use_vq: True 102 | kmeans_init: True 103 | commitment_weight: 0.25 104 | num_options: 20 105 | num_hidden: 2 106 | option_transformer: 107 | hidden_size: 128 108 | n_layer: 1 109 | n_head: 4 110 | max_length: 111 | max_ep_len: 112 | activation_function: 'relu' 113 | n_positions: 1024 114 | dropout: 0.1 115 | output_attention: False 116 | 117 | dt: 118 | hidden_size: 128 119 | n_layer: 1 120 | n_head: 4 121 | option_il: False 122 | activation_function: 'relu' 123 | n_positions: 1024 124 | dropout: 0.1 125 | no_actions: False 126 | no_states: False 127 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/conf/model/vanilla.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | model: 4 | name: vanilla 5 | 6 | horizon: 5 7 | K: 5 8 | train_lm: True 9 | method: ${model.name} 10 | state_reconstruct: False 11 | lang_reconstruct: False 12 | 13 | state_reconstructor: 14 | num_hidden: 2 15 | hidden_size: 128 16 | 17 | lang_reconstructor: 18 | num_hidden: 2 19 | hidden_size: 128 20 | max_options: ## to be filled in code 21 | 22 | dt: 23 | option_il: False 24 | hidden_size: 128 25 | n_layer: 2 26 | n_head: 4 27 | activation_function: 'relu' 28 | n_positions: 1024 29 | dropout: 0.1 30 | no_states: False 31 | no_actions: False -------------------------------------------------------------------------------- /skilldiffuser/hrl/decision_transformer.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import transformers 8 | 9 | from trajectory_model import TrajectoryModel 10 | from trajectory_gpt2 import GPT2Model 11 | from img_encoder import Encoder 12 | 13 | import iq 14 | from utils import pad 15 | 16 | 17 | class DecisionTransformer(TrajectoryModel): 18 | 19 | """ 20 | This model uses GPT to model (Lang, state_1, action_1, state_2, ...) or (state_1, option_1, action_1, ...) 21 | """ 22 | 23 | def __init__( 24 | self, 25 | state_dim, 26 | action_dim, 27 | option_dim, 28 | lang_dim, 29 | discrete, 30 | hidden_size, 31 | use_language=False, 32 | use_options=True, 33 | option_il=False, 34 | predict_q=False, 35 | max_length=None, 36 | max_ep_len=4096, 37 | action_tanh=False, 38 | no_states=False, 39 | no_actions=False, 40 | ** kwargs): 41 | # max_length used to be K 42 | super().__init__(state_dim, action_dim, max_length=max_length) 43 | 44 | self.use_options = use_options 45 | self.use_language = use_language 46 | self.option_il = option_il 47 | self.predict_q = predict_q 48 | 49 | if use_language and use_options: 50 | raise ValueError("Cannot use language and options!") 51 | if not use_language and not use_options: 52 | raise ValueError("Have to use language or options!") 53 | self.option_dim = option_dim 54 | self.discrete = discrete 55 | 56 | self.hidden_size = hidden_size 57 | config = transformers.GPT2Config( 58 | vocab_size=1, # doesn't matter -- we don't use the vocab 59 | n_embd=hidden_size, 60 | **kwargs 61 | ) 62 | 63 | if isinstance(state_dim, tuple): 64 | # LORL 65 | if state_dim[0] == 3: 66 | # LORL Sawyer 67 | self.embed_state = Encoder(hidden_size=hidden_size, ch=3, robot=False) 68 | else: 69 | # LORL Franka 70 | self.embed_state = Encoder(hidden_size=hidden_size, ch=12, robot=True) 71 | else: 72 | self.embed_state = nn.Linear(self.state_dim, hidden_size) 73 | 74 | # note: the only difference between this GPT2Model and the default Huggingface version 75 | # is that the positional embeddings are removed (since we'll add those ourselves) 76 | self.transformer = GPT2Model(config) 77 | 78 | self.embed_timestep = nn.Embedding(max_ep_len, hidden_size) 79 | 80 | self.embed_action = nn.Linear(self.act_dim, hidden_size) 81 | 82 | self.no_states = no_states 83 | self.no_actions = no_actions 84 | 85 | if use_options: 86 | self.embed_option = nn.Linear(self.option_dim, hidden_size) 87 | 88 | if use_language: 89 | self.embed_lang = nn.Linear(lang_dim, hidden_size) 90 | 91 | self.embed_ln = nn.LayerNorm(hidden_size) 92 | # note: we don't predict states or returns for the paper 93 | if isinstance(self.state_dim, int): 94 | self.predict_state = torch.nn.Linear(hidden_size, self.state_dim) 95 | self.predict_action = nn.Sequential( 96 | *([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh and not discrete else [])) 97 | ) 98 | if use_options: 99 | self.predict_option = torch.nn.Linear(hidden_size, self.option_dim) 100 | if predict_q: 101 | self.predict_q = torch.nn.Linear(hidden_size, self.act_dim) 102 | 103 | def forward(self, states, actions, timesteps, options=None, word_embeddings=None, attention_mask=None): 104 | 105 | batch_size, seq_length = states.shape[0], states.shape[1] 106 | 107 | if attention_mask is None: 108 | raise ValueError('Should not have attention_mask NONE') 109 | # attention mask for GPT: 1 if can be attended to, 0 if not 110 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long) 111 | 112 | if self.use_options: 113 | assert options is not None 114 | option_embeddings = self.embed_option(options) 115 | time_embeddings = self.embed_timestep(timesteps) 116 | 117 | # time embeddings are treated similar to positional embeddings 118 | option_embeddings = option_embeddings + time_embeddings 119 | 120 | if self.no_states: 121 | # IMP: MAKE SURE THIS IS NOT SET ON BY DEFAULT 122 | state_embeddings = self.embed_state(torch.zeros_like(states)) 123 | else: 124 | state_embeddings = self.embed_state(states) 125 | state_embeddings = state_embeddings + time_embeddings 126 | 127 | if self.no_actions: 128 | # IMP: MAKE SURE THIS IS NOT SET ON BY DEFAULT 129 | action_embeddings = self.embed_action(torch.zeros_like(actions)) 130 | else: 131 | action_embeddings = self.embed_action(actions) 132 | action_embeddings = action_embeddings + time_embeddings 133 | 134 | # this makes the sequence look like (o1, s1, a1,o2, s2, a2, ...) 135 | # which works nice in an autoregressive sense since states predict actions 136 | # note that o1 and o2 need not be different 137 | stacked_inputs = torch.stack( 138 | (option_embeddings, state_embeddings, action_embeddings), 139 | dim=1).permute( 140 | 0, 2, 1, 3).reshape( 141 | batch_size, 3 * seq_length, self.hidden_size) 142 | # LAYERNORM 143 | stacked_inputs = self.embed_ln(stacked_inputs) 144 | 145 | # to make the attention mask fit the stacked inputs, have to stack it as well 146 | stacked_attention_mask = torch.stack( 147 | (attention_mask, attention_mask, attention_mask), dim=1 148 | ).permute(0, 2, 1).reshape(batch_size, 3 * seq_length) 149 | 150 | # we feed in the input embeddings (not word indices as in NLP) to the model 151 | transformer_outputs = self.transformer( 152 | inputs_embeds=stacked_inputs, 153 | attention_mask=stacked_attention_mask, 154 | ) 155 | x = transformer_outputs['last_hidden_state'] 156 | 157 | # reshape x so that the second dimension corresponds to the original 158 | # options (0), states (1) or actions (2); i.e. x[:,0,t] is the token for s_t 159 | traj_out = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3) 160 | # get predictions 161 | # predict next state given option, state and action. skip the last state for prediction 162 | if isinstance(self.state_dim, int): 163 | state_preds = self.predict_state(traj_out[:, 2])[:, :-1, :] 164 | else: 165 | state_preds = None 166 | # predict next action given state and option 167 | action_preds = self.predict_action(traj_out[:, 1]) 168 | 169 | # reconstruct current option given current option 170 | if self.option_il: 171 | option_preds = self.predict_option(traj_out[:, 0]) 172 | options_loss = F.mse_loss(option_preds, options.detach()) 173 | else: 174 | options_loss = None 175 | 176 | outputs = {'state_preds': state_preds, 177 | 'action_preds': action_preds, 178 | 'options_loss': options_loss} 179 | 180 | if self.predict_q: 181 | # predict next Q given state and option ## IMP: Don't use current action 182 | q_preds = self.predict_q(traj_out[:, 1]) 183 | outputs.update({'q_preds': q_preds}) 184 | 185 | return outputs 186 | 187 | if self.use_language: 188 | assert word_embeddings is not None 189 | num_tokens = word_embeddings.shape[1] 190 | state_embeddings = self.embed_state(states) 191 | lang_embeddings = self.embed_lang(word_embeddings) 192 | action_embeddings = self.embed_action(actions) 193 | time_embeddings = self.embed_timestep(timesteps) 194 | 195 | # time embeddings are treated similar to positional embeddings 196 | state_embeddings = state_embeddings + time_embeddings 197 | action_embeddings = action_embeddings + time_embeddings 198 | 199 | stacked_inputs = torch.stack( 200 | (state_embeddings, action_embeddings), 201 | dim=1).permute( 202 | 0, 2, 1, 3).reshape( 203 | batch_size, 2 * seq_length, self.hidden_size) 204 | lang_and_inputs = torch.cat([lang_embeddings, stacked_inputs], dim=1) 205 | # LAYERNORM AFTER LANGUAGE 206 | stacked_inputs = self.embed_ln(lang_and_inputs) 207 | 208 | # to make the attention mask fit the stacked inputs, have to stack it as well 209 | stacked_attention_mask = torch.stack( 210 | (attention_mask, attention_mask), dim=1 211 | ).permute(0, 2, 1).reshape(batch_size, 2*seq_length) 212 | lang_attn_mask = torch.cat( 213 | [torch.ones((batch_size, num_tokens), device=states.device), stacked_attention_mask], dim=1) 214 | 215 | # we feed in the input embeddings (not word indices as in NLP) to the model 216 | transformer_outputs = self.transformer( 217 | inputs_embeds=stacked_inputs, 218 | attention_mask=lang_attn_mask, 219 | ) 220 | x = transformer_outputs['last_hidden_state'] 221 | 222 | # reshape x so that the second dimension corresponds to the original 223 | # states (0), or actions (1); i.e. x[:,0,t] is the token for s_t 224 | lang_out = x[:, :num_tokens, :].reshape( 225 | batch_size, num_tokens, 1, self.hidden_size).permute(0, 2, 1, 3) 226 | traj_out = x[:, num_tokens:, :].reshape( 227 | batch_size, seq_length, 2, self.hidden_size).permute(0, 2, 1, 3) 228 | 229 | # get predictions 230 | # predict state given state, action. skip the last prediction 231 | if isinstance(self.state_dim, int): 232 | state_preds = self.predict_state(traj_out[:, 1])[:, :-1, :] 233 | else: 234 | state_preds = None 235 | action_preds = self.predict_action(traj_out[:, 0]) # predict next action given state 236 | 237 | outputs = {'state_preds': state_preds, 238 | 'action_preds': action_preds} 239 | 240 | if self.predict_q: 241 | # predict next Q given state ## IMP: Don't use current action 242 | q_preds = self.predict_q(traj_out[:, 0]) 243 | outputs.update({'q_preds': q_preds}) 244 | 245 | return outputs 246 | 247 | def get_action(self, states, actions, timesteps, options=None, word_embeddings=None, **kwargs): 248 | 249 | if self.use_options: 250 | assert options is not None 251 | if isinstance(self.state_dim, tuple): 252 | states = states.reshape(1, -1, *self.state_dim) 253 | else: 254 | states = states.reshape(1, -1, self.state_dim) 255 | options = options.reshape(1, -1, self.option_dim) 256 | actions = actions.reshape(1, -1, self.act_dim) 257 | timesteps = timesteps.reshape(1, -1) 258 | 259 | if self.max_length is not None: 260 | states = states[:, -self.max_length:] 261 | options = options[:, -self.max_length:] 262 | actions = actions[:, -self.max_length:] 263 | timesteps = timesteps[:, -self.max_length:] 264 | 265 | # pad all tokens to sequence length 266 | attention_mask = pad(torch.ones(1, states.shape[1]), self.max_length).to( 267 | dtype=torch.long, device=states.device).reshape(1, -1) 268 | states = pad(states, self.max_length).to(dtype=torch.float32) 269 | options = pad(options, self.max_length).to(dtype=torch.float32) 270 | actions = pad(actions, self.max_length).to(dtype=torch.float32) 271 | timesteps = pad(timesteps, self.max_length).to(dtype=torch.long) 272 | else: 273 | raise ValueError('Should not have max_length NONE') 274 | attention_mask = None 275 | 276 | preds = self.forward( 277 | states, actions, timesteps, options=options, attention_mask=attention_mask) 278 | 279 | if self.use_language: 280 | assert word_embeddings is not None 281 | if isinstance(self.state_dim, tuple): 282 | states = states.reshape(1, -1, *self.state_dim) 283 | else: 284 | states = states.reshape(1, -1, self.state_dim) 285 | actions = actions.reshape(1, -1, self.act_dim) 286 | timesteps = timesteps.reshape(1, -1) 287 | 288 | if self.max_length is not None: 289 | states = states[:, -self.max_length:] 290 | actions = actions[:, -self.max_length:] 291 | timesteps = timesteps[:, -self.max_length:] 292 | 293 | # pad all tokens to sequence length 294 | attention_mask = pad( 295 | torch.ones(1, states.shape[1]), 296 | self.max_length).to( 297 | dtype=torch.long, device=states.device).reshape( 298 | 1, -1) 299 | states = pad(states, self.max_length).to(dtype=torch.float32) 300 | actions = pad(actions, self.max_length).to(dtype=torch.float32) 301 | timesteps = pad(timesteps, self.max_length).to(dtype=torch.long) 302 | else: 303 | attention_mask = None 304 | 305 | preds = self.forward( 306 | states, actions, timesteps, word_embeddings=word_embeddings, attention_mask=attention_mask, **kwargs) 307 | 308 | return preds 309 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/difftrainer.py: -------------------------------------------------------------------------------- 1 | # import pdb 2 | 3 | import numpy as np 4 | import torch 5 | import copy 6 | from diffuser.utils.training import EMA 7 | from ml_logger import logger 8 | import os 9 | import torch.nn as nn 10 | 11 | class DiffTrainer(nn.Module): 12 | def __init__( 13 | self, 14 | args, 15 | diffusion_model, 16 | diff_trainer_args, 17 | ): 18 | # max_length used to be K 19 | super().__init__() 20 | 21 | self.model = diffusion_model 22 | self.ema = EMA(diff_trainer_args['ema_decay']) 23 | self.ema_model = copy.deepcopy(self.model) 24 | self.update_ema_every = diff_trainer_args['update_ema_every'] 25 | self.save_checkpoints = diff_trainer_args['save_checkpoints'] 26 | self.step_start_ema = 2000 27 | 28 | self.log_freq = diff_trainer_args['log_freq'] 29 | self.sample_freq = diff_trainer_args['sample_freq'] 30 | self.save_freq = diff_trainer_args['save_freq'] 31 | self.label_freq = diff_trainer_args['label_freq'] 32 | self.save_parallel = diff_trainer_args['save_parallel'] 33 | 34 | # self.batch_size = diff_trainer_args['train_batch_size'] 35 | self.gradient_accumulate_every = diff_trainer_args['gradient_accumulate_every'] 36 | 37 | self.bucket = os.path.join(args.hydra_base_dir, "buckets") 38 | self.n_reference = diff_trainer_args['n_reference'] 39 | 40 | self.reset_parameters() 41 | self.step = 0 42 | 43 | self.device = diff_trainer_args['train_device'] 44 | 45 | logger.configure(args.hydra_base_dir, 46 | prefix=f"diffuser_log") 47 | # self.model = self.model.to(device=self.device) # TODO 48 | # self.ema_model = self.ema_model.to(device=self.device) 49 | 50 | self.lr = diff_trainer_args['train_lr'] 51 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) 52 | 53 | def reset_parameters(self): 54 | self.ema_model.load_state_dict(self.model.state_dict()) 55 | 56 | def step_ema(self): 57 | if self.step < self.step_start_ema: 58 | self.reset_parameters() 59 | return 60 | self.ema.update_model_average(self.ema_model, self.model) 61 | 62 | # -----------------------------------------------------------------------------# 63 | # ------------------------------------ api ------------------------------------# 64 | # -----------------------------------------------------------------------------# 65 | 66 | def train_iteration(self, batch): 67 | # timer = Timer() 68 | # for step in range(n_train_steps): 69 | # for i in range(self.gradient_accumulate_every): 70 | # batch = next(self.dataloader) 71 | # batch = batch_to_device(batch, device=self.device) 72 | loss, infos = self.model.loss(*batch) 73 | # loss = loss / self.gradient_accumulate_every 74 | 75 | self.optimizer.zero_grad() 76 | loss.backward(retain_graph=True) 77 | self.optimizer.step() 78 | 79 | if self.step % self.update_ema_every == 0: 80 | self.step_ema() 81 | 82 | if self.step % self.save_freq == 0: 83 | self.save() 84 | 85 | infos.pop('obs') 86 | 87 | if self.step % self.log_freq == 0: 88 | infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()]) 89 | logger.print(f'{self.step}: {loss:8.4f} | {infos_str}') 90 | metrics = {k: v.detach().item() for k, v in infos.items()} 91 | metrics['steps'] = self.step 92 | metrics['loss'] = loss.detach().item() 93 | logger.log_metrics_summary(metrics, default_stats='mean') 94 | 95 | # if self.step == 0 and self.sample_freq: 96 | # self.render_reference(self.n_reference) 97 | 98 | # if self.sample_freq and self.step % self.sample_freq == 0: 99 | # if self.model.__class__ == diffuser.models.diffusion.GaussianInvDynDiffusion: 100 | # self.inv_render_samples() 101 | # elif self.model.__class__ == diffuser.models.diffusion.ActionGaussianDiffusion: 102 | # pass 103 | # else: 104 | # self.render_samples() 105 | 106 | self.step += 1 107 | return 108 | 109 | def save(self): 110 | ''' 111 | saves model and ema to disk; 112 | syncs to storage bucket if a bucket is specified 113 | ''' 114 | data = { 115 | 'step': self.step, 116 | 'model': self.model.state_dict(), 117 | 'ema': self.ema_model.state_dict() 118 | } 119 | savepath = os.path.join(self.bucket, logger.prefix, 'checkpoint') 120 | os.makedirs(savepath, exist_ok=True) 121 | # logger.save_torch(data, savepath) 122 | if self.save_checkpoints: 123 | savepath = os.path.join(savepath, f'state_{self.step}.pt') 124 | else: 125 | savepath = os.path.join(savepath, 'state.pt') 126 | torch.save(data, savepath) 127 | logger.print(f'[ utils/training ] Saved model to {savepath}') 128 | 129 | def load(self, loadpath): 130 | ''' 131 | loads model and ema from disk 132 | ''' 133 | # loadpath = os.path.join(self.bucket, logger.prefix, f'checkpoint/state.pt') 134 | # data = logger.load_torch(loadpath) 135 | data = torch.load(loadpath) 136 | 137 | # self.step = data['step'] 138 | self.model.load_state_dict(data['model']) 139 | self.ema_model.load_state_dict(data['ema']) 140 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/__init__.py: -------------------------------------------------------------------------------- 1 | from . import environments -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .sequence import * 2 | from .d4rl import load_environment -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/datasets/buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def atleast_2d(x): 4 | while x.ndim < 2: 5 | x = np.expand_dims(x, axis=-1) 6 | return x 7 | 8 | class ReplayBuffer: 9 | 10 | def __init__(self, max_n_episodes, max_path_length, termination_penalty): 11 | self._dict = { 12 | 'path_lengths': np.zeros(max_n_episodes, dtype=np.int), 13 | } 14 | self._count = 0 15 | self.max_n_episodes = max_n_episodes 16 | self.max_path_length = max_path_length 17 | self.termination_penalty = termination_penalty 18 | 19 | def __repr__(self): 20 | return '[ datasets/buffer ] Fields:\n' + '\n'.join( 21 | f' {key}: {val.shape}' 22 | for key, val in self.items() 23 | ) 24 | 25 | def __getitem__(self, key): 26 | return self._dict[key] 27 | 28 | def __setitem__(self, key, val): 29 | self._dict[key] = val 30 | self._add_attributes() 31 | 32 | @property 33 | def n_episodes(self): 34 | return self._count 35 | 36 | @property 37 | def n_steps(self): 38 | return sum(self['path_lengths']) 39 | 40 | def _add_keys(self, path): 41 | if hasattr(self, 'keys'): 42 | return 43 | self.keys = list(path.keys()) 44 | 45 | def _add_attributes(self): 46 | ''' 47 | can access fields with `buffer.observations` 48 | instead of `buffer['observations']` 49 | ''' 50 | for key, val in self._dict.items(): 51 | setattr(self, key, val) 52 | 53 | def items(self): 54 | return {k: v for k, v in self._dict.items() 55 | if k != 'path_lengths'}.items() 56 | 57 | def _allocate(self, key, array): 58 | assert key not in self._dict 59 | dim = array.shape[-1] 60 | shape = (self.max_n_episodes, self.max_path_length, dim) 61 | self._dict[key] = np.zeros(shape, dtype=np.float32) 62 | # print(f'[ utils/mujoco ] Allocated {key} with size {shape}') 63 | 64 | def add_path(self, path): 65 | path_length = len(path['observations']) 66 | assert path_length <= self.max_path_length 67 | 68 | if path['terminals'].any(): 69 | assert (path['terminals'][-1] == True) and (not path['terminals'][:-1].any()) 70 | 71 | ## if first path added, set keys based on contents 72 | self._add_keys(path) 73 | 74 | ## add tracked keys in path 75 | for key in self.keys: 76 | array = atleast_2d(path[key]) 77 | if key not in self._dict: self._allocate(key, array) 78 | self._dict[key][self._count, :path_length] = array 79 | 80 | ## penalize early termination 81 | if path['terminals'].any() and self.termination_penalty is not None: 82 | assert not path['timeouts'].any(), 'Penalized a timeout episode for early termination' 83 | self._dict['rewards'][self._count, path_length - 1] += self.termination_penalty 84 | 85 | ## record path length 86 | self._dict['path_lengths'][self._count] = path_length 87 | 88 | ## increment path counter 89 | self._count += 1 90 | 91 | def truncate_path(self, path_ind, step): 92 | old = self._dict['path_lengths'][path_ind] 93 | new = min(step, old) 94 | self._dict['path_lengths'][path_ind] = new 95 | 96 | def finalize(self): 97 | ## remove extra slots 98 | for key in self.keys + ['path_lengths']: 99 | self._dict[key] = self._dict[key][:self._count] 100 | self._add_attributes() 101 | print(f'[ datasets/buffer ] Finalized replay buffer | {self._count} episodes') 102 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/datasets/d4rl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import numpy as np 4 | import gym 5 | import pdb 6 | 7 | from contextlib import ( 8 | contextmanager, 9 | redirect_stderr, 10 | redirect_stdout, 11 | ) 12 | 13 | @contextmanager 14 | def suppress_output(): 15 | """ 16 | A context manager that redirects stdout and stderr to devnull 17 | https://stackoverflow.com/a/52442331 18 | """ 19 | with open(os.devnull, 'w') as fnull: 20 | with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out: 21 | yield (err, out) 22 | 23 | with suppress_output(): 24 | ## d4rl prints out a variety of warnings 25 | import d4rl 26 | 27 | #-----------------------------------------------------------------------------# 28 | #-------------------------------- general api --------------------------------# 29 | #-----------------------------------------------------------------------------# 30 | 31 | def load_environment(name): 32 | if type(name) != str: 33 | ## name is already an environment 34 | return name 35 | with suppress_output(): 36 | wrapped_env = gym.make(name) 37 | env = wrapped_env.unwrapped 38 | env.max_episode_steps = wrapped_env._max_episode_steps 39 | env.name = name 40 | return env 41 | 42 | def get_dataset(env): 43 | dataset = env.get_dataset() 44 | 45 | if 'antmaze' in str(env).lower(): 46 | ## the antmaze-v0 environments have a variety of bugs 47 | ## involving trajectory segmentation, so manually reset 48 | ## the terminal and timeout fields 49 | dataset = antmaze_fix_timeouts(dataset) 50 | dataset = antmaze_scale_rewards(dataset) 51 | get_max_delta(dataset) 52 | 53 | return dataset 54 | 55 | def sequence_dataset(env, preprocess_fn): 56 | """ 57 | Returns an iterator through trajectories. 58 | Args: 59 | env: An OfflineEnv object. 60 | dataset: An optional dataset to pass in for processing. If None, 61 | the dataset will default to env.get_dataset() 62 | **kwargs: Arguments to pass to env.get_dataset(). 63 | Returns: 64 | An iterator through dictionaries with keys: 65 | observations 66 | actions 67 | rewards 68 | terminals 69 | """ 70 | dataset = get_dataset(env) 71 | dataset = preprocess_fn(dataset) 72 | 73 | N = dataset['rewards'].shape[0] 74 | data_ = collections.defaultdict(list) 75 | 76 | # The newer version of the dataset adds an explicit 77 | # timeouts field. Keep old method for backwards compatability. 78 | use_timeouts = 'timeouts' in dataset 79 | 80 | episode_step = 0 81 | for i in range(N): 82 | done_bool = bool(dataset['terminals'][i]) 83 | if use_timeouts: 84 | final_timestep = dataset['timeouts'][i] 85 | else: 86 | final_timestep = (episode_step == env._max_episode_steps - 1) 87 | 88 | for k in dataset: 89 | if 'metadata' in k or 'infos' in k: continue 90 | data_[k].append(dataset[k][i]) 91 | if done_bool or final_timestep: 92 | episode_step = 0 93 | episode_data = {} 94 | for k in data_: 95 | episode_data[k] = np.array(data_[k]) 96 | if 'maze2d' in env.name: 97 | episode_data = process_maze2d_episode(episode_data) 98 | yield episode_data 99 | data_ = collections.defaultdict(list) 100 | 101 | episode_step += 1 102 | 103 | 104 | #-----------------------------------------------------------------------------# 105 | #-------------------------------- maze2d fixes -------------------------------# 106 | #-----------------------------------------------------------------------------# 107 | 108 | def process_maze2d_episode(episode): 109 | ''' 110 | adds in `next_observations` field to episode 111 | ''' 112 | assert 'next_observations' not in episode 113 | length = len(episode['observations']) 114 | next_observations = episode['observations'][1:].copy() 115 | for key, val in episode.items(): 116 | episode[key] = val[:-1] 117 | episode['next_observations'] = next_observations 118 | return episode 119 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/datasets/normalization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.interpolate as interpolate 3 | import pdb 4 | 5 | POINTMASS_KEYS = ['observations', 'actions', 'next_observations', 'deltas'] 6 | 7 | #-----------------------------------------------------------------------------# 8 | #--------------------------- multi-field normalizer --------------------------# 9 | #-----------------------------------------------------------------------------# 10 | 11 | class DatasetNormalizer: 12 | 13 | def __init__(self, dataset, normalizer, path_lengths=None): 14 | dataset = flatten(dataset, path_lengths) 15 | 16 | self.observation_dim = dataset['observations'].shape[1] 17 | self.action_dim = dataset['actions'].shape[1] 18 | 19 | if type(normalizer) == str: 20 | normalizer = eval(normalizer) 21 | 22 | self.normalizers = {} 23 | for key, val in dataset.items(): 24 | try: 25 | self.normalizers[key] = normalizer(val) 26 | except: 27 | print(f'[ utils/normalization ] Skipping {key} | {normalizer}') 28 | # key: normalizer(val) 29 | # for key, val in dataset.items() 30 | 31 | def __repr__(self): 32 | string = '' 33 | for key, normalizer in self.normalizers.items(): 34 | string += f'{key}: {normalizer}]\n' 35 | return string 36 | 37 | def __call__(self, *args, **kwargs): 38 | return self.normalize(*args, **kwargs) 39 | 40 | def normalize(self, x, key): 41 | return self.normalizers[key].normalize(x) 42 | 43 | def unnormalize(self, x, key): 44 | return self.normalizers[key].unnormalize(x) 45 | 46 | def flatten(dataset, path_lengths): 47 | ''' 48 | flattens dataset of { key: [ n_episodes x max_path_lenth x dim ] } 49 | to { key : [ (n_episodes * sum(path_lengths)) x dim ]} 50 | ''' 51 | flattened = {} 52 | for key, xs in dataset.items(): 53 | assert len(xs) == len(path_lengths) 54 | flattened[key] = np.concatenate([ 55 | x[:length] 56 | for x, length in zip(xs, path_lengths) 57 | ], axis=0) 58 | return flattened 59 | 60 | #-----------------------------------------------------------------------------# 61 | #------------------------------- @TODO: remove? ------------------------------# 62 | #-----------------------------------------------------------------------------# 63 | 64 | class PointMassDatasetNormalizer(DatasetNormalizer): 65 | 66 | def __init__(self, preprocess_fns, dataset, normalizer, keys=POINTMASS_KEYS): 67 | 68 | reshaped = {} 69 | for key, val in dataset.items(): 70 | dim = val.shape[-1] 71 | reshaped[key] = val.reshape(-1, dim) 72 | 73 | self.observation_dim = reshaped['observations'].shape[1] 74 | self.action_dim = reshaped['actions'].shape[1] 75 | 76 | if type(normalizer) == str: 77 | normalizer = eval(normalizer) 78 | 79 | self.normalizers = { 80 | key: normalizer(reshaped[key]) 81 | for key in keys 82 | } 83 | 84 | #-----------------------------------------------------------------------------# 85 | #-------------------------- single-field normalizers -------------------------# 86 | #-----------------------------------------------------------------------------# 87 | 88 | class Normalizer: 89 | ''' 90 | parent class, subclass by defining the `normalize` and `unnormalize` methods 91 | ''' 92 | 93 | def __init__(self, X): 94 | self.X = X.astype(np.float32) 95 | self.mins = X.min(axis=0) 96 | self.maxs = X.max(axis=0) 97 | 98 | def __repr__(self): 99 | return ( 100 | f'''[ Normalizer ] dim: {self.mins.size}\n -: ''' 101 | f'''{np.round(self.mins, 2)}\n +: {np.round(self.maxs, 2)}\n''' 102 | ) 103 | 104 | def __call__(self, x): 105 | return self.normalize(x) 106 | 107 | def normalize(self, *args, **kwargs): 108 | raise NotImplementedError() 109 | 110 | def unnormalize(self, *args, **kwargs): 111 | raise NotImplementedError() 112 | 113 | 114 | class DebugNormalizer(Normalizer): 115 | ''' 116 | identity function 117 | ''' 118 | 119 | def normalize(self, x, *args, **kwargs): 120 | return x 121 | 122 | def unnormalize(self, x, *args, **kwargs): 123 | return x 124 | 125 | 126 | class GaussianNormalizer(Normalizer): 127 | ''' 128 | normalizes to zero mean and unit variance 129 | ''' 130 | 131 | def __init__(self, *args, **kwargs): 132 | super().__init__(*args, **kwargs) 133 | self.means = self.X.mean(axis=0) 134 | self.stds = self.X.std(axis=0) 135 | self.z = 1 136 | 137 | def __repr__(self): 138 | return ( 139 | f'''[ Normalizer ] dim: {self.mins.size}\n ''' 140 | f'''means: {np.round(self.means, 2)}\n ''' 141 | f'''stds: {np.round(self.z * self.stds, 2)}\n''' 142 | ) 143 | 144 | def normalize(self, x): 145 | return (x - self.means) / self.stds 146 | 147 | def unnormalize(self, x): 148 | return x * self.stds + self.means 149 | 150 | 151 | class LimitsNormalizer(Normalizer): 152 | ''' 153 | maps [ xmin, xmax ] to [ -1, 1 ] 154 | ''' 155 | 156 | def normalize(self, x): 157 | ## [ 0, 1 ] 158 | x = (x - self.mins) / (self.maxs - self.mins) 159 | ## [ -1, 1 ] 160 | x = 2 * x - 1 161 | return x 162 | 163 | def unnormalize(self, x, eps=1e-4): 164 | ''' 165 | x : [ -1, 1 ] 166 | ''' 167 | if x.max() > 1 + eps or x.min() < -1 - eps: 168 | # print(f'[ datasets/mujoco ] Warning: sample out of range | ({x.min():.4f}, {x.max():.4f})') 169 | x = np.clip(x, -1, 1) 170 | 171 | ## [ -1, 1 ] --> [ 0, 1 ] 172 | x = (x + 1) / 2. 173 | 174 | return x * (self.maxs - self.mins) + self.mins 175 | 176 | class SafeLimitsNormalizer(LimitsNormalizer): 177 | ''' 178 | functions like LimitsNormalizer, but can handle data for which a dimension is constant 179 | ''' 180 | 181 | def __init__(self, *args, eps=1, **kwargs): 182 | super().__init__(*args, **kwargs) 183 | for i in range(len(self.mins)): 184 | if self.mins[i] == self.maxs[i]: 185 | print(f''' 186 | [ utils/normalization ] Constant data in dimension {i} | ''' 187 | f'''max = min = {self.maxs[i]}''' 188 | ) 189 | self.mins -= eps 190 | self.maxs += eps 191 | 192 | #-----------------------------------------------------------------------------# 193 | #------------------------------- CDF normalizer ------------------------------# 194 | #-----------------------------------------------------------------------------# 195 | 196 | class CDFNormalizer(Normalizer): 197 | ''' 198 | makes training data uniform (over each dimension) by transforming it with marginal CDFs 199 | ''' 200 | 201 | def __init__(self, X): 202 | super().__init__(atleast_2d(X)) 203 | self.dim = self.X.shape[1] 204 | self.cdfs = [CDFNormalizer1d(self.X[:, i]) for i in range(self.dim)] 205 | 206 | def __repr__(self): 207 | return f'[ CDFNormalizer ] dim: {self.mins.size}\n' + ' | '.join( 208 | f'{i:3d}: {cdf}' for i, cdf in enumerate(self.cdfs) 209 | ) 210 | 211 | def wrap(self, fn_name, x): 212 | shape = x.shape 213 | ## reshape to 2d 214 | x = x.reshape(-1, self.dim) 215 | out = np.zeros_like(x) 216 | for i, cdf in enumerate(self.cdfs): 217 | fn = getattr(cdf, fn_name) 218 | out[:, i] = fn(x[:, i]) 219 | return out.reshape(shape) 220 | 221 | def normalize(self, x): 222 | return self.wrap('normalize', x) 223 | 224 | def unnormalize(self, x): 225 | return self.wrap('unnormalize', x) 226 | 227 | class CDFNormalizer1d: 228 | ''' 229 | CDF normalizer for a single dimension 230 | ''' 231 | 232 | def __init__(self, X): 233 | assert X.ndim == 1 234 | self.X = X.astype(np.float32) 235 | if self.X.max() == self.X.min(): 236 | self.constant = True 237 | else: 238 | self.constant = False 239 | quantiles, cumprob = empirical_cdf(self.X) 240 | self.fn = interpolate.interp1d(quantiles, cumprob) 241 | self.inv = interpolate.interp1d(cumprob, quantiles) 242 | 243 | self.xmin, self.xmax = quantiles.min(), quantiles.max() 244 | self.ymin, self.ymax = cumprob.min(), cumprob.max() 245 | 246 | def __repr__(self): 247 | return ( 248 | f'[{np.round(self.xmin, 2):.4f}, {np.round(self.xmax, 2):.4f}' 249 | ) 250 | 251 | def normalize(self, x): 252 | if self.constant: 253 | return x 254 | 255 | x = np.clip(x, self.xmin, self.xmax) 256 | ## [ 0, 1 ] 257 | y = self.fn(x) 258 | ## [ -1, 1 ] 259 | y = 2 * y - 1 260 | return y 261 | 262 | def unnormalize(self, x, eps=1e-4): 263 | ''' 264 | X : [ -1, 1 ] 265 | ''' 266 | ## [ -1, 1 ] --> [ 0, 1 ] 267 | if self.constant: 268 | return x 269 | 270 | x = (x + 1) / 2. 271 | 272 | if (x < self.ymin - eps).any() or (x > self.ymax + eps).any(): 273 | print( 274 | f'''[ dataset/normalization ] Warning: out of range in unnormalize: ''' 275 | f'''[{x.min()}, {x.max()}] | ''' 276 | f'''x : [{self.xmin}, {self.xmax}] | ''' 277 | f'''y: [{self.ymin}, {self.ymax}]''' 278 | ) 279 | 280 | x = np.clip(x, self.ymin, self.ymax) 281 | 282 | y = self.inv(x) 283 | return y 284 | 285 | def empirical_cdf(sample): 286 | ## https://stackoverflow.com/a/33346366 287 | 288 | # find the unique values and their corresponding counts 289 | quantiles, counts = np.unique(sample, return_counts=True) 290 | 291 | # take the cumulative sum of the counts and divide by the sample size to 292 | # get the cumulative probabilities between 0 and 1 293 | cumprob = np.cumsum(counts).astype(np.double) / sample.size 294 | 295 | return quantiles, cumprob 296 | 297 | def atleast_2d(x): 298 | if x.ndim < 2: 299 | x = x[:,None] 300 | return x 301 | 302 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/datasets/preprocessing.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import einops 4 | from scipy.spatial.transform import Rotation as R 5 | import pdb 6 | 7 | from .d4rl import load_environment 8 | 9 | #-----------------------------------------------------------------------------# 10 | #-------------------------------- general api --------------------------------# 11 | #-----------------------------------------------------------------------------# 12 | 13 | def compose(*fns): 14 | 15 | def _fn(x): 16 | for fn in fns: 17 | x = fn(x) 18 | return x 19 | 20 | return _fn 21 | 22 | def get_preprocess_fn(fn_names, env): 23 | fns = [eval(name)(env) for name in fn_names] 24 | return compose(*fns) 25 | 26 | def get_policy_preprocess_fn(fn_names): 27 | fns = [eval(name) for name in fn_names] 28 | return compose(*fns) 29 | 30 | #-----------------------------------------------------------------------------# 31 | #-------------------------- preprocessing functions --------------------------# 32 | #-----------------------------------------------------------------------------# 33 | 34 | #------------------------ @TODO: remove some of these ------------------------# 35 | 36 | def arctanh_actions(*args, **kwargs): 37 | epsilon = 1e-4 38 | 39 | def _fn(dataset): 40 | actions = dataset['actions'] 41 | assert actions.min() >= -1 and actions.max() <= 1, \ 42 | f'applying arctanh to actions in range [{actions.min()}, {actions.max()}]' 43 | actions = np.clip(actions, -1 + epsilon, 1 - epsilon) 44 | dataset['actions'] = np.arctanh(actions) 45 | return dataset 46 | 47 | return _fn 48 | 49 | def add_deltas(env): 50 | 51 | def _fn(dataset): 52 | deltas = dataset['next_observations'] - dataset['observations'] 53 | dataset['deltas'] = deltas 54 | return dataset 55 | 56 | return _fn 57 | 58 | 59 | def maze2d_set_terminals(env): 60 | env = load_environment(env) if type(env) == str else env 61 | goal = np.array(env._target) 62 | threshold = 0.5 63 | 64 | def _fn(dataset): 65 | xy = dataset['observations'][:,:2] 66 | distances = np.linalg.norm(xy - goal, axis=-1) 67 | at_goal = distances < threshold 68 | timeouts = np.zeros_like(dataset['timeouts']) 69 | 70 | ## timeout at time t iff 71 | ## at goal at time t and 72 | ## not at goal at time t + 1 73 | timeouts[:-1] = at_goal[:-1] * ~at_goal[1:] 74 | 75 | timeout_steps = np.where(timeouts)[0] 76 | path_lengths = timeout_steps[1:] - timeout_steps[:-1] 77 | 78 | print( 79 | f'[ utils/preprocessing ] Segmented {env.name} | {len(path_lengths)} paths | ' 80 | f'min length: {path_lengths.min()} | max length: {path_lengths.max()}' 81 | ) 82 | 83 | dataset['timeouts'] = timeouts 84 | return dataset 85 | 86 | return _fn 87 | 88 | 89 | #-------------------------- block-stacking --------------------------# 90 | 91 | def blocks_quat_to_euler(observations): 92 | ''' 93 | input : [ N x robot_dim + n_blocks * 8 ] = [ N x 39 ] 94 | xyz: 3 95 | quat: 4 96 | contact: 1 97 | 98 | returns : [ N x robot_dim + n_blocks * 10] = [ N x 47 ] 99 | xyz: 3 100 | sin: 3 101 | cos: 3 102 | contact: 1 103 | ''' 104 | robot_dim = 7 105 | block_dim = 8 106 | n_blocks = 4 107 | assert observations.shape[-1] == robot_dim + n_blocks * block_dim 108 | 109 | X = observations[:, :robot_dim] 110 | 111 | for i in range(n_blocks): 112 | start = robot_dim + i * block_dim 113 | end = start + block_dim 114 | 115 | block_info = observations[:, start:end] 116 | 117 | xpos = block_info[:, :3] 118 | quat = block_info[:, 3:-1] 119 | contact = block_info[:, -1:] 120 | 121 | euler = R.from_quat(quat).as_euler('xyz') 122 | sin = np.sin(euler) 123 | cos = np.cos(euler) 124 | 125 | X = np.concatenate([ 126 | X, 127 | xpos, 128 | sin, 129 | cos, 130 | contact, 131 | ], axis=-1) 132 | 133 | return X 134 | 135 | def blocks_euler_to_quat_2d(observations): 136 | robot_dim = 7 137 | block_dim = 10 138 | n_blocks = 4 139 | 140 | assert observations.shape[-1] == robot_dim + n_blocks * block_dim 141 | 142 | X = observations[:, :robot_dim] 143 | 144 | for i in range(n_blocks): 145 | start = robot_dim + i * block_dim 146 | end = start + block_dim 147 | 148 | block_info = observations[:, start:end] 149 | 150 | xpos = block_info[:, :3] 151 | sin = block_info[:, 3:6] 152 | cos = block_info[:, 6:9] 153 | contact = block_info[:, 9:] 154 | 155 | euler = np.arctan2(sin, cos) 156 | quat = R.from_euler('xyz', euler, degrees=False).as_quat() 157 | 158 | X = np.concatenate([ 159 | X, 160 | xpos, 161 | quat, 162 | contact, 163 | ], axis=-1) 164 | 165 | return X 166 | 167 | def blocks_euler_to_quat(paths): 168 | return np.stack([ 169 | blocks_euler_to_quat_2d(path) 170 | for path in paths 171 | ], axis=0) 172 | 173 | def blocks_process_cubes(env): 174 | 175 | def _fn(dataset): 176 | for key in ['observations', 'next_observations']: 177 | dataset[key] = blocks_quat_to_euler(dataset[key]) 178 | return dataset 179 | 180 | return _fn 181 | 182 | def blocks_remove_kuka(env): 183 | 184 | def _fn(dataset): 185 | for key in ['observations', 'next_observations']: 186 | dataset[key] = dataset[key][:, 7:] 187 | return dataset 188 | 189 | return _fn 190 | 191 | def blocks_add_kuka(observations): 192 | ''' 193 | observations : [ batch_size x horizon x 32 ] 194 | ''' 195 | robot_dim = 7 196 | batch_size, horizon, _ = observations.shape 197 | observations = np.concatenate([ 198 | np.zeros((batch_size, horizon, 7)), 199 | observations, 200 | ], axis=-1) 201 | return observations 202 | 203 | def blocks_cumsum_quat(deltas): 204 | ''' 205 | deltas : [ batch_size x horizon x transition_dim ] 206 | ''' 207 | robot_dim = 7 208 | block_dim = 8 209 | n_blocks = 4 210 | assert deltas.shape[-1] == robot_dim + n_blocks * block_dim 211 | 212 | batch_size, horizon, _ = deltas.shape 213 | 214 | cumsum = deltas.cumsum(axis=1) 215 | for i in range(n_blocks): 216 | start = robot_dim + i * block_dim + 3 217 | end = start + 4 218 | 219 | quat = deltas[:, :, start:end].copy() 220 | 221 | quat = einops.rearrange(quat, 'b h q -> (b h) q') 222 | euler = R.from_quat(quat).as_euler('xyz') 223 | euler = einops.rearrange(euler, '(b h) e -> b h e', b=batch_size) 224 | cumsum_euler = euler.cumsum(axis=1) 225 | 226 | cumsum_euler = einops.rearrange(cumsum_euler, 'b h e -> (b h) e') 227 | cumsum_quat = R.from_euler('xyz', cumsum_euler).as_quat() 228 | cumsum_quat = einops.rearrange(cumsum_quat, '(b h) q -> b h q', b=batch_size) 229 | 230 | cumsum[:, :, start:end] = cumsum_quat.copy() 231 | 232 | return cumsum 233 | 234 | def blocks_delta_quat_helper(observations, next_observations): 235 | ''' 236 | input : [ N x robot_dim + n_blocks * 8 ] = [ N x 39 ] 237 | xyz: 3 238 | quat: 4 239 | contact: 1 240 | ''' 241 | robot_dim = 7 242 | block_dim = 8 243 | n_blocks = 4 244 | assert observations.shape[-1] == next_observations.shape[-1] == robot_dim + n_blocks * block_dim 245 | 246 | deltas = (next_observations - observations)[:, :robot_dim] 247 | 248 | for i in range(n_blocks): 249 | start = robot_dim + i * block_dim 250 | end = start + block_dim 251 | 252 | block_info = observations[:, start:end] 253 | next_block_info = next_observations[:, start:end] 254 | 255 | xpos = block_info[:, :3] 256 | next_xpos = next_block_info[:, :3] 257 | 258 | quat = block_info[:, 3:-1] 259 | next_quat = next_block_info[:, 3:-1] 260 | 261 | contact = block_info[:, -1:] 262 | next_contact = next_block_info[:, -1:] 263 | 264 | delta_xpos = next_xpos - xpos 265 | delta_contact = next_contact - contact 266 | 267 | rot = R.from_quat(quat) 268 | next_rot = R.from_quat(next_quat) 269 | 270 | delta_quat = (next_rot * rot.inv()).as_quat() 271 | w = delta_quat[:, -1:] 272 | 273 | ## make w positive to avoid [0, 0, 0, -1] 274 | delta_quat = delta_quat * np.sign(w) 275 | 276 | ## apply rot then delta to ensure we end at next_rot 277 | ## delta * rot = next_rot * rot' * rot = next_rot 278 | next_euler = next_rot.as_euler('xyz') 279 | next_euler_check = (R.from_quat(delta_quat) * rot).as_euler('xyz') 280 | assert np.allclose(next_euler, next_euler_check) 281 | 282 | deltas = np.concatenate([ 283 | deltas, 284 | delta_xpos, 285 | delta_quat, 286 | delta_contact, 287 | ], axis=-1) 288 | 289 | return deltas 290 | 291 | def blocks_add_deltas(env): 292 | 293 | def _fn(dataset): 294 | deltas = blocks_delta_quat_helper(dataset['observations'], dataset['next_observations']) 295 | # deltas = dataset['next_observations'] - dataset['observations'] 296 | dataset['deltas'] = deltas 297 | return dataset 298 | 299 | return _fn 300 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/datasets/sequence.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import numpy as np 3 | import torch 4 | import pdb 5 | 6 | from .preprocessing import get_preprocess_fn 7 | from .d4rl import load_environment, sequence_dataset 8 | from .normalization import DatasetNormalizer 9 | from .buffer import ReplayBuffer 10 | 11 | RewardBatch = namedtuple('Batch', 'trajectories conditions returns') 12 | Batch = namedtuple('Batch', 'trajectories conditions') 13 | ValueBatch = namedtuple('ValueBatch', 'trajectories conditions values') 14 | 15 | class SequenceDataset(torch.utils.data.Dataset): 16 | 17 | def __init__(self, env='hopper-medium-replay', horizon=64, 18 | normalizer='LimitsNormalizer', preprocess_fns=[], max_path_length=1000, 19 | max_n_episodes=10000, termination_penalty=0, use_padding=True, discount=0.99, returns_scale=1000, include_returns=False): 20 | self.preprocess_fn = get_preprocess_fn(preprocess_fns, env) 21 | self.env = env = load_environment(env) 22 | self.returns_scale = returns_scale 23 | self.horizon = horizon 24 | self.max_path_length = max_path_length 25 | self.discount = discount 26 | self.discounts = self.discount ** np.arange(self.max_path_length)[:, None] 27 | self.use_padding = use_padding 28 | self.include_returns = include_returns 29 | itr = sequence_dataset(env, self.preprocess_fn) 30 | 31 | fields = ReplayBuffer(max_n_episodes, max_path_length, termination_penalty) 32 | for i, episode in enumerate(itr): 33 | fields.add_path(episode) 34 | fields.finalize() 35 | 36 | self.normalizer = DatasetNormalizer(fields, normalizer, path_lengths=fields['path_lengths']) 37 | self.indices = self.make_indices(fields.path_lengths, horizon) 38 | 39 | self.observation_dim = fields.observations.shape[-1] 40 | self.action_dim = fields.actions.shape[-1] 41 | self.fields = fields 42 | self.n_episodes = fields.n_episodes 43 | self.path_lengths = fields.path_lengths 44 | self.normalize() 45 | 46 | print(fields) 47 | # shapes = {key: val.shape for key, val in self.fields.items()} 48 | # print(f'[ datasets/mujoco ] Dataset fields: {shapes}') 49 | 50 | def normalize(self, keys=['observations', 'actions']): 51 | ''' 52 | normalize fields that will be predicted by the diffusion model 53 | ''' 54 | for key in keys: 55 | array = self.fields[key].reshape(self.n_episodes*self.max_path_length, -1) 56 | normed = self.normalizer(array, key) 57 | self.fields[f'normed_{key}'] = normed.reshape(self.n_episodes, self.max_path_length, -1) 58 | 59 | def make_indices(self, path_lengths, horizon): 60 | ''' 61 | makes indices for sampling from dataset; 62 | each index maps to a datapoint 63 | ''' 64 | indices = [] 65 | for i, path_length in enumerate(path_lengths): 66 | max_start = min(path_length - 1, self.max_path_length - horizon) 67 | if not self.use_padding: 68 | max_start = min(max_start, path_length - horizon) 69 | for start in range(max_start): 70 | end = start + horizon 71 | indices.append((i, start, end)) 72 | indices = np.array(indices) 73 | return indices 74 | 75 | def get_conditions(self, observations): 76 | ''' 77 | condition on current observation for planning 78 | ''' 79 | return {0: observations[0]} 80 | 81 | def __len__(self): 82 | return len(self.indices) 83 | 84 | def __getitem__(self, idx, eps=1e-4): 85 | path_ind, start, end = self.indices[idx] 86 | 87 | observations = self.fields.normed_observations[path_ind, start:end] 88 | actions = self.fields.normed_actions[path_ind, start:end] 89 | 90 | conditions = self.get_conditions(observations) 91 | trajectories = np.concatenate([actions, observations], axis=-1) 92 | 93 | if self.include_returns: 94 | rewards = self.fields.rewards[path_ind, start:] 95 | discounts = self.discounts[:len(rewards)] 96 | returns = (discounts * rewards).sum() 97 | returns = np.array([returns/self.returns_scale], dtype=np.float32) 98 | batch = RewardBatch(trajectories, conditions, returns) 99 | else: 100 | batch = Batch(trajectories, conditions) 101 | 102 | return batch 103 | 104 | class CondSequenceDataset(torch.utils.data.Dataset): 105 | 106 | def __init__(self, env='hopper-medium-replay', horizon=64, 107 | normalizer='LimitsNormalizer', preprocess_fns=[], max_path_length=1000, 108 | max_n_episodes=10000, termination_penalty=0, use_padding=True, discount=0.99, returns_scale=1000, include_returns=False): 109 | self.preprocess_fn = get_preprocess_fn(preprocess_fns, env) 110 | self.env = env = load_environment(env) 111 | self.returns_scale = returns_scale 112 | self.horizon = horizon 113 | self.max_path_length = max_path_length 114 | self.discount = discount 115 | self.discounts = self.discount ** np.arange(self.max_path_length)[:, None] 116 | self.use_padding = use_padding 117 | self.include_returns = include_returns 118 | itr = sequence_dataset(env, self.preprocess_fn) 119 | 120 | fields = ReplayBuffer(max_n_episodes, max_path_length, termination_penalty) 121 | for i, episode in enumerate(itr): 122 | fields.add_path(episode) 123 | fields.finalize() 124 | 125 | self.normalizer = DatasetNormalizer(fields, normalizer, path_lengths=fields['path_lengths']) 126 | self.indices = self.make_indices(fields.path_lengths, horizon) 127 | 128 | self.observation_dim = fields.observations.shape[-1] 129 | self.action_dim = fields.actions.shape[-1] 130 | self.fields = fields 131 | self.n_episodes = fields.n_episodes 132 | self.path_lengths = fields.path_lengths 133 | self.normalize() 134 | 135 | print(fields) 136 | # shapes = {key: val.shape for key, val in self.fields.items()} 137 | # print(f'[ datasets/mujoco ] Dataset fields: {shapes}') 138 | 139 | def normalize(self, keys=['observations', 'actions']): 140 | ''' 141 | normalize fields that will be predicted by the diffusion model 142 | ''' 143 | for key in keys: 144 | array = self.fields[key].reshape(self.n_episodes*self.max_path_length, -1) 145 | normed = self.normalizer(array, key) 146 | self.fields[f'normed_{key}'] = normed.reshape(self.n_episodes, self.max_path_length, -1) 147 | 148 | def make_indices(self, path_lengths, horizon): 149 | ''' 150 | makes indices for sampling from dataset; 151 | each index maps to a datapoint 152 | ''' 153 | indices = [] 154 | for i, path_length in enumerate(path_lengths): 155 | max_start = min(path_length - 1, self.max_path_length - horizon) 156 | if not self.use_padding: 157 | max_start = min(max_start, path_length - horizon) 158 | for start in range(max_start): 159 | end = start + horizon 160 | indices.append((i, start, end)) 161 | indices = np.array(indices) 162 | return indices 163 | 164 | def __len__(self): 165 | return len(self.indices) 166 | 167 | def __getitem__(self, idx, eps=1e-4): 168 | path_ind, start, end = self.indices[idx] 169 | 170 | t_step = np.random.randint(0, self.horizon) 171 | 172 | observations = self.fields.normed_observations[path_ind, start:end] 173 | actions = self.fields.normed_actions[path_ind, start:end] 174 | 175 | traj_dim = self.action_dim + self.observation_dim 176 | 177 | conditions = np.ones((self.horizon, 2*traj_dim)).astype(np.float32) 178 | 179 | # Set up conditional masking 180 | conditions[t_step:,:self.action_dim] = 0 181 | conditions[:,traj_dim:] = 0 182 | conditions[t_step,traj_dim:traj_dim+self.action_dim] = 1 183 | 184 | if t_step < self.horizon-1: 185 | observations[t_step+1:] = 0 186 | 187 | trajectories = np.concatenate([actions, observations], axis=-1) 188 | 189 | if self.include_returns: 190 | rewards = self.fields.rewards[path_ind, start:] 191 | discounts = self.discounts[:len(rewards)] 192 | returns = (discounts * rewards).sum() 193 | returns = np.array([returns/self.returns_scale], dtype=np.float32) 194 | batch = RewardBatch(trajectories, conditions, returns) 195 | else: 196 | batch = Batch(trajectories, conditions) 197 | 198 | return batch 199 | 200 | class GoalDataset(SequenceDataset): 201 | 202 | def get_conditions(self, observations): 203 | ''' 204 | condition on both the current observation and the last observation in the plan 205 | ''' 206 | return { 207 | 0: observations[0], 208 | self.horizon - 1: observations[-1], 209 | } 210 | 211 | class ValueDataset(SequenceDataset): 212 | ''' 213 | adds a value field to the datapoints for training the value function 214 | ''' 215 | 216 | def __init__(self, *args, discount=0.99, **kwargs): 217 | super().__init__(*args, **kwargs) 218 | self.discount = discount 219 | self.discounts = self.discount ** np.arange(self.max_path_length)[:,None] 220 | 221 | def __getitem__(self, idx): 222 | batch = super().__getitem__(idx) 223 | path_ind, start, end = self.indices[idx] 224 | rewards = self.fields['rewards'][path_ind, start:] 225 | discounts = self.discounts[:len(rewards)] 226 | value = (discounts * rewards).sum() 227 | value = np.array([value], dtype=np.float32) 228 | value_batch = ValueBatch(*batch, value) 229 | return value_batch 230 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/environments/__init__.py: -------------------------------------------------------------------------------- 1 | from .registration import register_environments 2 | 3 | registered_environments = register_environments() -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/environments/ant.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from gym import utils 4 | from gym.envs.mujoco import mujoco_env 5 | 6 | ''' 7 | qpos : 15 8 | qvel : 14 9 | 0-2: root x, y, z 10 | 3-7: root quat 11 | 7 : front L hip 12 | 8 : front L ankle 13 | 9 : front R hip 14 | 10 : front R ankle 15 | 11 : back L hip 16 | 12 : back L ankle 17 | 13 : back R hip 18 | 14 : back R ankle 19 | 20 | ''' 21 | 22 | class AntFullObsEnv(mujoco_env.MujocoEnv, utils.EzPickle): 23 | def __init__(self): 24 | asset_path = os.path.join( 25 | os.path.dirname(__file__), 'assets/ant.xml') 26 | mujoco_env.MujocoEnv.__init__(self, asset_path, 5) 27 | utils.EzPickle.__init__(self) 28 | 29 | def step(self, a): 30 | xposbefore = self.get_body_com("torso")[0] 31 | self.do_simulation(a, self.frame_skip) 32 | xposafter = self.get_body_com("torso")[0] 33 | forward_reward = (xposafter - xposbefore) / self.dt 34 | ctrl_cost = 0.5 * np.square(a).sum() 35 | contact_cost = ( 36 | 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 37 | ) 38 | survive_reward = 1.0 39 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 40 | state = self.state_vector() 41 | notdone = np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0 42 | done = not notdone 43 | ob = self._get_obs() 44 | return ( 45 | ob, 46 | reward, 47 | done, 48 | dict( 49 | reward_forward=forward_reward, 50 | reward_ctrl=-ctrl_cost, 51 | reward_contact=-contact_cost, 52 | reward_survive=survive_reward, 53 | ), 54 | ) 55 | 56 | def _get_obs(self): 57 | return np.concatenate( 58 | [ 59 | self.sim.data.qpos.flat[2:], 60 | self.sim.data.qvel.flat, 61 | np.clip(self.sim.data.cfrc_ext, -1, 1).flat, 62 | ] 63 | ) 64 | 65 | def reset_model(self): 66 | qpos = self.init_qpos + self.np_random.uniform( 67 | size=self.model.nq, low=-0.1, high=0.1 68 | ) 69 | qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1 70 | self.set_state(qpos, qvel) 71 | return self._get_obs() 72 | 73 | def viewer_setup(self): 74 | self.viewer.cam.distance = self.model.stat.extent * 0.5 -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/environments/assets/ant.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/environments/assets/half_cheetah.xml: -------------------------------------------------------------------------------- 1 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/environments/assets/hopper.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/environments/assets/walker2d.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/environments/half_cheetah.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from gym import utils 4 | from gym.envs.mujoco import mujoco_env 5 | 6 | class HalfCheetahFullObsEnv(mujoco_env.MujocoEnv, utils.EzPickle): 7 | def __init__(self): 8 | asset_path = os.path.join( 9 | os.path.dirname(__file__), 'assets/half_cheetah.xml') 10 | mujoco_env.MujocoEnv.__init__(self, asset_path, 5) 11 | utils.EzPickle.__init__(self) 12 | 13 | def step(self, action): 14 | xposbefore = self.sim.data.qpos[0] 15 | self.do_simulation(action, self.frame_skip) 16 | xposafter = self.sim.data.qpos[0] 17 | ob = self._get_obs() 18 | reward_ctrl = - 0.1 * np.square(action).sum() 19 | reward_run = (xposafter - xposbefore)/self.dt 20 | reward = reward_ctrl + reward_run 21 | done = False 22 | return ob, reward, done, dict(reward_run=reward_run, reward_ctrl=reward_ctrl) 23 | 24 | def _get_obs(self): 25 | return np.concatenate([ 26 | # self.sim.data.qpos.flat[1:], 27 | self.sim.data.qpos.flat, #[1:], 28 | self.sim.data.qvel.flat, 29 | ]) 30 | 31 | def reset_model(self): 32 | qpos = self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq) 33 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 34 | self.set_state(qpos, qvel) 35 | return self._get_obs() 36 | 37 | def viewer_setup(self): 38 | self.viewer.cam.distance = self.model.stat.extent * 0.5 39 | 40 | def set(self, state): 41 | qpos_dim = self.sim.data.qpos.size 42 | qpos = state[:qpos_dim] 43 | qvel = state[qpos_dim:] 44 | self.set_state(qpos, qvel) 45 | return self._get_obs() 46 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/environments/hopper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from gym import utils 4 | from gym.envs.mujoco import mujoco_env 5 | 6 | class HopperFullObsEnv(mujoco_env.MujocoEnv, utils.EzPickle): 7 | def __init__(self): 8 | asset_path = os.path.join( 9 | os.path.dirname(__file__), 'assets/hopper.xml') 10 | mujoco_env.MujocoEnv.__init__(self, asset_path, 4) 11 | utils.EzPickle.__init__(self) 12 | 13 | def step(self, a): 14 | posbefore = self.sim.data.qpos[0] 15 | self.do_simulation(a, self.frame_skip) 16 | posafter, height, ang = self.sim.data.qpos[0:3] 17 | alive_bonus = 1.0 18 | reward = (posafter - posbefore) / self.dt 19 | reward += alive_bonus 20 | reward -= 1e-3 * np.square(a).sum() 21 | s = self.state_vector() 22 | done = not (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and 23 | (height > .7) and (abs(ang) < .2)) 24 | ob = self._get_obs() 25 | return ob, reward, done, {} 26 | 27 | def _get_obs(self): 28 | return np.concatenate([ 29 | # self.sim.data.qpos.flat[1:], 30 | self.sim.data.qpos.flat, 31 | np.clip(self.sim.data.qvel.flat, -10, 10) 32 | ]) 33 | 34 | def reset_model(self): 35 | qpos = self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq) 36 | qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) 37 | self.set_state(qpos, qvel) 38 | return self._get_obs() 39 | 40 | def viewer_setup(self): 41 | self.viewer.cam.trackbodyid = 2 42 | self.viewer.cam.distance = self.model.stat.extent * 0.75 43 | self.viewer.cam.lookat[2] = 1.15 44 | self.viewer.cam.elevation = -20 45 | 46 | def set(self, state): 47 | qpos_dim = self.sim.data.qpos.size 48 | qpos = state[:qpos_dim] 49 | qvel = state[qpos_dim:] 50 | self.set_state(qpos, qvel) 51 | return self._get_obs() -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/environments/registration.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | ENVIRONMENT_SPECS = ( 4 | { 5 | 'id': 'HopperFullObs-v2', 6 | 'entry_point': ('diffuser.environments.hopper:HopperFullObsEnv'), 7 | }, 8 | { 9 | 'id': 'HalfCheetahFullObs-v2', 10 | 'entry_point': ('diffuser.environments.half_cheetah:HalfCheetahFullObsEnv'), 11 | }, 12 | { 13 | 'id': 'Walker2dFullObs-v2', 14 | 'entry_point': ('diffuser.environments.walker2d:Walker2dFullObsEnv'), 15 | }, 16 | { 17 | 'id': 'AntFullObs-v2', 18 | 'entry_point': ('diffuser.environments.ant:AntFullObsEnv'), 19 | }, 20 | ) 21 | 22 | def register_environments(): 23 | try: 24 | for environment in ENVIRONMENT_SPECS: 25 | gym.register(**environment) 26 | 27 | gym_ids = tuple( 28 | environment_spec['id'] 29 | for environment_spec in ENVIRONMENT_SPECS) 30 | 31 | return gym_ids 32 | except: 33 | print('[ diffuser/environments/registration ] WARNING: not registering diffuser environments') 34 | return tuple() -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/environments/walker2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from gym import utils 4 | from gym.envs.mujoco import mujoco_env 5 | 6 | class Walker2dFullObsEnv(mujoco_env.MujocoEnv, utils.EzPickle): 7 | 8 | def __init__(self): 9 | asset_path = os.path.join( 10 | os.path.dirname(__file__), 'assets/walker2d.xml') 11 | mujoco_env.MujocoEnv.__init__(self, asset_path, 4) 12 | utils.EzPickle.__init__(self) 13 | 14 | def step(self, a): 15 | posbefore = self.sim.data.qpos[0] 16 | self.do_simulation(a, self.frame_skip) 17 | posafter, height, ang = self.sim.data.qpos[0:3] 18 | alive_bonus = 1.0 19 | reward = ((posafter - posbefore) / self.dt) 20 | reward += alive_bonus 21 | reward -= 1e-3 * np.square(a).sum() 22 | done = not (height > 0.8 and height < 2.0 and 23 | ang > -1.0 and ang < 1.0) 24 | ob = self._get_obs() 25 | return ob, reward, done, {} 26 | 27 | def _get_obs(self): 28 | qpos = self.sim.data.qpos 29 | qvel = self.sim.data.qvel 30 | return np.concatenate([qpos, qvel]).ravel() 31 | 32 | def reset_model(self): 33 | self.set_state( 34 | self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq), 35 | self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv) 36 | ) 37 | return self._get_obs() 38 | 39 | def viewer_setup(self): 40 | self.viewer.cam.trackbodyid = 2 41 | self.viewer.cam.distance = self.model.stat.extent * 0.5 42 | self.viewer.cam.lookat[2] = 1.15 43 | self.viewer.cam.elevation = -20 44 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .temporal import TemporalUnet, TemporalValue, MLPnet 2 | from .diffusion import GaussianDiffusion, ActionGaussianDiffusion, GaussianInvDynDiffusion -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/models/helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import einops 7 | from einops.layers.torch import Rearrange 8 | import pdb 9 | 10 | import diffuser.utils as utils 11 | 12 | #-----------------------------------------------------------------------------# 13 | #---------------------------------- modules ----------------------------------# 14 | #-----------------------------------------------------------------------------# 15 | 16 | class SinusoidalPosEmb(nn.Module): 17 | def __init__(self, dim): 18 | super().__init__() 19 | self.dim = dim 20 | 21 | def forward(self, x): 22 | device = x.device 23 | half_dim = self.dim // 2 24 | emb = math.log(10000) / (half_dim - 1) 25 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 26 | emb = x[:, None] * emb[None, :] 27 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 28 | return emb 29 | 30 | class Downsample1d(nn.Module): 31 | def __init__(self, dim): 32 | super().__init__() 33 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1) 34 | 35 | def forward(self, x): 36 | return self.conv(x) 37 | 38 | class Upsample1d(nn.Module): 39 | def __init__(self, dim): 40 | super().__init__() 41 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1) 42 | 43 | def forward(self, x): 44 | return self.conv(x) 45 | 46 | class Conv1dBlock(nn.Module): 47 | ''' 48 | Conv1d --> GroupNorm --> Mish 49 | ''' 50 | 51 | def __init__(self, inp_channels, out_channels, kernel_size, mish=True, n_groups=8): 52 | super().__init__() 53 | 54 | if mish: 55 | act_fn = nn.Mish() 56 | else: 57 | act_fn = nn.SiLU() 58 | 59 | self.block = nn.Sequential( 60 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), 61 | Rearrange('batch channels horizon -> batch channels 1 horizon'), 62 | nn.GroupNorm(n_groups, out_channels), 63 | Rearrange('batch channels 1 horizon -> batch channels horizon'), 64 | act_fn, 65 | ) 66 | 67 | def forward(self, x): 68 | return self.block(x) 69 | 70 | 71 | #-----------------------------------------------------------------------------# 72 | #---------------------------------- sampling ---------------------------------# 73 | #-----------------------------------------------------------------------------# 74 | 75 | def extract(a, t, x_shape): 76 | b, *_ = t.shape 77 | out = a.gather(-1, t) 78 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 79 | 80 | def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32): 81 | """ 82 | cosine schedule 83 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 84 | """ 85 | steps = timesteps + 1 86 | x = np.linspace(0, steps, steps) 87 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 88 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 89 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 90 | betas_clipped = np.clip(betas, a_min=0, a_max=0.999) 91 | return torch.tensor(betas_clipped, dtype=dtype) 92 | 93 | def apply_conditioning(x, conditions, action_dim): 94 | for t, val in conditions.items(): 95 | x[:, t, action_dim:] = val.clone() 96 | return x 97 | 98 | #-----------------------------------------------------------------------------# 99 | #---------------------------------- losses -----------------------------------# 100 | #-----------------------------------------------------------------------------# 101 | 102 | class WeightedLoss(nn.Module): 103 | 104 | def __init__(self, weights, action_dim): 105 | super().__init__() 106 | self.register_buffer('weights', weights) 107 | self.action_dim = action_dim 108 | 109 | def forward(self, pred, targ): 110 | ''' 111 | pred, targ : tensor 112 | [ batch_size x horizon x transition_dim ] 113 | ''' 114 | loss = self._loss(pred, targ) 115 | weighted_loss = (loss * self.weights).mean() 116 | a0_loss = (loss[:, 0, :self.action_dim] / self.weights[0, :self.action_dim]).mean() 117 | return weighted_loss, {'a0_loss': a0_loss} 118 | 119 | class WeightedStateLoss(nn.Module): 120 | 121 | def __init__(self, weights): 122 | super().__init__() 123 | self.register_buffer('weights', weights) 124 | 125 | def forward(self, pred, targ): 126 | ''' 127 | pred, targ : tensor 128 | [ batch_size x horizon x transition_dim ] 129 | ''' 130 | loss = self._loss(pred, targ) 131 | weighted_loss = (loss * self.weights).mean() 132 | return weighted_loss, {'a0_loss': weighted_loss} 133 | 134 | class ValueLoss(nn.Module): 135 | def __init__(self, *args): 136 | super().__init__() 137 | pass 138 | 139 | def forward(self, pred, targ): 140 | loss = self._loss(pred, targ).mean() 141 | 142 | if len(pred) > 1: 143 | corr = np.corrcoef( 144 | utils.to_np(pred).squeeze(), 145 | utils.to_np(targ).squeeze() 146 | )[0,1] 147 | else: 148 | corr = np.NaN 149 | 150 | info = { 151 | 'mean_pred': pred.mean(), 'mean_targ': targ.mean(), 152 | 'min_pred': pred.min(), 'min_targ': targ.min(), 153 | 'max_pred': pred.max(), 'max_targ': targ.max(), 154 | 'corr': corr, 155 | } 156 | 157 | return loss, info 158 | 159 | class WeightedL1(WeightedLoss): 160 | 161 | def _loss(self, pred, targ): 162 | return torch.abs(pred - targ) 163 | 164 | class WeightedL2(WeightedLoss): 165 | 166 | def _loss(self, pred, targ): 167 | return F.mse_loss(pred, targ, reduction='none') 168 | 169 | class WeightedStateL2(WeightedStateLoss): 170 | 171 | def _loss(self, pred, targ): 172 | return F.mse_loss(pred, targ, reduction='none') 173 | 174 | class ValueL1(ValueLoss): 175 | 176 | def _loss(self, pred, targ): 177 | return torch.abs(pred - targ) 178 | 179 | class ValueL2(ValueLoss): 180 | 181 | def _loss(self, pred, targ): 182 | return F.mse_loss(pred, targ, reduction='none') 183 | 184 | Losses = { 185 | 'l1': WeightedL1, 186 | 'l2': WeightedL2, 187 | 'state_l2': WeightedStateL2, 188 | 'value_l1': ValueL1, 189 | 'value_l2': ValueL2, 190 | } 191 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .serialization import * 2 | from .training import * 3 | from .progress import * 4 | # from .setup import * 5 | from .config import * 6 | from .rendering import * 7 | from .arrays import * 8 | from .colab import * 9 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/utils/arrays.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | import torch 4 | import pdb 5 | 6 | DTYPE = torch.float 7 | DEVICE = 'cuda' 8 | 9 | #-----------------------------------------------------------------------------# 10 | #------------------------------ numpy <--> torch -----------------------------# 11 | #-----------------------------------------------------------------------------# 12 | 13 | def to_np(x): 14 | if torch.is_tensor(x): 15 | x = x.detach().cpu().numpy() 16 | return x 17 | 18 | def to_torch(x, dtype=None, device=None): 19 | dtype = dtype or DTYPE 20 | device = device or DEVICE 21 | if type(x) is dict: 22 | return {k: to_torch(v, dtype, device) for k, v in x.items()} 23 | elif torch.is_tensor(x): 24 | return x.to(device).type(dtype) 25 | # import pdb; pdb.set_trace() 26 | return torch.tensor(x, dtype=dtype, device=device) 27 | 28 | def to_device(x, device=DEVICE): 29 | if torch.is_tensor(x): 30 | return x.to(device) 31 | elif type(x) is dict: 32 | return {k: to_device(v, device) for k, v in x.items()} 33 | else: 34 | print(f'Unrecognized type in `to_device`: {type(x)}') 35 | pdb.set_trace() 36 | # return [x.to(device) for x in xs] 37 | 38 | # def atleast_2d(x, axis=0): 39 | # ''' 40 | # works for both np arrays and torch tensors 41 | # ''' 42 | # while len(x.shape) < 2: 43 | # shape = (1, *x.shape) if axis == 0 else (*x.shape, 1) 44 | # x = x.reshape(*shape) 45 | # return x 46 | 47 | # def to_2d(x): 48 | # dim = x.shape[-1] 49 | # return x.reshape(-1, dim) 50 | 51 | def batchify(batch, device): 52 | ''' 53 | convert a single dataset item to a batch suitable for passing to a model by 54 | 1) converting np arrays to torch tensors and 55 | 2) and ensuring that everything has a batch dimension 56 | ''' 57 | fn = lambda x: to_torch(x[None], device=device) 58 | 59 | batched_vals = [] 60 | for field in batch._fields: 61 | val = getattr(batch, field) 62 | val = apply_dict(fn, val) if type(val) is dict else fn(val) 63 | batched_vals.append(val) 64 | return type(batch)(*batched_vals) 65 | 66 | def apply_dict(fn, d, *args, **kwargs): 67 | return { 68 | k: fn(v, *args, **kwargs) 69 | for k, v in d.items() 70 | } 71 | 72 | def normalize(x): 73 | """ 74 | scales `x` to [0, 1] 75 | """ 76 | x = x - x.min() 77 | x = x / x.max() 78 | return x 79 | 80 | def to_img(x): 81 | normalized = normalize(x) 82 | array = to_np(normalized) 83 | array = np.transpose(array, (1,2,0)) 84 | return (array * 255).astype(np.uint8) 85 | 86 | def set_device(device): 87 | DEVICE = device 88 | if 'cuda' in device: 89 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 90 | 91 | def batch_to_device(batch, device='cuda:0'): 92 | vals = [ 93 | to_device(getattr(batch, field), device) 94 | for field in batch._fields 95 | ] 96 | return type(batch)(*vals) 97 | 98 | def _to_str(num): 99 | if num >= 1e6: 100 | return f'{(num/1e6):.2f} M' 101 | else: 102 | return f'{(num/1e3):.2f} k' 103 | 104 | #-----------------------------------------------------------------------------# 105 | #----------------------------- parameter counting ----------------------------# 106 | #-----------------------------------------------------------------------------# 107 | 108 | def param_to_module(param): 109 | module_name = param[::-1].split('.', maxsplit=1)[-1][::-1] 110 | return module_name 111 | 112 | def report_parameters(model, topk=10): 113 | counts = {k: p.numel() for k, p in model.named_parameters()} 114 | n_parameters = sum(counts.values()) 115 | print(f'[ utils/arrays ] Total parameters: {_to_str(n_parameters)}') 116 | 117 | modules = dict(model.named_modules()) 118 | sorted_keys = sorted(counts, key=lambda x: -counts[x]) 119 | max_length = max([len(k) for k in sorted_keys]) 120 | for i in range(topk): 121 | key = sorted_keys[i] 122 | count = counts[key] 123 | module = param_to_module(key) 124 | print(' '*8, f'{key:10}: {_to_str(count)} | {modules[module]}') 125 | 126 | remaining_parameters = sum([counts[k] for k in sorted_keys[topk:]]) 127 | print(' '*8, f'... and {len(counts)-topk} others accounting for {_to_str(remaining_parameters)} parameters') 128 | return n_parameters 129 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/utils/cloud.py: -------------------------------------------------------------------------------- 1 | import shlex 2 | import subprocess 3 | import pdb 4 | 5 | def sync_logs(logdir, bucket, background=False): 6 | ## remove prefix 'logs' on google cloud 7 | destination = 'logs' + logdir.split('logs')[-1] 8 | upload_blob(logdir, destination, bucket, background) 9 | 10 | def upload_blob(source, destination, bucket, background): 11 | command = f'gsutil -m -o GSUtil:parallel_composite_upload_threshold=150M rsync -r {source} {bucket}/{destination}' 12 | print(f'[ utils/cloud ] Syncing bucket: {command}') 13 | command = shlex.split(command) 14 | 15 | if background: 16 | subprocess.Popen(command) 17 | else: 18 | subprocess.call(command) 19 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/utils/colab.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import einops 4 | import matplotlib.pyplot as plt 5 | from tqdm import tqdm 6 | 7 | try: 8 | import io 9 | import base64 10 | from IPython.display import HTML 11 | from IPython import display as ipythondisplay 12 | except: 13 | print('[ utils/colab ] Warning: not importing colab dependencies') 14 | 15 | from .serialization import mkdir 16 | from .arrays import to_torch, to_np 17 | from .video import save_video 18 | 19 | 20 | def run_diffusion(model, dataset, obs, n_samples=1, device='cuda:0', **diffusion_kwargs): 21 | ## normalize observation for model 22 | obs = dataset.normalizer.normalize(obs, 'observations') 23 | 24 | ## add a batch dimension and repeat for multiple samples 25 | ## [ observation_dim ] --> [ n_samples x observation_dim ] 26 | obs = obs[None].repeat(n_samples, axis=0) 27 | 28 | ## format `conditions` input for model 29 | conditions = { 30 | 0: to_torch(obs, device=device) 31 | } 32 | 33 | samples, diffusion = model.conditional_sample(conditions, 34 | return_diffusion=True, verbose=False, **diffusion_kwargs) 35 | 36 | ## [ n_samples x (n_diffusion_steps + 1) x horizon x (action_dim + observation_dim)] 37 | diffusion = to_np(diffusion) 38 | 39 | ## extract observations 40 | ## [ n_samples x (n_diffusion_steps + 1) x horizon x observation_dim ] 41 | normed_observations = diffusion[:, :, :, dataset.action_dim:] 42 | 43 | ## unnormalize observation samples from model 44 | observations = dataset.normalizer.unnormalize(normed_observations, 'observations') 45 | 46 | ## [ (n_diffusion_steps + 1) x n_samples x horizon x observation_dim ] 47 | observations = einops.rearrange(observations, 48 | 'batch steps horizon dim -> steps batch horizon dim') 49 | 50 | return observations 51 | 52 | 53 | def show_diffusion(renderer, observations, n_repeat=100, substep=1, filename='diffusion.mp4', savebase='/content/videos'): 54 | ''' 55 | observations : [ n_diffusion_steps x batch_size x horizon x observation_dim ] 56 | ''' 57 | mkdir(savebase) 58 | savepath = os.path.join(savebase, filename) 59 | 60 | subsampled = observations[::substep] 61 | 62 | images = [] 63 | for t in tqdm(range(len(subsampled))): 64 | observation = subsampled[t] 65 | 66 | img = renderer.composite(None, observation) 67 | images.append(img) 68 | images = np.stack(images, axis=0) 69 | 70 | ## pause at the end of video 71 | images = np.concatenate([ 72 | images, 73 | images[-1:].repeat(n_repeat, axis=0) 74 | ], axis=0) 75 | 76 | save_video(savepath, images) 77 | show_video(savepath) 78 | 79 | 80 | def show_sample(renderer, observations, filename='sample.mp4', savebase='/content/videos'): 81 | ''' 82 | observations : [ batch_size x horizon x observation_dim ] 83 | ''' 84 | 85 | mkdir(savebase) 86 | savepath = os.path.join(savebase, filename) 87 | 88 | images = [] 89 | for rollout in observations: 90 | ## [ horizon x height x width x channels ] 91 | img = renderer._renders(rollout, partial=True) 92 | images.append(img) 93 | 94 | ## [ horizon x height x (batch_size * width) x channels ] 95 | images = np.concatenate(images, axis=2) 96 | 97 | save_video(savepath, images) 98 | show_video(savepath, height=200) 99 | 100 | 101 | def show_samples(renderer, observations_l, figsize=12): 102 | ''' 103 | observations_l : [ [ n_diffusion_steps x batch_size x horizon x observation_dim ], ... ] 104 | ''' 105 | 106 | images = [] 107 | for observations in observations_l: 108 | path = observations[-1] 109 | img = renderer.composite(None, path) 110 | images.append(img) 111 | images = np.concatenate(images, axis=0) 112 | 113 | plt.imshow(images) 114 | plt.axis('off') 115 | plt.gcf().set_size_inches(figsize, figsize) 116 | 117 | 118 | def show_video(path, height=400): 119 | video = io.open(path, 'r+b').read() 120 | encoded = base64.b64encode(video) 121 | ipythondisplay.display(HTML(data=''''''.format(height, encoded.decode('ascii')))) 125 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import importlib 4 | import pickle 5 | from ml_logger import logger 6 | 7 | def import_class(_class): 8 | if type(_class) is not str: return _class 9 | ## 'diffusion' on standard installs 10 | repo_name = __name__.split('.')[0] 11 | ## eg, 'utils' 12 | module_name = '.'.join(_class.split('.')[:-1]) 13 | ## eg, 'Renderer' 14 | class_name = _class.split('.')[-1] 15 | ## eg, 'diffusion.utils' 16 | module = importlib.import_module(f'{repo_name}.{module_name}') 17 | ## eg, diffusion.utils.Renderer 18 | _class = getattr(module, class_name) 19 | print(f'[ utils/config ] Imported {repo_name}.{module_name}:{class_name}') 20 | return _class 21 | 22 | class Config(collections.Mapping): 23 | 24 | def __init__(self, _class, verbose=True, savepath=None, device=None, **kwargs): 25 | self._class = import_class(_class) 26 | self._device = device 27 | self._dict = {} 28 | 29 | for key, val in kwargs.items(): 30 | self._dict[key] = val 31 | 32 | if verbose: 33 | print(self) 34 | 35 | if savepath is not None: 36 | logger.save_pkl(self, savepath) 37 | print(f'[ utils/config ] Saved config to: {savepath}\n') 38 | 39 | def __repr__(self): 40 | string = f'\n[utils/config ] Config: {self._class}\n' 41 | for key in sorted(self._dict.keys()): 42 | val = self._dict[key] 43 | string += f' {key}: {val}\n' 44 | return string 45 | 46 | def __iter__(self): 47 | return iter(self._dict) 48 | 49 | def __getitem__(self, item): 50 | return self._dict[item] 51 | 52 | def __len__(self): 53 | return len(self._dict) 54 | 55 | def __getattr__(self, attr): 56 | if attr == '_dict' and '_dict' not in vars(self): 57 | self._dict = {} 58 | return self._dict 59 | try: 60 | return self._dict[attr] 61 | except KeyError: 62 | raise AttributeError(attr) 63 | 64 | def __call__(self, *args, **kwargs): 65 | instance = self._class(*args, **kwargs, **self._dict) 66 | if self._device: 67 | instance = instance.to(self._device) 68 | return instance 69 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/utils/git_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import git 3 | import pdb 4 | 5 | PROJECT_PATH = os.path.dirname( 6 | os.path.realpath(os.path.join(__file__, '..', '..'))) 7 | 8 | def get_repo(path=PROJECT_PATH, search_parent_directories=True): 9 | repo = git.Repo( 10 | path, search_parent_directories=search_parent_directories) 11 | return repo 12 | 13 | def get_git_rev(*args, **kwargs): 14 | try: 15 | repo = get_repo(*args, **kwargs) 16 | if repo.head.is_detached: 17 | git_rev = repo.head.object.name_rev 18 | else: 19 | git_rev = repo.active_branch.commit.name_rev 20 | except: 21 | git_rev = None 22 | 23 | return git_rev 24 | 25 | def git_diff(*args, **kwargs): 26 | repo = get_repo(*args, **kwargs) 27 | diff = repo.git.diff() 28 | return diff 29 | 30 | def save_git_diff(savepath, *args, **kwargs): 31 | diff = git_diff(*args, **kwargs) 32 | with open(savepath, 'w') as f: 33 | f.write(diff) 34 | 35 | if __name__ == '__main__': 36 | 37 | git_rev = get_git_rev() 38 | print(git_rev) 39 | 40 | save_git_diff('diff_test.txt') -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/utils/iql.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import jax 4 | import jax.numpy as jnp 5 | import functools 6 | import pdb 7 | 8 | from diffuser.iql.common import Model 9 | from diffuser.iql.value_net import DoubleCritic 10 | 11 | def load_q(env, loadpath, hidden_dims=(256, 256), seed=42): 12 | print(f'[ utils/iql ] Loading Q: {loadpath}') 13 | observations = env.observation_space.sample()[np.newaxis] 14 | actions = env.action_space.sample()[np.newaxis] 15 | 16 | rng = jax.random.PRNGKey(seed) 17 | rng, key = jax.random.split(rng) 18 | 19 | critic_def = DoubleCritic(hidden_dims) 20 | critic = Model.create(critic_def, 21 | inputs=[key, observations, actions]) 22 | 23 | ## allows for relative paths 24 | loadpath = os.path.expanduser(loadpath) 25 | critic = critic.load(loadpath) 26 | return critic 27 | 28 | class JaxWrapper: 29 | 30 | def __init__(self, env, loadpath, *args, **kwargs): 31 | self.model = load_q(env, loadpath) 32 | 33 | @functools.partial(jax.jit, static_argnames=('self'), device=jax.devices('cpu')[0]) 34 | def forward(self, xs): 35 | Qs = self.model(*xs) 36 | Q = jnp.minimum(*Qs) 37 | return Q 38 | 39 | def __call__(self, *xs): 40 | Q = self.forward(xs) 41 | return np.array(Q) 42 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/utils/progress.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import pdb 4 | 5 | class Progress: 6 | 7 | def __init__(self, total, name = 'Progress', ncol=3, max_length=20, indent=0, line_width=100, speed_update_freq=100): 8 | self.total = total 9 | self.name = name 10 | self.ncol = ncol 11 | self.max_length = max_length 12 | self.indent = indent 13 | self.line_width = line_width 14 | self._speed_update_freq = speed_update_freq 15 | 16 | self._step = 0 17 | self._prev_line = '\033[F' 18 | self._clear_line = ' ' * self.line_width 19 | 20 | self._pbar_size = self.ncol * self.max_length 21 | self._complete_pbar = '#' * self._pbar_size 22 | self._incomplete_pbar = ' ' * self._pbar_size 23 | 24 | self.lines = [''] 25 | self.fraction = '{} / {}'.format(0, self.total) 26 | 27 | self.resume() 28 | 29 | 30 | def update(self, description, n=1): 31 | self._step += n 32 | if self._step % self._speed_update_freq == 0: 33 | self._time0 = time.time() 34 | self._step0 = self._step 35 | self.set_description(description) 36 | 37 | def resume(self): 38 | self._skip_lines = 1 39 | print('\n', end='') 40 | self._time0 = time.time() 41 | self._step0 = self._step 42 | 43 | def pause(self): 44 | self._clear() 45 | self._skip_lines = 1 46 | 47 | def set_description(self, params=[]): 48 | 49 | if type(params) == dict: 50 | params = sorted([ 51 | (key, val) 52 | for key, val in params.items() 53 | ]) 54 | 55 | ############ 56 | # Position # 57 | ############ 58 | self._clear() 59 | 60 | ########### 61 | # Percent # 62 | ########### 63 | percent, fraction = self._format_percent(self._step, self.total) 64 | self.fraction = fraction 65 | 66 | ######### 67 | # Speed # 68 | ######### 69 | speed = self._format_speed(self._step) 70 | 71 | ########## 72 | # Params # 73 | ########## 74 | num_params = len(params) 75 | nrow = math.ceil(num_params / self.ncol) 76 | params_split = self._chunk(params, self.ncol) 77 | params_string, lines = self._format(params_split) 78 | self.lines = lines 79 | 80 | 81 | description = '{} | {}{}'.format(percent, speed, params_string) 82 | print(description) 83 | self._skip_lines = nrow + 1 84 | 85 | def append_description(self, descr): 86 | self.lines.append(descr) 87 | 88 | def _clear(self): 89 | position = self._prev_line * self._skip_lines 90 | empty = '\n'.join([self._clear_line for _ in range(self._skip_lines)]) 91 | print(position, end='') 92 | print(empty) 93 | print(position, end='') 94 | 95 | def _format_percent(self, n, total): 96 | if total: 97 | percent = n / float(total) 98 | 99 | complete_entries = int(percent * self._pbar_size) 100 | incomplete_entries = self._pbar_size - complete_entries 101 | 102 | pbar = self._complete_pbar[:complete_entries] + self._incomplete_pbar[:incomplete_entries] 103 | fraction = '{} / {}'.format(n, total) 104 | string = '{} [{}] {:3d}%'.format(fraction, pbar, int(percent*100)) 105 | else: 106 | fraction = '{}'.format(n) 107 | string = '{} iterations'.format(n) 108 | return string, fraction 109 | 110 | def _format_speed(self, n): 111 | num_steps = n - self._step0 112 | t = time.time() - self._time0 113 | speed = num_steps / t 114 | string = '{:.1f} Hz'.format(speed) 115 | if num_steps > 0: 116 | self._speed = string 117 | return string 118 | 119 | def _chunk(self, l, n): 120 | return [l[i:i+n] for i in range(0, len(l), n)] 121 | 122 | def _format(self, chunks): 123 | lines = [self._format_chunk(chunk) for chunk in chunks] 124 | lines.insert(0,'') 125 | padding = '\n' + ' '*self.indent 126 | string = padding.join(lines) 127 | return string, lines 128 | 129 | def _format_chunk(self, chunk): 130 | line = ' | '.join([self._format_param(param) for param in chunk]) 131 | return line 132 | 133 | def _format_param(self, param): 134 | k, v = param 135 | return '{} : {}'.format(k, v)[:self.max_length] 136 | 137 | def stamp(self): 138 | if self.lines != ['']: 139 | params = ' | '.join(self.lines) 140 | string = '[ {} ] {}{} | {}'.format(self.name, self.fraction, params, self._speed) 141 | self._clear() 142 | print(string, end='\n') 143 | self._skip_lines = 1 144 | else: 145 | self._clear() 146 | self._skip_lines = 0 147 | 148 | def close(self): 149 | self.pause() 150 | 151 | class Silent: 152 | 153 | def __init__(self, *args, **kwargs): 154 | pass 155 | 156 | def __getattr__(self, attr): 157 | return lambda *args: None 158 | 159 | 160 | if __name__ == '__main__': 161 | silent = Silent() 162 | silent.update() 163 | silent.stamp() 164 | 165 | num_steps = 1000 166 | progress = Progress(num_steps) 167 | for i in range(num_steps): 168 | progress.update() 169 | params = [ 170 | ['A', '{:06d}'.format(i)], 171 | ['B', '{:06d}'.format(i)], 172 | ['C', '{:06d}'.format(i)], 173 | ['D', '{:06d}'.format(i)], 174 | ['E', '{:06d}'.format(i)], 175 | ['F', '{:06d}'.format(i)], 176 | ['G', '{:06d}'.format(i)], 177 | ['H', '{:06d}'.format(i)], 178 | ] 179 | progress.set_description(params) 180 | time.sleep(0.01) 181 | progress.close() 182 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/utils/rendering.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import einops 4 | import imageio 5 | import matplotlib.pyplot as plt 6 | from matplotlib.colors import ListedColormap 7 | import gym 8 | import mujoco_py as mjc 9 | import warnings 10 | import pdb 11 | 12 | from .arrays import to_np 13 | from .video import save_video, save_videos 14 | from ml_logger import logger 15 | 16 | from diffuser.datasets.d4rl import load_environment 17 | 18 | #-----------------------------------------------------------------------------# 19 | #------------------------------- helper structs ------------------------------# 20 | #-----------------------------------------------------------------------------# 21 | 22 | def env_map(env_name): 23 | ''' 24 | map D4RL dataset names to custom fully-observed 25 | variants for rendering 26 | ''' 27 | if 'halfcheetah' in env_name: 28 | return 'HalfCheetahFullObs-v2' 29 | elif 'hopper' in env_name: 30 | return 'HopperFullObs-v2' 31 | elif 'walker2d' in env_name: 32 | return 'Walker2dFullObs-v2' 33 | else: 34 | return env_name 35 | 36 | #-----------------------------------------------------------------------------# 37 | #------------------------------ helper functions -----------------------------# 38 | #-----------------------------------------------------------------------------# 39 | 40 | def get_image_mask(img): 41 | background = (img == 255).all(axis=-1, keepdims=True) 42 | mask = ~background.repeat(3, axis=-1) 43 | return mask 44 | 45 | def atmost_2d(x): 46 | while x.ndim > 2: 47 | x = x.squeeze(0) 48 | return x 49 | 50 | #-----------------------------------------------------------------------------# 51 | #---------------------------------- renderers --------------------------------# 52 | #-----------------------------------------------------------------------------# 53 | 54 | class MuJoCoRenderer: 55 | ''' 56 | default mujoco renderer 57 | ''' 58 | 59 | def __init__(self, env): 60 | if type(env) is str: 61 | env = env_map(env) 62 | self.env = gym.make(env) 63 | else: 64 | self.env = env 65 | ## - 1 because the envs in renderer are fully-observed 66 | ## @TODO : clean up 67 | self.observation_dim = np.prod(self.env.observation_space.shape) - 1 68 | self.action_dim = np.prod(self.env.action_space.shape) 69 | try: 70 | self.viewer = mjc.MjRenderContextOffscreen(self.env.sim) 71 | except: 72 | print('[ utils/rendering ] Warning: could not initialize offscreen renderer') 73 | self.viewer = None 74 | 75 | def pad_observation(self, observation): 76 | state = np.concatenate([ 77 | np.zeros(1), 78 | observation, 79 | ]) 80 | return state 81 | 82 | def pad_observations(self, observations): 83 | qpos_dim = self.env.sim.data.qpos.size 84 | ## xpos is hidden 85 | xvel_dim = qpos_dim - 1 86 | xvel = observations[:, xvel_dim] 87 | xpos = np.cumsum(xvel) * self.env.dt 88 | states = np.concatenate([ 89 | xpos[:,None], 90 | observations, 91 | ], axis=-1) 92 | return states 93 | 94 | def render(self, observation, dim=256, partial=False, qvel=True, render_kwargs=None, conditions=None): 95 | 96 | if type(dim) == int: 97 | dim = (dim, dim) 98 | 99 | if self.viewer is None: 100 | return np.zeros((*dim, 3), np.uint8) 101 | 102 | if render_kwargs is None: 103 | xpos = observation[0] if not partial else 0 104 | render_kwargs = { 105 | 'trackbodyid': 2, 106 | 'distance': 3, 107 | 'lookat': [xpos, -0.5, 1], 108 | 'elevation': -20 109 | } 110 | 111 | for key, val in render_kwargs.items(): 112 | if key == 'lookat': 113 | self.viewer.cam.lookat[:] = val[:] 114 | else: 115 | setattr(self.viewer.cam, key, val) 116 | 117 | if partial: 118 | state = self.pad_observation(observation) 119 | else: 120 | state = observation 121 | 122 | qpos_dim = self.env.sim.data.qpos.size 123 | if not qvel or state.shape[-1] == qpos_dim: 124 | qvel_dim = self.env.sim.data.qvel.size 125 | state = np.concatenate([state, np.zeros(qvel_dim)]) 126 | 127 | set_state(self.env, state) 128 | 129 | self.viewer.render(*dim) 130 | data = self.viewer.read_pixels(*dim, depth=False) 131 | data = data[::-1, :, :] 132 | return data 133 | 134 | def _renders(self, observations, **kwargs): 135 | images = [] 136 | for observation in observations: 137 | img = self.render(observation, **kwargs) 138 | images.append(img) 139 | return np.stack(images, axis=0) 140 | 141 | def renders(self, samples, partial=False, **kwargs): 142 | if partial: 143 | samples = self.pad_observations(samples) 144 | partial = False 145 | 146 | sample_images = self._renders(samples, partial=partial, **kwargs) 147 | 148 | composite = np.ones_like(sample_images[0]) * 255 149 | 150 | for img in sample_images: 151 | mask = get_image_mask(img) 152 | composite[mask] = img[mask] 153 | 154 | return composite 155 | 156 | def composite(self, savepath, paths, dim=(1024, 256), **kwargs): 157 | 158 | render_kwargs = { 159 | 'trackbodyid': 2, 160 | 'distance': 10, 161 | 'lookat': [5, 2, 0.5], 162 | 'elevation': 0 163 | } 164 | images = [] 165 | for path in paths: 166 | ## [ H x obs_dim ] 167 | path = atmost_2d(path) 168 | img = self.renders(to_np(path), dim=dim, partial=True, qvel=True, render_kwargs=render_kwargs, **kwargs) 169 | images.append(img) 170 | images = np.concatenate(images, axis=0) 171 | 172 | if savepath is not None: 173 | fig = plt.figure() 174 | plt.imshow(images) 175 | logger.savefig(savepath, fig) 176 | print(f'Saved {len(paths)} samples to: {savepath}') 177 | 178 | return images 179 | 180 | def render_rollout(self, savepath, states, **video_kwargs): 181 | if type(states) is list: states = np.array(states) 182 | images = self._renders(states, partial=True) 183 | save_video(savepath, images, **video_kwargs) 184 | 185 | def render_plan(self, savepath, actions, observations_pred, state, fps=30): 186 | ## [ batch_size x horizon x observation_dim ] 187 | observations_real = rollouts_from_state(self.env, state, actions) 188 | 189 | ## there will be one more state in `observations_real` 190 | ## than in `observations_pred` because the last action 191 | ## does not have an associated next_state in the sampled trajectory 192 | observations_real = observations_real[:,:-1] 193 | 194 | images_pred = np.stack([ 195 | self._renders(obs_pred, partial=True) 196 | for obs_pred in observations_pred 197 | ]) 198 | 199 | images_real = np.stack([ 200 | self._renders(obs_real, partial=False) 201 | for obs_real in observations_real 202 | ]) 203 | 204 | ## [ batch_size x horizon x H x W x C ] 205 | images = np.concatenate([images_pred, images_real], axis=-2) 206 | save_videos(savepath, *images) 207 | 208 | def render_diffusion(self, savepath, diffusion_path, **video_kwargs): 209 | ''' 210 | diffusion_path : [ n_diffusion_steps x batch_size x 1 x horizon x joined_dim ] 211 | ''' 212 | render_kwargs = { 213 | 'trackbodyid': 2, 214 | 'distance': 10, 215 | 'lookat': [10, 2, 0.5], 216 | 'elevation': 0, 217 | } 218 | 219 | diffusion_path = to_np(diffusion_path) 220 | 221 | n_diffusion_steps, batch_size, _, horizon, joined_dim = diffusion_path.shape 222 | 223 | frames = [] 224 | for t in reversed(range(n_diffusion_steps)): 225 | print(f'[ utils/renderer ] Diffusion: {t} / {n_diffusion_steps}') 226 | 227 | ## [ batch_size x horizon x observation_dim ] 228 | states_l = diffusion_path[t].reshape(batch_size, horizon, joined_dim)[:, :, :self.observation_dim] 229 | 230 | frame = [] 231 | for states in states_l: 232 | img = self.composite(None, states, dim=(1024, 256), partial=True, qvel=True, render_kwargs=render_kwargs) 233 | frame.append(img) 234 | frame = np.concatenate(frame, axis=0) 235 | 236 | frames.append(frame) 237 | 238 | save_video(savepath, frames, **video_kwargs) 239 | 240 | def __call__(self, *args, **kwargs): 241 | return self.renders(*args, **kwargs) 242 | 243 | #-----------------------------------------------------------------------------# 244 | #---------------------------------- rollouts ---------------------------------# 245 | #-----------------------------------------------------------------------------# 246 | 247 | def set_state(env, state): 248 | qpos_dim = env.sim.data.qpos.size 249 | qvel_dim = env.sim.data.qvel.size 250 | if not state.size == qpos_dim + qvel_dim: 251 | warnings.warn( 252 | f'[ utils/rendering ] Expected state of size {qpos_dim + qvel_dim}, ' 253 | f'but got state of size {state.size}') 254 | state = state[:qpos_dim + qvel_dim] 255 | 256 | env.set_state(state[:qpos_dim], state[qpos_dim:]) 257 | 258 | def rollouts_from_state(env, state, actions_l): 259 | rollouts = np.stack([ 260 | rollout_from_state(env, state, actions) 261 | for actions in actions_l 262 | ]) 263 | return rollouts 264 | 265 | def rollout_from_state(env, state, actions): 266 | qpos_dim = env.sim.data.qpos.size 267 | env.set_state(state[:qpos_dim], state[qpos_dim:]) 268 | observations = [env._get_obs()] 269 | for act in actions: 270 | obs, rew, term, _ = env.step(act) 271 | observations.append(obs) 272 | if term: 273 | break 274 | for i in range(len(observations), len(actions)+1): 275 | ## if terminated early, pad with zeros 276 | observations.append( np.zeros(obs.size) ) 277 | return np.stack(observations) 278 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/utils/serialization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import glob 4 | import torch 5 | import pdb 6 | 7 | from collections import namedtuple 8 | 9 | DiffusionExperiment = namedtuple('Diffusion', 'dataset renderer model diffusion ema trainer epoch') 10 | 11 | def mkdir(savepath): 12 | """ 13 | returns `True` iff `savepath` is created 14 | """ 15 | if not os.path.exists(savepath): 16 | os.makedirs(savepath) 17 | return True 18 | else: 19 | return False 20 | 21 | def get_latest_epoch(loadpath): 22 | states = glob.glob1(os.path.join(*loadpath), 'state_*') 23 | latest_epoch = -1 24 | for state in states: 25 | epoch = int(state.replace('state_', '').replace('.pt', '')) 26 | latest_epoch = max(epoch, latest_epoch) 27 | return latest_epoch 28 | 29 | def load_config(*loadpath): 30 | loadpath = os.path.join(*loadpath) 31 | config = pickle.load(open(loadpath, 'rb')) 32 | print(f'[ utils/serialization ] Loaded config from {loadpath}') 33 | print(config) 34 | return config 35 | 36 | def load_diffusion(*loadpath, epoch='latest', device='cuda:0'): 37 | dataset_config = load_config(*loadpath, 'dataset_config.pkl') 38 | render_config = load_config(*loadpath, 'render_config.pkl') 39 | model_config = load_config(*loadpath, 'model_config.pkl') 40 | diffusion_config = load_config(*loadpath, 'diffusion_config.pkl') 41 | trainer_config = load_config(*loadpath, 'trainer_config.pkl') 42 | 43 | ## remove absolute path for results loaded from azure 44 | ## @TODO : remove results folder from within trainer class 45 | trainer_config._dict['results_folder'] = os.path.join(*loadpath) 46 | 47 | dataset = dataset_config() 48 | renderer = render_config() 49 | model = model_config() 50 | diffusion = diffusion_config(model) 51 | trainer = trainer_config(diffusion, dataset, renderer) 52 | 53 | if epoch == 'latest': 54 | epoch = get_latest_epoch(loadpath) 55 | 56 | print(f'\n[ utils/serialization ] Loading model epoch: {epoch}\n') 57 | 58 | trainer.load(epoch) 59 | 60 | return DiffusionExperiment(dataset, renderer, model, diffusion, trainer.ema_model, trainer, epoch) 61 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/utils/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | import random 4 | import numpy as np 5 | import torch 6 | from tap import Tap 7 | import pdb 8 | 9 | from .serialization import mkdir 10 | from .git_utils import ( 11 | get_git_rev, 12 | save_git_diff, 13 | ) 14 | 15 | def set_seed(seed): 16 | random.seed(seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | 21 | def watch(args_to_watch): 22 | def _fn(args): 23 | exp_name = [] 24 | for key, label in args_to_watch: 25 | if not hasattr(args, key): 26 | continue 27 | val = getattr(args, key) 28 | if type(val) == dict: 29 | val = '_'.join(f'{k}-{v}' for k, v in val.items()) 30 | exp_name.append(f'{label}{val}') 31 | exp_name = '_'.join(exp_name) 32 | exp_name = exp_name.replace('/_', '/') 33 | exp_name = exp_name.replace('(', '').replace(')', '') 34 | exp_name = exp_name.replace(', ', '-') 35 | return exp_name 36 | return _fn 37 | 38 | def lazy_fstring(template, args): 39 | ## https://stackoverflow.com/a/53671539 40 | return eval(f"f'{template}'") 41 | 42 | class Parser(Tap): 43 | 44 | def save(self): 45 | fullpath = os.path.join(self.savepath, 'args.json') 46 | print(f'[ utils/setup ] Saved args to {fullpath}') 47 | super().save(fullpath, skip_unpicklable=True) 48 | 49 | def parse_args(self, experiment=None): 50 | args = super().parse_args(known_only=True) 51 | ## if not loading from a config script, skip the result of the setup 52 | if not hasattr(args, 'config'): return args 53 | args = self.read_config(args, experiment) 54 | self.add_extras(args) 55 | self.eval_fstrings(args) 56 | self.set_seed(args) 57 | self.get_commit(args) 58 | self.generate_exp_name(args) 59 | self.mkdir(args) 60 | self.save_diff(args) 61 | return args 62 | 63 | def read_config(self, args, experiment): 64 | ''' 65 | Load parameters from config file 66 | ''' 67 | dataset = args.dataset.replace('-', '_') 68 | print(f'[ utils/setup ] Reading config: {args.config}:{dataset}') 69 | module = importlib.import_module(args.config) 70 | params = getattr(module, 'base')[experiment] 71 | 72 | if hasattr(module, dataset) and experiment in getattr(module, dataset): 73 | print(f'[ utils/setup ] Using overrides | config: {args.config} | dataset: {dataset}') 74 | overrides = getattr(module, dataset)[experiment] 75 | params.update(overrides) 76 | else: 77 | print(f'[ utils/setup ] Not using overrides | config: {args.config} | dataset: {dataset}') 78 | 79 | self._dict = {} 80 | for key, val in params.items(): 81 | setattr(args, key, val) 82 | self._dict[key] = val 83 | 84 | return args 85 | 86 | def add_extras(self, args): 87 | ''' 88 | Override config parameters with command-line arguments 89 | ''' 90 | extras = args.extra_args 91 | if not len(extras): 92 | return 93 | 94 | print(f'[ utils/setup ] Found extras: {extras}') 95 | assert len(extras) % 2 == 0, f'Found odd number ({len(extras)}) of extras: {extras}' 96 | for i in range(0, len(extras), 2): 97 | key = extras[i].replace('--', '') 98 | val = extras[i+1] 99 | assert hasattr(args, key), f'[ utils/setup ] {key} not found in config: {args.config}' 100 | old_val = getattr(args, key) 101 | old_type = type(old_val) 102 | print(f'[ utils/setup ] Overriding config | {key} : {old_val} --> {val}') 103 | if val == 'None': 104 | val = None 105 | elif val == 'latest': 106 | val = 'latest' 107 | elif old_type in [bool, type(None)]: 108 | try: 109 | val = eval(val) 110 | except: 111 | print(f'[ utils/setup ] Warning: could not parse {val} (old: {old_val}, {old_type}), using str') 112 | else: 113 | val = old_type(val) 114 | setattr(args, key, val) 115 | self._dict[key] = val 116 | 117 | def eval_fstrings(self, args): 118 | for key, old in self._dict.items(): 119 | if type(old) is str and old[:2] == 'f:': 120 | val = old.replace('{', '{args.').replace('f:', '') 121 | new = lazy_fstring(val, args) 122 | print(f'[ utils/setup ] Lazy fstring | {key} : {old} --> {new}') 123 | setattr(self, key, new) 124 | self._dict[key] = new 125 | 126 | def set_seed(self, args): 127 | if not 'seed' in dir(args): 128 | return 129 | print(f'[ utils/setup ] Setting seed: {args.seed}') 130 | set_seed(args.seed) 131 | 132 | def generate_exp_name(self, args): 133 | if not 'exp_name' in dir(args): 134 | return 135 | exp_name = getattr(args, 'exp_name') 136 | if callable(exp_name): 137 | exp_name_string = exp_name(args) 138 | print(f'[ utils/setup ] Setting exp_name to: {exp_name_string}') 139 | setattr(args, 'exp_name', exp_name_string) 140 | self._dict['exp_name'] = exp_name_string 141 | 142 | def mkdir(self, args): 143 | if 'logbase' in dir(args) and 'dataset' in dir(args) and 'exp_name' in dir(args): 144 | args.savepath = os.path.join(args.logbase, args.dataset, args.exp_name) 145 | self._dict['savepath'] = args.savepath 146 | if 'suffix' in dir(args): 147 | args.savepath = os.path.join(args.savepath, args.suffix) 148 | if mkdir(args.savepath): 149 | print(f'[ utils/setup ] Made savepath: {args.savepath}') 150 | self.save() 151 | 152 | def get_commit(self, args): 153 | args.commit = get_git_rev() 154 | 155 | def save_diff(self, args): 156 | try: 157 | save_git_diff(os.path.join(args.savepath, 'diff.txt')) 158 | except: 159 | print('[ utils/setup ] WARNING: did not save git diff') 160 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | class Timer: 4 | 5 | def __init__(self): 6 | self._start = time.time() 7 | 8 | def __call__(self, reset=True): 9 | now = time.time() 10 | diff = now - self._start 11 | if reset: 12 | self._start = now 13 | return diff -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/utils/training.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import numpy as np 4 | import torch 5 | import einops 6 | # import pdb 7 | import diffuser 8 | from copy import deepcopy 9 | 10 | from .arrays import batch_to_device, to_np, to_device, apply_dict 11 | from .timer import Timer 12 | # from .cloud import sync_logs 13 | from ml_logger import logger 14 | 15 | def cycle(dl): 16 | while True: 17 | for data in dl: 18 | yield data 19 | 20 | class EMA(): 21 | ''' 22 | empirical moving average 23 | ''' 24 | def __init__(self, beta): 25 | super().__init__() 26 | self.beta = beta 27 | 28 | def update_model_average(self, ma_model, current_model): 29 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 30 | old_weight, up_weight = ma_params.data, current_params.data 31 | ma_params.data = self.update_average(old_weight, up_weight) 32 | 33 | def update_average(self, old, new): 34 | if old is None: 35 | return new 36 | return old * self.beta + (1 - self.beta) * new 37 | 38 | class Trainer(object): 39 | def __init__( 40 | self, 41 | diffusion_model, 42 | dataset, 43 | renderer, 44 | ema_decay=0.995, 45 | train_batch_size=32, 46 | train_lr=2e-5, 47 | gradient_accumulate_every=2, 48 | step_start_ema=2000, 49 | update_ema_every=10, 50 | log_freq=100, 51 | sample_freq=1000, 52 | save_freq=1000, 53 | label_freq=100000, 54 | save_parallel=False, 55 | n_reference=8, 56 | bucket=None, 57 | train_device='cuda', 58 | save_checkpoints=False, 59 | ): 60 | super().__init__() 61 | self.model = diffusion_model 62 | self.ema = EMA(ema_decay) 63 | self.ema_model = copy.deepcopy(self.model) 64 | self.update_ema_every = update_ema_every 65 | self.save_checkpoints = save_checkpoints 66 | 67 | self.step_start_ema = step_start_ema 68 | self.log_freq = log_freq 69 | self.sample_freq = sample_freq 70 | self.save_freq = save_freq 71 | self.label_freq = label_freq 72 | self.save_parallel = save_parallel 73 | 74 | self.batch_size = train_batch_size 75 | self.gradient_accumulate_every = gradient_accumulate_every 76 | 77 | self.dataset = dataset 78 | 79 | self.dataloader = cycle(torch.utils.data.DataLoader( 80 | self.dataset, batch_size=train_batch_size, num_workers=0, shuffle=True, pin_memory=True 81 | )) 82 | self.dataloader_vis = cycle(torch.utils.data.DataLoader( 83 | self.dataset, batch_size=1, num_workers=0, shuffle=True, pin_memory=True 84 | )) 85 | self.renderer = renderer 86 | self.optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=train_lr) 87 | 88 | self.bucket = bucket 89 | self.n_reference = n_reference 90 | 91 | self.reset_parameters() 92 | self.step = 0 93 | 94 | self.device = train_device 95 | 96 | def reset_parameters(self): 97 | self.ema_model.load_state_dict(self.model.state_dict()) 98 | 99 | def step_ema(self): 100 | if self.step < self.step_start_ema: 101 | self.reset_parameters() 102 | return 103 | self.ema.update_model_average(self.ema_model, self.model) 104 | 105 | #-----------------------------------------------------------------------------# 106 | #------------------------------------ api ------------------------------------# 107 | #-----------------------------------------------------------------------------# 108 | 109 | def train(self, n_train_steps): 110 | 111 | timer = Timer() 112 | for step in range(n_train_steps): 113 | for i in range(self.gradient_accumulate_every): 114 | batch = next(self.dataloader) 115 | batch = batch_to_device(batch, device=self.device) 116 | loss, infos = self.model.loss(*batch) 117 | loss = loss / self.gradient_accumulate_every 118 | loss.backward() 119 | 120 | self.optimizer.step() 121 | self.optimizer.zero_grad() 122 | 123 | if self.step % self.update_ema_every == 0: 124 | self.step_ema() 125 | 126 | if self.step % self.save_freq == 0: 127 | self.save() 128 | 129 | if self.step % self.log_freq == 0: 130 | infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()]) 131 | logger.print(f'{self.step}: {loss:8.4f} | {infos_str} | t: {timer():8.4f}') 132 | metrics = {k:v.detach().item() for k, v in infos.items()} 133 | metrics['steps'] = self.step 134 | metrics['loss'] = loss.detach().item() 135 | logger.log_metrics_summary(metrics, default_stats='mean') 136 | 137 | if self.step == 0 and self.sample_freq: 138 | self.render_reference(self.n_reference) 139 | 140 | if self.sample_freq and self.step % self.sample_freq == 0: 141 | if self.model.__class__ == diffuser.models.diffusion.GaussianInvDynDiffusion: 142 | self.inv_render_samples() 143 | elif self.model.__class__ == diffuser.models.diffusion.ActionGaussianDiffusion: 144 | pass 145 | else: 146 | self.render_samples() 147 | 148 | self.step += 1 149 | 150 | def save(self): 151 | ''' 152 | saves model and ema to disk; 153 | syncs to storage bucket if a bucket is specified 154 | ''' 155 | data = { 156 | 'step': self.step, 157 | 'model': self.model.state_dict(), 158 | 'ema': self.ema_model.state_dict() 159 | } 160 | savepath = os.path.join(self.bucket, logger.prefix, 'checkpoint') 161 | os.makedirs(savepath, exist_ok=True) 162 | # logger.save_torch(data, savepath) 163 | if self.save_checkpoints: 164 | savepath = os.path.join(savepath, f'state_{self.step}.pt') 165 | else: 166 | savepath = os.path.join(savepath, 'state.pt') 167 | torch.save(data, savepath) 168 | logger.print(f'[ utils/training ] Saved model to {savepath}') 169 | 170 | def load(self): 171 | ''' 172 | loads model and ema from disk 173 | ''' 174 | loadpath = os.path.join(self.bucket, logger.prefix, f'checkpoint/state.pt') 175 | # data = logger.load_torch(loadpath) 176 | data = torch.load(loadpath) 177 | 178 | self.step = data['step'] 179 | self.model.load_state_dict(data['model']) 180 | self.ema_model.load_state_dict(data['ema']) 181 | 182 | #-----------------------------------------------------------------------------# 183 | #--------------------------------- rendering ---------------------------------# 184 | #-----------------------------------------------------------------------------# 185 | 186 | def render_reference(self, batch_size=10): 187 | ''' 188 | renders training points 189 | ''' 190 | 191 | ## get a temporary dataloader to load a single batch 192 | dataloader_tmp = cycle(torch.utils.data.DataLoader( 193 | self.dataset, batch_size=batch_size, num_workers=0, shuffle=True, pin_memory=True 194 | )) 195 | batch = dataloader_tmp.__next__() 196 | dataloader_tmp.close() 197 | 198 | ## get trajectories and condition at t=0 from batch 199 | trajectories = to_np(batch.trajectories) 200 | conditions = to_np(batch.conditions[0])[:,None] 201 | 202 | ## [ batch_size x horizon x observation_dim ] 203 | normed_observations = trajectories[:, :, self.dataset.action_dim:] 204 | observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations') 205 | 206 | # from diffusion.datasets.preprocessing import blocks_cumsum_quat 207 | # # observations = conditions + blocks_cumsum_quat(deltas) 208 | # observations = conditions + deltas.cumsum(axis=1) 209 | 210 | #### @TODO: remove block-stacking specific stuff 211 | # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka 212 | # observations = blocks_add_kuka(observations) 213 | #### 214 | 215 | savepath = os.path.join('images', f'sample-reference.png') 216 | self.renderer.composite(savepath, observations) 217 | 218 | def render_samples(self, batch_size=2, n_samples=2): 219 | ''' 220 | renders samples from (ema) diffusion model 221 | ''' 222 | for i in range(batch_size): 223 | 224 | ## get a single datapoint 225 | batch = self.dataloader_vis.__next__() 226 | conditions = to_device(batch.conditions, self.device) 227 | ## repeat each item in conditions `n_samples` times 228 | conditions = apply_dict( 229 | einops.repeat, 230 | conditions, 231 | 'b d -> (repeat b) d', repeat=n_samples, 232 | ) 233 | 234 | ## [ n_samples x horizon x (action_dim + observation_dim) ] 235 | if self.ema_model.returns_condition: 236 | returns = to_device(torch.ones(n_samples, 1), self.device) 237 | else: 238 | returns = None 239 | 240 | if self.ema_model.model.calc_energy: 241 | samples = self.ema_model.grad_conditional_sample(conditions, returns=returns) 242 | else: 243 | samples = self.ema_model.conditional_sample(conditions, returns=returns) 244 | 245 | samples = to_np(samples) 246 | 247 | ## [ n_samples x horizon x observation_dim ] 248 | normed_observations = samples[:, :, self.dataset.action_dim:] 249 | 250 | # [ 1 x 1 x observation_dim ] 251 | normed_conditions = to_np(batch.conditions[0])[:,None] 252 | 253 | # from diffusion.datasets.preprocessing import blocks_cumsum_quat 254 | # observations = conditions + blocks_cumsum_quat(deltas) 255 | # observations = conditions + deltas.cumsum(axis=1) 256 | 257 | ## [ n_samples x (horizon + 1) x observation_dim ] 258 | normed_observations = np.concatenate([ 259 | np.repeat(normed_conditions, n_samples, axis=0), 260 | normed_observations 261 | ], axis=1) 262 | 263 | ## [ n_samples x (horizon + 1) x observation_dim ] 264 | observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations') 265 | 266 | #### @TODO: remove block-stacking specific stuff 267 | # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka 268 | # observations = blocks_add_kuka(observations) 269 | #### 270 | 271 | savepath = os.path.join('images', f'sample-{i}.png') 272 | self.renderer.composite(savepath, observations) 273 | 274 | def inv_render_samples(self, batch_size=2, n_samples=2): 275 | ''' 276 | renders samples from (ema) diffusion model 277 | ''' 278 | for i in range(batch_size): 279 | 280 | ## get a single datapoint 281 | batch = self.dataloader_vis.__next__() 282 | conditions = to_device(batch.conditions, self.device) 283 | ## repeat each item in conditions `n_samples` times 284 | conditions = apply_dict( 285 | einops.repeat, 286 | conditions, 287 | 'b d -> (repeat b) d', repeat=n_samples, 288 | ) 289 | 290 | ## [ n_samples x horizon x (action_dim + observation_dim) ] 291 | if self.ema_model.returns_condition: 292 | returns = to_device(torch.ones(n_samples, 1), self.device) 293 | else: 294 | returns = None 295 | 296 | if self.ema_model.model.calc_energy: 297 | samples = self.ema_model.grad_conditional_sample(conditions, returns=returns) 298 | else: 299 | samples = self.ema_model.conditional_sample(conditions, returns=returns) 300 | 301 | samples = to_np(samples) 302 | 303 | ## [ n_samples x horizon x observation_dim ] 304 | normed_observations = samples[:, :, :] 305 | 306 | # [ 1 x 1 x observation_dim ] 307 | normed_conditions = to_np(batch.conditions[0])[:,None] 308 | 309 | # from diffusion.datasets.preprocessing import blocks_cumsum_quat 310 | # observations = conditions + blocks_cumsum_quat(deltas) 311 | # observations = conditions + deltas.cumsum(axis=1) 312 | 313 | ## [ n_samples x (horizon + 1) x observation_dim ] 314 | normed_observations = np.concatenate([ 315 | np.repeat(normed_conditions, n_samples, axis=0), 316 | normed_observations 317 | ], axis=1) 318 | 319 | ## [ n_samples x (horizon + 1) x observation_dim ] 320 | observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations') 321 | 322 | #### @TODO: remove block-stacking specific stuff 323 | # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka 324 | # observations = blocks_add_kuka(observations) 325 | #### 326 | 327 | savepath = os.path.join('images', f'sample-{i}.png') 328 | self.renderer.composite(savepath, observations) -------------------------------------------------------------------------------- /skilldiffuser/hrl/diffuser/utils/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import skvideo.io 4 | from ml_logger import logger 5 | 6 | def _make_dir(filename): 7 | folder = os.path.dirname(filename) 8 | if not os.path.exists(folder): 9 | os.makedirs(folder) 10 | 11 | def save_video(filename, video_frames, fps=60, video_format='mp4'): 12 | assert fps == int(fps), fps 13 | # logger.save_video(video_frames, filename, fps=fps, format=video_format) 14 | _make_dir(filename) 15 | 16 | skvideo.io.vwrite( 17 | filename, 18 | video_frames, 19 | inputdict={ 20 | '-r': str(int(fps)), 21 | }, 22 | outputdict={ 23 | '-f': video_format, 24 | '-pix_fmt': 'yuv420p', # '-pix_fmt=yuv420p' needed for osx https://github.com/scikit-video/scikit-video/issues/74 25 | } 26 | ) 27 | 28 | def save_videos(filename, *video_frames, axis=1, **kwargs): 29 | ## video_frame : [ N x H x W x C ] 30 | video_frames = np.concatenate(video_frames, axis=axis) 31 | save_video(filename, video_frames, **kwargs) 32 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import gym 4 | from PIL import Image 5 | import pdb 6 | 7 | from utils import String, lorl_gt_reward, lorl_save_im 8 | 9 | 10 | class BaseWrapper(gym.Wrapper): 11 | """Base processing wrapper. 12 | 13 | 1) Adds String command to observations 14 | 2) Preprocess states 15 | """ 16 | 17 | def __init__(self, env, dataset): 18 | super(BaseWrapper, self).__init__(env) 19 | self.env = env 20 | self.dataset = dataset 21 | self.state_mean, self.state_std = self.dataset.state_mean, self.dataset.state_std 22 | 23 | self.state_dim = len(self.state_mean) 24 | 25 | self.observation_space = gym.spaces.Dict( 26 | { 27 | "state": gym.spaces.Box( 28 | low=-np.inf, # 0.0 29 | high=np.inf, # 1.0 30 | shape=(self.state_dim,), 31 | dtype=np.float32, 32 | ), 33 | "lang": String(), 34 | } 35 | ) 36 | 37 | self.act_dim = env.action_space.shape[0] 38 | # print(self.action_space.low, self.action_space.high) 39 | 40 | def reset(self, **kwargs): 41 | obs = self.env.reset() 42 | return self.get_state(obs) 43 | 44 | def step(self, action): 45 | obs, reward, done, info = self.env.step(action) 46 | 47 | if done: 48 | success = 0 49 | if reward > 0: 50 | success = 1 51 | info.update({'success': success}) 52 | 53 | return self.get_state(obs), reward, done, info 54 | 55 | def get_state(self, obs): 56 | """Returns the observation and lang_cmd""" 57 | 58 | lang = "dummy" 59 | cur_state = (obs.reshape(-1) - self.state_mean) / self.state_std 60 | 61 | return {'state': cur_state, 'lang': lang} 62 | 63 | def get_image(self): 64 | return Image.fromarray(self.env.render(), 'RGB') 65 | 66 | 67 | class BabyAIWrapper(gym.Wrapper): 68 | """BabyAI processing wrapper. 69 | 70 | 1) Adds String command to observations 71 | 2) Preprocess states 72 | """ 73 | 74 | def __init__(self, env, dataset): 75 | super(BabyAIWrapper, self).__init__(env) 76 | self.env = env 77 | self.dataset = dataset 78 | self.state_mean, self.state_std = self.dataset.state_mean, self.dataset.state_std 79 | 80 | self.state_dim = len(self.state_mean) 81 | if self.dataset.kwargs['use_direction']: 82 | self.state_dim += 4 # for the direction part in BabyAI 83 | 84 | self.observation_space = gym.spaces.Dict( 85 | { 86 | "state": gym.spaces.Box( 87 | low=-np.inf, # 0.0 88 | high=np.inf, # 1.0 89 | shape=(self.state_dim,), 90 | dtype=np.float32, 91 | ), 92 | "lang": String(), 93 | } 94 | ) 95 | 96 | self.act_dim = env.action_space.n 97 | 98 | def reset(self, **kwargs): 99 | obs = self.env.reset() 100 | return self.get_state(obs) 101 | 102 | def step(self, action): 103 | obs, reward, done, info = self.env.step(action) 104 | 105 | if done: 106 | success = 0 107 | if reward > 0: 108 | success = 1 109 | info.update({'success': success}) 110 | 111 | return self.get_state(obs), reward, done, info 112 | 113 | def get_state(self, obs): 114 | """Returns the observation and lang_cmd""" 115 | 116 | lang = obs["mission"] 117 | cur_state = (obs["image"].reshape(-1) - self.state_mean) / self.state_std 118 | 119 | if self.dataset.kwargs['use_direction']: 120 | direction = np.zeros(4) 121 | direction[obs["direction"]] = 1. 122 | cur_state = np.concatenate([cur_state, direction]) 123 | return {'state': cur_state, 'lang': lang} 124 | 125 | def get_image(self): 126 | return Image.fromarray(self.env.render(), 'RGB') 127 | 128 | 129 | class LorlWrapper(gym.Wrapper): 130 | """BabyAI processing wrapper. 131 | 132 | 1) Adds String command to observations 133 | 2) Preprocess states 134 | """ 135 | 136 | def __init__(self, env, dataset, **kwargs): 137 | super(LorlWrapper, self).__init__(env) 138 | 139 | self.env = env 140 | self.dataset = dataset 141 | self.use_state = dataset.kwargs["use_state"] 142 | self.state_mean, self.state_std = self.dataset.state_mean, self.dataset.state_std 143 | 144 | self.state_dim = self.state_mean.shape 145 | self.act_dim = env.action_space.shape[0] 146 | 147 | if isinstance(self.state_dim, tuple): 148 | self.observation_space = gym.spaces.Dict( 149 | { 150 | "state": gym.spaces.Box( 151 | low=-np.inf, 152 | high=np.inf, 153 | shape=self.state_dim, 154 | dtype=np.float32, 155 | ), 156 | "lang": String(), 157 | } 158 | ) 159 | else: 160 | self.observation_space = gym.spaces.Dict( 161 | { 162 | "state": gym.spaces.Box( 163 | low=-np.inf, 164 | high=np.inf, 165 | shape=(self.state_dim,), 166 | dtype=np.float32, 167 | ), 168 | "lang": String(), 169 | } 170 | ) 171 | 172 | self.initial_state = None 173 | self.instr = kwargs["instr"] 174 | self.orig_instr = kwargs["orig_instr"] 175 | 176 | def reset(self, render=False, **kwargs): 177 | if render: 178 | render_path, iter_num, i = kwargs['render_path'], kwargs['iter_num'], kwargs["i"] 179 | 180 | env = self.env 181 | orig_instr, instr = self.orig_instr, self.instr 182 | im, _ = env.reset() 183 | 184 | # Initialize state for different tasks 185 | if orig_instr == "open drawer": 186 | env.sim.data.qpos[14] = 0 + np.random.uniform(-0.05, 0) 187 | elif orig_instr == "close drawer": 188 | env.sim.data.qpos[14] = -0.1 + np.random.uniform(-0.05, 0.05) 189 | elif orig_instr == "turn faucet right": 190 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5) 191 | elif orig_instr == "turn faucet left": 192 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5) 193 | elif orig_instr == "move black mug right": 194 | env.sim.data.qpos[11] = -0.2 + np.random.uniform(-0.05, 0.05) 195 | env.sim.data.qpos[12] = 0.65 + np.random.uniform(-0.05, 0.05) 196 | elif orig_instr == "move white mug down": 197 | env.sim.data.qpos[9] = -0.2 + np.random.uniform(-0.05, 0.05) 198 | env.sim.data.qpos[10] = 0.65 + np.random.uniform(-0.05, 0.05) 199 | # Dont know if the following are correct 200 | elif orig_instr == 'open drawer and move black mug right': 201 | env.sim.data.qpos[14] = 0 + np.random.uniform(-0.05, 0) 202 | env.sim.data.qpos[11] = -0.2 + np.random.uniform(-0.05, 0.05) 203 | env.sim.data.qpos[12] = 0.65 + np.random.uniform(-0.05, 0.05) 204 | elif orig_instr == 'pull the handle and move black mug down': 205 | env.sim.data.qpos[14] = 0 + np.random.uniform(-0.05, 0) 206 | env.sim.data.qpos[11] = -0.2 + np.random.uniform(-0.05, 0.05) 207 | env.sim.data.qpos[12] = 0.65 + np.random.uniform(-0.05, 0.05) 208 | elif orig_instr == 'move white mug right': 209 | env.sim.data.qpos[9] = -0.2 + np.random.uniform(-0.05, 0.05) 210 | env.sim.data.qpos[10] = 0.65 + np.random.uniform(-0.05, 0.05) 211 | elif orig_instr == 'move black mug down': 212 | env.sim.data.qpos[11] = -0.2 + np.random.uniform(-0.05, 0.05) 213 | env.sim.data.qpos[12] = 0.65 + np.random.uniform(-0.05, 0.05) 214 | elif orig_instr == 'close drawer and turn faucet right': 215 | env.sim.data.qpos[14] = -0.1 + np.random.uniform(-0.05, 0.05) 216 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5) 217 | elif orig_instr == 'close drawer and turn faucet left': 218 | env.sim.data.qpos[14] = -0.1 + np.random.uniform(-0.05, 0.05) 219 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5) 220 | elif orig_instr == 'turn faucet left and move white mug down': 221 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5) 222 | env.sim.data.qpos[9] = -0.2 + np.random.uniform(-0.05, 0.05) 223 | env.sim.data.qpos[10] = 0.65 + np.random.uniform(-0.05, 0.05) 224 | elif orig_instr == 'turn faucet right and close drawer': 225 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5) 226 | env.sim.data.qpos[14] = -0.1 + np.random.uniform(-0.05, 0.05) 227 | elif orig_instr == 'move white mug down and turn faucet left': 228 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5) 229 | env.sim.data.qpos[9] = -0.2 + np.random.uniform(-0.05, 0.05) 230 | env.sim.data.qpos[10] = 0.65 + np.random.uniform(-0.05, 0.05) 231 | elif orig_instr == 'close the drawer, turn the faucet left and move black mug right': 232 | env.sim.data.qpos[14] = -0.1 + np.random.uniform(-0.05, 0.05) 233 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5) 234 | env.sim.data.qpos[11] = -0.2 + np.random.uniform(-0.05, 0.05) 235 | env.sim.data.qpos[12] = 0.65 + np.random.uniform(-0.05, 0.05) 236 | elif instr == "open drawer and turn faucet counterclockwise": 237 | env.sim.data.qpos[14] = 0 + np.random.uniform(-0.05, 0) 238 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5) 239 | elif instr == "slide the drawer closed and then shift white mug down": 240 | env.sim.data.qpos[14] = -0.1 + np.random.uniform(-0.05, 0.05) 241 | env.sim.data.qpos[9] = -0.2 + np.random.uniform(-0.05, 0.05) 242 | env.sim.data.qpos[10] = 0.65 + np.random.uniform(-0.05, 0.05) 243 | 244 | # if orig_instr == "move white mug down": 245 | # env._reset_hand(pos=[-0.1, 0.55, 0.1]) 246 | # elif orig_instr == "move black mug right": 247 | # env._reset_hand(pos=[-0.1, 0.55, 0.1]) 248 | if "mug" in orig_instr: 249 | env._reset_hand(pos=[-0.1, 0.55, 0.1]) 250 | else: 251 | env._reset_hand(pos=[0, 0.45, 0.1]) 252 | 253 | for _ in range(50): 254 | env.sim.step() 255 | 256 | reset_state = copy.deepcopy(env.sim.data.qpos[:]) 257 | env.sim.data.qpos[:] = reset_state 258 | env.sim.data.qacc[:] = 0 259 | env.sim.data.qvel[:] = 0 260 | env.sim.step() 261 | self.initial_state = copy.deepcopy(env.sim.data.qpos[:]) 262 | 263 | if render: 264 | # Initialize goal image for initial state 265 | if orig_instr == "open drawer": 266 | env.sim.data.qpos[14] = -0.15 267 | elif orig_instr == "close drawer": 268 | env.sim.data.qpos[14] = 0.0 269 | elif orig_instr == "turn faucet right": 270 | env.sim.data.qpos[13] -= np.pi/5 271 | elif orig_instr == "turn faucet left": 272 | env.sim.data.qpos[13] += np.pi/5 273 | elif orig_instr == "move black mug right": 274 | env.sim.data.qpos[11] -= 0.1 275 | elif orig_instr == "move white mug down": 276 | env.sim.data.qpos[10] += 0.1 277 | 278 | env.sim.step() 279 | gim = env._get_obs()[:, :, :3] 280 | 281 | # Reset inital state 282 | env.sim.data.qpos[:] = reset_state 283 | env.sim.data.qacc[:] = 0 284 | env.sim.data.qvel[:] = 0 285 | env.sim.step() 286 | 287 | im = env._get_obs()[:, :, :3] 288 | initim = im 289 | render_path = "." 290 | lorl_save_im( 291 | (initim * 255.0).astype(np.uint8), 292 | render_path + f"/initialim_{iter_num}_{i}_{instr}.jpg") 293 | lorl_save_im((gim*255.0).astype(np.uint8), render_path+f"gim_{iter_num}_{i}_{instr}.jpg") 294 | 295 | observation = self.get_state(im) 296 | cur_state, lang = observation['state'], observation['lang'] 297 | if self.use_state: 298 | cur_state = (cur_state - self.state_mean) / self.state_std 299 | self.state_dim = len(cur_state) 300 | else: 301 | im = np.moveaxis(im, 2, 0) # make H,W,C to C,H,w 302 | cur_state = (im - self.state_mean) / self.state_std 303 | self.state_dim = cur_state.shape 304 | 305 | return {'state': cur_state, 'lang': lang} 306 | 307 | def step(self, action): 308 | im, _, _, info = self.env.step(action) 309 | dist, s = lorl_gt_reward(self.env.sim.data.qpos[:], self.initial_state, self.orig_instr) 310 | 311 | reward = 0 312 | success = 0 313 | if s: 314 | success = 1 315 | reward = dist 316 | 317 | info.update({'success': success}) 318 | return self.get_state(im), reward, s, info 319 | 320 | def get_state(self, obs): 321 | """Returns the observation and lang_cmd""" 322 | 323 | if self.use_state: 324 | obs = self.env.sim.data.qpos[:] 325 | else: 326 | obs = np.moveaxis(obs, 2, 0) # make H,W,C to C,H,w 327 | 328 | state = (obs - self.state_mean) / self.state_std 329 | return {'state': state, 'lang': self.instr} 330 | 331 | def get_image(self, h=1024, w=1024): 332 | # im = self.env._get_obs() 333 | obs = self.sim.render(h, w, camera_name="cam0") / 255. 334 | im = np.flip(obs, 0).copy() 335 | return (im[:, :, :3]*255.0).astype(np.uint8) 336 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/hrl_model.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import torch 4 | import torch.nn as nn 5 | from transformers import DistilBertModel 6 | 7 | from iq import IQMixin 8 | from option_selector import OptionSelector 9 | from reconstructors import StateReconstructor, LanguageReconstructor 10 | # from decision_transformer import DecisionTransformer 11 | from dec_encoder import DecEncoder 12 | # from utils import pad 13 | # import diffuser.utils as diffutils 14 | 15 | 16 | class HRLModel(nn.Module, IQMixin): 17 | """Base class containing all the models""" 18 | 19 | def __init__(self, args, option_selector_args, state_reconstructor_args, lang_reconstructor_args, 20 | decision_args, iq_args, diff_trainer, device, horizon=5, K=10, train_lm=True, 21 | method='vanilla', state_reconstruct=False, lang_reconstruct=False, **kwargs): 22 | super().__init__() 23 | 24 | self.args = args 25 | self.horizon = horizon 26 | self.K = K 27 | self.lm = DistilBertModel.from_pretrained('distilbert-base-uncased') 28 | self.train_lm = train_lm # whether to train lm or not 29 | self.device = device 30 | 31 | if train_lm: 32 | self.lm.train() 33 | else: 34 | self.lm.eval() 35 | 36 | self.method = method 37 | self.state_reconstruct = state_reconstruct 38 | self.lang_reconstruct = lang_reconstruct 39 | 40 | self.state_dim = decision_args['state_dim'] 41 | self.action_dim = decision_args['action_dim'] 42 | self.option_dim = decision_args['option_dim'] 43 | 44 | # self.decision_transformer = DecisionTransformer(lang_dim=self.lm.config.dim, **decision_transformer_args) 45 | 46 | if method == 'vanilla': 47 | assert decision_args['use_language'] and not decision_args['use_options'] 48 | else: 49 | assert decision_args['use_options'] and not decision_args['use_language'] 50 | 51 | self.option_selector = OptionSelector(lang_dim=self.lm.config.dim, 52 | method=self.method, **option_selector_args) 53 | 54 | if state_reconstruct: 55 | self.state_reconstructor = StateReconstructor(**state_reconstructor_args) 56 | if lang_reconstruct: 57 | self.lang_reconstructor = LanguageReconstructor( 58 | lang_dim=self.lm.config.dim, **lang_reconstructor_args) 59 | 60 | # TODO 61 | self.diffuser = DecEncoder(lang_dim=self.lm.config.dim, K=K, **decision_args) 62 | 63 | self.diff_trainer = diff_trainer 64 | # self.diff_trainer.set_optimizer(self.diffuser) 65 | 66 | # if self.decision_transformer.predict_q: 67 | # # initialize iq mixins 68 | # IQMixin.__init__(self, self.decision_transformer, iq_args, device) 69 | 70 | def forward(self, lm_input_ids, lm_attention_mask, states, actions, timesteps, iter_num, ind, diff_train_num, attention_mask=None): 71 | # pdb.set_trace() 72 | batch_size, traj_len = states.shape[0], states.shape[1] 73 | if not self.train_lm: 74 | with torch.no_grad(): 75 | # (batch_size,num_embeddings,embedding_size) 76 | lm_embeddings = self.lm(lm_input_ids, lm_attention_mask).last_hidden_state 77 | else: 78 | # (batch_size,num_embeddings,embedding_size) 79 | lm_embeddings = self.lm(lm_input_ids, lm_attention_mask).last_hidden_state 80 | cls_embeddings = lm_embeddings[:, 0, :].unsqueeze_(1) 81 | # word_embeddings = lm_embeddings[:, 1:-1, :] # We skip the CLS and SEP tokens. I know there's padding here but we at least always remove the CLS 82 | word_embeddings = lm_embeddings[:, 1:, :] # We skip the CLS tokens 83 | 84 | entropy = None 85 | if self.method == 'vanilla': 86 | raise NotImplementedError 87 | # preds = self.decision_transformer( 88 | # states, actions, timesteps, word_embeddings=word_embeddings, attention_mask=attention_mask) 89 | # 90 | # state_rc_preds = None 91 | # state_rc_targets = None 92 | # 93 | # lang_rc_preds = None 94 | # lang_rc_targets = None 95 | # 96 | # commitment_loss = None 97 | else: 98 | # TODO (important) how does this get padded across batches? some of these horizon states may actually be padding 99 | # we change options only every H states. Say this leads to N states 100 | # N selected options 101 | # B, max_length // H, option_dim 102 | if self.method == 'option': 103 | # selected_options, _, commitment_loss = self.option_selector(cls_embeddings, states) 104 | selected_options, _, commitment_loss, entropy, state_embeddings = self.option_selector( 105 | word_embeddings.mean(1, keepdim=True), states) 106 | else: 107 | selected_options, _, commitment_loss, entropy, state_embeddings = self.option_selector( 108 | word_embeddings, states, timesteps, attention_mask) 109 | 110 | # # need to make options same length as states and actions 111 | options = torch.zeros((batch_size, traj_len, selected_options.shape[-1])).to(selected_options.device) 112 | 113 | ### This doesn't really work in making only some messages have gradients and others not having gradients 114 | ### The entire options tensor below has gradients after we do options = selected_options 115 | ### Actually it may only have the gradients related to the selectbackward operation -- unsure 116 | 117 | # Repeated detached options for horizon length 118 | for i in range(selected_options.shape[1]): 119 | options[:, i*self.horizon:(i+1)*self.horizon, :] = selected_options[:, 120 | i, :].unsqueeze(1).clone().detach() 121 | # Make sure to pass gradients for options at each horizon steps 122 | options[:, ::self.horizon, :] = selected_options 123 | 124 | # We reshape sequences to K size sub-sequences, so that the sub-policy only uses the current option 125 | # Here we are choosing K to be horizon 126 | B, L = states.shape[0], states.shape[1] 127 | num_seq = L // self.K # self.K == self.horizon == 8 128 | 129 | # We reshape sequences to K size sub-sequences, so that the sub-policy only uses the current option 130 | # Here we are choosing K to be horizon since it makes sense but technically we can do any K 131 | # This ensures the DT only looks at chunks of size horizon 132 | if isinstance(self.state_dim, tuple): 133 | states = states.reshape(B * num_seq, self.K, *self.state_dim) 134 | else: 135 | states = states.reshape(B * num_seq, self.K, self.state_dim) 136 | state_embeddings = state_embeddings.reshape(B * num_seq, self.K, state_embeddings.shape[2]) 137 | options = options.reshape(B * num_seq, self.K, self.option_dim) 138 | actions = actions.reshape(B * num_seq, self.K, self.action_dim) 139 | # Should these timesteps be 1,2,3,4..H,1,2... or just 1,2,3,4...L? Going with 1,2,3,4...L 140 | timesteps = timesteps.reshape(B * num_seq, self.K) 141 | # timesteps = torch.arange(0, self.K).repeat(B * num_seq, 1) 142 | attention_mask = attention_mask.reshape(B * num_seq, self.K) 143 | 144 | # Make sure shapes are okay 145 | assert states.shape[0] == actions.shape[0] == options.shape[0] == timesteps.shape[0] == attention_mask.shape[0] == batch_size * num_seq 146 | assert states.shape[1] == actions.shape[1] == options.shape[1] == timesteps.shape[1] == attention_mask.shape[1] == self.K 147 | # preds = self.decision_transformer( 148 | # states, actions, timesteps, options=options, attention_mask=attention_mask) 149 | 150 | # TODO 151 | encoder_out = self.diffuser.encode(states, actions, timesteps, options=options, 152 | state_embeddings=state_embeddings, attention_mask=attention_mask) 153 | 154 | stacked_inputs = encoder_out['stacked_inputs'].to(self.device) 155 | option_embeddings = encoder_out['option_embeddings'].to(self.device) 156 | conds = {} 157 | returns = option_embeddings # conditions 158 | # num_diff_steps = int(self.args.diffuser.n_train_steps // self.args.max_iters) 159 | for in_step in range(diff_train_num): 160 | # step = iter_num * num_diff_steps + ind * diff_train_num + in_step 161 | self.diff_trainer.train_iteration((stacked_inputs, conds, returns)) 162 | 163 | diff_loss, infos = self.diff_trainer.model.loss(stacked_inputs, conds, returns) 164 | obs_recon = infos.pop('obs') 165 | 166 | decode_outputs = self.diffuser.decode(obs_recon) 167 | 168 | self.state_reconstruct = False 169 | self.lang_reconstruct = False 170 | if self.state_reconstruct: 171 | # TODO: Maybe fix?? We now predict an option using trajs 172 | state_rc_preds = self.state_reconstructor(selected_options) 173 | state_rc_targets = states # horizon_states 174 | else: 175 | state_rc_preds = None 176 | state_rc_targets = None 177 | 178 | if self.lang_reconstruct: 179 | # TODO: Maybe fix?? Do we need the max options formulation? Check this 180 | lang_rc_preds = self.lang_reconstructor(selected_options.reshape(batch_size, -1)) 181 | lang_rc_targets = cls_embeddings 182 | else: 183 | lang_rc_preds = None 184 | lang_rc_targets = None 185 | 186 | return {'dt': decode_outputs, 187 | 'state_rc': (state_rc_preds, state_rc_targets), 188 | 'lang_rc': (lang_rc_preds, lang_rc_targets), 189 | 'actions': actions, 190 | 'attention_mask': attention_mask, 191 | 'commitment_loss': commitment_loss, 192 | 'diff_loss': diff_loss, 193 | 'entropy': entropy} 194 | 195 | def get_action(self, states, actions, timesteps, options=None, word_embeddings=None, ema_diffusion_model=None): 196 | if self.method == 'vanilla': 197 | preds = self.decision_transformer.get_action( 198 | states, actions, timesteps, word_embeddings=word_embeddings) 199 | 200 | return preds 201 | else: 202 | # preds = self.decision_transformer.get_action( 203 | # states, actions, timesteps, options=options) 204 | encoder_out = self.diffuser.get_action(states, actions, timesteps, 205 | options=options, embed_state=self.option_selector.option_dt.embed_state) 206 | stacked_inputs = encoder_out['stacked_inputs'] 207 | option_embeddings = encoder_out['option_embeddings'] 208 | 209 | conds = {0: stacked_inputs[:, -1, self.diffuser.act_dim:]} 210 | returns = option_embeddings # conditions 211 | 212 | samples = ema_diffusion_model.conditional_sample(conds, returns=returns) 213 | 214 | action_hist = [] 215 | for tt in range(samples.shape[1]-1): 216 | obs_comb = torch.cat([samples[:, tt, :], samples[:, tt+1, :]], dim=-1) 217 | obs_comb = obs_comb.reshape(-1, 2 * self.diffuser.hidden_size) 218 | action = ema_diffusion_model.inv_model(obs_comb) 219 | action_hist.append(action) 220 | 221 | # obs_comb = torch.cat([samples[:, 0, :], samples[:, 1, :]], dim=-1) 222 | # obs_comb = obs_comb.reshape(-1, 2 * self.diffuser.hidden_size) 223 | # action = ema_diffusion_model.inv_model(obs_comb) 224 | 225 | if self.diffuser.predict_q: 226 | # Choose actions from q_values 227 | # action = self.iq_choose_action(preds['q_preds'][:, -1], sample=True) 228 | raise NotImplementedError 229 | else: 230 | # Choose actions from direct predictions 231 | # action = preds['action_preds'][:, -1] 232 | if self.diffuser.discrete: 233 | action = action.argmax(dim=1) 234 | 235 | # action = action.squeeze(0) 236 | action = torch.cat(action_hist, dim=0) 237 | return action 238 | 239 | def save(self, iter_num, filepath, config): 240 | if hasattr(self.model, 'module'): 241 | model = self.model.module 242 | 243 | torch.save({'model': model.state_dict(), 244 | 'optimizer': self.optimizer.state_dict(), 245 | 'scheduler': self.scheduler.state_dict(), 246 | 'iter_num': iter_num, 247 | 'train_dataset_max_length': self.train_loader.dataset.max_length, 248 | 'config': config}, filepath) 249 | 250 | def load(self, filepath): 251 | checkpoint = torch.load(filepath) 252 | self.model.load_state_dict(checkpoint['model']) 253 | self.optimizer.load_state_dict(checkpoint['optimizer']) 254 | self.scheduler.load_state_dict(checkpoint['scheduler']) 255 | return {'iter_num': checkpoint['iter_num'], 'train_dataset_max_length': checkpoint['train_dataset_max_length'], 'config': checkpoint['config']} 256 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/img_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | 6 | # Image Encoder 7 | # From https://github.com/suraj-nair-1/lorel/blob/main/models.py 8 | class BaseEncoder(nn.Module): 9 | __constants__ = ['embedding_size'] 10 | 11 | def __init__(self): 12 | super().__init__() 13 | 14 | 15 | def preprocess(self, observations): 16 | """ 17 | Reshape to 4 dimensions so it works for convolutions 18 | Chunk the time and batch dimensions 19 | """ 20 | B, T, C, H, W = observations.shape 21 | return observations.reshape(-1, C, H, W).type(torch.float32).contiguous() 22 | 23 | def unpreprocess(self, embeddings, B, T): 24 | """ 25 | Reshape back to 5 dimensions 26 | Unsqueeze the batch and time dimensions 27 | """ 28 | BT, E = embeddings.shape 29 | return embeddings.reshape(B, T, E) 30 | 31 | 32 | # Image Encoder 33 | # From https://github.com/suraj-nair-1/lorel/blob/main/models.py 34 | class Encoder(BaseEncoder): 35 | __constants__ = ['embedding_size'] 36 | 37 | def __init__(self, hidden_size, activation_function='relu', ch=3, robot=False): 38 | super().__init__() 39 | self.act_fn = getattr(F, activation_function) 40 | self.softmax = nn.Softmax(dim=2) 41 | self.sigmoid = nn.Sigmoid() 42 | self.robot = robot 43 | if self.robot: 44 | g = 4 45 | else: 46 | g = 1 47 | self.conv1 = nn.Conv2d(ch, 32, 4, stride=2, padding=1, groups=g) # 3 48 | self.conv1_2 = nn.Conv2d(32, 32, 4, stride=1, padding=1, groups=g) 49 | self.conv2 = nn.Conv2d(32, 64, 4, stride=2, padding=1, groups=g) 50 | self.conv2_2 = nn.Conv2d(64, 64, 4, stride=1, padding=1, groups=g) 51 | self.conv3 = nn.Conv2d(64, 128, 4, stride=2, padding=1, groups=g) 52 | self.conv3_2 = nn.Conv2d(128, 128, 4, stride=1, padding=1, groups=g) 53 | self.conv4 = nn.Conv2d(128, 256, 4, stride=2, padding=1, groups=g) 54 | self.conv4_2 = nn.Conv2d(256, 256, 4, stride=1, padding=1, groups=g) 55 | 56 | self.fc1 = nn.Linear(1024, 512) 57 | self.fc1_2 = nn.Linear(512, 512) 58 | self.fc1_3 = nn.Linear(512, 512) 59 | self.fc1_4 = nn.Linear(512, 512) 60 | self.fc2 = nn.Linear(512, hidden_size) 61 | 62 | def forward(self, observations): 63 | if self.robot: 64 | observations = torch.cat([ 65 | observations[:, :3], observations[:, 12:15], observations[:, 3:6], observations[:, 15:18], 66 | observations[:, 6:9], observations[:, 18:21], observations[:, 9:12], observations[:, 21:], 67 | ], 1) 68 | if len(observations.shape) == 5: 69 | preprocessed_observations = self.preprocess(observations) 70 | else: 71 | preprocessed_observations = observations 72 | hidden = self.act_fn(self.conv1(preprocessed_observations)) 73 | hidden = self.act_fn(self.conv1_2(hidden)) 74 | hidden = self.act_fn(self.conv2(hidden)) 75 | hidden = self.act_fn(self.conv2_2(hidden)) 76 | hidden = self.act_fn(self.conv3(hidden)) 77 | hidden = self.act_fn(self.conv3_2(hidden)) 78 | hidden = self.act_fn(self.conv4(hidden)) 79 | hidden = self.act_fn(self.conv4_2(hidden)) 80 | hidden = hidden.reshape(preprocessed_observations.shape[0], -1) 81 | 82 | hidden = self.act_fn(self.fc1(hidden)) 83 | hidden = self.act_fn(self.fc1_2(hidden)) 84 | hidden = self.act_fn(self.fc1_3(hidden)) 85 | hidden = self.act_fn(self.fc1_4(hidden)) 86 | hidden = self.fc2(hidden) 87 | 88 | if len(observations.shape) == 5: 89 | return self.unpreprocess(hidden, observations.shape[0], observations.shape[1]) 90 | else: 91 | return hidden 92 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/iq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.core.shape_base import vstack 3 | import torch 4 | from torch._C import qscheme 5 | import torch.nn.functional as F 6 | from torch.distributions import Categorical 7 | import copy 8 | import wandb 9 | 10 | 11 | class IQMixin(object): 12 | # Mixin to Base model that adds extra IQ 13 | 14 | def __init__(self, q_net, args, device): 15 | super().__init__() 16 | 17 | self.gamma = args.gamma 18 | self.args = args 19 | self.device = device 20 | 21 | self.log_alpha = torch.tensor(np.log(args.alpha)).to(self.device) 22 | self.q_net = q_net 23 | 24 | self.train() 25 | 26 | # Create target network 27 | if args.use_target: 28 | self.target_net = copy.deepcopy(q_net) 29 | self.target_net.load_state_dict(self.q_net.state_dict()) 30 | 31 | self.target_net.train() 32 | # self.critic_tau = agent_cfg.critic_tau 33 | # self.critic_target_update_frequency = agent_cfg.critic_target_update_frequency 34 | 35 | def train(self, training=True): 36 | self.training = training 37 | self.q_net.train(training) 38 | 39 | @property 40 | def alpha(self): 41 | return self.log_alpha.exp() 42 | 43 | @property 44 | def critic_net(self): 45 | return self.q_net 46 | 47 | @property 48 | def critic_target_net(self): 49 | return self.target_net 50 | 51 | def iq_choose_action(self, q, sample=True): 52 | with torch.no_grad(): 53 | dist = F.softmax(q/self.alpha, dim=1) 54 | if sample: 55 | dist = Categorical(dist) 56 | action = dist.sample() # if sample else dist.mean 57 | else: 58 | action = torch.argmax(dist, dim=1) 59 | return action 60 | # return action.detach().cpu().numpy()[0] 61 | 62 | def getV(self, q): 63 | v = self.alpha * \ 64 | torch.logsumexp(q/self.alpha, dim=1, keepdim=True) 65 | return v 66 | 67 | def critic(self, q, action): 68 | return q.gather(1, action.long()) 69 | 70 | # Offline IQ-Learn objective 71 | def iq_update_critic(self, expert_batch): 72 | args = self.args 73 | # Assume the expert_batch contains the current state q and the next state q. 74 | q, next_q, action, done = expert_batch 75 | 76 | losses = {} 77 | # keep track of v0 78 | v0 = self.getV(q).mean() 79 | losses['v0'] = v0.item() 80 | 81 | # calculate 1st term of loss 82 | # -E_(ρ_expert)[Q(s, a) - γV(s')] 83 | current_q = self.critic(q, action) 84 | next_v = self.getV(next_q) 85 | y = (1 - done) * self.gamma * next_v 86 | 87 | if args.use_target: 88 | with torch.no_grad(): 89 | target_q = self.target_net(next_obs) 90 | next_v = self.get_V(target_q) 91 | y = (1 - done) * self.gamma * next_v 92 | 93 | reward = current_q - y 94 | 95 | with torch.no_grad(): 96 | if args.div == "hellinger": 97 | phi_grad = 1/(1+reward)**2 98 | elif args.div == "kl": 99 | phi_grad = torch.exp(-reward-1) 100 | elif args.div == "kl2": 101 | phi_grad = F.softmax(-reward, dim=0) * reward.shape[0] 102 | elif args.div == "kl_fix": 103 | phi_grad = torch.exp(-reward) 104 | elif args.div == "js": 105 | phi_grad = torch.exp(-reward)/(2 - torch.exp(-reward)) 106 | else: 107 | phi_grad = 1 108 | 109 | loss = -(phi_grad * reward).mean() 110 | losses['softq_loss'] = loss.item() 111 | 112 | if args.loss == "v0": 113 | # calculate 2nd term for our loss 114 | # (1-γ)E_(ρ0)[V(s0)] 115 | v0_loss = (1 - self.gamma) * v0 116 | loss += v0_loss 117 | losses['v0_loss'] = v0_loss.item() 118 | 119 | elif args.loss == "value": 120 | # alternative 2nd term for our loss (use only expert states) 121 | # E_(ρ)[Q(s,a) - γV(s')] 122 | value_loss = (self.getV(q) - y).mean() 123 | loss += value_loss 124 | losses['value_loss'] = value_loss.item() 125 | 126 | elif args.loss == "skip": 127 | # No loss 128 | pass 129 | 130 | if args.div == "chi": 131 | # Use χ2 divergence (adds a extra term to the loss) 132 | if args.use_target: 133 | with torch.no_grad(): 134 | target_q = self.target_net(next_obs) 135 | next_v = self.getV(target_q) 136 | else: 137 | next_v = self.getV(next_q) 138 | 139 | y = (1 - done) * self.gamma * next_v 140 | 141 | current_q = self.critic(q, action) 142 | reward = current_q - y 143 | chi2_loss = 1/2 * (reward**2).mean() 144 | loss += chi2_loss 145 | losses['chi2_loss'] = chi2_loss.item() 146 | 147 | losses['total_loss'] = loss.item() 148 | 149 | return loss, losses 150 | 151 | def iq_critic_loss(self, batch, step): 152 | # assume we get input of trajectory with state, action, q_predictions, dones 153 | q, next_q, actions, dones = batch 154 | 155 | actions = actions.reshape(-1, 1) 156 | dones = dones.reshape(-1, 1) 157 | 158 | loss, loss_metrics = self.iq_update_critic((q, next_q, actions, dones)) 159 | 160 | if self.args.use_target and step % self.critic_target_update_frequency == 0: 161 | self.target_net.load_state_dict(self.q_net.state_dict()) 162 | 163 | wandb.log(loss_metrics, step=step) 164 | return loss 165 | 166 | def infer_q(self, state, action): 167 | state = torch.FloatTensor(state).to(self.device).unsqueeze(0) 168 | action = torch.FloatTensor([action]).unsqueeze(0).to(self.device) 169 | 170 | with torch.no_grad(): 171 | q = self.critic(state, action) 172 | return q.squeeze(0).cpu().numpy() 173 | 174 | def infer_v(self, state): 175 | state = torch.FloatTensor(state).to(self.device).unsqueeze(0) 176 | with torch.no_grad(): 177 | v = self.getV(state).squeeze() 178 | return v.cpu().numpy() 179 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/option_selector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch._C import Value 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from option_transformer import OptionTransformer 7 | from utils import pad, entropy 8 | from img_encoder import Encoder 9 | 10 | from vector_quantize_pytorch import VectorQuantize 11 | 12 | 13 | class OptionSelector(nn.Module): 14 | 15 | """ 16 | This model takes in the language embedding and the state to output a z from a categorical distribution 17 | Use the VQ trick to pick an option z 18 | """ 19 | 20 | def __init__( 21 | self, state_dim, num_options, option_dim, lang_dim, horizon, num_hidden=None, hidden_size=None, 22 | method='traj_option', option_transformer=None, codebook_dim=16, use_vq=True, kmeans_init=False, 23 | commitment_weight=0.25, **kwargs): 24 | 25 | # option_dim and codebook_dim are different because of the way the VQ package is designed 26 | # if they are different, then there is a projection operation that happens inside the VQ layer 27 | 28 | super().__init__() 29 | 30 | if num_hidden is not None: 31 | assert num_hidden >= 2, "We need at least two hidden layers!" 32 | 33 | self.state_dim = state_dim 34 | self.option_dim = option_dim 35 | self.use_vq = use_vq 36 | self.num_options = num_options 37 | 38 | self.horizon = horizon 39 | self.method = method # whether to use full trajectory to get options or just current state 40 | self.hidden_size = 128 41 | hidden_size = 128 42 | 43 | if option_transformer: 44 | self.hidden_size = option_transformer.hidden_size 45 | 46 | self.Z = VectorQuantize( 47 | dim=option_dim, 48 | codebook_dim=codebook_dim, # codebook vector size 49 | codebook_size=num_options, # codebook size 50 | decay=0.99, # the exponential moving average decay, lower means the dictionary will change faster 51 | commitment_weight=commitment_weight, # the weight on the commitment loss 52 | kmeans_init=kmeans_init, # use kmeans init 53 | cpc=False, 54 | # threshold_ema_dead_code=2, # should actively replace any codes that have an exponential moving average cluster size less than 2 55 | use_cosine_sim=False # l2 normalize the codes 56 | ) 57 | 58 | if self.method == 'traj_option': 59 | option_transformer_args = {'state_dim': state_dim, 60 | 'lang_dim': lang_dim, 61 | 'option_dim': option_dim, 62 | 'hidden_size': option_transformer.hidden_size, 63 | 'max_length': option_transformer.max_length, 64 | 'max_ep_len': option_transformer.max_ep_len, 65 | 'n_layer': option_transformer.n_layer, 66 | 'n_head': option_transformer.n_head, 67 | 'n_inner': 4*option_transformer.hidden_size, 68 | 'activation_function': option_transformer.activation_function, 69 | 'n_positions': option_transformer.n_positions, 70 | 'resid_pdrop': option_transformer.dropout, 71 | 'attn_pdrop': option_transformer.dropout, 72 | 'output_attentions': True # option_transformer.output_attention, 73 | } 74 | self.option_dt = OptionTransformer(**option_transformer_args) 75 | else: 76 | if isinstance(state_dim, tuple): 77 | # LORL 78 | if state_dim[0] == 3: 79 | # LORL Sawyer 80 | self.embed_state = Encoder(hidden_size=hidden_size, ch=3, robot=False) 81 | else: 82 | # LORL Franka 83 | self.embed_state = Encoder(hidden_size=hidden_size, ch=12, robot=True) 84 | else: 85 | self.embed_state = nn.Linear(state_dim, hidden_size) 86 | 87 | z_layers = [] 88 | for i in range(num_hidden): 89 | if i == 0: 90 | z_layers.append(nn.Linear(2*hidden_size, hidden_size)) 91 | elif i == num_hidden-1: 92 | z_layers.append(nn.Linear(hidden_size, option_dim)) 93 | else: 94 | z_layers.append(nn.Linear(hidden_size, hidden_size)) 95 | self.pred_options = nn.Sequential(*z_layers) 96 | self.embed_lang = nn.Linear(lang_dim, hidden_size) 97 | 98 | def forward(self, word_embeddings, states, timesteps=None, attention_mask=None, **kwargs): 99 | if self.method == 'traj_option': 100 | dt_ret = self.option_dt(word_embeddings, states, timesteps, attention_mask) 101 | option_preds = dt_ret[0] 102 | state_embeddings = dt_ret[2] 103 | option_preds = option_preds[:, ::self.horizon, :] 104 | else: 105 | ret_state_embeddings = self.embed_state(states).clone().detach() 106 | horizon_states = states[:, ::self.horizon, :] 107 | state_embeddings = self.embed_state(horizon_states) 108 | lang_embeddings = self.embed_lang(word_embeddings) # these will be cls embeddings or word embeddings mean 109 | 110 | inp = torch.cat([lang_embeddings.repeat( 111 | 1, state_embeddings.shape[1], 1), state_embeddings], dim=-1) 112 | option_preds = self.pred_options(inp) 113 | 114 | state_embeddings = ret_state_embeddings 115 | 116 | if self.use_vq: 117 | options, indices, commitment_loss = self.Z(option_preds) 118 | entropies = entropy(self.Z.codebook, options, self.Z.project_in(option_preds)) 119 | else: 120 | # TODO: For now simply return the first dim of option 121 | options, indices = option_preds, option_preds[:, :, 0] 122 | commitment_loss = None 123 | entropies = None 124 | return options, indices, commitment_loss, entropies, state_embeddings 125 | 126 | def get_option(self, word_embeddings, states, timesteps=None, **kwargs): 127 | 128 | if 'constant_option' in kwargs: 129 | return self.Z.project_out( 130 | self.Z.codebook[kwargs['constant_option']]), torch.tensor( 131 | kwargs['constant_option']) 132 | 133 | if self.method == 'traj_option': 134 | if isinstance(self.state_dim, tuple): 135 | states = states.reshape(1, -1, *self.state_dim) 136 | else: 137 | states = states.reshape(1, -1, self.state_dim) 138 | timesteps = timesteps.reshape(1, -1) 139 | max_length = self.option_dt.max_length 140 | 141 | if max_length is not None: 142 | states = states[:, -max_length:] 143 | timesteps = timesteps[:, -max_length:] 144 | 145 | # pad all tokens to sequence length 146 | attention_mask = pad( 147 | torch.ones(1, states.shape[1]), 148 | max_length).to( 149 | dtype=torch.long, device=states.device).reshape( 150 | 1, -1) 151 | states = pad(states, max_length).to(dtype=torch.float32) 152 | timesteps = pad(timesteps, max_length).to(dtype=torch.long) 153 | else: 154 | attention_mask = None 155 | raise ValueError('Attention mask should not be none') 156 | 157 | options, option_indx, _, _, _ = self.forward( 158 | word_embeddings, states, timesteps, attention_mask=attention_mask, **kwargs) 159 | else: 160 | states = states[:, ::self.horizon, :] 161 | options, option_indx, _, _, _ = self.forward( 162 | word_embeddings, states, None, attention_mask=None, **kwargs) 163 | 164 | return options[0, -1], option_indx[0, -1] 165 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/option_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import transformers 5 | 6 | from trajectory_gpt2 import GPT2Model 7 | from img_encoder import Encoder 8 | 9 | 10 | class OptionTransformer(nn.Module): 11 | 12 | """ 13 | This model uses GPT-2 to select options for every horizon-th state 14 | """ 15 | 16 | def __init__( 17 | self, 18 | state_dim, 19 | lang_dim, 20 | option_dim, 21 | hidden_size, 22 | max_length=None, 23 | max_ep_len=4096, 24 | **kwargs): 25 | super().__init__() 26 | 27 | self.option_dim = option_dim 28 | self.hidden_size = hidden_size 29 | 30 | config = transformers.GPT2Config( 31 | vocab_size=1, # doesn't matter -- we don't use the vocab 32 | n_embd=hidden_size, 33 | **kwargs 34 | ) 35 | 36 | self.state_dim = state_dim 37 | self.max_length = max_length 38 | self.output_attentions = kwargs["output_attentions"] 39 | 40 | if isinstance(state_dim, tuple): 41 | # LORL 42 | if state_dim[0] == 3: 43 | # LORL Sawyer 44 | self.embed_state = Encoder(hidden_size=hidden_size, ch=3, robot=False) 45 | else: 46 | # LORL Franka 47 | self.embed_state = Encoder(hidden_size=hidden_size, ch=12, robot=True) 48 | 49 | else: 50 | self.embed_state = nn.Linear(self.state_dim, hidden_size) 51 | 52 | self.embed_lang = nn.Linear(lang_dim, hidden_size) 53 | 54 | # note: the only difference between this GPT2Model and the default Huggingface version 55 | # is that the positional embeddings are removed (since we'll add those ourselves) 56 | self.transformer = GPT2Model(config) 57 | 58 | self.embed_timestep = nn.Embedding(max_ep_len, hidden_size) 59 | 60 | self.embed_ln = nn.LayerNorm(hidden_size) 61 | self.predict_options = torch.nn.Linear(hidden_size, self.option_dim) 62 | 63 | def forward(self, word_embeddings, states, timesteps, attention_mask, **kwargs): 64 | batch_size, seq_length = states.shape[0], states.shape[1] 65 | num_tokens = word_embeddings.shape[1] 66 | 67 | if attention_mask is None: 68 | raise ValueError('Should not have attention_mask NONE') 69 | # attention mask for GPT: 1 if can be attended to, 0 if not 70 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long) 71 | 72 | state_embeddings = self.embed_state(states) 73 | ret_state_embeddings = state_embeddings.clone().detach() 74 | lang_embeddings = self.embed_lang(word_embeddings) 75 | time_embeddings = self.embed_timestep(timesteps) 76 | 77 | # time embeddings are treated similar to positional embeddings 78 | state_embeddings = state_embeddings + time_embeddings # (batch_size, seq_length, hidden) 79 | lang_and_inputs = torch.cat([lang_embeddings, state_embeddings], dim=1) 80 | # LAYERNORM AFTER LANGUAGE 81 | stacked_inputs = self.embed_ln(lang_and_inputs) 82 | 83 | lang_attn_mask = torch.cat([torch.ones((batch_size, num_tokens), device=states.device), attention_mask], dim=1) 84 | 85 | # we feed in the input embeddings (not word indices as in NLP) to the model 86 | transformer_outputs = self.transformer( 87 | inputs_embeds=stacked_inputs, 88 | attention_mask=lang_attn_mask, 89 | ) 90 | 91 | x = transformer_outputs['last_hidden_state'] 92 | lang_out = x[:, :num_tokens, :].reshape(batch_size, num_tokens, self.hidden_size) 93 | traj_out = x[:, num_tokens:, :].reshape(batch_size, seq_length, self.hidden_size) 94 | 95 | # get predictions 96 | # predict option logits given state 97 | option_preds = self.predict_options(traj_out) 98 | 99 | if self.output_attentions: 100 | attentions = transformer_outputs[-1] 101 | return option_preds, attentions, ret_state_embeddings 102 | 103 | return option_preds, None, ret_state_embeddings 104 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/reconstructors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class StateReconstructor(nn.Module): 7 | """ 8 | This model tries to reconstruct the states from which the options were chosen for each option 9 | """ 10 | 11 | def __init__(self, option_dim, state_dim, num_hidden, hidden_size): 12 | super().__init__() 13 | 14 | assert num_hidden >= 2, "We need at least two hidden layers!" 15 | assert isinstance(state_dim, int), "State dimension has to be integer!" 16 | 17 | layers = [] 18 | for i in range(num_hidden): 19 | if i == 0: 20 | layers.append(nn.Linear(option_dim, hidden_size)) 21 | elif i == num_hidden-1: 22 | layers.append(nn.Linear(hidden_size, state_dim)) 23 | else: 24 | layers.append(nn.Linear(hidden_size, hidden_size)) 25 | self.predictor = nn.Sequential(*layers) 26 | 27 | def forward(self, options): 28 | return self.predictor(options) 29 | 30 | 31 | class LanguageReconstructor(nn.Module): 32 | """ 33 | This model tries to reconstruct the language from all the options 34 | """ 35 | 36 | def __init__(self, option_dim, max_options, lang_dim, num_hidden, hidden_size): 37 | super().__init__() 38 | 39 | assert num_hidden >= 2, "We need at least two hidden layers!" 40 | 41 | self.max_options = max_options 42 | 43 | layers = [] 44 | for i in range(num_hidden): 45 | if i == 0: 46 | layers.append(nn.Linear(max_options*option_dim, hidden_size)) 47 | elif i == num_hidden-1: 48 | layers.append(nn.Linear(hidden_size, lang_dim)) 49 | else: 50 | layers.append(nn.Linear(hidden_size, hidden_size)) 51 | self.predictor = nn.Sequential(*layers) 52 | 53 | def forward(self, options): 54 | options = F.pad(options, pad=(1, self.max_options-options.shape[1])) 55 | return self.predictor(options) 56 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/trajectory_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class TrajectoryModel(nn.Module): 7 | 8 | def __init__(self, state_dim, act_dim, max_length=None): 9 | super().__init__() 10 | 11 | self.state_dim = state_dim 12 | self.act_dim = act_dim 13 | self.max_length = max_length 14 | 15 | def forward(self, states, actions, rewards, masks=None, attention_mask=None): 16 | # "masked" tokens or unspecified inputs can be passed in as None 17 | return None, None, None 18 | 19 | def get_action(self, states, actions, rewards, **kwargs): 20 | # these will come as tensors on the correct device 21 | return torch.zeros_like(actions[-1]) 22 | -------------------------------------------------------------------------------- /skilldiffuser/hrl/viz.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | from typing import Dict, Iterable, Callable 3 | from collections import Counter 4 | from itertools import chain 5 | import numpy as np 6 | import torch 7 | import seaborn as sns 8 | import torch.nn as nn 9 | from torch import Tensor 10 | import matplotlib.pyplot as plt 11 | import wandb 12 | 13 | 14 | class Attention(nn.Module): 15 | def __init__(self, model: nn.Module): 16 | super().__init__() 17 | self.model = model 18 | self._attention = None 19 | 20 | model.option_selector.option_dt.register_forward_hook(self.save_attention_hook()) 21 | 22 | def save_attention_hook(self) -> Callable: 23 | def fn(model, input, output): 24 | self._attention = output[-1] 25 | return fn 26 | 27 | def forward(self, x: Tensor) -> Dict[str, Tensor]: 28 | _ = self.model(x) 29 | return self._attention 30 | 31 | 32 | def get_tokens(inputs, tokenizer): 33 | input_ids = inputs['input_ids'] 34 | input_id_list = input_ids[0].tolist() # Batch index 0 35 | tokens = tokenizer.convert_ids_to_tokens(input_id_list)[1:] 36 | return tokens 37 | 38 | 39 | def viz_matrix(words_dict, num_options, step, skip_words): 40 | # skip_words = ['go', 'to', 'the', 'a', '[SEP]'] 41 | words = sorted(set(chain(*words_dict.values())) - set(skip_words)) 42 | 43 | def w_to_ind(word): 44 | return words.index(word) 45 | 46 | matrix = np.zeros([len(words), num_options]) 47 | 48 | for o in range(num_options): 49 | for w in words_dict[o]: 50 | if w not in skip_words: 51 | matrix[w_to_ind(w), o] += 1 52 | 53 | # plot co-occurence matrix (words x options) 54 | plt.figure(figsize=(30, 10)) 55 | sns.heatmap(matrix, yticklabels=words) 56 | # plt.plot() 57 | wandb.log({"Correlation Matrix": wandb.Image(plt)}, step=step) 58 | 59 | # Now if we normalize it by column (word freq for each option) 60 | plt.figure(figsize=(30, 10)) 61 | matrix_norm_col = (matrix)/(matrix.sum(axis=0, keepdims=True) + 1e-6) 62 | sns.heatmap(matrix_norm_col, yticklabels=words) 63 | wandb.log({"Word Freq Matrix": wandb.Image(plt)}, step=step) 64 | 65 | # Now if we normalize it by row (option freq for each word) 66 | plt.figure(figsize=(30, 10)) 67 | matrix_norm_row = (matrix)/(matrix.sum(axis=1, keepdims=True) + 1e-6) 68 | sns.heatmap(matrix_norm_row, yticklabels=words) 69 | wandb.log({"Option Freq Matrix": wandb.Image(plt)}, step=step) 70 | plt.close() 71 | 72 | 73 | def viz_matrix2(words_dict, num_options, step, skip_words, labelsize=15): # labelsize=12 74 | # skip_words = ['go', 'to', 'the', 'a', '[SEP]'] 75 | skip_words.extend(['and',',','.','t']) 76 | words_set = set(chain(*words_dict.values())) 77 | for item in words_set: 78 | if item.startswith('##'): 79 | skip_words.append(item) 80 | 81 | words = sorted(set(chain(*words_dict.values())) - set(skip_words)) 82 | # "/home/zxliang/new-code/LISA-hot/lisa" 83 | 84 | def w_to_ind(word): 85 | return words.index(word) 86 | 87 | matrix = np.zeros([len(words), num_options]) 88 | 89 | for o in range(num_options): 90 | for w in words_dict[o]: 91 | if w not in skip_words: 92 | matrix[w_to_ind(w), o] += 1 93 | 94 | # pdb.set_trace() 95 | 96 | words = ["faucet" if x == "fa" else x for x in words] 97 | 98 | # plot co-occurence matrix (words x options) 99 | plt.figure(figsize=(20, 10)) 100 | plt.tick_params(labelsize=labelsize) 101 | ax=sns.heatmap(matrix, yticklabels=words) 102 | colorbar = ax.collections[0].colorbar 103 | colorbar.ax.tick_params(labelsize=labelsize) 104 | # plt.plot() 105 | wandb.log({"Correlation Matrix": wandb.Image(plt)}, step=step) 106 | 107 | # Now if we normalize it by column (word freq for each option) 108 | plt.figure(figsize=(20, 10)) 109 | plt.tick_params(labelsize=labelsize) 110 | matrix_norm_col = (matrix)/(matrix.sum(axis=0, keepdims=True) + 1e-6) 111 | ax=sns.heatmap(matrix_norm_col, yticklabels=words) 112 | colorbar = ax.collections[0].colorbar 113 | colorbar.ax.tick_params(labelsize=labelsize) 114 | wandb.log({"Word Freq Matrix": wandb.Image(plt)}, step=step) 115 | 116 | # Now if we normalize it by row (option freq for each word) 117 | plt.figure(figsize=(20, 10)) 118 | plt.tick_params(labelsize=labelsize) 119 | matrix_norm_row = (matrix)/(matrix.sum(axis=1, keepdims=True) + 1e-6) 120 | ax=sns.heatmap(matrix_norm_row, yticklabels=words) 121 | colorbar = ax.collections[0].colorbar 122 | colorbar.ax.tick_params(labelsize=labelsize) 123 | wandb.log({"Option Freq Matrix": wandb.Image(plt)}, step=step) 124 | 125 | # plt.show() 126 | plt.close() 127 | 128 | 129 | def plot_hist(stats): 130 | plt.clf() 131 | plt.bar(stats.keys(), stats.values()) 132 | plt.xticks(rotation='vertical') 133 | return wandb.Image(plt) 134 | -------------------------------------------------------------------------------- /skilldiffuser/matrix.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Liang-ZX/SkillDiffuser/bc9da40fb98b8e65797005d3d9293d3d75665d9d/skilldiffuser/matrix.npy -------------------------------------------------------------------------------- /skilldiffuser/pkl2pkl.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import h5py 4 | import numpy as np 5 | import pickle 6 | import pandas as pd 7 | import os 8 | 9 | root_path = '/home/zxliang/dataset/lorel/may_08_sawyer_50k' 10 | # f = h5py.File(os.path.join(root_path, 'data.hdf5'),'r') 11 | # f.visit(lambda x: print(x)) 12 | f = open(os.path.join(root_path, 'prep_data.pkl'),'rb') 13 | 14 | # data = dict() 15 | 16 | df = pd.read_table(os.path.join(root_path, "labels.csv"), sep=",") 17 | langs = df["Text Description"].str.strip().to_numpy().reshape(-1) 18 | langs = np.array(['' if x is np.isnan else x for x in langs]) 19 | filtr1 = np.array([int(("nothing" in l) or ("nan" in l) or ("wave" in l)) for l in langs]) 20 | filtr = filtr1 == 0 21 | # data['langs'] = langs[filtr] 22 | 23 | data = pickle.load(f) 24 | for key in data.keys(): 25 | data[key] = data[key][filtr, ...] 26 | 27 | with open(os.path.join(root_path, "prep_data4.pkl"),"wb") as fo: 28 | pickle.dump(data, fo) 29 | -------------------------------------------------------------------------------- /skilldiffuser/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pandas 3 | matplotlib 4 | torch 5 | seaborn 6 | transformers==4.12.0 7 | scipy 8 | tensorboardX 9 | hydra-core 10 | omegaconf 11 | gym==0.15.4 12 | wandb 13 | opencv-python 14 | cython<3 15 | git+https://github.com/rlworkgroup/metaworld.git@b016e6a25e485f1ffa8ccbf52df54ac204a81f31#egg=metaworld 16 | ml_logger==0.7.5 17 | einops 18 | mujoco-py==2.0.2.5 19 | matplotlib==3.3.4 20 | #torch==1.9.1+cu111 21 | #torchvision==0.10.1+cu111 22 | #torch==1.7.1+cu101 23 | #torchvision==0.8.2+cu101 24 | typed-argument-parser 25 | git+https://github.com/rail-berkeley/d4rl@master#egg=d4rl 26 | #git+https://github.com/rail-berkeley/d4rl.git@c39eefd68d2f3277ca68e996a45ce1dd24e65625 27 | scikit-image==0.17.2 28 | scikit-video==1.1.11 29 | gitpython 30 | pillow 31 | tqdm 32 | pygame 33 | networkx 34 | dotmap -------------------------------------------------------------------------------- /skilldiffuser/train_lorel_compose.sh: -------------------------------------------------------------------------------- 1 | python hrl/main_hot.py env=lorel_sawyer_obs method=traj_option dt.n_layer=1 dt.n_head=4 option_selector.option_transformer.n_layer=1 option_selector.option_transformer.n_head=4 option_selector.commitment_weight=0.1 option_selector.option_transformer.hidden_size=128 batch_size=256 seed=1 warmup_steps=5000 resume=True checkpoint_path=${checkpoint_pretrain_ckpt} diffuser.loadpath=${diffuser_pretrain_ckpt} 2 | --------------------------------------------------------------------------------