├── .gitignore
├── README.md
├── assets
└── framework.png
└── skilldiffuser
├── README.md
├── eval_lorel.sh
├── h5py2pkl.py
├── hrl
├── conf
│ ├── agent
│ │ └── big_dt.yaml
│ ├── config.yaml
│ ├── env
│ │ ├── hopper.yaml
│ │ ├── lorel_franka.yaml
│ │ ├── lorel_sawyer_obs.yaml
│ │ └── lorel_sawyer_state.yaml
│ └── model
│ │ ├── option.yaml
│ │ ├── traj_option.yaml
│ │ └── vanilla.yaml
├── dec_diffuser.py
├── dec_encoder.py
├── decision_transformer.py
├── difftrainer.py
├── diffuser
│ ├── __init__.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── buffer.py
│ │ ├── d4rl.py
│ │ ├── normalization.py
│ │ ├── preprocessing.py
│ │ └── sequence.py
│ ├── environments
│ │ ├── __init__.py
│ │ ├── ant.py
│ │ ├── assets
│ │ │ ├── ant.xml
│ │ │ ├── half_cheetah.xml
│ │ │ ├── hopper.xml
│ │ │ └── walker2d.xml
│ │ ├── half_cheetah.py
│ │ ├── hopper.py
│ │ ├── registration.py
│ │ └── walker2d.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── diffusion.py
│ │ ├── helpers.py
│ │ └── temporal.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── arrays.py
│ │ ├── cloud.py
│ │ ├── colab.py
│ │ ├── config.py
│ │ ├── git_utils.py
│ │ ├── iql.py
│ │ ├── progress.py
│ │ ├── pybullet_utils.py
│ │ ├── rendering.py
│ │ ├── serialization.py
│ │ ├── setup.py
│ │ ├── timer.py
│ │ ├── training.py
│ │ ├── transformations.py
│ │ └── video.py
├── env.py
├── eval.py
├── eval_orig.py
├── expert_dataset.py
├── hrl_model.py
├── img_encoder.py
├── iq.py
├── main_hot.py
├── option_selector.py
├── option_transformer.py
├── reconstructors.py
├── trainer.py
├── trajectory_gpt2.py
├── trajectory_model.py
├── utils.py
├── vector_quantize_pytorch.py
├── viz-lorl.ipynb
├── viz-lorl.py
└── viz.py
├── matrix.npy
├── pkl2pkl.py
├── requirements.txt
└── train_lorel_compose.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | outputs/
3 | checkpoints/
4 | wandb/
5 | alfred/data
6 | alfworld/storage
7 | eval_*/
8 | skilldiffuser/outputs
9 | buckets
10 | alfred/exp
11 | lorel
12 | metaworld
13 | hrl/gifs
14 |
15 | # Jupyter Notebook
16 | .ipynb_checkpoints
17 |
18 | .venv
19 | .vscode
20 |
21 | babyai/babyai.egg-info/
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SkillDiffuser on LOReL Compositional Tasks
2 |
3 | This is a part of the official PyTorch implementation of paper:
4 |
5 | > ### [SkillDiffuser: Interpretable Hierarchical Planning via Skill Abstractions in Diffusion-Based Task Execution](https://arxiv.org/abs/2312.11598)
6 | > Zhixuan Liang, Yao Mu, Hengbo Ma, Masayoshi Tomizuka, Mingyu Ding, Ping Luo
7 | >
8 | > CVPR 2024
9 | >
10 | > [Project Page](https://skilldiffuser.github.io/) | [Paper](https://arxiv.org/abs/2312.11598)
11 |
12 | ### Framework of SkillDiffuser
13 |
14 |
15 |
16 | SkillDiffuser is a hierarchical planning model that leverages the cooperation of interpretable skill abstractions at the higher level and a skill conditioned diffusion model at the lower level for task execution in a multi-task learning environment. The high-level skill abstraction is achieved through a skill predictor and a vector quantization operation, generating sub-goals (skill set) that the diffusion model employs to determine the appropriate future states. Future states are converted to actions using an inverse dynamics model. This unique fusion enables a consistent underlying planner across different tasks, with the variation only in the inverse dynamics model.
17 |
18 | ### Citation
19 |
20 | ```bibtex
21 | @article{liang2023skilldiffuser,
22 | title={Skilldiffuser: Interpretable hierarchical planning via skill abstractions in diffusion-based task execution},
23 | author={Liang, Zhixuan and Mu, Yao and Ma, Hengbo and Tomizuka, Masayoshi and Ding, Mingyu and Luo, Ping},
24 | journal={arXiv preprint arXiv:2312.11598},
25 | year={2023}
26 | }
27 | ```
28 |
29 | ### Code
30 |
31 | To install and use SkillDiffuser check the instructions provided in the [skilldiffuser](skilldiffuser) folder.
32 |
33 |
34 | ### Acknowledgements
35 |
36 | The diffusion model implementation is based on Michael Janner's [diffuser](https://github.com/jannerm/diffuser) repo.
37 | The organization of this repo and remote launcher is based on the [LISA](https://github.com/Div99/LISA) repo.
38 |
39 | ### Questions
40 | Please email us if you have any questions.
41 |
42 | Zhixuan Liang ([liangzx@connect.hku.hk](mailto:liangzx@connect.hku.hk?subject=[GitHub]%skilldiffuser))
43 |
44 |
--------------------------------------------------------------------------------
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liang-ZX/SkillDiffuser/bc9da40fb98b8e65797005d3d9293d3d75665d9d/assets/framework.png
--------------------------------------------------------------------------------
/skilldiffuser/README.md:
--------------------------------------------------------------------------------
1 | # SkillDiffuser
2 |
3 | Zhixuan Liang, Yao Mu, Hengbo Ma, Masayoshi Tomizuka, Mingyu Ding, Ping Luo
4 |
5 | ## Usage
6 | ### Setup Python Environment
7 | 1. Install MuJoCo 200
8 | ```shell
9 | unzip mujoco200_linux.zip
10 | mv mujoco200_linux mujoco200
11 | cp mjkey.txt ~/.mujoco
12 | cp mjkey.txt ~/.mujoco/mujoco200/bin
13 |
14 | # test the install
15 | cd ~/.mujoco/mujoco200/bin
16 | ./simulate ../model/humanoid.xml
17 |
18 | # add environment variables
19 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco200/bin
20 | export MUJOCO_KEY_PATH=~/.mujoco/${MUJOCO_KEY_PATH}
21 | ```
22 | 2. Install Pypi Packages
23 | ```shell
24 | pip install -r requirements.txt
25 | ```
26 | 3. Install LOReL Environment
27 | ```shell
28 | git clone https://github.com/suraj-nair-1/lorel.git
29 |
30 | cd lorel/env
31 | pip install -e .
32 | ```
33 |
34 | ### Setup LOReL Dataset
35 | 1. Download the dataset from [LOReL](https://drive.google.com/file/d/1pLnctqkOzyWZa1F1zTFqkNgUzSkUCtEv/view?usp=sharing)
36 | 2. Process the dataset from h5py to pickle
37 | ```shell
38 | python h5py2pkl.py --root_path --output_name
39 | ```
40 | 3. Change the path in `hrl/conf/env/lorel_sawyer_obs.yaml` to the processed dataset.
41 |
42 | ## Instructions
43 |
44 | Our code for running SkillDiffuser experiments is present in `hrl` folder.
45 |
46 | To run the code, please use the following command:
47 |
48 | `./train_lorel_compose.sh`
49 |
50 | This is a sample command intended to show the usage of different flags available. The checkpoints can be downloaded from [here](https://connecthkuhk-my.sharepoint.com/:f:/g/personal/liangzx_connect_hku_hk/Em3qBc3AxWpOkYR7Pgd7lnUBH0bkLsILMpgUX2Xg5l3YGg?e=QslO6n).
51 | (The checkpoint is used for fine-tuning, not for evaluation directly.)
52 |
53 | If you would like to evaluate the model directly, please see [this issue](https://github.com/Liang-ZX/SkillDiffuser/issues/2#issuecomment-2377606631).
54 |
55 | ## License
56 |
57 | The code is made available for academic, non-commercial usage.
58 |
59 | For any inquiry, contact: Zhixuan Liang (liangzx@connect.hku.hk)
60 |
--------------------------------------------------------------------------------
/skilldiffuser/eval_lorel.sh:
--------------------------------------------------------------------------------
1 | BASE_PATH="/home/zxliang/new-code/LISA/lisa/outputs/2023-03-09/20-56-04/checkpoints/LorlEnv-v0-40108-traj_option-2023-03-09-20:56:04"
2 | ITER=500
3 | CKPT="${BASE_PATH}/model_${ITER}.ckpt"
4 |
5 | #method=option_dt
6 | python hrl/main.py env=lorel_sawyer_obs method=traj_option dt.n_layer=1 dt.n_head=4 option_selector.option_transformer.n_layer=1 option_selector.option_transformer.n_head=4 option_selector.commitment_weight=0.1 option_selector.option_transformer.hidden_size=128 batch_size=256 seed=1 warmup_steps=5000 eval=True render=True checkpoint_path=${CKPT}
7 |
--------------------------------------------------------------------------------
/skilldiffuser/h5py2pkl.py:
--------------------------------------------------------------------------------
1 | import pdb
2 |
3 | import h5py
4 | import numpy as np
5 | import pickle
6 | import pandas as pd
7 | import os
8 | import argparse
9 |
10 | # parse args
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument('--root_path', default=None, type=str, help='Absolute path to the root directory of the dataset')
13 | parser.add_argument('--output_name', default="prep_data.pkl", type=str, help='Output file name')
14 | args = parser.parse_args()
15 |
16 | root_path = args.root_path
17 | f = h5py.File(os.path.join(root_path, 'data.hdf5'),'r')
18 | # f.visit(lambda x: print(x))
19 |
20 | data = dict()
21 |
22 | df = pd.read_table(os.path.join(root_path, "labels.csv"), sep=",")
23 | langs = df["Text Description"].str.strip().to_numpy().reshape(-1)
24 | langs = np.array(['' if x is np.isnan else x for x in langs])
25 | filtr1 = np.array([int(("nothing" in l) or ("nan" in l) or ("wave" in l)) for l in langs])
26 | filtr = filtr1 == 0
27 | data['langs'] = langs[filtr]
28 |
29 | for group in f.keys():
30 | for key in f[group].keys():
31 | print(group, key)
32 | data[key] = f[group][key][:][filtr]
33 |
34 | with open(os.path.join(root_path, args.output_name), "wb") as fo:
35 | pickle.dump(data, fo, protocol=4)
36 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/conf/agent/big_dt.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | agent:
4 | hidden_size: 288
5 | n_layer: 6
6 | n_head: 6
7 | activation_function: 'relu'
8 | n_positions: 1024
9 | dropout: 0.1
10 |
11 |
12 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/conf/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | # Set default options
3 | - _self_
4 | - model: traj_option
5 | - env: lorel_sawyer_obs
6 |
7 | cuda_deterministic: False
8 |
9 | wandb: True
10 |
11 | seed: 0
12 | resume: False
13 | load_options: False
14 | load_frozen: False
15 | freeze_loaded_options: False
16 | checkpoint_path:
17 |
18 | eval: False
19 |
20 | render: False # False
21 | render_path: ./eval_${env.name}/
22 |
23 | batch_size: 64 # 512
24 | max_iters: 500 # TODO
25 | warmup_steps: 2500 # 5000
26 |
27 | lr_decay: 0.1
28 | decay_steps: 100000
29 |
30 | # options config
31 | option_dim: 128
32 | codebook_dim: 16
33 |
34 | parallel: False # True
35 | savedir: 'checkpoints'
36 | savepath: ## to be filled in code
37 |
38 | method: ## to be filled in code
39 | use_iq: False ## use IQ-Learn objective instead of BC
40 |
41 | learning_rate: 1e-4
42 | lm_learning_rate: 1e-6
43 | weight_decay: 1e-4
44 | os_learning_rate: 1e-4
45 |
46 | trainer:
47 | device: ## to be filled in code
48 | state_il: False
49 | num_eval_episodes: 100
50 | eval_every: 1
51 | K: ${model.K}
52 |
53 | model:
54 | # Model specific configuration
55 |
56 | env:
57 | # Env specific configuration
58 | skip_words: ['go', 'to', 'the', 'a', '[SEP]']
59 |
60 | option_selector:
61 | # Option configuration
62 | option_transformer:
63 |
64 | iq:
65 | alpha: 0.1
66 | div: chi
67 | loss: value
68 | gamma: 0.99
69 | # Don't use target updates
70 | use_target: False
71 |
72 | # Extra args
73 | log_interval: 1 # Log every this many iterations
74 | save_interval: 50 # Save networks every this many iterations
75 | hydra_base_dir: ""
76 | exp_name: ''
77 | project_name: ${env.name}
78 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/conf/env/hopper.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | env:
4 | name: Hopper-v2
5 | state_dim: 11
6 | action_dim: 3
7 | discrete: False
8 | eval_offline: False
9 | eval_episode_factor: 2
10 | eval_env:
11 |
12 | train_dataset:
13 | expert_location: '/sailhome/divgarg/implicit-irl/experts/Hopper-v2_25.pkl'
14 | num_trajectories: 10
15 | normalize_states: False
16 | no_lang: False
17 | seed: ${seed}
18 |
19 | val_dataset:
20 | expert_location:
21 | num_trajectories: ${trainer.num_eval_episodes}
22 | normalize_states: False
23 | seed: ${seed}
24 |
25 | codebook_dim: 8
26 |
27 | trainer:
28 | device: ## to be filled in code
29 | state_il: False
30 | num_eval_episodes: 5
31 | eval_every: 5
32 | K: ${model.K}
33 |
34 | model:
35 | train_lm: False
36 |
37 | option_selector:
38 | commitment_weight: 20
39 | option_transformer:
40 | n_layer: 2
41 | n_head: 4
42 |
43 | dt:
44 | hidden_size: 128
45 | n_layer: 2
46 | n_head: 4
47 | activation_function: 'relu'
48 | n_positions: 1024
49 | dropout: 0.1
50 | no_states: false
51 | no_actions: false
52 |
53 | learning_rate: 1e-4
54 | lm_learning_rate: 1e-6
55 | weight_decay: 1e-4
56 | os_learning_rate: 1e-4
--------------------------------------------------------------------------------
/skilldiffuser/hrl/conf/env/lorel_franka.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | env:
4 | name:
5 | state_dim: (12, 64, 64)
6 | action_dim: 5
7 | discrete: False
8 | eval_offline: True
9 | use_state: False
10 |
11 | train_dataset:
12 | expert_location: '/atlas/u/divgarg/datasets/lorel/may_06_franka_3k/prep_data3.pkl'
13 | num_trajectories: 10000
14 | normalize_states: True
15 | no_lang: False
16 | seed: ${seed}
17 | aug: True
18 |
19 | val_dataset:
20 | expert_location:
21 | num_trajectories: ${trainer.num_eval_episodes}
22 | normalize_states: True
23 | seed: ${seed}
24 |
25 | trainer:
26 | device: ## to be filled in code
27 | state_il: False
28 | num_eval_episodes: 5
29 | eval_every: 5
30 | K: ${model.K}
--------------------------------------------------------------------------------
/skilldiffuser/hrl/conf/env/lorel_sawyer_obs.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | learning_rate: 1e-5
4 | lm_learning_rate: 1e-7
5 | weight_decay: 1e-4
6 | os_learning_rate: 1e-6 # 1e-5
7 |
8 | env:
9 | name: LorlEnv-v0
10 | state_dim: (3, 64, 64)
11 | action_dim: 5
12 | discrete: False
13 | eval_offline: False
14 | use_state: False
15 | eval_episode_factor: 10 # 10
16 | eval_env:
17 |
18 | train_dataset:
19 | expert_location: '/path/to/prep_data.pkl'
20 | num_trajectories: 40108
21 | normalize_states: False
22 | no_lang: False
23 | seed: ${seed}
24 | aug: True
25 |
26 | val_dataset:
27 | expert_location:
28 | num_trajectories: ${trainer.num_eval_episodes}
29 | normalize_states: False
30 | seed: ${seed}
31 | aug: True
32 |
33 |
34 | trainer:
35 | device: ## to be filled in code
36 | state_il: False
37 | num_eval_episodes: 5 # 5
38 | eval_every: 20
39 | K: ${model.K}
--------------------------------------------------------------------------------
/skilldiffuser/hrl/conf/env/lorel_sawyer_state.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | env:
4 | name: LorlEnv-v0
5 | state_dim: 15
6 | action_dim: 5
7 | discrete: False
8 | eval_offline: False
9 | use_state: True
10 | eval_episode_factor: 10
11 |
12 | train_dataset:
13 | expert_location: '/home/zxliang/dataset/lorel/may_08_sawyer_50k/prep_data3.pkl'
14 | num_trajectories: 40108
15 | normalize_states: False
16 | seed: ${seed}
17 |
18 | val_dataset:
19 | expert_location:
20 | num_trajectories: ${trainer.num_eval_episodes}
21 | normalize_states: False
22 | no_lang: False
23 | seed: ${seed}
24 |
25 |
26 | trainer:
27 | device: ## to be filled in code
28 | state_il: False
29 | num_eval_episodes: 5
30 | eval_every: 5
31 | K: ${model.K}
--------------------------------------------------------------------------------
/skilldiffuser/hrl/conf/model/option.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | ema_decay: 0.995
4 | gradient_accumulate_every: 2
5 |
6 | diffuser:
7 | # model: 'models.TemporalUnet'
8 | # diffusion: 'models.GaussianDiffusion'
9 | # horizon: ${model.horizon}
10 | ## dim_mults: '(1, 4, 8)'
11 | # n_diffusion_steps: 128 # 512
12 | # loss_type: 'l2'
13 | # clip_denoised: True
14 | # predict_epsilon: False
15 | # ## loss weighting
16 | # action_weight: 1
17 | ## loss_weights: None
18 | # loss_discount: 1
19 | # savepath: '/home/zxliang/new-code/LISA/lisa/outputs/diffuser_debug'
20 |
21 | model: 'models.TemporalUnet'
22 | diffusion: 'models.GaussianInvDynDiffusion'
23 | horizon: ${model.horizon} # 100
24 | dim_mults: '(1, 4, 8)'
25 | n_diffusion_steps: 200 # 512
26 | # loss_type: 'l2'
27 | # clip_denoised: True
28 | predict_epsilon: True
29 | ## loss weighting
30 | action_weight: 10
31 | loss_weights: None
32 | loss_discount: 1
33 |
34 | returns_condition: True
35 | calc_energy: False
36 | condition_dropout: 0.25
37 | condition_guidance_w: 1.2
38 | test_ret: 0.9
39 | # renderer: 'utils.MuJoCoRenderer'
40 | dim: 128
41 | savepath: 'outputs/diffuser_debug'
42 |
43 | ## dataset
44 | loader: 'datasets.SequenceDataset'
45 | normalizer: 'CDFNormalizer'
46 | preprocess_fns: []
47 | clip_denoised: True
48 | use_padding: True
49 | include_returns: True
50 | discount: 0.99
51 | max_path_length: 1000
52 | inv_hidden_dim: 256
53 | ar_inv: False
54 | train_only_inv: False
55 | termination_penalty: -100
56 | returns_scale: 400.0 # Determined using rewards from the dataset
57 |
58 | ## training
59 | n_steps_per_epoch: 10000
60 | loss_type: 'l2'
61 | n_train_steps: 8e5 # 8e5 # 1e6
62 | batch_size: 64 # equal to args.batch_size * fold num
63 | learning_rate: 5e-3 # 2e-4
64 | gradient_accumulate_every: 1 # 2
65 | ema_decay: 0.995
66 | log_freq: 1000
67 | save_freq: 10000
68 | sample_freq: 10000
69 | n_saves: 5
70 | save_parallel: False
71 | n_reference: 8
72 | save_checkpoints: True
73 | loadpath: ## to be filled in code
74 |
75 | ## misc
76 | bucket: ''
77 | seed: 100
78 |
79 | model:
80 | name: option
81 |
82 | horizon: 8
83 | K: 8
84 | train_lm: False
85 | use_iq: ${use_iq}
86 | method: ${model.name}
87 | state_reconstruct: False
88 | lang_reconstruct: False
89 |
90 | state_reconstructor:
91 | num_hidden: 2
92 | hidden_size: 128
93 |
94 | lang_reconstructor:
95 | num_hidden: 2
96 | hidden_size: 128
97 | max_options: ## to be filled in code
98 |
99 | option_selector:
100 | horizon: ${model.horizon}
101 | use_vq: True
102 | kmeans_init: True
103 | commitment_weight: 0.25
104 | num_options: 20
105 | num_hidden: 2
106 | hidden_size: 128
107 |
108 | dt:
109 | hidden_size: 128
110 | n_layer: 4
111 | n_head: 4
112 | option_il: False
113 | activation_function: 'relu'
114 | n_positions: 1024
115 | dropout: 0.1
116 | no_actions: False
117 | no_states: False
118 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/conf/model/traj_option.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | ema_decay: 0.995
4 | gradient_accumulate_every: 2
5 |
6 | diffuser:
7 | # model: 'models.TemporalUnet'
8 | # diffusion: 'models.GaussianDiffusion'
9 | # horizon: ${model.horizon}
10 | ## dim_mults: '(1, 4, 8)'
11 | # n_diffusion_steps: 128 # 512
12 | # loss_type: 'l2'
13 | # clip_denoised: True
14 | # predict_epsilon: False
15 | # ## loss weighting
16 | # action_weight: 1
17 | ## loss_weights: None
18 | # loss_discount: 1
19 | # savepath: '/home/zxliang/new-code/LISA/lisa/outputs/diffuser_debug'
20 |
21 | model: 'models.TemporalUnet'
22 | diffusion: 'models.GaussianInvDynDiffusion'
23 | horizon: ${model.horizon} # 100
24 | dim_mults: '(1, 4, 8)'
25 | n_diffusion_steps: 200 # 512
26 | # loss_type: 'l2'
27 | # clip_denoised: True
28 | predict_epsilon: True
29 | ## loss weighting
30 | action_weight: 10
31 | loss_weights: None
32 | loss_discount: 1
33 |
34 | returns_condition: True
35 | calc_energy: False
36 | condition_dropout: 0.25
37 | condition_guidance_w: 1.2
38 | test_ret: 0.9
39 | # renderer: 'utils.MuJoCoRenderer'
40 | dim: 128
41 | savepath: 'outputs/diffuser_debug'
42 |
43 | ## dataset
44 | loader: 'datasets.SequenceDataset'
45 | normalizer: 'CDFNormalizer'
46 | preprocess_fns: []
47 | clip_denoised: True
48 | use_padding: True
49 | include_returns: True
50 | discount: 0.99
51 | max_path_length: 1000
52 | inv_hidden_dim: 256
53 | ar_inv: False
54 | train_only_inv: False
55 | termination_penalty: -100
56 | returns_scale: 400.0 # Determined using rewards from the dataset
57 |
58 | ## training
59 | n_steps_per_epoch: 10000
60 | loss_type: 'l2'
61 | n_train_steps: 5e5 # 8e5 # 1e6
62 | batch_size: 64 # 32 equal to args.batch_size
63 | learning_rate: 1e-3 # 2e-4
64 | gradient_accumulate_every: 1 # 2
65 | ema_decay: 0.995
66 | log_freq: 1000
67 | save_freq: 10000
68 | sample_freq: 10000
69 | n_saves: 5
70 | save_parallel: False
71 | n_reference: 8
72 | save_checkpoints: True
73 | loadpath: ## to be filled in code
74 |
75 | ## misc
76 | bucket: ''
77 | seed: 100
78 |
79 | model:
80 | name: traj_option
81 |
82 | horizon: 8 # 10
83 | K: 8 # 10
84 | train_lm: False
85 | use_iq: ${use_iq}
86 | method: ${model.name}
87 | state_reconstruct: False
88 | lang_reconstruct: False
89 |
90 | state_reconstructor:
91 | num_hidden: 2
92 | hidden_size: 128
93 |
94 | lang_reconstructor:
95 | num_hidden: 2
96 | hidden_size: 128
97 | max_options: ## to be filled in code
98 |
99 | option_selector:
100 | horizon: ${model.horizon}
101 | use_vq: True
102 | kmeans_init: True
103 | commitment_weight: 0.25
104 | num_options: 20
105 | num_hidden: 2
106 | option_transformer:
107 | hidden_size: 128
108 | n_layer: 1
109 | n_head: 4
110 | max_length:
111 | max_ep_len:
112 | activation_function: 'relu'
113 | n_positions: 1024
114 | dropout: 0.1
115 | output_attention: False
116 |
117 | dt:
118 | hidden_size: 128
119 | n_layer: 1
120 | n_head: 4
121 | option_il: False
122 | activation_function: 'relu'
123 | n_positions: 1024
124 | dropout: 0.1
125 | no_actions: False
126 | no_states: False
127 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/conf/model/vanilla.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | model:
4 | name: vanilla
5 |
6 | horizon: 5
7 | K: 5
8 | train_lm: True
9 | method: ${model.name}
10 | state_reconstruct: False
11 | lang_reconstruct: False
12 |
13 | state_reconstructor:
14 | num_hidden: 2
15 | hidden_size: 128
16 |
17 | lang_reconstructor:
18 | num_hidden: 2
19 | hidden_size: 128
20 | max_options: ## to be filled in code
21 |
22 | dt:
23 | option_il: False
24 | hidden_size: 128
25 | n_layer: 2
26 | n_head: 4
27 | activation_function: 'relu'
28 | n_positions: 1024
29 | dropout: 0.1
30 | no_states: False
31 | no_actions: False
--------------------------------------------------------------------------------
/skilldiffuser/hrl/decision_transformer.py:
--------------------------------------------------------------------------------
1 | import pdb
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import transformers
8 |
9 | from trajectory_model import TrajectoryModel
10 | from trajectory_gpt2 import GPT2Model
11 | from img_encoder import Encoder
12 |
13 | import iq
14 | from utils import pad
15 |
16 |
17 | class DecisionTransformer(TrajectoryModel):
18 |
19 | """
20 | This model uses GPT to model (Lang, state_1, action_1, state_2, ...) or (state_1, option_1, action_1, ...)
21 | """
22 |
23 | def __init__(
24 | self,
25 | state_dim,
26 | action_dim,
27 | option_dim,
28 | lang_dim,
29 | discrete,
30 | hidden_size,
31 | use_language=False,
32 | use_options=True,
33 | option_il=False,
34 | predict_q=False,
35 | max_length=None,
36 | max_ep_len=4096,
37 | action_tanh=False,
38 | no_states=False,
39 | no_actions=False,
40 | ** kwargs):
41 | # max_length used to be K
42 | super().__init__(state_dim, action_dim, max_length=max_length)
43 |
44 | self.use_options = use_options
45 | self.use_language = use_language
46 | self.option_il = option_il
47 | self.predict_q = predict_q
48 |
49 | if use_language and use_options:
50 | raise ValueError("Cannot use language and options!")
51 | if not use_language and not use_options:
52 | raise ValueError("Have to use language or options!")
53 | self.option_dim = option_dim
54 | self.discrete = discrete
55 |
56 | self.hidden_size = hidden_size
57 | config = transformers.GPT2Config(
58 | vocab_size=1, # doesn't matter -- we don't use the vocab
59 | n_embd=hidden_size,
60 | **kwargs
61 | )
62 |
63 | if isinstance(state_dim, tuple):
64 | # LORL
65 | if state_dim[0] == 3:
66 | # LORL Sawyer
67 | self.embed_state = Encoder(hidden_size=hidden_size, ch=3, robot=False)
68 | else:
69 | # LORL Franka
70 | self.embed_state = Encoder(hidden_size=hidden_size, ch=12, robot=True)
71 | else:
72 | self.embed_state = nn.Linear(self.state_dim, hidden_size)
73 |
74 | # note: the only difference between this GPT2Model and the default Huggingface version
75 | # is that the positional embeddings are removed (since we'll add those ourselves)
76 | self.transformer = GPT2Model(config)
77 |
78 | self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
79 |
80 | self.embed_action = nn.Linear(self.act_dim, hidden_size)
81 |
82 | self.no_states = no_states
83 | self.no_actions = no_actions
84 |
85 | if use_options:
86 | self.embed_option = nn.Linear(self.option_dim, hidden_size)
87 |
88 | if use_language:
89 | self.embed_lang = nn.Linear(lang_dim, hidden_size)
90 |
91 | self.embed_ln = nn.LayerNorm(hidden_size)
92 | # note: we don't predict states or returns for the paper
93 | if isinstance(self.state_dim, int):
94 | self.predict_state = torch.nn.Linear(hidden_size, self.state_dim)
95 | self.predict_action = nn.Sequential(
96 | *([nn.Linear(hidden_size, self.act_dim)] + ([nn.Tanh()] if action_tanh and not discrete else []))
97 | )
98 | if use_options:
99 | self.predict_option = torch.nn.Linear(hidden_size, self.option_dim)
100 | if predict_q:
101 | self.predict_q = torch.nn.Linear(hidden_size, self.act_dim)
102 |
103 | def forward(self, states, actions, timesteps, options=None, word_embeddings=None, attention_mask=None):
104 |
105 | batch_size, seq_length = states.shape[0], states.shape[1]
106 |
107 | if attention_mask is None:
108 | raise ValueError('Should not have attention_mask NONE')
109 | # attention mask for GPT: 1 if can be attended to, 0 if not
110 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
111 |
112 | if self.use_options:
113 | assert options is not None
114 | option_embeddings = self.embed_option(options)
115 | time_embeddings = self.embed_timestep(timesteps)
116 |
117 | # time embeddings are treated similar to positional embeddings
118 | option_embeddings = option_embeddings + time_embeddings
119 |
120 | if self.no_states:
121 | # IMP: MAKE SURE THIS IS NOT SET ON BY DEFAULT
122 | state_embeddings = self.embed_state(torch.zeros_like(states))
123 | else:
124 | state_embeddings = self.embed_state(states)
125 | state_embeddings = state_embeddings + time_embeddings
126 |
127 | if self.no_actions:
128 | # IMP: MAKE SURE THIS IS NOT SET ON BY DEFAULT
129 | action_embeddings = self.embed_action(torch.zeros_like(actions))
130 | else:
131 | action_embeddings = self.embed_action(actions)
132 | action_embeddings = action_embeddings + time_embeddings
133 |
134 | # this makes the sequence look like (o1, s1, a1,o2, s2, a2, ...)
135 | # which works nice in an autoregressive sense since states predict actions
136 | # note that o1 and o2 need not be different
137 | stacked_inputs = torch.stack(
138 | (option_embeddings, state_embeddings, action_embeddings),
139 | dim=1).permute(
140 | 0, 2, 1, 3).reshape(
141 | batch_size, 3 * seq_length, self.hidden_size)
142 | # LAYERNORM
143 | stacked_inputs = self.embed_ln(stacked_inputs)
144 |
145 | # to make the attention mask fit the stacked inputs, have to stack it as well
146 | stacked_attention_mask = torch.stack(
147 | (attention_mask, attention_mask, attention_mask), dim=1
148 | ).permute(0, 2, 1).reshape(batch_size, 3 * seq_length)
149 |
150 | # we feed in the input embeddings (not word indices as in NLP) to the model
151 | transformer_outputs = self.transformer(
152 | inputs_embeds=stacked_inputs,
153 | attention_mask=stacked_attention_mask,
154 | )
155 | x = transformer_outputs['last_hidden_state']
156 |
157 | # reshape x so that the second dimension corresponds to the original
158 | # options (0), states (1) or actions (2); i.e. x[:,0,t] is the token for s_t
159 | traj_out = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)
160 | # get predictions
161 | # predict next state given option, state and action. skip the last state for prediction
162 | if isinstance(self.state_dim, int):
163 | state_preds = self.predict_state(traj_out[:, 2])[:, :-1, :]
164 | else:
165 | state_preds = None
166 | # predict next action given state and option
167 | action_preds = self.predict_action(traj_out[:, 1])
168 |
169 | # reconstruct current option given current option
170 | if self.option_il:
171 | option_preds = self.predict_option(traj_out[:, 0])
172 | options_loss = F.mse_loss(option_preds, options.detach())
173 | else:
174 | options_loss = None
175 |
176 | outputs = {'state_preds': state_preds,
177 | 'action_preds': action_preds,
178 | 'options_loss': options_loss}
179 |
180 | if self.predict_q:
181 | # predict next Q given state and option ## IMP: Don't use current action
182 | q_preds = self.predict_q(traj_out[:, 1])
183 | outputs.update({'q_preds': q_preds})
184 |
185 | return outputs
186 |
187 | if self.use_language:
188 | assert word_embeddings is not None
189 | num_tokens = word_embeddings.shape[1]
190 | state_embeddings = self.embed_state(states)
191 | lang_embeddings = self.embed_lang(word_embeddings)
192 | action_embeddings = self.embed_action(actions)
193 | time_embeddings = self.embed_timestep(timesteps)
194 |
195 | # time embeddings are treated similar to positional embeddings
196 | state_embeddings = state_embeddings + time_embeddings
197 | action_embeddings = action_embeddings + time_embeddings
198 |
199 | stacked_inputs = torch.stack(
200 | (state_embeddings, action_embeddings),
201 | dim=1).permute(
202 | 0, 2, 1, 3).reshape(
203 | batch_size, 2 * seq_length, self.hidden_size)
204 | lang_and_inputs = torch.cat([lang_embeddings, stacked_inputs], dim=1)
205 | # LAYERNORM AFTER LANGUAGE
206 | stacked_inputs = self.embed_ln(lang_and_inputs)
207 |
208 | # to make the attention mask fit the stacked inputs, have to stack it as well
209 | stacked_attention_mask = torch.stack(
210 | (attention_mask, attention_mask), dim=1
211 | ).permute(0, 2, 1).reshape(batch_size, 2*seq_length)
212 | lang_attn_mask = torch.cat(
213 | [torch.ones((batch_size, num_tokens), device=states.device), stacked_attention_mask], dim=1)
214 |
215 | # we feed in the input embeddings (not word indices as in NLP) to the model
216 | transformer_outputs = self.transformer(
217 | inputs_embeds=stacked_inputs,
218 | attention_mask=lang_attn_mask,
219 | )
220 | x = transformer_outputs['last_hidden_state']
221 |
222 | # reshape x so that the second dimension corresponds to the original
223 | # states (0), or actions (1); i.e. x[:,0,t] is the token for s_t
224 | lang_out = x[:, :num_tokens, :].reshape(
225 | batch_size, num_tokens, 1, self.hidden_size).permute(0, 2, 1, 3)
226 | traj_out = x[:, num_tokens:, :].reshape(
227 | batch_size, seq_length, 2, self.hidden_size).permute(0, 2, 1, 3)
228 |
229 | # get predictions
230 | # predict state given state, action. skip the last prediction
231 | if isinstance(self.state_dim, int):
232 | state_preds = self.predict_state(traj_out[:, 1])[:, :-1, :]
233 | else:
234 | state_preds = None
235 | action_preds = self.predict_action(traj_out[:, 0]) # predict next action given state
236 |
237 | outputs = {'state_preds': state_preds,
238 | 'action_preds': action_preds}
239 |
240 | if self.predict_q:
241 | # predict next Q given state ## IMP: Don't use current action
242 | q_preds = self.predict_q(traj_out[:, 0])
243 | outputs.update({'q_preds': q_preds})
244 |
245 | return outputs
246 |
247 | def get_action(self, states, actions, timesteps, options=None, word_embeddings=None, **kwargs):
248 |
249 | if self.use_options:
250 | assert options is not None
251 | if isinstance(self.state_dim, tuple):
252 | states = states.reshape(1, -1, *self.state_dim)
253 | else:
254 | states = states.reshape(1, -1, self.state_dim)
255 | options = options.reshape(1, -1, self.option_dim)
256 | actions = actions.reshape(1, -1, self.act_dim)
257 | timesteps = timesteps.reshape(1, -1)
258 |
259 | if self.max_length is not None:
260 | states = states[:, -self.max_length:]
261 | options = options[:, -self.max_length:]
262 | actions = actions[:, -self.max_length:]
263 | timesteps = timesteps[:, -self.max_length:]
264 |
265 | # pad all tokens to sequence length
266 | attention_mask = pad(torch.ones(1, states.shape[1]), self.max_length).to(
267 | dtype=torch.long, device=states.device).reshape(1, -1)
268 | states = pad(states, self.max_length).to(dtype=torch.float32)
269 | options = pad(options, self.max_length).to(dtype=torch.float32)
270 | actions = pad(actions, self.max_length).to(dtype=torch.float32)
271 | timesteps = pad(timesteps, self.max_length).to(dtype=torch.long)
272 | else:
273 | raise ValueError('Should not have max_length NONE')
274 | attention_mask = None
275 |
276 | preds = self.forward(
277 | states, actions, timesteps, options=options, attention_mask=attention_mask)
278 |
279 | if self.use_language:
280 | assert word_embeddings is not None
281 | if isinstance(self.state_dim, tuple):
282 | states = states.reshape(1, -1, *self.state_dim)
283 | else:
284 | states = states.reshape(1, -1, self.state_dim)
285 | actions = actions.reshape(1, -1, self.act_dim)
286 | timesteps = timesteps.reshape(1, -1)
287 |
288 | if self.max_length is not None:
289 | states = states[:, -self.max_length:]
290 | actions = actions[:, -self.max_length:]
291 | timesteps = timesteps[:, -self.max_length:]
292 |
293 | # pad all tokens to sequence length
294 | attention_mask = pad(
295 | torch.ones(1, states.shape[1]),
296 | self.max_length).to(
297 | dtype=torch.long, device=states.device).reshape(
298 | 1, -1)
299 | states = pad(states, self.max_length).to(dtype=torch.float32)
300 | actions = pad(actions, self.max_length).to(dtype=torch.float32)
301 | timesteps = pad(timesteps, self.max_length).to(dtype=torch.long)
302 | else:
303 | attention_mask = None
304 |
305 | preds = self.forward(
306 | states, actions, timesteps, word_embeddings=word_embeddings, attention_mask=attention_mask, **kwargs)
307 |
308 | return preds
309 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/difftrainer.py:
--------------------------------------------------------------------------------
1 | # import pdb
2 |
3 | import numpy as np
4 | import torch
5 | import copy
6 | from diffuser.utils.training import EMA
7 | from ml_logger import logger
8 | import os
9 | import torch.nn as nn
10 |
11 | class DiffTrainer(nn.Module):
12 | def __init__(
13 | self,
14 | args,
15 | diffusion_model,
16 | diff_trainer_args,
17 | ):
18 | # max_length used to be K
19 | super().__init__()
20 |
21 | self.model = diffusion_model
22 | self.ema = EMA(diff_trainer_args['ema_decay'])
23 | self.ema_model = copy.deepcopy(self.model)
24 | self.update_ema_every = diff_trainer_args['update_ema_every']
25 | self.save_checkpoints = diff_trainer_args['save_checkpoints']
26 | self.step_start_ema = 2000
27 |
28 | self.log_freq = diff_trainer_args['log_freq']
29 | self.sample_freq = diff_trainer_args['sample_freq']
30 | self.save_freq = diff_trainer_args['save_freq']
31 | self.label_freq = diff_trainer_args['label_freq']
32 | self.save_parallel = diff_trainer_args['save_parallel']
33 |
34 | # self.batch_size = diff_trainer_args['train_batch_size']
35 | self.gradient_accumulate_every = diff_trainer_args['gradient_accumulate_every']
36 |
37 | self.bucket = os.path.join(args.hydra_base_dir, "buckets")
38 | self.n_reference = diff_trainer_args['n_reference']
39 |
40 | self.reset_parameters()
41 | self.step = 0
42 |
43 | self.device = diff_trainer_args['train_device']
44 |
45 | logger.configure(args.hydra_base_dir,
46 | prefix=f"diffuser_log")
47 | # self.model = self.model.to(device=self.device) # TODO
48 | # self.ema_model = self.ema_model.to(device=self.device)
49 |
50 | self.lr = diff_trainer_args['train_lr']
51 | self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
52 |
53 | def reset_parameters(self):
54 | self.ema_model.load_state_dict(self.model.state_dict())
55 |
56 | def step_ema(self):
57 | if self.step < self.step_start_ema:
58 | self.reset_parameters()
59 | return
60 | self.ema.update_model_average(self.ema_model, self.model)
61 |
62 | # -----------------------------------------------------------------------------#
63 | # ------------------------------------ api ------------------------------------#
64 | # -----------------------------------------------------------------------------#
65 |
66 | def train_iteration(self, batch):
67 | # timer = Timer()
68 | # for step in range(n_train_steps):
69 | # for i in range(self.gradient_accumulate_every):
70 | # batch = next(self.dataloader)
71 | # batch = batch_to_device(batch, device=self.device)
72 | loss, infos = self.model.loss(*batch)
73 | # loss = loss / self.gradient_accumulate_every
74 |
75 | self.optimizer.zero_grad()
76 | loss.backward(retain_graph=True)
77 | self.optimizer.step()
78 |
79 | if self.step % self.update_ema_every == 0:
80 | self.step_ema()
81 |
82 | if self.step % self.save_freq == 0:
83 | self.save()
84 |
85 | infos.pop('obs')
86 |
87 | if self.step % self.log_freq == 0:
88 | infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()])
89 | logger.print(f'{self.step}: {loss:8.4f} | {infos_str}')
90 | metrics = {k: v.detach().item() for k, v in infos.items()}
91 | metrics['steps'] = self.step
92 | metrics['loss'] = loss.detach().item()
93 | logger.log_metrics_summary(metrics, default_stats='mean')
94 |
95 | # if self.step == 0 and self.sample_freq:
96 | # self.render_reference(self.n_reference)
97 |
98 | # if self.sample_freq and self.step % self.sample_freq == 0:
99 | # if self.model.__class__ == diffuser.models.diffusion.GaussianInvDynDiffusion:
100 | # self.inv_render_samples()
101 | # elif self.model.__class__ == diffuser.models.diffusion.ActionGaussianDiffusion:
102 | # pass
103 | # else:
104 | # self.render_samples()
105 |
106 | self.step += 1
107 | return
108 |
109 | def save(self):
110 | '''
111 | saves model and ema to disk;
112 | syncs to storage bucket if a bucket is specified
113 | '''
114 | data = {
115 | 'step': self.step,
116 | 'model': self.model.state_dict(),
117 | 'ema': self.ema_model.state_dict()
118 | }
119 | savepath = os.path.join(self.bucket, logger.prefix, 'checkpoint')
120 | os.makedirs(savepath, exist_ok=True)
121 | # logger.save_torch(data, savepath)
122 | if self.save_checkpoints:
123 | savepath = os.path.join(savepath, f'state_{self.step}.pt')
124 | else:
125 | savepath = os.path.join(savepath, 'state.pt')
126 | torch.save(data, savepath)
127 | logger.print(f'[ utils/training ] Saved model to {savepath}')
128 |
129 | def load(self, loadpath):
130 | '''
131 | loads model and ema from disk
132 | '''
133 | # loadpath = os.path.join(self.bucket, logger.prefix, f'checkpoint/state.pt')
134 | # data = logger.load_torch(loadpath)
135 | data = torch.load(loadpath)
136 |
137 | # self.step = data['step']
138 | self.model.load_state_dict(data['model'])
139 | self.ema_model.load_state_dict(data['ema'])
140 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/__init__.py:
--------------------------------------------------------------------------------
1 | from . import environments
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .sequence import *
2 | from .d4rl import load_environment
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/datasets/buffer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | def atleast_2d(x):
4 | while x.ndim < 2:
5 | x = np.expand_dims(x, axis=-1)
6 | return x
7 |
8 | class ReplayBuffer:
9 |
10 | def __init__(self, max_n_episodes, max_path_length, termination_penalty):
11 | self._dict = {
12 | 'path_lengths': np.zeros(max_n_episodes, dtype=np.int),
13 | }
14 | self._count = 0
15 | self.max_n_episodes = max_n_episodes
16 | self.max_path_length = max_path_length
17 | self.termination_penalty = termination_penalty
18 |
19 | def __repr__(self):
20 | return '[ datasets/buffer ] Fields:\n' + '\n'.join(
21 | f' {key}: {val.shape}'
22 | for key, val in self.items()
23 | )
24 |
25 | def __getitem__(self, key):
26 | return self._dict[key]
27 |
28 | def __setitem__(self, key, val):
29 | self._dict[key] = val
30 | self._add_attributes()
31 |
32 | @property
33 | def n_episodes(self):
34 | return self._count
35 |
36 | @property
37 | def n_steps(self):
38 | return sum(self['path_lengths'])
39 |
40 | def _add_keys(self, path):
41 | if hasattr(self, 'keys'):
42 | return
43 | self.keys = list(path.keys())
44 |
45 | def _add_attributes(self):
46 | '''
47 | can access fields with `buffer.observations`
48 | instead of `buffer['observations']`
49 | '''
50 | for key, val in self._dict.items():
51 | setattr(self, key, val)
52 |
53 | def items(self):
54 | return {k: v for k, v in self._dict.items()
55 | if k != 'path_lengths'}.items()
56 |
57 | def _allocate(self, key, array):
58 | assert key not in self._dict
59 | dim = array.shape[-1]
60 | shape = (self.max_n_episodes, self.max_path_length, dim)
61 | self._dict[key] = np.zeros(shape, dtype=np.float32)
62 | # print(f'[ utils/mujoco ] Allocated {key} with size {shape}')
63 |
64 | def add_path(self, path):
65 | path_length = len(path['observations'])
66 | assert path_length <= self.max_path_length
67 |
68 | if path['terminals'].any():
69 | assert (path['terminals'][-1] == True) and (not path['terminals'][:-1].any())
70 |
71 | ## if first path added, set keys based on contents
72 | self._add_keys(path)
73 |
74 | ## add tracked keys in path
75 | for key in self.keys:
76 | array = atleast_2d(path[key])
77 | if key not in self._dict: self._allocate(key, array)
78 | self._dict[key][self._count, :path_length] = array
79 |
80 | ## penalize early termination
81 | if path['terminals'].any() and self.termination_penalty is not None:
82 | assert not path['timeouts'].any(), 'Penalized a timeout episode for early termination'
83 | self._dict['rewards'][self._count, path_length - 1] += self.termination_penalty
84 |
85 | ## record path length
86 | self._dict['path_lengths'][self._count] = path_length
87 |
88 | ## increment path counter
89 | self._count += 1
90 |
91 | def truncate_path(self, path_ind, step):
92 | old = self._dict['path_lengths'][path_ind]
93 | new = min(step, old)
94 | self._dict['path_lengths'][path_ind] = new
95 |
96 | def finalize(self):
97 | ## remove extra slots
98 | for key in self.keys + ['path_lengths']:
99 | self._dict[key] = self._dict[key][:self._count]
100 | self._add_attributes()
101 | print(f'[ datasets/buffer ] Finalized replay buffer | {self._count} episodes')
102 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/datasets/d4rl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import collections
3 | import numpy as np
4 | import gym
5 | import pdb
6 |
7 | from contextlib import (
8 | contextmanager,
9 | redirect_stderr,
10 | redirect_stdout,
11 | )
12 |
13 | @contextmanager
14 | def suppress_output():
15 | """
16 | A context manager that redirects stdout and stderr to devnull
17 | https://stackoverflow.com/a/52442331
18 | """
19 | with open(os.devnull, 'w') as fnull:
20 | with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:
21 | yield (err, out)
22 |
23 | with suppress_output():
24 | ## d4rl prints out a variety of warnings
25 | import d4rl
26 |
27 | #-----------------------------------------------------------------------------#
28 | #-------------------------------- general api --------------------------------#
29 | #-----------------------------------------------------------------------------#
30 |
31 | def load_environment(name):
32 | if type(name) != str:
33 | ## name is already an environment
34 | return name
35 | with suppress_output():
36 | wrapped_env = gym.make(name)
37 | env = wrapped_env.unwrapped
38 | env.max_episode_steps = wrapped_env._max_episode_steps
39 | env.name = name
40 | return env
41 |
42 | def get_dataset(env):
43 | dataset = env.get_dataset()
44 |
45 | if 'antmaze' in str(env).lower():
46 | ## the antmaze-v0 environments have a variety of bugs
47 | ## involving trajectory segmentation, so manually reset
48 | ## the terminal and timeout fields
49 | dataset = antmaze_fix_timeouts(dataset)
50 | dataset = antmaze_scale_rewards(dataset)
51 | get_max_delta(dataset)
52 |
53 | return dataset
54 |
55 | def sequence_dataset(env, preprocess_fn):
56 | """
57 | Returns an iterator through trajectories.
58 | Args:
59 | env: An OfflineEnv object.
60 | dataset: An optional dataset to pass in for processing. If None,
61 | the dataset will default to env.get_dataset()
62 | **kwargs: Arguments to pass to env.get_dataset().
63 | Returns:
64 | An iterator through dictionaries with keys:
65 | observations
66 | actions
67 | rewards
68 | terminals
69 | """
70 | dataset = get_dataset(env)
71 | dataset = preprocess_fn(dataset)
72 |
73 | N = dataset['rewards'].shape[0]
74 | data_ = collections.defaultdict(list)
75 |
76 | # The newer version of the dataset adds an explicit
77 | # timeouts field. Keep old method for backwards compatability.
78 | use_timeouts = 'timeouts' in dataset
79 |
80 | episode_step = 0
81 | for i in range(N):
82 | done_bool = bool(dataset['terminals'][i])
83 | if use_timeouts:
84 | final_timestep = dataset['timeouts'][i]
85 | else:
86 | final_timestep = (episode_step == env._max_episode_steps - 1)
87 |
88 | for k in dataset:
89 | if 'metadata' in k or 'infos' in k: continue
90 | data_[k].append(dataset[k][i])
91 | if done_bool or final_timestep:
92 | episode_step = 0
93 | episode_data = {}
94 | for k in data_:
95 | episode_data[k] = np.array(data_[k])
96 | if 'maze2d' in env.name:
97 | episode_data = process_maze2d_episode(episode_data)
98 | yield episode_data
99 | data_ = collections.defaultdict(list)
100 |
101 | episode_step += 1
102 |
103 |
104 | #-----------------------------------------------------------------------------#
105 | #-------------------------------- maze2d fixes -------------------------------#
106 | #-----------------------------------------------------------------------------#
107 |
108 | def process_maze2d_episode(episode):
109 | '''
110 | adds in `next_observations` field to episode
111 | '''
112 | assert 'next_observations' not in episode
113 | length = len(episode['observations'])
114 | next_observations = episode['observations'][1:].copy()
115 | for key, val in episode.items():
116 | episode[key] = val[:-1]
117 | episode['next_observations'] = next_observations
118 | return episode
119 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/datasets/normalization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.interpolate as interpolate
3 | import pdb
4 |
5 | POINTMASS_KEYS = ['observations', 'actions', 'next_observations', 'deltas']
6 |
7 | #-----------------------------------------------------------------------------#
8 | #--------------------------- multi-field normalizer --------------------------#
9 | #-----------------------------------------------------------------------------#
10 |
11 | class DatasetNormalizer:
12 |
13 | def __init__(self, dataset, normalizer, path_lengths=None):
14 | dataset = flatten(dataset, path_lengths)
15 |
16 | self.observation_dim = dataset['observations'].shape[1]
17 | self.action_dim = dataset['actions'].shape[1]
18 |
19 | if type(normalizer) == str:
20 | normalizer = eval(normalizer)
21 |
22 | self.normalizers = {}
23 | for key, val in dataset.items():
24 | try:
25 | self.normalizers[key] = normalizer(val)
26 | except:
27 | print(f'[ utils/normalization ] Skipping {key} | {normalizer}')
28 | # key: normalizer(val)
29 | # for key, val in dataset.items()
30 |
31 | def __repr__(self):
32 | string = ''
33 | for key, normalizer in self.normalizers.items():
34 | string += f'{key}: {normalizer}]\n'
35 | return string
36 |
37 | def __call__(self, *args, **kwargs):
38 | return self.normalize(*args, **kwargs)
39 |
40 | def normalize(self, x, key):
41 | return self.normalizers[key].normalize(x)
42 |
43 | def unnormalize(self, x, key):
44 | return self.normalizers[key].unnormalize(x)
45 |
46 | def flatten(dataset, path_lengths):
47 | '''
48 | flattens dataset of { key: [ n_episodes x max_path_lenth x dim ] }
49 | to { key : [ (n_episodes * sum(path_lengths)) x dim ]}
50 | '''
51 | flattened = {}
52 | for key, xs in dataset.items():
53 | assert len(xs) == len(path_lengths)
54 | flattened[key] = np.concatenate([
55 | x[:length]
56 | for x, length in zip(xs, path_lengths)
57 | ], axis=0)
58 | return flattened
59 |
60 | #-----------------------------------------------------------------------------#
61 | #------------------------------- @TODO: remove? ------------------------------#
62 | #-----------------------------------------------------------------------------#
63 |
64 | class PointMassDatasetNormalizer(DatasetNormalizer):
65 |
66 | def __init__(self, preprocess_fns, dataset, normalizer, keys=POINTMASS_KEYS):
67 |
68 | reshaped = {}
69 | for key, val in dataset.items():
70 | dim = val.shape[-1]
71 | reshaped[key] = val.reshape(-1, dim)
72 |
73 | self.observation_dim = reshaped['observations'].shape[1]
74 | self.action_dim = reshaped['actions'].shape[1]
75 |
76 | if type(normalizer) == str:
77 | normalizer = eval(normalizer)
78 |
79 | self.normalizers = {
80 | key: normalizer(reshaped[key])
81 | for key in keys
82 | }
83 |
84 | #-----------------------------------------------------------------------------#
85 | #-------------------------- single-field normalizers -------------------------#
86 | #-----------------------------------------------------------------------------#
87 |
88 | class Normalizer:
89 | '''
90 | parent class, subclass by defining the `normalize` and `unnormalize` methods
91 | '''
92 |
93 | def __init__(self, X):
94 | self.X = X.astype(np.float32)
95 | self.mins = X.min(axis=0)
96 | self.maxs = X.max(axis=0)
97 |
98 | def __repr__(self):
99 | return (
100 | f'''[ Normalizer ] dim: {self.mins.size}\n -: '''
101 | f'''{np.round(self.mins, 2)}\n +: {np.round(self.maxs, 2)}\n'''
102 | )
103 |
104 | def __call__(self, x):
105 | return self.normalize(x)
106 |
107 | def normalize(self, *args, **kwargs):
108 | raise NotImplementedError()
109 |
110 | def unnormalize(self, *args, **kwargs):
111 | raise NotImplementedError()
112 |
113 |
114 | class DebugNormalizer(Normalizer):
115 | '''
116 | identity function
117 | '''
118 |
119 | def normalize(self, x, *args, **kwargs):
120 | return x
121 |
122 | def unnormalize(self, x, *args, **kwargs):
123 | return x
124 |
125 |
126 | class GaussianNormalizer(Normalizer):
127 | '''
128 | normalizes to zero mean and unit variance
129 | '''
130 |
131 | def __init__(self, *args, **kwargs):
132 | super().__init__(*args, **kwargs)
133 | self.means = self.X.mean(axis=0)
134 | self.stds = self.X.std(axis=0)
135 | self.z = 1
136 |
137 | def __repr__(self):
138 | return (
139 | f'''[ Normalizer ] dim: {self.mins.size}\n '''
140 | f'''means: {np.round(self.means, 2)}\n '''
141 | f'''stds: {np.round(self.z * self.stds, 2)}\n'''
142 | )
143 |
144 | def normalize(self, x):
145 | return (x - self.means) / self.stds
146 |
147 | def unnormalize(self, x):
148 | return x * self.stds + self.means
149 |
150 |
151 | class LimitsNormalizer(Normalizer):
152 | '''
153 | maps [ xmin, xmax ] to [ -1, 1 ]
154 | '''
155 |
156 | def normalize(self, x):
157 | ## [ 0, 1 ]
158 | x = (x - self.mins) / (self.maxs - self.mins)
159 | ## [ -1, 1 ]
160 | x = 2 * x - 1
161 | return x
162 |
163 | def unnormalize(self, x, eps=1e-4):
164 | '''
165 | x : [ -1, 1 ]
166 | '''
167 | if x.max() > 1 + eps or x.min() < -1 - eps:
168 | # print(f'[ datasets/mujoco ] Warning: sample out of range | ({x.min():.4f}, {x.max():.4f})')
169 | x = np.clip(x, -1, 1)
170 |
171 | ## [ -1, 1 ] --> [ 0, 1 ]
172 | x = (x + 1) / 2.
173 |
174 | return x * (self.maxs - self.mins) + self.mins
175 |
176 | class SafeLimitsNormalizer(LimitsNormalizer):
177 | '''
178 | functions like LimitsNormalizer, but can handle data for which a dimension is constant
179 | '''
180 |
181 | def __init__(self, *args, eps=1, **kwargs):
182 | super().__init__(*args, **kwargs)
183 | for i in range(len(self.mins)):
184 | if self.mins[i] == self.maxs[i]:
185 | print(f'''
186 | [ utils/normalization ] Constant data in dimension {i} | '''
187 | f'''max = min = {self.maxs[i]}'''
188 | )
189 | self.mins -= eps
190 | self.maxs += eps
191 |
192 | #-----------------------------------------------------------------------------#
193 | #------------------------------- CDF normalizer ------------------------------#
194 | #-----------------------------------------------------------------------------#
195 |
196 | class CDFNormalizer(Normalizer):
197 | '''
198 | makes training data uniform (over each dimension) by transforming it with marginal CDFs
199 | '''
200 |
201 | def __init__(self, X):
202 | super().__init__(atleast_2d(X))
203 | self.dim = self.X.shape[1]
204 | self.cdfs = [CDFNormalizer1d(self.X[:, i]) for i in range(self.dim)]
205 |
206 | def __repr__(self):
207 | return f'[ CDFNormalizer ] dim: {self.mins.size}\n' + ' | '.join(
208 | f'{i:3d}: {cdf}' for i, cdf in enumerate(self.cdfs)
209 | )
210 |
211 | def wrap(self, fn_name, x):
212 | shape = x.shape
213 | ## reshape to 2d
214 | x = x.reshape(-1, self.dim)
215 | out = np.zeros_like(x)
216 | for i, cdf in enumerate(self.cdfs):
217 | fn = getattr(cdf, fn_name)
218 | out[:, i] = fn(x[:, i])
219 | return out.reshape(shape)
220 |
221 | def normalize(self, x):
222 | return self.wrap('normalize', x)
223 |
224 | def unnormalize(self, x):
225 | return self.wrap('unnormalize', x)
226 |
227 | class CDFNormalizer1d:
228 | '''
229 | CDF normalizer for a single dimension
230 | '''
231 |
232 | def __init__(self, X):
233 | assert X.ndim == 1
234 | self.X = X.astype(np.float32)
235 | if self.X.max() == self.X.min():
236 | self.constant = True
237 | else:
238 | self.constant = False
239 | quantiles, cumprob = empirical_cdf(self.X)
240 | self.fn = interpolate.interp1d(quantiles, cumprob)
241 | self.inv = interpolate.interp1d(cumprob, quantiles)
242 |
243 | self.xmin, self.xmax = quantiles.min(), quantiles.max()
244 | self.ymin, self.ymax = cumprob.min(), cumprob.max()
245 |
246 | def __repr__(self):
247 | return (
248 | f'[{np.round(self.xmin, 2):.4f}, {np.round(self.xmax, 2):.4f}'
249 | )
250 |
251 | def normalize(self, x):
252 | if self.constant:
253 | return x
254 |
255 | x = np.clip(x, self.xmin, self.xmax)
256 | ## [ 0, 1 ]
257 | y = self.fn(x)
258 | ## [ -1, 1 ]
259 | y = 2 * y - 1
260 | return y
261 |
262 | def unnormalize(self, x, eps=1e-4):
263 | '''
264 | X : [ -1, 1 ]
265 | '''
266 | ## [ -1, 1 ] --> [ 0, 1 ]
267 | if self.constant:
268 | return x
269 |
270 | x = (x + 1) / 2.
271 |
272 | if (x < self.ymin - eps).any() or (x > self.ymax + eps).any():
273 | print(
274 | f'''[ dataset/normalization ] Warning: out of range in unnormalize: '''
275 | f'''[{x.min()}, {x.max()}] | '''
276 | f'''x : [{self.xmin}, {self.xmax}] | '''
277 | f'''y: [{self.ymin}, {self.ymax}]'''
278 | )
279 |
280 | x = np.clip(x, self.ymin, self.ymax)
281 |
282 | y = self.inv(x)
283 | return y
284 |
285 | def empirical_cdf(sample):
286 | ## https://stackoverflow.com/a/33346366
287 |
288 | # find the unique values and their corresponding counts
289 | quantiles, counts = np.unique(sample, return_counts=True)
290 |
291 | # take the cumulative sum of the counts and divide by the sample size to
292 | # get the cumulative probabilities between 0 and 1
293 | cumprob = np.cumsum(counts).astype(np.double) / sample.size
294 |
295 | return quantiles, cumprob
296 |
297 | def atleast_2d(x):
298 | if x.ndim < 2:
299 | x = x[:,None]
300 | return x
301 |
302 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/datasets/preprocessing.py:
--------------------------------------------------------------------------------
1 | import gym
2 | import numpy as np
3 | import einops
4 | from scipy.spatial.transform import Rotation as R
5 | import pdb
6 |
7 | from .d4rl import load_environment
8 |
9 | #-----------------------------------------------------------------------------#
10 | #-------------------------------- general api --------------------------------#
11 | #-----------------------------------------------------------------------------#
12 |
13 | def compose(*fns):
14 |
15 | def _fn(x):
16 | for fn in fns:
17 | x = fn(x)
18 | return x
19 |
20 | return _fn
21 |
22 | def get_preprocess_fn(fn_names, env):
23 | fns = [eval(name)(env) for name in fn_names]
24 | return compose(*fns)
25 |
26 | def get_policy_preprocess_fn(fn_names):
27 | fns = [eval(name) for name in fn_names]
28 | return compose(*fns)
29 |
30 | #-----------------------------------------------------------------------------#
31 | #-------------------------- preprocessing functions --------------------------#
32 | #-----------------------------------------------------------------------------#
33 |
34 | #------------------------ @TODO: remove some of these ------------------------#
35 |
36 | def arctanh_actions(*args, **kwargs):
37 | epsilon = 1e-4
38 |
39 | def _fn(dataset):
40 | actions = dataset['actions']
41 | assert actions.min() >= -1 and actions.max() <= 1, \
42 | f'applying arctanh to actions in range [{actions.min()}, {actions.max()}]'
43 | actions = np.clip(actions, -1 + epsilon, 1 - epsilon)
44 | dataset['actions'] = np.arctanh(actions)
45 | return dataset
46 |
47 | return _fn
48 |
49 | def add_deltas(env):
50 |
51 | def _fn(dataset):
52 | deltas = dataset['next_observations'] - dataset['observations']
53 | dataset['deltas'] = deltas
54 | return dataset
55 |
56 | return _fn
57 |
58 |
59 | def maze2d_set_terminals(env):
60 | env = load_environment(env) if type(env) == str else env
61 | goal = np.array(env._target)
62 | threshold = 0.5
63 |
64 | def _fn(dataset):
65 | xy = dataset['observations'][:,:2]
66 | distances = np.linalg.norm(xy - goal, axis=-1)
67 | at_goal = distances < threshold
68 | timeouts = np.zeros_like(dataset['timeouts'])
69 |
70 | ## timeout at time t iff
71 | ## at goal at time t and
72 | ## not at goal at time t + 1
73 | timeouts[:-1] = at_goal[:-1] * ~at_goal[1:]
74 |
75 | timeout_steps = np.where(timeouts)[0]
76 | path_lengths = timeout_steps[1:] - timeout_steps[:-1]
77 |
78 | print(
79 | f'[ utils/preprocessing ] Segmented {env.name} | {len(path_lengths)} paths | '
80 | f'min length: {path_lengths.min()} | max length: {path_lengths.max()}'
81 | )
82 |
83 | dataset['timeouts'] = timeouts
84 | return dataset
85 |
86 | return _fn
87 |
88 |
89 | #-------------------------- block-stacking --------------------------#
90 |
91 | def blocks_quat_to_euler(observations):
92 | '''
93 | input : [ N x robot_dim + n_blocks * 8 ] = [ N x 39 ]
94 | xyz: 3
95 | quat: 4
96 | contact: 1
97 |
98 | returns : [ N x robot_dim + n_blocks * 10] = [ N x 47 ]
99 | xyz: 3
100 | sin: 3
101 | cos: 3
102 | contact: 1
103 | '''
104 | robot_dim = 7
105 | block_dim = 8
106 | n_blocks = 4
107 | assert observations.shape[-1] == robot_dim + n_blocks * block_dim
108 |
109 | X = observations[:, :robot_dim]
110 |
111 | for i in range(n_blocks):
112 | start = robot_dim + i * block_dim
113 | end = start + block_dim
114 |
115 | block_info = observations[:, start:end]
116 |
117 | xpos = block_info[:, :3]
118 | quat = block_info[:, 3:-1]
119 | contact = block_info[:, -1:]
120 |
121 | euler = R.from_quat(quat).as_euler('xyz')
122 | sin = np.sin(euler)
123 | cos = np.cos(euler)
124 |
125 | X = np.concatenate([
126 | X,
127 | xpos,
128 | sin,
129 | cos,
130 | contact,
131 | ], axis=-1)
132 |
133 | return X
134 |
135 | def blocks_euler_to_quat_2d(observations):
136 | robot_dim = 7
137 | block_dim = 10
138 | n_blocks = 4
139 |
140 | assert observations.shape[-1] == robot_dim + n_blocks * block_dim
141 |
142 | X = observations[:, :robot_dim]
143 |
144 | for i in range(n_blocks):
145 | start = robot_dim + i * block_dim
146 | end = start + block_dim
147 |
148 | block_info = observations[:, start:end]
149 |
150 | xpos = block_info[:, :3]
151 | sin = block_info[:, 3:6]
152 | cos = block_info[:, 6:9]
153 | contact = block_info[:, 9:]
154 |
155 | euler = np.arctan2(sin, cos)
156 | quat = R.from_euler('xyz', euler, degrees=False).as_quat()
157 |
158 | X = np.concatenate([
159 | X,
160 | xpos,
161 | quat,
162 | contact,
163 | ], axis=-1)
164 |
165 | return X
166 |
167 | def blocks_euler_to_quat(paths):
168 | return np.stack([
169 | blocks_euler_to_quat_2d(path)
170 | for path in paths
171 | ], axis=0)
172 |
173 | def blocks_process_cubes(env):
174 |
175 | def _fn(dataset):
176 | for key in ['observations', 'next_observations']:
177 | dataset[key] = blocks_quat_to_euler(dataset[key])
178 | return dataset
179 |
180 | return _fn
181 |
182 | def blocks_remove_kuka(env):
183 |
184 | def _fn(dataset):
185 | for key in ['observations', 'next_observations']:
186 | dataset[key] = dataset[key][:, 7:]
187 | return dataset
188 |
189 | return _fn
190 |
191 | def blocks_add_kuka(observations):
192 | '''
193 | observations : [ batch_size x horizon x 32 ]
194 | '''
195 | robot_dim = 7
196 | batch_size, horizon, _ = observations.shape
197 | observations = np.concatenate([
198 | np.zeros((batch_size, horizon, 7)),
199 | observations,
200 | ], axis=-1)
201 | return observations
202 |
203 | def blocks_cumsum_quat(deltas):
204 | '''
205 | deltas : [ batch_size x horizon x transition_dim ]
206 | '''
207 | robot_dim = 7
208 | block_dim = 8
209 | n_blocks = 4
210 | assert deltas.shape[-1] == robot_dim + n_blocks * block_dim
211 |
212 | batch_size, horizon, _ = deltas.shape
213 |
214 | cumsum = deltas.cumsum(axis=1)
215 | for i in range(n_blocks):
216 | start = robot_dim + i * block_dim + 3
217 | end = start + 4
218 |
219 | quat = deltas[:, :, start:end].copy()
220 |
221 | quat = einops.rearrange(quat, 'b h q -> (b h) q')
222 | euler = R.from_quat(quat).as_euler('xyz')
223 | euler = einops.rearrange(euler, '(b h) e -> b h e', b=batch_size)
224 | cumsum_euler = euler.cumsum(axis=1)
225 |
226 | cumsum_euler = einops.rearrange(cumsum_euler, 'b h e -> (b h) e')
227 | cumsum_quat = R.from_euler('xyz', cumsum_euler).as_quat()
228 | cumsum_quat = einops.rearrange(cumsum_quat, '(b h) q -> b h q', b=batch_size)
229 |
230 | cumsum[:, :, start:end] = cumsum_quat.copy()
231 |
232 | return cumsum
233 |
234 | def blocks_delta_quat_helper(observations, next_observations):
235 | '''
236 | input : [ N x robot_dim + n_blocks * 8 ] = [ N x 39 ]
237 | xyz: 3
238 | quat: 4
239 | contact: 1
240 | '''
241 | robot_dim = 7
242 | block_dim = 8
243 | n_blocks = 4
244 | assert observations.shape[-1] == next_observations.shape[-1] == robot_dim + n_blocks * block_dim
245 |
246 | deltas = (next_observations - observations)[:, :robot_dim]
247 |
248 | for i in range(n_blocks):
249 | start = robot_dim + i * block_dim
250 | end = start + block_dim
251 |
252 | block_info = observations[:, start:end]
253 | next_block_info = next_observations[:, start:end]
254 |
255 | xpos = block_info[:, :3]
256 | next_xpos = next_block_info[:, :3]
257 |
258 | quat = block_info[:, 3:-1]
259 | next_quat = next_block_info[:, 3:-1]
260 |
261 | contact = block_info[:, -1:]
262 | next_contact = next_block_info[:, -1:]
263 |
264 | delta_xpos = next_xpos - xpos
265 | delta_contact = next_contact - contact
266 |
267 | rot = R.from_quat(quat)
268 | next_rot = R.from_quat(next_quat)
269 |
270 | delta_quat = (next_rot * rot.inv()).as_quat()
271 | w = delta_quat[:, -1:]
272 |
273 | ## make w positive to avoid [0, 0, 0, -1]
274 | delta_quat = delta_quat * np.sign(w)
275 |
276 | ## apply rot then delta to ensure we end at next_rot
277 | ## delta * rot = next_rot * rot' * rot = next_rot
278 | next_euler = next_rot.as_euler('xyz')
279 | next_euler_check = (R.from_quat(delta_quat) * rot).as_euler('xyz')
280 | assert np.allclose(next_euler, next_euler_check)
281 |
282 | deltas = np.concatenate([
283 | deltas,
284 | delta_xpos,
285 | delta_quat,
286 | delta_contact,
287 | ], axis=-1)
288 |
289 | return deltas
290 |
291 | def blocks_add_deltas(env):
292 |
293 | def _fn(dataset):
294 | deltas = blocks_delta_quat_helper(dataset['observations'], dataset['next_observations'])
295 | # deltas = dataset['next_observations'] - dataset['observations']
296 | dataset['deltas'] = deltas
297 | return dataset
298 |
299 | return _fn
300 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/datasets/sequence.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import numpy as np
3 | import torch
4 | import pdb
5 |
6 | from .preprocessing import get_preprocess_fn
7 | from .d4rl import load_environment, sequence_dataset
8 | from .normalization import DatasetNormalizer
9 | from .buffer import ReplayBuffer
10 |
11 | RewardBatch = namedtuple('Batch', 'trajectories conditions returns')
12 | Batch = namedtuple('Batch', 'trajectories conditions')
13 | ValueBatch = namedtuple('ValueBatch', 'trajectories conditions values')
14 |
15 | class SequenceDataset(torch.utils.data.Dataset):
16 |
17 | def __init__(self, env='hopper-medium-replay', horizon=64,
18 | normalizer='LimitsNormalizer', preprocess_fns=[], max_path_length=1000,
19 | max_n_episodes=10000, termination_penalty=0, use_padding=True, discount=0.99, returns_scale=1000, include_returns=False):
20 | self.preprocess_fn = get_preprocess_fn(preprocess_fns, env)
21 | self.env = env = load_environment(env)
22 | self.returns_scale = returns_scale
23 | self.horizon = horizon
24 | self.max_path_length = max_path_length
25 | self.discount = discount
26 | self.discounts = self.discount ** np.arange(self.max_path_length)[:, None]
27 | self.use_padding = use_padding
28 | self.include_returns = include_returns
29 | itr = sequence_dataset(env, self.preprocess_fn)
30 |
31 | fields = ReplayBuffer(max_n_episodes, max_path_length, termination_penalty)
32 | for i, episode in enumerate(itr):
33 | fields.add_path(episode)
34 | fields.finalize()
35 |
36 | self.normalizer = DatasetNormalizer(fields, normalizer, path_lengths=fields['path_lengths'])
37 | self.indices = self.make_indices(fields.path_lengths, horizon)
38 |
39 | self.observation_dim = fields.observations.shape[-1]
40 | self.action_dim = fields.actions.shape[-1]
41 | self.fields = fields
42 | self.n_episodes = fields.n_episodes
43 | self.path_lengths = fields.path_lengths
44 | self.normalize()
45 |
46 | print(fields)
47 | # shapes = {key: val.shape for key, val in self.fields.items()}
48 | # print(f'[ datasets/mujoco ] Dataset fields: {shapes}')
49 |
50 | def normalize(self, keys=['observations', 'actions']):
51 | '''
52 | normalize fields that will be predicted by the diffusion model
53 | '''
54 | for key in keys:
55 | array = self.fields[key].reshape(self.n_episodes*self.max_path_length, -1)
56 | normed = self.normalizer(array, key)
57 | self.fields[f'normed_{key}'] = normed.reshape(self.n_episodes, self.max_path_length, -1)
58 |
59 | def make_indices(self, path_lengths, horizon):
60 | '''
61 | makes indices for sampling from dataset;
62 | each index maps to a datapoint
63 | '''
64 | indices = []
65 | for i, path_length in enumerate(path_lengths):
66 | max_start = min(path_length - 1, self.max_path_length - horizon)
67 | if not self.use_padding:
68 | max_start = min(max_start, path_length - horizon)
69 | for start in range(max_start):
70 | end = start + horizon
71 | indices.append((i, start, end))
72 | indices = np.array(indices)
73 | return indices
74 |
75 | def get_conditions(self, observations):
76 | '''
77 | condition on current observation for planning
78 | '''
79 | return {0: observations[0]}
80 |
81 | def __len__(self):
82 | return len(self.indices)
83 |
84 | def __getitem__(self, idx, eps=1e-4):
85 | path_ind, start, end = self.indices[idx]
86 |
87 | observations = self.fields.normed_observations[path_ind, start:end]
88 | actions = self.fields.normed_actions[path_ind, start:end]
89 |
90 | conditions = self.get_conditions(observations)
91 | trajectories = np.concatenate([actions, observations], axis=-1)
92 |
93 | if self.include_returns:
94 | rewards = self.fields.rewards[path_ind, start:]
95 | discounts = self.discounts[:len(rewards)]
96 | returns = (discounts * rewards).sum()
97 | returns = np.array([returns/self.returns_scale], dtype=np.float32)
98 | batch = RewardBatch(trajectories, conditions, returns)
99 | else:
100 | batch = Batch(trajectories, conditions)
101 |
102 | return batch
103 |
104 | class CondSequenceDataset(torch.utils.data.Dataset):
105 |
106 | def __init__(self, env='hopper-medium-replay', horizon=64,
107 | normalizer='LimitsNormalizer', preprocess_fns=[], max_path_length=1000,
108 | max_n_episodes=10000, termination_penalty=0, use_padding=True, discount=0.99, returns_scale=1000, include_returns=False):
109 | self.preprocess_fn = get_preprocess_fn(preprocess_fns, env)
110 | self.env = env = load_environment(env)
111 | self.returns_scale = returns_scale
112 | self.horizon = horizon
113 | self.max_path_length = max_path_length
114 | self.discount = discount
115 | self.discounts = self.discount ** np.arange(self.max_path_length)[:, None]
116 | self.use_padding = use_padding
117 | self.include_returns = include_returns
118 | itr = sequence_dataset(env, self.preprocess_fn)
119 |
120 | fields = ReplayBuffer(max_n_episodes, max_path_length, termination_penalty)
121 | for i, episode in enumerate(itr):
122 | fields.add_path(episode)
123 | fields.finalize()
124 |
125 | self.normalizer = DatasetNormalizer(fields, normalizer, path_lengths=fields['path_lengths'])
126 | self.indices = self.make_indices(fields.path_lengths, horizon)
127 |
128 | self.observation_dim = fields.observations.shape[-1]
129 | self.action_dim = fields.actions.shape[-1]
130 | self.fields = fields
131 | self.n_episodes = fields.n_episodes
132 | self.path_lengths = fields.path_lengths
133 | self.normalize()
134 |
135 | print(fields)
136 | # shapes = {key: val.shape for key, val in self.fields.items()}
137 | # print(f'[ datasets/mujoco ] Dataset fields: {shapes}')
138 |
139 | def normalize(self, keys=['observations', 'actions']):
140 | '''
141 | normalize fields that will be predicted by the diffusion model
142 | '''
143 | for key in keys:
144 | array = self.fields[key].reshape(self.n_episodes*self.max_path_length, -1)
145 | normed = self.normalizer(array, key)
146 | self.fields[f'normed_{key}'] = normed.reshape(self.n_episodes, self.max_path_length, -1)
147 |
148 | def make_indices(self, path_lengths, horizon):
149 | '''
150 | makes indices for sampling from dataset;
151 | each index maps to a datapoint
152 | '''
153 | indices = []
154 | for i, path_length in enumerate(path_lengths):
155 | max_start = min(path_length - 1, self.max_path_length - horizon)
156 | if not self.use_padding:
157 | max_start = min(max_start, path_length - horizon)
158 | for start in range(max_start):
159 | end = start + horizon
160 | indices.append((i, start, end))
161 | indices = np.array(indices)
162 | return indices
163 |
164 | def __len__(self):
165 | return len(self.indices)
166 |
167 | def __getitem__(self, idx, eps=1e-4):
168 | path_ind, start, end = self.indices[idx]
169 |
170 | t_step = np.random.randint(0, self.horizon)
171 |
172 | observations = self.fields.normed_observations[path_ind, start:end]
173 | actions = self.fields.normed_actions[path_ind, start:end]
174 |
175 | traj_dim = self.action_dim + self.observation_dim
176 |
177 | conditions = np.ones((self.horizon, 2*traj_dim)).astype(np.float32)
178 |
179 | # Set up conditional masking
180 | conditions[t_step:,:self.action_dim] = 0
181 | conditions[:,traj_dim:] = 0
182 | conditions[t_step,traj_dim:traj_dim+self.action_dim] = 1
183 |
184 | if t_step < self.horizon-1:
185 | observations[t_step+1:] = 0
186 |
187 | trajectories = np.concatenate([actions, observations], axis=-1)
188 |
189 | if self.include_returns:
190 | rewards = self.fields.rewards[path_ind, start:]
191 | discounts = self.discounts[:len(rewards)]
192 | returns = (discounts * rewards).sum()
193 | returns = np.array([returns/self.returns_scale], dtype=np.float32)
194 | batch = RewardBatch(trajectories, conditions, returns)
195 | else:
196 | batch = Batch(trajectories, conditions)
197 |
198 | return batch
199 |
200 | class GoalDataset(SequenceDataset):
201 |
202 | def get_conditions(self, observations):
203 | '''
204 | condition on both the current observation and the last observation in the plan
205 | '''
206 | return {
207 | 0: observations[0],
208 | self.horizon - 1: observations[-1],
209 | }
210 |
211 | class ValueDataset(SequenceDataset):
212 | '''
213 | adds a value field to the datapoints for training the value function
214 | '''
215 |
216 | def __init__(self, *args, discount=0.99, **kwargs):
217 | super().__init__(*args, **kwargs)
218 | self.discount = discount
219 | self.discounts = self.discount ** np.arange(self.max_path_length)[:,None]
220 |
221 | def __getitem__(self, idx):
222 | batch = super().__getitem__(idx)
223 | path_ind, start, end = self.indices[idx]
224 | rewards = self.fields['rewards'][path_ind, start:]
225 | discounts = self.discounts[:len(rewards)]
226 | value = (discounts * rewards).sum()
227 | value = np.array([value], dtype=np.float32)
228 | value_batch = ValueBatch(*batch, value)
229 | return value_batch
230 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/environments/__init__.py:
--------------------------------------------------------------------------------
1 | from .registration import register_environments
2 |
3 | registered_environments = register_environments()
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/environments/ant.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from gym import utils
4 | from gym.envs.mujoco import mujoco_env
5 |
6 | '''
7 | qpos : 15
8 | qvel : 14
9 | 0-2: root x, y, z
10 | 3-7: root quat
11 | 7 : front L hip
12 | 8 : front L ankle
13 | 9 : front R hip
14 | 10 : front R ankle
15 | 11 : back L hip
16 | 12 : back L ankle
17 | 13 : back R hip
18 | 14 : back R ankle
19 |
20 | '''
21 |
22 | class AntFullObsEnv(mujoco_env.MujocoEnv, utils.EzPickle):
23 | def __init__(self):
24 | asset_path = os.path.join(
25 | os.path.dirname(__file__), 'assets/ant.xml')
26 | mujoco_env.MujocoEnv.__init__(self, asset_path, 5)
27 | utils.EzPickle.__init__(self)
28 |
29 | def step(self, a):
30 | xposbefore = self.get_body_com("torso")[0]
31 | self.do_simulation(a, self.frame_skip)
32 | xposafter = self.get_body_com("torso")[0]
33 | forward_reward = (xposafter - xposbefore) / self.dt
34 | ctrl_cost = 0.5 * np.square(a).sum()
35 | contact_cost = (
36 | 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
37 | )
38 | survive_reward = 1.0
39 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward
40 | state = self.state_vector()
41 | notdone = np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0
42 | done = not notdone
43 | ob = self._get_obs()
44 | return (
45 | ob,
46 | reward,
47 | done,
48 | dict(
49 | reward_forward=forward_reward,
50 | reward_ctrl=-ctrl_cost,
51 | reward_contact=-contact_cost,
52 | reward_survive=survive_reward,
53 | ),
54 | )
55 |
56 | def _get_obs(self):
57 | return np.concatenate(
58 | [
59 | self.sim.data.qpos.flat[2:],
60 | self.sim.data.qvel.flat,
61 | np.clip(self.sim.data.cfrc_ext, -1, 1).flat,
62 | ]
63 | )
64 |
65 | def reset_model(self):
66 | qpos = self.init_qpos + self.np_random.uniform(
67 | size=self.model.nq, low=-0.1, high=0.1
68 | )
69 | qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1
70 | self.set_state(qpos, qvel)
71 | return self._get_obs()
72 |
73 | def viewer_setup(self):
74 | self.viewer.cam.distance = self.model.stat.extent * 0.5
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/environments/assets/ant.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/environments/assets/half_cheetah.xml:
--------------------------------------------------------------------------------
1 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/environments/assets/hopper.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
44 |
45 |
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/environments/assets/walker2d.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
57 |
58 |
59 |
60 |
61 |
62 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/environments/half_cheetah.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from gym import utils
4 | from gym.envs.mujoco import mujoco_env
5 |
6 | class HalfCheetahFullObsEnv(mujoco_env.MujocoEnv, utils.EzPickle):
7 | def __init__(self):
8 | asset_path = os.path.join(
9 | os.path.dirname(__file__), 'assets/half_cheetah.xml')
10 | mujoco_env.MujocoEnv.__init__(self, asset_path, 5)
11 | utils.EzPickle.__init__(self)
12 |
13 | def step(self, action):
14 | xposbefore = self.sim.data.qpos[0]
15 | self.do_simulation(action, self.frame_skip)
16 | xposafter = self.sim.data.qpos[0]
17 | ob = self._get_obs()
18 | reward_ctrl = - 0.1 * np.square(action).sum()
19 | reward_run = (xposafter - xposbefore)/self.dt
20 | reward = reward_ctrl + reward_run
21 | done = False
22 | return ob, reward, done, dict(reward_run=reward_run, reward_ctrl=reward_ctrl)
23 |
24 | def _get_obs(self):
25 | return np.concatenate([
26 | # self.sim.data.qpos.flat[1:],
27 | self.sim.data.qpos.flat, #[1:],
28 | self.sim.data.qvel.flat,
29 | ])
30 |
31 | def reset_model(self):
32 | qpos = self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq)
33 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
34 | self.set_state(qpos, qvel)
35 | return self._get_obs()
36 |
37 | def viewer_setup(self):
38 | self.viewer.cam.distance = self.model.stat.extent * 0.5
39 |
40 | def set(self, state):
41 | qpos_dim = self.sim.data.qpos.size
42 | qpos = state[:qpos_dim]
43 | qvel = state[qpos_dim:]
44 | self.set_state(qpos, qvel)
45 | return self._get_obs()
46 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/environments/hopper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from gym import utils
4 | from gym.envs.mujoco import mujoco_env
5 |
6 | class HopperFullObsEnv(mujoco_env.MujocoEnv, utils.EzPickle):
7 | def __init__(self):
8 | asset_path = os.path.join(
9 | os.path.dirname(__file__), 'assets/hopper.xml')
10 | mujoco_env.MujocoEnv.__init__(self, asset_path, 4)
11 | utils.EzPickle.__init__(self)
12 |
13 | def step(self, a):
14 | posbefore = self.sim.data.qpos[0]
15 | self.do_simulation(a, self.frame_skip)
16 | posafter, height, ang = self.sim.data.qpos[0:3]
17 | alive_bonus = 1.0
18 | reward = (posafter - posbefore) / self.dt
19 | reward += alive_bonus
20 | reward -= 1e-3 * np.square(a).sum()
21 | s = self.state_vector()
22 | done = not (np.isfinite(s).all() and (np.abs(s[2:]) < 100).all() and
23 | (height > .7) and (abs(ang) < .2))
24 | ob = self._get_obs()
25 | return ob, reward, done, {}
26 |
27 | def _get_obs(self):
28 | return np.concatenate([
29 | # self.sim.data.qpos.flat[1:],
30 | self.sim.data.qpos.flat,
31 | np.clip(self.sim.data.qvel.flat, -10, 10)
32 | ])
33 |
34 | def reset_model(self):
35 | qpos = self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq)
36 | qvel = self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv)
37 | self.set_state(qpos, qvel)
38 | return self._get_obs()
39 |
40 | def viewer_setup(self):
41 | self.viewer.cam.trackbodyid = 2
42 | self.viewer.cam.distance = self.model.stat.extent * 0.75
43 | self.viewer.cam.lookat[2] = 1.15
44 | self.viewer.cam.elevation = -20
45 |
46 | def set(self, state):
47 | qpos_dim = self.sim.data.qpos.size
48 | qpos = state[:qpos_dim]
49 | qvel = state[qpos_dim:]
50 | self.set_state(qpos, qvel)
51 | return self._get_obs()
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/environments/registration.py:
--------------------------------------------------------------------------------
1 | import gym
2 |
3 | ENVIRONMENT_SPECS = (
4 | {
5 | 'id': 'HopperFullObs-v2',
6 | 'entry_point': ('diffuser.environments.hopper:HopperFullObsEnv'),
7 | },
8 | {
9 | 'id': 'HalfCheetahFullObs-v2',
10 | 'entry_point': ('diffuser.environments.half_cheetah:HalfCheetahFullObsEnv'),
11 | },
12 | {
13 | 'id': 'Walker2dFullObs-v2',
14 | 'entry_point': ('diffuser.environments.walker2d:Walker2dFullObsEnv'),
15 | },
16 | {
17 | 'id': 'AntFullObs-v2',
18 | 'entry_point': ('diffuser.environments.ant:AntFullObsEnv'),
19 | },
20 | )
21 |
22 | def register_environments():
23 | try:
24 | for environment in ENVIRONMENT_SPECS:
25 | gym.register(**environment)
26 |
27 | gym_ids = tuple(
28 | environment_spec['id']
29 | for environment_spec in ENVIRONMENT_SPECS)
30 |
31 | return gym_ids
32 | except:
33 | print('[ diffuser/environments/registration ] WARNING: not registering diffuser environments')
34 | return tuple()
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/environments/walker2d.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from gym import utils
4 | from gym.envs.mujoco import mujoco_env
5 |
6 | class Walker2dFullObsEnv(mujoco_env.MujocoEnv, utils.EzPickle):
7 |
8 | def __init__(self):
9 | asset_path = os.path.join(
10 | os.path.dirname(__file__), 'assets/walker2d.xml')
11 | mujoco_env.MujocoEnv.__init__(self, asset_path, 4)
12 | utils.EzPickle.__init__(self)
13 |
14 | def step(self, a):
15 | posbefore = self.sim.data.qpos[0]
16 | self.do_simulation(a, self.frame_skip)
17 | posafter, height, ang = self.sim.data.qpos[0:3]
18 | alive_bonus = 1.0
19 | reward = ((posafter - posbefore) / self.dt)
20 | reward += alive_bonus
21 | reward -= 1e-3 * np.square(a).sum()
22 | done = not (height > 0.8 and height < 2.0 and
23 | ang > -1.0 and ang < 1.0)
24 | ob = self._get_obs()
25 | return ob, reward, done, {}
26 |
27 | def _get_obs(self):
28 | qpos = self.sim.data.qpos
29 | qvel = self.sim.data.qvel
30 | return np.concatenate([qpos, qvel]).ravel()
31 |
32 | def reset_model(self):
33 | self.set_state(
34 | self.init_qpos + self.np_random.uniform(low=-.005, high=.005, size=self.model.nq),
35 | self.init_qvel + self.np_random.uniform(low=-.005, high=.005, size=self.model.nv)
36 | )
37 | return self._get_obs()
38 |
39 | def viewer_setup(self):
40 | self.viewer.cam.trackbodyid = 2
41 | self.viewer.cam.distance = self.model.stat.extent * 0.5
42 | self.viewer.cam.lookat[2] = 1.15
43 | self.viewer.cam.elevation = -20
44 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .temporal import TemporalUnet, TemporalValue, MLPnet
2 | from .diffusion import GaussianDiffusion, ActionGaussianDiffusion, GaussianInvDynDiffusion
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/models/helpers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import einops
7 | from einops.layers.torch import Rearrange
8 | import pdb
9 |
10 | import diffuser.utils as utils
11 |
12 | #-----------------------------------------------------------------------------#
13 | #---------------------------------- modules ----------------------------------#
14 | #-----------------------------------------------------------------------------#
15 |
16 | class SinusoidalPosEmb(nn.Module):
17 | def __init__(self, dim):
18 | super().__init__()
19 | self.dim = dim
20 |
21 | def forward(self, x):
22 | device = x.device
23 | half_dim = self.dim // 2
24 | emb = math.log(10000) / (half_dim - 1)
25 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
26 | emb = x[:, None] * emb[None, :]
27 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
28 | return emb
29 |
30 | class Downsample1d(nn.Module):
31 | def __init__(self, dim):
32 | super().__init__()
33 | self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
34 |
35 | def forward(self, x):
36 | return self.conv(x)
37 |
38 | class Upsample1d(nn.Module):
39 | def __init__(self, dim):
40 | super().__init__()
41 | self.conv = nn.ConvTranspose1d(dim, dim, 4, 2, 1)
42 |
43 | def forward(self, x):
44 | return self.conv(x)
45 |
46 | class Conv1dBlock(nn.Module):
47 | '''
48 | Conv1d --> GroupNorm --> Mish
49 | '''
50 |
51 | def __init__(self, inp_channels, out_channels, kernel_size, mish=True, n_groups=8):
52 | super().__init__()
53 |
54 | if mish:
55 | act_fn = nn.Mish()
56 | else:
57 | act_fn = nn.SiLU()
58 |
59 | self.block = nn.Sequential(
60 | nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2),
61 | Rearrange('batch channels horizon -> batch channels 1 horizon'),
62 | nn.GroupNorm(n_groups, out_channels),
63 | Rearrange('batch channels 1 horizon -> batch channels horizon'),
64 | act_fn,
65 | )
66 |
67 | def forward(self, x):
68 | return self.block(x)
69 |
70 |
71 | #-----------------------------------------------------------------------------#
72 | #---------------------------------- sampling ---------------------------------#
73 | #-----------------------------------------------------------------------------#
74 |
75 | def extract(a, t, x_shape):
76 | b, *_ = t.shape
77 | out = a.gather(-1, t)
78 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
79 |
80 | def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32):
81 | """
82 | cosine schedule
83 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
84 | """
85 | steps = timesteps + 1
86 | x = np.linspace(0, steps, steps)
87 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
88 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
89 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
90 | betas_clipped = np.clip(betas, a_min=0, a_max=0.999)
91 | return torch.tensor(betas_clipped, dtype=dtype)
92 |
93 | def apply_conditioning(x, conditions, action_dim):
94 | for t, val in conditions.items():
95 | x[:, t, action_dim:] = val.clone()
96 | return x
97 |
98 | #-----------------------------------------------------------------------------#
99 | #---------------------------------- losses -----------------------------------#
100 | #-----------------------------------------------------------------------------#
101 |
102 | class WeightedLoss(nn.Module):
103 |
104 | def __init__(self, weights, action_dim):
105 | super().__init__()
106 | self.register_buffer('weights', weights)
107 | self.action_dim = action_dim
108 |
109 | def forward(self, pred, targ):
110 | '''
111 | pred, targ : tensor
112 | [ batch_size x horizon x transition_dim ]
113 | '''
114 | loss = self._loss(pred, targ)
115 | weighted_loss = (loss * self.weights).mean()
116 | a0_loss = (loss[:, 0, :self.action_dim] / self.weights[0, :self.action_dim]).mean()
117 | return weighted_loss, {'a0_loss': a0_loss}
118 |
119 | class WeightedStateLoss(nn.Module):
120 |
121 | def __init__(self, weights):
122 | super().__init__()
123 | self.register_buffer('weights', weights)
124 |
125 | def forward(self, pred, targ):
126 | '''
127 | pred, targ : tensor
128 | [ batch_size x horizon x transition_dim ]
129 | '''
130 | loss = self._loss(pred, targ)
131 | weighted_loss = (loss * self.weights).mean()
132 | return weighted_loss, {'a0_loss': weighted_loss}
133 |
134 | class ValueLoss(nn.Module):
135 | def __init__(self, *args):
136 | super().__init__()
137 | pass
138 |
139 | def forward(self, pred, targ):
140 | loss = self._loss(pred, targ).mean()
141 |
142 | if len(pred) > 1:
143 | corr = np.corrcoef(
144 | utils.to_np(pred).squeeze(),
145 | utils.to_np(targ).squeeze()
146 | )[0,1]
147 | else:
148 | corr = np.NaN
149 |
150 | info = {
151 | 'mean_pred': pred.mean(), 'mean_targ': targ.mean(),
152 | 'min_pred': pred.min(), 'min_targ': targ.min(),
153 | 'max_pred': pred.max(), 'max_targ': targ.max(),
154 | 'corr': corr,
155 | }
156 |
157 | return loss, info
158 |
159 | class WeightedL1(WeightedLoss):
160 |
161 | def _loss(self, pred, targ):
162 | return torch.abs(pred - targ)
163 |
164 | class WeightedL2(WeightedLoss):
165 |
166 | def _loss(self, pred, targ):
167 | return F.mse_loss(pred, targ, reduction='none')
168 |
169 | class WeightedStateL2(WeightedStateLoss):
170 |
171 | def _loss(self, pred, targ):
172 | return F.mse_loss(pred, targ, reduction='none')
173 |
174 | class ValueL1(ValueLoss):
175 |
176 | def _loss(self, pred, targ):
177 | return torch.abs(pred - targ)
178 |
179 | class ValueL2(ValueLoss):
180 |
181 | def _loss(self, pred, targ):
182 | return F.mse_loss(pred, targ, reduction='none')
183 |
184 | Losses = {
185 | 'l1': WeightedL1,
186 | 'l2': WeightedL2,
187 | 'state_l2': WeightedStateL2,
188 | 'value_l1': ValueL1,
189 | 'value_l2': ValueL2,
190 | }
191 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .serialization import *
2 | from .training import *
3 | from .progress import *
4 | # from .setup import *
5 | from .config import *
6 | from .rendering import *
7 | from .arrays import *
8 | from .colab import *
9 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/utils/arrays.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import numpy as np
3 | import torch
4 | import pdb
5 |
6 | DTYPE = torch.float
7 | DEVICE = 'cuda'
8 |
9 | #-----------------------------------------------------------------------------#
10 | #------------------------------ numpy <--> torch -----------------------------#
11 | #-----------------------------------------------------------------------------#
12 |
13 | def to_np(x):
14 | if torch.is_tensor(x):
15 | x = x.detach().cpu().numpy()
16 | return x
17 |
18 | def to_torch(x, dtype=None, device=None):
19 | dtype = dtype or DTYPE
20 | device = device or DEVICE
21 | if type(x) is dict:
22 | return {k: to_torch(v, dtype, device) for k, v in x.items()}
23 | elif torch.is_tensor(x):
24 | return x.to(device).type(dtype)
25 | # import pdb; pdb.set_trace()
26 | return torch.tensor(x, dtype=dtype, device=device)
27 |
28 | def to_device(x, device=DEVICE):
29 | if torch.is_tensor(x):
30 | return x.to(device)
31 | elif type(x) is dict:
32 | return {k: to_device(v, device) for k, v in x.items()}
33 | else:
34 | print(f'Unrecognized type in `to_device`: {type(x)}')
35 | pdb.set_trace()
36 | # return [x.to(device) for x in xs]
37 |
38 | # def atleast_2d(x, axis=0):
39 | # '''
40 | # works for both np arrays and torch tensors
41 | # '''
42 | # while len(x.shape) < 2:
43 | # shape = (1, *x.shape) if axis == 0 else (*x.shape, 1)
44 | # x = x.reshape(*shape)
45 | # return x
46 |
47 | # def to_2d(x):
48 | # dim = x.shape[-1]
49 | # return x.reshape(-1, dim)
50 |
51 | def batchify(batch, device):
52 | '''
53 | convert a single dataset item to a batch suitable for passing to a model by
54 | 1) converting np arrays to torch tensors and
55 | 2) and ensuring that everything has a batch dimension
56 | '''
57 | fn = lambda x: to_torch(x[None], device=device)
58 |
59 | batched_vals = []
60 | for field in batch._fields:
61 | val = getattr(batch, field)
62 | val = apply_dict(fn, val) if type(val) is dict else fn(val)
63 | batched_vals.append(val)
64 | return type(batch)(*batched_vals)
65 |
66 | def apply_dict(fn, d, *args, **kwargs):
67 | return {
68 | k: fn(v, *args, **kwargs)
69 | for k, v in d.items()
70 | }
71 |
72 | def normalize(x):
73 | """
74 | scales `x` to [0, 1]
75 | """
76 | x = x - x.min()
77 | x = x / x.max()
78 | return x
79 |
80 | def to_img(x):
81 | normalized = normalize(x)
82 | array = to_np(normalized)
83 | array = np.transpose(array, (1,2,0))
84 | return (array * 255).astype(np.uint8)
85 |
86 | def set_device(device):
87 | DEVICE = device
88 | if 'cuda' in device:
89 | torch.set_default_tensor_type(torch.cuda.FloatTensor)
90 |
91 | def batch_to_device(batch, device='cuda:0'):
92 | vals = [
93 | to_device(getattr(batch, field), device)
94 | for field in batch._fields
95 | ]
96 | return type(batch)(*vals)
97 |
98 | def _to_str(num):
99 | if num >= 1e6:
100 | return f'{(num/1e6):.2f} M'
101 | else:
102 | return f'{(num/1e3):.2f} k'
103 |
104 | #-----------------------------------------------------------------------------#
105 | #----------------------------- parameter counting ----------------------------#
106 | #-----------------------------------------------------------------------------#
107 |
108 | def param_to_module(param):
109 | module_name = param[::-1].split('.', maxsplit=1)[-1][::-1]
110 | return module_name
111 |
112 | def report_parameters(model, topk=10):
113 | counts = {k: p.numel() for k, p in model.named_parameters()}
114 | n_parameters = sum(counts.values())
115 | print(f'[ utils/arrays ] Total parameters: {_to_str(n_parameters)}')
116 |
117 | modules = dict(model.named_modules())
118 | sorted_keys = sorted(counts, key=lambda x: -counts[x])
119 | max_length = max([len(k) for k in sorted_keys])
120 | for i in range(topk):
121 | key = sorted_keys[i]
122 | count = counts[key]
123 | module = param_to_module(key)
124 | print(' '*8, f'{key:10}: {_to_str(count)} | {modules[module]}')
125 |
126 | remaining_parameters = sum([counts[k] for k in sorted_keys[topk:]])
127 | print(' '*8, f'... and {len(counts)-topk} others accounting for {_to_str(remaining_parameters)} parameters')
128 | return n_parameters
129 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/utils/cloud.py:
--------------------------------------------------------------------------------
1 | import shlex
2 | import subprocess
3 | import pdb
4 |
5 | def sync_logs(logdir, bucket, background=False):
6 | ## remove prefix 'logs' on google cloud
7 | destination = 'logs' + logdir.split('logs')[-1]
8 | upload_blob(logdir, destination, bucket, background)
9 |
10 | def upload_blob(source, destination, bucket, background):
11 | command = f'gsutil -m -o GSUtil:parallel_composite_upload_threshold=150M rsync -r {source} {bucket}/{destination}'
12 | print(f'[ utils/cloud ] Syncing bucket: {command}')
13 | command = shlex.split(command)
14 |
15 | if background:
16 | subprocess.Popen(command)
17 | else:
18 | subprocess.call(command)
19 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/utils/colab.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import einops
4 | import matplotlib.pyplot as plt
5 | from tqdm import tqdm
6 |
7 | try:
8 | import io
9 | import base64
10 | from IPython.display import HTML
11 | from IPython import display as ipythondisplay
12 | except:
13 | print('[ utils/colab ] Warning: not importing colab dependencies')
14 |
15 | from .serialization import mkdir
16 | from .arrays import to_torch, to_np
17 | from .video import save_video
18 |
19 |
20 | def run_diffusion(model, dataset, obs, n_samples=1, device='cuda:0', **diffusion_kwargs):
21 | ## normalize observation for model
22 | obs = dataset.normalizer.normalize(obs, 'observations')
23 |
24 | ## add a batch dimension and repeat for multiple samples
25 | ## [ observation_dim ] --> [ n_samples x observation_dim ]
26 | obs = obs[None].repeat(n_samples, axis=0)
27 |
28 | ## format `conditions` input for model
29 | conditions = {
30 | 0: to_torch(obs, device=device)
31 | }
32 |
33 | samples, diffusion = model.conditional_sample(conditions,
34 | return_diffusion=True, verbose=False, **diffusion_kwargs)
35 |
36 | ## [ n_samples x (n_diffusion_steps + 1) x horizon x (action_dim + observation_dim)]
37 | diffusion = to_np(diffusion)
38 |
39 | ## extract observations
40 | ## [ n_samples x (n_diffusion_steps + 1) x horizon x observation_dim ]
41 | normed_observations = diffusion[:, :, :, dataset.action_dim:]
42 |
43 | ## unnormalize observation samples from model
44 | observations = dataset.normalizer.unnormalize(normed_observations, 'observations')
45 |
46 | ## [ (n_diffusion_steps + 1) x n_samples x horizon x observation_dim ]
47 | observations = einops.rearrange(observations,
48 | 'batch steps horizon dim -> steps batch horizon dim')
49 |
50 | return observations
51 |
52 |
53 | def show_diffusion(renderer, observations, n_repeat=100, substep=1, filename='diffusion.mp4', savebase='/content/videos'):
54 | '''
55 | observations : [ n_diffusion_steps x batch_size x horizon x observation_dim ]
56 | '''
57 | mkdir(savebase)
58 | savepath = os.path.join(savebase, filename)
59 |
60 | subsampled = observations[::substep]
61 |
62 | images = []
63 | for t in tqdm(range(len(subsampled))):
64 | observation = subsampled[t]
65 |
66 | img = renderer.composite(None, observation)
67 | images.append(img)
68 | images = np.stack(images, axis=0)
69 |
70 | ## pause at the end of video
71 | images = np.concatenate([
72 | images,
73 | images[-1:].repeat(n_repeat, axis=0)
74 | ], axis=0)
75 |
76 | save_video(savepath, images)
77 | show_video(savepath)
78 |
79 |
80 | def show_sample(renderer, observations, filename='sample.mp4', savebase='/content/videos'):
81 | '''
82 | observations : [ batch_size x horizon x observation_dim ]
83 | '''
84 |
85 | mkdir(savebase)
86 | savepath = os.path.join(savebase, filename)
87 |
88 | images = []
89 | for rollout in observations:
90 | ## [ horizon x height x width x channels ]
91 | img = renderer._renders(rollout, partial=True)
92 | images.append(img)
93 |
94 | ## [ horizon x height x (batch_size * width) x channels ]
95 | images = np.concatenate(images, axis=2)
96 |
97 | save_video(savepath, images)
98 | show_video(savepath, height=200)
99 |
100 |
101 | def show_samples(renderer, observations_l, figsize=12):
102 | '''
103 | observations_l : [ [ n_diffusion_steps x batch_size x horizon x observation_dim ], ... ]
104 | '''
105 |
106 | images = []
107 | for observations in observations_l:
108 | path = observations[-1]
109 | img = renderer.composite(None, path)
110 | images.append(img)
111 | images = np.concatenate(images, axis=0)
112 |
113 | plt.imshow(images)
114 | plt.axis('off')
115 | plt.gcf().set_size_inches(figsize, figsize)
116 |
117 |
118 | def show_video(path, height=400):
119 | video = io.open(path, 'r+b').read()
120 | encoded = base64.b64encode(video)
121 | ipythondisplay.display(HTML(data=''''''.format(height, encoded.decode('ascii'))))
125 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/utils/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import collections
3 | import importlib
4 | import pickle
5 | from ml_logger import logger
6 |
7 | def import_class(_class):
8 | if type(_class) is not str: return _class
9 | ## 'diffusion' on standard installs
10 | repo_name = __name__.split('.')[0]
11 | ## eg, 'utils'
12 | module_name = '.'.join(_class.split('.')[:-1])
13 | ## eg, 'Renderer'
14 | class_name = _class.split('.')[-1]
15 | ## eg, 'diffusion.utils'
16 | module = importlib.import_module(f'{repo_name}.{module_name}')
17 | ## eg, diffusion.utils.Renderer
18 | _class = getattr(module, class_name)
19 | print(f'[ utils/config ] Imported {repo_name}.{module_name}:{class_name}')
20 | return _class
21 |
22 | class Config(collections.Mapping):
23 |
24 | def __init__(self, _class, verbose=True, savepath=None, device=None, **kwargs):
25 | self._class = import_class(_class)
26 | self._device = device
27 | self._dict = {}
28 |
29 | for key, val in kwargs.items():
30 | self._dict[key] = val
31 |
32 | if verbose:
33 | print(self)
34 |
35 | if savepath is not None:
36 | logger.save_pkl(self, savepath)
37 | print(f'[ utils/config ] Saved config to: {savepath}\n')
38 |
39 | def __repr__(self):
40 | string = f'\n[utils/config ] Config: {self._class}\n'
41 | for key in sorted(self._dict.keys()):
42 | val = self._dict[key]
43 | string += f' {key}: {val}\n'
44 | return string
45 |
46 | def __iter__(self):
47 | return iter(self._dict)
48 |
49 | def __getitem__(self, item):
50 | return self._dict[item]
51 |
52 | def __len__(self):
53 | return len(self._dict)
54 |
55 | def __getattr__(self, attr):
56 | if attr == '_dict' and '_dict' not in vars(self):
57 | self._dict = {}
58 | return self._dict
59 | try:
60 | return self._dict[attr]
61 | except KeyError:
62 | raise AttributeError(attr)
63 |
64 | def __call__(self, *args, **kwargs):
65 | instance = self._class(*args, **kwargs, **self._dict)
66 | if self._device:
67 | instance = instance.to(self._device)
68 | return instance
69 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/utils/git_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import git
3 | import pdb
4 |
5 | PROJECT_PATH = os.path.dirname(
6 | os.path.realpath(os.path.join(__file__, '..', '..')))
7 |
8 | def get_repo(path=PROJECT_PATH, search_parent_directories=True):
9 | repo = git.Repo(
10 | path, search_parent_directories=search_parent_directories)
11 | return repo
12 |
13 | def get_git_rev(*args, **kwargs):
14 | try:
15 | repo = get_repo(*args, **kwargs)
16 | if repo.head.is_detached:
17 | git_rev = repo.head.object.name_rev
18 | else:
19 | git_rev = repo.active_branch.commit.name_rev
20 | except:
21 | git_rev = None
22 |
23 | return git_rev
24 |
25 | def git_diff(*args, **kwargs):
26 | repo = get_repo(*args, **kwargs)
27 | diff = repo.git.diff()
28 | return diff
29 |
30 | def save_git_diff(savepath, *args, **kwargs):
31 | diff = git_diff(*args, **kwargs)
32 | with open(savepath, 'w') as f:
33 | f.write(diff)
34 |
35 | if __name__ == '__main__':
36 |
37 | git_rev = get_git_rev()
38 | print(git_rev)
39 |
40 | save_git_diff('diff_test.txt')
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/utils/iql.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import jax
4 | import jax.numpy as jnp
5 | import functools
6 | import pdb
7 |
8 | from diffuser.iql.common import Model
9 | from diffuser.iql.value_net import DoubleCritic
10 |
11 | def load_q(env, loadpath, hidden_dims=(256, 256), seed=42):
12 | print(f'[ utils/iql ] Loading Q: {loadpath}')
13 | observations = env.observation_space.sample()[np.newaxis]
14 | actions = env.action_space.sample()[np.newaxis]
15 |
16 | rng = jax.random.PRNGKey(seed)
17 | rng, key = jax.random.split(rng)
18 |
19 | critic_def = DoubleCritic(hidden_dims)
20 | critic = Model.create(critic_def,
21 | inputs=[key, observations, actions])
22 |
23 | ## allows for relative paths
24 | loadpath = os.path.expanduser(loadpath)
25 | critic = critic.load(loadpath)
26 | return critic
27 |
28 | class JaxWrapper:
29 |
30 | def __init__(self, env, loadpath, *args, **kwargs):
31 | self.model = load_q(env, loadpath)
32 |
33 | @functools.partial(jax.jit, static_argnames=('self'), device=jax.devices('cpu')[0])
34 | def forward(self, xs):
35 | Qs = self.model(*xs)
36 | Q = jnp.minimum(*Qs)
37 | return Q
38 |
39 | def __call__(self, *xs):
40 | Q = self.forward(xs)
41 | return np.array(Q)
42 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/utils/progress.py:
--------------------------------------------------------------------------------
1 | import time
2 | import math
3 | import pdb
4 |
5 | class Progress:
6 |
7 | def __init__(self, total, name = 'Progress', ncol=3, max_length=20, indent=0, line_width=100, speed_update_freq=100):
8 | self.total = total
9 | self.name = name
10 | self.ncol = ncol
11 | self.max_length = max_length
12 | self.indent = indent
13 | self.line_width = line_width
14 | self._speed_update_freq = speed_update_freq
15 |
16 | self._step = 0
17 | self._prev_line = '\033[F'
18 | self._clear_line = ' ' * self.line_width
19 |
20 | self._pbar_size = self.ncol * self.max_length
21 | self._complete_pbar = '#' * self._pbar_size
22 | self._incomplete_pbar = ' ' * self._pbar_size
23 |
24 | self.lines = ['']
25 | self.fraction = '{} / {}'.format(0, self.total)
26 |
27 | self.resume()
28 |
29 |
30 | def update(self, description, n=1):
31 | self._step += n
32 | if self._step % self._speed_update_freq == 0:
33 | self._time0 = time.time()
34 | self._step0 = self._step
35 | self.set_description(description)
36 |
37 | def resume(self):
38 | self._skip_lines = 1
39 | print('\n', end='')
40 | self._time0 = time.time()
41 | self._step0 = self._step
42 |
43 | def pause(self):
44 | self._clear()
45 | self._skip_lines = 1
46 |
47 | def set_description(self, params=[]):
48 |
49 | if type(params) == dict:
50 | params = sorted([
51 | (key, val)
52 | for key, val in params.items()
53 | ])
54 |
55 | ############
56 | # Position #
57 | ############
58 | self._clear()
59 |
60 | ###########
61 | # Percent #
62 | ###########
63 | percent, fraction = self._format_percent(self._step, self.total)
64 | self.fraction = fraction
65 |
66 | #########
67 | # Speed #
68 | #########
69 | speed = self._format_speed(self._step)
70 |
71 | ##########
72 | # Params #
73 | ##########
74 | num_params = len(params)
75 | nrow = math.ceil(num_params / self.ncol)
76 | params_split = self._chunk(params, self.ncol)
77 | params_string, lines = self._format(params_split)
78 | self.lines = lines
79 |
80 |
81 | description = '{} | {}{}'.format(percent, speed, params_string)
82 | print(description)
83 | self._skip_lines = nrow + 1
84 |
85 | def append_description(self, descr):
86 | self.lines.append(descr)
87 |
88 | def _clear(self):
89 | position = self._prev_line * self._skip_lines
90 | empty = '\n'.join([self._clear_line for _ in range(self._skip_lines)])
91 | print(position, end='')
92 | print(empty)
93 | print(position, end='')
94 |
95 | def _format_percent(self, n, total):
96 | if total:
97 | percent = n / float(total)
98 |
99 | complete_entries = int(percent * self._pbar_size)
100 | incomplete_entries = self._pbar_size - complete_entries
101 |
102 | pbar = self._complete_pbar[:complete_entries] + self._incomplete_pbar[:incomplete_entries]
103 | fraction = '{} / {}'.format(n, total)
104 | string = '{} [{}] {:3d}%'.format(fraction, pbar, int(percent*100))
105 | else:
106 | fraction = '{}'.format(n)
107 | string = '{} iterations'.format(n)
108 | return string, fraction
109 |
110 | def _format_speed(self, n):
111 | num_steps = n - self._step0
112 | t = time.time() - self._time0
113 | speed = num_steps / t
114 | string = '{:.1f} Hz'.format(speed)
115 | if num_steps > 0:
116 | self._speed = string
117 | return string
118 |
119 | def _chunk(self, l, n):
120 | return [l[i:i+n] for i in range(0, len(l), n)]
121 |
122 | def _format(self, chunks):
123 | lines = [self._format_chunk(chunk) for chunk in chunks]
124 | lines.insert(0,'')
125 | padding = '\n' + ' '*self.indent
126 | string = padding.join(lines)
127 | return string, lines
128 |
129 | def _format_chunk(self, chunk):
130 | line = ' | '.join([self._format_param(param) for param in chunk])
131 | return line
132 |
133 | def _format_param(self, param):
134 | k, v = param
135 | return '{} : {}'.format(k, v)[:self.max_length]
136 |
137 | def stamp(self):
138 | if self.lines != ['']:
139 | params = ' | '.join(self.lines)
140 | string = '[ {} ] {}{} | {}'.format(self.name, self.fraction, params, self._speed)
141 | self._clear()
142 | print(string, end='\n')
143 | self._skip_lines = 1
144 | else:
145 | self._clear()
146 | self._skip_lines = 0
147 |
148 | def close(self):
149 | self.pause()
150 |
151 | class Silent:
152 |
153 | def __init__(self, *args, **kwargs):
154 | pass
155 |
156 | def __getattr__(self, attr):
157 | return lambda *args: None
158 |
159 |
160 | if __name__ == '__main__':
161 | silent = Silent()
162 | silent.update()
163 | silent.stamp()
164 |
165 | num_steps = 1000
166 | progress = Progress(num_steps)
167 | for i in range(num_steps):
168 | progress.update()
169 | params = [
170 | ['A', '{:06d}'.format(i)],
171 | ['B', '{:06d}'.format(i)],
172 | ['C', '{:06d}'.format(i)],
173 | ['D', '{:06d}'.format(i)],
174 | ['E', '{:06d}'.format(i)],
175 | ['F', '{:06d}'.format(i)],
176 | ['G', '{:06d}'.format(i)],
177 | ['H', '{:06d}'.format(i)],
178 | ]
179 | progress.set_description(params)
180 | time.sleep(0.01)
181 | progress.close()
182 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/utils/rendering.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import einops
4 | import imageio
5 | import matplotlib.pyplot as plt
6 | from matplotlib.colors import ListedColormap
7 | import gym
8 | import mujoco_py as mjc
9 | import warnings
10 | import pdb
11 |
12 | from .arrays import to_np
13 | from .video import save_video, save_videos
14 | from ml_logger import logger
15 |
16 | from diffuser.datasets.d4rl import load_environment
17 |
18 | #-----------------------------------------------------------------------------#
19 | #------------------------------- helper structs ------------------------------#
20 | #-----------------------------------------------------------------------------#
21 |
22 | def env_map(env_name):
23 | '''
24 | map D4RL dataset names to custom fully-observed
25 | variants for rendering
26 | '''
27 | if 'halfcheetah' in env_name:
28 | return 'HalfCheetahFullObs-v2'
29 | elif 'hopper' in env_name:
30 | return 'HopperFullObs-v2'
31 | elif 'walker2d' in env_name:
32 | return 'Walker2dFullObs-v2'
33 | else:
34 | return env_name
35 |
36 | #-----------------------------------------------------------------------------#
37 | #------------------------------ helper functions -----------------------------#
38 | #-----------------------------------------------------------------------------#
39 |
40 | def get_image_mask(img):
41 | background = (img == 255).all(axis=-1, keepdims=True)
42 | mask = ~background.repeat(3, axis=-1)
43 | return mask
44 |
45 | def atmost_2d(x):
46 | while x.ndim > 2:
47 | x = x.squeeze(0)
48 | return x
49 |
50 | #-----------------------------------------------------------------------------#
51 | #---------------------------------- renderers --------------------------------#
52 | #-----------------------------------------------------------------------------#
53 |
54 | class MuJoCoRenderer:
55 | '''
56 | default mujoco renderer
57 | '''
58 |
59 | def __init__(self, env):
60 | if type(env) is str:
61 | env = env_map(env)
62 | self.env = gym.make(env)
63 | else:
64 | self.env = env
65 | ## - 1 because the envs in renderer are fully-observed
66 | ## @TODO : clean up
67 | self.observation_dim = np.prod(self.env.observation_space.shape) - 1
68 | self.action_dim = np.prod(self.env.action_space.shape)
69 | try:
70 | self.viewer = mjc.MjRenderContextOffscreen(self.env.sim)
71 | except:
72 | print('[ utils/rendering ] Warning: could not initialize offscreen renderer')
73 | self.viewer = None
74 |
75 | def pad_observation(self, observation):
76 | state = np.concatenate([
77 | np.zeros(1),
78 | observation,
79 | ])
80 | return state
81 |
82 | def pad_observations(self, observations):
83 | qpos_dim = self.env.sim.data.qpos.size
84 | ## xpos is hidden
85 | xvel_dim = qpos_dim - 1
86 | xvel = observations[:, xvel_dim]
87 | xpos = np.cumsum(xvel) * self.env.dt
88 | states = np.concatenate([
89 | xpos[:,None],
90 | observations,
91 | ], axis=-1)
92 | return states
93 |
94 | def render(self, observation, dim=256, partial=False, qvel=True, render_kwargs=None, conditions=None):
95 |
96 | if type(dim) == int:
97 | dim = (dim, dim)
98 |
99 | if self.viewer is None:
100 | return np.zeros((*dim, 3), np.uint8)
101 |
102 | if render_kwargs is None:
103 | xpos = observation[0] if not partial else 0
104 | render_kwargs = {
105 | 'trackbodyid': 2,
106 | 'distance': 3,
107 | 'lookat': [xpos, -0.5, 1],
108 | 'elevation': -20
109 | }
110 |
111 | for key, val in render_kwargs.items():
112 | if key == 'lookat':
113 | self.viewer.cam.lookat[:] = val[:]
114 | else:
115 | setattr(self.viewer.cam, key, val)
116 |
117 | if partial:
118 | state = self.pad_observation(observation)
119 | else:
120 | state = observation
121 |
122 | qpos_dim = self.env.sim.data.qpos.size
123 | if not qvel or state.shape[-1] == qpos_dim:
124 | qvel_dim = self.env.sim.data.qvel.size
125 | state = np.concatenate([state, np.zeros(qvel_dim)])
126 |
127 | set_state(self.env, state)
128 |
129 | self.viewer.render(*dim)
130 | data = self.viewer.read_pixels(*dim, depth=False)
131 | data = data[::-1, :, :]
132 | return data
133 |
134 | def _renders(self, observations, **kwargs):
135 | images = []
136 | for observation in observations:
137 | img = self.render(observation, **kwargs)
138 | images.append(img)
139 | return np.stack(images, axis=0)
140 |
141 | def renders(self, samples, partial=False, **kwargs):
142 | if partial:
143 | samples = self.pad_observations(samples)
144 | partial = False
145 |
146 | sample_images = self._renders(samples, partial=partial, **kwargs)
147 |
148 | composite = np.ones_like(sample_images[0]) * 255
149 |
150 | for img in sample_images:
151 | mask = get_image_mask(img)
152 | composite[mask] = img[mask]
153 |
154 | return composite
155 |
156 | def composite(self, savepath, paths, dim=(1024, 256), **kwargs):
157 |
158 | render_kwargs = {
159 | 'trackbodyid': 2,
160 | 'distance': 10,
161 | 'lookat': [5, 2, 0.5],
162 | 'elevation': 0
163 | }
164 | images = []
165 | for path in paths:
166 | ## [ H x obs_dim ]
167 | path = atmost_2d(path)
168 | img = self.renders(to_np(path), dim=dim, partial=True, qvel=True, render_kwargs=render_kwargs, **kwargs)
169 | images.append(img)
170 | images = np.concatenate(images, axis=0)
171 |
172 | if savepath is not None:
173 | fig = plt.figure()
174 | plt.imshow(images)
175 | logger.savefig(savepath, fig)
176 | print(f'Saved {len(paths)} samples to: {savepath}')
177 |
178 | return images
179 |
180 | def render_rollout(self, savepath, states, **video_kwargs):
181 | if type(states) is list: states = np.array(states)
182 | images = self._renders(states, partial=True)
183 | save_video(savepath, images, **video_kwargs)
184 |
185 | def render_plan(self, savepath, actions, observations_pred, state, fps=30):
186 | ## [ batch_size x horizon x observation_dim ]
187 | observations_real = rollouts_from_state(self.env, state, actions)
188 |
189 | ## there will be one more state in `observations_real`
190 | ## than in `observations_pred` because the last action
191 | ## does not have an associated next_state in the sampled trajectory
192 | observations_real = observations_real[:,:-1]
193 |
194 | images_pred = np.stack([
195 | self._renders(obs_pred, partial=True)
196 | for obs_pred in observations_pred
197 | ])
198 |
199 | images_real = np.stack([
200 | self._renders(obs_real, partial=False)
201 | for obs_real in observations_real
202 | ])
203 |
204 | ## [ batch_size x horizon x H x W x C ]
205 | images = np.concatenate([images_pred, images_real], axis=-2)
206 | save_videos(savepath, *images)
207 |
208 | def render_diffusion(self, savepath, diffusion_path, **video_kwargs):
209 | '''
210 | diffusion_path : [ n_diffusion_steps x batch_size x 1 x horizon x joined_dim ]
211 | '''
212 | render_kwargs = {
213 | 'trackbodyid': 2,
214 | 'distance': 10,
215 | 'lookat': [10, 2, 0.5],
216 | 'elevation': 0,
217 | }
218 |
219 | diffusion_path = to_np(diffusion_path)
220 |
221 | n_diffusion_steps, batch_size, _, horizon, joined_dim = diffusion_path.shape
222 |
223 | frames = []
224 | for t in reversed(range(n_diffusion_steps)):
225 | print(f'[ utils/renderer ] Diffusion: {t} / {n_diffusion_steps}')
226 |
227 | ## [ batch_size x horizon x observation_dim ]
228 | states_l = diffusion_path[t].reshape(batch_size, horizon, joined_dim)[:, :, :self.observation_dim]
229 |
230 | frame = []
231 | for states in states_l:
232 | img = self.composite(None, states, dim=(1024, 256), partial=True, qvel=True, render_kwargs=render_kwargs)
233 | frame.append(img)
234 | frame = np.concatenate(frame, axis=0)
235 |
236 | frames.append(frame)
237 |
238 | save_video(savepath, frames, **video_kwargs)
239 |
240 | def __call__(self, *args, **kwargs):
241 | return self.renders(*args, **kwargs)
242 |
243 | #-----------------------------------------------------------------------------#
244 | #---------------------------------- rollouts ---------------------------------#
245 | #-----------------------------------------------------------------------------#
246 |
247 | def set_state(env, state):
248 | qpos_dim = env.sim.data.qpos.size
249 | qvel_dim = env.sim.data.qvel.size
250 | if not state.size == qpos_dim + qvel_dim:
251 | warnings.warn(
252 | f'[ utils/rendering ] Expected state of size {qpos_dim + qvel_dim}, '
253 | f'but got state of size {state.size}')
254 | state = state[:qpos_dim + qvel_dim]
255 |
256 | env.set_state(state[:qpos_dim], state[qpos_dim:])
257 |
258 | def rollouts_from_state(env, state, actions_l):
259 | rollouts = np.stack([
260 | rollout_from_state(env, state, actions)
261 | for actions in actions_l
262 | ])
263 | return rollouts
264 |
265 | def rollout_from_state(env, state, actions):
266 | qpos_dim = env.sim.data.qpos.size
267 | env.set_state(state[:qpos_dim], state[qpos_dim:])
268 | observations = [env._get_obs()]
269 | for act in actions:
270 | obs, rew, term, _ = env.step(act)
271 | observations.append(obs)
272 | if term:
273 | break
274 | for i in range(len(observations), len(actions)+1):
275 | ## if terminated early, pad with zeros
276 | observations.append( np.zeros(obs.size) )
277 | return np.stack(observations)
278 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/utils/serialization.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import glob
4 | import torch
5 | import pdb
6 |
7 | from collections import namedtuple
8 |
9 | DiffusionExperiment = namedtuple('Diffusion', 'dataset renderer model diffusion ema trainer epoch')
10 |
11 | def mkdir(savepath):
12 | """
13 | returns `True` iff `savepath` is created
14 | """
15 | if not os.path.exists(savepath):
16 | os.makedirs(savepath)
17 | return True
18 | else:
19 | return False
20 |
21 | def get_latest_epoch(loadpath):
22 | states = glob.glob1(os.path.join(*loadpath), 'state_*')
23 | latest_epoch = -1
24 | for state in states:
25 | epoch = int(state.replace('state_', '').replace('.pt', ''))
26 | latest_epoch = max(epoch, latest_epoch)
27 | return latest_epoch
28 |
29 | def load_config(*loadpath):
30 | loadpath = os.path.join(*loadpath)
31 | config = pickle.load(open(loadpath, 'rb'))
32 | print(f'[ utils/serialization ] Loaded config from {loadpath}')
33 | print(config)
34 | return config
35 |
36 | def load_diffusion(*loadpath, epoch='latest', device='cuda:0'):
37 | dataset_config = load_config(*loadpath, 'dataset_config.pkl')
38 | render_config = load_config(*loadpath, 'render_config.pkl')
39 | model_config = load_config(*loadpath, 'model_config.pkl')
40 | diffusion_config = load_config(*loadpath, 'diffusion_config.pkl')
41 | trainer_config = load_config(*loadpath, 'trainer_config.pkl')
42 |
43 | ## remove absolute path for results loaded from azure
44 | ## @TODO : remove results folder from within trainer class
45 | trainer_config._dict['results_folder'] = os.path.join(*loadpath)
46 |
47 | dataset = dataset_config()
48 | renderer = render_config()
49 | model = model_config()
50 | diffusion = diffusion_config(model)
51 | trainer = trainer_config(diffusion, dataset, renderer)
52 |
53 | if epoch == 'latest':
54 | epoch = get_latest_epoch(loadpath)
55 |
56 | print(f'\n[ utils/serialization ] Loading model epoch: {epoch}\n')
57 |
58 | trainer.load(epoch)
59 |
60 | return DiffusionExperiment(dataset, renderer, model, diffusion, trainer.ema_model, trainer, epoch)
61 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/utils/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import importlib
3 | import random
4 | import numpy as np
5 | import torch
6 | from tap import Tap
7 | import pdb
8 |
9 | from .serialization import mkdir
10 | from .git_utils import (
11 | get_git_rev,
12 | save_git_diff,
13 | )
14 |
15 | def set_seed(seed):
16 | random.seed(seed)
17 | np.random.seed(seed)
18 | torch.manual_seed(seed)
19 | torch.cuda.manual_seed_all(seed)
20 |
21 | def watch(args_to_watch):
22 | def _fn(args):
23 | exp_name = []
24 | for key, label in args_to_watch:
25 | if not hasattr(args, key):
26 | continue
27 | val = getattr(args, key)
28 | if type(val) == dict:
29 | val = '_'.join(f'{k}-{v}' for k, v in val.items())
30 | exp_name.append(f'{label}{val}')
31 | exp_name = '_'.join(exp_name)
32 | exp_name = exp_name.replace('/_', '/')
33 | exp_name = exp_name.replace('(', '').replace(')', '')
34 | exp_name = exp_name.replace(', ', '-')
35 | return exp_name
36 | return _fn
37 |
38 | def lazy_fstring(template, args):
39 | ## https://stackoverflow.com/a/53671539
40 | return eval(f"f'{template}'")
41 |
42 | class Parser(Tap):
43 |
44 | def save(self):
45 | fullpath = os.path.join(self.savepath, 'args.json')
46 | print(f'[ utils/setup ] Saved args to {fullpath}')
47 | super().save(fullpath, skip_unpicklable=True)
48 |
49 | def parse_args(self, experiment=None):
50 | args = super().parse_args(known_only=True)
51 | ## if not loading from a config script, skip the result of the setup
52 | if not hasattr(args, 'config'): return args
53 | args = self.read_config(args, experiment)
54 | self.add_extras(args)
55 | self.eval_fstrings(args)
56 | self.set_seed(args)
57 | self.get_commit(args)
58 | self.generate_exp_name(args)
59 | self.mkdir(args)
60 | self.save_diff(args)
61 | return args
62 |
63 | def read_config(self, args, experiment):
64 | '''
65 | Load parameters from config file
66 | '''
67 | dataset = args.dataset.replace('-', '_')
68 | print(f'[ utils/setup ] Reading config: {args.config}:{dataset}')
69 | module = importlib.import_module(args.config)
70 | params = getattr(module, 'base')[experiment]
71 |
72 | if hasattr(module, dataset) and experiment in getattr(module, dataset):
73 | print(f'[ utils/setup ] Using overrides | config: {args.config} | dataset: {dataset}')
74 | overrides = getattr(module, dataset)[experiment]
75 | params.update(overrides)
76 | else:
77 | print(f'[ utils/setup ] Not using overrides | config: {args.config} | dataset: {dataset}')
78 |
79 | self._dict = {}
80 | for key, val in params.items():
81 | setattr(args, key, val)
82 | self._dict[key] = val
83 |
84 | return args
85 |
86 | def add_extras(self, args):
87 | '''
88 | Override config parameters with command-line arguments
89 | '''
90 | extras = args.extra_args
91 | if not len(extras):
92 | return
93 |
94 | print(f'[ utils/setup ] Found extras: {extras}')
95 | assert len(extras) % 2 == 0, f'Found odd number ({len(extras)}) of extras: {extras}'
96 | for i in range(0, len(extras), 2):
97 | key = extras[i].replace('--', '')
98 | val = extras[i+1]
99 | assert hasattr(args, key), f'[ utils/setup ] {key} not found in config: {args.config}'
100 | old_val = getattr(args, key)
101 | old_type = type(old_val)
102 | print(f'[ utils/setup ] Overriding config | {key} : {old_val} --> {val}')
103 | if val == 'None':
104 | val = None
105 | elif val == 'latest':
106 | val = 'latest'
107 | elif old_type in [bool, type(None)]:
108 | try:
109 | val = eval(val)
110 | except:
111 | print(f'[ utils/setup ] Warning: could not parse {val} (old: {old_val}, {old_type}), using str')
112 | else:
113 | val = old_type(val)
114 | setattr(args, key, val)
115 | self._dict[key] = val
116 |
117 | def eval_fstrings(self, args):
118 | for key, old in self._dict.items():
119 | if type(old) is str and old[:2] == 'f:':
120 | val = old.replace('{', '{args.').replace('f:', '')
121 | new = lazy_fstring(val, args)
122 | print(f'[ utils/setup ] Lazy fstring | {key} : {old} --> {new}')
123 | setattr(self, key, new)
124 | self._dict[key] = new
125 |
126 | def set_seed(self, args):
127 | if not 'seed' in dir(args):
128 | return
129 | print(f'[ utils/setup ] Setting seed: {args.seed}')
130 | set_seed(args.seed)
131 |
132 | def generate_exp_name(self, args):
133 | if not 'exp_name' in dir(args):
134 | return
135 | exp_name = getattr(args, 'exp_name')
136 | if callable(exp_name):
137 | exp_name_string = exp_name(args)
138 | print(f'[ utils/setup ] Setting exp_name to: {exp_name_string}')
139 | setattr(args, 'exp_name', exp_name_string)
140 | self._dict['exp_name'] = exp_name_string
141 |
142 | def mkdir(self, args):
143 | if 'logbase' in dir(args) and 'dataset' in dir(args) and 'exp_name' in dir(args):
144 | args.savepath = os.path.join(args.logbase, args.dataset, args.exp_name)
145 | self._dict['savepath'] = args.savepath
146 | if 'suffix' in dir(args):
147 | args.savepath = os.path.join(args.savepath, args.suffix)
148 | if mkdir(args.savepath):
149 | print(f'[ utils/setup ] Made savepath: {args.savepath}')
150 | self.save()
151 |
152 | def get_commit(self, args):
153 | args.commit = get_git_rev()
154 |
155 | def save_diff(self, args):
156 | try:
157 | save_git_diff(os.path.join(args.savepath, 'diff.txt'))
158 | except:
159 | print('[ utils/setup ] WARNING: did not save git diff')
160 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/utils/timer.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | class Timer:
4 |
5 | def __init__(self):
6 | self._start = time.time()
7 |
8 | def __call__(self, reset=True):
9 | now = time.time()
10 | diff = now - self._start
11 | if reset:
12 | self._start = now
13 | return diff
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/utils/training.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import numpy as np
4 | import torch
5 | import einops
6 | # import pdb
7 | import diffuser
8 | from copy import deepcopy
9 |
10 | from .arrays import batch_to_device, to_np, to_device, apply_dict
11 | from .timer import Timer
12 | # from .cloud import sync_logs
13 | from ml_logger import logger
14 |
15 | def cycle(dl):
16 | while True:
17 | for data in dl:
18 | yield data
19 |
20 | class EMA():
21 | '''
22 | empirical moving average
23 | '''
24 | def __init__(self, beta):
25 | super().__init__()
26 | self.beta = beta
27 |
28 | def update_model_average(self, ma_model, current_model):
29 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
30 | old_weight, up_weight = ma_params.data, current_params.data
31 | ma_params.data = self.update_average(old_weight, up_weight)
32 |
33 | def update_average(self, old, new):
34 | if old is None:
35 | return new
36 | return old * self.beta + (1 - self.beta) * new
37 |
38 | class Trainer(object):
39 | def __init__(
40 | self,
41 | diffusion_model,
42 | dataset,
43 | renderer,
44 | ema_decay=0.995,
45 | train_batch_size=32,
46 | train_lr=2e-5,
47 | gradient_accumulate_every=2,
48 | step_start_ema=2000,
49 | update_ema_every=10,
50 | log_freq=100,
51 | sample_freq=1000,
52 | save_freq=1000,
53 | label_freq=100000,
54 | save_parallel=False,
55 | n_reference=8,
56 | bucket=None,
57 | train_device='cuda',
58 | save_checkpoints=False,
59 | ):
60 | super().__init__()
61 | self.model = diffusion_model
62 | self.ema = EMA(ema_decay)
63 | self.ema_model = copy.deepcopy(self.model)
64 | self.update_ema_every = update_ema_every
65 | self.save_checkpoints = save_checkpoints
66 |
67 | self.step_start_ema = step_start_ema
68 | self.log_freq = log_freq
69 | self.sample_freq = sample_freq
70 | self.save_freq = save_freq
71 | self.label_freq = label_freq
72 | self.save_parallel = save_parallel
73 |
74 | self.batch_size = train_batch_size
75 | self.gradient_accumulate_every = gradient_accumulate_every
76 |
77 | self.dataset = dataset
78 |
79 | self.dataloader = cycle(torch.utils.data.DataLoader(
80 | self.dataset, batch_size=train_batch_size, num_workers=0, shuffle=True, pin_memory=True
81 | ))
82 | self.dataloader_vis = cycle(torch.utils.data.DataLoader(
83 | self.dataset, batch_size=1, num_workers=0, shuffle=True, pin_memory=True
84 | ))
85 | self.renderer = renderer
86 | self.optimizer = torch.optim.Adam(diffusion_model.parameters(), lr=train_lr)
87 |
88 | self.bucket = bucket
89 | self.n_reference = n_reference
90 |
91 | self.reset_parameters()
92 | self.step = 0
93 |
94 | self.device = train_device
95 |
96 | def reset_parameters(self):
97 | self.ema_model.load_state_dict(self.model.state_dict())
98 |
99 | def step_ema(self):
100 | if self.step < self.step_start_ema:
101 | self.reset_parameters()
102 | return
103 | self.ema.update_model_average(self.ema_model, self.model)
104 |
105 | #-----------------------------------------------------------------------------#
106 | #------------------------------------ api ------------------------------------#
107 | #-----------------------------------------------------------------------------#
108 |
109 | def train(self, n_train_steps):
110 |
111 | timer = Timer()
112 | for step in range(n_train_steps):
113 | for i in range(self.gradient_accumulate_every):
114 | batch = next(self.dataloader)
115 | batch = batch_to_device(batch, device=self.device)
116 | loss, infos = self.model.loss(*batch)
117 | loss = loss / self.gradient_accumulate_every
118 | loss.backward()
119 |
120 | self.optimizer.step()
121 | self.optimizer.zero_grad()
122 |
123 | if self.step % self.update_ema_every == 0:
124 | self.step_ema()
125 |
126 | if self.step % self.save_freq == 0:
127 | self.save()
128 |
129 | if self.step % self.log_freq == 0:
130 | infos_str = ' | '.join([f'{key}: {val:8.4f}' for key, val in infos.items()])
131 | logger.print(f'{self.step}: {loss:8.4f} | {infos_str} | t: {timer():8.4f}')
132 | metrics = {k:v.detach().item() for k, v in infos.items()}
133 | metrics['steps'] = self.step
134 | metrics['loss'] = loss.detach().item()
135 | logger.log_metrics_summary(metrics, default_stats='mean')
136 |
137 | if self.step == 0 and self.sample_freq:
138 | self.render_reference(self.n_reference)
139 |
140 | if self.sample_freq and self.step % self.sample_freq == 0:
141 | if self.model.__class__ == diffuser.models.diffusion.GaussianInvDynDiffusion:
142 | self.inv_render_samples()
143 | elif self.model.__class__ == diffuser.models.diffusion.ActionGaussianDiffusion:
144 | pass
145 | else:
146 | self.render_samples()
147 |
148 | self.step += 1
149 |
150 | def save(self):
151 | '''
152 | saves model and ema to disk;
153 | syncs to storage bucket if a bucket is specified
154 | '''
155 | data = {
156 | 'step': self.step,
157 | 'model': self.model.state_dict(),
158 | 'ema': self.ema_model.state_dict()
159 | }
160 | savepath = os.path.join(self.bucket, logger.prefix, 'checkpoint')
161 | os.makedirs(savepath, exist_ok=True)
162 | # logger.save_torch(data, savepath)
163 | if self.save_checkpoints:
164 | savepath = os.path.join(savepath, f'state_{self.step}.pt')
165 | else:
166 | savepath = os.path.join(savepath, 'state.pt')
167 | torch.save(data, savepath)
168 | logger.print(f'[ utils/training ] Saved model to {savepath}')
169 |
170 | def load(self):
171 | '''
172 | loads model and ema from disk
173 | '''
174 | loadpath = os.path.join(self.bucket, logger.prefix, f'checkpoint/state.pt')
175 | # data = logger.load_torch(loadpath)
176 | data = torch.load(loadpath)
177 |
178 | self.step = data['step']
179 | self.model.load_state_dict(data['model'])
180 | self.ema_model.load_state_dict(data['ema'])
181 |
182 | #-----------------------------------------------------------------------------#
183 | #--------------------------------- rendering ---------------------------------#
184 | #-----------------------------------------------------------------------------#
185 |
186 | def render_reference(self, batch_size=10):
187 | '''
188 | renders training points
189 | '''
190 |
191 | ## get a temporary dataloader to load a single batch
192 | dataloader_tmp = cycle(torch.utils.data.DataLoader(
193 | self.dataset, batch_size=batch_size, num_workers=0, shuffle=True, pin_memory=True
194 | ))
195 | batch = dataloader_tmp.__next__()
196 | dataloader_tmp.close()
197 |
198 | ## get trajectories and condition at t=0 from batch
199 | trajectories = to_np(batch.trajectories)
200 | conditions = to_np(batch.conditions[0])[:,None]
201 |
202 | ## [ batch_size x horizon x observation_dim ]
203 | normed_observations = trajectories[:, :, self.dataset.action_dim:]
204 | observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations')
205 |
206 | # from diffusion.datasets.preprocessing import blocks_cumsum_quat
207 | # # observations = conditions + blocks_cumsum_quat(deltas)
208 | # observations = conditions + deltas.cumsum(axis=1)
209 |
210 | #### @TODO: remove block-stacking specific stuff
211 | # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka
212 | # observations = blocks_add_kuka(observations)
213 | ####
214 |
215 | savepath = os.path.join('images', f'sample-reference.png')
216 | self.renderer.composite(savepath, observations)
217 |
218 | def render_samples(self, batch_size=2, n_samples=2):
219 | '''
220 | renders samples from (ema) diffusion model
221 | '''
222 | for i in range(batch_size):
223 |
224 | ## get a single datapoint
225 | batch = self.dataloader_vis.__next__()
226 | conditions = to_device(batch.conditions, self.device)
227 | ## repeat each item in conditions `n_samples` times
228 | conditions = apply_dict(
229 | einops.repeat,
230 | conditions,
231 | 'b d -> (repeat b) d', repeat=n_samples,
232 | )
233 |
234 | ## [ n_samples x horizon x (action_dim + observation_dim) ]
235 | if self.ema_model.returns_condition:
236 | returns = to_device(torch.ones(n_samples, 1), self.device)
237 | else:
238 | returns = None
239 |
240 | if self.ema_model.model.calc_energy:
241 | samples = self.ema_model.grad_conditional_sample(conditions, returns=returns)
242 | else:
243 | samples = self.ema_model.conditional_sample(conditions, returns=returns)
244 |
245 | samples = to_np(samples)
246 |
247 | ## [ n_samples x horizon x observation_dim ]
248 | normed_observations = samples[:, :, self.dataset.action_dim:]
249 |
250 | # [ 1 x 1 x observation_dim ]
251 | normed_conditions = to_np(batch.conditions[0])[:,None]
252 |
253 | # from diffusion.datasets.preprocessing import blocks_cumsum_quat
254 | # observations = conditions + blocks_cumsum_quat(deltas)
255 | # observations = conditions + deltas.cumsum(axis=1)
256 |
257 | ## [ n_samples x (horizon + 1) x observation_dim ]
258 | normed_observations = np.concatenate([
259 | np.repeat(normed_conditions, n_samples, axis=0),
260 | normed_observations
261 | ], axis=1)
262 |
263 | ## [ n_samples x (horizon + 1) x observation_dim ]
264 | observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations')
265 |
266 | #### @TODO: remove block-stacking specific stuff
267 | # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka
268 | # observations = blocks_add_kuka(observations)
269 | ####
270 |
271 | savepath = os.path.join('images', f'sample-{i}.png')
272 | self.renderer.composite(savepath, observations)
273 |
274 | def inv_render_samples(self, batch_size=2, n_samples=2):
275 | '''
276 | renders samples from (ema) diffusion model
277 | '''
278 | for i in range(batch_size):
279 |
280 | ## get a single datapoint
281 | batch = self.dataloader_vis.__next__()
282 | conditions = to_device(batch.conditions, self.device)
283 | ## repeat each item in conditions `n_samples` times
284 | conditions = apply_dict(
285 | einops.repeat,
286 | conditions,
287 | 'b d -> (repeat b) d', repeat=n_samples,
288 | )
289 |
290 | ## [ n_samples x horizon x (action_dim + observation_dim) ]
291 | if self.ema_model.returns_condition:
292 | returns = to_device(torch.ones(n_samples, 1), self.device)
293 | else:
294 | returns = None
295 |
296 | if self.ema_model.model.calc_energy:
297 | samples = self.ema_model.grad_conditional_sample(conditions, returns=returns)
298 | else:
299 | samples = self.ema_model.conditional_sample(conditions, returns=returns)
300 |
301 | samples = to_np(samples)
302 |
303 | ## [ n_samples x horizon x observation_dim ]
304 | normed_observations = samples[:, :, :]
305 |
306 | # [ 1 x 1 x observation_dim ]
307 | normed_conditions = to_np(batch.conditions[0])[:,None]
308 |
309 | # from diffusion.datasets.preprocessing import blocks_cumsum_quat
310 | # observations = conditions + blocks_cumsum_quat(deltas)
311 | # observations = conditions + deltas.cumsum(axis=1)
312 |
313 | ## [ n_samples x (horizon + 1) x observation_dim ]
314 | normed_observations = np.concatenate([
315 | np.repeat(normed_conditions, n_samples, axis=0),
316 | normed_observations
317 | ], axis=1)
318 |
319 | ## [ n_samples x (horizon + 1) x observation_dim ]
320 | observations = self.dataset.normalizer.unnormalize(normed_observations, 'observations')
321 |
322 | #### @TODO: remove block-stacking specific stuff
323 | # from diffusion.datasets.preprocessing import blocks_euler_to_quat, blocks_add_kuka
324 | # observations = blocks_add_kuka(observations)
325 | ####
326 |
327 | savepath = os.path.join('images', f'sample-{i}.png')
328 | self.renderer.composite(savepath, observations)
--------------------------------------------------------------------------------
/skilldiffuser/hrl/diffuser/utils/video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import skvideo.io
4 | from ml_logger import logger
5 |
6 | def _make_dir(filename):
7 | folder = os.path.dirname(filename)
8 | if not os.path.exists(folder):
9 | os.makedirs(folder)
10 |
11 | def save_video(filename, video_frames, fps=60, video_format='mp4'):
12 | assert fps == int(fps), fps
13 | # logger.save_video(video_frames, filename, fps=fps, format=video_format)
14 | _make_dir(filename)
15 |
16 | skvideo.io.vwrite(
17 | filename,
18 | video_frames,
19 | inputdict={
20 | '-r': str(int(fps)),
21 | },
22 | outputdict={
23 | '-f': video_format,
24 | '-pix_fmt': 'yuv420p', # '-pix_fmt=yuv420p' needed for osx https://github.com/scikit-video/scikit-video/issues/74
25 | }
26 | )
27 |
28 | def save_videos(filename, *video_frames, axis=1, **kwargs):
29 | ## video_frame : [ N x H x W x C ]
30 | video_frames = np.concatenate(video_frames, axis=axis)
31 | save_video(filename, video_frames, **kwargs)
32 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import copy
3 | import gym
4 | from PIL import Image
5 | import pdb
6 |
7 | from utils import String, lorl_gt_reward, lorl_save_im
8 |
9 |
10 | class BaseWrapper(gym.Wrapper):
11 | """Base processing wrapper.
12 |
13 | 1) Adds String command to observations
14 | 2) Preprocess states
15 | """
16 |
17 | def __init__(self, env, dataset):
18 | super(BaseWrapper, self).__init__(env)
19 | self.env = env
20 | self.dataset = dataset
21 | self.state_mean, self.state_std = self.dataset.state_mean, self.dataset.state_std
22 |
23 | self.state_dim = len(self.state_mean)
24 |
25 | self.observation_space = gym.spaces.Dict(
26 | {
27 | "state": gym.spaces.Box(
28 | low=-np.inf, # 0.0
29 | high=np.inf, # 1.0
30 | shape=(self.state_dim,),
31 | dtype=np.float32,
32 | ),
33 | "lang": String(),
34 | }
35 | )
36 |
37 | self.act_dim = env.action_space.shape[0]
38 | # print(self.action_space.low, self.action_space.high)
39 |
40 | def reset(self, **kwargs):
41 | obs = self.env.reset()
42 | return self.get_state(obs)
43 |
44 | def step(self, action):
45 | obs, reward, done, info = self.env.step(action)
46 |
47 | if done:
48 | success = 0
49 | if reward > 0:
50 | success = 1
51 | info.update({'success': success})
52 |
53 | return self.get_state(obs), reward, done, info
54 |
55 | def get_state(self, obs):
56 | """Returns the observation and lang_cmd"""
57 |
58 | lang = "dummy"
59 | cur_state = (obs.reshape(-1) - self.state_mean) / self.state_std
60 |
61 | return {'state': cur_state, 'lang': lang}
62 |
63 | def get_image(self):
64 | return Image.fromarray(self.env.render(), 'RGB')
65 |
66 |
67 | class BabyAIWrapper(gym.Wrapper):
68 | """BabyAI processing wrapper.
69 |
70 | 1) Adds String command to observations
71 | 2) Preprocess states
72 | """
73 |
74 | def __init__(self, env, dataset):
75 | super(BabyAIWrapper, self).__init__(env)
76 | self.env = env
77 | self.dataset = dataset
78 | self.state_mean, self.state_std = self.dataset.state_mean, self.dataset.state_std
79 |
80 | self.state_dim = len(self.state_mean)
81 | if self.dataset.kwargs['use_direction']:
82 | self.state_dim += 4 # for the direction part in BabyAI
83 |
84 | self.observation_space = gym.spaces.Dict(
85 | {
86 | "state": gym.spaces.Box(
87 | low=-np.inf, # 0.0
88 | high=np.inf, # 1.0
89 | shape=(self.state_dim,),
90 | dtype=np.float32,
91 | ),
92 | "lang": String(),
93 | }
94 | )
95 |
96 | self.act_dim = env.action_space.n
97 |
98 | def reset(self, **kwargs):
99 | obs = self.env.reset()
100 | return self.get_state(obs)
101 |
102 | def step(self, action):
103 | obs, reward, done, info = self.env.step(action)
104 |
105 | if done:
106 | success = 0
107 | if reward > 0:
108 | success = 1
109 | info.update({'success': success})
110 |
111 | return self.get_state(obs), reward, done, info
112 |
113 | def get_state(self, obs):
114 | """Returns the observation and lang_cmd"""
115 |
116 | lang = obs["mission"]
117 | cur_state = (obs["image"].reshape(-1) - self.state_mean) / self.state_std
118 |
119 | if self.dataset.kwargs['use_direction']:
120 | direction = np.zeros(4)
121 | direction[obs["direction"]] = 1.
122 | cur_state = np.concatenate([cur_state, direction])
123 | return {'state': cur_state, 'lang': lang}
124 |
125 | def get_image(self):
126 | return Image.fromarray(self.env.render(), 'RGB')
127 |
128 |
129 | class LorlWrapper(gym.Wrapper):
130 | """BabyAI processing wrapper.
131 |
132 | 1) Adds String command to observations
133 | 2) Preprocess states
134 | """
135 |
136 | def __init__(self, env, dataset, **kwargs):
137 | super(LorlWrapper, self).__init__(env)
138 |
139 | self.env = env
140 | self.dataset = dataset
141 | self.use_state = dataset.kwargs["use_state"]
142 | self.state_mean, self.state_std = self.dataset.state_mean, self.dataset.state_std
143 |
144 | self.state_dim = self.state_mean.shape
145 | self.act_dim = env.action_space.shape[0]
146 |
147 | if isinstance(self.state_dim, tuple):
148 | self.observation_space = gym.spaces.Dict(
149 | {
150 | "state": gym.spaces.Box(
151 | low=-np.inf,
152 | high=np.inf,
153 | shape=self.state_dim,
154 | dtype=np.float32,
155 | ),
156 | "lang": String(),
157 | }
158 | )
159 | else:
160 | self.observation_space = gym.spaces.Dict(
161 | {
162 | "state": gym.spaces.Box(
163 | low=-np.inf,
164 | high=np.inf,
165 | shape=(self.state_dim,),
166 | dtype=np.float32,
167 | ),
168 | "lang": String(),
169 | }
170 | )
171 |
172 | self.initial_state = None
173 | self.instr = kwargs["instr"]
174 | self.orig_instr = kwargs["orig_instr"]
175 |
176 | def reset(self, render=False, **kwargs):
177 | if render:
178 | render_path, iter_num, i = kwargs['render_path'], kwargs['iter_num'], kwargs["i"]
179 |
180 | env = self.env
181 | orig_instr, instr = self.orig_instr, self.instr
182 | im, _ = env.reset()
183 |
184 | # Initialize state for different tasks
185 | if orig_instr == "open drawer":
186 | env.sim.data.qpos[14] = 0 + np.random.uniform(-0.05, 0)
187 | elif orig_instr == "close drawer":
188 | env.sim.data.qpos[14] = -0.1 + np.random.uniform(-0.05, 0.05)
189 | elif orig_instr == "turn faucet right":
190 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5)
191 | elif orig_instr == "turn faucet left":
192 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5)
193 | elif orig_instr == "move black mug right":
194 | env.sim.data.qpos[11] = -0.2 + np.random.uniform(-0.05, 0.05)
195 | env.sim.data.qpos[12] = 0.65 + np.random.uniform(-0.05, 0.05)
196 | elif orig_instr == "move white mug down":
197 | env.sim.data.qpos[9] = -0.2 + np.random.uniform(-0.05, 0.05)
198 | env.sim.data.qpos[10] = 0.65 + np.random.uniform(-0.05, 0.05)
199 | # Dont know if the following are correct
200 | elif orig_instr == 'open drawer and move black mug right':
201 | env.sim.data.qpos[14] = 0 + np.random.uniform(-0.05, 0)
202 | env.sim.data.qpos[11] = -0.2 + np.random.uniform(-0.05, 0.05)
203 | env.sim.data.qpos[12] = 0.65 + np.random.uniform(-0.05, 0.05)
204 | elif orig_instr == 'pull the handle and move black mug down':
205 | env.sim.data.qpos[14] = 0 + np.random.uniform(-0.05, 0)
206 | env.sim.data.qpos[11] = -0.2 + np.random.uniform(-0.05, 0.05)
207 | env.sim.data.qpos[12] = 0.65 + np.random.uniform(-0.05, 0.05)
208 | elif orig_instr == 'move white mug right':
209 | env.sim.data.qpos[9] = -0.2 + np.random.uniform(-0.05, 0.05)
210 | env.sim.data.qpos[10] = 0.65 + np.random.uniform(-0.05, 0.05)
211 | elif orig_instr == 'move black mug down':
212 | env.sim.data.qpos[11] = -0.2 + np.random.uniform(-0.05, 0.05)
213 | env.sim.data.qpos[12] = 0.65 + np.random.uniform(-0.05, 0.05)
214 | elif orig_instr == 'close drawer and turn faucet right':
215 | env.sim.data.qpos[14] = -0.1 + np.random.uniform(-0.05, 0.05)
216 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5)
217 | elif orig_instr == 'close drawer and turn faucet left':
218 | env.sim.data.qpos[14] = -0.1 + np.random.uniform(-0.05, 0.05)
219 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5)
220 | elif orig_instr == 'turn faucet left and move white mug down':
221 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5)
222 | env.sim.data.qpos[9] = -0.2 + np.random.uniform(-0.05, 0.05)
223 | env.sim.data.qpos[10] = 0.65 + np.random.uniform(-0.05, 0.05)
224 | elif orig_instr == 'turn faucet right and close drawer':
225 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5)
226 | env.sim.data.qpos[14] = -0.1 + np.random.uniform(-0.05, 0.05)
227 | elif orig_instr == 'move white mug down and turn faucet left':
228 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5)
229 | env.sim.data.qpos[9] = -0.2 + np.random.uniform(-0.05, 0.05)
230 | env.sim.data.qpos[10] = 0.65 + np.random.uniform(-0.05, 0.05)
231 | elif orig_instr == 'close the drawer, turn the faucet left and move black mug right':
232 | env.sim.data.qpos[14] = -0.1 + np.random.uniform(-0.05, 0.05)
233 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5)
234 | env.sim.data.qpos[11] = -0.2 + np.random.uniform(-0.05, 0.05)
235 | env.sim.data.qpos[12] = 0.65 + np.random.uniform(-0.05, 0.05)
236 | elif instr == "open drawer and turn faucet counterclockwise":
237 | env.sim.data.qpos[14] = 0 + np.random.uniform(-0.05, 0)
238 | env.sim.data.qpos[13] = 0 + np.random.uniform(-np.pi/5, np.pi/5)
239 | elif instr == "slide the drawer closed and then shift white mug down":
240 | env.sim.data.qpos[14] = -0.1 + np.random.uniform(-0.05, 0.05)
241 | env.sim.data.qpos[9] = -0.2 + np.random.uniform(-0.05, 0.05)
242 | env.sim.data.qpos[10] = 0.65 + np.random.uniform(-0.05, 0.05)
243 |
244 | # if orig_instr == "move white mug down":
245 | # env._reset_hand(pos=[-0.1, 0.55, 0.1])
246 | # elif orig_instr == "move black mug right":
247 | # env._reset_hand(pos=[-0.1, 0.55, 0.1])
248 | if "mug" in orig_instr:
249 | env._reset_hand(pos=[-0.1, 0.55, 0.1])
250 | else:
251 | env._reset_hand(pos=[0, 0.45, 0.1])
252 |
253 | for _ in range(50):
254 | env.sim.step()
255 |
256 | reset_state = copy.deepcopy(env.sim.data.qpos[:])
257 | env.sim.data.qpos[:] = reset_state
258 | env.sim.data.qacc[:] = 0
259 | env.sim.data.qvel[:] = 0
260 | env.sim.step()
261 | self.initial_state = copy.deepcopy(env.sim.data.qpos[:])
262 |
263 | if render:
264 | # Initialize goal image for initial state
265 | if orig_instr == "open drawer":
266 | env.sim.data.qpos[14] = -0.15
267 | elif orig_instr == "close drawer":
268 | env.sim.data.qpos[14] = 0.0
269 | elif orig_instr == "turn faucet right":
270 | env.sim.data.qpos[13] -= np.pi/5
271 | elif orig_instr == "turn faucet left":
272 | env.sim.data.qpos[13] += np.pi/5
273 | elif orig_instr == "move black mug right":
274 | env.sim.data.qpos[11] -= 0.1
275 | elif orig_instr == "move white mug down":
276 | env.sim.data.qpos[10] += 0.1
277 |
278 | env.sim.step()
279 | gim = env._get_obs()[:, :, :3]
280 |
281 | # Reset inital state
282 | env.sim.data.qpos[:] = reset_state
283 | env.sim.data.qacc[:] = 0
284 | env.sim.data.qvel[:] = 0
285 | env.sim.step()
286 |
287 | im = env._get_obs()[:, :, :3]
288 | initim = im
289 | render_path = "."
290 | lorl_save_im(
291 | (initim * 255.0).astype(np.uint8),
292 | render_path + f"/initialim_{iter_num}_{i}_{instr}.jpg")
293 | lorl_save_im((gim*255.0).astype(np.uint8), render_path+f"gim_{iter_num}_{i}_{instr}.jpg")
294 |
295 | observation = self.get_state(im)
296 | cur_state, lang = observation['state'], observation['lang']
297 | if self.use_state:
298 | cur_state = (cur_state - self.state_mean) / self.state_std
299 | self.state_dim = len(cur_state)
300 | else:
301 | im = np.moveaxis(im, 2, 0) # make H,W,C to C,H,w
302 | cur_state = (im - self.state_mean) / self.state_std
303 | self.state_dim = cur_state.shape
304 |
305 | return {'state': cur_state, 'lang': lang}
306 |
307 | def step(self, action):
308 | im, _, _, info = self.env.step(action)
309 | dist, s = lorl_gt_reward(self.env.sim.data.qpos[:], self.initial_state, self.orig_instr)
310 |
311 | reward = 0
312 | success = 0
313 | if s:
314 | success = 1
315 | reward = dist
316 |
317 | info.update({'success': success})
318 | return self.get_state(im), reward, s, info
319 |
320 | def get_state(self, obs):
321 | """Returns the observation and lang_cmd"""
322 |
323 | if self.use_state:
324 | obs = self.env.sim.data.qpos[:]
325 | else:
326 | obs = np.moveaxis(obs, 2, 0) # make H,W,C to C,H,w
327 |
328 | state = (obs - self.state_mean) / self.state_std
329 | return {'state': state, 'lang': self.instr}
330 |
331 | def get_image(self, h=1024, w=1024):
332 | # im = self.env._get_obs()
333 | obs = self.sim.render(h, w, camera_name="cam0") / 255.
334 | im = np.flip(obs, 0).copy()
335 | return (im[:, :, :3]*255.0).astype(np.uint8)
336 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/hrl_model.py:
--------------------------------------------------------------------------------
1 | import pdb
2 |
3 | import torch
4 | import torch.nn as nn
5 | from transformers import DistilBertModel
6 |
7 | from iq import IQMixin
8 | from option_selector import OptionSelector
9 | from reconstructors import StateReconstructor, LanguageReconstructor
10 | # from decision_transformer import DecisionTransformer
11 | from dec_encoder import DecEncoder
12 | # from utils import pad
13 | # import diffuser.utils as diffutils
14 |
15 |
16 | class HRLModel(nn.Module, IQMixin):
17 | """Base class containing all the models"""
18 |
19 | def __init__(self, args, option_selector_args, state_reconstructor_args, lang_reconstructor_args,
20 | decision_args, iq_args, diff_trainer, device, horizon=5, K=10, train_lm=True,
21 | method='vanilla', state_reconstruct=False, lang_reconstruct=False, **kwargs):
22 | super().__init__()
23 |
24 | self.args = args
25 | self.horizon = horizon
26 | self.K = K
27 | self.lm = DistilBertModel.from_pretrained('distilbert-base-uncased')
28 | self.train_lm = train_lm # whether to train lm or not
29 | self.device = device
30 |
31 | if train_lm:
32 | self.lm.train()
33 | else:
34 | self.lm.eval()
35 |
36 | self.method = method
37 | self.state_reconstruct = state_reconstruct
38 | self.lang_reconstruct = lang_reconstruct
39 |
40 | self.state_dim = decision_args['state_dim']
41 | self.action_dim = decision_args['action_dim']
42 | self.option_dim = decision_args['option_dim']
43 |
44 | # self.decision_transformer = DecisionTransformer(lang_dim=self.lm.config.dim, **decision_transformer_args)
45 |
46 | if method == 'vanilla':
47 | assert decision_args['use_language'] and not decision_args['use_options']
48 | else:
49 | assert decision_args['use_options'] and not decision_args['use_language']
50 |
51 | self.option_selector = OptionSelector(lang_dim=self.lm.config.dim,
52 | method=self.method, **option_selector_args)
53 |
54 | if state_reconstruct:
55 | self.state_reconstructor = StateReconstructor(**state_reconstructor_args)
56 | if lang_reconstruct:
57 | self.lang_reconstructor = LanguageReconstructor(
58 | lang_dim=self.lm.config.dim, **lang_reconstructor_args)
59 |
60 | # TODO
61 | self.diffuser = DecEncoder(lang_dim=self.lm.config.dim, K=K, **decision_args)
62 |
63 | self.diff_trainer = diff_trainer
64 | # self.diff_trainer.set_optimizer(self.diffuser)
65 |
66 | # if self.decision_transformer.predict_q:
67 | # # initialize iq mixins
68 | # IQMixin.__init__(self, self.decision_transformer, iq_args, device)
69 |
70 | def forward(self, lm_input_ids, lm_attention_mask, states, actions, timesteps, iter_num, ind, diff_train_num, attention_mask=None):
71 | # pdb.set_trace()
72 | batch_size, traj_len = states.shape[0], states.shape[1]
73 | if not self.train_lm:
74 | with torch.no_grad():
75 | # (batch_size,num_embeddings,embedding_size)
76 | lm_embeddings = self.lm(lm_input_ids, lm_attention_mask).last_hidden_state
77 | else:
78 | # (batch_size,num_embeddings,embedding_size)
79 | lm_embeddings = self.lm(lm_input_ids, lm_attention_mask).last_hidden_state
80 | cls_embeddings = lm_embeddings[:, 0, :].unsqueeze_(1)
81 | # word_embeddings = lm_embeddings[:, 1:-1, :] # We skip the CLS and SEP tokens. I know there's padding here but we at least always remove the CLS
82 | word_embeddings = lm_embeddings[:, 1:, :] # We skip the CLS tokens
83 |
84 | entropy = None
85 | if self.method == 'vanilla':
86 | raise NotImplementedError
87 | # preds = self.decision_transformer(
88 | # states, actions, timesteps, word_embeddings=word_embeddings, attention_mask=attention_mask)
89 | #
90 | # state_rc_preds = None
91 | # state_rc_targets = None
92 | #
93 | # lang_rc_preds = None
94 | # lang_rc_targets = None
95 | #
96 | # commitment_loss = None
97 | else:
98 | # TODO (important) how does this get padded across batches? some of these horizon states may actually be padding
99 | # we change options only every H states. Say this leads to N states
100 | # N selected options
101 | # B, max_length // H, option_dim
102 | if self.method == 'option':
103 | # selected_options, _, commitment_loss = self.option_selector(cls_embeddings, states)
104 | selected_options, _, commitment_loss, entropy, state_embeddings = self.option_selector(
105 | word_embeddings.mean(1, keepdim=True), states)
106 | else:
107 | selected_options, _, commitment_loss, entropy, state_embeddings = self.option_selector(
108 | word_embeddings, states, timesteps, attention_mask)
109 |
110 | # # need to make options same length as states and actions
111 | options = torch.zeros((batch_size, traj_len, selected_options.shape[-1])).to(selected_options.device)
112 |
113 | ### This doesn't really work in making only some messages have gradients and others not having gradients
114 | ### The entire options tensor below has gradients after we do options = selected_options
115 | ### Actually it may only have the gradients related to the selectbackward operation -- unsure
116 |
117 | # Repeated detached options for horizon length
118 | for i in range(selected_options.shape[1]):
119 | options[:, i*self.horizon:(i+1)*self.horizon, :] = selected_options[:,
120 | i, :].unsqueeze(1).clone().detach()
121 | # Make sure to pass gradients for options at each horizon steps
122 | options[:, ::self.horizon, :] = selected_options
123 |
124 | # We reshape sequences to K size sub-sequences, so that the sub-policy only uses the current option
125 | # Here we are choosing K to be horizon
126 | B, L = states.shape[0], states.shape[1]
127 | num_seq = L // self.K # self.K == self.horizon == 8
128 |
129 | # We reshape sequences to K size sub-sequences, so that the sub-policy only uses the current option
130 | # Here we are choosing K to be horizon since it makes sense but technically we can do any K
131 | # This ensures the DT only looks at chunks of size horizon
132 | if isinstance(self.state_dim, tuple):
133 | states = states.reshape(B * num_seq, self.K, *self.state_dim)
134 | else:
135 | states = states.reshape(B * num_seq, self.K, self.state_dim)
136 | state_embeddings = state_embeddings.reshape(B * num_seq, self.K, state_embeddings.shape[2])
137 | options = options.reshape(B * num_seq, self.K, self.option_dim)
138 | actions = actions.reshape(B * num_seq, self.K, self.action_dim)
139 | # Should these timesteps be 1,2,3,4..H,1,2... or just 1,2,3,4...L? Going with 1,2,3,4...L
140 | timesteps = timesteps.reshape(B * num_seq, self.K)
141 | # timesteps = torch.arange(0, self.K).repeat(B * num_seq, 1)
142 | attention_mask = attention_mask.reshape(B * num_seq, self.K)
143 |
144 | # Make sure shapes are okay
145 | assert states.shape[0] == actions.shape[0] == options.shape[0] == timesteps.shape[0] == attention_mask.shape[0] == batch_size * num_seq
146 | assert states.shape[1] == actions.shape[1] == options.shape[1] == timesteps.shape[1] == attention_mask.shape[1] == self.K
147 | # preds = self.decision_transformer(
148 | # states, actions, timesteps, options=options, attention_mask=attention_mask)
149 |
150 | # TODO
151 | encoder_out = self.diffuser.encode(states, actions, timesteps, options=options,
152 | state_embeddings=state_embeddings, attention_mask=attention_mask)
153 |
154 | stacked_inputs = encoder_out['stacked_inputs'].to(self.device)
155 | option_embeddings = encoder_out['option_embeddings'].to(self.device)
156 | conds = {}
157 | returns = option_embeddings # conditions
158 | # num_diff_steps = int(self.args.diffuser.n_train_steps // self.args.max_iters)
159 | for in_step in range(diff_train_num):
160 | # step = iter_num * num_diff_steps + ind * diff_train_num + in_step
161 | self.diff_trainer.train_iteration((stacked_inputs, conds, returns))
162 |
163 | diff_loss, infos = self.diff_trainer.model.loss(stacked_inputs, conds, returns)
164 | obs_recon = infos.pop('obs')
165 |
166 | decode_outputs = self.diffuser.decode(obs_recon)
167 |
168 | self.state_reconstruct = False
169 | self.lang_reconstruct = False
170 | if self.state_reconstruct:
171 | # TODO: Maybe fix?? We now predict an option using trajs
172 | state_rc_preds = self.state_reconstructor(selected_options)
173 | state_rc_targets = states # horizon_states
174 | else:
175 | state_rc_preds = None
176 | state_rc_targets = None
177 |
178 | if self.lang_reconstruct:
179 | # TODO: Maybe fix?? Do we need the max options formulation? Check this
180 | lang_rc_preds = self.lang_reconstructor(selected_options.reshape(batch_size, -1))
181 | lang_rc_targets = cls_embeddings
182 | else:
183 | lang_rc_preds = None
184 | lang_rc_targets = None
185 |
186 | return {'dt': decode_outputs,
187 | 'state_rc': (state_rc_preds, state_rc_targets),
188 | 'lang_rc': (lang_rc_preds, lang_rc_targets),
189 | 'actions': actions,
190 | 'attention_mask': attention_mask,
191 | 'commitment_loss': commitment_loss,
192 | 'diff_loss': diff_loss,
193 | 'entropy': entropy}
194 |
195 | def get_action(self, states, actions, timesteps, options=None, word_embeddings=None, ema_diffusion_model=None):
196 | if self.method == 'vanilla':
197 | preds = self.decision_transformer.get_action(
198 | states, actions, timesteps, word_embeddings=word_embeddings)
199 |
200 | return preds
201 | else:
202 | # preds = self.decision_transformer.get_action(
203 | # states, actions, timesteps, options=options)
204 | encoder_out = self.diffuser.get_action(states, actions, timesteps,
205 | options=options, embed_state=self.option_selector.option_dt.embed_state)
206 | stacked_inputs = encoder_out['stacked_inputs']
207 | option_embeddings = encoder_out['option_embeddings']
208 |
209 | conds = {0: stacked_inputs[:, -1, self.diffuser.act_dim:]}
210 | returns = option_embeddings # conditions
211 |
212 | samples = ema_diffusion_model.conditional_sample(conds, returns=returns)
213 |
214 | action_hist = []
215 | for tt in range(samples.shape[1]-1):
216 | obs_comb = torch.cat([samples[:, tt, :], samples[:, tt+1, :]], dim=-1)
217 | obs_comb = obs_comb.reshape(-1, 2 * self.diffuser.hidden_size)
218 | action = ema_diffusion_model.inv_model(obs_comb)
219 | action_hist.append(action)
220 |
221 | # obs_comb = torch.cat([samples[:, 0, :], samples[:, 1, :]], dim=-1)
222 | # obs_comb = obs_comb.reshape(-1, 2 * self.diffuser.hidden_size)
223 | # action = ema_diffusion_model.inv_model(obs_comb)
224 |
225 | if self.diffuser.predict_q:
226 | # Choose actions from q_values
227 | # action = self.iq_choose_action(preds['q_preds'][:, -1], sample=True)
228 | raise NotImplementedError
229 | else:
230 | # Choose actions from direct predictions
231 | # action = preds['action_preds'][:, -1]
232 | if self.diffuser.discrete:
233 | action = action.argmax(dim=1)
234 |
235 | # action = action.squeeze(0)
236 | action = torch.cat(action_hist, dim=0)
237 | return action
238 |
239 | def save(self, iter_num, filepath, config):
240 | if hasattr(self.model, 'module'):
241 | model = self.model.module
242 |
243 | torch.save({'model': model.state_dict(),
244 | 'optimizer': self.optimizer.state_dict(),
245 | 'scheduler': self.scheduler.state_dict(),
246 | 'iter_num': iter_num,
247 | 'train_dataset_max_length': self.train_loader.dataset.max_length,
248 | 'config': config}, filepath)
249 |
250 | def load(self, filepath):
251 | checkpoint = torch.load(filepath)
252 | self.model.load_state_dict(checkpoint['model'])
253 | self.optimizer.load_state_dict(checkpoint['optimizer'])
254 | self.scheduler.load_state_dict(checkpoint['scheduler'])
255 | return {'iter_num': checkpoint['iter_num'], 'train_dataset_max_length': checkpoint['train_dataset_max_length'], 'config': checkpoint['config']}
256 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/img_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 |
5 |
6 | # Image Encoder
7 | # From https://github.com/suraj-nair-1/lorel/blob/main/models.py
8 | class BaseEncoder(nn.Module):
9 | __constants__ = ['embedding_size']
10 |
11 | def __init__(self):
12 | super().__init__()
13 |
14 |
15 | def preprocess(self, observations):
16 | """
17 | Reshape to 4 dimensions so it works for convolutions
18 | Chunk the time and batch dimensions
19 | """
20 | B, T, C, H, W = observations.shape
21 | return observations.reshape(-1, C, H, W).type(torch.float32).contiguous()
22 |
23 | def unpreprocess(self, embeddings, B, T):
24 | """
25 | Reshape back to 5 dimensions
26 | Unsqueeze the batch and time dimensions
27 | """
28 | BT, E = embeddings.shape
29 | return embeddings.reshape(B, T, E)
30 |
31 |
32 | # Image Encoder
33 | # From https://github.com/suraj-nair-1/lorel/blob/main/models.py
34 | class Encoder(BaseEncoder):
35 | __constants__ = ['embedding_size']
36 |
37 | def __init__(self, hidden_size, activation_function='relu', ch=3, robot=False):
38 | super().__init__()
39 | self.act_fn = getattr(F, activation_function)
40 | self.softmax = nn.Softmax(dim=2)
41 | self.sigmoid = nn.Sigmoid()
42 | self.robot = robot
43 | if self.robot:
44 | g = 4
45 | else:
46 | g = 1
47 | self.conv1 = nn.Conv2d(ch, 32, 4, stride=2, padding=1, groups=g) # 3
48 | self.conv1_2 = nn.Conv2d(32, 32, 4, stride=1, padding=1, groups=g)
49 | self.conv2 = nn.Conv2d(32, 64, 4, stride=2, padding=1, groups=g)
50 | self.conv2_2 = nn.Conv2d(64, 64, 4, stride=1, padding=1, groups=g)
51 | self.conv3 = nn.Conv2d(64, 128, 4, stride=2, padding=1, groups=g)
52 | self.conv3_2 = nn.Conv2d(128, 128, 4, stride=1, padding=1, groups=g)
53 | self.conv4 = nn.Conv2d(128, 256, 4, stride=2, padding=1, groups=g)
54 | self.conv4_2 = nn.Conv2d(256, 256, 4, stride=1, padding=1, groups=g)
55 |
56 | self.fc1 = nn.Linear(1024, 512)
57 | self.fc1_2 = nn.Linear(512, 512)
58 | self.fc1_3 = nn.Linear(512, 512)
59 | self.fc1_4 = nn.Linear(512, 512)
60 | self.fc2 = nn.Linear(512, hidden_size)
61 |
62 | def forward(self, observations):
63 | if self.robot:
64 | observations = torch.cat([
65 | observations[:, :3], observations[:, 12:15], observations[:, 3:6], observations[:, 15:18],
66 | observations[:, 6:9], observations[:, 18:21], observations[:, 9:12], observations[:, 21:],
67 | ], 1)
68 | if len(observations.shape) == 5:
69 | preprocessed_observations = self.preprocess(observations)
70 | else:
71 | preprocessed_observations = observations
72 | hidden = self.act_fn(self.conv1(preprocessed_observations))
73 | hidden = self.act_fn(self.conv1_2(hidden))
74 | hidden = self.act_fn(self.conv2(hidden))
75 | hidden = self.act_fn(self.conv2_2(hidden))
76 | hidden = self.act_fn(self.conv3(hidden))
77 | hidden = self.act_fn(self.conv3_2(hidden))
78 | hidden = self.act_fn(self.conv4(hidden))
79 | hidden = self.act_fn(self.conv4_2(hidden))
80 | hidden = hidden.reshape(preprocessed_observations.shape[0], -1)
81 |
82 | hidden = self.act_fn(self.fc1(hidden))
83 | hidden = self.act_fn(self.fc1_2(hidden))
84 | hidden = self.act_fn(self.fc1_3(hidden))
85 | hidden = self.act_fn(self.fc1_4(hidden))
86 | hidden = self.fc2(hidden)
87 |
88 | if len(observations.shape) == 5:
89 | return self.unpreprocess(hidden, observations.shape[0], observations.shape[1])
90 | else:
91 | return hidden
92 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/iq.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy.core.shape_base import vstack
3 | import torch
4 | from torch._C import qscheme
5 | import torch.nn.functional as F
6 | from torch.distributions import Categorical
7 | import copy
8 | import wandb
9 |
10 |
11 | class IQMixin(object):
12 | # Mixin to Base model that adds extra IQ
13 |
14 | def __init__(self, q_net, args, device):
15 | super().__init__()
16 |
17 | self.gamma = args.gamma
18 | self.args = args
19 | self.device = device
20 |
21 | self.log_alpha = torch.tensor(np.log(args.alpha)).to(self.device)
22 | self.q_net = q_net
23 |
24 | self.train()
25 |
26 | # Create target network
27 | if args.use_target:
28 | self.target_net = copy.deepcopy(q_net)
29 | self.target_net.load_state_dict(self.q_net.state_dict())
30 |
31 | self.target_net.train()
32 | # self.critic_tau = agent_cfg.critic_tau
33 | # self.critic_target_update_frequency = agent_cfg.critic_target_update_frequency
34 |
35 | def train(self, training=True):
36 | self.training = training
37 | self.q_net.train(training)
38 |
39 | @property
40 | def alpha(self):
41 | return self.log_alpha.exp()
42 |
43 | @property
44 | def critic_net(self):
45 | return self.q_net
46 |
47 | @property
48 | def critic_target_net(self):
49 | return self.target_net
50 |
51 | def iq_choose_action(self, q, sample=True):
52 | with torch.no_grad():
53 | dist = F.softmax(q/self.alpha, dim=1)
54 | if sample:
55 | dist = Categorical(dist)
56 | action = dist.sample() # if sample else dist.mean
57 | else:
58 | action = torch.argmax(dist, dim=1)
59 | return action
60 | # return action.detach().cpu().numpy()[0]
61 |
62 | def getV(self, q):
63 | v = self.alpha * \
64 | torch.logsumexp(q/self.alpha, dim=1, keepdim=True)
65 | return v
66 |
67 | def critic(self, q, action):
68 | return q.gather(1, action.long())
69 |
70 | # Offline IQ-Learn objective
71 | def iq_update_critic(self, expert_batch):
72 | args = self.args
73 | # Assume the expert_batch contains the current state q and the next state q.
74 | q, next_q, action, done = expert_batch
75 |
76 | losses = {}
77 | # keep track of v0
78 | v0 = self.getV(q).mean()
79 | losses['v0'] = v0.item()
80 |
81 | # calculate 1st term of loss
82 | # -E_(ρ_expert)[Q(s, a) - γV(s')]
83 | current_q = self.critic(q, action)
84 | next_v = self.getV(next_q)
85 | y = (1 - done) * self.gamma * next_v
86 |
87 | if args.use_target:
88 | with torch.no_grad():
89 | target_q = self.target_net(next_obs)
90 | next_v = self.get_V(target_q)
91 | y = (1 - done) * self.gamma * next_v
92 |
93 | reward = current_q - y
94 |
95 | with torch.no_grad():
96 | if args.div == "hellinger":
97 | phi_grad = 1/(1+reward)**2
98 | elif args.div == "kl":
99 | phi_grad = torch.exp(-reward-1)
100 | elif args.div == "kl2":
101 | phi_grad = F.softmax(-reward, dim=0) * reward.shape[0]
102 | elif args.div == "kl_fix":
103 | phi_grad = torch.exp(-reward)
104 | elif args.div == "js":
105 | phi_grad = torch.exp(-reward)/(2 - torch.exp(-reward))
106 | else:
107 | phi_grad = 1
108 |
109 | loss = -(phi_grad * reward).mean()
110 | losses['softq_loss'] = loss.item()
111 |
112 | if args.loss == "v0":
113 | # calculate 2nd term for our loss
114 | # (1-γ)E_(ρ0)[V(s0)]
115 | v0_loss = (1 - self.gamma) * v0
116 | loss += v0_loss
117 | losses['v0_loss'] = v0_loss.item()
118 |
119 | elif args.loss == "value":
120 | # alternative 2nd term for our loss (use only expert states)
121 | # E_(ρ)[Q(s,a) - γV(s')]
122 | value_loss = (self.getV(q) - y).mean()
123 | loss += value_loss
124 | losses['value_loss'] = value_loss.item()
125 |
126 | elif args.loss == "skip":
127 | # No loss
128 | pass
129 |
130 | if args.div == "chi":
131 | # Use χ2 divergence (adds a extra term to the loss)
132 | if args.use_target:
133 | with torch.no_grad():
134 | target_q = self.target_net(next_obs)
135 | next_v = self.getV(target_q)
136 | else:
137 | next_v = self.getV(next_q)
138 |
139 | y = (1 - done) * self.gamma * next_v
140 |
141 | current_q = self.critic(q, action)
142 | reward = current_q - y
143 | chi2_loss = 1/2 * (reward**2).mean()
144 | loss += chi2_loss
145 | losses['chi2_loss'] = chi2_loss.item()
146 |
147 | losses['total_loss'] = loss.item()
148 |
149 | return loss, losses
150 |
151 | def iq_critic_loss(self, batch, step):
152 | # assume we get input of trajectory with state, action, q_predictions, dones
153 | q, next_q, actions, dones = batch
154 |
155 | actions = actions.reshape(-1, 1)
156 | dones = dones.reshape(-1, 1)
157 |
158 | loss, loss_metrics = self.iq_update_critic((q, next_q, actions, dones))
159 |
160 | if self.args.use_target and step % self.critic_target_update_frequency == 0:
161 | self.target_net.load_state_dict(self.q_net.state_dict())
162 |
163 | wandb.log(loss_metrics, step=step)
164 | return loss
165 |
166 | def infer_q(self, state, action):
167 | state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
168 | action = torch.FloatTensor([action]).unsqueeze(0).to(self.device)
169 |
170 | with torch.no_grad():
171 | q = self.critic(state, action)
172 | return q.squeeze(0).cpu().numpy()
173 |
174 | def infer_v(self, state):
175 | state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
176 | with torch.no_grad():
177 | v = self.getV(state).squeeze()
178 | return v.cpu().numpy()
179 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/option_selector.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch._C import Value
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from option_transformer import OptionTransformer
7 | from utils import pad, entropy
8 | from img_encoder import Encoder
9 |
10 | from vector_quantize_pytorch import VectorQuantize
11 |
12 |
13 | class OptionSelector(nn.Module):
14 |
15 | """
16 | This model takes in the language embedding and the state to output a z from a categorical distribution
17 | Use the VQ trick to pick an option z
18 | """
19 |
20 | def __init__(
21 | self, state_dim, num_options, option_dim, lang_dim, horizon, num_hidden=None, hidden_size=None,
22 | method='traj_option', option_transformer=None, codebook_dim=16, use_vq=True, kmeans_init=False,
23 | commitment_weight=0.25, **kwargs):
24 |
25 | # option_dim and codebook_dim are different because of the way the VQ package is designed
26 | # if they are different, then there is a projection operation that happens inside the VQ layer
27 |
28 | super().__init__()
29 |
30 | if num_hidden is not None:
31 | assert num_hidden >= 2, "We need at least two hidden layers!"
32 |
33 | self.state_dim = state_dim
34 | self.option_dim = option_dim
35 | self.use_vq = use_vq
36 | self.num_options = num_options
37 |
38 | self.horizon = horizon
39 | self.method = method # whether to use full trajectory to get options or just current state
40 | self.hidden_size = 128
41 | hidden_size = 128
42 |
43 | if option_transformer:
44 | self.hidden_size = option_transformer.hidden_size
45 |
46 | self.Z = VectorQuantize(
47 | dim=option_dim,
48 | codebook_dim=codebook_dim, # codebook vector size
49 | codebook_size=num_options, # codebook size
50 | decay=0.99, # the exponential moving average decay, lower means the dictionary will change faster
51 | commitment_weight=commitment_weight, # the weight on the commitment loss
52 | kmeans_init=kmeans_init, # use kmeans init
53 | cpc=False,
54 | # threshold_ema_dead_code=2, # should actively replace any codes that have an exponential moving average cluster size less than 2
55 | use_cosine_sim=False # l2 normalize the codes
56 | )
57 |
58 | if self.method == 'traj_option':
59 | option_transformer_args = {'state_dim': state_dim,
60 | 'lang_dim': lang_dim,
61 | 'option_dim': option_dim,
62 | 'hidden_size': option_transformer.hidden_size,
63 | 'max_length': option_transformer.max_length,
64 | 'max_ep_len': option_transformer.max_ep_len,
65 | 'n_layer': option_transformer.n_layer,
66 | 'n_head': option_transformer.n_head,
67 | 'n_inner': 4*option_transformer.hidden_size,
68 | 'activation_function': option_transformer.activation_function,
69 | 'n_positions': option_transformer.n_positions,
70 | 'resid_pdrop': option_transformer.dropout,
71 | 'attn_pdrop': option_transformer.dropout,
72 | 'output_attentions': True # option_transformer.output_attention,
73 | }
74 | self.option_dt = OptionTransformer(**option_transformer_args)
75 | else:
76 | if isinstance(state_dim, tuple):
77 | # LORL
78 | if state_dim[0] == 3:
79 | # LORL Sawyer
80 | self.embed_state = Encoder(hidden_size=hidden_size, ch=3, robot=False)
81 | else:
82 | # LORL Franka
83 | self.embed_state = Encoder(hidden_size=hidden_size, ch=12, robot=True)
84 | else:
85 | self.embed_state = nn.Linear(state_dim, hidden_size)
86 |
87 | z_layers = []
88 | for i in range(num_hidden):
89 | if i == 0:
90 | z_layers.append(nn.Linear(2*hidden_size, hidden_size))
91 | elif i == num_hidden-1:
92 | z_layers.append(nn.Linear(hidden_size, option_dim))
93 | else:
94 | z_layers.append(nn.Linear(hidden_size, hidden_size))
95 | self.pred_options = nn.Sequential(*z_layers)
96 | self.embed_lang = nn.Linear(lang_dim, hidden_size)
97 |
98 | def forward(self, word_embeddings, states, timesteps=None, attention_mask=None, **kwargs):
99 | if self.method == 'traj_option':
100 | dt_ret = self.option_dt(word_embeddings, states, timesteps, attention_mask)
101 | option_preds = dt_ret[0]
102 | state_embeddings = dt_ret[2]
103 | option_preds = option_preds[:, ::self.horizon, :]
104 | else:
105 | ret_state_embeddings = self.embed_state(states).clone().detach()
106 | horizon_states = states[:, ::self.horizon, :]
107 | state_embeddings = self.embed_state(horizon_states)
108 | lang_embeddings = self.embed_lang(word_embeddings) # these will be cls embeddings or word embeddings mean
109 |
110 | inp = torch.cat([lang_embeddings.repeat(
111 | 1, state_embeddings.shape[1], 1), state_embeddings], dim=-1)
112 | option_preds = self.pred_options(inp)
113 |
114 | state_embeddings = ret_state_embeddings
115 |
116 | if self.use_vq:
117 | options, indices, commitment_loss = self.Z(option_preds)
118 | entropies = entropy(self.Z.codebook, options, self.Z.project_in(option_preds))
119 | else:
120 | # TODO: For now simply return the first dim of option
121 | options, indices = option_preds, option_preds[:, :, 0]
122 | commitment_loss = None
123 | entropies = None
124 | return options, indices, commitment_loss, entropies, state_embeddings
125 |
126 | def get_option(self, word_embeddings, states, timesteps=None, **kwargs):
127 |
128 | if 'constant_option' in kwargs:
129 | return self.Z.project_out(
130 | self.Z.codebook[kwargs['constant_option']]), torch.tensor(
131 | kwargs['constant_option'])
132 |
133 | if self.method == 'traj_option':
134 | if isinstance(self.state_dim, tuple):
135 | states = states.reshape(1, -1, *self.state_dim)
136 | else:
137 | states = states.reshape(1, -1, self.state_dim)
138 | timesteps = timesteps.reshape(1, -1)
139 | max_length = self.option_dt.max_length
140 |
141 | if max_length is not None:
142 | states = states[:, -max_length:]
143 | timesteps = timesteps[:, -max_length:]
144 |
145 | # pad all tokens to sequence length
146 | attention_mask = pad(
147 | torch.ones(1, states.shape[1]),
148 | max_length).to(
149 | dtype=torch.long, device=states.device).reshape(
150 | 1, -1)
151 | states = pad(states, max_length).to(dtype=torch.float32)
152 | timesteps = pad(timesteps, max_length).to(dtype=torch.long)
153 | else:
154 | attention_mask = None
155 | raise ValueError('Attention mask should not be none')
156 |
157 | options, option_indx, _, _, _ = self.forward(
158 | word_embeddings, states, timesteps, attention_mask=attention_mask, **kwargs)
159 | else:
160 | states = states[:, ::self.horizon, :]
161 | options, option_indx, _, _, _ = self.forward(
162 | word_embeddings, states, None, attention_mask=None, **kwargs)
163 |
164 | return options[0, -1], option_indx[0, -1]
165 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/option_transformer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import transformers
5 |
6 | from trajectory_gpt2 import GPT2Model
7 | from img_encoder import Encoder
8 |
9 |
10 | class OptionTransformer(nn.Module):
11 |
12 | """
13 | This model uses GPT-2 to select options for every horizon-th state
14 | """
15 |
16 | def __init__(
17 | self,
18 | state_dim,
19 | lang_dim,
20 | option_dim,
21 | hidden_size,
22 | max_length=None,
23 | max_ep_len=4096,
24 | **kwargs):
25 | super().__init__()
26 |
27 | self.option_dim = option_dim
28 | self.hidden_size = hidden_size
29 |
30 | config = transformers.GPT2Config(
31 | vocab_size=1, # doesn't matter -- we don't use the vocab
32 | n_embd=hidden_size,
33 | **kwargs
34 | )
35 |
36 | self.state_dim = state_dim
37 | self.max_length = max_length
38 | self.output_attentions = kwargs["output_attentions"]
39 |
40 | if isinstance(state_dim, tuple):
41 | # LORL
42 | if state_dim[0] == 3:
43 | # LORL Sawyer
44 | self.embed_state = Encoder(hidden_size=hidden_size, ch=3, robot=False)
45 | else:
46 | # LORL Franka
47 | self.embed_state = Encoder(hidden_size=hidden_size, ch=12, robot=True)
48 |
49 | else:
50 | self.embed_state = nn.Linear(self.state_dim, hidden_size)
51 |
52 | self.embed_lang = nn.Linear(lang_dim, hidden_size)
53 |
54 | # note: the only difference between this GPT2Model and the default Huggingface version
55 | # is that the positional embeddings are removed (since we'll add those ourselves)
56 | self.transformer = GPT2Model(config)
57 |
58 | self.embed_timestep = nn.Embedding(max_ep_len, hidden_size)
59 |
60 | self.embed_ln = nn.LayerNorm(hidden_size)
61 | self.predict_options = torch.nn.Linear(hidden_size, self.option_dim)
62 |
63 | def forward(self, word_embeddings, states, timesteps, attention_mask, **kwargs):
64 | batch_size, seq_length = states.shape[0], states.shape[1]
65 | num_tokens = word_embeddings.shape[1]
66 |
67 | if attention_mask is None:
68 | raise ValueError('Should not have attention_mask NONE')
69 | # attention mask for GPT: 1 if can be attended to, 0 if not
70 | attention_mask = torch.ones((batch_size, seq_length), dtype=torch.long)
71 |
72 | state_embeddings = self.embed_state(states)
73 | ret_state_embeddings = state_embeddings.clone().detach()
74 | lang_embeddings = self.embed_lang(word_embeddings)
75 | time_embeddings = self.embed_timestep(timesteps)
76 |
77 | # time embeddings are treated similar to positional embeddings
78 | state_embeddings = state_embeddings + time_embeddings # (batch_size, seq_length, hidden)
79 | lang_and_inputs = torch.cat([lang_embeddings, state_embeddings], dim=1)
80 | # LAYERNORM AFTER LANGUAGE
81 | stacked_inputs = self.embed_ln(lang_and_inputs)
82 |
83 | lang_attn_mask = torch.cat([torch.ones((batch_size, num_tokens), device=states.device), attention_mask], dim=1)
84 |
85 | # we feed in the input embeddings (not word indices as in NLP) to the model
86 | transformer_outputs = self.transformer(
87 | inputs_embeds=stacked_inputs,
88 | attention_mask=lang_attn_mask,
89 | )
90 |
91 | x = transformer_outputs['last_hidden_state']
92 | lang_out = x[:, :num_tokens, :].reshape(batch_size, num_tokens, self.hidden_size)
93 | traj_out = x[:, num_tokens:, :].reshape(batch_size, seq_length, self.hidden_size)
94 |
95 | # get predictions
96 | # predict option logits given state
97 | option_preds = self.predict_options(traj_out)
98 |
99 | if self.output_attentions:
100 | attentions = transformer_outputs[-1]
101 | return option_preds, attentions, ret_state_embeddings
102 |
103 | return option_preds, None, ret_state_embeddings
104 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/reconstructors.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class StateReconstructor(nn.Module):
7 | """
8 | This model tries to reconstruct the states from which the options were chosen for each option
9 | """
10 |
11 | def __init__(self, option_dim, state_dim, num_hidden, hidden_size):
12 | super().__init__()
13 |
14 | assert num_hidden >= 2, "We need at least two hidden layers!"
15 | assert isinstance(state_dim, int), "State dimension has to be integer!"
16 |
17 | layers = []
18 | for i in range(num_hidden):
19 | if i == 0:
20 | layers.append(nn.Linear(option_dim, hidden_size))
21 | elif i == num_hidden-1:
22 | layers.append(nn.Linear(hidden_size, state_dim))
23 | else:
24 | layers.append(nn.Linear(hidden_size, hidden_size))
25 | self.predictor = nn.Sequential(*layers)
26 |
27 | def forward(self, options):
28 | return self.predictor(options)
29 |
30 |
31 | class LanguageReconstructor(nn.Module):
32 | """
33 | This model tries to reconstruct the language from all the options
34 | """
35 |
36 | def __init__(self, option_dim, max_options, lang_dim, num_hidden, hidden_size):
37 | super().__init__()
38 |
39 | assert num_hidden >= 2, "We need at least two hidden layers!"
40 |
41 | self.max_options = max_options
42 |
43 | layers = []
44 | for i in range(num_hidden):
45 | if i == 0:
46 | layers.append(nn.Linear(max_options*option_dim, hidden_size))
47 | elif i == num_hidden-1:
48 | layers.append(nn.Linear(hidden_size, lang_dim))
49 | else:
50 | layers.append(nn.Linear(hidden_size, hidden_size))
51 | self.predictor = nn.Sequential(*layers)
52 |
53 | def forward(self, options):
54 | options = F.pad(options, pad=(1, self.max_options-options.shape[1]))
55 | return self.predictor(options)
56 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/trajectory_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class TrajectoryModel(nn.Module):
7 |
8 | def __init__(self, state_dim, act_dim, max_length=None):
9 | super().__init__()
10 |
11 | self.state_dim = state_dim
12 | self.act_dim = act_dim
13 | self.max_length = max_length
14 |
15 | def forward(self, states, actions, rewards, masks=None, attention_mask=None):
16 | # "masked" tokens or unspecified inputs can be passed in as None
17 | return None, None, None
18 |
19 | def get_action(self, states, actions, rewards, **kwargs):
20 | # these will come as tensors on the correct device
21 | return torch.zeros_like(actions[-1])
22 |
--------------------------------------------------------------------------------
/skilldiffuser/hrl/viz.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | from typing import Dict, Iterable, Callable
3 | from collections import Counter
4 | from itertools import chain
5 | import numpy as np
6 | import torch
7 | import seaborn as sns
8 | import torch.nn as nn
9 | from torch import Tensor
10 | import matplotlib.pyplot as plt
11 | import wandb
12 |
13 |
14 | class Attention(nn.Module):
15 | def __init__(self, model: nn.Module):
16 | super().__init__()
17 | self.model = model
18 | self._attention = None
19 |
20 | model.option_selector.option_dt.register_forward_hook(self.save_attention_hook())
21 |
22 | def save_attention_hook(self) -> Callable:
23 | def fn(model, input, output):
24 | self._attention = output[-1]
25 | return fn
26 |
27 | def forward(self, x: Tensor) -> Dict[str, Tensor]:
28 | _ = self.model(x)
29 | return self._attention
30 |
31 |
32 | def get_tokens(inputs, tokenizer):
33 | input_ids = inputs['input_ids']
34 | input_id_list = input_ids[0].tolist() # Batch index 0
35 | tokens = tokenizer.convert_ids_to_tokens(input_id_list)[1:]
36 | return tokens
37 |
38 |
39 | def viz_matrix(words_dict, num_options, step, skip_words):
40 | # skip_words = ['go', 'to', 'the', 'a', '[SEP]']
41 | words = sorted(set(chain(*words_dict.values())) - set(skip_words))
42 |
43 | def w_to_ind(word):
44 | return words.index(word)
45 |
46 | matrix = np.zeros([len(words), num_options])
47 |
48 | for o in range(num_options):
49 | for w in words_dict[o]:
50 | if w not in skip_words:
51 | matrix[w_to_ind(w), o] += 1
52 |
53 | # plot co-occurence matrix (words x options)
54 | plt.figure(figsize=(30, 10))
55 | sns.heatmap(matrix, yticklabels=words)
56 | # plt.plot()
57 | wandb.log({"Correlation Matrix": wandb.Image(plt)}, step=step)
58 |
59 | # Now if we normalize it by column (word freq for each option)
60 | plt.figure(figsize=(30, 10))
61 | matrix_norm_col = (matrix)/(matrix.sum(axis=0, keepdims=True) + 1e-6)
62 | sns.heatmap(matrix_norm_col, yticklabels=words)
63 | wandb.log({"Word Freq Matrix": wandb.Image(plt)}, step=step)
64 |
65 | # Now if we normalize it by row (option freq for each word)
66 | plt.figure(figsize=(30, 10))
67 | matrix_norm_row = (matrix)/(matrix.sum(axis=1, keepdims=True) + 1e-6)
68 | sns.heatmap(matrix_norm_row, yticklabels=words)
69 | wandb.log({"Option Freq Matrix": wandb.Image(plt)}, step=step)
70 | plt.close()
71 |
72 |
73 | def viz_matrix2(words_dict, num_options, step, skip_words, labelsize=15): # labelsize=12
74 | # skip_words = ['go', 'to', 'the', 'a', '[SEP]']
75 | skip_words.extend(['and',',','.','t'])
76 | words_set = set(chain(*words_dict.values()))
77 | for item in words_set:
78 | if item.startswith('##'):
79 | skip_words.append(item)
80 |
81 | words = sorted(set(chain(*words_dict.values())) - set(skip_words))
82 | # "/home/zxliang/new-code/LISA-hot/lisa"
83 |
84 | def w_to_ind(word):
85 | return words.index(word)
86 |
87 | matrix = np.zeros([len(words), num_options])
88 |
89 | for o in range(num_options):
90 | for w in words_dict[o]:
91 | if w not in skip_words:
92 | matrix[w_to_ind(w), o] += 1
93 |
94 | # pdb.set_trace()
95 |
96 | words = ["faucet" if x == "fa" else x for x in words]
97 |
98 | # plot co-occurence matrix (words x options)
99 | plt.figure(figsize=(20, 10))
100 | plt.tick_params(labelsize=labelsize)
101 | ax=sns.heatmap(matrix, yticklabels=words)
102 | colorbar = ax.collections[0].colorbar
103 | colorbar.ax.tick_params(labelsize=labelsize)
104 | # plt.plot()
105 | wandb.log({"Correlation Matrix": wandb.Image(plt)}, step=step)
106 |
107 | # Now if we normalize it by column (word freq for each option)
108 | plt.figure(figsize=(20, 10))
109 | plt.tick_params(labelsize=labelsize)
110 | matrix_norm_col = (matrix)/(matrix.sum(axis=0, keepdims=True) + 1e-6)
111 | ax=sns.heatmap(matrix_norm_col, yticklabels=words)
112 | colorbar = ax.collections[0].colorbar
113 | colorbar.ax.tick_params(labelsize=labelsize)
114 | wandb.log({"Word Freq Matrix": wandb.Image(plt)}, step=step)
115 |
116 | # Now if we normalize it by row (option freq for each word)
117 | plt.figure(figsize=(20, 10))
118 | plt.tick_params(labelsize=labelsize)
119 | matrix_norm_row = (matrix)/(matrix.sum(axis=1, keepdims=True) + 1e-6)
120 | ax=sns.heatmap(matrix_norm_row, yticklabels=words)
121 | colorbar = ax.collections[0].colorbar
122 | colorbar.ax.tick_params(labelsize=labelsize)
123 | wandb.log({"Option Freq Matrix": wandb.Image(plt)}, step=step)
124 |
125 | # plt.show()
126 | plt.close()
127 |
128 |
129 | def plot_hist(stats):
130 | plt.clf()
131 | plt.bar(stats.keys(), stats.values())
132 | plt.xticks(rotation='vertical')
133 | return wandb.Image(plt)
134 |
--------------------------------------------------------------------------------
/skilldiffuser/matrix.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Liang-ZX/SkillDiffuser/bc9da40fb98b8e65797005d3d9293d3d75665d9d/skilldiffuser/matrix.npy
--------------------------------------------------------------------------------
/skilldiffuser/pkl2pkl.py:
--------------------------------------------------------------------------------
1 | import pdb
2 |
3 | import h5py
4 | import numpy as np
5 | import pickle
6 | import pandas as pd
7 | import os
8 |
9 | root_path = '/home/zxliang/dataset/lorel/may_08_sawyer_50k'
10 | # f = h5py.File(os.path.join(root_path, 'data.hdf5'),'r')
11 | # f.visit(lambda x: print(x))
12 | f = open(os.path.join(root_path, 'prep_data.pkl'),'rb')
13 |
14 | # data = dict()
15 |
16 | df = pd.read_table(os.path.join(root_path, "labels.csv"), sep=",")
17 | langs = df["Text Description"].str.strip().to_numpy().reshape(-1)
18 | langs = np.array(['' if x is np.isnan else x for x in langs])
19 | filtr1 = np.array([int(("nothing" in l) or ("nan" in l) or ("wave" in l)) for l in langs])
20 | filtr = filtr1 == 0
21 | # data['langs'] = langs[filtr]
22 |
23 | data = pickle.load(f)
24 | for key in data.keys():
25 | data[key] = data[key][filtr, ...]
26 |
27 | with open(os.path.join(root_path, "prep_data4.pkl"),"wb") as fo:
28 | pickle.dump(data, fo)
29 |
--------------------------------------------------------------------------------
/skilldiffuser/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | pandas
3 | matplotlib
4 | torch
5 | seaborn
6 | transformers==4.12.0
7 | scipy
8 | tensorboardX
9 | hydra-core
10 | omegaconf
11 | gym==0.15.4
12 | wandb
13 | opencv-python
14 | cython<3
15 | git+https://github.com/rlworkgroup/metaworld.git@b016e6a25e485f1ffa8ccbf52df54ac204a81f31#egg=metaworld
16 | ml_logger==0.7.5
17 | einops
18 | mujoco-py==2.0.2.5
19 | matplotlib==3.3.4
20 | #torch==1.9.1+cu111
21 | #torchvision==0.10.1+cu111
22 | #torch==1.7.1+cu101
23 | #torchvision==0.8.2+cu101
24 | typed-argument-parser
25 | git+https://github.com/rail-berkeley/d4rl@master#egg=d4rl
26 | #git+https://github.com/rail-berkeley/d4rl.git@c39eefd68d2f3277ca68e996a45ce1dd24e65625
27 | scikit-image==0.17.2
28 | scikit-video==1.1.11
29 | gitpython
30 | pillow
31 | tqdm
32 | pygame
33 | networkx
34 | dotmap
--------------------------------------------------------------------------------
/skilldiffuser/train_lorel_compose.sh:
--------------------------------------------------------------------------------
1 | python hrl/main_hot.py env=lorel_sawyer_obs method=traj_option dt.n_layer=1 dt.n_head=4 option_selector.option_transformer.n_layer=1 option_selector.option_transformer.n_head=4 option_selector.commitment_weight=0.1 option_selector.option_transformer.hidden_size=128 batch_size=256 seed=1 warmup_steps=5000 resume=True checkpoint_path=${checkpoint_pretrain_ckpt} diffuser.loadpath=${diffuser_pretrain_ckpt}
2 |
--------------------------------------------------------------------------------