├── src ├── __init__.py ├── data │ ├── __init__.py │ ├── atari │ │ ├── requirements.txt │ │ └── README.md │ ├── untar_files.sh │ ├── composuite │ │ └── README.md │ ├── parallel_copy.py │ └── mimicgen │ │ └── README.md ├── envs │ ├── __init__.py │ ├── compatibility_wrapper.py │ ├── dummy_env_utils.py │ ├── composuite_utils.py │ ├── procgen_utils.py │ ├── hn_scores.py │ └── dmcontrol_utils.py ├── augmentations │ ├── __init__.py │ └── augs.py ├── callbacks │ ├── __init__.py │ ├── builder.py │ └── validation_callback.py ├── schedulers │ ├── __init__.py │ ├── lr_schedulers.py │ └── visualize_schedulers.py ├── utils │ ├── __init__.py │ └── debug.py ├── buffers │ ├── __init__.py │ ├── dataloaders.py │ ├── multi_domain_buffer.py │ ├── buffer_utils.py │ └── trajectory.py ├── algos │ ├── models │ │ ├── __init__.py │ │ ├── rms_norm.py │ │ ├── token_learner.py │ │ ├── model_utils.py │ │ ├── extractors.py │ │ ├── multi_domain_discrete_dt_model.py │ │ └── rope.py │ ├── decision_xlstm.py │ ├── __init__.py │ ├── discrete_decision_transformer_sb3.py │ ├── decision_mamba.py │ └── builder.py ├── tokenizers_custom │ ├── base_tokenizer.py │ ├── __init__.py │ ├── mu_law_tokenizer.py │ └── minmax_tokenizer.py └── optimizers │ └── __init__.py ├── configs ├── run_params │ ├── base.yaml │ ├── evaluate.yaml │ ├── finetune.yaml │ └── pretrain.yaml ├── agent_params │ ├── model_kwargs │ │ ├── default.yaml │ │ ├── atari.yaml │ │ ├── procgen.yaml │ │ ├── dmcontrol.yaml │ │ ├── mt_disc.yaml │ │ ├── dark_room.yaml │ │ └── multi_domain.yaml │ ├── data_paths │ │ ├── d4rl.yaml │ │ ├── atari.yaml │ │ ├── names │ │ │ ├── dmcontrol11.yaml │ │ │ ├── atari41.yaml │ │ │ └── mt45_v2.yaml │ │ ├── dmcontrol11.yaml │ │ ├── procgen12.yaml │ │ ├── mt45_v2.yaml │ │ ├── mt45v2_dmc11.yaml │ │ ├── mt45v2_dmc11_pg12.yaml │ │ ├── dark_room_10x10.yaml │ │ ├── dark_keydoor_10x10.yaml │ │ ├── mt45v2_dmc11_pg12_atari41.yaml │ │ └── mimicgen83.yaml │ ├── lr_sched_kwargs │ │ └── cosine.yaml │ ├── huggingface │ │ ├── dt_large.yaml │ │ ├── dt_larger.yaml │ │ ├── dt_medium.yaml │ │ ├── dt_large_64.yaml │ │ ├── mamba_huge.yaml │ │ ├── mamba_large.yaml │ │ ├── mamba_medium.yaml │ │ ├── mamba_huge_half.yaml │ │ ├── mamba_mediumplus.yaml │ │ ├── mamba_hugeplus.yaml │ │ ├── dt_huge.yaml │ │ ├── dt_medium_64.yaml │ │ ├── dt_largeplus_64.yaml │ │ ├── dt_mediumplus_64.yaml │ │ ├── dt_hugeplus.yaml │ │ ├── xlstm_huge.yaml │ │ ├── xlstm_large.yaml │ │ ├── xlstm_medium.yaml │ │ ├── xlstm_huge_half.yaml │ │ ├── xlstm_large_half.yaml │ │ ├── xlstm_medium_half.yaml │ │ ├── xlstm_mediumplus.yaml │ │ ├── xlstm_ms_mediumplus.yaml │ │ ├── xlstm_mediumplus_half.yaml │ │ └── xlstm_hugeplus.yaml │ ├── replay_buffer_kwargs │ │ ├── single_domain_disc.yaml │ │ └── multi_domain_mtdmccs.yaml │ ├── multi_domain.yaml │ └── darkroom.yaml ├── eval_params │ ├── base.yaml │ ├── finetune.yaml │ ├── pretrain_icl.yaml │ └── pretrain.yaml ├── wandb_callback_params │ └── pretrain.yaml ├── env_params │ ├── mujoco_gym.yaml │ ├── mimicgen.yaml │ ├── composuite.yaml │ ├── dmcontrol_icl.yaml │ ├── mt45.yaml │ ├── atari_freeway.yaml │ ├── atari.yaml │ ├── procgen.yaml │ ├── mt_dmc_procgen.yaml │ ├── mt_dmc_procgen_atari.yaml │ ├── mt_dmc_procgen_atari_cs_mg.yaml │ ├── mt_dmc_procgen_atari_cs.yaml │ ├── dark_room.yaml │ └── dark_keydoor.yaml └── config.yaml ├── figures └── lram.png ├── .gitmodules ├── dmc2gym_custom ├── setup.py ├── README.md └── dmc2gym_custom │ ├── __init__.py │ └── wrappers.py ├── requirements.txt ├── LICENSE ├── evaluate.py ├── .gitignore ├── environment.yaml └── main.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import make_env 2 | -------------------------------------------------------------------------------- /src/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from .augs import make_augmentations -------------------------------------------------------------------------------- /src/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import make_callbacks 2 | -------------------------------------------------------------------------------- /configs/run_params/base.yaml: -------------------------------------------------------------------------------- 1 | total_timesteps: 1e6 2 | log_interval: 10 -------------------------------------------------------------------------------- /configs/run_params/evaluate.yaml: -------------------------------------------------------------------------------- 1 | log_interval: 1 2 | total_timesteps: 0 3 | -------------------------------------------------------------------------------- /src/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_schedulers import make_lr_scheduler 2 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import maybe_split, safe_mean, multiply 2 | -------------------------------------------------------------------------------- /configs/run_params/finetune.yaml: -------------------------------------------------------------------------------- 1 | log_interval: 1000 2 | total_timesteps: 100000 3 | -------------------------------------------------------------------------------- /configs/run_params/pretrain.yaml: -------------------------------------------------------------------------------- 1 | log_interval: 1000 2 | total_timesteps: 200000 3 | -------------------------------------------------------------------------------- /figures/lram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/figures/lram.png -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/default.yaml: -------------------------------------------------------------------------------- 1 | reward_condition: True 2 | relative_pos_embds: False 3 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/d4rl.yaml: -------------------------------------------------------------------------------- 1 | base: ${DATA_DIR}/d4rl 2 | names: hopper-medium-v2.pkl 3 | 4 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/atari.yaml: -------------------------------------------------------------------------------- 1 | base: ${SSD_DATA_DIR}/atari_1M_64rgb 2 | defaults: 3 | - names: atari41 4 | -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/atari.yaml: -------------------------------------------------------------------------------- 1 | reward_condition: True 2 | tokenize_a: False 3 | relative_pos_embds: False 4 | 5 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "continual_world"] 2 | path = continual_world 3 | url = https://github.com/awarelab/continual_world.git 4 | -------------------------------------------------------------------------------- /configs/agent_params/lr_sched_kwargs/cosine.yaml: -------------------------------------------------------------------------------- 1 | kind: cosine 2 | T_max: ${run_params.total_timesteps} 3 | #T_max: 150000 4 | eta_min: 0.000001 5 | -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/procgen.yaml: -------------------------------------------------------------------------------- 1 | tokenize_a: False 2 | reward_condition: True 3 | relative_pos_embds: False 4 | use_time_embds: False 5 | -------------------------------------------------------------------------------- /configs/eval_params/base.yaml: -------------------------------------------------------------------------------- 1 | use_eval_callback: False 2 | n_eval_episodes: 10 3 | eval_freq: 10000 4 | max_no_improvement_evals: 0 5 | deterministic: False -------------------------------------------------------------------------------- /configs/eval_params/finetune.yaml: -------------------------------------------------------------------------------- 1 | use_eval_callback: True 2 | n_eval_episodes: 5 3 | eval_freq: 25000 4 | max_no_improvement_evals: 0 5 | deterministic: False 6 | -------------------------------------------------------------------------------- /configs/wandb_callback_params/pretrain.yaml: -------------------------------------------------------------------------------- 1 | gradient_save_freq: 0 2 | verbose: 1 3 | model_save_path: models 4 | model_sync_wandb: False 5 | model_save_freq: 50000 -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_large.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | n_embd: 512 3 | n_layer: 8 4 | n_head: 8 5 | max_ep_len: 1000 6 | hidden_size: 768 7 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_larger.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | n_embd: 512 3 | n_layer: 6 4 | n_head: 16 5 | max_ep_len: 1000 6 | hidden_size: 1024 7 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_medium.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | n_embd: 512 3 | n_layer: 4 4 | n_head: 4 5 | max_ep_len: 1000 6 | hidden_size: 512 7 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/dmcontrol.yaml: -------------------------------------------------------------------------------- 1 | reward_condition: True 2 | tokenize_a: True 3 | tokenize_rtg: False 4 | action_channels: 64 5 | relative_pos_embds: False 6 | -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/mt_disc.yaml: -------------------------------------------------------------------------------- 1 | reward_condition: True 2 | tokenize_a: True 3 | action_channels: 64 4 | relative_pos_embds: False 5 | tokenize_rtg: False 6 | 7 | -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_large_64.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | n_embd: 512 3 | n_layer: 8 4 | n_head: 12 5 | max_ep_len: 1000 6 | hidden_size: 768 7 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/replay_buffer_kwargs/single_domain_disc.yaml: -------------------------------------------------------------------------------- 1 | kind: default 2 | max_act_dim: 6 3 | max_state_dim: 204 4 | num_workers: 16 5 | pin_memory: False 6 | init_top_p: 1 7 | -------------------------------------------------------------------------------- /configs/agent_params/huggingface/mamba_huge.yaml: -------------------------------------------------------------------------------- 1 | max_length: 50 2 | n_embd: 512 3 | n_layer: 20 4 | n_head: 1 5 | max_ep_len: 1000 6 | d_model: 1280 7 | d_intermediate: 0 8 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/huggingface/mamba_large.yaml: -------------------------------------------------------------------------------- 1 | max_length: 50 2 | n_embd: 512 3 | n_layer: 16 4 | n_head: 1 5 | max_ep_len: 1000 6 | d_model: 1024 7 | d_intermediate: 0 8 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/huggingface/mamba_medium.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | n_embd: 512 3 | n_layer: 8 4 | n_head: 1 5 | max_ep_len: 1000 6 | d_model: 512 7 | d_intermediate: 0 8 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/huggingface/mamba_huge_half.yaml: -------------------------------------------------------------------------------- 1 | max_length: 50 2 | n_embd: 512 3 | n_layer: 10 4 | n_head: 1 5 | max_ep_len: 1000 6 | d_model: 1792 7 | d_intermediate: 0 8 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/huggingface/mamba_mediumplus.yaml: -------------------------------------------------------------------------------- 1 | max_length: 50 2 | n_embd: 512 3 | n_layer: 12 4 | n_head: 1 5 | max_ep_len: 1000 6 | d_model: 768 7 | d_intermediate: 0 8 | output_attentions: True -------------------------------------------------------------------------------- /configs/env_params/mujoco_gym.yaml: -------------------------------------------------------------------------------- 1 | envid: Hopper-v3 2 | reward_scale: 1000 3 | target_return: 3600 4 | num_envs: 1 5 | norm_obs: False 6 | record: False 7 | record_freq: 250000 8 | record_length: 1000 -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/dark_room.yaml: -------------------------------------------------------------------------------- 1 | tokenize_a: False 2 | reward_condition: True 3 | relative_pos_embds: False 4 | use_time_embds: False 5 | shared_a_head: True 6 | action_condition: False 7 | -------------------------------------------------------------------------------- /configs/agent_params/replay_buffer_kwargs/multi_domain_mtdmccs.yaml: -------------------------------------------------------------------------------- 1 | kind: domain 2 | max_act_dim: 8 3 | max_state_dim: 204 4 | num_workers: 16 5 | pin_memory: False 6 | init_top_p: 1 7 | p_valid: 0.025 8 | -------------------------------------------------------------------------------- /configs/agent_params/huggingface/mamba_hugeplus.yaml: -------------------------------------------------------------------------------- 1 | # 408M 2 | max_length: 50 3 | n_embd: 512 4 | n_layer: 28 5 | n_head: 1 6 | max_ep_len: 1000 7 | d_model: 1536 8 | d_intermediate: 0 9 | output_attentions: True -------------------------------------------------------------------------------- /configs/eval_params/pretrain_icl.yaml: -------------------------------------------------------------------------------- 1 | use_eval_callback: True 2 | n_eval_episodes: 40 3 | eval_freq: 25000 4 | max_no_improvement_evals: 0 5 | deterministic: False 6 | log_eval_trj: True 7 | eval_on_train: True -------------------------------------------------------------------------------- /configs/eval_params/pretrain.yaml: -------------------------------------------------------------------------------- 1 | use_eval_callback: True 2 | n_eval_episodes: 5 3 | eval_freq: 50000 4 | max_no_improvement_evals: 0 5 | deterministic: False 6 | log_eval_trj: True 7 | eval_on_train: True 8 | n_jobs: 4 -------------------------------------------------------------------------------- /configs/env_params/mimicgen.yaml: -------------------------------------------------------------------------------- 1 | envid: "mimicgen83" 2 | reward_scale: 1 3 | target_return: 20000 4 | num_envs: 1 5 | norm_obs: False 6 | record: False 7 | record_freq: 250000 8 | record_length: 1000 9 | eval_env_names: "mimicgen2" -------------------------------------------------------------------------------- /configs/env_params/composuite.yaml: -------------------------------------------------------------------------------- 1 | envid: "composuite240" 2 | reward_scale: 200 3 | target_return: 20000 4 | num_envs: 1 5 | norm_obs: False 6 | record: False 7 | record_freq: 250000 8 | record_length: 1000 9 | eval_env_names: "composuite16" -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_huge.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | n_embd: 512 3 | n_layer: 10 4 | n_head: 20 5 | max_ep_len: 1000 6 | hidden_size: 1280 7 | output_attentions: True 8 | 9 | resid_pdrop: 0 10 | embd_pdrop: 0 11 | attn_pdrop: 0 -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_medium_64.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | n_embd: 512 3 | n_layer: 4 4 | n_head: 8 5 | max_ep_len: 1000 6 | hidden_size: 512 7 | output_attentions: True 8 | 9 | resid_pdrop: 0 10 | embd_pdrop: 0 11 | attn_pdrop: 0 -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_largeplus_64.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | n_embd: 512 3 | n_layer: 8 4 | n_head: 16 5 | max_ep_len: 1000 6 | hidden_size: 1024 7 | output_attentions: True 8 | 9 | resid_pdrop: 0 10 | embd_pdrop: 0 11 | attn_pdrop: 0 -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_mediumplus_64.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | n_embd: 512 3 | n_layer: 6 4 | n_head: 12 5 | max_ep_len: 1000 6 | hidden_size: 768 7 | output_attentions: True 8 | 9 | resid_pdrop: 0 10 | embd_pdrop: 0 11 | attn_pdrop: 0 -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_hugeplus.yaml: -------------------------------------------------------------------------------- 1 | # 408M parameters 2 | max_length: 20 3 | n_embd: 512 4 | n_layer: 14 5 | n_head: 24 6 | max_ep_len: 1000 7 | hidden_size: 1536 8 | output_attentions: True 9 | 10 | resid_pdrop: 0 11 | embd_pdrop: 0 12 | attn_pdrop: 0 -------------------------------------------------------------------------------- /configs/env_params/dmcontrol_icl.yaml: -------------------------------------------------------------------------------- 1 | envid: dmcontrol11_icl 2 | reward_scale: 100 3 | target_return: 1000 4 | num_envs: 1 5 | norm_obs: False 6 | record: False 7 | record_freq: 250000 8 | record_length: 1000 9 | eval_env_names: "dmcontrol5_icl" 10 | dmc_env_kwargs: 11 | flatten_obs: False 12 | -------------------------------------------------------------------------------- /src/data/atari/requirements.txt: -------------------------------------------------------------------------------- 1 | ## on CentOS: 2 | #d3rlpy==1.0.0 3 | #numpy==1.19.5 4 | ## on Windows or Ubuntu: 5 | d3rlpy==1.1.0 6 | numpy==1.19.5 7 | git+https://github.com/takuseno/d4rl-atari 8 | gym[atari]==0.24.1 9 | #autorom[accept-rom-license] 10 | tqdm==4.64.0 11 | matplotlib==3.5.1 12 | -------------------------------------------------------------------------------- /configs/env_params/mt45.yaml: -------------------------------------------------------------------------------- 1 | envid: "mt45_v2" 2 | reward_scale: 200 3 | target_return: 20000 4 | num_envs: 1 5 | norm_obs: False 6 | record: False 7 | record_freq: 250000 8 | record_length: 1000 9 | randomization: random_init_all 10 | remove_task_ids: True 11 | add_task_ids: False 12 | eval_env_names: "mt5_v2" -------------------------------------------------------------------------------- /configs/agent_params/data_paths/names/dmcontrol11.yaml: -------------------------------------------------------------------------------- 1 | - finger_turn_easy.npz 2 | - fish_upright.npz 3 | - hopper_stand.npz 4 | - point_mass_easy.npz 5 | - walker_stand.npz 6 | - walker_run.npz 7 | - ball_in_cup_catch.npz 8 | - cartpole_swingup.npz 9 | - cheetah_run.npz 10 | - finger_spin.npz 11 | - reacher_easy.npz -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/multi_domain.yaml: -------------------------------------------------------------------------------- 1 | reward_condition: True 2 | tokenize_a: True 3 | tokenize_rtg: False 4 | action_channels: 256 5 | discrete_actions: 18 6 | state_dim: 204 7 | image_shape: [3,64,64] 8 | relative_pos_embds: False 9 | use_time_embds: False 10 | action_condition: False 11 | shared_a_head: True -------------------------------------------------------------------------------- /src/buffers/__init__.py: -------------------------------------------------------------------------------- 1 | def make_buffer_class(kind): 2 | if kind == "domain": 3 | from .multi_domain_buffer import MultiDomainTrajectoryReplayBuffer 4 | return MultiDomainTrajectoryReplayBuffer 5 | else: 6 | from .trajectory_buffer import TrajectoryReplayBuffer 7 | return TrajectoryReplayBuffer 8 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/dmcontrol11.yaml: -------------------------------------------------------------------------------- 1 | base: ${SSD_DATA_DIR}/dmc 2 | names: 3 | - finger_turn_easy 4 | - fish_upright 5 | - hopper_stand 6 | - point_mass_easy 7 | - walker_stand 8 | - walker_run 9 | - ball_in_cup_catch 10 | - cartpole_swingup 11 | - cheetah_run 12 | - finger_spin 13 | - reacher_easy -------------------------------------------------------------------------------- /configs/agent_params/data_paths/procgen12.yaml: -------------------------------------------------------------------------------- 1 | base: ${SSD_DATA_DIR}/procgen/processed_25M_custom 2 | names: 3 | - "bigfish" 4 | - "bossfight" 5 | - "caveflyer" 6 | - "chaser" 7 | - "coinrun" 8 | - "dodgeball" 9 | - "fruitbot" 10 | - "heist" 11 | - "leaper" 12 | - "maze" 13 | - "miner" 14 | - "starpilot" -------------------------------------------------------------------------------- /configs/env_params/atari_freeway.yaml: -------------------------------------------------------------------------------- 1 | envid: FreewayNoFrameskip-v4 2 | reward_scale: 20 3 | target_return: 21 4 | num_envs: 1 5 | norm_obs: False 6 | record: False 7 | record_freq: 1000000 8 | record_length: 2000 9 | atari_env_kwargs: 10 | full_action_space: True 11 | wrapper_kwargs: 12 | to_rgb: True 13 | screen_size: 64 14 | -------------------------------------------------------------------------------- /src/algos/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .online_decision_transformer_model import OnlineDecisionTransformerModel 2 | from .discrete_decision_transformer_model import DiscreteDTModel 3 | from .custom_critic import CustomContinuousCritic, MultiHeadContinuousCritic, StateValueFn 4 | from .multi_domain_discrete_dt_model import MultiDomainDiscreteDTModel 5 | -------------------------------------------------------------------------------- /dmc2gym_custom/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup( 5 | name='dmc2gym_custom', 6 | version='1.0.0', 7 | author='Thomas Schmied', 8 | description=('a gym like wrapper for dm_control'), 9 | license='', 10 | keywords='gym dm_control openai deepmind', 11 | packages=find_packages(), 12 | ) -------------------------------------------------------------------------------- /configs/env_params/atari.yaml: -------------------------------------------------------------------------------- 1 | envid: PongNoFrameskip-v4 2 | reward_scale: 20 3 | target_return: 21 4 | num_envs: 1 5 | norm_obs: False 6 | record: False 7 | record_freq: 1000000 8 | record_length: 2000 9 | eval_env_names: "atari41" 10 | atari_env_kwargs: 11 | full_action_space: True 12 | wrapper_kwargs: 13 | to_rgb: True 14 | screen_size: 64 -------------------------------------------------------------------------------- /configs/env_params/procgen.yaml: -------------------------------------------------------------------------------- 1 | envid: "procgen12" 2 | reward_scale: 1 3 | target_return: 1 4 | num_envs: 1 5 | norm_obs: False 6 | norm_reward: False 7 | record: False 8 | record_freq: 1000000 9 | record_length: 2000 10 | eval_env_names: procgen4 11 | distribution_mode: "easy" 12 | train_eval_seeds: True 13 | time_limit: 400 14 | env_kwargs: 15 | # data was generated with 0 to 199 16 | num_levels: 200 17 | start_level: 0 18 | eval_env_kwargs: 19 | num_levels: 1 20 | start_level: 200 21 | -------------------------------------------------------------------------------- /src/tokenizers_custom/base_tokenizer.py: -------------------------------------------------------------------------------- 1 | class BaseTokenizer: 2 | 3 | def __init__(self, vocab_size=256, shift=0): 4 | self._vocab_size = vocab_size 5 | self._shift = shift 6 | 7 | def tokenize(self, x): 8 | raise NotImplementedError() 9 | 10 | def inv_tokenize(self, x): 11 | raise NotImplementedError() 12 | 13 | @property 14 | def vocab_size(self): 15 | return self._vocab_size 16 | 17 | @property 18 | def shift(self): 19 | return self._shift 20 | -------------------------------------------------------------------------------- /src/data/untar_files.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | src_dir="$1" 4 | dst_dir="$2" 5 | if [ -z "$src_dir" ] || [ -z "$dst_dir" ]; then 6 | echo "Usage: $0 " 7 | exit 1 8 | fi 9 | mkdir -p "$dst_dir" 10 | extract_archive() { 11 | archive="$1" 12 | dst="$2" 13 | tar xvf "$archive" -C "$dst" 14 | } 15 | export -f extract_archive 16 | export dst_dir 17 | find "$src_dir" -name "*.tar.gz" -print0 | xargs -0 -I {} -n 1 -P 64 bash -c 'extract_archive "{}" "$dst_dir"' _ "$dst_dir" 18 | echo "All .tar.gz files have been extracted to $dst_dir." -------------------------------------------------------------------------------- /src/tokenizers_custom/__init__.py: -------------------------------------------------------------------------------- 1 | def make_tokenizer(kind, tokenizer_kwargs=None): 2 | if tokenizer_kwargs is None: 3 | tokenizer_kwargs = {} 4 | if kind == 'mulaw': 5 | from .mu_law_tokenizer import MuLawTokenizer 6 | return MuLawTokenizer(**tokenizer_kwargs) 7 | elif kind == 'minmax': 8 | from .minmax_tokenizer import MinMaxTokenizer 9 | return MinMaxTokenizer(**tokenizer_kwargs) 10 | elif kind == 'minmax2': 11 | from .minmax_tokenizer import MinMaxTokenizer2 12 | return MinMaxTokenizer2(**tokenizer_kwargs) 13 | else: 14 | raise ValueError(f"Unknown tokenizer type {kind}") 15 | -------------------------------------------------------------------------------- /dmc2gym_custom/README.md: -------------------------------------------------------------------------------- 1 | # dmc2gym_custom 2 | The original repository is available at: https://github.com/denisyarats/dmc2gym 3 | 4 | The original code base always flattens the observations by default. However, this is impractical for our purposes, 5 | as we need to construct the full observation space. Therefore, we make `flatten_obs` optional. 6 | 7 | ### Instalation 8 | To install `dmc2gym_custom`, run: 9 | ``` 10 | pip install -e . 11 | ``` 12 | from the root of this directory. 13 | 14 | ### Usage 15 | ```python 16 | import dmc2gym_custom 17 | env = dmc2gym_custom.make(domain_name='point_mass', task_name='easy', seed=1, flatten_obs=True) 18 | done = False 19 | obs = env.reset() 20 | while not done: 21 | action = env.action_space.sample() 22 | obs, reward, done, info = env.step(action) 23 | ``` -------------------------------------------------------------------------------- /configs/env_params/mt_dmc_procgen.yaml: -------------------------------------------------------------------------------- 1 | # envid: mt45v2_dmc11 2 | envid: bigfish 3 | target_return: 1 4 | num_envs: 1 5 | norm_obs: False 6 | record: False 7 | record_freq: 1000000 8 | record_length: 2000 9 | 10 | # Meta-world specific 11 | randomization: random_init_all 12 | remove_task_ids: True 13 | add_task_ids: False 14 | 15 | # procgen specific 16 | distribution_mode: "easy" 17 | # time_limit: 400 18 | env_kwargs: 19 | # data was generated with 0 to 199 20 | num_levels: 200 21 | start_level: 0 22 | eval_env_kwargs: 23 | num_levels: 200 24 | start_level: 0 25 | 26 | # DMC specific 27 | dmc_env_kwargs: 28 | flatten_obs: False 29 | 30 | # multi domain evaluation 31 | # eval_env_names: "mt5v2_dmc5" 32 | eval_env_names: "mt45v2_dmc11_pg12" 33 | 34 | reward_scale: 35 | mt50: 1 36 | dmcontrol: 1 37 | procgen: 1 38 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/names/atari41.yaml: -------------------------------------------------------------------------------- 1 | # same ones as used in MGDT 2 | # - alien 3 | - amidar 4 | - assault 5 | - asterix 6 | - atlantis 7 | - bank-heist 8 | - battle-zone 9 | - beam-rider 10 | - boxing 11 | - breakout 12 | - carnival 13 | - centipede 14 | - chopper-command 15 | - crazy-climber 16 | - demon-attack 17 | - double-dunk 18 | - enduro 19 | - fishing-derby 20 | - freeway 21 | - frostbite 22 | - gopher 23 | - gravitar 24 | - hero 25 | - ice-hockey 26 | - jamesbond 27 | - kangaroo 28 | - krull 29 | - kung-fu-master 30 | # - ms-pacman 31 | - name-this-game 32 | - phoenix 33 | # - pong 34 | - pooyan 35 | - qbert 36 | - riverraid 37 | - road-runner 38 | - robotank 39 | - seaquest 40 | # - space-invaders 41 | # - star-gunner 42 | - time-pilot 43 | - up-n-down 44 | - video-pinball 45 | - wizard-of-wor 46 | - yars-revenge 47 | - zaxxon 48 | -------------------------------------------------------------------------------- /configs/agent_params/multi_domain.yaml: -------------------------------------------------------------------------------- 1 | kind: "MDDT" 2 | use_critic: False 3 | learning_starts: 0 4 | batch_size: 128 5 | gradient_steps: 1 6 | stochastic_policy: False 7 | loss_fn: "ce" 8 | ent_coef: 0.0 9 | offline_steps: ${run_params.total_timesteps} 10 | buffer_max_len_type: "transition" 11 | buffer_size: 2000000000 12 | buffer_weight_by: len 13 | target_return_type: predefined 14 | warmup_steps: 4000 15 | use_amp: True 16 | compile: True 17 | bfloat16: True 18 | 19 | defaults: 20 | - huggingface: dt_medium_64 21 | - data_paths: mt45v2_dmc11_pg12_atari41_cs240_mg83 22 | - model_kwargs: multi_domain 23 | - lr_sched_kwargs: cosine 24 | - replay_buffer_kwargs: multi_domain_mtdmccs 25 | 26 | huggingface: 27 | activation_function: gelu 28 | max_length: 50 29 | use_fast_attn: True 30 | n_positions: 1600 31 | eval_context_len: ${agent_params.huggingface.max_length} 32 | -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_huge.yaml: -------------------------------------------------------------------------------- 1 | max_ep_len: 1000 2 | max_length: 50 3 | n_layer: 20 4 | hidden_size: 1280 5 | n_head: 4 6 | 7 | xlstm_config: 8 | mlstm_block: 9 | mlstm: 10 | conv1d_kernel_size: 4 11 | qkv_proj_blocksize: 4 12 | num_heads: ${agent_params.huggingface.n_head} 13 | slstm_block: 14 | slstm: 15 | backend: cuda 16 | num_heads: ${agent_params.huggingface.n_head} 17 | conv1d_kernel_size: 4 18 | bias_init: powerlaw_blockdependent 19 | feedforward: 20 | proj_factor: 1.3 21 | act_fn: gelu 22 | # context length needs to be set to 3 times the max_length --> s/a/r/rtg tokens. 23 | # context_length: ${multiply:${agent_params.huggingface.max_length},4} 24 | context_length: 150 25 | num_blocks: ${agent_params.huggingface.n_layer} 26 | embedding_dim: ${agent_params.huggingface.hidden_size} 27 | # slstm_at: [1] 28 | -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_large.yaml: -------------------------------------------------------------------------------- 1 | max_ep_len: 1000 2 | max_length: 50 3 | n_layer: 16 4 | hidden_size: 1024 5 | n_head: 4 6 | 7 | xlstm_config: 8 | mlstm_block: 9 | mlstm: 10 | conv1d_kernel_size: 4 11 | qkv_proj_blocksize: 4 12 | num_heads: ${agent_params.huggingface.n_head} 13 | slstm_block: 14 | slstm: 15 | backend: cuda 16 | num_heads: ${agent_params.huggingface.n_head} 17 | conv1d_kernel_size: 4 18 | bias_init: powerlaw_blockdependent 19 | feedforward: 20 | proj_factor: 1.3 21 | act_fn: gelu 22 | # context length needs to be set to 3 times the max_length --> s/a/r/rtg tokens. 23 | # context_length: ${multiply:${agent_params.huggingface.max_length},4} 24 | context_length: 150 25 | num_blocks: ${agent_params.huggingface.n_layer} 26 | embedding_dim: ${agent_params.huggingface.hidden_size} 27 | # slstm_at: [1] 28 | -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_medium.yaml: -------------------------------------------------------------------------------- 1 | max_ep_len: 1000 2 | max_length: 50 3 | n_layer: 8 4 | hidden_size: 512 5 | n_head: 4 6 | 7 | xlstm_config: 8 | mlstm_block: 9 | mlstm: 10 | conv1d_kernel_size: 4 11 | qkv_proj_blocksize: 4 12 | num_heads: ${agent_params.huggingface.n_head} 13 | slstm_block: 14 | slstm: 15 | backend: cuda 16 | num_heads: ${agent_params.huggingface.n_head} 17 | conv1d_kernel_size: 4 18 | bias_init: powerlaw_blockdependent 19 | feedforward: 20 | proj_factor: 1.3 21 | act_fn: gelu 22 | # context length needs to be set to 3 times the max_length --> s/a/r/rtg tokens. 23 | # context_length: ${multiply:${agent_params.huggingface.max_length},4} 24 | context_length: 150 25 | num_blocks: ${agent_params.huggingface.n_layer} 26 | embedding_dim: ${agent_params.huggingface.hidden_size} 27 | # slstm_at: [1] 28 | -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_huge_half.yaml: -------------------------------------------------------------------------------- 1 | max_ep_len: 1000 2 | max_length: 50 3 | n_layer: 10 4 | hidden_size: 1792 5 | n_head: 4 6 | 7 | xlstm_config: 8 | mlstm_block: 9 | mlstm: 10 | conv1d_kernel_size: 4 11 | qkv_proj_blocksize: 4 12 | num_heads: ${agent_params.huggingface.n_head} 13 | slstm_block: 14 | slstm: 15 | backend: cuda 16 | num_heads: ${agent_params.huggingface.n_head} 17 | conv1d_kernel_size: 4 18 | bias_init: powerlaw_blockdependent 19 | feedforward: 20 | proj_factor: 1.3 21 | act_fn: gelu 22 | # context length needs to be set to 3 times the max_length --> s/a/r/rtg tokens. 23 | # context_length: ${multiply:${agent_params.huggingface.max_length},4} 24 | context_length: 150 25 | num_blocks: ${agent_params.huggingface.n_layer} 26 | embedding_dim: ${agent_params.huggingface.hidden_size} 27 | # slstm_at: [1] 28 | -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_large_half.yaml: -------------------------------------------------------------------------------- 1 | max_ep_len: 1000 2 | max_length: 50 3 | n_layer: 8 4 | hidden_size: 1432 5 | n_head: 4 6 | 7 | xlstm_config: 8 | mlstm_block: 9 | mlstm: 10 | conv1d_kernel_size: 4 11 | qkv_proj_blocksize: 4 12 | num_heads: ${agent_params.huggingface.n_head} 13 | slstm_block: 14 | slstm: 15 | backend: cuda 16 | num_heads: ${agent_params.huggingface.n_head} 17 | conv1d_kernel_size: 4 18 | bias_init: powerlaw_blockdependent 19 | feedforward: 20 | proj_factor: 1.3 21 | act_fn: gelu 22 | # context length needs to be set to 3 times the max_length --> s/a/r/rtg tokens. 23 | # context_length: ${multiply:${agent_params.huggingface.max_length},4} 24 | context_length: 150 25 | num_blocks: ${agent_params.huggingface.n_layer} 26 | embedding_dim: ${agent_params.huggingface.hidden_size} 27 | # slstm_at: [1] 28 | -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_medium_half.yaml: -------------------------------------------------------------------------------- 1 | max_ep_len: 1000 2 | max_length: 50 3 | n_layer: 4 4 | hidden_size: 704 5 | n_head: 4 6 | 7 | xlstm_config: 8 | mlstm_block: 9 | mlstm: 10 | conv1d_kernel_size: 4 11 | qkv_proj_blocksize: 4 12 | num_heads: ${agent_params.huggingface.n_head} 13 | slstm_block: 14 | slstm: 15 | backend: cuda 16 | num_heads: ${agent_params.huggingface.n_head} 17 | conv1d_kernel_size: 4 18 | bias_init: powerlaw_blockdependent 19 | feedforward: 20 | proj_factor: 1.3 21 | act_fn: gelu 22 | # context length needs to be set to 3 times the max_length --> s/a/r/rtg tokens. 23 | # context_length: ${multiply:${agent_params.huggingface.max_length},4} 24 | context_length: 150 25 | num_blocks: ${agent_params.huggingface.n_layer} 26 | embedding_dim: ${agent_params.huggingface.hidden_size} 27 | # slstm_at: [1] 28 | -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_mediumplus.yaml: -------------------------------------------------------------------------------- 1 | max_ep_len: 1000 2 | max_length: 50 3 | n_layer: 12 4 | hidden_size: 768 5 | n_head: 4 6 | 7 | xlstm_config: 8 | mlstm_block: 9 | mlstm: 10 | conv1d_kernel_size: 4 11 | qkv_proj_blocksize: 4 12 | num_heads: ${agent_params.huggingface.n_head} 13 | slstm_block: 14 | slstm: 15 | backend: cuda 16 | num_heads: ${agent_params.huggingface.n_head} 17 | conv1d_kernel_size: 4 18 | bias_init: powerlaw_blockdependent 19 | feedforward: 20 | proj_factor: 1.3 21 | act_fn: gelu 22 | # context length needs to be set to 3 times the max_length --> s/a/r/rtg tokens. 23 | # context_length: ${multiply:${agent_params.huggingface.max_length},4} 24 | context_length: 150 25 | num_blocks: ${agent_params.huggingface.n_layer} 26 | embedding_dim: ${agent_params.huggingface.hidden_size} 27 | # slstm_at: [1] 28 | -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_ms_mediumplus.yaml: -------------------------------------------------------------------------------- 1 | max_ep_len: 1000 2 | max_length: 50 3 | n_layer: 12 4 | hidden_size: 768 5 | n_head: 4 6 | 7 | xlstm_config: 8 | mlstm_block: 9 | mlstm: 10 | conv1d_kernel_size: 4 11 | qkv_proj_blocksize: 4 12 | num_heads: ${agent_params.huggingface.n_head} 13 | slstm_block: 14 | slstm: 15 | backend: cuda 16 | num_heads: ${agent_params.huggingface.n_head} 17 | conv1d_kernel_size: 4 18 | bias_init: powerlaw_blockdependent 19 | feedforward: 20 | proj_factor: 1.3 21 | act_fn: gelu 22 | # context length needs to be set to 3 times the max_length --> s/a/r/rtg tokens. 23 | # context_length: ${multiply:${agent_params.huggingface.max_length},4} 24 | context_length: 150 25 | num_blocks: ${agent_params.huggingface.n_layer} 26 | embedding_dim: ${agent_params.huggingface.hidden_size} 27 | slstm_at: [1] 28 | -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_mediumplus_half.yaml: -------------------------------------------------------------------------------- 1 | max_ep_len: 1000 2 | max_length: 50 3 | n_layer: 6 4 | hidden_size: 1064 5 | n_head: 4 6 | 7 | xlstm_config: 8 | mlstm_block: 9 | mlstm: 10 | conv1d_kernel_size: 4 11 | qkv_proj_blocksize: 4 12 | num_heads: ${agent_params.huggingface.n_head} 13 | slstm_block: 14 | slstm: 15 | backend: cuda 16 | num_heads: ${agent_params.huggingface.n_head} 17 | conv1d_kernel_size: 4 18 | bias_init: powerlaw_blockdependent 19 | feedforward: 20 | proj_factor: 1.3 21 | act_fn: gelu 22 | # context length needs to be set to 3 times the max_length --> s/a/r/rtg tokens. 23 | # context_length: ${multiply:${agent_params.huggingface.max_length},4} 24 | context_length: 150 25 | num_blocks: ${agent_params.huggingface.n_layer} 26 | embedding_dim: ${agent_params.huggingface.hidden_size} 27 | # slstm_at: [1] 28 | -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_hugeplus.yaml: -------------------------------------------------------------------------------- 1 | # 408M 2 | max_ep_len: 1000 3 | max_length: 50 4 | n_layer: 28 5 | hidden_size: 1536 6 | n_head: 4 7 | 8 | xlstm_config: 9 | mlstm_block: 10 | mlstm: 11 | conv1d_kernel_size: 4 12 | qkv_proj_blocksize: 4 13 | num_heads: ${agent_params.huggingface.n_head} 14 | slstm_block: 15 | slstm: 16 | backend: cuda 17 | num_heads: ${agent_params.huggingface.n_head} 18 | conv1d_kernel_size: 4 19 | bias_init: powerlaw_blockdependent 20 | feedforward: 21 | proj_factor: 1.3 22 | act_fn: gelu 23 | # context length needs to be set to 3 times the max_length --> s/a/r/rtg tokens. 24 | # context_length: ${multiply:${agent_params.huggingface.max_length},4} 25 | context_length: 150 26 | num_blocks: ${agent_params.huggingface.n_layer} 27 | embedding_dim: ${agent_params.huggingface.hidden_size} 28 | # slstm_at: [1] 29 | -------------------------------------------------------------------------------- /src/algos/models/rms_norm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class LlamaRMSNorm(nn.Module): 9 | def __init__(self, hidden_size, eps=1e-6): 10 | """ 11 | LlamaRMSNorm is equivalent to T5LayerNorm 12 | """ 13 | super().__init__() 14 | self.weight = nn.Parameter(torch.ones(hidden_size)) 15 | self.variance_epsilon = eps 16 | 17 | def forward(self, hidden_states): 18 | input_dtype = hidden_states.dtype 19 | hidden_states = hidden_states.to(torch.float32) 20 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 21 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 22 | return self.weight * hidden_states.to(input_dtype) 23 | 24 | def reset_parameters(self): 25 | nn.init.ones_(self.weight) 26 | -------------------------------------------------------------------------------- /configs/agent_params/darkroom.yaml: -------------------------------------------------------------------------------- 1 | kind: "DDT" 2 | use_critic: False 3 | learning_starts: 0 4 | batch_size: 128 5 | gradient_steps: 1 6 | stochastic_policy: False 7 | loss_fn: "ce" 8 | ent_coef: 0.0 9 | offline_steps: ${run_params.total_timesteps} 10 | buffer_max_len_type: "transition" 11 | buffer_size: 2000000000 12 | buffer_weight_by: len 13 | target_return_type: predefined 14 | warmup_steps: 4000 15 | use_amp: True 16 | compile: True 17 | bfloat16: True 18 | persist_context: True 19 | 20 | replay_buffer_kwargs: 21 | num_workers: 16 22 | pin_memory: False 23 | init_top_p: 1 24 | seqs_per_sample: 2 25 | seq_sample_kind: sequential 26 | full_context_trjs: True 27 | 28 | defaults: 29 | - huggingface: dt_medium_64 30 | - data_paths: dark_room_10x10_sfixed_grand_train 31 | - model_kwargs: dark_room 32 | - lr_sched_kwargs: cosine 33 | 34 | huggingface: 35 | activation_function: gelu 36 | max_length: 50 37 | use_fast_attn: True 38 | n_positions: 1600 39 | eval_context_len: ${agent_params.huggingface.max_length} 40 | -------------------------------------------------------------------------------- /configs/env_params/mt_dmc_procgen_atari.yaml: -------------------------------------------------------------------------------- 1 | # envid: mt45v2_dmc11_pg12_atari41 2 | envid: bigfish 3 | target_return: 1 4 | num_envs: 1 5 | norm_obs: False 6 | record: False 7 | record_freq: 1000000 8 | record_length: 2000 9 | 10 | # Meta-world specific 11 | randomization: random_init_all 12 | remove_task_ids: True 13 | add_task_ids: False 14 | 15 | # procgen specific 16 | distribution_mode: "easy" 17 | # time_limit: 400 18 | env_kwargs: 19 | # data was generated with 0 to 199 20 | num_levels: 200 21 | start_level: 0 22 | eval_env_kwargs: 23 | num_levels: 200 24 | start_level: 0 25 | 26 | # DMC specific 27 | dmc_env_kwargs: 28 | flatten_obs: False 29 | 30 | # atari specific 31 | atari_env_kwargs: 32 | full_action_space: True 33 | wrapper_kwargs: 34 | to_rgb: True 35 | screen_size: 64 36 | 37 | # multi domain evaluation 38 | # eval_env_names: "mt5v2_dmc5_pg4_atari5" 39 | eval_env_names: "mt45v2_dmc11_pg12_atari41" 40 | 41 | reward_scale: 42 | mt50: 200 43 | dmcontrol: 100 44 | procgen: 1 45 | atari: 20 46 | -------------------------------------------------------------------------------- /configs/env_params/mt_dmc_procgen_atari_cs_mg.yaml: -------------------------------------------------------------------------------- 1 | # envid: mt45v2_dmc11_pg12_atari41_cs240_mg24 2 | envid: bigfish 3 | target_return: 1 4 | num_envs: 1 5 | norm_obs: False 6 | record: False 7 | record_freq: 1000000 8 | record_length: 2000 9 | 10 | # Meta-world specific 11 | randomization: random_init_all 12 | remove_task_ids: True 13 | add_task_ids: False 14 | 15 | # procgen specific 16 | distribution_mode: "easy" 17 | # time_limit: 400 18 | env_kwargs: 19 | # data was generated with 0 to 199 20 | num_levels: 200 21 | start_level: 0 22 | eval_env_kwargs: 23 | num_levels: 200 24 | start_level: 0 25 | 26 | # DMC specific 27 | dmc_env_kwargs: 28 | flatten_obs: False 29 | 30 | # atari specific 31 | atari_env_kwargs: 32 | full_action_space: True 33 | wrapper_kwargs: 34 | to_rgb: True 35 | screen_size: 64 36 | 37 | # multi domain evaluation 38 | eval_env_names: "mt45v2_dmc11_pg12_atari41_cs240_mg83" 39 | 40 | reward_scale: 41 | mt50: 200 42 | dmcontrol: 100 43 | procgen: 1 44 | atari: 20 45 | composuite: 1 46 | mimicgen: 1 47 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # on server 2 | LOG_DIR: ./logs 3 | DATA_DIR: ./data 4 | SSD_DATA_DIR: ./data 5 | MODELS_DIR: ./models 6 | 7 | # directory creation is handled by Hydra 8 | hydra: 9 | sweep: 10 | dir: ${LOG_DIR}/${experiment_name}/${now:%Y-%m-%d_%H-%M-%S} 11 | subdir: ${maybe_split:${hydra.job.override_dirname}}/seed=${seed} 12 | run: 13 | dir: ${LOG_DIR}/${experiment_name}/${now:%Y-%m-%d_%H-%M-%S} 14 | job: 15 | config: 16 | override_dirname: 17 | exclude_keys: 18 | - seed 19 | 20 | defaults: 21 | - agent_params: odt 22 | - env_params: mujoco_gym 23 | - eval_params: base 24 | - run_params: base 25 | 26 | # General 27 | experiment_name: test 28 | device: "auto" 29 | seed: 42 30 | # Hydra does the logging for us 31 | logdir: '.' 32 | use_wandb: True 33 | 34 | wandb_params: 35 | project: "LRAM" 36 | sync_tensorboard: True 37 | monitor_gym: True 38 | save_code: True 39 | entity: "X" 40 | key: X 41 | host: X 42 | 43 | wandb_callback_params: 44 | gradient_save_freq: 0 45 | verbose: 1 46 | model_save_path: -------------------------------------------------------------------------------- /configs/env_params/mt_dmc_procgen_atari_cs.yaml: -------------------------------------------------------------------------------- 1 | # envid: mt45v2_dmc11_pg12_atari41_cs240 2 | envid: bigfish 3 | target_return: 1 4 | num_envs: 1 5 | norm_obs: False 6 | record: False 7 | record_freq: 1000000 8 | record_length: 2000 9 | 10 | # Meta-world specific 11 | randomization: random_init_all 12 | remove_task_ids: True 13 | add_task_ids: False 14 | 15 | # procgen specific 16 | distribution_mode: "easy" 17 | # time_limit: 400 18 | env_kwargs: 19 | # data was generated with 0 to 199 20 | num_levels: 200 21 | start_level: 0 22 | eval_env_kwargs: 23 | num_levels: 200 24 | start_level: 0 25 | 26 | # DMC specific 27 | dmc_env_kwargs: 28 | flatten_obs: False 29 | 30 | # atari specific 31 | atari_env_kwargs: 32 | full_action_space: True 33 | wrapper_kwargs: 34 | to_rgb: True 35 | screen_size: 64 36 | 37 | # multi domain evaluation 38 | # eval_env_names: "mt5v2_dmc5_pg4_atari5_cs16" 39 | eval_env_names: "mt45v2_dmc11_pg12_atari41_cs240" 40 | 41 | reward_scale: 42 | mt50: 200 43 | dmcontrol: 100 44 | procgen: 1 45 | atari: 20 46 | composuite: 200 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | protobuf==3.20.1 2 | wheel==0.38.0 3 | setuptools==65.5.0 4 | gym==0.21.0 5 | dm-control==1.0.3.post1 6 | dm-env==1.5 7 | dm-tree==0.1.7 8 | einops==0.7.0 9 | hydra-core==1.2.0 10 | matplotlib==3.5.1 11 | lockfile==0.12.2 12 | mujoco_py==2.0.2.5 13 | numpy==1.22.3 14 | omegaconf 15 | pandas==1.4.2 16 | seaborn==0.11.2 17 | tensorflow==2.8.0 18 | tqdm==4.64.0 19 | transformers==4.39.1 20 | datasets 21 | wandb==0.14.0 22 | gym[atari]==0.21.0 23 | autorom[accept-rom-license] 24 | stable_baselines3[extra]==1.5.0 25 | ale-py==0.7.4 26 | procgen 27 | cloudpickle==2.1.0 28 | # fsspec==2022.1.0 29 | git+https://github.com/denisyarats/dmc2gym.git 30 | torch==2.2.2+cu121 31 | torchvision==0.17.2+cu121 32 | torchaudio==2.2.2+cu121 33 | --extra-index-url https://download.pytorch.org/whl/cu121 34 | torchmetrics==1.2.0 35 | h5py==3.6.0 36 | scikit-learn==1.1.3 37 | gymnasium==0.28.1 38 | dacite==1.8.1 39 | # may results in issues with nle: https://github.com/facebookresearch/nle/issues/246 40 | # afterwards add to bashrc --> LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$CONDA_PREFIX/lib" 41 | # minihack==0.1.5 42 | # opencv-python==4.6.0.66 -------------------------------------------------------------------------------- /src/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def filter_params(params): 5 | # filter out params that don't require gradients 6 | params = list(params) 7 | if isinstance(params[0], dict): 8 | params_filtered = [] 9 | for group in params: 10 | group['params'] = filter(lambda p: p.requires_grad, group['params']) 11 | params_filtered.append(group) 12 | params = params_filtered 13 | else: 14 | params = filter(lambda p: p.requires_grad, params) 15 | return params 16 | 17 | 18 | def make_optimizer(kind, params, lr): 19 | params = filter_params(params) 20 | if kind == "adamw": 21 | optimizer = torch.optim.AdamW(params, lr=lr) 22 | elif kind == "adam": 23 | optimizer = torch.optim.Adam(params, lr=lr) 24 | elif kind == "radam": 25 | optimizer = torch.optim.RAdam(params, lr=lr) 26 | elif kind == "sgd": 27 | optimizer = torch.optim.SGD(params, lr=lr) 28 | elif kind == "rmsprop": 29 | optimizer = torch.optim.RMSprop(params, lr=lr) 30 | else: 31 | raise NotImplementedError() 32 | return optimizer 33 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/mt45_v2.yaml: -------------------------------------------------------------------------------- 1 | base: ${SSD_DATA_DIR}/metaworld 2 | names: 3 | - reach-v2 4 | - push-v2 5 | - pick-place-v2 6 | - door-open-v2 7 | - drawer-open-v2 8 | - drawer-close-v2 9 | - button-press-topdown-v2 10 | - peg-insert-side-v2 11 | - window-open-v2 12 | - window-close-v2 13 | - door-close-v2 14 | - reach-wall-v2 15 | - pick-place-wall-v2 16 | - push-wall-v2 17 | - button-press-v2 18 | - button-press-topdown-wall-v2 19 | - button-press-wall-v2 20 | - peg-unplug-side-v2 21 | - disassemble-v2 22 | - hammer-v2 23 | - plate-slide-v2 24 | - plate-slide-side-v2 25 | - plate-slide-back-v2 26 | - plate-slide-back-side-v2 27 | - handle-press-v2 28 | - handle-pull-v2 29 | - handle-press-side-v2 30 | - handle-pull-side-v2 31 | - stick-push-v2 32 | - stick-pull-v2 33 | - basketball-v2 34 | - soccer-v2 35 | - faucet-open-v2 36 | - faucet-close-v2 37 | - coffee-push-v2 38 | - coffee-pull-v2 39 | - coffee-button-v2 40 | - sweep-v2 41 | - sweep-into-v2 42 | - pick-out-of-hole-v2 43 | - assembly-v2 44 | - shelf-place-v2 45 | - push-back-v2 46 | - lever-pull-v2 47 | - dial-turn-v2 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Institute for Machine Learning, Johannes Kepler University Linz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/names/mt45_v2.yaml: -------------------------------------------------------------------------------- 1 | - reach-v2.pkl 2 | - push-v2.pkl 3 | - pick-place-v2.pkl 4 | - door-open-v2.pkl 5 | - drawer-open-v2.pkl 6 | - drawer-close-v2.pkl 7 | - button-press-topdown-v2.pkl 8 | - peg-insert-side-v2.pkl 9 | - window-open-v2.pkl 10 | - window-close-v2.pkl 11 | - door-close-v2.pkl 12 | - reach-wall-v2.pkl 13 | - pick-place-wall-v2.pkl 14 | - push-wall-v2.pkl 15 | - button-press-v2.pkl 16 | - button-press-topdown-wall-v2.pkl 17 | - button-press-wall-v2.pkl 18 | - peg-unplug-side-v2.pkl 19 | - disassemble-v2.pkl 20 | - hammer-v2.pkl 21 | - plate-slide-v2.pkl 22 | - plate-slide-side-v2.pkl 23 | - plate-slide-back-v2.pkl 24 | - plate-slide-back-side-v2.pkl 25 | - handle-press-v2.pkl 26 | - handle-pull-v2.pkl 27 | - handle-press-side-v2.pkl 28 | - handle-pull-side-v2.pkl 29 | - stick-push-v2.pkl 30 | - stick-pull-v2.pkl 31 | - basketball-v2.pkl 32 | - soccer-v2.pkl 33 | - faucet-open-v2.pkl 34 | - faucet-close-v2.pkl 35 | - coffee-push-v2.pkl 36 | - coffee-pull-v2.pkl 37 | - coffee-button-v2.pkl 38 | - sweep-v2.pkl 39 | - sweep-into-v2.pkl 40 | - pick-out-of-hole-v2.pkl 41 | - assembly-v2.pkl 42 | - shelf-place-v2.pkl 43 | - push-back-v2.pkl 44 | - lever-pull-v2.pkl 45 | - dial-turn-v2.pkl -------------------------------------------------------------------------------- /src/buffers/dataloaders.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | 4 | class MultiEpochsDataLoader(DataLoader): 5 | 6 | def __init__(self, *args, **kwargs): 7 | """ 8 | From: 9 | https://discuss.pytorch.org/t/enumerate-dataloader-slow/87778/4 10 | 11 | Ensures that the Dataset is iterated over multiple times wihthout destroying the workers. 12 | 13 | """ 14 | super().__init__(*args, **kwargs) 15 | self._DataLoader__initialized = False 16 | self.batch_sampler = _RepeatSampler(self.batch_sampler) 17 | self._DataLoader__initialized = True 18 | self.iterator = super().__iter__() 19 | 20 | def __len__(self): 21 | return len(self.batch_sampler.sampler) 22 | 23 | def __iter__(self): 24 | for i in range(len(self)): 25 | yield next(self.iterator) 26 | 27 | 28 | class _RepeatSampler(object): 29 | """ Sampler that repeats forever. 30 | Args: 31 | sampler (Sampler) 32 | """ 33 | 34 | def __init__(self, sampler): 35 | self.sampler = sampler 36 | 37 | def __iter__(self): 38 | while True: 39 | yield from iter(self.sampler) 40 | -------------------------------------------------------------------------------- /src/augmentations/augs.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | import torch.nn.functional as F 3 | 4 | 5 | class CustomRandomCrop(T.RandomCrop): 6 | def __init__(self, size=84, padding=4, **kwargs): 7 | super().__init__(size, **kwargs) 8 | self.padding = padding 9 | 10 | def __call__(self, img): 11 | # first pad image by 4 pixels on each side 12 | img = F.pad(img, (self.padding, self.padding, self.padding, self.padding), mode='replicate') 13 | # crop to original size 14 | return super().__call__(img) 15 | 16 | 17 | def make_augmentations(aug_kwargs=None): 18 | if aug_kwargs is None: 19 | aug_kwargs = {} 20 | aug_kwargs = aug_kwargs.copy() 21 | kind = aug_kwargs.pop("kind", "crop_rotate") 22 | p_aug = aug_kwargs.get("p_aug", 0.5) 23 | if kind == "crop": 24 | return T.RandomApply([CustomRandomCrop(**aug_kwargs)], p=p_aug) 25 | elif kind == "rotate": 26 | degrees = aug_kwargs.pop("degrees", 30) 27 | return T.RandomApply([T.RandomRotation(degrees=degrees, **aug_kwargs)], p=p_aug) 28 | elif kind == "crop_rotate": 29 | degrees = aug_kwargs.pop("degrees", 30) 30 | return T.Compose([ 31 | T.RandomApply([CustomRandomCrop(**aug_kwargs)], p=p_aug), 32 | T.RandomApply([T.RandomRotation(degrees=degrees, **aug_kwargs)], p=p_aug) 33 | ]) 34 | raise ValueError(f"Unknown augmentation kind: {kind}") 35 | -------------------------------------------------------------------------------- /src/schedulers/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import CosineAnnealingLR 2 | 3 | 4 | class CosineAnnealingLRSingleCycle(CosineAnnealingLR): 5 | 6 | def get_lr(self): 7 | # in case T_max is reached, always return eta_min, don't go up in the cycle again 8 | lrs = super().get_lr() 9 | if self.last_epoch >= self.T_max: 10 | lrs = [self.eta_min for _ in self.optimizer.param_groups] 11 | return lrs 12 | 13 | 14 | def make_lr_scheduler(optimizer, kind="cosine", sched_kwargs=None): 15 | if sched_kwargs is None: 16 | sched_kwargs = {} 17 | if kind == "cosine": 18 | return CosineAnnealingLRSingleCycle(optimizer, **sched_kwargs) 19 | elif kind == "cosine_restart": 20 | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts 21 | return CosineAnnealingWarmRestarts(optimizer, **sched_kwargs) 22 | elif kind == "step": 23 | from torch.optim.lr_scheduler import StepLR 24 | return StepLR(optimizer, **sched_kwargs) 25 | elif kind == "plateau": 26 | from torch.optim.lr_scheduler import ReduceLROnPlateau 27 | return ReduceLROnPlateau(optimizer, **sched_kwargs) 28 | elif kind == "cyclic": 29 | from torch.optim.lr_scheduler import CyclicLR 30 | return CyclicLR(optimizer, cycle_momentum=False, **sched_kwargs) 31 | elif kind == "exp": 32 | from torch.optim.lr_scheduler import ExponentialLR 33 | return ExponentialLR(optimizer, **sched_kwargs) 34 | raise ValueError(f"Unknown scheduler {kind}") 35 | -------------------------------------------------------------------------------- /src/data/composuite/README.md: -------------------------------------------------------------------------------- 1 | # Composuite 2 | - Paper: https://arxiv.org/pdf/2207.04136 3 | - Code: https://github.com/Lifelong-ML/CompoSuite 4 | - Documentation & Data Download: https://datadryad.org/stash/dataset/doi:10.5061/dryad.9cnp5hqps 5 | 6 | ## Installation 7 | Composuite uses mujoco and robosuite underneath. We use mujoco 2.3.0 and robosuite 1.4.1 to remain compatible with mimicgen. 8 | Composuite officially requires robosuite==1.4.0, but it is possible to use robosuite==1.4.1. 9 | Requirement is to have gymnasium==0.28.1 installed. Consequently, we make use of compatibitliy wrappers during env creation. 10 | 11 | Install `compusuite` as follows: 12 | ``` 13 | # mujoco 14 | pip install mujoco==2.3.2 15 | 16 | # robosuite 17 | pip install robosuite==1.4.1 18 | # requires 19 | pip install gymnasium==0.28.1 20 | 21 | # git cone, install composuite: https://github.com/Lifelong-ML/CompoSuite.git 22 | git clone https://github.com/Lifelong-ML/CompoSuite.git 23 | cd CompoSuite 24 | pip install -e . 25 | ``` 26 | 27 | ## Troubleshooting 28 | `evdev` installation may cause error when installing robosuite: 29 | ``` 30 | mamba install -c conda-forge evdev=1.7.1 31 | ``` 32 | May result in issue with libffi: 33 | ``` 34 | pip uninstall cffi 35 | pip install cffi==1.15.0 36 | ``` 37 | 38 | ## Data preparation 39 | First, download and extract the `expert` datasets from https://datadryad.org/stash/dataset/doi:10.5061/dryad.9cnp5hqps. 40 | 41 | Then prepare the datasets accordingly: 42 | ``` 43 | cd src/data/composuite 44 | python prepare_data.py --add_rtgs --compress --data_dir=DATA_DIR --save_dir=SAVE_DIR 45 | ``` 46 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import wandb 3 | import omegaconf 4 | from src.utils import maybe_split 5 | from src.envs import make_env 6 | from src.callbacks import make_callbacks 7 | from src.algos.builder import make_agent 8 | from main import setup_logging 9 | 10 | 11 | def evaluate(agent, callbacks): 12 | _, callback = agent._setup_learn(0, None, callbacks) 13 | agent.policy.eval() 14 | callback.on_training_start(locals(), globals()) 15 | agent._dump_logs() 16 | 17 | 18 | @hydra.main(config_path="configs", config_name="config") 19 | def main(config): 20 | print("Config: \n", omegaconf.OmegaConf.to_yaml(config, resolve=True, sort_keys=True)) 21 | config.agent_params.offline_steps = 0 22 | config.agent_params.data_paths = None 23 | run, logdir = setup_logging(config) 24 | env, eval_env, train_eval_env = make_env(config, logdir) 25 | agent = make_agent(config, env, logdir) 26 | callbacks = make_callbacks(config, env=env, eval_env=eval_env, logdir=logdir, train_eval_env=train_eval_env) 27 | try: 28 | # set training steps to 0 to avoid training 29 | agent.learn( 30 | total_timesteps=0, 31 | eval_env=eval_env, 32 | callback=callbacks 33 | ) 34 | finally: 35 | print("Finalizing run...") 36 | if config.use_wandb: 37 | run.finish() 38 | wandb.finish() 39 | # return last avg reward for hparam optimization 40 | if hasattr(agent, "cache"): 41 | agent.cache.cleanup_cache() 42 | 43 | 44 | if __name__ == "__main__": 45 | omegaconf.OmegaConf.register_new_resolver("maybe_split", maybe_split) 46 | main() 47 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/mt45v2_dmc11.yaml: -------------------------------------------------------------------------------- 1 | mt40_v2: 2 | base: ${DATA_DIR}/metaworld 3 | names: 4 | - reach-v2.pkl 5 | - push-v2.pkl 6 | - pick-place-v2.pkl 7 | - door-open-v2.pkl 8 | - drawer-open-v2.pkl 9 | - drawer-close-v2.pkl 10 | - button-press-topdown-v2.pkl 11 | - peg-insert-side-v2.pkl 12 | - window-open-v2.pkl 13 | - window-close-v2.pkl 14 | - door-close-v2.pkl 15 | - reach-wall-v2.pkl 16 | - pick-place-wall-v2.pkl 17 | - push-wall-v2.pkl 18 | - button-press-v2.pkl 19 | - button-press-topdown-wall-v2.pkl 20 | - button-press-wall-v2.pkl 21 | - peg-unplug-side-v2.pkl 22 | - disassemble-v2.pkl 23 | - hammer-v2.pkl 24 | - plate-slide-v2.pkl 25 | - plate-slide-side-v2.pkl 26 | - plate-slide-back-v2.pkl 27 | - plate-slide-back-side-v2.pkl 28 | - handle-press-v2.pkl 29 | - handle-pull-v2.pkl 30 | - handle-press-side-v2.pkl 31 | - handle-pull-side-v2.pkl 32 | - stick-push-v2.pkl 33 | - stick-pull-v2.pkl 34 | - basketball-v2.pkl 35 | - soccer-v2.pkl 36 | - faucet-open-v2.pkl 37 | - faucet-close-v2.pkl 38 | - coffee-push-v2.pkl 39 | - coffee-pull-v2.pkl 40 | - coffee-button-v2.pkl 41 | - sweep-v2.pkl 42 | - sweep-into-v2.pkl 43 | - pick-out-of-hole-v2.pkl 44 | - assembly-v2.pkl 45 | - shelf-place-v2.pkl 46 | - push-back-v2.pkl 47 | - lever-pull-v2.pkl 48 | - dial-turn-v2.pkl 49 | dmcontrol: 50 | base: ${DATA_DIR}/dmc 51 | names: 52 | - finger_turn_easy.npz 53 | - fish_upright.npz 54 | - hopper_stand.npz 55 | - point_mass_easy.npz 56 | - walker_stand.npz 57 | - walker_run.npz 58 | - ball_in_cup_catch.npz 59 | - cartpole_swingup.npz 60 | - cheetah_run.npz 61 | - finger_spin.npz 62 | - reacher_easy.npz 63 | -------------------------------------------------------------------------------- /src/envs/compatibility_wrapper.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import gymnasium 3 | from gym import spaces 4 | 5 | 6 | class GymCompatibilityWrapper(gym.Wrapper): 7 | 8 | def __init__(self, env) -> None: 9 | """ 10 | Minimum compatibility for gymnasium to gym envs. 11 | Handles the conversions of obs/act spaces + step/reset calls. 12 | 13 | Necessary, such that we can use robosuite==1.4.1 and mimicgen, while 14 | keeping composuite/mimicgen the same. 15 | 16 | Args: 17 | env: gym.Env. 18 | """ 19 | super().__init__(env) 20 | self.observation_space = self._convert_space(self.observation_space) 21 | self.action_space = self._convert_space(self.action_space) 22 | 23 | def _convert_space(self, space): 24 | """ 25 | Converts gymnasium spaces to gym spaces. 26 | """ 27 | if isinstance(space, gymnasium.spaces.Discrete): 28 | return spaces.Discrete(space.n) 29 | elif isinstance(space, spaces.Discrete): 30 | # correct already 31 | return space 32 | elif isinstance(space, gymnasium.spaces.Box): 33 | return spaces.Box(low=space.low, high=space.high, shape=space.shape, dtype=space.dtype) 34 | elif isinstance(space, spaces.Box): 35 | # correct already 36 | return space 37 | else: 38 | raise NotImplementedError("This space type is not supported yet.") 39 | 40 | def seed(self, seed=None): 41 | pass 42 | 43 | def step(self, action): 44 | obs, reward, terminated, truncated, info = self.env.step(action) 45 | done = terminated or truncated 46 | return obs, reward, done, info 47 | 48 | def reset(self): 49 | obs, _ = self.env.reset() 50 | return obs 51 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/mt45v2_dmc11_pg12.yaml: -------------------------------------------------------------------------------- 1 | mt40_v2: 2 | base: ${SSD_DATA_DIR}/metaworld 3 | names: 4 | - reach-v2 5 | - push-v2 6 | - pick-place-v2 7 | - door-open-v2 8 | - drawer-open-v2 9 | - drawer-close-v2 10 | - button-press-topdown-v2 11 | - peg-insert-side-v2 12 | - window-open-v2 13 | - window-close-v2 14 | - door-close-v2 15 | - reach-wall-v2 16 | - pick-place-wall-v2 17 | - push-wall-v2 18 | - button-press-v2 19 | - button-press-topdown-wall-v2 20 | - button-press-wall-v2 21 | - peg-unplug-side-v2 22 | - disassemble-v2 23 | - hammer-v2 24 | - plate-slide-v2 25 | - plate-slide-side-v2 26 | - plate-slide-back-v2 27 | - plate-slide-back-side-v2 28 | - handle-press-v2 29 | - handle-pull-v2 30 | - handle-press-side-v2 31 | - handle-pull-side-v2 32 | - stick-push-v2 33 | - stick-pull-v2 34 | - basketball-v2 35 | - soccer-v2 36 | - faucet-open-v2 37 | - faucet-close-v2 38 | - coffee-push-v2 39 | - coffee-pull-v2 40 | - coffee-button-v2 41 | - sweep-v2 42 | - sweep-into-v2 43 | - pick-out-of-hole-v2 44 | - assembly-v2 45 | - shelf-place-v2 46 | - push-back-v2 47 | - lever-pull-v2 48 | - dial-turn-v2 49 | dmcontrol: 50 | base: ${SSD_DATA_DIR}/dmc 51 | names: 52 | - finger_turn_easy 53 | - fish_upright 54 | - hopper_stand 55 | - point_mass_easy 56 | - walker_stand 57 | - walker_run 58 | - ball_in_cup_catch 59 | - cartpole_swingup 60 | - cheetah_run 61 | - finger_spin 62 | - reacher_easy 63 | procgen: 64 | base: ${SSD_DATA_DIR}/procgen/processed_25M_custom 65 | names: 66 | - "bigfish" 67 | - "bossfight" 68 | - "caveflyer" 69 | - "chaser" 70 | - "coinrun" 71 | - "dodgeball" 72 | - "fruitbot" 73 | - "heist" 74 | - "leaper" 75 | - "maze" 76 | - "miner" 77 | - "starpilot" -------------------------------------------------------------------------------- /src/algos/decision_xlstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .universal_decision_transformer_sb3 import UDT 3 | from .discrete_decision_transformer_sb3 import DiscreteDecisionTransformerSb3 4 | 5 | 6 | class DecisionXLSTM(UDT): 7 | 8 | def __init__(self, policy, env, **kwargs): 9 | super().__init__(policy, env, **kwargs) 10 | 11 | def pad_inputs(self, states, actions, returns_to_go, timesteps, context_len=5, rewards=None): 12 | if self.use_inference_cache: 13 | # no need to pad inputs 14 | context_len = 1 15 | attention_mask = torch.ones(actions.shape[1], device=self.device, dtype=torch.long).reshape(1, -1) 16 | if self.replay_buffer.max_state_dim is not None and len(states.shape) == 3 and not self.s_proj_raw: 17 | # pad state input to max_state_dim, in case of continous state 18 | s_pad = self.replay_buffer.max_state_dim - states.shape[-1] 19 | states = torch.cat([states, torch.zeros((*states.shape[:-1], s_pad), device=self.device)], dim=-1) 20 | if self.replay_buffer.max_act_dim is not None and actions.is_floating_point(): 21 | # check if observations are images --> discrete action 22 | if len(states.shape) != 5: 23 | a_pad = self.replay_buffer.max_act_dim - actions.shape[-1] 24 | actions = torch.cat([actions, torch.zeros((*actions.shape[:-1], a_pad), device=self.device)], dim=-1) 25 | return states.float(), actions, returns_to_go.float(), timesteps, attention_mask, rewards 26 | else: 27 | return super().pad_inputs(states, actions, returns_to_go, timesteps, 28 | context_len=context_len, rewards=rewards) 29 | 30 | 31 | 32 | class DiscreteDecisionXLSTM(DiscreteDecisionTransformerSb3, DecisionXLSTM): 33 | 34 | def __init__(self, policy, env, **kwargs): 35 | super().__init__(policy, env, **kwargs) 36 | -------------------------------------------------------------------------------- /dmc2gym_custom/dmc2gym_custom/__init__.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym.envs.registration import register 3 | 4 | 5 | def make( 6 | domain_name, 7 | task_name, 8 | seed=1, 9 | visualize_reward=True, 10 | from_pixels=False, 11 | height=84, 12 | width=84, 13 | camera_id=0, 14 | frame_skip=1, 15 | episode_length=1000, 16 | environment_kwargs=None, 17 | time_limit=None, 18 | channels_first=True, 19 | flatten_obs=True, 20 | deterministic=False 21 | ): 22 | env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed) 23 | 24 | if from_pixels: 25 | assert not visualize_reward, 'cannot use visualize reward when learning from pixels' 26 | 27 | # shorten episode length 28 | max_episode_steps = (episode_length + frame_skip - 1) // frame_skip 29 | 30 | # make env kwargs 31 | env_kwargs = dict( 32 | environment_kwargs=environment_kwargs, 33 | visualize_reward=visualize_reward, 34 | from_pixels=from_pixels, 35 | height=height, 36 | width=width, 37 | camera_id=camera_id, 38 | frame_skip=frame_skip, 39 | channels_first=channels_first, 40 | flatten_obs=flatten_obs, 41 | deterministic=deterministic 42 | ) 43 | 44 | if not env_id in gym.envs.registry.env_specs: 45 | task_kwargs = {} 46 | if seed is not None: 47 | task_kwargs['random'] = seed 48 | if time_limit is not None: 49 | task_kwargs['time_limit'] = time_limit 50 | register( 51 | id=env_id, 52 | entry_point='dmc2gym_custom.wrappers:DMCWrapper', 53 | kwargs=dict( 54 | domain_name=domain_name, 55 | task_name=task_name, 56 | task_kwargs=task_kwargs, 57 | **env_kwargs 58 | ), 59 | max_episode_steps=max_episode_steps, 60 | ) 61 | return gym.make(env_id, **env_kwargs) -------------------------------------------------------------------------------- /src/data/parallel_copy.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import joblib 3 | import subprocess 4 | from pathlib import Path 5 | from tqdm import tqdm 6 | from joblib import delayed 7 | 8 | 9 | class ProgressParallel(joblib.Parallel): 10 | def __init__(self, use_tqdm=True, total=None, *args, **kwargs): 11 | self._use_tqdm = use_tqdm 12 | self._total = total 13 | super().__init__(*args, **kwargs) 14 | 15 | def __call__(self, *args, **kwargs): 16 | with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar: 17 | return joblib.Parallel.__call__(self, *args, **kwargs) 18 | 19 | def print_progress(self): 20 | if self._total is None: 21 | self._pbar.total = self.n_dispatched_tasks 22 | self._pbar.n = self.n_completed_tasks 23 | self._pbar.refresh() 24 | 25 | 26 | def copy_file(file_path, source, destination): 27 | relative_path = Path(file_path).relative_to(source) 28 | destination_path = Path(destination) / relative_path 29 | destination_path.parent.mkdir(parents=True, exist_ok=True) 30 | subprocess.run(['cp', '-p', str(file_path), str(destination_path)], check=True) 31 | 32 | 33 | def main(source, destination, n_jobs, suffix=None): 34 | print("Collecting files.") 35 | pattern = f"**/*.{suffix}" if suffix is not None else "*" 36 | file_list = [file for file in Path(source).rglob(pattern) if file.is_file()] 37 | print(f"Copying {len(file_list)} files.") 38 | ProgressParallel(n_jobs=n_jobs, total=len(file_list))(delayed(copy_file)(file, source, destination) for file in file_list) 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser(description="Parallel cp with progress bar.") 43 | parser.add_argument('--src', type=str, help='Source directory to copy from.') 44 | parser.add_argument('--dst', type=str, help='Destination directory to copy to.') 45 | parser.add_argument('--suffix', type=str) 46 | parser.add_argument('--n_jobs', type=int, default=1, help='Number of parallel jobs.') 47 | args = parser.parse_args() 48 | main(args.src, args.dst, args.n_jobs, args.suffix) -------------------------------------------------------------------------------- /configs/agent_params/data_paths/dark_room_10x10.yaml: -------------------------------------------------------------------------------- 1 | base: ${DATA_DIR}/minihack/dark_room_10x10_v1 2 | names: 3 | - "[0,0]_[2,0].pkl" 4 | - "[0,0]_[0,2].pkl" 5 | - "[0,0]_[1,5].pkl" 6 | - "[0,0]_[2,2].pkl" 7 | - "[0,0]_[5,7].pkl" 8 | - "[0,0]_[9,1].pkl" 9 | - "[0,0]_[6,9].pkl" 10 | - "[0,0]_[5,5].pkl" 11 | - "[0,0]_[1,1].pkl" 12 | - "[0,0]_[7,9].pkl" 13 | - "[0,0]_[0,9].pkl" 14 | - "[0,0]_[3,8].pkl" 15 | - "[0,0]_[8,5].pkl" 16 | - "[0,0]_[0,0].pkl" 17 | - "[0,0]_[8,9].pkl" 18 | - "[0,0]_[1,3].pkl" 19 | - "[0,0]_[0,5].pkl" 20 | - "[0,0]_[0,1].pkl" 21 | - "[0,0]_[9,5].pkl" 22 | - "[0,0]_[8,3].pkl" 23 | - "[0,0]_[4,4].pkl" 24 | - "[0,0]_[1,2].pkl" 25 | - "[0,0]_[7,8].pkl" 26 | - "[0,0]_[9,4].pkl" 27 | - "[0,0]_[3,7].pkl" 28 | - "[0,0]_[9,2].pkl" 29 | - "[0,0]_[9,7].pkl" 30 | - "[0,0]_[5,6].pkl" 31 | - "[0,0]_[6,3].pkl" 32 | - "[0,0]_[4,6].pkl" 33 | - "[0,0]_[0,8].pkl" 34 | - "[0,0]_[3,3].pkl" 35 | - "[0,0]_[4,5].pkl" 36 | - "[0,0]_[1,9].pkl" 37 | - "[0,0]_[1,4].pkl" 38 | - "[0,0]_[9,3].pkl" 39 | - "[0,0]_[7,3].pkl" 40 | - "[0,0]_[3,9].pkl" 41 | - "[0,0]_[2,4].pkl" 42 | - "[0,0]_[0,6].pkl" 43 | - "[0,0]_[6,2].pkl" 44 | - "[0,0]_[2,3].pkl" 45 | - "[0,0]_[1,8].pkl" 46 | - "[0,0]_[4,2].pkl" 47 | - "[0,0]_[8,0].pkl" 48 | - "[0,0]_[8,6].pkl" 49 | - "[0,0]_[3,1].pkl" 50 | - "[0,0]_[6,7].pkl" 51 | - "[0,0]_[2,7].pkl" 52 | - "[0,0]_[1,0].pkl" 53 | - "[0,0]_[4,0].pkl" 54 | - "[0,0]_[7,0].pkl" 55 | - "[0,0]_[9,6].pkl" 56 | - "[0,0]_[8,8].pkl" 57 | - "[0,0]_[2,6].pkl" 58 | - "[0,0]_[5,4].pkl" 59 | - "[0,0]_[7,1].pkl" 60 | - "[0,0]_[2,9].pkl" 61 | - "[0,0]_[2,5].pkl" 62 | - "[0,0]_[4,3].pkl" 63 | - "[0,0]_[4,1].pkl" 64 | - "[0,0]_[7,2].pkl" 65 | - "[0,0]_[0,3].pkl" 66 | - "[0,0]_[8,1].pkl" 67 | - "[0,0]_[5,3].pkl" 68 | - "[0,0]_[7,7].pkl" 69 | - "[0,0]_[6,1].pkl" 70 | - "[0,0]_[0,7].pkl" 71 | - "[0,0]_[2,8].pkl" 72 | - "[0,0]_[9,9].pkl" 73 | - "[0,0]_[5,2].pkl" 74 | - "[0,0]_[4,8].pkl" 75 | - "[0,0]_[8,2].pkl" 76 | - "[0,0]_[6,0].pkl" 77 | - "[0,0]_[7,4].pkl" 78 | - "[0,0]_[3,2].pkl" 79 | - "[0,0]_[4,7].pkl" 80 | - "[0,0]_[7,6].pkl" 81 | - "[0,0]_[3,6].pkl" 82 | - "[0,0]_[0,4].pkl" -------------------------------------------------------------------------------- /src/algos/models/token_learner.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adjusted from the TF implementation provided by: 3 | https://github.com/google-research/robotics_transformer/blob/master/tokenizers/token_learner.py 4 | 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class MlpBlock(nn.Module): 12 | 13 | def __init__(self, mlp_dim, out_dim, dropout_rate: float = 0.1): 14 | """ 15 | Initializer for the MLP Block. 16 | 17 | This computes outer_dense(gelu(hidden_dense(input))), with dropout 18 | applied as necessary. 19 | 20 | Args: 21 | mlp_dim: The dimension of the inner representation (output of hidden 22 | layer). Usually larger than the input/output dim. 23 | out_dim: The output dimension of the block. If None, the model output dim 24 | is equal to the input dim (usually desired) 25 | dropout_rate: Dropout rate to be applied after dense ( & activation) 26 | 27 | """ 28 | super().__init__() 29 | self.net = nn.Sequential( 30 | nn.Linear(mlp_dim, mlp_dim), 31 | nn.GELU(), 32 | nn.Dropout(dropout_rate), 33 | nn.Linear(mlp_dim, out_dim), 34 | nn.Dropout(dropout_rate) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.net(x) 39 | 40 | 41 | class TokenLearnerModule(nn.Module): 42 | """TokenLearner module V1.1 (https://arxiv.org/abs/2106.11297).""" 43 | 44 | def __init__(self, 45 | num_tokens: int = 8, 46 | bottleneck_dim: int = 64, 47 | dropout_rate: float = 0.): 48 | super().__init__() 49 | self.mlp = MlpBlock(mlp_dim=bottleneck_dim, out_dim=num_tokens, dropout_rate=dropout_rate) 50 | self.layernorm = nn.LayerNorm(bottleneck_dim) 51 | 52 | def forward(self, x): 53 | if len(x.shape) == 4: 54 | batch_size, height, width, channels = x.shape 55 | x = x.reshape(batch_size, height * width, channels) 56 | 57 | selected = self.layernorm(x) 58 | # Shape: [bs, h*w, n_token]. 59 | selected = self.mlp(x) 60 | # Shape: [bs, n_token, h*w]. 61 | selected = selected.permute(0, 2, 1) 62 | selected = F.softmax(selected, dim=-1) 63 | 64 | # Shape: [bs, n_token, c] 65 | return torch.einsum("...si,...id->...sd", selected, x) 66 | -------------------------------------------------------------------------------- /src/envs/dummy_env_utils.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | import numpy as np 4 | from stable_baselines3.common.monitor import Monitor 5 | from stable_baselines3.common.env_util import DummyVecEnv 6 | 7 | 8 | class DummyEnv(gym.Env): 9 | def __init__(self, obs_dim=10, act_dim=1, ep_len=1000): 10 | super(DummyEnv, self).__init__() 11 | self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(obs_dim,), dtype=np.float32) 12 | self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(act_dim,), dtype=np.float32) 13 | self.ep_len = ep_len 14 | self.current_step = 0 15 | 16 | def reset(self): 17 | """Reset the environment and return the initial observation.""" 18 | self.current_step = 0 19 | # Return an initial dummy observation 20 | return self.observation_space.sample() 21 | 22 | def step(self, action): 23 | """Take a step in the environment.""" 24 | self.current_step += 1 25 | observation = self.observation_space.sample() 26 | done = self.current_step >= self.ep_len 27 | return observation, 1, done, {} 28 | 29 | def render(self, mode="human"): 30 | """Render the environment (optional, can be extended).""" 31 | pass 32 | 33 | def close(self): 34 | """Clean up resources (optional).""" 35 | pass 36 | 37 | 38 | def get_dummyenv_constructor(envid, env_kwargs=None): 39 | env_kwargs = dict(env_kwargs) if env_kwargs is not None else {} 40 | def make(): 41 | env = DummyEnv(**env_kwargs) 42 | env.name = envid 43 | return Monitor(env) 44 | return make 45 | 46 | 47 | def get_dummyenv_constructors(envid, env_kwargs=None): 48 | if not isinstance(envid, list): 49 | envid = [envid] 50 | return [get_dummyenv_constructor(eid, env_kwargs=env_kwargs) for eid in envid] 51 | 52 | 53 | def make_dummyenv_envs(env_params, envid, make_eval_env=True): 54 | const_kwargs = { 55 | "envid": envid, 56 | "env_kwargs": env_params.get("env_kwargs", {}), 57 | } 58 | env = DummyVecEnv(get_dummyenv_constructors(**const_kwargs)) 59 | eval_env = None 60 | if make_eval_env: 61 | eval_env = DummyVecEnv(get_dummyenv_constructors(**const_kwargs)) 62 | eval_env.num_envs = 1 63 | env.num_envs = 1 64 | return env, eval_env 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # custom 132 | .tmp/ 133 | .idea 134 | .vscode 135 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/dark_keydoor_10x10.yaml: -------------------------------------------------------------------------------- 1 | base: ${DATA_DIR}/minihack/dark_keydoor_10x10_v1 2 | names: 3 | - "[0,0]_[2,0]_[8,3].pkl" 4 | - "[0,0]_[0,2]_[5,3].pkl" 5 | - "[0,0]_[1,5]_[7,0].pkl" 6 | - "[0,0]_[2,2]_[4,5].pkl" 7 | - "[0,0]_[5,7]_[4,4].pkl" 8 | - "[0,0]_[9,1]_[3,9].pkl" 9 | - "[0,0]_[6,9]_[2,2].pkl" 10 | - "[0,0]_[5,5]_[8,0].pkl" 11 | - "[0,0]_[1,1]_[1,0].pkl" 12 | - "[0,0]_[7,9]_[0,0].pkl" 13 | - "[0,0]_[0,9]_[1,8].pkl" 14 | - "[0,0]_[3,8]_[3,0].pkl" 15 | - "[0,0]_[8,5]_[7,3].pkl" 16 | - "[0,0]_[0,0]_[3,3].pkl" 17 | - "[0,0]_[8,9]_[9,0].pkl" 18 | - "[0,0]_[1,3]_[0,4].pkl" 19 | - "[0,0]_[0,5]_[7,6].pkl" 20 | - "[0,0]_[0,1]_[7,7].pkl" 21 | - "[0,0]_[9,5]_[1,2].pkl" 22 | - "[0,0]_[8,3]_[3,1].pkl" 23 | - "[0,0]_[4,4]_[5,5].pkl" 24 | - "[0,0]_[1,2]_[8,8].pkl" 25 | - "[0,0]_[7,8]_[2,6].pkl" 26 | - "[0,0]_[9,4]_[4,2].pkl" 27 | - "[0,0]_[3,7]_[6,9].pkl" 28 | - "[0,0]_[9,2]_[1,5].pkl" 29 | - "[0,0]_[9,7]_[4,0].pkl" 30 | - "[0,0]_[5,6]_[9,6].pkl" 31 | - "[0,0]_[6,3]_[0,9].pkl" 32 | - "[0,0]_[4,6]_[7,2].pkl" 33 | - "[0,0]_[0,8]_[1,1].pkl" 34 | - "[0,0]_[3,3]_[4,7].pkl" 35 | - "[0,0]_[4,5]_[8,5].pkl" 36 | - "[0,0]_[1,9]_[2,8].pkl" 37 | - "[0,0]_[1,4]_[9,3].pkl" 38 | - "[0,0]_[9,3]_[0,5].pkl" 39 | - "[0,0]_[7,3]_[6,6].pkl" 40 | - "[0,0]_[3,9]_[6,5].pkl" 41 | - "[0,0]_[2,4]_[3,5].pkl" 42 | - "[0,0]_[0,6]_[1,6].pkl" 43 | - "[0,0]_[6,2]_[4,9].pkl" 44 | - "[0,0]_[2,3]_[3,4].pkl" 45 | - "[0,0]_[1,8]_[0,7].pkl" 46 | - "[0,0]_[4,2]_[9,5].pkl" 47 | - "[0,0]_[8,0]_[2,7].pkl" 48 | - "[0,0]_[8,6]_[1,9].pkl" 49 | - "[0,0]_[3,1]_[8,1].pkl" 50 | - "[0,0]_[6,7]_[2,5].pkl" 51 | - "[0,0]_[2,7]_[6,2].pkl" 52 | - "[0,0]_[1,0]_[1,3].pkl" 53 | - "[0,0]_[4,0]_[2,4].pkl" 54 | - "[0,0]_[7,0]_[0,3].pkl" 55 | - "[0,0]_[9,6]_[1,7].pkl" 56 | - "[0,0]_[8,8]_[3,8].pkl" 57 | - "[0,0]_[2,6]_[0,8].pkl" 58 | - "[0,0]_[5,4]_[7,8].pkl" 59 | - "[0,0]_[7,1]_[0,6].pkl" 60 | - "[0,0]_[2,9]_[6,4].pkl" 61 | - "[0,0]_[2,5]_[3,6].pkl" 62 | - "[0,0]_[4,3]_[8,9].pkl" 63 | - "[0,0]_[4,1]_[5,6].pkl" 64 | - "[0,0]_[7,2]_[9,9].pkl" 65 | - "[0,0]_[0,3]_[5,4].pkl" 66 | - "[0,0]_[8,1]_[4,3].pkl" 67 | - "[0,0]_[5,3]_[5,0].pkl" 68 | - "[0,0]_[7,7]_[6,7].pkl" 69 | - "[0,0]_[6,1]_[4,6].pkl" 70 | - "[0,0]_[0,7]_[6,8].pkl" 71 | - "[0,0]_[2,8]_[6,1].pkl" 72 | - "[0,0]_[9,9]_[9,7].pkl" 73 | - "[0,0]_[5,2]_[7,9].pkl" 74 | - "[0,0]_[4,8]_[4,1].pkl" 75 | - "[0,0]_[8,2]_[5,8].pkl" 76 | - "[0,0]_[6,0]_[4,8].pkl" 77 | - "[0,0]_[7,4]_[9,8].pkl" 78 | - "[0,0]_[3,2]_[5,7].pkl" 79 | - "[0,0]_[4,7]_[7,5].pkl" 80 | - "[0,0]_[7,6]_[3,2].pkl" 81 | - "[0,0]_[3,6]_[9,4].pkl" 82 | - "[0,0]_[0,4]_[5,9].pkl" -------------------------------------------------------------------------------- /src/utils/debug.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib.lines import Line2D 4 | from pathlib import Path 5 | from datetime import datetime 6 | 7 | 8 | class GradPlotter: 9 | 10 | def __init__(self, y_min=-0.01, y_max=0.5, base_dir=None): 11 | self.base_dir = base_dir 12 | if self.base_dir is None: 13 | self.base_dir = "./debug" 14 | time = datetime.now().strftime("%d-%m-%Y_%Hh%Mm") 15 | self.base_dir = Path(self.base_dir) / time 16 | self.base_dir.mkdir(exist_ok=True, parents=True) 17 | self.y_min = y_min 18 | self.y_max = y_max 19 | 20 | def plot_grad_flow(self, named_parameters, file_name): 21 | """ 22 | Adjusted from: 23 | https://gist.github.com/Flova/8bed128b41a74142a661883af9e51490 24 | 25 | Plots the gradients flowing through different layers in the net during training. 26 | Can be used for checking for possible gradient vanishing / exploding problems. 27 | 28 | Usage: Plug this function in Trainer class after loss.backwards() as 29 | "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow 30 | 31 | E.g., call using: 32 | if self._n_updates % 1000 == 0: 33 | plot_grad_flow(self.critic.named_parameters(), f"critic_update,critic,step={self._n_updates}.png") 34 | plot_grad_flow(self.policy.named_parameters(), f"critic_update,policy,step={self._n_updates}.png") 35 | 36 | """ 37 | ave_grads, max_grads, layers = [], [], [] 38 | for n, p in named_parameters: 39 | if (p.requires_grad) and ("bias" not in n): 40 | layers.append(n) 41 | ave_grads.append(p.grad.abs().mean().item() if p.grad is not None else 0) 42 | max_grads.append(p.grad.abs().max().item() if p.grad is not None else 0) 43 | plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.5, lw=1, color="c") 44 | plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.5, lw=1, color="b") 45 | plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k") 46 | plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") 47 | plt.xlim(left=0, right=len(ave_grads)) 48 | # plt.ylim(bottom=self.y_min, top=self.y_max) 49 | plt.xlabel("Layers") 50 | plt.ylabel("average gradient") 51 | plt.title("Gradient flow") 52 | plt.grid(True) 53 | plt.legend([Line2D([0], [0], color="c", lw=4), 54 | Line2D([0], [0], color="b", lw=4), 55 | Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient']) 56 | plt.savefig(self.base_dir / file_name, bbox_inches='tight') 57 | plt.close() 58 | -------------------------------------------------------------------------------- /src/algos/models/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def sample_from_logits(logits, temperature=1.0, top_k=0, top_p=0.5): 8 | """ 9 | Adjusted from: 10 | - https://github.com/google-research/google-research/tree/master/multi_game_dt 11 | - https://github.com/etaoxing/multigame-dt 12 | """ 13 | logits = logits.double() 14 | if top_p > 0.0: 15 | # percentile: 0 to 100, quantile: 0 to 1 16 | # torch.quantile cannot handle float16 17 | percentile = torch.quantile(logits, top_p, dim=-1) 18 | if percentile != logits.max(): 19 | # otherwise all logits would become -inf 20 | logits = torch.where(logits > percentile.unsqueeze(-1), logits, -float("inf")) 21 | if top_k > 0: 22 | logits, top_indices = torch.topk(logits, top_k) 23 | sample = torch.distributions.Categorical(logits=temperature * logits).sample() 24 | if top_k > 0: 25 | sample_shape = sample.shape 26 | # Flatten top-k indices and samples for easy indexing. 27 | top_indices = torch.reshape(top_indices, [-1, top_k]) 28 | sample = sample.flatten() 29 | sample = top_indices[torch.arange(len(sample)), sample] 30 | # Reshape samples back to original dimensions. 31 | sample = torch.reshape(sample, sample_shape) 32 | return sample 33 | 34 | 35 | def position_encoding_init(n_position, d_pos_vec): 36 | position_enc = np.array([ 37 | [pos / np.power(10000, 2*i/d_pos_vec) for i in range(d_pos_vec)] 38 | for pos in range(n_position)]) 39 | 40 | position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i 41 | position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1 42 | return torch.from_numpy(position_enc).type(torch.FloatTensor) 43 | 44 | 45 | def make_sinusoidal_embd(n_positions, embed_dim): 46 | position_enc = torch.nn.Embedding(n_positions, embed_dim) 47 | position_enc.weight.data = position_encoding_init(n_positions, embed_dim) 48 | return position_enc 49 | 50 | 51 | def symlog(x): 52 | return torch.sign(x) * torch.log(1 + torch.abs(x)) 53 | 54 | 55 | class SwiGLU(nn.Module): 56 | # SwiGLU https://arxiv.org/abs/2002.05202 57 | def forward(self, x): 58 | x, gate = x.chunk(2, dim=-1) 59 | return F.silu(gate) * x 60 | 61 | 62 | class GEGLU(nn.Module): 63 | """ 64 | References: 65 | Shazeer et al., "GLU Variants Improve Transformer," 2020. 66 | https://arxiv.org/abs/2002.05202 67 | """ 68 | 69 | def geglu(self, x): 70 | assert x.shape[-1] % 2 == 0 71 | a, b = x.chunk(2, dim=-1) 72 | return a * F.gelu(b) 73 | 74 | def forward(self, x): 75 | return self.geglu(x) 76 | -------------------------------------------------------------------------------- /src/tokenizers_custom/mu_law_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adjusted from: https://github.com/G-Wang/WaveRNN-Pytorch/blob/master/utils.py 3 | 4 | """ 5 | import torch 6 | import numpy as np 7 | from .base_tokenizer import BaseTokenizer 8 | 9 | 10 | class MuLawTokenizer(BaseTokenizer): 11 | 12 | def __init__(self, **kwargs): 13 | super().__init__(**kwargs) 14 | 15 | def tokenize(self, x): 16 | """ 17 | Encode signal based on mu-law companding. For more info see the 18 | `Wikipedia Entry `_ 19 | This algorithm assumes the signal has been scaled to between -1 and 1 and 20 | returns a signal encoded with values from 0 to quantization_channels - 1 21 | Args: 22 | quantization_channels (int): Number of channels. default: 256 23 | """ 24 | mu = self.vocab_size - 1 25 | if isinstance(x, np.ndarray): 26 | x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu) 27 | tokens = ((x_mu + 1) / 2 * mu + 0.5).astype(int) 28 | if self.shift != 0: 29 | return tokens + self.shift 30 | return tokens 31 | elif isinstance(x, (torch.Tensor, torch.LongTensor)): 32 | if isinstance(x, torch.LongTensor): 33 | x = x.float() 34 | mu = torch.FloatTensor([mu]).to(device=x.device) 35 | x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu) 36 | tokens = ((x_mu + 1) / 2 * mu + 0.5).long() 37 | if self.shift != 0: 38 | return tokens + self.shift 39 | return tokens 40 | raise NotImplementedError() 41 | 42 | def inv_tokenize(self, x_mu): 43 | """ 44 | Decode mu-law encoded signal. For more info see the 45 | `Wikipedia Entry `_ 46 | This expects an input with values between 0 and quantization_channels - 1 47 | and returns a signal scaled between -1 and 1. 48 | Args: 49 | quantization_channels (int): Number of channels. default: 256 50 | """ 51 | mu = self.vocab_size - 1. 52 | if self.shift != 0: 53 | x_mu = x_mu - self.shift 54 | if isinstance(x_mu, np.ndarray): 55 | x = ((x_mu) / mu) * 2 - 1. 56 | return np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu 57 | elif isinstance(x_mu, (torch.Tensor, torch.LongTensor)): 58 | if isinstance(x_mu, (torch.LongTensor, torch.cuda.LongTensor)): 59 | x_mu = x_mu.float() 60 | mu = torch.FloatTensor([mu]).to(x_mu.device) 61 | x = ((x_mu) / mu) * 2 - 1. 62 | return torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu 63 | raise NotImplementedError() 64 | -------------------------------------------------------------------------------- /configs/env_params/dark_room.yaml: -------------------------------------------------------------------------------- 1 | envid: MiniHack-Room-Dark-10x10-v0 2 | target_return: 1 3 | num_envs: 1 4 | norm_obs: False 5 | record: False 6 | record_freq: 1000000 7 | record_length: 2000 8 | reward_scale: 1 9 | 10 | # randomly selected positions 11 | # 80 for training, 20 for testing 12 | env_kwargs: 13 | random: False 14 | observation_keys: 15 | - tty_cursor 16 | 17 | train_start_pos: 18 | - [0,0] 19 | 20 | train_goal_pos: 21 | - [2, 0] 22 | 23 | # train tasks for sfixed, grand 24 | eval_start_pos: 25 | - [8, 3] 26 | - [5, 3] 27 | - [7, 0] 28 | - [4, 5] 29 | - [4, 4] 30 | - [3, 9] 31 | - [2, 2] 32 | - [8, 0] 33 | - [1, 0] 34 | - [0, 0] 35 | - [1, 8] 36 | - [3, 0] 37 | - [7, 3] 38 | - [3, 3] 39 | - [9, 0] 40 | - [0, 4] 41 | - [7, 6] 42 | - [7, 7] 43 | - [1, 2] 44 | - [3, 1] 45 | - [5, 5] 46 | - [8, 8] 47 | - [2, 6] 48 | - [4, 2] 49 | - [6, 9] 50 | - [1, 5] 51 | - [4, 0] 52 | - [9, 6] 53 | - [0, 9] 54 | - [7, 2] 55 | - [1, 1] 56 | - [4, 7] 57 | - [8, 5] 58 | - [2, 8] 59 | - [9, 3] 60 | - [0, 5] 61 | - [6, 6] 62 | - [6, 5] 63 | - [3, 5] 64 | - [1, 6] 65 | - [4, 9] 66 | - [3, 4] 67 | - [0, 7] 68 | - [9, 5] 69 | - [2, 7] 70 | - [1, 9] 71 | - [8, 1] 72 | - [2, 5] 73 | - [6, 2] 74 | - [1, 3] 75 | - [2, 4] 76 | - [0, 3] 77 | - [1, 7] 78 | - [3, 8] 79 | - [0, 8] 80 | - [7, 8] 81 | - [0, 6] 82 | - [6, 4] 83 | - [3, 6] 84 | - [8, 9] 85 | - [5, 6] 86 | - [9, 9] 87 | - [5, 4] 88 | - [4, 3] 89 | - [5, 0] 90 | - [6, 7] 91 | - [4, 6] 92 | - [6, 8] 93 | - [6, 1] 94 | - [9, 7] 95 | - [7, 9] 96 | - [4, 1] 97 | - [5, 8] 98 | - [4, 8] 99 | - [9, 8] 100 | - [5, 7] 101 | - [7, 5] 102 | - [3, 2] 103 | - [9, 4] 104 | - [5, 9] 105 | 106 | 107 | eval_goal_pos: 108 | - [2, 0] 109 | - [0, 2] 110 | - [1, 5] 111 | - [2, 2] 112 | - [5, 7] 113 | - [9, 1] 114 | - [6, 9] 115 | - [5, 5] 116 | - [1, 1] 117 | - [7, 9] 118 | - [0, 9] 119 | - [3, 8] 120 | - [8, 5] 121 | - [0, 0] 122 | - [8, 9] 123 | - [1, 3] 124 | - [0, 5] 125 | - [0, 1] 126 | - [9, 5] 127 | - [8, 3] 128 | - [4, 4] 129 | - [1, 2] 130 | - [7, 8] 131 | - [9, 4] 132 | - [3, 7] 133 | - [9, 2] 134 | - [9, 7] 135 | - [5, 6] 136 | - [6, 3] 137 | - [4, 6] 138 | - [0, 8] 139 | - [3, 3] 140 | - [4, 5] 141 | - [1, 9] 142 | - [1, 4] 143 | - [9, 3] 144 | - [7, 3] 145 | - [3, 9] 146 | - [2, 4] 147 | - [0, 6] 148 | - [6, 2] 149 | - [2, 3] 150 | - [1, 8] 151 | - [4, 2] 152 | - [8, 0] 153 | - [8, 6] 154 | - [3, 1] 155 | - [6, 7] 156 | - [2, 7] 157 | - [1, 0] 158 | - [4, 0] 159 | - [7, 0] 160 | - [9, 6] 161 | - [8, 8] 162 | - [2, 6] 163 | - [5, 4] 164 | - [7, 1] 165 | - [2, 9] 166 | - [2, 5] 167 | - [4, 3] 168 | - [4, 1] 169 | - [7, 2] 170 | - [0, 3] 171 | - [8, 1] 172 | - [5, 3] 173 | - [7, 7] 174 | - [6, 1] 175 | - [0, 7] 176 | - [2, 8] 177 | - [9, 9] 178 | - [5, 2] 179 | - [4, 8] 180 | - [8, 2] 181 | - [6, 0] 182 | - [7, 4] 183 | - [3, 2] 184 | - [4, 7] 185 | - [7, 6] 186 | - [3, 6] 187 | - [0, 4] 188 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/mt45v2_dmc11_pg12_atari41.yaml: -------------------------------------------------------------------------------- 1 | mt40_v2: 2 | base: ${SSD_DATA_DIR}/metaworld 3 | names: 4 | - reach-v2 5 | - push-v2 6 | - pick-place-v2 7 | - door-open-v2 8 | - drawer-open-v2 9 | - drawer-close-v2 10 | - button-press-topdown-v2 11 | - peg-insert-side-v2 12 | - window-open-v2 13 | - window-close-v2 14 | - door-close-v2 15 | - reach-wall-v2 16 | - pick-place-wall-v2 17 | - push-wall-v2 18 | - button-press-v2 19 | - button-press-topdown-wall-v2 20 | - button-press-wall-v2 21 | - peg-unplug-side-v2 22 | - disassemble-v2 23 | - hammer-v2 24 | - plate-slide-v2 25 | - plate-slide-side-v2 26 | - plate-slide-back-v2 27 | - plate-slide-back-side-v2 28 | - handle-press-v2 29 | - handle-pull-v2 30 | - handle-press-side-v2 31 | - handle-pull-side-v2 32 | - stick-push-v2 33 | - stick-pull-v2 34 | - basketball-v2 35 | - soccer-v2 36 | - faucet-open-v2 37 | - faucet-close-v2 38 | - coffee-push-v2 39 | - coffee-pull-v2 40 | - coffee-button-v2 41 | - sweep-v2 42 | - sweep-into-v2 43 | - pick-out-of-hole-v2 44 | - assembly-v2 45 | - shelf-place-v2 46 | - push-back-v2 47 | - lever-pull-v2 48 | - dial-turn-v2 49 | dmcontrol: 50 | base: ${SSD_DATA_DIR}/dmc 51 | names: 52 | - finger_turn_easy 53 | - fish_upright 54 | - hopper_stand 55 | - point_mass_easy 56 | - walker_stand 57 | - walker_run 58 | - ball_in_cup_catch 59 | - cartpole_swingup 60 | - cheetah_run 61 | - finger_spin 62 | - reacher_easy 63 | procgen: 64 | base: ${SSD_DATA_DIR}/procgen/processed_25M_custom 65 | names: 66 | - "bigfish" 67 | - "bossfight" 68 | - "caveflyer" 69 | - "chaser" 70 | - "coinrun" 71 | - "dodgeball" 72 | - "fruitbot" 73 | - "heist" 74 | - "leaper" 75 | - "maze" 76 | - "miner" 77 | - "starpilot" 78 | atari: 79 | base: ${SSD_DATA_DIR}/atari_1M_64rgb 80 | names: 81 | - amidar 82 | - assault 83 | - asterix 84 | - atlantis 85 | - bank-heist 86 | - battle-zone 87 | - beam-rider 88 | - boxing 89 | - breakout 90 | - carnival 91 | - centipede 92 | - chopper-command 93 | - crazy-climber 94 | - demon-attack 95 | - double-dunk 96 | - enduro 97 | - fishing-derby 98 | - freeway 99 | - frostbite 100 | - gopher 101 | - gravitar 102 | - hero 103 | - ice-hockey 104 | - jamesbond 105 | - kangaroo 106 | - krull 107 | - kung-fu-master 108 | - name-this-game 109 | - phoenix 110 | - pooyan 111 | - qbert 112 | - riverraid 113 | - road-runner 114 | - robotank 115 | - seaquest 116 | - time-pilot 117 | - up-n-down 118 | - video-pinball 119 | - wizard-of-wor 120 | - yars-revenge 121 | - zaxxon -------------------------------------------------------------------------------- /src/tokenizers_custom/minmax_tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_tokenizer import BaseTokenizer 3 | 4 | 5 | class MinMaxTokenizer(BaseTokenizer): 6 | 7 | def __init__(self, min_val=-1, max_val=1, one_hot=False, **kwargs): 8 | super().__init__(**kwargs) 9 | self.min_val = min_val 10 | self.max_val = max_val 11 | self.one_hot = one_hot 12 | self.bin_width = (max_val - min_val) / self.vocab_size 13 | 14 | def tokenize(self, x): 15 | # Reshape the input tensor to have shape (batch_size, num_features) 16 | batch_size, num_features = x.shape[0], x.shape[1:] 17 | x = x.view(batch_size, -1) 18 | 19 | # Compute the indices of the bins 20 | tokens = ((x - self.min_val) / self.bin_width).long().clamp(min=0, max=self.vocab_size - 1) 21 | 22 | # Reshape the output tensor to have the same shape as the input tensor 23 | tokens = tokens.view(batch_size, *num_features) 24 | 25 | if self.shift != 0: 26 | return tokens + self.shift 27 | if self.one_hot: 28 | return torch.nn.functional.one_hot(tokens, num_classes=self.vocab_size).float().flatten(-2) 29 | return tokens 30 | 31 | def inv_tokenize(self, x): 32 | if self.one_hot: 33 | x = torch.argmax(x, dim=-1) 34 | if self.shift != 0: 35 | x = x - self.shift 36 | # can't be smaller than 0 37 | x[x < 0] = 0 38 | 39 | # Reshape the input tensor to have shape (batch_size, num_features) 40 | batch_size, num_features = x.shape[0], x.shape[1:] 41 | x = x.view(batch_size, -1) 42 | 43 | # Compute the values of the bins 44 | values = x.float() * self.bin_width + self.min_val 45 | 46 | # Reshape the output tensor to have the same shape as the input tensor 47 | return values.view(batch_size, *num_features) 48 | 49 | 50 | class MinMaxTokenizer2(BaseTokenizer): 51 | 52 | def __init__(self, min_val=-1, max_val=1, **kwargs): 53 | """ 54 | Tokenizes a given (action) input as described by: https://arxiv.org/abs/2212.06817 55 | Args: 56 | **kwargs: 57 | 58 | """ 59 | super().__init__(**kwargs) 60 | self.min_val = min_val 61 | self.max_val = max_val 62 | 63 | def tokenize(self, x): 64 | x = torch.clamp(x, self.min_val, self.max_val) 65 | # Normalize the action [batch, actions_size] 66 | tokens = (x - self.min_val) / (self.max_val - self.min_val) 67 | # Bucket and discretize the action to vocab_size, [batch, actions_size] 68 | tokens = (tokens * (self.vocab_size - 1)).long() 69 | if self.shift != 0: 70 | return tokens + self.shift 71 | return tokens 72 | 73 | def inv_tokenize(self, x): 74 | if self.shift != 0: 75 | x = x - self.shift 76 | x[x < 0] = 0 77 | x = x.float() / (self.vocab_size - 1) 78 | x = (x * (self.max_val - self.min_val)) + self.min_val 79 | return x 80 | -------------------------------------------------------------------------------- /src/envs/composuite_utils.py: -------------------------------------------------------------------------------- 1 | import composuite 2 | from stable_baselines3.common.monitor import Monitor 3 | from stable_baselines3.common.env_util import DummyVecEnv 4 | from composuite.env.gym_wrapper import GymWrapper 5 | from .compatibility_wrapper import GymCompatibilityWrapper 6 | 7 | 8 | class CustomGymWrapper(GymWrapper): 9 | """ 10 | Overwrites original GymWrapper to allow for pickling. 11 | """ 12 | def __getattr__(self, attr): 13 | # using getattr ensures that both __getattribute__ and __getattr__ (fallback) get called 14 | # (see https://stackoverflow.com/questions/3278077/difference-between-getattr-vs-getattribute) 15 | if attr == "env": 16 | return self.env 17 | orig_attr = getattr(self.env, attr) 18 | if callable(orig_attr): 19 | def hooked(*args, **kwargs): 20 | result = orig_attr(*args, **kwargs) 21 | # prevent wrapped_class from becoming unwrapped 22 | # NOTE: had to use "is" to prevent errors when returning numpy arrays from a wrapped method 23 | if result is self.env: 24 | return self 25 | return result 26 | 27 | return hooked 28 | else: 29 | return orig_attr 30 | 31 | def __reduce__(self): 32 | return (self.__class__, (self.env, self.keys)) 33 | 34 | 35 | def get_composuite_constructor(envid, env_kwargs=None): 36 | env_kwargs = dict(env_kwargs) if env_kwargs is not None else {} 37 | render_mode = env_kwargs.pop("render_mode", None) 38 | use_task_id_obs = env_kwargs.pop("use_task_id_obs", True) 39 | def make(): 40 | robot_name, object_name, obstacle_name, objective_name = envid.split("_") 41 | env = composuite.make(robot_name, object_name, obstacle_name, objective_name, 42 | use_task_id_obs=use_task_id_obs, ignore_done=False, **env_kwargs) 43 | # overwrite original GymWrapper to allow for env pickling. 44 | env = CustomGymWrapper(env.unwrapped) 45 | # make gymnasium env compatible with gym 46 | env = GymCompatibilityWrapper(env) 47 | # rename for easier metric tracking 48 | env.name = envid 49 | if render_mode is not None: 50 | env.metadata.update({"render.modes": [render_mode]}) 51 | return Monitor(env) 52 | return make 53 | 54 | 55 | def get_composuite_constructors(envid, env_kwargs=None): 56 | if not isinstance(envid, list): 57 | envid = [envid] 58 | return [get_composuite_constructor(eid, env_kwargs=env_kwargs) for eid in envid] 59 | 60 | 61 | def make_composuite_envs(env_params, envid, make_eval_env=True): 62 | const_kwargs = { 63 | "envid": envid, 64 | "env_kwargs": env_params.get("cs_env_kwargs", {}), 65 | } 66 | env = DummyVecEnv(get_composuite_constructors(**const_kwargs)) 67 | eval_env = None 68 | if make_eval_env: 69 | eval_env = DummyVecEnv(get_composuite_constructors(**const_kwargs)) 70 | eval_env.num_envs = 1 71 | env.num_envs = 1 72 | return env, eval_env 73 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: lram 2 | channels: 3 | - anaconda 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1 7 | - _openmp_mutex=4.5 8 | - _tflow_select=2.1.0 9 | - aiohttp=3.8.1 10 | - aiosignal=1.2.0 11 | - astor=0.8.1 12 | - astunparse=1.6.3 13 | - async-timeout=4.0.1 14 | - attrs=21.4.0 15 | - blas=1.0 16 | - blinker=1.4 17 | - brotlipy=0.7.0 18 | - c-ares=1.18.1 19 | - ca-certificates=2022.3.29 20 | - cffi=1.15.0 21 | - click=8.0.4 22 | - cryptography=3.4.8 23 | - dataclasses=0.8 24 | - frozenlist=1.2.0 25 | - gast=0.4.0 26 | - google-pasta=0.2.0 27 | - hdf5=1.10.6 28 | - idna=3.3 29 | - importlib-metadata=4.11.3 30 | - intel-openmp=2021.4.0 31 | - keras-preprocessing=1.1.2 32 | - ld_impl_linux-64=2.35.1 33 | - libffi=3.3 34 | - libgcc-ng=9.3.0 35 | - libgfortran-ng=7.5.0 36 | - libgfortran4=7.5.0 37 | - libgomp=9.3.0 38 | - libprotobuf=3.19.1 39 | - libstdcxx-ng=9.3.0 40 | - mkl=2021.4.0 41 | - mkl-service=2.4.0 42 | - mkl_fft=1.3.1 43 | - mkl_random=1.2.2 44 | - multidict=5.2.0 45 | - ncurses=6.3 46 | - oauthlib=3.2.0 47 | - openssl=1.1.1n 48 | - opt_einsum=3.3.0 49 | - pip=21.2.4 50 | - pyasn1=0.4.8 51 | - pycparser=2.21 52 | - pyjwt=2.1.0 53 | - pyopenssl=21.0.0 54 | - pysocks=1.7.1 55 | - python=3.9.12 56 | - python-flatbuffers=2.0 57 | - readline=8.1.2 58 | - requests=2.27.1 59 | - six=1.16.0 60 | - sqlite=3.38.2 61 | - tk=8.6.11 62 | - typing_extensions 63 | - tzdata=2022a 64 | - urllib3=1.26.9 65 | - xz=5.2.5 66 | - yarl=1.6.3 67 | - zlib=1.2.12 68 | - pip: 69 | - absl-py==1.0.0 70 | - antlr4-python3-runtime==4.8 71 | - cachetools==5.0.0 72 | - certifi==2021.10.8 73 | - charset-normalizer==2.0.12 74 | - cloudpickle==2.0.0 75 | - cycler==0.11.0 76 | - cython==0.29.28 77 | - docker-pycreds==0.4.0 78 | - fasteners==0.17.3 79 | - filelock==3.7.0 80 | - lockfile==0.12.2 81 | - flatbuffers==1.12 82 | - fonttools==4.33.3 83 | - future==0.18.2 84 | - gitdb==4.0.9 85 | - gitpython==3.1.27 86 | - glfw==2.5.3 87 | - google-auth==2.6.6 88 | - google-auth-oauthlib==0.4.6 89 | - grpcio==1.44.0 90 | - wheel==0.38.0 91 | - setuptools==65.5.0 92 | - gym-notices==0.0.6 93 | - hydra-core==1.1.2 94 | - imageio==2.18.0 95 | - kiwisolver==1.4.2 96 | - labmaze==1.0.5 97 | - libclang==14.0.1 98 | - lxml==4.8.0 99 | - markdown==3.3.6 100 | - matplotlib==3.5.1 101 | - numpy 102 | - omegaconf==2.1.2 103 | - packaging==21.3 104 | - pandas==1.4.2 105 | - pathtools==0.1.2 106 | - pillow==9.1.0 107 | - promise==2.3 108 | - protobuf==3.20.1 109 | - psutil==5.9.0 110 | - pyasn1-modules==0.2.8 111 | - pybullet==3.2.4 112 | - pyopengl==3.1.6 113 | - pyparsing==2.4.7 114 | - python-dateutil==2.8.2 115 | - pytz==2022.1 116 | - pyyaml==6.0 117 | - regex==2022.4.24 118 | - requests-oauthlib==1.3.1 119 | - rsa==4.8 120 | - scipy==1.8.0 121 | - seaborn==0.11.2 122 | - sentry-sdk==1.5.12 123 | - setproctitle==1.2.3 124 | - shortuuid==1.0.9 125 | - smmap==5.0.0 126 | - termcolor==1.1.0 127 | - tqdm==4.64.0 128 | - typing-extensions 129 | - werkzeug==2.1.2 130 | - wrapt==1.12.1 131 | - zipp==3.8.0 -------------------------------------------------------------------------------- /src/algos/__init__.py: -------------------------------------------------------------------------------- 1 | MODEL_CLASSES = { 2 | "DT": None, 3 | "ODT": None, 4 | "UDT": None, 5 | "DummyUDT": None, 6 | "MPDT": None, 7 | "DDT": None, 8 | "MDDT": None, 9 | "DecisionMamba": None, 10 | "DiscreteDecisionMamba": None, 11 | "MDDMamba": None, 12 | "DecisionXLSTM": None, 13 | "DiscreteDecisionXLSTM": None, 14 | "MDDXLSTM": None, 15 | } 16 | 17 | AGENT_CLASSES = { 18 | "DT": None, 19 | "ODT": None, 20 | "UDT": None, 21 | "DummyUDT": None, 22 | "MPDT": None, 23 | "DDT": None, 24 | "MDDT": None, 25 | "DecisionMamba": None, 26 | "DiscreteDecisionMamba": None, 27 | "MDDMamba": None, 28 | "DecisionXLSTM": None, 29 | "DiscreteDecisionXLSTM": None, 30 | "MDDXLSTM": None, 31 | } 32 | 33 | 34 | def get_model_class(kind): 35 | if kind in ["DT", "ODT", "UDT"]: 36 | from .models.online_decision_transformer_model import OnlineDecisionTransformerModel 37 | MODEL_CLASSES[kind] = OnlineDecisionTransformerModel 38 | elif kind in ["DDT"]: 39 | from .models.discrete_decision_transformer_model import DiscreteDTModel 40 | MODEL_CLASSES[kind] = DiscreteDTModel 41 | elif kind in ["MDDT"]: 42 | from .models.multi_domain_discrete_dt_model import MultiDomainDiscreteDTModel 43 | MODEL_CLASSES[kind] = MultiDomainDiscreteDTModel 44 | elif "mamba" in kind.lower(): 45 | from .models.decision_mamba import DecisionMambaModel, DiscreteDecisionMambaModel, MultiDomainDecisionMambaModel 46 | MODEL_CLASSES["DecisionMamba"] = DecisionMambaModel 47 | MODEL_CLASSES["DiscreteDecisionMamba"] = DiscreteDecisionMambaModel 48 | MODEL_CLASSES["MDDMamba"] = MultiDomainDecisionMambaModel 49 | elif "xlstm" in kind.lower(): 50 | from .models.decision_xlstm import DecisionXLSTMModel, DiscreteDecisionXLSTMModel, MultiDomainDiscreteDecisionXLSTMModel 51 | MODEL_CLASSES["DecisionXLSTM"] = DecisionXLSTMModel 52 | MODEL_CLASSES["DiscreteDecisionXLSTM"] = DiscreteDecisionXLSTMModel 53 | MODEL_CLASSES["MDDXLSTM"] = MultiDomainDiscreteDecisionXLSTMModel 54 | assert kind in MODEL_CLASSES, f"Unknown kind: {kind}" 55 | return MODEL_CLASSES[kind] 56 | 57 | 58 | def get_agent_class(kind): 59 | assert kind in AGENT_CLASSES, f"Unknown kind: {kind}" 60 | # lazy imports only when needed 61 | if kind in ["DT", "ODT"]: 62 | from .decision_transformer_sb3 import DecisionTransformerSb3 63 | AGENT_CLASSES[kind] = DecisionTransformerSb3 64 | elif kind in ["UDT"]: 65 | from .universal_decision_transformer_sb3 import UDT 66 | AGENT_CLASSES[kind] = UDT 67 | elif kind in ["DDT", "MDDT"]: 68 | from .discrete_decision_transformer_sb3 import DiscreteDecisionTransformerSb3 69 | AGENT_CLASSES[kind] = DiscreteDecisionTransformerSb3 70 | elif kind == "DecisionMamba": 71 | from .decision_mamba import DecisionMamba 72 | AGENT_CLASSES[kind] = DecisionMamba 73 | elif kind in ["DiscreteDecisionMamba", "MDDMamba"]: 74 | from .decision_mamba import DiscreteDecisionMamba 75 | AGENT_CLASSES[kind] = DiscreteDecisionMamba 76 | elif kind in ["DecisionXLSTM"]: 77 | from .decision_xlstm import DecisionXLSTM 78 | AGENT_CLASSES[kind] = DecisionXLSTM 79 | elif kind in ["DiscreteDecisionXLSTM", "MDDXLSTM"]: 80 | from .decision_xlstm import DiscreteDecisionXLSTM 81 | AGENT_CLASSES[kind] = DiscreteDecisionXLSTM 82 | return AGENT_CLASSES[kind] 83 | -------------------------------------------------------------------------------- /src/algos/models/extractors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, create_mlp 4 | 5 | 6 | class FlattenExtractorWithMLP(FlattenExtractor): 7 | 8 | def __init__(self, observation_space, net_arch=None): 9 | super().__init__(observation_space) 10 | if net_arch is None: 11 | net_arch = [128, 128] 12 | 13 | mlp = create_mlp(self.features_dim, net_arch[-1], net_arch) 14 | self.mlp = nn.Sequential(*mlp) 15 | self._features_dim = net_arch[-1] 16 | 17 | def forward(self, observations): 18 | return self.mlp(self.flatten(observations)) 19 | 20 | 21 | class TextureFeatureExtractor(BaseFeaturesExtractor): 22 | """ 23 | Textures Feature Extractor for Crafter. Textures that at dim 21 24 | """ 25 | def __init__(self, observation_space, features_dim=256, texture_start_dim=21, num_textures=63, 26 | texture_embed_dim=4, textures_shape=(9,7), hidden_dim=192, **kwargs): 27 | super().__init__(observation_space, features_dim=features_dim) 28 | self.texture_start_dim = texture_start_dim 29 | self.texture_emb = nn.Embedding(num_textures + 1, texture_embed_dim) 30 | self.texture_net = nn.Sequential( 31 | nn.Linear(texture_embed_dim * textures_shape[0] * textures_shape[1], hidden_dim), 32 | nn.LeakyReLU(), 33 | nn.Linear(hidden_dim, hidden_dim), 34 | nn.LayerNorm(hidden_dim) 35 | ) 36 | self.out = nn.Linear(texture_start_dim + hidden_dim, self.features_dim) 37 | 38 | def forward(self, observations): 39 | # receives flatttened info + textures 40 | info, textures = observations[..., :self.texture_start_dim], observations[..., self.texture_start_dim:].long() 41 | texture_embeds = self.texture_emb(textures).flatten(-2) 42 | texture_features = self.texture_net(texture_embeds) 43 | x = torch.cat((info, texture_features), dim=-1) 44 | return self.out(x) 45 | 46 | 47 | def create_cwnet( 48 | input_dim: int, 49 | output_dim: int, 50 | net_arch=(256,256,256), 51 | # activation_fn=lambda: nn.LeakyReLU(negative_slope=0.2), 52 | activation_fn=nn.LeakyReLU, 53 | squash_output: bool = False, 54 | ): 55 | """ 56 | Creates the same Net as described in https://arxiv.org/pdf/2105.10919.pdf 57 | Basically just adds LayerNorm + Tanh after first Dense layer. 58 | 59 | :param input_dim: Dimension of the input vector 60 | :param output_dim: 61 | :param net_arch: Architecture of the neural net 62 | It represents the number of units per layer. 63 | The length of this list is the number of layers. 64 | :param activation_fn: The activation function 65 | to use after each layer. 66 | :param squash_output: Whether to squash the output using a Tanh 67 | activation function 68 | :return: 69 | """ 70 | 71 | if len(net_arch) > 0: 72 | modules = [nn.Linear(input_dim, net_arch[0])] 73 | else: 74 | modules = [] 75 | 76 | modules.append(nn.LayerNorm(net_arch[0])) 77 | modules.append(nn.Tanh()) 78 | 79 | for idx in range(len(net_arch) - 1): 80 | modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1])) 81 | modules.append(activation_fn()) 82 | 83 | if output_dim > 0: 84 | last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim 85 | modules.append(nn.Linear(last_layer_dim, output_dim)) 86 | if squash_output: 87 | modules.append(nn.Tanh()) 88 | return modules 89 | -------------------------------------------------------------------------------- /configs/env_params/dark_keydoor.yaml: -------------------------------------------------------------------------------- 1 | envid: MiniHack-KeyDoor-Dark-Dense-10x10-v0 2 | target_return: 1 3 | num_envs: 1 4 | norm_obs: False 5 | record: False 6 | record_freq: 1000000 7 | record_length: 2000 8 | reward_scale: 1 9 | 10 | # randomly selected positions 11 | # 80 for training, 20 for testing 12 | env_kwargs: 13 | random: False 14 | observation_keys: 15 | - tty_cursor 16 | 17 | train_start_pos: 18 | - [0,0] 19 | train_goal_pos: 20 | - [2,0] 21 | - [0,2] 22 | - [1,5] 23 | - [2,2] 24 | - [5,7] 25 | - [9,1] 26 | - [6,9] 27 | - [5,5] 28 | - [1,1] 29 | - [7,9] 30 | - [0,9] 31 | - [3,8] 32 | - [8,5] 33 | - [0,0] 34 | - [8,9] 35 | - [1,3] 36 | - [0,5] 37 | - [0,1] 38 | - [9,5] 39 | - [8,3] 40 | - [4,4] 41 | - [1,2] 42 | - [7,8] 43 | - [9,4] 44 | - [3,7] 45 | - [9,2] 46 | - [9,7] 47 | - [5,6] 48 | - [6,3] 49 | - [4,6] 50 | - [0,8] 51 | - [3,3] 52 | - [4,5] 53 | - [1,9] 54 | - [1,4] 55 | - [9,3] 56 | - [7,3] 57 | - [3,9] 58 | - [2,4] 59 | - [0,6] 60 | - [6,2] 61 | - [2,3] 62 | - [1,8] 63 | - [4,2] 64 | - [8,0] 65 | - [8,6] 66 | - [3,1] 67 | - [6,7] 68 | - [2,7] 69 | - [1,0] 70 | - [4,0] 71 | - [7,0] 72 | - [9,6] 73 | - [8,8] 74 | - [2,6] 75 | - [5,4] 76 | - [7,1] 77 | - [2,9] 78 | - [2,5] 79 | - [4,3] 80 | - [4,1] 81 | - [7,2] 82 | - [0,3] 83 | - [8,1] 84 | - [5,3] 85 | - [7,7] 86 | - [6,1] 87 | - [0,7] 88 | - [2,8] 89 | - [9,9] 90 | - [5,2] 91 | - [4,8] 92 | - [8,2] 93 | - [6,0] 94 | - [7,4] 95 | - [3,2] 96 | - [4,7] 97 | - [7,6] 98 | - [3,6] 99 | - [0,4] 100 | train_key_pos: 101 | - [8,3] 102 | - [5,3] 103 | - [7,0] 104 | - [4,5] 105 | - [4,4] 106 | - [3,9] 107 | - [2,2] 108 | - [8,0] 109 | - [1,0] 110 | - [0,0] 111 | - [1,8] 112 | - [3,0] 113 | - [7,3] 114 | - [3,3] 115 | - [9,0] 116 | - [0,4] 117 | - [7,6] 118 | - [7,7] 119 | - [1,2] 120 | - [3,1] 121 | - [5,5] 122 | - [8,8] 123 | - [2,6] 124 | - [4,2] 125 | - [6,9] 126 | - [1,5] 127 | - [4,0] 128 | - [9,6] 129 | - [0,9] 130 | - [7,2] 131 | - [1,1] 132 | - [4,7] 133 | - [8,5] 134 | - [2,8] 135 | - [9,3] 136 | - [0,5] 137 | - [6,6] 138 | - [6,5] 139 | - [3,5] 140 | - [1,6] 141 | - [4,9] 142 | - [3,4] 143 | - [0,7] 144 | - [9,5] 145 | - [2,7] 146 | - [1,9] 147 | - [8,1] 148 | - [2,5] 149 | - [6,2] 150 | - [1,3] 151 | - [2,4] 152 | - [0,3] 153 | - [1,7] 154 | - [3,8] 155 | - [0,8] 156 | - [7,8] 157 | - [0,6] 158 | - [6,4] 159 | - [3,6] 160 | - [8,9] 161 | - [5,6] 162 | - [9,9] 163 | - [5,4] 164 | - [4,3] 165 | - [5,0] 166 | - [6,7] 167 | - [4,6] 168 | - [6,8] 169 | - [6,1] 170 | - [9,7] 171 | - [7,9] 172 | - [4,1] 173 | - [5,8] 174 | - [4,8] 175 | - [9,8] 176 | - [5,7] 177 | - [7,5] 178 | - [3,2] 179 | - [9,4] 180 | - [5,9] 181 | 182 | # train tasks for sfixed, grand 183 | eval_start_pos: [0, 0] 184 | eval_goal_pos: 185 | - [9,0] 186 | - [8,4] 187 | - [3,4] 188 | - [6,5] 189 | - [7,5] 190 | - [5,0] 191 | - [3,5] 192 | - [9,8] 193 | - [8,7] 194 | - [3,0] 195 | - [6,6] 196 | - [5,9] 197 | - [1,7] 198 | - [5,1] 199 | - [1,6] 200 | - [5,8] 201 | - [2,1] 202 | - [4,9] 203 | - [6,4] 204 | - [6,8] 205 | eval_key_pos: 206 | - [6,3] 207 | - [3,1] 208 | - [3,7] 209 | - [2,9] 210 | - [0,1] 211 | - [5,2] 212 | - [2,1] 213 | - [0,2] 214 | - [2,3] 215 | - [8,7] 216 | - [9,1] 217 | - [7,4] 218 | - [8,6] 219 | - [8,2] 220 | - [2,0] 221 | - [6,0] 222 | - [7,1] 223 | - [1,4] 224 | - [9,2] 225 | - [5,1] 226 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import traceback 3 | import hydra 4 | import wandb 5 | import omegaconf 6 | from datetime import timedelta 7 | from pathlib import Path 8 | from torch.distributed import init_process_group, destroy_process_group 9 | from src.utils import maybe_split, safe_mean, multiply 10 | 11 | 12 | def setup_logging(config): 13 | config_dict = omegaconf.OmegaConf.to_container(config, resolve=True, throw_on_missing=True) 14 | config_dict["PID"] = os.getpid() 15 | print(f"PID: {os.getpid()}") 16 | # hydra changes working directories automatically 17 | logdir = str(Path.joinpath(Path(os.getcwd()), config.logdir)) 18 | Path(logdir).mkdir(exist_ok=True, parents=True) 19 | print(f"Logdir: {logdir}") 20 | 21 | run = None 22 | if config.use_wandb: 23 | print("Setting up logging to Weights & Biases.") 24 | # make "wandb" path, otherwise WSL might block writing to dir 25 | wandb_path = Path.joinpath(Path(logdir), "wandb") 26 | wandb_path.mkdir(exist_ok=True, parents=True) 27 | wandb_params = omegaconf.OmegaConf.to_container(config.wandb_params, resolve=True, throw_on_missing=True) 28 | key, host = wandb_params.pop("key", None), wandb_params.pop("host", None) 29 | if key is not None and host is not None: 30 | wandb.login(key=key, host=host) 31 | config.wandb_params.update({"key": None, "host": None}) 32 | run = wandb.init(tags=[config.experiment_name], 33 | config=config_dict, **wandb_params) 34 | print(f"Writing Weights & Biases logs to: {str(wandb_path)}") 35 | run.log_code(hydra.utils.get_original_cwd()) 36 | return run, logdir 37 | 38 | 39 | def setup_ddp(): 40 | init_process_group(backend="nccl", timeout=timedelta(minutes=240)) 41 | 42 | 43 | @hydra.main(config_path="configs", config_name="config", version_base="1.1") 44 | def main(config): 45 | ddp = config.get("ddp", False) 46 | if ddp: 47 | setup_ddp() 48 | # make sure only global rank0 writes to wandb 49 | logdir = None 50 | global_rank = int(os.environ["RANK"]) 51 | if global_rank == 0: 52 | run, logdir = setup_logging(config) 53 | else: 54 | run, logdir = setup_logging(config) 55 | print("Config: \n", omegaconf.OmegaConf.to_yaml(config, resolve=True, sort_keys=True)) 56 | 57 | # imports after initializing ddp to avoid fork/spawn issues 58 | from src.envs import make_env 59 | from src.callbacks import make_callbacks 60 | from src.algos.builder import make_agent 61 | 62 | env, eval_env, train_eval_env = make_env(config, logdir) 63 | agent = make_agent(config, env, logdir) 64 | callbacks = make_callbacks(config, env=env, eval_env=eval_env, logdir=logdir, train_eval_env=train_eval_env) 65 | res, score = None, None 66 | try: 67 | res = agent.learn( 68 | **config.run_params, 69 | eval_env=eval_env, 70 | callback=callbacks 71 | ) 72 | except Exception as e: 73 | print(f"Exception: {e}") 74 | traceback.print_exc() 75 | finally: 76 | print("Finalizing run...") 77 | if config.use_wandb: 78 | if config.env_params.record: 79 | env.video_recorder.close() 80 | if not ddp or (ddp and global_rank == 0): 81 | run.finish() 82 | wandb.finish 83 | # return last avg reward for hparam optimization 84 | score = None if res is None else safe_mean([ep_info["r"] for ep_info in res.ep_info_buffer]) 85 | if ddp: 86 | destroy_process_group() 87 | if hasattr(agent, "cache"): 88 | agent.cache.cleanup_cache() 89 | return score 90 | 91 | 92 | if __name__ == "__main__": 93 | omegaconf.OmegaConf.register_new_resolver("maybe_split", maybe_split) 94 | omegaconf.OmegaConf.register_new_resolver("multiply", multiply) 95 | main() 96 | -------------------------------------------------------------------------------- /src/data/mimicgen/README.md: -------------------------------------------------------------------------------- 1 | # Mimicgen 2 | - Paper: https://arxiv.org/pdf/2310.17596 3 | - Code: https://github.com/NVlabs/mimicgen 4 | - Documentation: https://mimicgen.github.io/docs/introduction/overview.html 5 | 6 | ## Installation 7 | Mimicgen requires mujoco, robosuite, robomimic and robosuite_task_zoo. 8 | 9 | Install `mimicgen` as follows: 10 | ``` 11 | # mujoco 12 | pip install mujoco==2.3.2 13 | 14 | # robosuite 15 | pip install robosuite==1.4.1 16 | # requires 17 | pip install gymnasium==0.28.1 18 | 19 | # git cone, install, robomimic: https://github.com/ARISE-Initiative/robomimic 20 | git clone https://github.com/ARISE-Initiative/robomimic.git 21 | cd robomimic 22 | pip install -e . 23 | 24 | # git clone, install: https://mimicgen.github.io/docs/introduction/installation.html 25 | git clone https://github.com/NVlabs/mimicgen.git 26 | cd mimicgen 27 | pip install -e . 28 | 29 | # git clone install: https://github.com/ARISE-Initiative/robosuite-task-zoo 30 | git clone https://github.com/ARISE-Initiative/robosuite-task-zoo.git 31 | cd robosuite-task-zoo 32 | pip install -e . 33 | pip install mujoco==2.3.2 34 | pip install mujoco_py==2.0.2.5 35 | ``` 36 | 37 | ## Troubleshooting 38 | `egl-probe` may fail. Solve by: 39 | ``` 40 | pip install cmake 41 | ``` 42 | 43 | ## Data download 44 | Download the 26 original `core` datasets provided by the Mimicgen publication: 45 | ``` 46 | # using gdown 47 | pip install gdown 48 | gdown --folder https://drive.google.com/drive/folders/14uywHbSdletLBJUmR8c5UrBUZkALFcUz 49 | # or any of the methods described here: https://mimicgen.github.io/docs/datasets/mimicgen_corl_2023.html 50 | ``` 51 | 52 | ## Data structure 53 | Every .hdf5 contains a field `data` and individual fields for each episode. 54 | Each episodes contains `states`, `actions`, `rewards`, `dones`, and `obs`. 55 | - `states` contains the simulation state, not the actual continous observation state. 56 | - `obs` contains the individual observation attributes. 57 | 58 | `mimigen` uses `robomimic` underneath, which leverages the `EnvRobosuite` env. This class removes a number of fields 59 | from the actual `obs` than would be returned by the original `robosuite` env. Consequently, the data does not contain 60 | these fields and subsequently, the interaction environment also has to be an `EnvRobosuite` env. 61 | 62 | In particular, `EnvRobosuite` removes: 63 | - `object-state` containing the state of the object --> added back via `object` field. 64 | - `robot0_proprio-state` containing the state of the robot --> has to be reconstructed from the individual robot states. 65 | 66 | The stored dataset additionally contain image fields `agentview_image` and `robot0_eye_in_hand_image`. 67 | Consequently, we sort fields alphabetically and remove the image fields to get the actual observation fields. 68 | 69 | ## Data preparation 70 | Prepare the 26 original `core` datasets used for our experiments: 71 | ``` 72 | # low dim keys, binary reward 73 | python prepare_data.py --add_rtgs --low_dim_keys --compress --sparse_reward --save_dir=./data/mimicgen/core_processed_sparse 74 | ``` 75 | 76 | Then download our generated datasets (additional robot arms) and extract them (e.g., using `untar_files.sh`) to the same directory: 77 | ``` 78 | cd ./data/mimicgen 79 | huggingface-cli download ml-jku/mimicgen59 --local-dir=./mimicgen59 --repo-type dataset 80 | bash untar_files.sh mimicgen59 core_processed_sparse 81 | ``` 82 | 83 | Other splits can be produced similarly: 84 | ``` 85 | # low dim keys only 86 | python prepare_data.py --add_rtgs --low_dim_keys --compress --save_dir=./data/mimicgen/core_processed 87 | 88 | # img observations 89 | python prepare_data.py --add_rtgs --compress --img_key=agentview_image --crop_dim=64 --save_dir=./data/mimicgen/core_processed_agentview 90 | ``` 91 | 92 | ## Data generation 93 | To generate your own datasets using the `mimicgen` framework we refer to: 94 | - https://github.com/NVlabs/mimicgen/blob/ea0988523f468ccf7570475f1906023f854962e9/docs/tutorials/getting_started.md -------------------------------------------------------------------------------- /src/data/atari/README.md: -------------------------------------------------------------------------------- 1 | # Atari data 2 | 3 | ## Overview 4 | For Atari, we make use of the [DQN Replay Dataset](https://research.google/tools/datasets/dqn-replay/). 5 | For details on the dataset see the website. 6 | 7 | For loading the individual episodes, we rely on [d3rlpy](https://github.com/takuseno/d3rlpy) (a popular library for offline RL) and 8 | [d4rl-atari](https://github.com/takuseno/d4rl-atari). 9 | 10 | The individual `.hdf5` episode files adhere to the following folder structure: 11 | ``` 12 | environment family (e.g. atari) 13 | - environment name (e.g. breakout) 14 | -- one .hdf5 file (numbered, zero padded to 9 digits) per episode with fields: states, actions, rewards, dones 15 | ``` 16 | 17 | ## Installation 18 | As `d3rlpy` and `d4rl-atari` have different dependencies than the regular codebase, we use a separate conda environment 19 | to download the datasets. 20 | 21 | Create conda environment: 22 | 23 | ``` 24 | conda create -n atari_data python=3.9 25 | conda activate atari_data 26 | ``` 27 | 28 | Then install the requirements: 29 | ``` 30 | pip install -r requirements.txt 31 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 32 | ``` 33 | 34 | If problems are encountered with ROM licenses: 35 | ``` 36 | pip install autorom 37 | AutoROM 38 | ``` 39 | 40 | ## Usage: 41 | For each game, the Atari datasets are provided in a [GCP Bucket](https://console.cloud.google.com/storage/browser/atari-replay-datasets) and stored in 50 slices each containing 1M transitions (200M frames + frame-stack=4 --> 50M transistions). 42 | 43 | To collect all 50 slices for Breakout, run: 44 | ``` 45 | python download_atari_datasets.py --save_dir="SAVE_DIR" --num_slices=50 46 | ``` 47 | 48 | The episode quality can be specified using `--quality` ('random', or 'mixed', 'expert'). This is only used if `num_slices < 50`. 49 | ``` 50 | python download_atari_datasets.py --save_dir="SAVE_DIR" --quality="mixed" --num_slices=49 51 | ``` 52 | 53 | The maximum number of episodes to collect can be specified using `--max_episodes`. 54 | ``` 55 | python download_atari_datasets.py --save_dir="SAVE_DIR" --max_episodes=20000 --quality="mixed" --num_slices=49 56 | ``` 57 | 58 | This creates one `.hdf5` per episode in the specified `save_dir`. A `stats.json` file is written 59 | to the same folder and contains information about the collected episodes (e.g., number of transitions, return/length of episodes). 60 | In addition, `episode_lengths.json` is written to the same directory that contains the episode lengths for each episode. 61 | 62 | There is an option for using differen file formats. Compressed files require less space, but data loading takes longer. 63 | We do not use compression by default. The default file format is `.hdf5`, as data loading is faster. 64 | 65 | Furthermore, return-to-gos can be pre-computed and added to the datasets. This is not done by default. 66 | 67 | ## Collecting datasets 68 | We map actions to the full action space using `a_to_full_space` and add return-to-gos using `add_rtgs`. 69 | The cache dir for the original splits can be specified using the environment variable `D4RL_DATASET_DIR`. 70 | 71 | We collect 5M transitions for each of the 46 games used in [Multi-game Decision Transformer (MDGT)](https://arxiv.org/abs/2205.15241). 72 | ``` 73 | python download_atari_datasets.py --save_dir="DATA_DIR/atari_1M/atari_5M_64rgb" --max_transitions=4000000 --envs pong asterix breakout qbert seaquest alien beam-rider freeway ms-pacman space-invaders amidar assault atlantis bank-heist battle-zone boxing carnival centipede chopper-command crazy-climber demon-attack double-dunk enduro fishing-derby frostbite gopher gravitar hero ice-hockey jamesbond kangaroo krull kung-fu-master name-this-game phoenix pooyan riverraid road-runner robotank star-gunner time-pilot up-n-down video-pinball wizard-of-wor yars-revenge zaxxon --quality="expert" --num_slices=49 --a_to_full_space --add_rtgs --to_rgb --crop_dim=64 74 | ``` 75 | ## Troubleshooting: 76 | In case the data download fails, this may be useful: 77 | - https://github.com/opencv/opencv-python/issues/370#issuecomment-996657018 78 | 79 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/mimicgen83.yaml: -------------------------------------------------------------------------------- 1 | base: ${SSD_DATA_DIR}/mimicgen/core_processed_sparse 2 | names: 3 | # original mimicgen tasks 4 | - coffee_d0 5 | - coffee_d1 6 | - coffee_d2 7 | - coffee_preparation_d0 8 | - coffee_preparation_d1 9 | - hammer_cleanup_d0 10 | - hammer_cleanup_d1 11 | - kitchen_d0 12 | - kitchen_d1 13 | - mug_cleanup_d0 14 | - mug_cleanup_d1 15 | - nut_assembly_d0 16 | - pick_place_d0 17 | - square_d0 18 | - square_d1 19 | - square_d2 20 | - stack_d0 21 | - stack_d1 22 | - stack_three_d0 23 | - stack_three_d1 24 | - threading_d0 25 | - threading_d1 26 | - three_piece_assembly_d0 27 | - three_piece_assembly_d1 28 | # self-generate mimicgen data for other robots 29 | - coffee_D0_robot_IIWA_gripper_Robotiq85Gripper 30 | - coffee_D0_robot_Sawyer_gripper_RethinkGripper 31 | - coffee_D0_robot_UR5e_gripper_Robotiq85Gripper 32 | - coffee_D1_robot_IIWA_gripper_Robotiq85Gripper 33 | - coffee_D1_robot_Sawyer_gripper_RethinkGripper 34 | - coffee_D1_robot_UR5e_gripper_Robotiq85Gripper 35 | - coffee_D2_robot_IIWA_gripper_Robotiq85Gripper 36 | - coffee_D2_robot_UR5e_gripper_Robotiq85Gripper 37 | - hammer_cleanup_D0_robot_IIWA_gripper_Robotiq85Gripper 38 | - hammer_cleanup_D0_robot_Sawyer_gripper_RethinkGripper 39 | - hammer_cleanup_D0_robot_UR5e_gripper_Robotiq85Gripper 40 | - hammer_cleanup_D1_robot_IIWA_gripper_Robotiq85Gripper 41 | - hammer_cleanup_D1_robot_Sawyer_gripper_RethinkGripper 42 | - hammer_cleanup_D1_robot_UR5e_gripper_Robotiq85Gripper 43 | - kitchen_D0_robot_IIWA_gripper_Robotiq85Gripper 44 | - kitchen_D0_robot_UR5e_gripper_Robotiq85Gripper 45 | - kitchen_D1_robot_UR5e_gripper_Robotiq85Gripper 46 | - mug_cleanup_D0_robot_IIWA_gripper_Robotiq85Gripper 47 | - mug_cleanup_D1_robot_IIWA_gripper_Robotiq85Gripper 48 | - mug_cleanup_D1_robot_UR5e_gripper_Robotiq85Gripper 49 | # - mug_cleanup_O2_robot_UR5e_gripper_Robotiq85Gripper 50 | - nut_assembly_D0_robot_IIWA_gripper_Robotiq85Gripper 51 | - nut_assembly_D0_robot_Sawyer_gripper_RethinkGripper 52 | - nut_assembly_D0_robot_UR5e_gripper_Robotiq85Gripper 53 | - pick_place_D0_robot_IIWA_gripper_Robotiq85Gripper 54 | - pick_place_D0_robot_Sawyer_gripper_RethinkGripper 55 | - pick_place_D0_robot_UR5e_gripper_Robotiq85Gripper 56 | - square_D0_robot_IIWA_gripper_Robotiq85Gripper 57 | - square_D0_robot_Sawyer_gripper_RethinkGripper 58 | - square_D0_robot_UR5e_gripper_Robotiq85Gripper 59 | - square_D1_robot_IIWA_gripper_Robotiq85Gripper 60 | - square_D1_robot_Sawyer_gripper_RethinkGripper 61 | - square_D1_robot_UR5e_gripper_Robotiq85Gripper 62 | - stack_D0_robot_IIWA_gripper_Robotiq85Gripper 63 | - stack_D0_robot_Sawyer_gripper_RethinkGripper 64 | - stack_D0_robot_UR5e_gripper_Robotiq85Gripper 65 | - stack_D1_robot_IIWA_gripper_Robotiq85Gripper 66 | - stack_D1_robot_Sawyer_gripper_RethinkGripper 67 | - stack_D1_robot_UR5e_gripper_Robotiq85Gripper 68 | - stack_three_D0_robot_IIWA_gripper_Robotiq85Gripper 69 | - stack_three_D0_robot_Sawyer_gripper_RethinkGripper 70 | - stack_three_D0_robot_UR5e_gripper_Robotiq85Gripper 71 | - stack_three_D1_robot_IIWA_gripper_Robotiq85Gripper 72 | - stack_three_D1_robot_Sawyer_gripper_RethinkGripper 73 | - stack_three_D1_robot_UR5e_gripper_Robotiq85Gripper 74 | - threading_D0_robot_IIWA_gripper_Robotiq85Gripper 75 | - threading_D0_robot_Sawyer_gripper_RethinkGripper 76 | - threading_D0_robot_UR5e_gripper_Robotiq85Gripper 77 | - threading_D1_robot_IIWA_gripper_Robotiq85Gripper 78 | - threading_D1_robot_Sawyer_gripper_RethinkGripper 79 | - threading_D1_robot_UR5e_gripper_Robotiq85Gripper 80 | - three_piece_assembly_D0_robot_IIWA_gripper_Robotiq85Gripper 81 | - three_piece_assembly_D0_robot_Sawyer_gripper_RethinkGripper 82 | - three_piece_assembly_D0_robot_UR5e_gripper_Robotiq85Gripper 83 | - three_piece_assembly_D1_robot_IIWA_gripper_Robotiq85Gripper 84 | - three_piece_assembly_D1_robot_Sawyer_gripper_RethinkGripper 85 | - three_piece_assembly_D1_robot_UR5e_gripper_Robotiq85Gripper 86 | - three_piece_assembly_D2_robot_IIWA_gripper_Robotiq85Gripper 87 | - three_piece_assembly_D2_robot_Sawyer_gripper_RethinkGripper 88 | - three_piece_assembly_D2_robot_UR5e_gripper_Robotiq85Gripper 89 | -------------------------------------------------------------------------------- /src/algos/discrete_decision_transformer_sb3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .universal_decision_transformer_sb3 import UDT 3 | from .models.model_utils import sample_from_logits 4 | 5 | 6 | class DiscreteDecisionTransformerSb3(UDT): 7 | 8 | def __init__(self, policy, env, loss_fn="ce", rtg_sample_kwargs=None, a_sample_kwargs=None, **kwargs): 9 | super().__init__(policy, env, loss_fn=loss_fn, **kwargs) 10 | self.rtg_sample_kwargs = {} if rtg_sample_kwargs is None else rtg_sample_kwargs 11 | self.a_sample_kwargs = a_sample_kwargs 12 | 13 | def get_action_pred(self, policy, states, actions, rewards, returns_to_go, timesteps, attention_mask, 14 | deterministic, prompt, is_eval=False, task_id=None, env_act_dim=None): 15 | inputs = { 16 | "states": states, 17 | "actions": actions, 18 | "rewards": rewards, 19 | "returns_to_go": returns_to_go, 20 | "timesteps": timesteps, 21 | "attention_mask": attention_mask, 22 | "return_dict": True, 23 | "deterministic": deterministic, 24 | "prompt": prompt, 25 | "task_id": task_id, 26 | "ddp_kwargs": self.ddp_kwargs, 27 | "use_inference_cache": self.use_inference_cache, 28 | "past_key_values": self.past_key_values # None by default 29 | } 30 | 31 | # exper-action inference mechanism 32 | if self.target_return_type == "infer": 33 | with torch.autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp): 34 | policy_output = policy(**inputs) 35 | return_logits = policy_output.return_preds[:, -1] 36 | return_sample = policy.sample_from_rtg_logits(return_logits, **self.rtg_sample_kwargs) 37 | inputs["returns_to_go"][0, -1] = return_sample 38 | 39 | if not self.policy.tok_a_target_only and not self.policy.shared_a_head: 40 | # autoregressive action prediction 41 | # e.g., for discretized continuous action space predict actions dims one after another 42 | act_dim = actions.shape[-1] if env_act_dim is None else env_act_dim 43 | for i in range(act_dim): 44 | with torch.autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp): 45 | policy_output = policy(**inputs) 46 | if not is_eval and self.num_timesteps % 10000 == 0 and self.log_attn_maps: 47 | self._record_attention_maps(policy_output.attentions, step=self.num_timesteps, prefix="rollout") 48 | if policy_output.cross_attentions is not None: 49 | self._record_attention_maps(policy_output.cross_attentions, step=self.num_timesteps + i, 50 | prefix="rollout_cross", lower_triu=False) 51 | if self.a_sample_kwargs is not None: 52 | action_logits = policy_output.action_logits[0, -1, i] 53 | inputs["actions"][0, -1, i] = sample_from_logits(action_logits, **self.a_sample_kwargs) 54 | else: 55 | inputs["actions"][0, -1, i] = policy_output.action_preds[0, -1, i] 56 | if self.use_inference_cache: 57 | self.past_key_values = policy_output.past_key_values 58 | inputs["past_key_values"] = self.past_key_values 59 | action = inputs["actions"][0, -1] 60 | else: 61 | with torch.autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp): 62 | policy_output = policy(**inputs) 63 | if self.a_sample_kwargs is not None: 64 | action = sample_from_logits(policy_output.action_logits[0, -1], **self.a_sample_kwargs) 65 | else: 66 | action = policy_output.action_preds[0, -1] 67 | if self.use_inference_cache: 68 | self.past_key_values = policy_output.past_key_values 69 | 70 | if env_act_dim is not None: 71 | action = action[:env_act_dim] 72 | return action, inputs["returns_to_go"][0, -1] if self.target_return_type == "infer" else action 73 | -------------------------------------------------------------------------------- /src/envs/procgen_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from copy import deepcopy 4 | from procgen import ProcgenEnv 5 | from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor, VecTransposeImage, VecNormalize 6 | from stable_baselines3.common.env_util import DummyVecEnv 7 | 8 | 9 | def get_procgen_constructor(envid, distribution_mode="easy", time_limit=None, env_kwargs=None): 10 | env_kwargs = dict(env_kwargs) if env_kwargs is not None else {} 11 | num_envs = env_kwargs.pop("num_envs", 1) 12 | norm_reward = env_kwargs.pop("norm_reward", False) 13 | def make(): 14 | env = ProcgenEnv(env_name=envid, num_envs=num_envs, 15 | distribution_mode=distribution_mode, **env_kwargs) 16 | # monitor to obtain ep_rew_mean, ep_rew_len + extract rgb images from dict states 17 | env = CustomVecMonitor(VecExtractDictObs(env, 'rgb'), time_limit=time_limit) 18 | env = VecTransposeImage(env) 19 | if norm_reward: 20 | env = VecNormalize(env, norm_obs=False, norm_reward=True) 21 | env.name = envid 22 | return env 23 | return make 24 | 25 | 26 | class CustomVecMonitor(VecMonitor): 27 | """ 28 | Custom version of VecMonitor that allows for a timelimit. 29 | Once, timelimit is hit, we also need to reset the environment. 30 | We can however, not save the reset state there. 31 | """ 32 | def __init__( 33 | self, 34 | venv, 35 | filename=None, 36 | info_keywords=(), 37 | time_limit=None 38 | ): 39 | super().__init__(venv, filename, info_keywords) 40 | self.time_limit = time_limit 41 | 42 | def step_wait(self): 43 | obs, rewards, dones, infos = self.venv.step_wait() 44 | self.episode_returns += rewards 45 | self.episode_lengths += 1 46 | new_infos = list(infos[:]) 47 | if self.time_limit is not None and (self.episode_lengths >= self.time_limit).any(): 48 | # check if any is over timelimit, if yes, set done 49 | over_time = self.episode_lengths >= self.time_limit 50 | # send action -1 to reset ProcgenEnv: https://github.com/openai/procgen/issues/40#issuecomment-633720234 51 | reset_action = over_time * -1 52 | reset_obs, reset_rewards, reset_done, reset_info = self.venv.step(reset_action) 53 | # get reset observation, ignore rest 54 | obs[over_time] = reset_obs[over_time] 55 | # set done where done or over_time 56 | dones = dones | over_time 57 | 58 | for i in range(len(dones)): 59 | if dones[i]: 60 | info = infos[i].copy() 61 | episode_return = self.episode_returns[i] 62 | episode_length = self.episode_lengths[i] 63 | episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)} 64 | for key in self.info_keywords: 65 | episode_info[key] = info[key] 66 | info["episode"] = episode_info 67 | self.episode_count += 1 68 | self.episode_returns[i] = 0 69 | self.episode_lengths[i] = 0 70 | if self.results_writer: 71 | self.results_writer.write_row(episode_info) 72 | new_infos[i] = info 73 | return obs, rewards, dones, new_infos 74 | 75 | 76 | class CustomDummyVecEnv(DummyVecEnv): 77 | """ 78 | Custom version of DummyVecEnv that allows wrapping ProcgenEnvs. 79 | By default, ProcgenEnvs are vectorized already. 80 | Therefore wrapping different tasks in a single DummyVecEnv fails, due to returning of vectorized infor buffers. 81 | """ 82 | def step_wait(self): 83 | for env_idx in range(self.num_envs): 84 | action = self.actions[env_idx] 85 | if not isinstance(action, np.ndarray): 86 | action = np.array([action]) 87 | obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step( 88 | action 89 | ) 90 | if self.buf_dones[env_idx]: 91 | # save final observation where user can get it, then reset 92 | # self.buf_infos[env_idx]terminal_observation"] = obs 93 | self.buf_infos[env_idx][0]["terminal_observation"] = obs 94 | obs = self.envs[env_idx].reset() 95 | self._save_obs(env_idx, obs) 96 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) 97 | -------------------------------------------------------------------------------- /src/callbacks/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import functools 3 | import wandb 4 | from wandb.integration.sb3 import WandbCallback 5 | from stable_baselines3.common.callbacks import CallbackList 6 | from ..algos import AGENT_CLASSES 7 | from ..envs.env_names import ID_TO_NAMES, MT50_ENVS_v2, ATARI_ENVS, DM_CONTROL_ENVS, \ 8 | MINIHACK_ENVS, GYM_ENVS, PROCGEN_ENVS, COMPOSUITE_ENVS, MIMICGEN_ENVS 9 | 10 | 11 | class CustomWandbCallback(WandbCallback): 12 | 13 | def __init__(self, model_sync_wandb=False, **kwargs): 14 | super().__init__(**kwargs) 15 | self.model_sync_wandb = model_sync_wandb 16 | 17 | def save_model(self) -> None: 18 | print(f"Saving model checkpoint to {self.path}") 19 | self.model.save(self.path) 20 | if self.model_sync_wandb: 21 | wandb.save(self.path, base_path=self.model_save_path) 22 | 23 | 24 | def make_callbacks(config, env=None, eval_env=None, logdir=None, train_eval_env=None): 25 | callbacks = [] 26 | if config.use_wandb and logdir is not None: 27 | model_save_path = None 28 | if config.wandb_callback_params.model_save_path is not None: 29 | model_save_path = f"{logdir}/{config.wandb_callback_params.model_save_path}" 30 | if config.get("ddp", False): 31 | global_rank = int(os.environ["RANK"]) 32 | model_save_path = model_save_path if global_rank == 0 else None 33 | callbacks.append( 34 | CustomWandbCallback( 35 | gradient_save_freq=config.wandb_callback_params.gradient_save_freq, 36 | verbose=config.wandb_callback_params.verbose, 37 | model_save_path=model_save_path, 38 | model_sync_wandb=config.wandb_callback_params.get("model_sync_wandb", False), 39 | model_save_freq=config.wandb_callback_params.get("model_save_freq", 0) 40 | ) 41 | ) 42 | if config.eval_params.use_eval_callback: 43 | if config.agent_params.kind in AGENT_CLASSES.keys(): 44 | from .custom_eval_callback import CustomEvalCallback, MultiEnvEvalCallback 45 | if config.env_params.envid not in [*list(ID_TO_NAMES.keys()), *ATARI_ENVS, *MT50_ENVS_v2, 46 | *DM_CONTROL_ENVS, *MINIHACK_ENVS, *GYM_ENVS, 47 | *PROCGEN_ENVS, *COMPOSUITE_ENVS, *MIMICGEN_ENVS, 48 | "Dummy-v1"]: 49 | eval_callback_class = functools.partial(CustomEvalCallback, use_wandb=config.use_wandb) 50 | else: 51 | eval_callback_class = functools.partial(MultiEnvEvalCallback, use_wandb=config.use_wandb) 52 | else: 53 | from stable_baselines3.common.callbacks import EvalCallback 54 | eval_callback_class = EvalCallback 55 | if config.eval_params.max_no_improvement_evals > 0: 56 | from stable_baselines3.common.callbacks import StopTrainingOnNoModelImprovement 57 | stop_training_callback = StopTrainingOnNoModelImprovement( 58 | max_no_improvement_evals=config.eval_params.max_no_improvement_evals, verbose=1) 59 | else: 60 | stop_training_callback = None 61 | eval_callback_kwargs = { 62 | "n_eval_episodes": config.eval_params.n_eval_episodes, "eval_freq": config.eval_params.eval_freq, 63 | "callback_after_eval": stop_training_callback, "deterministic": config.eval_params.deterministic, 64 | "first_step": config.eval_params.get("first_step", True), 65 | "log_eval_trj": config.eval_params.get("log_eval_trj", False), 66 | "n_jobs": config.eval_params.get("n_jobs", 0) 67 | } 68 | if config.eval_params.get("track_gpu_stats", False): 69 | eval_callback_kwargs["track_gpu_stats"] = config.eval_params.track_gpu_stats 70 | if config.eval_params.get("eval_on_train", False): 71 | train_eval_callback = eval_callback_class(eval_env=env, prefix="train_eval", **eval_callback_kwargs) 72 | callbacks.append(train_eval_callback) 73 | if train_eval_env is not None: 74 | train_eval_seeds_callback = eval_callback_class(eval_env=train_eval_env, prefix="train_eval_seeds", 75 | **eval_callback_kwargs) 76 | callbacks.append(train_eval_seeds_callback) 77 | 78 | eval_callback = eval_callback_class(eval_env=eval_env, **eval_callback_kwargs) 79 | callbacks.append(eval_callback) 80 | if hasattr(config.eval_params, "use_valid_callback") and config.eval_params.use_valid_callback: 81 | from .validation_callback import ValidationCallback 82 | valid_callback_kwargs = { 83 | "eval_freq": config.eval_params.eval_freq, 84 | "first_step": config.eval_params.get("first_step", True), 85 | **config.eval_params.get("valid_kwargs", {}) 86 | } 87 | valid_callback = ValidationCallback(**valid_callback_kwargs) 88 | callbacks.append(valid_callback) 89 | return CallbackList(callbacks) 90 | -------------------------------------------------------------------------------- /src/schedulers/visualize_schedulers.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | from pathlib import Path 6 | from torch.optim.lr_scheduler import SequentialLR 7 | from lr_schedulers import make_lr_scheduler 8 | 9 | 10 | def visualize_lr_scheduler(lr_scheduler, total_steps=1e6, title="", save_dir=None): 11 | total_steps = int(total_steps) 12 | lrs = [] 13 | for _ in range(total_steps): 14 | lrs.append(lr_scheduler.get_last_lr()) 15 | lr_scheduler.step() 16 | plt.plot(lrs) 17 | plt.title(title) 18 | plt.xlabel("Steps") 19 | plt.ylabel("Learning rate") 20 | if save_dir is not None: 21 | save_dir = Path(save_dir) 22 | save_dir.mkdir(parents=True, exist_ok=True) 23 | plt.savefig(save_dir / f"{title}.png", bbox_inches='tight') 24 | plt.show() 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--save_dir", type=str, default="./figures") 30 | args = parser.parse_args() 31 | 32 | sns.set_style("whitegrid") 33 | net = torch.nn.Linear(10, 10) 34 | total_steps = int(1e6) 35 | lr = 0.0001 36 | eta_min = 0.000001 37 | warmup_steps = 50000 38 | 39 | # cosine 40 | optimizer = torch.optim.AdamW(net.parameters(), lr=lr) 41 | sched_kwargs = {"eta_min": eta_min, "T_max": total_steps} 42 | lr_scheduler = make_lr_scheduler(optimizer, kind="cosine", sched_kwargs=sched_kwargs) 43 | warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)) 44 | lr_scheduler = SequentialLR(optimizer, [warmup, lr_scheduler], milestones=[warmup_steps]) 45 | visualize_lr_scheduler(lr_scheduler, total_steps=total_steps, title="cosine", save_dir=args.save_dir) 46 | 47 | # cosine 48 | optimizer = torch.optim.AdamW(net.parameters(), lr=lr) 49 | sched_kwargs = {"eta_min": eta_min, "T_max": total_steps / 5} 50 | lr_scheduler = make_lr_scheduler(optimizer, kind="cosine", sched_kwargs=sched_kwargs) 51 | warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)) 52 | lr_scheduler = SequentialLR(optimizer, [warmup, lr_scheduler], milestones=[warmup_steps]) 53 | visualize_lr_scheduler(lr_scheduler, total_steps=total_steps, title="cosine_2", save_dir=args.save_dir) 54 | 55 | # cosine_restart 56 | optimizer = torch.optim.AdamW(net.parameters(), lr=lr) 57 | sched_kwargs = {"eta_min": eta_min, "T_0": int(total_steps / 5)} 58 | lr_scheduler = make_lr_scheduler(optimizer, kind="cosine_restart", sched_kwargs=sched_kwargs) 59 | warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)) 60 | lr_scheduler = SequentialLR(optimizer, [warmup, lr_scheduler], milestones=[warmup_steps]) 61 | visualize_lr_scheduler(lr_scheduler, total_steps=total_steps, title="cosine_restart", save_dir=args.save_dir) 62 | 63 | # cosine_restart 64 | optimizer = torch.optim.AdamW(net.parameters(), lr=lr) 65 | sched_kwargs = {"eta_min": eta_min, "T_0": int(total_steps / 5), "T_mult": 2} 66 | lr_scheduler = make_lr_scheduler(optimizer, kind="cosine_restart", sched_kwargs=sched_kwargs) 67 | warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)) 68 | lr_scheduler = SequentialLR(optimizer, [warmup, lr_scheduler], milestones=[warmup_steps]) 69 | visualize_lr_scheduler(lr_scheduler, total_steps=total_steps, title="cosine_restart_2", save_dir=args.save_dir) 70 | 71 | # cyclic 72 | optimizer = torch.optim.AdamW(net.parameters(), lr=lr) 73 | sched_kwargs = {"base_lr": lr, "max_lr": 0.001, "step_size_up": 1e5} 74 | lr_scheduler = make_lr_scheduler(optimizer, kind="cyclic", sched_kwargs=sched_kwargs) 75 | warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)) 76 | lr_scheduler = SequentialLR(optimizer, [warmup, lr_scheduler], milestones=[warmup_steps]) 77 | visualize_lr_scheduler(lr_scheduler, total_steps=total_steps, title="cyclic", save_dir=args.save_dir) 78 | 79 | # step 80 | optimizer = torch.optim.AdamW(net.parameters(), lr=lr) 81 | sched_kwargs = {"gamma": 0.1, "step_size": 2e5} 82 | lr_scheduler = make_lr_scheduler(optimizer, kind="step", sched_kwargs=sched_kwargs) 83 | warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)) 84 | lr_scheduler = SequentialLR(optimizer, [warmup, lr_scheduler], milestones=[warmup_steps]) 85 | visualize_lr_scheduler(lr_scheduler, total_steps=total_steps, title="step", save_dir=args.save_dir) 86 | 87 | # exponential 88 | optimizer = torch.optim.AdamW(net.parameters(), lr=lr) 89 | sched_kwargs = {"gamma": eta_min} 90 | lr_scheduler = make_lr_scheduler(optimizer, kind="exp", sched_kwargs=sched_kwargs) 91 | warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)) 92 | lr_scheduler = SequentialLR(optimizer, [warmup, lr_scheduler], milestones=[warmup_steps]) 93 | visualize_lr_scheduler(lr_scheduler, total_steps=total_steps, title="exponential", save_dir=args.save_dir) 94 | -------------------------------------------------------------------------------- /src/envs/hn_scores.py: -------------------------------------------------------------------------------- 1 | # Adjusted from: 2 | # - https://github.com/deepmind/dqn_zoo/blob/807379c19f8819e407329ac1b95dcaccb9d536c3/dqn_zoo/atari_data.py 3 | # - https://github.com/etaoxing/multigame-dt/blob/master/atari_data.py 4 | 5 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ============================================================================== 19 | """Utilities to compute human-normalized Atari scores. 20 | The data used in this module is human and random performance data on Atari-57. 21 | It comprises of evaluation scores (undiscounted returns), each averaged 22 | over at least 3 episode runs, on each of the 57 Atari games. Each episode begins 23 | with the environment already stepped with a uniform random number (between 1 and 24 | 30 inclusive) of noop actions. 25 | The two agents are: 26 | * 'random' (agent choosing its actions uniformly randomly on each step) 27 | * 'human' (professional human game tester) 28 | Scores are obtained by averaging returns over the episodes played by each agent, 29 | with episode length capped to 108,000 frames (i.e. timeout after 30 minutes). 30 | The term 'human-normalized' here means a linear per-game transformation of 31 | a game score in such a way that 0 corresponds to random performance and 1 32 | corresponds to human performance. 33 | """ 34 | 35 | import math 36 | from .env_names import ATARI_NAME_TO_ENVID 37 | 38 | 39 | # Game: score-tuple dictionary. Each score tuple contains 40 | # 0: score random (float) and 1: score human (float). 41 | ENVID_TO_HNS = { 42 | "alien": (227.8, 7127.7), 43 | "amidar": (5.8, 1719.5), 44 | "assault": (222.4, 742.0), 45 | "asterix": (210.0, 8503.3), 46 | "asteroids": (719.1, 47388.7), 47 | "atlantis": (12850.0, 29028.1), 48 | "bank-heist": (14.2, 753.1), 49 | "battle-zone": (2360.0, 37187.5), 50 | "beam-rider": (363.9, 16926.5), 51 | "berzerk": (123.7, 2630.4), 52 | "bowling": (23.1, 160.7), 53 | "boxing": (0.1, 12.1), 54 | "breakout": (1.7, 30.5), 55 | "centipede": (2090.9, 12017.0), 56 | "chopper-command": (811.0, 7387.8), 57 | "crazy-climber": (10780.5, 35829.4), 58 | "defender": (2874.5, 18688.9), 59 | "demon-attack": (152.1, 1971.0), 60 | "double-dunk": (-18.6, -16.4), 61 | "enduro": (0.0, 860.5), 62 | "fishing-derby": (-91.7, -38.7), 63 | "freeway": (0.0, 29.6), 64 | "frostbite": (65.2, 4334.7), 65 | "gopher": (257.6, 2412.5), 66 | "gravitar": (173.0, 3351.4), 67 | "hero": (1027.0, 30826.4), 68 | "ice-hockey": (-11.2, 0.9), 69 | "jamesbond": (29.0, 302.8), 70 | "kangaroo": (52.0, 3035.0), 71 | "krull": (1598.0, 2665.5), 72 | "kung-fu-master": (258.5, 22736.3), 73 | "montezuma-revenge": (0.0, 4753.3), 74 | "ms-pacman": (307.3, 6951.6), 75 | "name-this-game": (2292.3, 8049.0), 76 | "phoenix": (761.4, 7242.6), 77 | "pitfall": (-229.4, 6463.7), 78 | "pong": (-20.7, 14.6), 79 | "private-eye": (24.9, 69571.3), 80 | "qbert": (163.9, 13455.0), 81 | "riverraid": (1338.5, 17118.0), 82 | "road-runner": (11.5, 7845.0), 83 | "robotank": (2.2, 11.9), 84 | "seaquest": (68.4, 42054.7), 85 | "skiing": (-17098.1, -4336.9), 86 | "solaris": (1236.3, 12326.7), 87 | "space-invaders": (148.0, 1668.7), 88 | "star-gunner": (664.0, 10250.0), 89 | "surround": (-10.0, 6.5), 90 | "tennis": (-23.8, -8.3), 91 | "time-pilot": (3568.0, 5229.2), 92 | "tutankham": (11.4, 167.6), 93 | "up-n-down": (533.4, 11693.2), 94 | "venture": (0.0, 1187.5), 95 | # Note the random agent score on Video Pinball is sometimes greater than the 96 | # human score under other evaluation methods. 97 | "video-pinball": (16256.9, 17667.9), 98 | "wizard-of-wor": (563.5, 4756.5), 99 | "yars-revenge": (3092.9, 54576.9), 100 | "zaxxon": (32.5, 9173.3), 101 | # extracted from procgen paper 102 | "bigfish": (1,40), 103 | "bossfight": (0.5,13), 104 | "caveflyer": (3.5,12), 105 | "chaser": (0.5,13), 106 | "climber": (2,12.6), 107 | "coinrun": (5,10), 108 | "dodgeball": (1.5,19), 109 | "fruitbot": (-1.5,32.4), 110 | "heist": (3.5,10), 111 | "jumper": (3,10), 112 | "leaper": (3,10), 113 | "maze": (5,10), 114 | "miner": (1.5,13), 115 | "ninja": (3.5,10), 116 | "plunder": (4.5,30), 117 | "starpilot": (2.5,64), 118 | } 119 | 120 | # add scores for actual env ids 121 | keys = list(ENVID_TO_HNS.keys()) 122 | for k in keys: 123 | if k not in ATARI_NAME_TO_ENVID: 124 | continue 125 | envid = ATARI_NAME_TO_ENVID[k] 126 | ENVID_TO_HNS[envid] = ENVID_TO_HNS[k] 127 | 128 | 129 | def get_human_normalized_score(game: str, raw_score: float, random_col=0, human_col=1) -> float: 130 | """Converts game score to human-normalized score.""" 131 | game_scores = ENVID_TO_HNS.get(game, (math.nan, math.nan)) 132 | random, human = game_scores[random_col], game_scores[human_col] 133 | return (raw_score - random) / (human - random) 134 | -------------------------------------------------------------------------------- /src/algos/models/multi_domain_discrete_dt_model.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from .image_encoders import make_image_encoder 6 | from .discrete_decision_transformer_model import DiscreteDTModel 7 | from ...tokenizers_custom import make_tokenizer 8 | 9 | 10 | class MultiDomainDiscreteDTModel(DiscreteDTModel): 11 | 12 | def __init__(self, config, observation_space, action_space, action_channels=256, discrete_actions=18, 13 | state_dim=39, image_shape=(1,84,84), **kwargs): 14 | """ 15 | Discrete DT version that supports multi-domain training. 16 | Different domains have different state and action spaces. This model takes care of that. 17 | Input observation space and action space arguments are irrelevant, as they only account for a single train env. 18 | Instead, the class demands an image shape and state dim, which are used to make persistent observation encoders. 19 | All actions need to be discrete to be used with the discrete action embeddings. 20 | This class should only be used in offline setting, as online data collection is not currently supported. 21 | Args: 22 | config: Huggingface config. 23 | observation_space: gym.Space. 24 | action_space: gym.Space 25 | image_shape: Tuple or List. Shape of image observations. 26 | state_dim: Int. Dimension/shape of state observations. 27 | discrete_actions: Int. Defaults to 18 (full Atari action space). Number of discrete actions. 28 | Also used as shift for the action tokenizer. 29 | 30 | """ 31 | self.discrete_actions = discrete_actions 32 | self.num_actions = discrete_actions + action_channels 33 | super().__init__(config, observation_space, action_space, action_channels=action_channels, **kwargs) 34 | 35 | # make persistent state/image encoders 36 | self.image_shape = image_shape 37 | self.state_dim = state_dim 38 | if self.image_shape is not None: 39 | # overwrite if exists 40 | if self.patch_size is not None: 41 | self.setup_patch_encoder() 42 | else: 43 | self.embed_image = make_image_encoder( 44 | observation_space=gym.spaces.Box(0, 255, self.image_shape, dtype=np.uint8), 45 | features_dim=config.hidden_size, encoder_kwargs=self.encoder_kwargs 46 | ) 47 | if self.state_dim is not None and not self.tokenize_s: 48 | del self.embed_state 49 | self.embed_state = torch.nn.Linear(self.state_dim, config.hidden_size) 50 | 51 | # make action tokenizer with shift 52 | assert self.tokenize_a or self.action_channels == 0, "If not tokenizing, action channels must be 0." 53 | if self.tokenize_a: 54 | a_tok_kind = self.a_tok_kwargs.pop('kind', 'minmax') 55 | # add shift argument to shift tokenization to the right by num of discrete actions 56 | self.action_tokenizer = make_tokenizer( 57 | a_tok_kind, 58 | {'vocab_size': self.action_channels, "shift": self.discrete_actions, **self.a_tok_kwargs} 59 | ) 60 | 61 | # make universal action embeddings 62 | self.action_pad_token = self.num_actions if self.use_action_pad else None 63 | self.embed_action_disc = nn.Embedding( 64 | self.num_actions + 1, config.hidden_size, padding_idx=self.action_pad_token 65 | ) 66 | 67 | def setup_policy(self): 68 | out_dim = self.num_actions 69 | if self.shared_a_head: 70 | # predict all action dimensions at once 71 | out_dim = out_dim * self.config.act_dim 72 | 73 | if self.stochastic_policy: 74 | raise NotImplementedError("Stochastic policy not implemented for multi-domain discrete DT.") 75 | if self.num_task_heads > 1: 76 | self.action_net = nn.ModuleList( 77 | [self.make_head(self.config.hidden_size, out_dim, self.n_layer_head) 78 | for _ in range(self.num_task_heads)] 79 | ) 80 | else: 81 | self.action_net = self.make_head(self.config.hidden_size, out_dim, self.n_layer_head) 82 | 83 | def get_action_from_logits(self, action_logits, is_discrete=False): 84 | if action_logits.shape[-2] == 1 and is_discrete: 85 | # safeguard for discrete action spaces to avoid selecting actions > num discrete actions 86 | # we assume discrete action spaces have action dim of 1 87 | action = torch.argmax(action_logits[..., :self.discrete_actions], dim=-1) 88 | else: 89 | action = torch.argmax(action_logits, dim=-1) 90 | if self.tokenize_a and action.shape[-1] > 1: 91 | action = self.inv_tokenize_actions(action) 92 | if len(action.shape) == 2: 93 | action = action.unsqueeze(0) 94 | return action 95 | 96 | def prepare_action_logits(self, action_logits, is_discrete=False): 97 | if self.tok_a_target_only: 98 | # action_logits contains action dim predictions together --> split up 99 | action_logits = action_logits.reshape((*action_logits.shape[:-1], self.config.act_dim, self.action_channels)) 100 | if self.shared_a_head: 101 | if action_logits.shape[-2] == 1 and is_discrete: 102 | # is discrete 103 | action_logits = action_logits[..., :self.num_actions] 104 | else: 105 | # only preserve predictions from first action dimension --> shared head 106 | orig_shape = action_logits.shape 107 | action_logits = action_logits[:, :, 0].reshape(*orig_shape[:-2], self.config.act_dim, self.num_actions) 108 | return action_logits 109 | -------------------------------------------------------------------------------- /src/envs/dmcontrol_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import gym 4 | import dmc2gym_custom 5 | from gym import spaces 6 | from stable_baselines3.common.monitor import Monitor 7 | from dm_control import suite 8 | 9 | 10 | def extract_obs_dims(exclude_domains=[]): 11 | obstype_to_dim = {} 12 | for domain_name, task_name in suite.BENCHMARKING: 13 | env = suite.load(domain_name, task_name) 14 | time_step = env.reset() 15 | print(f"{domain_name}-{task_name}", 16 | {k: v.shape for k, v in time_step.observation.items()}, 17 | env.action_spec().shape, 18 | "\n") 19 | if any(domain in domain_name for domain in exclude_domains): 20 | continue 21 | for k, v in time_step.observation.items(): 22 | v = np.array([v]) if np.isscalar(v) else v.ravel() 23 | obstype_to_dim[k] = max(obstype_to_dim.get(k, 0), v.shape[0]) 24 | return obstype_to_dim 25 | 26 | def extract_obstype_to_startidx(obstype_to_dim): 27 | cum_dim = 0 28 | obstype_to_start_idx = {} 29 | for k, v in obstype_to_dim.items(): 30 | obstype_to_start_idx[k] = cum_dim 31 | cum_dim += v 32 | return obstype_to_start_idx 33 | 34 | 35 | DMC_OBSTYPE_TO_DIM = { 36 | 'orientations': 14, 'velocity': 27, 'position': 8, 'touch': 5, 'target_position': 2, 'dist_to_target': 1, 37 | 'joint_angles': 21, 'upright': 1, 'target': 3, 'head_height': 1, 'extremities': 12, 'torso_vertical': 3, 38 | 'com_velocity': 3, 'arm_pos': 16, 'arm_vel': 8, 'hand_pos': 4, 'object_pos': 4, 'object_vel': 3, 'target_pos': 4, 39 | 'orientation': 2, 'to_target': 2, 'joints': 14, 'body_velocities': 45, 'height': 1 40 | } 41 | 42 | DMC_FULL_OBS_DIM = sum(DMC_OBSTYPE_TO_DIM.values()) 43 | 44 | DMC_OBSTYPE_TO_STARTIDX = { 45 | 'orientations': 0, 'velocity': 14, 'position': 41, 'touch': 49, 'target_position': 54, 'dist_to_target': 56, 46 | 'joint_angles': 57, 'upright': 78, 'target': 79, 'head_height': 82, 'extremities': 83, 'torso_vertical': 95, 47 | 'com_velocity': 98, 'arm_pos': 101, 'arm_vel': 117, 'hand_pos': 125, 'object_pos': 129, 'object_vel': 133, 48 | 'target_pos': 136, 'orientation': 140, 'to_target': 142, 'joints': 144, 'body_velocities': 158, 'height': 203 49 | } 50 | 51 | 52 | def map_obs_to_full_space(obs): 53 | dtype = obs.dtype if hasattr(obs, "dtype") else np.float32 54 | full_obs = np.zeros(DMC_FULL_OBS_DIM, dtype=dtype) 55 | for k, v in obs.items(): 56 | start_idx = DMC_OBSTYPE_TO_STARTIDX[k] 57 | v = np.array([v]) if np.isscalar(v) else v.ravel() 58 | full_obs[start_idx: start_idx + v.shape[0]] = v 59 | return full_obs 60 | 61 | 62 | def map_flattened_obs_to_full_space(obs, obs_spec): 63 | if not isinstance(obs, np.ndarray): 64 | obs = np.array(obs) 65 | is_one_dim = len(obs.shape) == 1 66 | if is_one_dim: 67 | obs = np.expand_dims(obs, axis=0) 68 | full_obs = np.zeros((*obs.shape[:-1], DMC_FULL_OBS_DIM)) 69 | flat_start_idx = 0 70 | for k, v in obs_spec.items(): 71 | dim = np.prod(v.shape) if len(v.shape) > 0 else 1 72 | full_start_idx = DMC_OBSTYPE_TO_STARTIDX[k] 73 | full_obs[..., full_start_idx: full_start_idx + dim] = obs[..., flat_start_idx: flat_start_idx + dim] 74 | flat_start_idx += dim 75 | if is_one_dim: 76 | full_obs = full_obs.ravel() 77 | return full_obs 78 | 79 | 80 | class DmcFullObsWrapper(gym.ObservationWrapper): 81 | """ 82 | Converts a given state observation to the full observation space of all DMControl environments. 83 | 84 | Unforunately, dmc2gym always flattens the obsevation by default. Therefore, this wrapper should 85 | always be used with dmc2gym custom, which make flattening the observation optional. 86 | 87 | Args: 88 | env: Gym environment. 89 | """ 90 | 91 | def __init__(self, env: gym.Env): 92 | gym.ObservationWrapper.__init__(self, env) 93 | low, high = np.array([-float("inf")] * DMC_FULL_OBS_DIM), np.array([float("inf")] * DMC_FULL_OBS_DIM) 94 | self.observation_space = spaces.Box( 95 | low=low, high=high, dtype=np.float32 96 | ) 97 | 98 | def observation(self, obs): 99 | return map_obs_to_full_space(obs) 100 | 101 | 102 | class GrayscaleWrapper(gym.ObservationWrapper): 103 | """ 104 | Converts a given frame to grayscale. The given frame must be channel last. 105 | 106 | Args: 107 | env: Gym environment. 108 | """ 109 | 110 | def __init__(self, env: gym.Env): 111 | gym.ObservationWrapper.__init__(self, env) 112 | channels, height, width, = env.observation_space.shape 113 | assert channels != 1, "Image is grayscale already." 114 | self.observation_space = spaces.Box( 115 | low=0, high=255, shape=(1, height, width), dtype=env.observation_space.dtype 116 | ) 117 | 118 | def observation(self, frame): 119 | frame = cv2.cvtColor(frame.transpose(1, 2, 0), cv2.COLOR_RGB2GRAY) 120 | return np.expand_dims(frame, 0) 121 | 122 | 123 | def get_dmcontrol_constructor(envid, env_kwargs=None): 124 | env_kwargs = dict(env_kwargs) if env_kwargs is not None else {} 125 | render_mode = env_kwargs.pop("render_mode", None) 126 | def make(): 127 | domain_name, task_name = envid.split("-") 128 | env = dmc2gym_custom.make(domain_name=domain_name, task_name=task_name, **env_kwargs) 129 | # change envid to make more readable than default in dmc2gym_custom 130 | env.spec.id = f"{domain_name}-{task_name}" 131 | env.name = f"{domain_name}-{task_name}" 132 | if env_kwargs.get("from_pixels", False): 133 | env = GrayscaleWrapper(env) 134 | if not env_kwargs.get("flatten_obs", True): 135 | env = DmcFullObsWrapper(env) 136 | if render_mode is not None: 137 | env.metadata.update({"render.modes": [render_mode]}) 138 | return Monitor(env) 139 | return make 140 | -------------------------------------------------------------------------------- /src/algos/models/rope.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class LlamaRotaryEmbedding(nn.Module): 9 | def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): 10 | super().__init__() 11 | self.scaling_factor = scaling_factor 12 | self.dim = dim 13 | self.max_position_embeddings = max_position_embeddings 14 | self.base = base 15 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) 16 | self.register_buffer("inv_freq", inv_freq, persistent=False) 17 | # For BC we register cos and sin cached 18 | self.max_seq_len_cached = max_position_embeddings 19 | t = torch.arange(self.max_seq_len_cached, dtype=torch.int64).type_as(self.inv_freq) 20 | t = t / self.scaling_factor 21 | freqs = torch.outer(t, self.inv_freq) 22 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 23 | emb = torch.cat((freqs, freqs), dim=-1) 24 | self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) 25 | self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) 26 | 27 | @property 28 | def sin_cached(self): 29 | print( 30 | "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " 31 | "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" 32 | ) 33 | return self._sin_cached 34 | 35 | @property 36 | def cos_cached(self): 37 | print( 38 | "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " 39 | "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" 40 | ) 41 | return self._cos_cached 42 | 43 | @torch.no_grad() 44 | def forward(self, x, position_ids): 45 | # x: [bs, num_attention_heads, seq_len, head_size] 46 | inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) 47 | position_ids_expanded = position_ids[:, None, :].float() 48 | # Force float32 since bfloat16 loses precision on long contexts 49 | # See https://github.com/huggingface/transformers/pull/29285 50 | # device_type = x.device.type 51 | # device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" 52 | # with torch.autocast(device_type=device_type, enabled=False): 53 | # does not work for some reason, if we compile 54 | freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) 55 | emb = torch.cat((freqs, freqs), dim=-1) 56 | cos = emb.cos() 57 | sin = emb.sin() 58 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 59 | 60 | 61 | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): 62 | """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 63 | 64 | def forward(self, x, position_ids): 65 | # difference to the original RoPE: a scaling factor is aplied to the position ids 66 | position_ids = position_ids.float() / self.scaling_factor 67 | cos, sin = super().forward(x, position_ids) 68 | return cos, sin 69 | 70 | 71 | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): 72 | """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 73 | 74 | def forward(self, x, position_ids): 75 | # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length 76 | seq_len = torch.max(position_ids) + 1 77 | if seq_len > self.max_position_embeddings: 78 | base = self.base * ( 79 | (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) 80 | ) ** (self.dim / (self.dim - 2)) 81 | inv_freq = 1.0 / ( 82 | base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) 83 | ) 84 | self.register_buffer("inv_freq", inv_freq, persistent=False) 85 | 86 | cos, sin = super().forward(x, position_ids) 87 | return cos, sin 88 | 89 | 90 | def rotate_half(x): 91 | """Rotates half the hidden dims of the input.""" 92 | x1 = x[..., : x.shape[-1] // 2] 93 | x2 = x[..., x.shape[-1] // 2 :] 94 | return torch.cat((-x2, x1), dim=-1) 95 | 96 | 97 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): 98 | """Applies Rotary Position Embedding to the query and key tensors. 99 | 100 | Args: 101 | q (`torch.Tensor`): The query tensor. 102 | k (`torch.Tensor`): The key tensor. 103 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 104 | sin (`torch.Tensor`): The sine part of the rotary embedding. 105 | position_ids (`torch.Tensor`, *optional*): 106 | Deprecated and unused. 107 | unsqueeze_dim (`int`, *optional*, defaults to 1): 108 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 109 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 110 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 111 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 112 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 113 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 114 | Returns: 115 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 116 | """ 117 | cos = cos.unsqueeze(unsqueeze_dim) 118 | sin = sin.unsqueeze(unsqueeze_dim) 119 | q_embed = (q * cos) + (rotate_half(q) * sin) 120 | k_embed = (k * cos) + (rotate_half(k) * sin) 121 | return q_embed, k_embed 122 | -------------------------------------------------------------------------------- /src/algos/decision_mamba.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from dataclasses import dataclass, field 3 | from .universal_decision_transformer_sb3 import UDT 4 | from .discrete_decision_transformer_sb3 import DiscreteDecisionTransformerSb3 5 | from .models.model_utils import sample_from_logits 6 | 7 | 8 | @dataclass 9 | class InferenceParams: 10 | """Inference parameters that are passed to the main model in order 11 | to efficienly calculate and store the context during inference.""" 12 | 13 | max_seqlen: int 14 | max_batch_size: int 15 | seqlen_offset: int = 0 16 | batch_size_offset: int = 0 17 | key_value_memory_dict: dict = field(default_factory=dict) 18 | lengths_per_sample=None 19 | 20 | def reset(self): 21 | self.max_seqlen = self.max_seqlen 22 | self.max_batch_size = self.max_batch_size 23 | self.seqlen_offset = 0 24 | if self.lengths_per_sample is not None: 25 | self.lengths_per_sample.zero_() 26 | 27 | 28 | class DecisionMamba(UDT): 29 | 30 | def __init__(self, policy, env, use_inference_cache=False, **kwargs): 31 | super().__init__(policy, env, **kwargs) 32 | self.use_inference_cache = use_inference_cache 33 | self.inference_params = None 34 | if self.use_inference_cache: 35 | self.inference_params = InferenceParams( 36 | max_seqlen=self.policy.config.max_length, 37 | max_batch_size=1, 38 | ) 39 | 40 | def get_action_pred(self, policy, states, actions, rewards, returns_to_go, timesteps, attention_mask, 41 | deterministic, prompt, is_eval=False, task_id=None, env_act_dim=None): 42 | if self.use_inference_cache: 43 | # only last step 44 | states, actions, rewards, returns_to_go, timesteps, attention_mask = states[:, -1:], actions[:, -1:],\ 45 | rewards[:, -1:], returns_to_go[:, -1:], timesteps[:, -1:], attention_mask[:, -1:] 46 | 47 | with torch.autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp): 48 | policy_output = policy( 49 | states=states, 50 | actions=actions, 51 | rewards=rewards, 52 | returns_to_go=returns_to_go, 53 | timesteps=timesteps, 54 | attention_mask=attention_mask, 55 | return_dict=True, 56 | deterministic=deterministic, 57 | prompt=prompt, 58 | task_id=task_id, 59 | ddp_kwargs=self.ddp_kwargs, 60 | inference_params=self.inference_params 61 | ) 62 | 63 | if not is_eval and self.num_timesteps % 10000 == 0 and self.log_attn_maps: 64 | self._record_attention_maps(policy_output.attentions, step=self.num_timesteps, prefix="rollout") 65 | if policy_output.cross_attentions is not None: 66 | self._record_attention_maps(policy_output.cross_attentions, step=self.num_timesteps, 67 | prefix="rollout_cross", lower_triu=False) 68 | action_preds = policy_output.action_preds 69 | if env_act_dim is not None: 70 | action_preds = action_preds[..., :env_act_dim] 71 | return action_preds[0, -1], action_preds[0, -1] 72 | 73 | 74 | class DiscreteDecisionMamba(DiscreteDecisionTransformerSb3, DecisionMamba): 75 | 76 | def get_action_pred(self, policy, states, actions, rewards, returns_to_go, timesteps, attention_mask, 77 | deterministic, prompt, is_eval=False, task_id=None, env_act_dim=None): 78 | inputs = { 79 | "states": states, 80 | "actions": actions, 81 | "rewards": rewards, 82 | "returns_to_go": returns_to_go, 83 | "timesteps": timesteps, 84 | "attention_mask": attention_mask, 85 | "return_dict": True, 86 | "deterministic": deterministic, 87 | "prompt": prompt, 88 | "task_id": task_id, 89 | "ddp_kwargs": self.ddp_kwargs, 90 | "inference_params": self.inference_params 91 | } 92 | 93 | if self.use_inference_cache: 94 | inputs.update({"states": states[:, -1:], "actions": actions[:, -1:], "rewards": rewards[:, -1:], 95 | "returns_to_go": returns_to_go[:, -1:], "timesteps": timesteps[:, -1:], 96 | "attention_mask": attention_mask[:, -1:]}) 97 | 98 | # exper-action inference mechanism 99 | if self.target_return_type == "infer": 100 | with torch.autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp): 101 | policy_output = policy(**inputs) 102 | return_logits = policy_output.return_preds[:, -1] 103 | return_sample = policy.sample_from_rtg_logits(return_logits, **self.rtg_sample_kwargs) 104 | inputs["returns_to_go"][0, -1] = return_sample 105 | 106 | # autoregressive action prediction 107 | # e.g., for discretizes continuous action space need to predict each action dim after another 108 | act_dim = actions.shape[-1] if env_act_dim is None else env_act_dim 109 | for i in range(act_dim): 110 | with torch.autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp): 111 | policy_output = policy(**inputs) 112 | 113 | if not is_eval and self.num_timesteps % 10000 == 0 and self.log_attn_maps: 114 | self._record_attention_maps(policy_output.attentions, step=self.num_timesteps, prefix="rollout") 115 | if policy_output.cross_attentions is not None: 116 | self._record_attention_maps(policy_output.cross_attentions, step=self.num_timesteps + i, 117 | prefix="rollout_cross", lower_triu=False) 118 | if self.a_sample_kwargs is not None: 119 | action_logits = policy_output.action_logits[0, -1, i] 120 | inputs["actions"][0, -1, i] = sample_from_logits(action_logits, **self.a_sample_kwargs) 121 | else: 122 | inputs["actions"][0, -1, i] = policy_output.action_preds[0, -1, i] 123 | 124 | action = inputs["actions"][0, -1] 125 | if env_act_dim is not None: 126 | action = action[:act_dim] 127 | return action, inputs["returns_to_go"][0, -1] if self.target_return_type == "infer" else action 128 | -------------------------------------------------------------------------------- /src/buffers/multi_domain_buffer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | from .trajectory_buffer import TrajectoryReplayBuffer 4 | from .samplers import DomainWeightedRandomSampler, DistributedSamplerWrapper, MixedBatchRandomSampler 5 | from .trajectory import Trajectory 6 | 7 | 8 | class MultiDomainTrajectoryReplayBuffer(TrajectoryReplayBuffer): 9 | 10 | def __init__(self, buffer_size, observation_space, action_space, mixed=False, mixed_weighted=False, 11 | domain_weights=None, **kwargs): 12 | """ 13 | A trajectory replay buffer that can handle trajectories from multiple domains 14 | Different domains have different observation spaces and action spaces. 15 | Data is loaded from different data paths. 16 | When sampling, each batch can only contain trajectories from one domain (or from domains with same shapes). 17 | Otherwise the batch collating will fail. This buffer should only be used in offline mode for pre-training. 18 | It assumes that all trajectories from a particular domain are either in memrory or on disk, but no mixture. 19 | 20 | Args: 21 | buffer_size (int): size of the buffer 22 | observation_space (gym.Space): observation space 23 | action_space (gym.Space): action space 24 | mixed: Bool. Whether batches contain sequences from multiple domains. 25 | mixed_weighted: Bool. Whether to weight the samples in each batch per domain and by length. 26 | domain_weights: None or Dict. 27 | 28 | """ 29 | super().__init__(buffer_size, observation_space, action_space, **kwargs) 30 | # domain specific 31 | self.mixed = mixed 32 | self.domain_id = 0 33 | self.domain_to_indices = collections.defaultdict(list) 34 | self.valid_domain_to_indices = collections.defaultdict(list) 35 | self.task_to_domain = {} 36 | self.domain_weights = domain_weights 37 | self.fixed_domain_weights = domain_weights is not None 38 | self.mixed_weighted = mixed_weighted if domain_weights is None else True 39 | self.domain_names = None 40 | 41 | def compute_trajectory_probs(self, top_k=5, weight_by="len"): 42 | if self.mixed: 43 | return super().compute_trajectory_probs(top_k=top_k, weight_by=weight_by) 44 | # only supports weighting by len or uniform weighting 45 | if self.trj_ds_has_changed or self.trajectory_probs is None: 46 | self.trajectory_probs = {} 47 | if weight_by == "uniform": 48 | for i, indices in self.domain_to_indices.items(): 49 | num_trjs = len(indices) 50 | self.trajectory_probs[i] = [1 / num_trjs] * num_trjs 51 | elif weight_by == "len": 52 | for i, indices in self.domain_to_indices.items(): 53 | trj_lens = [len(self.trajectories[idx]) if isinstance(self.trajectories[idx], Trajectory) 54 | else self.trajectory_lengths[str(self.trajectories[idx])] 55 | for idx in indices] 56 | total_samples = sum(trj_lens) 57 | self.trajectory_probs[i] = [l / total_samples for l in trj_lens] 58 | else: 59 | raise NotImplementedError() 60 | return self.trajectory_probs 61 | 62 | def make_sampler(self, dataset, trajectory_probs, batch_size): 63 | if self.mixed: 64 | return super().make_sampler(dataset, trajectory_probs, batch_size) 65 | 66 | mult = 100 if not self.ddp else 10 67 | batch_size = batch_size if not self.ddp else batch_size * int(os.environ["WORLD_SIZE"]) 68 | 69 | if self.mixed_weighted: 70 | # mix batches in proporition to domain 71 | if (self.trj_ds_has_changed or self.domain_weights is None) and not self.fixed_domain_weights: 72 | total_samples_per_domain = {} 73 | for i, indices in self.domain_to_indices.items(): 74 | total_samples_per_domain[i] = sum([ 75 | len(self.trajectories[idx]) if isinstance(self.trajectories[idx], Trajectory) 76 | else self.trajectory_lengths[str(self.trajectories[idx])] 77 | for idx in indices 78 | ]) 79 | total_samples = sum(total_samples_per_domain.values()) 80 | self.domain_weights = {i: total_samples_per_domain[i] / total_samples for i in total_samples_per_domain} 81 | sampler = MixedBatchRandomSampler(weights=trajectory_probs, domain_weights=self.domain_weights, 82 | batch_size=batch_size, num_samples=len(dataset) * mult, replacement=True) 83 | else: 84 | sampler = DomainWeightedRandomSampler(weights=trajectory_probs, batch_size=batch_size, 85 | num_samples=len(dataset) * mult, replacement=True) 86 | if self.ddp: 87 | sampler = DistributedSamplerWrapper(sampler) 88 | return sampler 89 | 90 | def init_buffer_from_dataset(self, paths): 91 | assert isinstance(paths, (list, tuple, dict)) 92 | if isinstance(paths, dict): 93 | self.domain_names = list(paths.keys()) 94 | paths = list(paths.values()) 95 | else: 96 | self.domain_names = list(range(len(paths))) 97 | for i, p in enumerate(paths): 98 | self.domain_id = i 99 | start_idx = len(self) 100 | valid_start_idx = len(self._valid_trajectories) 101 | self.task_id = max(self.task_to_trj.keys()) + 1 if len(self.task_to_trj) > 0 else 0 102 | super().init_buffer_from_dataset(p) 103 | end_idx = len(self) 104 | self.domain_to_indices[i] = list(range(start_idx, end_idx)) 105 | self.valid_domain_to_indices[i] = list(range(valid_start_idx, len(self._valid_trajectories))) 106 | 107 | def init_from_existing_buffer(self, buffer, validation=False): 108 | super().init_from_existing_buffer(buffer, validation=validation) 109 | self.domain_to_indices = buffer.domain_to_indices if not validation else buffer.valid_domain_to_indices 110 | self.domain_weights = buffer.domain_weights 111 | self.fixed_domain_weights = buffer.fixed_domain_weights 112 | self.mixed_weighted = buffer.mixed_weighted 113 | self.task_to_domain = buffer.task_to_domain 114 | self.domain_names = buffer.domain_names 115 | 116 | def get_current_domain_name(self, accumulation_steps): 117 | idx = max(self.num_sampled_batches - 1, 0) % accumulation_steps 118 | domain_name = self.domain_names[idx] 119 | return domain_name 120 | -------------------------------------------------------------------------------- /dmc2gym_custom/dmc2gym_custom/wrappers.py: -------------------------------------------------------------------------------- 1 | from gym import core, spaces 2 | from dm_control import suite 3 | from dm_env import specs 4 | import numpy as np 5 | 6 | 7 | def _spec_to_box(spec, dtype): 8 | def extract_min_max(s): 9 | assert s.dtype == np.float64 or s.dtype == np.float32 10 | dim = np.int(np.prod(s.shape)) 11 | if type(s) == specs.Array: 12 | bound = np.inf * np.ones(dim, dtype=np.float32) 13 | return -bound, bound 14 | elif type(s) == specs.BoundedArray: 15 | zeros = np.zeros(dim, dtype=np.float32) 16 | return s.minimum + zeros, s.maximum + zeros 17 | 18 | mins, maxs = [], [] 19 | for s in spec: 20 | mn, mx = extract_min_max(s) 21 | mins.append(mn) 22 | maxs.append(mx) 23 | low = np.concatenate(mins, axis=0).astype(dtype) 24 | high = np.concatenate(maxs, axis=0).astype(dtype) 25 | assert low.shape == high.shape 26 | return spaces.Box(low, high, dtype=dtype) 27 | 28 | 29 | def _flatten_obs(obs): 30 | obs_pieces = [] 31 | for v in obs.values(): 32 | flat = np.array([v]) if np.isscalar(v) else v.ravel() 33 | obs_pieces.append(flat) 34 | return np.concatenate(obs_pieces, axis=0) 35 | 36 | 37 | class DMCWrapper(core.Env): 38 | def __init__( 39 | self, 40 | domain_name, 41 | task_name, 42 | task_kwargs=None, 43 | visualize_reward={}, 44 | from_pixels=False, 45 | height=84, 46 | width=84, 47 | camera_id=0, 48 | frame_skip=1, 49 | environment_kwargs=None, 50 | channels_first=True, 51 | flatten_obs=True, 52 | deterministic=False 53 | ): 54 | assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour' 55 | self._domain_name = domain_name 56 | self._task_name = task_name 57 | self._task_kwargs = task_kwargs 58 | self._visualize_reward = visualize_reward 59 | self._from_pixels = from_pixels 60 | self._height = height 61 | self._width = width 62 | self._camera_id = camera_id 63 | self._frame_skip = frame_skip 64 | self._channels_first = channels_first 65 | self._flatten_obs = flatten_obs 66 | self.deterministic = deterministic 67 | self._environment_kwargs = environment_kwargs 68 | 69 | # create task 70 | self._env = suite.load( 71 | domain_name=domain_name, 72 | task_name=task_name, 73 | task_kwargs=task_kwargs, 74 | visualize_reward=visualize_reward, 75 | environment_kwargs=environment_kwargs 76 | ) 77 | 78 | # true and normalized action spaces 79 | self._true_action_space = _spec_to_box([self._env.action_spec()], np.float32) 80 | self._norm_action_space = spaces.Box( 81 | low=-1.0, 82 | high=1.0, 83 | shape=self._true_action_space.shape, 84 | dtype=np.float32 85 | ) 86 | 87 | # create observation space 88 | if from_pixels: 89 | shape = [3, height, width] if channels_first else [height, width, 3] 90 | self._observation_space = spaces.Box( 91 | low=0, high=255, shape=shape, dtype=np.uint8 92 | ) 93 | else: 94 | self._observation_space = _spec_to_box( 95 | self._env.observation_spec().values(), 96 | np.float64 97 | ) 98 | 99 | self._state_space = _spec_to_box( 100 | self._env.observation_spec().values(), 101 | np.float64 102 | ) 103 | 104 | self.current_state = None 105 | 106 | # set seed 107 | self._seed = task_kwargs.get('random', 1) 108 | self.seed(seed=self._seed) 109 | 110 | def __getattr__(self, name): 111 | return getattr(self._env, name) 112 | 113 | def __reduce__(self): 114 | return (self.__class__, (self._domain_name, self._task_name, self._task_kwargs, 115 | self._visualize_reward, self._from_pixels, self._height, 116 | self._width, self._camera_id, self._frame_skip, 117 | self._environment_kwargs, self._channels_first, 118 | self._flatten_obs, self.deterministic)) 119 | 120 | def _get_obs(self, time_step): 121 | if self._from_pixels: 122 | obs = self.render( 123 | height=self._height, 124 | width=self._width, 125 | camera_id=self._camera_id 126 | ) 127 | if self._channels_first: 128 | obs = obs.transpose(2, 0, 1).copy() 129 | else: 130 | obs = _flatten_obs(time_step.observation) if self._flatten_obs else time_step.observation 131 | return obs 132 | 133 | def _convert_action(self, action): 134 | action = action.astype(np.float64) 135 | true_delta = self._true_action_space.high - self._true_action_space.low 136 | norm_delta = self._norm_action_space.high - self._norm_action_space.low 137 | action = (action - self._norm_action_space.low) / norm_delta 138 | action = action * true_delta + self._true_action_space.low 139 | action = action.astype(np.float32) 140 | return action 141 | 142 | @property 143 | def observation_space(self): 144 | return self._observation_space 145 | 146 | @property 147 | def state_space(self): 148 | return self._state_space 149 | 150 | @property 151 | def action_space(self): 152 | return self._norm_action_space 153 | 154 | @property 155 | def reward_range(self): 156 | return 0, self._frame_skip 157 | 158 | def seed(self, seed): 159 | self._true_action_space.seed(seed) 160 | self._norm_action_space.seed(seed) 161 | self._observation_space.seed(seed) 162 | 163 | def step(self, action): 164 | assert self._norm_action_space.contains(action) 165 | action = self._convert_action(action) 166 | assert self._true_action_space.contains(action) 167 | reward = 0 168 | extra = {'internal_state': self._env.physics.get_state().copy()} 169 | 170 | for _ in range(self._frame_skip): 171 | time_step = self._env.step(action) 172 | reward += time_step.reward or 0 173 | done = time_step.last() 174 | if done: 175 | break 176 | obs = self._get_obs(time_step) 177 | self.current_state = _flatten_obs(time_step.observation) if self._flatten_obs else time_step.observation 178 | extra['discount'] = time_step.discount 179 | return obs, reward, done, extra 180 | 181 | def reset(self): 182 | if self.deterministic: 183 | self._env.task._random = np.random.RandomState(self._seed) 184 | time_step = self._env.reset() 185 | self.current_state = _flatten_obs(time_step.observation) if self._flatten_obs else time_step.observation 186 | obs = self._get_obs(time_step) 187 | return obs 188 | 189 | def render(self, mode='rgb_array', height=None, width=None, camera_id=0): 190 | assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode 191 | height = height or self._height 192 | width = width or self._width 193 | camera_id = camera_id or self._camera_id 194 | return self._env.physics.render( 195 | height=height, width=width, camera_id=camera_id 196 | ) -------------------------------------------------------------------------------- /src/buffers/buffer_utils.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import pickle 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def discount_cumsum(x, gamma): 8 | new_x = np.zeros_like(x) 9 | new_x[-1] = x[-1] 10 | for t in reversed(range(x.shape[0] - 1)): 11 | new_x[t] = x[t] + gamma * new_x[t + 1] 12 | return new_x 13 | 14 | def discount_cumsum_np(x, gamma): 15 | # much faster version of the above 16 | new_x = np.zeros_like(x) 17 | rev_cumsum = np.cumsum(np.flip(x, 0)) 18 | new_x = np.flip(rev_cumsum * gamma ** np.arange(0, x.shape[0]), 0) 19 | new_x = np.ascontiguousarray(new_x).astype(np.float32) 20 | return new_x 21 | 22 | 23 | def discount_cumsum_torch(x, gamma): 24 | new_x = torch.zeros_like(x) 25 | rev_cumsum = torch.cumsum(torch.flip(x, [0]), 0) 26 | new_x = torch.flip(rev_cumsum * gamma ** torch.arange(0, x.shape[0], device=x.device), [0]) 27 | new_x = new_x.contiguous().to(dtype=torch.float32) 28 | return new_x 29 | 30 | 31 | def compute_rtg_from_target(x, target_return): 32 | new_x = np.zeros_like(x) 33 | new_x[0] = target_return 34 | for i in range(1, x.shape[0]): 35 | new_x[i] = min(new_x[i - 1] - x[i - 1], target_return) 36 | return new_x 37 | 38 | 39 | def split_train_valid(trajectories, p=1, trj_len_dict=None, seed=1): 40 | # split trajectories into train and validation sets, by trj len weights 41 | if trj_len_dict is not None: 42 | trj_lens = [trj_len_dict[str(t)] for t in trajectories] 43 | else: 44 | trj_lens = [len(t["observations"]) for t in trajectories] 45 | p_train = 1 - p 46 | total_samples = sum(trj_lens) 47 | trajectory_probs = [l / total_samples for l in trj_lens] 48 | # always performs same split via random state 49 | random_state = np.random.RandomState(seed=seed) 50 | idx = random_state.choice(len(trajectories), size=int(len(trajectories) * p_train), 51 | p=trajectory_probs, replace=False) 52 | train_trjs = [trajectories[i] for i in idx] 53 | idx = set(idx) 54 | valid_trjs = [trajectories[i] for i in range(len(trajectories)) if i not in idx] 55 | return train_trjs, valid_trjs 56 | 57 | 58 | def filter_top_p_trajectories(trajectories, top_p=1, epname_to_return=None, bottom=False): 59 | start = len(trajectories) - int(len(trajectories) * top_p) 60 | if epname_to_return is None: 61 | if hasattr(trajectories[0], "rewards"): 62 | def sort_fn(x): return np.array(x.rewards).sum() 63 | else: 64 | def sort_fn(x): return np.array(x.get("rewards")).sum() 65 | else: 66 | def sort_fn(x): return epname_to_return[str(x)] 67 | sorted_trajectories = sorted(trajectories, key=sort_fn, reverse=bottom) 68 | return sorted_trajectories[start:] 69 | 70 | 71 | def filter_trajectories_uniform(trajectories, p=1): 72 | # sample uniformly with trj len weights 73 | trj_lens = [len(t["observations"]) for t in trajectories] 74 | total_samples = sum(trj_lens) 75 | trajectory_probs = [l / total_samples for l in trj_lens] 76 | idx = np.random.choice(len(trajectories), size=int(len(trajectories) * p), p=trajectory_probs, replace=False) 77 | return [trajectories[i] for i in idx] 78 | 79 | 80 | def filter_trajectories_first(trajectories, p=1): 81 | return trajectories[:int(len(trajectories) * p)] 82 | 83 | 84 | def filter_trajectories_last(trajectories, p=1): 85 | return trajectories[int(len(trajectories) * p): ] 86 | 87 | 88 | def load_npz(path, start_idx=None, end_idx=None): 89 | returns_to_go = None 90 | with np.load(path, mmap_mode="r" if start_idx and end_idx else None) as trj: 91 | if start_idx is not None and end_idx is not None: 92 | # subtrajectory only 93 | observations, actions, rewards = trj["states"][start_idx: end_idx].astype(np.float32), \ 94 | trj["actions"][start_idx: end_idx].astype(np.float32), trj["rewards"][start_idx: end_idx].astype(np.float32) 95 | if "returns_to_go" in trj: 96 | returns_to_go = trj["returns_to_go"][start_idx: end_idx].astype(np.float32) 97 | else: 98 | # fully trajectory 99 | observations, actions, rewards = trj["states"], trj["actions"], trj["rewards"], 100 | if "returns_to_go" in trj: 101 | returns_to_go = trj["returns_to_go"].astype(np.float32) 102 | dones = np.array([trj["dones"]]) 103 | return observations, actions, rewards, dones, returns_to_go 104 | 105 | 106 | def load_hdf5(path, start_idx=None, end_idx=None, img_is_encoded=False): 107 | returns_to_go, dones = None, None 108 | with h5py.File(path, "r") as f: 109 | if start_idx is not None and end_idx is not None: 110 | # subtrajectory only 111 | if img_is_encoded: 112 | observations = f['states_encoded'][start_idx: end_idx] 113 | else: 114 | observations = f['states'][start_idx: end_idx] 115 | actions = f['actions'][start_idx: end_idx] 116 | rewards = f['rewards'][start_idx: end_idx] 117 | if "returns_to_go" in f: 118 | returns_to_go = f["returns_to_go"][start_idx: end_idx] 119 | if "dones" in f: 120 | try: 121 | dones = f['dones'][start_idx: end_idx] 122 | except Exception as e: 123 | pass 124 | else: 125 | # fully trajectory 126 | if img_is_encoded: 127 | observations = f['states_encoded'][:] 128 | else: 129 | observations = f['states'][:] 130 | actions = f['actions'][:] 131 | rewards = f['rewards'][:] 132 | if "returns_to_go" in f: 133 | returns_to_go = f["returns_to_go"][:] 134 | if "dones" in f: 135 | try: 136 | dones = f['dones'][:] 137 | except Exception as e: 138 | pass 139 | if dones is None: 140 | dones = np.array([f['dones'][()]]) 141 | return observations, actions, rewards, dones, returns_to_go 142 | 143 | 144 | def append_to_hdf5(path, new_vals, compress_kwargs=None): 145 | compress_kwargs = {"compression": "gzip", "compression_opts": 1} if compress_kwargs is None \ 146 | else compress_kwargs 147 | # open in append mode, add new vals 148 | with h5py.File(str(path), 'a') as f: 149 | for k, v in new_vals.items(): 150 | if k in f: 151 | del f[k] 152 | f.create_dataset(k, data=v, **compress_kwargs) 153 | 154 | 155 | def load_pkl(path, start_idx=None, end_idx=None): 156 | returns_to_go = None 157 | with open(path, "rb") as f: 158 | trj = pickle.load(f) 159 | if start_idx is not None and end_idx is not None: 160 | # subtrajectory only 161 | observations, actions, rewards = trj["states"][start_idx: end_idx], \ 162 | trj["actions"][start_idx: end_idx], trj["rewards"][start_idx: end_idx] 163 | if "returns_to_go" in trj: 164 | returns_to_go = trj["returns_to_go"][start_idx: end_idx] 165 | else: 166 | # fully trajectory 167 | observations, actions, rewards = trj["states"], trj["actions"], trj["rewards"], 168 | if "returns_to_go" in trj: 169 | returns_to_go = trj["returns_to_go"] 170 | dones = np.array([trj["dones"]]) 171 | return observations, actions, rewards, dones, returns_to_go 172 | 173 | 174 | def compute_start_end_context_idx(idx, seq_len, cache_len, future_cache_len, full_context_len=True, dynamic_len=False): 175 | start = max(0, idx - cache_len) 176 | end = min(seq_len, idx + future_cache_len) 177 | if dynamic_len: 178 | start = np.random.randint(start, idx + 1) 179 | end = np.random.randint(idx, end + 1) 180 | elif full_context_len: 181 | total_cache_len = cache_len + future_cache_len 182 | if end - start < total_cache_len: 183 | if start > 0: 184 | start -= total_cache_len - (end - start) 185 | else: 186 | end += total_cache_len - (end - start) 187 | start = max(0, start) 188 | end = min(seq_len, end) 189 | return start, end 190 | -------------------------------------------------------------------------------- /src/algos/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import omegaconf 3 | import numpy as np 4 | from pathlib import Path 5 | from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise 6 | from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape 7 | from transformers import DecisionTransformerConfig 8 | from . import get_model_class, get_agent_class, AGENT_CLASSES 9 | from ..buffers import make_buffer_class 10 | 11 | 12 | def make_agent(config, env, logdir): 13 | state_dim = get_obs_shape(env.observation_space)[0] 14 | act_dim = orig_act_dim = get_action_dim(env.action_space) 15 | agent_params_dict = omegaconf.OmegaConf.to_container(config.agent_params, resolve=True, throw_on_missing=True) 16 | agent_kind = agent_params_dict.pop("kind") 17 | agent_load_path = agent_params_dict.pop("load_path", None) 18 | agent_load_path = Path(agent_load_path["dir_path"]) / agent_load_path["file_name"] \ 19 | if isinstance(agent_load_path, dict) else agent_load_path 20 | if agent_kind in AGENT_CLASSES.keys(): 21 | if agent_kind in ["MDDT", "DDT", "MDDXLSTM", "MDDMamba"]: 22 | # https://github.com/pytorch/pytorch/issues/11201#issuecomment-421146936 23 | import torch.multiprocessing 24 | torch.multiprocessing.set_sharing_strategy('file_system') 25 | if config.get("ddp", False): 26 | import torch.multiprocessing 27 | torch.multiprocessing.set_start_method('spawn', force=True) 28 | local_rank = int(os.environ["LOCAL_RANK"]) 29 | torch.cuda.set_device(f"cuda:{local_rank}") 30 | 31 | # prespecified state/action dims in case of mixed spaces 32 | max_state_dim, max_act_dim = config.agent_params.replay_buffer_kwargs.get("max_state_dim", None), \ 33 | config.agent_params.replay_buffer_kwargs.get("max_act_dim", None) 34 | if max_state_dim is not None: 35 | state_dim = max_state_dim 36 | if max_act_dim is not None: 37 | act_dim = max_act_dim 38 | 39 | # state/action projections for randomization 40 | s_proj_dim, a_proj_dim = config.agent_params.get("s_proj_dim", None), \ 41 | config.agent_params.get("a_proj_dim", None) 42 | if s_proj_dim is not None: 43 | state_dim = s_proj_dim 44 | if a_proj_dim is not None: 45 | act_dim = a_proj_dim 46 | 47 | # huggingface specific params 48 | agent_huggingface_params = agent_params_dict.pop("huggingface") 49 | config_class = DecisionTransformerConfig 50 | if "mamba" in agent_kind.lower(): 51 | from .models.decision_mamba import MambaConfig 52 | config_class = MambaConfig 53 | elif "xlstm" in agent_kind.lower(): 54 | from .models.decision_xlstm import xLSTMConfig 55 | config_class = xLSTMConfig 56 | dt_config = config_class( 57 | state_dim=state_dim, 58 | act_dim=act_dim, 59 | **agent_huggingface_params 60 | ) 61 | 62 | # model specific params 63 | model_kwargs = agent_params_dict.pop("model_kwargs", {}) 64 | if max_act_dim is not None: 65 | model_kwargs["max_act_dim"] = max_act_dim 66 | if a_proj_dim is not None: 67 | model_kwargs["orig_act_dim"] = orig_act_dim 68 | 69 | # exploration specific params 70 | action_noise = make_action_noise(act_dim, agent_params_dict) 71 | 72 | # replay buffer class 73 | replay_buffer_class = make_buffer_class(agent_params_dict["replay_buffer_kwargs"].pop("kind", "default")) 74 | 75 | # compose additional agent kwargs 76 | target_return = config.env_params.target_return 77 | reward_scale = config.env_params.reward_scale 78 | add_agent_kwargs = { 79 | "device": config.device, 80 | "seed": config.seed, 81 | "action_noise": action_noise, 82 | "load_path": agent_load_path, 83 | "replay_buffer_class": replay_buffer_class, 84 | "ddp": config.get("ddp", False), 85 | "tensorboard_log": logdir if config.use_wandb else None, 86 | "target_return": target_return / reward_scale if isinstance(reward_scale, (int, float)) \ 87 | and isinstance(target_return, (int, float)) else target_return, 88 | "reward_scale": reward_scale if isinstance(reward_scale, (int, float)) else dict(reward_scale), 89 | } 90 | 91 | # make DT model 92 | policy = get_model_class(agent_kind)( 93 | dt_config, env.observation_space, env.action_space, 94 | stochastic_policy=agent_params_dict["stochastic_policy"], 95 | **model_kwargs 96 | ) 97 | 98 | # make DT agent 99 | agent = get_agent_class(agent_kind)( 100 | policy, 101 | env, 102 | **add_agent_kwargs, 103 | **agent_params_dict 104 | ) 105 | elif agent_kind in ["SAC"]: 106 | from stable_baselines3 import SAC 107 | policy, policy_kwargs = agent_params_dict.pop("policy"), agent_params_dict.pop("policy_kwargs", {}) 108 | extra_encoder = agent_params_dict.pop("extra_encoder") 109 | share_features_extractor = agent_params_dict.pop("share_features_extractor") 110 | features_extractor_arch = agent_params_dict.pop("features_extractor_arch") 111 | if extra_encoder: 112 | from src.algos.models.extractors import FlattenExtractorWithMLP 113 | policy_kwargs.update({"features_extractor_class": FlattenExtractorWithMLP, 114 | "share_features_extractor": share_features_extractor, 115 | "features_extractor_kwargs": {"net_arch": features_extractor_arch}}) 116 | agent = SAC(policy=policy, 117 | env=env, 118 | device=config.device, 119 | seed=config.seed, 120 | tensorboard_log=logdir if config.use_wandb else None, 121 | verbose=1, 122 | policy_kwargs=policy_kwargs, 123 | **agent_params_dict) 124 | print(agent.policy) 125 | elif agent_kind == "TD3": 126 | from stable_baselines3 import TD3 127 | policy = agent_params_dict.pop("policy") 128 | agent = TD3(policy=policy, 129 | env=env, 130 | device=config.device, 131 | seed=config.seed, 132 | tensorboard_log=logdir if config.use_wandb else None, 133 | verbose=1, 134 | action_noise=NormalActionNoise(mean=np.zeros(act_dim), sigma=0.1 * np.ones(act_dim)), 135 | **agent_params_dict) 136 | print(agent.policy) 137 | elif agent_kind in ["PPO", "DQN"]: 138 | policy = agent_params_dict.pop("policy") 139 | if agent_kind == "PPO": 140 | from .ppo_with_buffer import PPOWithBuffer 141 | agent_class = PPOWithBuffer 142 | elif agent_kind == "DQN": 143 | from stable_baselines3 import DQN 144 | agent_class = DQN 145 | agent = agent_class(policy=policy, 146 | env=env, 147 | device=config.device, 148 | seed=config.seed, 149 | tensorboard_log=logdir if config.use_wandb else None, 150 | verbose=1, 151 | **agent_params_dict) 152 | print(agent.policy) 153 | else: 154 | raise NotImplementedError 155 | return agent 156 | 157 | 158 | def make_action_noise(act_dim, agent_params_dict): 159 | action_noise_std = agent_params_dict.pop("action_noise_std", None) 160 | ou_noise = agent_params_dict.pop("ou_noise", False) 161 | if ou_noise: 162 | return OrnsteinUhlenbeckActionNoise(mean=np.zeros(act_dim), 163 | sigma=action_noise_std * np.ones(act_dim)) \ 164 | if action_noise_std is not None else None 165 | else: 166 | return NormalActionNoise(mean=np.zeros(act_dim), sigma=action_noise_std * np.ones(act_dim)) \ 167 | if action_noise_std is not None else None 168 | -------------------------------------------------------------------------------- /src/callbacks/validation_callback.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import collections 3 | import torch 4 | from tqdm import tqdm 5 | from stable_baselines3.common.callbacks import EventCallback 6 | from stable_baselines3.common.logger import HumanOutputFormat 7 | from ..buffers.buffer_utils import filter_top_p_trajectories 8 | 9 | 10 | class ValidationCallback(EventCallback): 11 | """ 12 | Callback to compute loss metrics on validation set every n steps. 13 | 14 | """ 15 | def __init__( 16 | self, 17 | eval_freq=10000, 18 | n_batches=2000, 19 | first_step=True, 20 | prefix="valid", 21 | splits=["full", "top_50", "bottom_50"], 22 | **kwargs 23 | ): 24 | super().__init__(**kwargs) 25 | self.first_step = first_step 26 | self.prefix = prefix 27 | self.eval_freq = eval_freq 28 | self.n_batch = n_batches 29 | self.splits = splits 30 | 31 | def init_callback(self, model) -> None: 32 | super().init_callback(model) 33 | if self.callback is not None: 34 | self.callback.init_callback(self.model) 35 | self._setup_validation_buffer() 36 | # configure Logger to display more than 36 characters --> this kills runs due to duplicate keys 37 | # increase max_length for now. 38 | for format in self.logger.output_formats: 39 | if isinstance(format, HumanOutputFormat): 40 | format.max_length = 96 41 | 42 | def _setup_validation_buffer(self): 43 | buffer_class = self.model.replay_buffer_class 44 | original_buffer = self.model.replay_buffer 45 | self.validation_buffer = buffer_class( 46 | self.model.buffer_size, 47 | self.model.observation_space, 48 | self.model.action_space, 49 | **self.model.replay_buffer_kwargs, 50 | ) 51 | assert original_buffer._valid_trajectories is not None 52 | # extract validation set from original buffer 53 | self.validation_buffer.init_from_existing_buffer(original_buffer, validation=True) 54 | 55 | def _on_step(self) -> bool: 56 | continue_training = True 57 | if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0: 58 | print("Validating...") 59 | self.logger.dump(self.num_timesteps) 60 | # compute metrics for given splits --> e.g., all, top50%, bottom50% 61 | self.compute_metrics_for_all_splits() 62 | self.logger.dump(self.num_timesteps) 63 | if hasattr(self.model, "ddp") and self.model.ddp: 64 | torch.distributed.barrier() 65 | return continue_training 66 | 67 | def _on_training_start(self): 68 | if self.first_step: 69 | # Do an initial validation before training 70 | print("Initial validation...") 71 | self._on_step() 72 | 73 | def compute_metrics_for_all_splits(self): 74 | original_trjs = copy.deepcopy(self.validation_buffer._trajectories) 75 | if hasattr(self.validation_buffer, "domain_to_indices"): 76 | original_domain_to_indices = copy.deepcopy(self.validation_buffer.domain_to_indices) 77 | for split in self.splits: 78 | n_batch = self.n_batch 79 | if split != "full": 80 | # no need to filter for full split 81 | top_bottom, p = split.split("_") 82 | p = int(p) / 100 83 | n_batch = round(n_batch * p) 84 | # for each only keep top/bottom p trajectories 85 | filtered_trjs, filtered_domain_to_indices = self.extract_filtered_trajectories(top_bottom, p) 86 | self.validation_buffer._trajectories = filtered_trjs 87 | self.reset_valid_buffer() 88 | if filtered_domain_to_indices is not None: 89 | # for multidomain buffer 90 | self.validation_buffer.domain_to_indices = filtered_domain_to_indices 91 | # compute metrics for the split 92 | self.compute_metrics_for_single_split(split, n_batch=n_batch) 93 | # reset trajectory split to original 94 | self.validation_buffer._trajectories = original_trjs 95 | self.reset_valid_buffer() 96 | if hasattr(self.validation_buffer, "domain_to_indices"): 97 | self.validation_buffer.domain_to_indices = original_domain_to_indices 98 | 99 | def extract_filtered_trajectories(self, top_bottom, p): 100 | filtered_trjs = collections.deque(maxlen=self.validation_buffer.buffer_size) 101 | filtered_domain_to_indices = None 102 | if hasattr(self.validation_buffer, "domain_to_indices"): 103 | filtered_domain_to_indices = collections.defaultdict(list) 104 | for task_id, trjs in self.validation_buffer.task_to_trj.items(): 105 | task_trjs = filter_top_p_trajectories(trjs, top_p=p, 106 | epname_to_return=self.validation_buffer.trj_to_return, 107 | bottom=top_bottom == "bottom") 108 | filtered_trjs += task_trjs 109 | if hasattr(self.validation_buffer, "domain_to_indices"): 110 | # multidomain buffer 111 | domain = self.validation_buffer.task_to_domain[task_id] 112 | start_idx = filtered_domain_to_indices[domain][-1] + 1 \ 113 | if len(filtered_domain_to_indices[domain]) > 0 else 0 114 | domain_indices = list(range(start_idx, start_idx + len(task_trjs))) 115 | filtered_domain_to_indices[domain] += domain_indices 116 | return filtered_trjs, filtered_domain_to_indices 117 | 118 | def compute_metrics_for_single_split(self, split_prefix, n_batch): 119 | for _ in tqdm(range(n_batch), desc=f"Validating {split_prefix}"): 120 | metrics = self._compute_metrics() 121 | for k, v in metrics.items(): 122 | self.logger.record_mean(f"{self.prefix}/{split_prefix}/{k}", v) 123 | if self.model.accumulation_steps > 1 and hasattr(self.validation_buffer, "domain_names") \ 124 | and self.validation_buffer.domain_names is not None: 125 | domain_name = self.validation_buffer.get_current_domain_name(self.model.accumulation_steps) 126 | self.logger.record_mean(f"{self.prefix}/{split_prefix}/{domain_name}/{k}", v) 127 | 128 | @torch.no_grad() 129 | def _compute_metrics(self): 130 | # get batch 131 | observations, actions, next_observations, rewards, rewards_to_go, timesteps, attention_mask, \ 132 | dones, task_ids, trj_ids, action_targets, action_mask, prompt, _, trj_seeds = self.model.sample_batch( 133 | self.model.batch_size, buffer=self.validation_buffer 134 | ) 135 | with torch.autocast(device_type='cuda', dtype=self.model.amp_dtype, enabled=self.model.use_amp): 136 | # compute model output 137 | policy_output = self.model.policy( 138 | states=observations, 139 | actions=actions, 140 | rewards=rewards, 141 | returns_to_go=rewards_to_go, 142 | timesteps=timesteps.long(), 143 | attention_mask=attention_mask, 144 | return_dict=True, 145 | with_log_probs=self.model.stochastic_policy, 146 | deterministic=False, 147 | prompt=prompt, 148 | task_id=self.model.current_task_id_tensor, 149 | ddp_kwargs=self.model.ddp_kwargs, 150 | ) 151 | # compute loss 152 | _, loss_dict = self.model.compute_policy_loss( 153 | policy_output, action_targets, attention_mask, 0, 154 | ent_tuning=False, return_targets=rewards_to_go, 155 | reward_targets=rewards, state_targets=observations, dones=dones, 156 | timesteps=timesteps, next_states=next_observations, action_mask=action_mask 157 | ) 158 | return loss_dict 159 | 160 | def reset_valid_buffer(self): 161 | self.validation_buffer.trj_ds_has_changed = True 162 | self.validation_buffer.trj_loader = None 163 | self.validation_buffer.trj_dataset = None 164 | self.validation_buffer.num_sampled_batches = 0 165 | -------------------------------------------------------------------------------- /src/buffers/trajectory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .buffer_utils import discount_cumsum_np, compute_rtg_from_target 3 | 4 | 5 | class Trajectory: 6 | 7 | def __init__(self, obs_shape, action_dim, max_len=1024, task_id=0, trj_id=0, trj_seed=0, 8 | relative_pos_embds=False, handle_timeout_termination=True, 9 | sample_full_seqs_only=False, last_seq_only=False, init_trj_buffers=True, 10 | episodic=False): 11 | self.obs_shape = obs_shape 12 | self.action_dim = action_dim 13 | self.max_len = max_len 14 | self.relative_pos_embds = relative_pos_embds 15 | self.handle_timeout_termination = handle_timeout_termination 16 | self.sample_full_seqs_only = sample_full_seqs_only 17 | self.task_id = task_id 18 | self.trj_id = trj_id 19 | self.trj_seed = trj_seed 20 | self.last_seq_only = last_seq_only 21 | self.init_trj_buffers = init_trj_buffers 22 | self.episodic = episodic 23 | if self.init_trj_buffers: 24 | assert obs_shape is not None and action_dim is not None, "obs_shape and action_dim must be provided" 25 | self.observations = np.zeros((self.max_len, ) + self.obs_shape, dtype=np.float32) 26 | self.next_observations = np.zeros((self.max_len, ) + self.obs_shape, dtype=np.float32) 27 | self.actions = np.zeros((self.max_len, self.action_dim), dtype=np.float32) 28 | self.rewards = np.zeros((self.max_len), dtype=np.float32) 29 | self.timesteps = np.zeros((self.max_len), dtype=np.float32) 30 | self.timeouts = np.zeros((self.max_len), dtype=np.float32) 31 | self.pos = 0 32 | self.full = False 33 | self.returns_to_go = None 34 | 35 | def add(self, obs, next_obs, action, reward, done, infos=None): 36 | self.observations[self.pos] = np.array(obs).copy() 37 | self.next_observations[self.pos] = np.array(next_obs).copy() 38 | self.actions[self.pos] = np.array(action).copy() 39 | self.rewards[self.pos] = np.array(reward).copy() 40 | self.timesteps[self.pos] = np.array(self.pos).copy() 41 | if self.handle_timeout_termination: 42 | # we assume there is only one environment --> index 0 of infos 43 | self.timeouts[self.pos] = np.array(infos[0].get("TimeLimit.truncated", False)) 44 | self.pos += 1 45 | if done: 46 | self.add_dones() 47 | if self.pos == self.max_len: 48 | self.full = True 49 | return self.full 50 | 51 | def add_full_trj(self, obs, next_obs, action, reward, done, task_id, trj_id, returns_to_go=None, trj_seed=None): 52 | self.pos = len(obs) 53 | if self.episodic: 54 | total_reward = np.sum(reward) 55 | reward = np.zeros_like(reward) 56 | reward[-1] = total_reward 57 | if self.init_trj_buffers: 58 | # buffers already, exist populate them with given trajectory 59 | self.observations[:self.pos] = obs 60 | if next_obs is not None: 61 | self.next_observations[:self.pos] = next_obs 62 | self.actions[:self.pos] = action 63 | self.rewards[:self.pos] = reward 64 | self.timesteps[:self.pos] = np.arange(0, self.pos) 65 | else: 66 | # buffers do not exist, assign using the given trajectory 67 | self.observations = obs 68 | self.next_observations = next_obs 69 | self.actions = action 70 | self.rewards = reward 71 | self.timesteps = np.arange(0, self.pos) 72 | self.timeouts = np.zeros((self.pos), dtype=np.float32) 73 | self.add_dones(is_done=done[-1]) 74 | self.full = True 75 | self.task_id = task_id 76 | self.trj_id = trj_id 77 | self.trj_seed = trj_seed 78 | self.returns_to_go = returns_to_go 79 | 80 | def sample(self, context_len=1, full_trj=False, end_idx=None, sample_start=False): 81 | """ 82 | Samples a trajectory from the buffer. 83 | 84 | It is important to sample the end_idx first. 85 | Otherwise we may get trajectories that can't actually happen during evaluation: 86 | E.g., for context len=20, could have start_idx=5, end_idx=10 --> can't happen 87 | If we sampled end_idx first: end_idx=10 and start_idx=min(10 - context_len, 0), which could be a valid trj. 88 | 89 | Args: 90 | context_len (int, optional): The length of the context to include in the trajectory. Defaults to 1. 91 | full_trj (bool, optional): Whether to sample the full trajectory. Defaults to False. 92 | end_idx (int, optional): The index of the end of the trajectory. Defaults to None. 93 | 94 | Returns: 95 | tuple: A tuple containing the sampled trajectory. 96 | """ 97 | if full_trj: 98 | start, end = 0, self.pos 99 | elif end_idx is not None: 100 | end = min(end_idx, self.pos) 101 | if sample_start: 102 | start = np.random.randint(max(end - context_len, 0), end, size=1)[0] 103 | else: 104 | start = max(0, end - context_len) 105 | else: 106 | end = np.random.randint(1, self.pos, size=1)[0] if self.pos > 1 else 1 107 | if sample_start: 108 | start = np.random.randint(max(end - context_len, 0), end, size=1)[0] 109 | else: 110 | start = max(0, end - context_len) 111 | if self.last_seq_only and start < context_len: 112 | # ensure that, the agent also has the possibility to see the first n steps of a trajectory 113 | end = np.random.randint(start + 1, min(start + context_len, self.pos)) 114 | if self.sample_full_seqs_only and (end - start) < context_len: 115 | # ensure that only full sequences are sampled 116 | residual = context_len - (end - start) 117 | if context_len > self.pos: 118 | start = 0 119 | end = self.pos 120 | elif start - residual >= 0: 121 | start -= residual 122 | elif end + residual < self.pos: 123 | end += residual 124 | return self._get_samples(start, end) 125 | 126 | def _get_samples(self, start, end): 127 | timesteps = self.timesteps[start: end] 128 | dones = self.dones[start: end] 129 | if self.relative_pos_embds: 130 | timesteps = np.arange(len(timesteps)) 131 | if self.handle_timeout_termination: 132 | dones = (dones * (1 - self.timeouts[start: end])) 133 | obs = self.observations[start: end] 134 | return obs, \ 135 | self.next_observations[start: end] if self.next_observations is not None else np.zeros_like(obs), \ 136 | self.actions[start: end], self.rewards[start: end], \ 137 | self.returns_to_go[start: end], \ 138 | timesteps, dones, self.task_id, self.trj_id, self.trj_seed 139 | 140 | def prune_trajectory(self): 141 | # to avoid OOM issues. 142 | self.observations = self.observations[:self.pos] 143 | self.next_observations = self.next_observations[:self.pos] if self.next_observations is not None else None 144 | self.actions = self.actions[:self.pos] 145 | self.rewards = self.rewards[:self.pos] 146 | self.timesteps = self.timesteps[:self.pos] 147 | self.timeouts = self.timeouts[:self.pos] 148 | 149 | def setup_final_trj(self, target_return=None, compute_stats=True): 150 | self.compute_returns_to_go(target_return=target_return) 151 | if compute_stats: 152 | self.compute_mean_reward() 153 | self.compute_std_reward() 154 | self.prune_trajectory() 155 | 156 | def compute_returns_to_go(self, target_return=None): 157 | if self.returns_to_go is not None: 158 | # was already initialized when adding full trajectory 159 | return 160 | if target_return is not None: 161 | self.returns_to_go = compute_rtg_from_target(self.rewards, target_return) 162 | self.total_return = self.rewards.sum() 163 | else: 164 | self.returns_to_go = discount_cumsum_np(self.rewards[:self.pos], 1) 165 | self.total_return = self.returns_to_go[0] 166 | 167 | def compute_mean_reward(self): 168 | self.mean_reward = self.rewards[:self.pos].mean() 169 | 170 | def compute_std_reward(self): 171 | self.std_reward = self.rewards[:self.pos].std() 172 | 173 | def add_dones(self, is_done=True): 174 | self.dones = np.zeros(self.pos) 175 | if is_done: 176 | self.dones[-1] = 1 177 | 178 | def size(self): 179 | return self.pos 180 | 181 | def __len__(self): 182 | return self.pos 183 | --------------------------------------------------------------------------------