├── src ├── __init__.py ├── data │ └── __init__.py ├── envs │ ├── __init__.py │ ├── hn_scores.py │ ├── procgen_utils.py │ ├── dmcontrol_utils.py │ └── dn_scores.py ├── augmentations │ ├── __init__.py │ └── augs.py ├── callbacks │ ├── __init__.py │ └── builder.py ├── utils │ ├── __init__.py │ ├── loss_functions.py │ ├── debug.py │ └── misc.py ├── exploration │ ├── __init__.py │ └── adaptive_param_noise.py ├── schedulers │ ├── __init__.py │ ├── lr_schedulers.py │ ├── schedulers.py │ └── visualize_schedulers.py ├── buffers │ ├── __init__.py │ ├── dataloaders.py │ ├── cache_dataset.py │ ├── buffer_utils.py │ └── trajectory.py ├── algos │ ├── models │ │ ├── __init__.py │ │ ├── rms_norm.py │ │ ├── token_learner.py │ │ ├── model_utils.py │ │ ├── extractors.py │ │ ├── universal_decision_transformer_model.py │ │ └── rope.py │ ├── __init__.py │ └── discrete_decision_transformer_sb3.py ├── tokenizers_custom │ ├── base_tokenizer.py │ ├── __init__.py │ ├── mu_law_tokenizer.py │ └── minmax_tokenizer.py └── optimizers │ └── __init__.py ├── configs ├── run_params │ ├── base.yaml │ ├── evaluate.yaml │ ├── finetune.yaml │ ├── pretrain.yaml │ └── pretrain_icl.yaml ├── agent_params │ ├── model_kwargs │ │ ├── default.yaml │ │ ├── dark_room.yaml │ │ ├── procgen.yaml │ │ ├── mt_disc.yaml │ │ └── dmcontrol.yaml │ ├── data_paths │ │ ├── mt45_v2.yaml │ │ ├── dmcontrol11_icl.yaml │ │ ├── procgen12.yaml │ │ ├── names │ │ │ ├── dmcontrol11.yaml │ │ │ └── mt45_v2.yaml │ │ ├── mazerunner15x15.yaml │ │ ├── dark_room_10x10_train.yaml │ │ ├── dark_room_20x20_train.yaml │ │ ├── dark_room_40x20_train.yaml │ │ ├── dark_keydoor_10x10_train.yaml │ │ ├── dark_keydoor_20x20_train.yaml │ │ └── dark_keydoor_40x20_train.yaml │ ├── lr_sched_kwargs │ │ ├── cosine_restart.yaml │ │ ├── cyclic.yaml │ │ └── cosine.yaml │ ├── ppo.yaml │ ├── replay_buffer_kwargs │ │ ├── dmcontrol_icl.yaml │ │ └── single_domain_disc.yaml │ ├── huggingface │ │ ├── dt_huge.yaml │ │ ├── dt_large.yaml │ │ ├── dt_larger.yaml │ │ ├── dt_medium.yaml │ │ ├── dt_small.yaml │ │ ├── dt_small_64.yaml │ │ ├── dt_large_64.yaml │ │ ├── dt_largeplus_64.yaml │ │ ├── dt_medium_64.yaml │ │ ├── dt_mediumplus_64.yaml │ │ └── dt_hugeplus.yaml │ ├── odt.yaml │ ├── ppo_procgen.yaml │ ├── ddt.yaml │ ├── retriever_kwargs │ │ ├── discrete_s_a.yaml │ │ ├── discrete_r_rtg.yaml │ │ ├── discrete_r_rtg_mazerunner.yaml │ │ ├── discrete_s_r_rtg.yaml │ │ ├── discrete_s_r_rtg_20x20.yaml │ │ ├── discrete_s_r_rtg_40x20.yaml │ │ ├── cont_r_rtg_dmc.yaml │ │ └── cont_r_rtg.yaml │ ├── cdt_pretrain_disc.yaml │ ├── dt_gridworlds.yaml │ ├── ad_gridworlds.yaml │ ├── radt_icl.yaml │ ├── radt_mazerunner.yaml │ ├── radt_disc_icl.yaml │ └── radt_procgen.yaml ├── eval_params │ ├── base.yaml │ ├── pretrain.yaml │ ├── pretrain_disc.yaml │ ├── pretrain_icl.yaml │ ├── pretrain_icl_mt.yaml │ └── pretrain_icl_grids.yaml ├── wandb_callback_params │ └── pretrain.yaml ├── env_params │ ├── mujoco_gym.yaml │ ├── dmcontrol_icl.yaml │ ├── mt45_icl.yaml │ ├── dmcontrol_icl_eval.yaml │ ├── procgen.yaml │ ├── procgen_eval.yaml │ ├── mazerunner.yaml │ ├── dark_room.yaml │ ├── dark_room_20x20.yaml │ ├── dark_room_40x20.yaml │ ├── dark_keydoor.yaml │ ├── dark_keydoor_20x20.yaml │ └── dark_keydoor_40x20.yaml └── config.yaml ├── figures └── radt.png ├── .gitmodules ├── dmc2gym_custom ├── setup.py ├── README.md └── dmc2gym_custom │ ├── __init__.py │ └── wrappers.py ├── LICENSE ├── requirements.txt ├── evaluate.py ├── .gitignore ├── precompute_img_embeds.py ├── environment.yaml └── main.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import make_env 2 | -------------------------------------------------------------------------------- /src/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from .augs import make_augmentations -------------------------------------------------------------------------------- /src/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import make_callbacks 2 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import maybe_split, safe_mean 2 | -------------------------------------------------------------------------------- /configs/run_params/base.yaml: -------------------------------------------------------------------------------- 1 | total_timesteps: 1e6 2 | log_interval: 10 -------------------------------------------------------------------------------- /configs/run_params/evaluate.yaml: -------------------------------------------------------------------------------- 1 | log_interval: 1 2 | total_timesteps: 0 3 | -------------------------------------------------------------------------------- /configs/run_params/finetune.yaml: -------------------------------------------------------------------------------- 1 | log_interval: 1000 2 | total_timesteps: 100000 3 | -------------------------------------------------------------------------------- /figures/radt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/RA-DT/HEAD/figures/radt.png -------------------------------------------------------------------------------- /configs/run_params/pretrain.yaml: -------------------------------------------------------------------------------- 1 | log_interval: 1000 2 | total_timesteps: 1000000 3 | -------------------------------------------------------------------------------- /configs/run_params/pretrain_icl.yaml: -------------------------------------------------------------------------------- 1 | log_interval: 1000 2 | total_timesteps: 200000 3 | -------------------------------------------------------------------------------- /src/exploration/__init__.py: -------------------------------------------------------------------------------- 1 | from .adaptive_param_noise import AdaptiveParamNoiseSpec 2 | -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/default.yaml: -------------------------------------------------------------------------------- 1 | reward_condition: True 2 | relative_pos_embds: False 3 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/mt45_v2.yaml: -------------------------------------------------------------------------------- 1 | base: ${DATA_DIR}/metaworld 2 | defaults: 3 | - names: mt45_v2 4 | -------------------------------------------------------------------------------- /src/schedulers/__init__.py: -------------------------------------------------------------------------------- 1 | from .schedulers import make_scheduler 2 | from .lr_schedulers import make_lr_scheduler 3 | -------------------------------------------------------------------------------- /configs/agent_params/lr_sched_kwargs/cosine_restart.yaml: -------------------------------------------------------------------------------- 1 | kind: cosine_restart 2 | T_0: 200000 3 | eta_min: 0.000001 4 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/dmcontrol11_icl.yaml: -------------------------------------------------------------------------------- 1 | base: ${DATA_DIR}/dm_control 2 | defaults: 3 | - names: dmcontrol11 4 | -------------------------------------------------------------------------------- /configs/agent_params/lr_sched_kwargs/cyclic.yaml: -------------------------------------------------------------------------------- 1 | kind: cyclic 2 | base_lr: 0.0001 3 | max_lr: 0.001 4 | step_size_up: 100000 5 | -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/dark_room.yaml: -------------------------------------------------------------------------------- 1 | tokenize_a: False 2 | reward_condition: True 3 | relative_pos_embds: False 4 | -------------------------------------------------------------------------------- /configs/agent_params/ppo.yaml: -------------------------------------------------------------------------------- 1 | kind: "PPO" 2 | policy: "CnnPolicy" 3 | # n_steps: 256 4 | # n_epochs: 3 5 | ent_coef: 0.01 6 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "continual_world"] 2 | path = continual_world 3 | url = https://github.com/awarelab/continual_world.git 4 | -------------------------------------------------------------------------------- /configs/agent_params/replay_buffer_kwargs/dmcontrol_icl.yaml: -------------------------------------------------------------------------------- 1 | num_workers: 16 2 | pin_memory: False 3 | max_act_dim: 6 4 | max_state_dim: 204 -------------------------------------------------------------------------------- /configs/agent_params/lr_sched_kwargs/cosine.yaml: -------------------------------------------------------------------------------- 1 | kind: cosine 2 | T_max: ${run_params.total_timesteps} 3 | #T_max: 150000 4 | eta_min: 0.000001 5 | -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/procgen.yaml: -------------------------------------------------------------------------------- 1 | tokenize_a: False 2 | reward_condition: True 3 | relative_pos_embds: False 4 | use_time_embds: False 5 | -------------------------------------------------------------------------------- /configs/eval_params/base.yaml: -------------------------------------------------------------------------------- 1 | use_eval_callback: False 2 | n_eval_episodes: 10 3 | eval_freq: 10000 4 | max_no_improvement_evals: 0 5 | deterministic: False -------------------------------------------------------------------------------- /configs/eval_params/pretrain.yaml: -------------------------------------------------------------------------------- 1 | use_eval_callback: True 2 | n_eval_episodes: 10 3 | eval_freq: 10000 4 | max_no_improvement_evals: 0 5 | deterministic: False 6 | -------------------------------------------------------------------------------- /configs/eval_params/pretrain_disc.yaml: -------------------------------------------------------------------------------- 1 | use_eval_callback: True 2 | n_eval_episodes: 3 3 | eval_freq: 200000 4 | max_no_improvement_evals: 0 5 | deterministic: False 6 | -------------------------------------------------------------------------------- /configs/wandb_callback_params/pretrain.yaml: -------------------------------------------------------------------------------- 1 | gradient_save_freq: 250 2 | verbose: 1 3 | model_save_path: models 4 | model_sync_wandb: False 5 | model_save_freq: 50000 -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/mt_disc.yaml: -------------------------------------------------------------------------------- 1 | reward_condition: True 2 | tokenize_a: True 3 | action_channels: 64 4 | relative_pos_embds: False 5 | tokenize_rtg: False 6 | -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/dmcontrol.yaml: -------------------------------------------------------------------------------- 1 | reward_condition: True 2 | tokenize_a: True 3 | tokenize_rtg: False 4 | action_channels: 64 5 | relative_pos_embds: False 6 | -------------------------------------------------------------------------------- /configs/agent_params/replay_buffer_kwargs/single_domain_disc.yaml: -------------------------------------------------------------------------------- 1 | kind: default 2 | max_act_dim: 6 3 | max_state_dim: 204 4 | num_workers: 16 5 | pin_memory: False 6 | init_top_p: 1 7 | -------------------------------------------------------------------------------- /configs/env_params/mujoco_gym.yaml: -------------------------------------------------------------------------------- 1 | envid: Hopper-v3 2 | reward_scale: 1000 3 | target_return: 3600 4 | num_envs: 1 5 | norm_obs: False 6 | record: False 7 | record_freq: 250000 8 | record_length: 1000 -------------------------------------------------------------------------------- /configs/eval_params/pretrain_icl.yaml: -------------------------------------------------------------------------------- 1 | use_eval_callback: True 2 | n_eval_episodes: 40 3 | eval_freq: 25000 4 | max_no_improvement_evals: 0 5 | deterministic: False 6 | log_eval_trj: True 7 | eval_on_train: True -------------------------------------------------------------------------------- /configs/eval_params/pretrain_icl_mt.yaml: -------------------------------------------------------------------------------- 1 | use_eval_callback: True 2 | n_eval_episodes: 30 3 | eval_freq: 50000 4 | max_no_improvement_evals: 0 5 | deterministic: False 6 | log_eval_trj: True 7 | eval_on_train: True -------------------------------------------------------------------------------- /configs/eval_params/pretrain_icl_grids.yaml: -------------------------------------------------------------------------------- 1 | use_eval_callback: True 2 | n_eval_episodes: 40 3 | eval_freq: 100000 4 | max_no_improvement_evals: 0 5 | deterministic: False 6 | log_eval_trj: True 7 | eval_on_train: True 8 | first_step: False -------------------------------------------------------------------------------- /src/buffers/__init__.py: -------------------------------------------------------------------------------- 1 | def make_buffer_class(kind): 2 | if kind == "cache": 3 | from .cache import Cache 4 | return Cache 5 | from .trajectory_buffer import TrajectoryReplayBuffer 6 | return TrajectoryReplayBuffer 7 | -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_huge.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | # n_embd is not actually used! hidden_size is also used for the embedding dim in the DT implementation... 3 | n_embd: 512 4 | n_layer: 10 5 | n_head: 20 6 | max_ep_len: 1000 7 | hidden_size: 1280 8 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_large.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | # n_embd is not actually used! hidden_size is also used for the embedding dim in the DT implementation... 3 | n_embd: 512 4 | n_layer: 8 5 | n_head: 8 6 | max_ep_len: 1000 7 | hidden_size: 768 8 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_larger.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | # n_embd is not actually used! hidden_size is also used for the embedding dim in the DT implementation... 3 | n_embd: 512 4 | n_layer: 6 5 | n_head: 16 6 | max_ep_len: 1000 7 | hidden_size: 1024 8 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_medium.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | # n_embd is not actually used! hidden_size is also used for the embedding dim in the DT implementation... 3 | n_embd: 512 4 | n_layer: 4 5 | n_head: 4 6 | max_ep_len: 1000 7 | hidden_size: 512 8 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_small.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | # n_embd is not actually used! hidden_size is also used for the embedding dim in the DT implementation... 3 | n_embd: 512 4 | n_layer: 3 5 | n_head: 2 6 | max_ep_len: 1000 7 | hidden_size: 256 8 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_small_64.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | # n_embd is not actually used! hidden_size is also used for the embedding dim in the DT implementation... 3 | n_embd: 512 4 | n_layer: 3 5 | n_head: 4 6 | max_ep_len: 1000 7 | hidden_size: 256 8 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_large_64.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | # n_embd is not actually used! hidden_size is also used for the embedding dim in the DT implementation... 3 | n_embd: 512 4 | n_layer: 8 5 | n_head: 12 6 | max_ep_len: 1000 7 | hidden_size: 768 8 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_largeplus_64.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | # n_embd is not actually used! hidden_size is also used for the embedding dim in the DT implementation... 3 | n_embd: 512 4 | n_layer: 8 5 | n_head: 16 6 | max_ep_len: 1000 7 | hidden_size: 1024 8 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_medium_64.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | # n_embd is not actually used! hidden_size is also used for the embedding dim in the DT implementation... 3 | n_embd: 512 4 | n_layer: 4 5 | n_head: 8 6 | max_ep_len: 1000 7 | hidden_size: 512 8 | output_attentions: True 9 | -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_mediumplus_64.yaml: -------------------------------------------------------------------------------- 1 | max_length: 20 2 | # n_embd is not actually used! hidden_size is also used for the embedding dim in the DT implementation... 3 | n_embd: 512 4 | n_layer: 6 5 | n_head: 12 6 | max_ep_len: 1000 7 | hidden_size: 768 8 | output_attentions: True -------------------------------------------------------------------------------- /configs/env_params/dmcontrol_icl.yaml: -------------------------------------------------------------------------------- 1 | envid: dmcontrol11_icl 2 | reward_scale: 100 3 | target_return: 1000 4 | num_envs: 1 5 | norm_obs: False 6 | record: False 7 | record_freq: 250000 8 | record_length: 1000 9 | eval_env_names: "dmcontrol5_icl" 10 | dmc_env_kwargs: 11 | flatten_obs: False 12 | -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_hugeplus.yaml: -------------------------------------------------------------------------------- 1 | # 300M parameters 2 | max_length: 20 3 | # n_embd is not actually used! hidden_size is also used for the embedding dim in the DT implementation... 4 | n_embd: 512 5 | n_layer: 24 6 | n_head: 16 7 | max_ep_len: 1000 8 | hidden_size: 1024 9 | output_attentions: True -------------------------------------------------------------------------------- /configs/agent_params/data_paths/procgen12.yaml: -------------------------------------------------------------------------------- 1 | base: ${SSD_DATA_DIR}/procgen_2M 2 | names: 3 | - "bigfish" 4 | - "bossfight" 5 | - "caveflyer" 6 | - "chaser" 7 | - "coinrun" 8 | - "dodgeball" 9 | - "fruitbot" 10 | - "heist" 11 | - "leaper" 12 | - "maze" 13 | - "miner" 14 | - "starpilot" -------------------------------------------------------------------------------- /configs/env_params/mt45_icl.yaml: -------------------------------------------------------------------------------- 1 | envid: "mt45_v2" 2 | reward_scale: 200 3 | target_return: 20000 4 | num_envs: 1 5 | norm_obs: False 6 | record: False 7 | record_freq: 250000 8 | record_length: 1000 9 | randomization: random_init_all 10 | remove_task_ids: True 11 | add_task_ids: False 12 | eval_env_names: "mt5_v2" -------------------------------------------------------------------------------- /configs/agent_params/data_paths/names/dmcontrol11.yaml: -------------------------------------------------------------------------------- 1 | - finger_turn_easy.npz 2 | - fish_upright.npz 3 | - hopper_stand.npz 4 | - point_mass_easy.npz 5 | - walker_stand.npz 6 | - walker_run.npz 7 | - ball_in_cup_catch.npz 8 | - cartpole_swingup.npz 9 | - cheetah_run.npz 10 | - finger_spin.npz 11 | - reacher_easy.npz -------------------------------------------------------------------------------- /dmc2gym_custom/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup( 5 | name='dmc2gym_custom', 6 | version='1.0.0', 7 | author='Thomas Schmied', 8 | description=('a gym like wrapper for dm_control'), 9 | license='', 10 | keywords='gym dm_control openai deepmind', 11 | packages=find_packages(), 12 | ) -------------------------------------------------------------------------------- /configs/agent_params/odt.yaml: -------------------------------------------------------------------------------- 1 | kind: "DT" 2 | # should be same as batch size 3 | learning_starts: 1500 4 | buffer_size: 1000 5 | batch_size: 64 6 | gradient_steps: 10 7 | stochastic_policy: True 8 | loss_fn: "mse" 9 | learning_rate: 0.0004 10 | lr_entropy: 0.0004 11 | eval_context_len: 5 12 | 13 | 14 | defaults: 15 | - huggingface: dt_medium 16 | 17 | -------------------------------------------------------------------------------- /configs/agent_params/ppo_procgen.yaml: -------------------------------------------------------------------------------- 1 | kind: "PPO" 2 | policy: "CnnPolicy" 3 | ent_coef: 0.01 4 | # https://arxiv.org/pdf/1912.01588.pdf 5 | gae_lambda: 0.95 6 | batch_size: 2048 7 | gamma: 0.999 8 | # num_envs == 64 --> 16384 steps 9 | n_steps: 256 10 | n_epochs: 3 11 | learning_rate: 5e-4 12 | 13 | policy_kwargs: 14 | impala: True 15 | net_arch: [] 16 | features_extractor_kwargs: 17 | features_dim: 2048 18 | model_size: 1 19 | -------------------------------------------------------------------------------- /configs/env_params/dmcontrol_icl_eval.yaml: -------------------------------------------------------------------------------- 1 | envid: dmcontrol11_icl 2 | reward_scale: 100 3 | target_return: 1000 4 | num_envs: 1 5 | norm_obs: False 6 | record: True 7 | record_eval: True 8 | # we do 30 eval episodes, every ep has 1000 steps 9 | record_freq: 10000 10 | record_length: 2000 11 | eval_env_names: "dmcontrol5_icl" 12 | dmc_env_kwargs: 13 | flatten_obs: False 14 | render_mode: "rgb_array" 15 | height: 256 16 | width: 256 17 | -------------------------------------------------------------------------------- /src/algos/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .online_decision_transformer_model import OnlineDecisionTransformerModel 2 | from .universal_decision_transformer_model import DummyUDTModel 3 | from .discrete_decision_transformer_model import DiscreteDTModel 4 | from .helm_decision_transformer_model import HelmDTModel, DiscreteHelmDTModel 5 | from .custom_critic import CustomContinuousCritic, MultiHeadContinuousCritic, StateValueFn 6 | from .cache_decision_transformer_model import CacheDTModel, DiscreteCacheDTModel 7 | -------------------------------------------------------------------------------- /configs/env_params/procgen.yaml: -------------------------------------------------------------------------------- 1 | envid: "procgen12" 2 | reward_scale: 1 3 | target_return: 1 4 | num_envs: 1 5 | norm_obs: False 6 | norm_reward: False 7 | record: False 8 | record_freq: 1000000 9 | record_length: 2000 10 | eval_env_names: procgen4 11 | distribution_mode: "easy" 12 | train_eval_seeds: True 13 | time_limit: 400 14 | env_kwargs: 15 | # data was generated with 0 to 199 16 | num_levels: 200 17 | start_level: 0 18 | eval_env_kwargs: 19 | num_levels: 1 20 | start_level: 200 21 | -------------------------------------------------------------------------------- /src/tokenizers_custom/base_tokenizer.py: -------------------------------------------------------------------------------- 1 | class BaseTokenizer: 2 | 3 | def __init__(self, vocab_size=256, shift=0): 4 | self._vocab_size = vocab_size 5 | self._shift = shift 6 | 7 | def tokenize(self, x): 8 | raise NotImplementedError() 9 | 10 | def inv_tokenize(self, x): 11 | raise NotImplementedError() 12 | 13 | @property 14 | def vocab_size(self): 15 | return self._vocab_size 16 | 17 | @property 18 | def shift(self): 19 | return self._shift 20 | -------------------------------------------------------------------------------- /configs/agent_params/ddt.yaml: -------------------------------------------------------------------------------- 1 | kind: "DDT" 2 | # should be same as batch size 3 | learning_starts: 1500 4 | buffer_size: 1000 5 | batch_size: 64 6 | gradient_steps: 10 7 | stochastic_policy: True 8 | loss_fn: "ce" 9 | learning_rate: 0.0004 10 | lr_entropy: 0.0004 11 | eval_context_len: 5 12 | 13 | huggingface: 14 | max_length: 20 15 | # n_embd is not actually used! hidden_size is also used for the embedding dim in the DT implementation... 16 | n_embd: 512 17 | n_layer: 4 18 | n_head: 4 19 | max_ep_len: 1000 20 | hidden_size: 128 21 | -------------------------------------------------------------------------------- /configs/env_params/procgen_eval.yaml: -------------------------------------------------------------------------------- 1 | envid: "procgen12" 2 | reward_scale: 1 3 | target_return: 1 4 | num_envs: 1 5 | norm_obs: False 6 | norm_reward: False 7 | record: True 8 | record_freq: 2000 9 | record_length: 2000 10 | record_eval: True 11 | time_limit: 400 12 | eval_env_names: procgen4 13 | distribution_mode: "easy" 14 | train_eval_seeds: True 15 | env_kwargs: 16 | # data was generated with 0 to 199 17 | num_levels: 200 18 | start_level: 0 19 | render_mode: rgb_array 20 | eval_env_kwargs: 21 | num_levels: 1 22 | start_level: 200 23 | render_mode: rgb_array -------------------------------------------------------------------------------- /configs/agent_params/retriever_kwargs/discrete_s_a.yaml: -------------------------------------------------------------------------------- 1 | kind: DHelmDT 2 | 3 | # positions are done by pre-trained model 4 | global_pos_embds: True 5 | relative_pos_embds: False 6 | use_time_embds: False 7 | stochastic_policy: False 8 | 9 | # rettrieve based on s,a,r,rtg 10 | reward_condition: False 11 | rtg_condition: False 12 | 13 | # frozenhopfield 14 | on_rtgs: False 15 | on_rewards: False 16 | 17 | # tokenization 18 | tokenize_a: False 19 | tokenize_s: True 20 | s_tok_kwargs: 21 | min_val: 0 22 | max_val: 9 23 | vocab_size: 10 24 | one_hot: True 25 | # sinusoidal: True -------------------------------------------------------------------------------- /configs/agent_params/retriever_kwargs/discrete_r_rtg.yaml: -------------------------------------------------------------------------------- 1 | kind: DHelmDT 2 | 3 | # positions are done by pre-trained model 4 | global_pos_embds: True 5 | relative_pos_embds: False 6 | use_time_embds: False 7 | stochastic_policy: False 8 | 9 | # rettrieve based on s,a,r,rtg 10 | reward_condition: True 11 | rtg_condition: True 12 | 13 | # frozenhopfield 14 | on_rtgs: True 15 | 16 | # tokenization 17 | tokenize_a: False 18 | tokenize_r: True 19 | tokenize_rtg: True 20 | r_tok_kwargs: 21 | min_val: 0 22 | max_val: 1 23 | vocab_size: 2 24 | rtg_tok_kwargs: 25 | min_val: 0 26 | max_val: 100 -------------------------------------------------------------------------------- /configs/agent_params/retriever_kwargs/discrete_r_rtg_mazerunner.yaml: -------------------------------------------------------------------------------- 1 | kind: DHelmDT 2 | 3 | # positions are done by pre-trained model 4 | global_pos_embds: True 5 | relative_pos_embds: False 6 | use_time_embds: False 7 | stochastic_policy: False 8 | 9 | # retrieve based on s,a,r,rtg 10 | reward_condition: True 11 | rtg_condition: True 12 | 13 | # frozenhopfield 14 | on_rtgs: True 15 | 16 | # tokenization 17 | tokenize_a: False 18 | tokenize_r: True 19 | tokenize_rtg: True 20 | tokenize_s: False 21 | r_tok_kwargs: 22 | min_val: 0 23 | max_val: 1 24 | vocab_size: 2 25 | rtg_tok_kwargs: 26 | min_val: 0 27 | max_val: 3 28 | vocab_size: 4 29 | -------------------------------------------------------------------------------- /src/tokenizers_custom/__init__.py: -------------------------------------------------------------------------------- 1 | def make_tokenizer(kind, tokenizer_kwargs=None): 2 | if tokenizer_kwargs is None: 3 | tokenizer_kwargs = {} 4 | if kind == 'mulaw': 5 | from .mu_law_tokenizer import MuLawTokenizer 6 | return MuLawTokenizer(**tokenizer_kwargs) 7 | elif kind == 'minmax': 8 | from .minmax_tokenizer import MinMaxTokenizer 9 | return MinMaxTokenizer(**tokenizer_kwargs) 10 | elif kind == 'minmax2': 11 | from .minmax_tokenizer import MinMaxTokenizer2 12 | return MinMaxTokenizer2(**tokenizer_kwargs) 13 | else: 14 | raise ValueError(f"Unknown tokenizer type {kind}") 15 | -------------------------------------------------------------------------------- /configs/agent_params/retriever_kwargs/discrete_s_r_rtg.yaml: -------------------------------------------------------------------------------- 1 | kind: DHelmDT 2 | 3 | # positions are done by pre-trained model 4 | global_pos_embds: True 5 | relative_pos_embds: False 6 | use_time_embds: False 7 | stochastic_policy: False 8 | 9 | # rettrieve based on s,a,r,rtg 10 | reward_condition: True 11 | rtg_condition: True 12 | 13 | # frozenhopfield 14 | on_rtgs: True 15 | 16 | # tokenization 17 | tokenize_a: False 18 | tokenize_r: True 19 | tokenize_rtg: True 20 | tokenize_s: True 21 | r_tok_kwargs: 22 | min_val: 0 23 | max_val: 1 24 | vocab_size: 2 25 | rtg_tok_kwargs: 26 | min_val: 0 27 | max_val: 100 28 | s_tok_kwargs: 29 | min_val: 0 30 | max_val: 9 31 | vocab_size: 10 32 | one_hot: True 33 | -------------------------------------------------------------------------------- /configs/agent_params/retriever_kwargs/discrete_s_r_rtg_20x20.yaml: -------------------------------------------------------------------------------- 1 | kind: DHelmDT 2 | 3 | # positions are done by pre-trained model 4 | global_pos_embds: True 5 | relative_pos_embds: False 6 | use_time_embds: False 7 | stochastic_policy: False 8 | 9 | # rettrieve based on s,a,r,rtg 10 | reward_condition: True 11 | rtg_condition: True 12 | 13 | # frozenhopfield 14 | on_rtgs: True 15 | 16 | # tokenization 17 | tokenize_a: False 18 | tokenize_r: True 19 | tokenize_rtg: True 20 | tokenize_s: True 21 | r_tok_kwargs: 22 | min_val: 0 23 | max_val: 1 24 | vocab_size: 2 25 | rtg_tok_kwargs: 26 | min_val: 0 27 | max_val: 400 28 | vocab_size: 400 29 | s_tok_kwargs: 30 | min_val: 0 31 | max_val: 19 32 | vocab_size: 20 33 | one_hot: True 34 | -------------------------------------------------------------------------------- /configs/agent_params/retriever_kwargs/discrete_s_r_rtg_40x20.yaml: -------------------------------------------------------------------------------- 1 | kind: DHelmDT 2 | 3 | # positions are done by pre-trained model 4 | global_pos_embds: True 5 | relative_pos_embds: False 6 | use_time_embds: False 7 | stochastic_policy: False 8 | 9 | # rettrieve based on s,a,r,rtg 10 | reward_condition: True 11 | rtg_condition: True 12 | 13 | # frozenhopfield 14 | on_rtgs: True 15 | 16 | # tokenization 17 | tokenize_a: False 18 | tokenize_r: True 19 | tokenize_rtg: True 20 | tokenize_s: True 21 | r_tok_kwargs: 22 | min_val: 0 23 | max_val: 1 24 | vocab_size: 2 25 | rtg_tok_kwargs: 26 | min_val: 0 27 | max_val: 800 28 | vocab_size: 800 29 | s_tok_kwargs: 30 | min_val: 0 31 | max_val: 39 32 | vocab_size: 40 33 | one_hot: True 34 | -------------------------------------------------------------------------------- /configs/agent_params/retriever_kwargs/cont_r_rtg_dmc.yaml: -------------------------------------------------------------------------------- 1 | kind: HelmDT 2 | 3 | # positions are done by pre-trained model 4 | global_pos_embds: True 5 | relative_pos_embds: False 6 | use_time_embds: False 7 | stochastic_policy: False 8 | 9 | # retrieve based on s,a,r,rtg 10 | reward_condition: True 11 | rtg_condition: True 12 | 13 | # frozenhopfield 14 | on_rtgs: True 15 | 16 | # tokenization 17 | # we are using reward_scale=100 for dmc, 18 | # min_rewards=0 max_reward=10 / 100 = 0.1 --> 50 bins --> every bin represents 0.25 rewards 19 | # min_rtg=0, max_rtg=1000 / 100 = 10 --> 400 bins --> every bin represents 5 RTG points 20 | r_tok_kwargs: 21 | min_val: 0 22 | max_val: 0.05 23 | vocab_size: 50 24 | rtg_tok_kwargs: 25 | min_val: 0 26 | max_val: 10 27 | vocab_size: 400 -------------------------------------------------------------------------------- /configs/agent_params/retriever_kwargs/cont_r_rtg.yaml: -------------------------------------------------------------------------------- 1 | kind: HelmDT 2 | 3 | # positions are done by pre-trained model 4 | global_pos_embds: True 5 | relative_pos_embds: False 6 | use_time_embds: False 7 | stochastic_policy: False 8 | 9 | # retrieve based on s,a,r,rtg 10 | reward_condition: True 11 | rtg_condition: True 12 | 13 | # frozenhopfield 14 | on_rtgs: True 15 | 16 | # tokenization 17 | # we are using reward_scale=200 for meta-world, 18 | # min_rewards=0 max_reward=10 / 200 = 0.05 --> 50 bins --> every bin represents 0.25 rewards 19 | # min_rtg=0, max_rtg=2000 / 200 = 10 --> 400 bins --> every bin represents 5 RTG points 20 | r_tok_kwargs: 21 | min_val: 0 22 | max_val: 0.05 23 | vocab_size: 40 24 | rtg_tok_kwargs: 25 | min_val: 0 26 | max_val: 10 27 | vocab_size: 400 -------------------------------------------------------------------------------- /dmc2gym_custom/README.md: -------------------------------------------------------------------------------- 1 | # dmc2gym_custom 2 | The original repository is available at: https://github.com/denisyarats/dmc2gym 3 | 4 | The original code base always flattens the observations by default. However, this is impractical for our purposes, 5 | as we need to construct the full observation space. Therefore, we make `flatten_obs` optional. 6 | 7 | ### Instalation 8 | To install `dmc2gym_custom`, run: 9 | ``` 10 | pip install -e . 11 | ``` 12 | from the root of this directory. 13 | 14 | ### Usage 15 | ```python 16 | import dmc2gym_custom 17 | env = dmc2gym_custom.make(domain_name='point_mass', task_name='easy', seed=1, flatten_obs=True) 18 | done = False 19 | obs = env.reset() 20 | while not done: 21 | action = env.action_space.sample() 22 | obs, reward, done, info = env.step(action) 23 | ``` -------------------------------------------------------------------------------- /configs/agent_params/cdt_pretrain_disc.yaml: -------------------------------------------------------------------------------- 1 | kind: "DDT" 2 | use_critic: False 3 | learning_starts: 0 4 | batch_size: 256 5 | gradient_steps: 1 6 | stochastic_policy: False 7 | loss_fn: "ce" 8 | eval_context_len: 5 9 | ent_coef: 0.0 10 | offline_steps: ${run_params.total_timesteps} 11 | buffer_max_len_type: "transition" 12 | buffer_size: 120000000 # 12e7 13 | buffer_weight_by: len 14 | target_return_type: predefined 15 | warmup_steps: 4000 16 | replay_buffer_kwargs: 17 | num_workers: 16 18 | pin_memory: False 19 | init_top_p: 1 20 | use_amp: True 21 | compile: True 22 | defaults: 23 | - huggingface: dt_mediumplus_64 24 | - data_paths: mt40_v2_cwnet_2M 25 | - model_kwargs: default 26 | - lr_sched_kwargs: cosine 27 | huggingface: 28 | # max_ep_len: 201 29 | activation_function: gelu 30 | max_length: 5 31 | -------------------------------------------------------------------------------- /src/algos/models/rms_norm.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py 2 | import torch 3 | import torch.nn as nn 4 | 5 | class LlamaRMSNorm(nn.Module): 6 | def __init__(self, hidden_size, eps=1e-6): 7 | """ 8 | LlamaRMSNorm is equivalent to T5LayerNorm 9 | """ 10 | super().__init__() 11 | self.weight = nn.Parameter(torch.ones(hidden_size)) 12 | self.variance_epsilon = eps 13 | 14 | def forward(self, hidden_states): 15 | input_dtype = hidden_states.dtype 16 | hidden_states = hidden_states.to(torch.float32) 17 | variance = hidden_states.pow(2).mean(-1, keepdim=True) 18 | hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) 19 | return self.weight * hidden_states.to(input_dtype) 20 | -------------------------------------------------------------------------------- /configs/agent_params/dt_gridworlds.yaml: -------------------------------------------------------------------------------- 1 | kind: "DDT" 2 | use_critic: False 3 | learning_starts: 0 4 | batch_size: 128 5 | gradient_steps: 1 6 | stochastic_policy: False 7 | persist_context: False 8 | ent_coef: 0.0 9 | offline_steps: ${run_params.total_timesteps} 10 | buffer_max_len_type: "transition" 11 | buffer_size: 120000000 # 12e7 12 | buffer_weight_by: len 13 | target_return_type: fixed 14 | warmup_steps: 4000 15 | use_amp: True 16 | compile: True 17 | 18 | replay_buffer_kwargs: 19 | num_workers: 16 20 | pin_memory: False 21 | init_top_p: 1 22 | 23 | defaults: 24 | - huggingface: dt_medium_64 25 | - data_paths: dark_room_10x10_sfixed_grand_train 26 | - model_kwargs: dark_room 27 | - lr_sched_kwargs: cosine 28 | 29 | huggingface: 30 | # max_ep_len: 201 31 | activation_function: gelu 32 | max_length: 50 33 | use_fast_attn: True 34 | n_positions: 1600 35 | eval_context_len: ${agent_params.huggingface.max_length} 36 | 37 | model_kwargs: 38 | global_pos_embds: True -------------------------------------------------------------------------------- /src/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def filter_params(params): 5 | # filter out params that don't require gradients 6 | params = list(params) 7 | if isinstance(params[0], dict): 8 | params_filtered = [] 9 | for group in params: 10 | group['params'] = filter(lambda p: p.requires_grad, group['params']) 11 | params_filtered.append(group) 12 | params = params_filtered 13 | else: 14 | params = filter(lambda p: p.requires_grad, params) 15 | return params 16 | 17 | 18 | def make_optimizer(kind, params, lr): 19 | params = filter_params(params) 20 | if kind == "adamw": 21 | optimizer = torch.optim.AdamW(params, lr=lr) 22 | elif kind == "adam": 23 | optimizer = torch.optim.Adam(params, lr=lr) 24 | elif kind == "radam": 25 | optimizer = torch.optim.RAdam(params, lr=lr) 26 | elif kind == "sgd": 27 | optimizer = torch.optim.SGD(params, lr=lr) 28 | elif kind == "rmsprop": 29 | optimizer = torch.optim.RMSprop(params, lr=lr) 30 | else: 31 | raise NotImplementedError() 32 | return optimizer 33 | -------------------------------------------------------------------------------- /configs/agent_params/ad_gridworlds.yaml: -------------------------------------------------------------------------------- 1 | kind: "DDT" 2 | use_critic: False 3 | learning_starts: 0 4 | batch_size: 128 5 | gradient_steps: 1 6 | stochastic_policy: False 7 | loss_fn: "ce" 8 | persist_context: True 9 | ent_coef: 0.0 10 | offline_steps: ${run_params.total_timesteps} 11 | buffer_max_len_type: "transition" 12 | buffer_size: 120000000 # 12e7 13 | buffer_weight_by: len 14 | target_return_type: predefined 15 | warmup_steps: 4000 16 | use_amp: True 17 | compile: True 18 | 19 | replay_buffer_kwargs: 20 | num_workers: 16 21 | pin_memory: False 22 | init_top_p: 1 23 | seqs_per_sample: 2 24 | seq_sample_kind: sequential 25 | full_context_trjs: True 26 | 27 | defaults: 28 | - huggingface: dt_medium_64 29 | - data_paths: dark_room_10x10_sfixed_grand_train 30 | - model_kwargs: dark_room 31 | - lr_sched_kwargs: cosine 32 | 33 | huggingface: 34 | # max_ep_len: 201 35 | activation_function: gelu 36 | use_fast_attn: True 37 | n_positions: 1600 38 | max_length: 100 39 | # eval_context_len: ${agent_params.huggingface.max_length} 40 | eval_context_len: 200 41 | 42 | model_kwargs: 43 | global_pos_embds: True 44 | rtg_condition: False 45 | -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- 1 | # on server 2 | LOG_DIR: ./logs 3 | DATA_DIR: ./data 4 | SSD_DATA_DIR: ./data 5 | MODELS_DIR: ./models 6 | 7 | # directory creation is handled by Hydra 8 | hydra: 9 | sweep: 10 | dir: ${LOG_DIR}/${experiment_name}/${now:%Y-%m-%d_%H-%M-%S} 11 | subdir: ${maybe_split:${hydra.job.override_dirname}}/seed=${seed} 12 | run: 13 | dir: ${LOG_DIR}/${experiment_name}/${now:%Y-%m-%d_%H-%M-%S} 14 | job: 15 | config: 16 | override_dirname: 17 | exclude_keys: 18 | - seed 19 | 20 | # defaults for variable components --> agent_params, env_params 21 | defaults: 22 | - agent_params: odt 23 | - env_params: mujoco_gym 24 | - eval_params: base 25 | - run_params: base 26 | 27 | # General 28 | experiment_name: test 29 | device: "auto" 30 | seed: 42 31 | # Hydra does the logging for us 32 | logdir: '.' 33 | use_wandb: True 34 | 35 | wandb_params: 36 | project: "RA-DT" 37 | sync_tensorboard: True 38 | monitor_gym: True 39 | save_code: True 40 | entity: "X" 41 | key: X 42 | host: https://api.wandb.ai 43 | 44 | wandb_callback_params: 45 | gradient_save_freq: 250 46 | verbose: 1 47 | model_save_path: -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Institute for Machine Learning, Johannes Kepler University Linz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/names/mt45_v2.yaml: -------------------------------------------------------------------------------- 1 | - reach-v2.pkl 2 | - push-v2.pkl 3 | - pick-place-v2.pkl 4 | - door-open-v2.pkl 5 | - drawer-open-v2.pkl 6 | - drawer-close-v2.pkl 7 | - button-press-topdown-v2.pkl 8 | - peg-insert-side-v2.pkl 9 | - window-open-v2.pkl 10 | - window-close-v2.pkl 11 | - door-close-v2.pkl 12 | - reach-wall-v2.pkl 13 | - pick-place-wall-v2.pkl 14 | - push-wall-v2.pkl 15 | - button-press-v2.pkl 16 | - button-press-topdown-wall-v2.pkl 17 | - button-press-wall-v2.pkl 18 | - peg-unplug-side-v2.pkl 19 | - disassemble-v2.pkl 20 | - hammer-v2.pkl 21 | - plate-slide-v2.pkl 22 | - plate-slide-side-v2.pkl 23 | - plate-slide-back-v2.pkl 24 | - plate-slide-back-side-v2.pkl 25 | - handle-press-v2.pkl 26 | - handle-pull-v2.pkl 27 | - handle-press-side-v2.pkl 28 | - handle-pull-side-v2.pkl 29 | - stick-push-v2.pkl 30 | - stick-pull-v2.pkl 31 | - basketball-v2.pkl 32 | - soccer-v2.pkl 33 | - faucet-open-v2.pkl 34 | - faucet-close-v2.pkl 35 | - coffee-push-v2.pkl 36 | - coffee-pull-v2.pkl 37 | - coffee-button-v2.pkl 38 | - sweep-v2.pkl 39 | - sweep-into-v2.pkl 40 | - pick-out-of-hole-v2.pkl 41 | - assembly-v2.pkl 42 | - shelf-place-v2.pkl 43 | - push-back-v2.pkl 44 | - lever-pull-v2.pkl 45 | - dial-turn-v2.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | protobuf==3.20.1 2 | wheel==0.38.0 3 | setuptools==65.5.0 4 | gym==0.21.0 5 | dm-control==1.0.2 6 | dm-env==1.5 7 | dm-tree==0.1.7 8 | einops==0.6.0 9 | hydra-core 10 | matplotlib==3.5.1 11 | lockfile==0.12.2 12 | mujoco_py==2.0.2.5 13 | numpy==1.22.3 14 | omegaconf 15 | pandas==1.4.2 16 | seaborn==0.11.2 17 | stable_baselines3==1.5.0 18 | tensorflow==2.8.0 19 | tqdm==4.64.0 20 | transformers==4.39.1 21 | wandb==0.14.0 22 | gym==0.21.0 23 | # if atary required only 24 | # gym[atari]==0.21.0 25 | # autorom[accept-rom-license] 26 | # ale-py==0.7.4 27 | procgen 28 | cloudpickle==2.1.0 29 | datasets==2.10.1 30 | autofaiss==2.15.8 31 | # fsspec==2022.1.0 32 | git+https://github.com/denisyarats/dmc2gym.git 33 | torch==2.2.2+cu121 34 | torchvision==0.17.2+cu121 35 | torchaudio==2.2.2+cu121 36 | --extra-index-url https://download.pytorch.org/whl/cu121 37 | torchmetrics==1.2.0 38 | h5py==3.6.0 39 | scikit-learn==1.1.3 40 | gymnasium==0.28.1 41 | # may results in issues with nle: https://github.com/facebookresearch/nle/issues/246 42 | # afterwards add to bashrc --> LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$CONDA_PREFIX/lib" 43 | minihack==0.1.5 44 | opencv-python==4.6.0.66 45 | # installing faiss-gpu: https://github.com/facebookresearch/faiss/blob/main/INSTALL.md 46 | # conda install -c conda-forge faiss-gpu=1.7.4 -------------------------------------------------------------------------------- /src/augmentations/augs.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | import torch.nn.functional as F 3 | 4 | 5 | class CustomRandomCrop(T.RandomCrop): 6 | def __init__(self, size=84, padding=4, **kwargs): 7 | super().__init__(size, **kwargs) 8 | self.padding = padding 9 | 10 | def __call__(self, img): 11 | # first pad image by 4 pixels on each side 12 | img = F.pad(img, (self.padding, self.padding, self.padding, self.padding), mode='replicate') 13 | # crop to original size 14 | return super().__call__(img) 15 | 16 | 17 | def make_augmentations(aug_kwargs=None): 18 | if aug_kwargs is None: 19 | aug_kwargs = {} 20 | aug_kwargs = aug_kwargs.copy() 21 | kind = aug_kwargs.pop("kind", "crop_rotate") 22 | p_aug = aug_kwargs.get("p_aug", 0.5) 23 | if kind == "crop": 24 | return T.RandomApply([CustomRandomCrop(**aug_kwargs)], p=p_aug) 25 | elif kind == "rotate": 26 | degrees = aug_kwargs.pop("degrees", 30) 27 | return T.RandomApply([T.RandomRotation(degrees=degrees, **aug_kwargs)], p=p_aug) 28 | elif kind == "crop_rotate": 29 | degrees = aug_kwargs.pop("degrees", 30) 30 | return T.Compose([ 31 | T.RandomApply([CustomRandomCrop(**aug_kwargs)], p=p_aug), 32 | T.RandomApply([T.RandomRotation(degrees=degrees, **aug_kwargs)], p=p_aug) 33 | ]) 34 | raise ValueError(f"Unknown augmentation kind: {kind}") 35 | -------------------------------------------------------------------------------- /src/exploration/adaptive_param_noise.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | From OpenAI Baselines: 4 | https://github.com/openai/baselines/blob/master/baselines/ddpg/noise.py 5 | 6 | """ 7 | 8 | 9 | class AdaptiveParamNoiseSpec(object): 10 | def __init__(self, initial_stddev=0.1, desired_action_stddev=0.2, adaptation_coefficient=1.01): 11 | """ 12 | Note that initial_stddev and current_stddev refer to std of parameter noise, 13 | but desired_action_stddev refers to (as name notes) desired std in action space 14 | """ 15 | self.initial_stddev = initial_stddev 16 | self.desired_action_stddev = desired_action_stddev 17 | self.adaptation_coefficient = adaptation_coefficient 18 | self.current_stddev = initial_stddev 19 | 20 | def adapt(self, distance): 21 | if distance > self.desired_action_stddev: 22 | # Decrease stddev. 23 | self.current_stddev /= self.adaptation_coefficient 24 | else: 25 | # Increase stddev. 26 | self.current_stddev *= self.adaptation_coefficient 27 | 28 | def get_stats(self): 29 | stats = { 30 | 'param_noise_stddev': self.current_stddev, 31 | } 32 | return stats 33 | 34 | def __repr__(self): 35 | fmt = 'AdaptiveParamNoiseSpec(initial_stddev={}, desired_action_stddev={}, adaptation_coefficient={})' 36 | return fmt.format(self.initial_stddev, self.desired_action_stddev, self.adaptation_coefficient) 37 | 38 | def reset(self): 39 | pass 40 | -------------------------------------------------------------------------------- /src/schedulers/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import CosineAnnealingLR 2 | 3 | 4 | class CosineAnnealingLRSingleCycle(CosineAnnealingLR): 5 | 6 | def get_lr(self): 7 | # in case T_max is reached, always return eta_min, don't go up in the cycle again 8 | lrs = super().get_lr() 9 | if self.last_epoch >= self.T_max: 10 | lrs = [self.eta_min for _ in self.optimizer.param_groups] 11 | return lrs 12 | 13 | 14 | def make_lr_scheduler(optimizer, kind="cosine", sched_kwargs=None): 15 | if sched_kwargs is None: 16 | sched_kwargs = {} 17 | if kind == "cosine": 18 | return CosineAnnealingLRSingleCycle(optimizer, **sched_kwargs) 19 | elif kind == "cosine_restart": 20 | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts 21 | return CosineAnnealingWarmRestarts(optimizer, **sched_kwargs) 22 | elif kind == "step": 23 | from torch.optim.lr_scheduler import StepLR 24 | return StepLR(optimizer, **sched_kwargs) 25 | elif kind == "plateau": 26 | from torch.optim.lr_scheduler import ReduceLROnPlateau 27 | return ReduceLROnPlateau(optimizer, **sched_kwargs) 28 | elif kind == "cyclic": 29 | from torch.optim.lr_scheduler import CyclicLR 30 | return CyclicLR(optimizer, cycle_momentum=False, **sched_kwargs) 31 | elif kind == "exp": 32 | from torch.optim.lr_scheduler import ExponentialLR 33 | return ExponentialLR(optimizer, **sched_kwargs) 34 | raise ValueError(f"Unknown scheduler {kind}") 35 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import wandb 3 | import omegaconf 4 | from src.utils import maybe_split 5 | from src.envs import make_env 6 | from src.callbacks import make_callbacks 7 | from src.algos.builder import make_agent 8 | from main import setup_logging 9 | 10 | 11 | def evaluate(agent, callbacks): 12 | _, callback = agent._setup_learn(0, None, callbacks) 13 | agent.policy.eval() 14 | callback.on_training_start(locals(), globals()) 15 | agent._dump_logs() 16 | 17 | 18 | @hydra.main(config_path="configs", config_name="config") 19 | def main(config): 20 | print("Config: \n", omegaconf.OmegaConf.to_yaml(config, resolve=True, sort_keys=True)) 21 | config.agent_params.offline_steps = 0 22 | config.agent_params.data_paths = None 23 | run, logdir = setup_logging(config) 24 | env, eval_env, train_eval_env = make_env(config, logdir) 25 | agent = make_agent(config, env, logdir) 26 | callbacks = make_callbacks(config, env=env, eval_env=eval_env, logdir=logdir, train_eval_env=train_eval_env) 27 | try: 28 | # set training steps to 0 to avoid training 29 | agent.learn( 30 | total_timesteps=0, 31 | eval_env=eval_env, 32 | callback=callbacks 33 | ) 34 | finally: 35 | print("Finalizing run...") 36 | if config.use_wandb: 37 | run.finish() 38 | wandb.finish() 39 | # return last avg reward for hparam optimization 40 | if hasattr(agent, "cache"): 41 | agent.cache.cleanup_cache() 42 | 43 | 44 | if __name__ == "__main__": 45 | omegaconf.OmegaConf.register_new_resolver("maybe_split", maybe_split) 46 | main() 47 | -------------------------------------------------------------------------------- /configs/env_params/mazerunner.yaml: -------------------------------------------------------------------------------- 1 | envid: MazeRunner-15x15-v0 2 | target_return: 3 3 | num_envs: 1 4 | norm_obs: False 5 | record: False 6 | record_freq: 1000000 7 | record_length: 2000 8 | reward_scale: 1 9 | maze_dim: 15 10 | timelimit: 400 11 | env_kwargs: 12 | reward_on_last_goal: False 13 | 14 | train_maze_seed: 15 | - 0 16 | - 1 17 | - 3 18 | - 4 19 | - 6 20 | - 7 21 | - 8 22 | - 9 23 | - 11 24 | - 12 25 | - 14 26 | - 16 27 | - 17 28 | - 18 29 | - 19 30 | - 20 31 | - 22 32 | - 23 33 | - 25 34 | - 26 35 | - 28 36 | - 30 37 | - 32 38 | - 33 39 | - 34 40 | - 35 41 | - 37 42 | - 38 43 | - 40 44 | # - 41 45 | - 42 46 | - 43 47 | - 44 48 | - 45 49 | - 46 50 | - 47 51 | - 48 52 | - 49 53 | - 50 54 | - 51 55 | - 54 56 | - 55 57 | - 56 58 | - 58 59 | - 60 60 | - 61 61 | - 62 62 | - 63 63 | - 64 64 | - 65 65 | - 66 66 | - 67 67 | - 68 68 | - 69 69 | - 70 70 | - 71 71 | - 72 72 | - 73 73 | - 74 74 | - 75 75 | - 76 76 | - 77 77 | - 79 78 | - 80 79 | - 81 80 | - 82 81 | - 83 82 | - 84 83 | - 85 84 | - 86 85 | - 87 86 | - 88 87 | - 89 88 | - 90 89 | - 91 90 | - 92 91 | - 93 92 | - 95 93 | - 96 94 | - 97 95 | - 98 96 | - 99 97 | - 100 98 | - 101 99 | - 102 100 | - 103 101 | - 104 102 | - 106 103 | - 108 104 | - 109 105 | - 110 106 | - 111 107 | - 112 108 | - 113 109 | - 116 110 | - 117 111 | - 118 112 | - 119 113 | - 120 114 | - 121 115 | 116 | eval_maze_seed: 117 | - 300 118 | - 301 119 | - 302 120 | - 303 121 | - 304 122 | - 305 123 | - 306 124 | - 307 125 | - 308 126 | - 309 127 | - 310 128 | - 311 129 | - 312 130 | - 313 131 | - 314 132 | - 315 133 | - 316 134 | - 317 135 | - 318 136 | - 319 137 | -------------------------------------------------------------------------------- /src/buffers/dataloaders.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | 3 | 4 | class MultiEpochsDataLoader(DataLoader): 5 | 6 | def __init__(self, *args, **kwargs): 7 | """ 8 | From: 9 | https://discuss.pytorch.org/t/enumerate-dataloader-slow/87778/4 10 | 11 | Ensures that the Dataset is iterated over multiple times wihthout destroying the workers. 12 | 13 | """ 14 | super().__init__(*args, **kwargs) 15 | self._DataLoader__initialized = False 16 | self.batch_sampler = _RepeatSampler(self.batch_sampler) 17 | self._DataLoader__initialized = True 18 | self.iterator = super().__iter__() 19 | 20 | def __len__(self): 21 | return len(self.batch_sampler.sampler) 22 | 23 | def __iter__(self): 24 | for i in range(len(self)): 25 | yield next(self.iterator) 26 | 27 | 28 | class _RepeatSampler(object): 29 | """ Sampler that repeats forever. 30 | Args: 31 | sampler (Sampler) 32 | """ 33 | 34 | def __init__(self, sampler): 35 | self.sampler = sampler 36 | 37 | def __iter__(self): 38 | while True: 39 | yield from iter(self.sampler) 40 | 41 | 42 | if __name__ == "__main__": 43 | import torch 44 | from torch.utils.data import TensorDataset 45 | 46 | ds = TensorDataset(torch.randn(100, 10)) 47 | loader = MultiEpochsDataLoader(ds, num_workers=4, batch_size=10) 48 | loader_iter = iter(loader) 49 | for i in range(1000): 50 | print(i) 51 | try: 52 | b = next(loader_iter) 53 | except StopIteration as e: 54 | loader_iter = iter(loader) 55 | b = next(loader_iter) 56 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/mazerunner15x15.yaml: -------------------------------------------------------------------------------- 1 | base: ${DATA_DIR}/mazerunner/15x15 2 | names: 3 | - "0.pkl" 4 | - "1.pkl" 5 | - "3.pkl" 6 | - "4.pkl" 7 | - "6.pkl" 8 | - "7.pkl" 9 | - "8.pkl" 10 | - "9.pkl" 11 | - "11.pkl" 12 | - "12.pkl" 13 | - "14.pkl" 14 | - "16.pkl" 15 | - "17.pkl" 16 | - "18.pkl" 17 | - "19.pkl" 18 | - "20.pkl" 19 | - "22.pkl" 20 | - "23.pkl" 21 | - "25.pkl" 22 | - "26.pkl" 23 | - "28.pkl" 24 | - "30.pkl" 25 | - "32.pkl" 26 | - "33.pkl" 27 | - "34.pkl" 28 | - "35.pkl" 29 | - "37.pkl" 30 | - "38.pkl" 31 | - "40.pkl" 32 | - "41.pkl" 33 | - "42.pkl" 34 | - "43.pkl" 35 | - "44.pkl" 36 | - "45.pkl" 37 | - "46.pkl" 38 | - "47.pkl" 39 | - "48.pkl" 40 | - "49.pkl" 41 | - "50.pkl" 42 | - "51.pkl" 43 | - "54.pkl" 44 | - "55.pkl" 45 | - "56.pkl" 46 | - "58.pkl" 47 | - "60.pkl" 48 | - "61.pkl" 49 | - "62.pkl" 50 | - "63.pkl" 51 | - "64.pkl" 52 | - "65.pkl" 53 | - "66.pkl" 54 | - "67.pkl" 55 | - "68.pkl" 56 | - "69.pkl" 57 | - "70.pkl" 58 | - "71.pkl" 59 | - "72.pkl" 60 | - "73.pkl" 61 | - "74.pkl" 62 | - "75.pkl" 63 | - "76.pkl" 64 | - "77.pkl" 65 | - "79.pkl" 66 | - "80.pkl" 67 | - "81.pkl" 68 | - "82.pkl" 69 | - "83.pkl" 70 | - "84.pkl" 71 | - "85.pkl" 72 | - "86.pkl" 73 | - "87.pkl" 74 | - "88.pkl" 75 | - "89.pkl" 76 | - "90.pkl" 77 | - "91.pkl" 78 | - "92.pkl" 79 | - "93.pkl" 80 | - "95.pkl" 81 | - "96.pkl" 82 | - "97.pkl" 83 | - "98.pkl" 84 | - "99.pkl" 85 | - "100.pkl" 86 | - "101.pkl" 87 | - "102.pkl" 88 | - "103.pkl" 89 | - "104.pkl" 90 | - "106.pkl" 91 | - "108.pkl" 92 | - "109.pkl" 93 | - "110.pkl" 94 | - "111.pkl" 95 | - "112.pkl" 96 | - "113.pkl" 97 | - "116.pkl" 98 | - "117.pkl" 99 | - "118.pkl" 100 | - "119.pkl" 101 | - "120.pkl" 102 | - "121.pkl" -------------------------------------------------------------------------------- /dmc2gym_custom/dmc2gym_custom/__init__.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym.envs.registration import register 3 | 4 | 5 | def make( 6 | domain_name, 7 | task_name, 8 | seed=1, 9 | visualize_reward=True, 10 | from_pixels=False, 11 | height=84, 12 | width=84, 13 | camera_id=0, 14 | frame_skip=1, 15 | episode_length=1000, 16 | environment_kwargs=None, 17 | time_limit=None, 18 | channels_first=True, 19 | flatten_obs=True, 20 | deterministic=False 21 | ): 22 | env_id = 'dmc_%s_%s_%s-v1' % (domain_name, task_name, seed) 23 | 24 | if from_pixels: 25 | assert not visualize_reward, 'cannot use visualize reward when learning from pixels' 26 | 27 | # shorten episode length 28 | max_episode_steps = (episode_length + frame_skip - 1) // frame_skip 29 | 30 | # make env kwargs 31 | env_kwargs = dict( 32 | environment_kwargs=environment_kwargs, 33 | visualize_reward=visualize_reward, 34 | from_pixels=from_pixels, 35 | height=height, 36 | width=width, 37 | camera_id=camera_id, 38 | frame_skip=frame_skip, 39 | channels_first=channels_first, 40 | flatten_obs=flatten_obs, 41 | deterministic=deterministic 42 | ) 43 | 44 | if not env_id in gym.envs.registry.env_specs: 45 | task_kwargs = {} 46 | if seed is not None: 47 | task_kwargs['random'] = seed 48 | if time_limit is not None: 49 | task_kwargs['time_limit'] = time_limit 50 | register( 51 | id=env_id, 52 | entry_point='dmc2gym_custom.wrappers:DMCWrapper', 53 | kwargs=dict( 54 | domain_name=domain_name, 55 | task_name=task_name, 56 | task_kwargs=task_kwargs, 57 | **env_kwargs 58 | ), 59 | max_episode_steps=max_episode_steps, 60 | ) 61 | return gym.make(env_id, **env_kwargs) -------------------------------------------------------------------------------- /configs/agent_params/radt_icl.yaml: -------------------------------------------------------------------------------- 1 | kind: "CDT" 2 | use_critic: False 3 | learning_starts: 0 4 | batch_size: 128 5 | gradient_steps: 1 6 | stochastic_policy: False 7 | loss_fn: "mse" 8 | ent_coef: 0.0 9 | offline_steps: ${run_params.total_timesteps} 10 | buffer_max_len_type: "transition" 11 | buffer_size: 100000000 # 8e7 12 | buffer_weight_by: len 13 | warmup_steps: 4000 14 | 15 | learnable_ret: True 16 | use_amp: True 17 | compile: True 18 | frozen: False 19 | sep_eval_cache: True 20 | target_return_type: predefined 21 | # query kwargs 22 | representation_type: mean 23 | cache_steps: 50 24 | cache_len: ${agent_params.cache_steps} 25 | agg_token: s 26 | 27 | freeze_kwargs: 28 | exclude_crossattn: True 29 | 30 | replay_buffer_kwargs: 31 | num_workers: 16 32 | pin_memory: False 33 | 34 | cache_kwargs: 35 | num_workers: 0 36 | prefetch_factor: null 37 | pin_memory: False 38 | init_top_p: 1 39 | exclude_same_trjs: True 40 | task_weight: 1 41 | reweight_top_k: 1 42 | # min_seq_len: 10 43 | share_trjs: True 44 | index_kwargs: 45 | nb_cores: 64 46 | # go with Flat index for now, others have colissions and are slower with high nprobe 47 | index_key: Flat 48 | 49 | eval_cache_kwargs: 50 | index_kwargs: 51 | nb_cores: ${agent_params.cache_kwargs.index_kwargs.nb_cores} 52 | return_weight: 1 53 | top_k: 50 54 | reweight_top_k: ${agent_params.cache_kwargs.reweight_top_k} 55 | 56 | load_path: 57 | dir_path: ${MODELS_DIR}/metaworld_icl 58 | file_name: dt_medium_64.zip 59 | 60 | defaults: 61 | - huggingface: dt_medium_64 62 | - data_paths: mt40_v2_cwnet_2M 63 | - data_paths@cache_data_paths: ${agent_params/data_paths} 64 | - model_kwargs: default 65 | - lr_sched_kwargs: cosine 66 | 67 | huggingface: 68 | activation_function: gelu 69 | max_length: 50 70 | add_cross_attention: True 71 | output_attentions: False 72 | n_positions: 1600 73 | eval_context_len: ${agent_params.huggingface.max_length} 74 | 75 | model_kwargs: 76 | global_pos_embds: True 77 | separate_ca_embed: True 78 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/dark_room_10x10_train.yaml: -------------------------------------------------------------------------------- 1 | base: ${DATA_DIR}/minihack/dark_room/10x10 2 | names: 3 | - "[0,0]_[2,0].pkl" 4 | - "[0,0]_[0,2].pkl" 5 | - "[0,0]_[1,5].pkl" 6 | - "[0,0]_[2,2].pkl" 7 | - "[0,0]_[5,7].pkl" 8 | - "[0,0]_[9,1].pkl" 9 | - "[0,0]_[6,9].pkl" 10 | - "[0,0]_[5,5].pkl" 11 | - "[0,0]_[1,1].pkl" 12 | - "[0,0]_[7,9].pkl" 13 | - "[0,0]_[0,9].pkl" 14 | - "[0,0]_[3,8].pkl" 15 | - "[0,0]_[8,5].pkl" 16 | - "[0,0]_[0,0].pkl" 17 | - "[0,0]_[8,9].pkl" 18 | - "[0,0]_[1,3].pkl" 19 | - "[0,0]_[0,5].pkl" 20 | - "[0,0]_[0,1].pkl" 21 | - "[0,0]_[9,5].pkl" 22 | - "[0,0]_[8,3].pkl" 23 | - "[0,0]_[4,4].pkl" 24 | - "[0,0]_[1,2].pkl" 25 | - "[0,0]_[7,8].pkl" 26 | - "[0,0]_[9,4].pkl" 27 | - "[0,0]_[3,7].pkl" 28 | - "[0,0]_[9,2].pkl" 29 | - "[0,0]_[9,7].pkl" 30 | - "[0,0]_[5,6].pkl" 31 | - "[0,0]_[6,3].pkl" 32 | - "[0,0]_[4,6].pkl" 33 | - "[0,0]_[0,8].pkl" 34 | - "[0,0]_[3,3].pkl" 35 | - "[0,0]_[4,5].pkl" 36 | - "[0,0]_[1,9].pkl" 37 | - "[0,0]_[1,4].pkl" 38 | - "[0,0]_[9,3].pkl" 39 | - "[0,0]_[7,3].pkl" 40 | - "[0,0]_[3,9].pkl" 41 | - "[0,0]_[2,4].pkl" 42 | - "[0,0]_[0,6].pkl" 43 | - "[0,0]_[6,2].pkl" 44 | - "[0,0]_[2,3].pkl" 45 | - "[0,0]_[1,8].pkl" 46 | - "[0,0]_[4,2].pkl" 47 | - "[0,0]_[8,0].pkl" 48 | - "[0,0]_[8,6].pkl" 49 | - "[0,0]_[3,1].pkl" 50 | - "[0,0]_[6,7].pkl" 51 | - "[0,0]_[2,7].pkl" 52 | - "[0,0]_[1,0].pkl" 53 | - "[0,0]_[4,0].pkl" 54 | - "[0,0]_[7,0].pkl" 55 | - "[0,0]_[9,6].pkl" 56 | - "[0,0]_[8,8].pkl" 57 | - "[0,0]_[2,6].pkl" 58 | - "[0,0]_[5,4].pkl" 59 | - "[0,0]_[7,1].pkl" 60 | - "[0,0]_[2,9].pkl" 61 | - "[0,0]_[2,5].pkl" 62 | - "[0,0]_[4,3].pkl" 63 | - "[0,0]_[4,1].pkl" 64 | - "[0,0]_[7,2].pkl" 65 | - "[0,0]_[0,3].pkl" 66 | - "[0,0]_[8,1].pkl" 67 | - "[0,0]_[5,3].pkl" 68 | - "[0,0]_[7,7].pkl" 69 | - "[0,0]_[6,1].pkl" 70 | - "[0,0]_[0,7].pkl" 71 | - "[0,0]_[2,8].pkl" 72 | - "[0,0]_[9,9].pkl" 73 | - "[0,0]_[5,2].pkl" 74 | - "[0,0]_[4,8].pkl" 75 | - "[0,0]_[8,2].pkl" 76 | - "[0,0]_[6,0].pkl" 77 | - "[0,0]_[7,4].pkl" 78 | - "[0,0]_[3,2].pkl" 79 | - "[0,0]_[4,7].pkl" 80 | - "[0,0]_[7,6].pkl" 81 | - "[0,0]_[3,6].pkl" 82 | - "[0,0]_[0,4].pkl" -------------------------------------------------------------------------------- /configs/env_params/dark_room.yaml: -------------------------------------------------------------------------------- 1 | envid: MiniHack-Room-Dark-Dense-10x10-v0 2 | target_return: [90,5] 3 | num_envs: 1 4 | norm_obs: False 5 | record: False 6 | record_freq: 1000000 7 | record_length: 2000 8 | reward_scale: 1 9 | 10 | # randomly selected positions 11 | # 80 for training, 20 for testing 12 | env_kwargs: 13 | random: False 14 | observation_keys: 15 | - tty_cursor 16 | 17 | train_start_pos: 18 | - [0,0] 19 | 20 | train_goal_pos: 21 | - [2, 0] 22 | - [0, 2] 23 | - [1, 5] 24 | - [2, 2] 25 | - [5, 7] 26 | - [9, 1] 27 | - [6, 9] 28 | - [5, 5] 29 | - [1, 1] 30 | - [7, 9] 31 | - [0, 9] 32 | - [3, 8] 33 | - [8, 5] 34 | - [0, 0] 35 | - [8, 9] 36 | - [1, 3] 37 | - [0, 5] 38 | - [0, 1] 39 | - [9, 5] 40 | - [8, 3] 41 | - [4, 4] 42 | - [1, 2] 43 | - [7, 8] 44 | - [9, 4] 45 | - [3, 7] 46 | - [9, 2] 47 | - [9, 7] 48 | - [5, 6] 49 | - [6, 3] 50 | - [4, 6] 51 | - [0, 8] 52 | - [3, 3] 53 | - [4, 5] 54 | - [1, 9] 55 | - [1, 4] 56 | - [9, 3] 57 | - [7, 3] 58 | - [3, 9] 59 | - [2, 4] 60 | - [0, 6] 61 | - [6, 2] 62 | - [2, 3] 63 | - [1, 8] 64 | - [4, 2] 65 | - [8, 0] 66 | - [8, 6] 67 | - [3, 1] 68 | - [6, 7] 69 | - [2, 7] 70 | - [1, 0] 71 | - [4, 0] 72 | - [7, 0] 73 | - [9, 6] 74 | - [8, 8] 75 | - [2, 6] 76 | - [5, 4] 77 | - [7, 1] 78 | - [2, 9] 79 | - [2, 5] 80 | - [4, 3] 81 | - [4, 1] 82 | - [7, 2] 83 | - [0, 3] 84 | - [8, 1] 85 | - [5, 3] 86 | - [7, 7] 87 | - [6, 1] 88 | - [0, 7] 89 | - [2, 8] 90 | - [9, 9] 91 | - [5, 2] 92 | - [4, 8] 93 | - [8, 2] 94 | - [6, 0] 95 | - [7, 4] 96 | - [3, 2] 97 | - [4, 7] 98 | - [7, 6] 99 | - [3, 6] 100 | - [0, 4] 101 | 102 | # train tasks for sfixed, grand 103 | eval_start_pos: [0, 0] 104 | eval_goal_pos: 105 | - [9, 0] 106 | - [8, 4] 107 | - [3, 4] 108 | - [6, 5] 109 | - [7, 5] 110 | - [5, 0] 111 | - [3, 5] 112 | - [9, 8] 113 | - [8, 7] 114 | - [3, 0] 115 | - [6, 6] 116 | - [5, 9] 117 | - [1, 7] 118 | - [5, 1] 119 | - [1, 6] 120 | - [5, 8] 121 | - [2, 1] 122 | - [4, 9] 123 | - [6, 4] 124 | - [6, 8] 125 | -------------------------------------------------------------------------------- /src/algos/models/token_learner.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adjusted from the TF implementation provided by: 3 | - https://github.com/google-research/robotics_transformer/blob/master/tokenizers/token_learner.py 4 | 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class MlpBlock(nn.Module): 12 | 13 | def __init__(self, mlp_dim, out_dim, dropout_rate: float = 0.1): 14 | """ 15 | Initializer for the MLP Block. 16 | 17 | This computes outer_dense(gelu(hidden_dense(input))), with dropout 18 | applied as necessary. 19 | 20 | Args: 21 | mlp_dim: The dimension of the inner representation (output of hidden 22 | layer). Usually larger than the input/output dim. 23 | out_dim: The output dimension of the block. If None, the model output dim 24 | is equal to the input dim (usually desired) 25 | dropout_rate: Dropout rate to be applied after dense ( & activation) 26 | 27 | """ 28 | super().__init__() 29 | self.net = nn.Sequential( 30 | nn.Linear(mlp_dim, mlp_dim), 31 | nn.GELU(), 32 | nn.Dropout(dropout_rate), 33 | nn.Linear(mlp_dim, out_dim), 34 | nn.Dropout(dropout_rate) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.net(x) 39 | 40 | 41 | class TokenLearnerModule(nn.Module): 42 | """TokenLearner module V1.1 (https://arxiv.org/abs/2106.11297).""" 43 | 44 | def __init__(self, 45 | num_tokens: int = 8, 46 | bottleneck_dim: int = 64, 47 | dropout_rate: float = 0.): 48 | super().__init__() 49 | self.mlp = MlpBlock(mlp_dim=bottleneck_dim, out_dim=num_tokens, dropout_rate=dropout_rate) 50 | self.layernorm = nn.LayerNorm(bottleneck_dim) 51 | 52 | def forward(self, x): 53 | if len(x.shape) == 4: 54 | batch_size, height, width, channels = x.shape 55 | x = x.reshape(batch_size, height * width, channels) 56 | 57 | selected = self.layernorm(x) 58 | # Shape: [bs, h*w, n_token]. 59 | selected = self.mlp(x) 60 | # Shape: [bs, n_token, h*w]. 61 | selected = selected.permute(0, 2, 1) 62 | selected = F.softmax(selected, dim=-1) 63 | 64 | # Shape: [bs, n_token, c] 65 | return torch.einsum("...si,...id->...sd", selected, x) 66 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/dark_room_20x20_train.yaml: -------------------------------------------------------------------------------- 1 | base: ${DATA_DIR}/minihack/dark_room/20x20 2 | names: 3 | - "[0,0]_[2,15].pkl" 4 | - "[0,0]_[13,11].pkl" 5 | - "[0,0]_[19,10].pkl" 6 | - "[0,0]_[5,5].pkl" 7 | - "[0,0]_[11,3].pkl" 8 | - "[0,0]_[17,18].pkl" 9 | - "[0,0]_[1,0].pkl" 10 | - "[0,0]_[15,9].pkl" 11 | - "[0,0]_[19,15].pkl" 12 | - "[0,0]_[6,19].pkl" 13 | - "[0,0]_[6,18].pkl" 14 | - "[0,0]_[9,5].pkl" 15 | - "[0,0]_[4,14].pkl" 16 | - "[0,0]_[4,2].pkl" 17 | - "[0,0]_[18,1].pkl" 18 | - "[0,0]_[19,2].pkl" 19 | - "[0,0]_[7,3].pkl" 20 | - "[0,0]_[17,15].pkl" 21 | - "[0,0]_[16,8].pkl" 22 | - "[0,0]_[17,9].pkl" 23 | - "[0,0]_[15,1].pkl" 24 | - "[0,0]_[3,4].pkl" 25 | - "[0,0]_[1,7].pkl" 26 | - "[0,0]_[11,8].pkl" 27 | - "[0,0]_[19,16].pkl" 28 | - "[0,0]_[6,16].pkl" 29 | - "[0,0]_[3,7].pkl" 30 | - "[0,0]_[12,3].pkl" 31 | - "[0,0]_[8,19].pkl" 32 | - "[0,0]_[11,14].pkl" 33 | - "[0,0]_[16,18].pkl" 34 | - "[0,0]_[14,4].pkl" 35 | - "[0,0]_[0,12].pkl" 36 | - "[0,0]_[11,13].pkl" 37 | - "[0,0]_[17,1].pkl" 38 | - "[0,0]_[3,2].pkl" 39 | - "[0,0]_[11,10].pkl" 40 | - "[0,0]_[2,19].pkl" 41 | - "[0,0]_[11,9].pkl" 42 | - "[0,0]_[5,3].pkl" 43 | - "[0,0]_[16,2].pkl" 44 | - "[0,0]_[19,9].pkl" 45 | - "[0,0]_[0,15].pkl" 46 | - "[0,0]_[10,7].pkl" 47 | - "[0,0]_[2,4].pkl" 48 | - "[0,0]_[12,17].pkl" 49 | - "[0,0]_[17,5].pkl" 50 | - "[0,0]_[19,1].pkl" 51 | - "[0,0]_[8,16].pkl" 52 | - "[0,0]_[13,3].pkl" 53 | - "[0,0]_[0,3].pkl" 54 | - "[0,0]_[10,5].pkl" 55 | - "[0,0]_[11,1].pkl" 56 | - "[0,0]_[16,1].pkl" 57 | - "[0,0]_[0,17].pkl" 58 | - "[0,0]_[12,7].pkl" 59 | - "[0,0]_[6,1].pkl" 60 | - "[0,0]_[8,5].pkl" 61 | - "[0,0]_[13,1].pkl" 62 | - "[0,0]_[18,18].pkl" 63 | - "[0,0]_[16,5].pkl" 64 | - "[0,0]_[1,9].pkl" 65 | - "[0,0]_[19,8].pkl" 66 | - "[0,0]_[1,10].pkl" 67 | - "[0,0]_[0,0].pkl" 68 | - "[0,0]_[16,7].pkl" 69 | - "[0,0]_[9,10].pkl" 70 | - "[0,0]_[13,6].pkl" 71 | - "[0,0]_[15,12].pkl" 72 | - "[0,0]_[1,18].pkl" 73 | - "[0,0]_[18,4].pkl" 74 | - "[0,0]_[3,9].pkl" 75 | - "[0,0]_[15,6].pkl" 76 | - "[0,0]_[1,16].pkl" 77 | - "[0,0]_[18,8].pkl" 78 | - "[0,0]_[6,6].pkl" 79 | - "[0,0]_[7,0].pkl" 80 | - "[0,0]_[18,0].pkl" 81 | - "[0,0]_[14,9].pkl" 82 | - "[0,0]_[9,3].pkl" 83 | -------------------------------------------------------------------------------- /src/utils/loss_functions.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class LogitNormLoss(nn.Module): 7 | 8 | def __init__(self, t=1.0, reduction="mean"): 9 | """ 10 | # From: https://github.com/hongxin001/logitnorm_ood/blob/main/common/loss_function.py 11 | 12 | """ 13 | super(LogitNormLoss, self).__init__() 14 | self.t = t 15 | self.reduction = reduction 16 | 17 | def forward(self, x, target): 18 | norms = torch.norm(x, p=2, dim=-1, keepdim=True) + 1e-7 19 | logit_norm = torch.div(x, norms) / self.t 20 | return F.cross_entropy(logit_norm, target, reduction=self.reduction) 21 | 22 | 23 | class DistanceSmoothedCrossEntropyLoss(torch.nn.Module): 24 | 25 | def __init__(self, label_smoothing=0.1, ignore_index=-100, reduction='mean'): 26 | super().__init__() 27 | self.label_smoothing = label_smoothing 28 | self.ignore_index = ignore_index 29 | self.reduction = reduction 30 | 31 | def forward(self, logits, targets): 32 | """ 33 | logits: the model's output logits 34 | targets: the true class labels 35 | """ 36 | num_classes = logits.size(-1) 37 | 38 | # calculate smooth targets via distance from true class 39 | distances = torch.abs(torch.arange(num_classes, device=targets.device) - targets.unsqueeze(1)) 40 | inv_distances = 1 / distances 41 | inv_distances[inv_distances == float('inf')] = 0 42 | inv_distances_norm = inv_distances / inv_distances.sum(dim=-1, keepdim=True) 43 | smooth_targets = inv_distances_norm * self.label_smoothing 44 | # set the weights for the true class 45 | smooth_targets.scatter_(1, targets.unsqueeze(1), 1 - self.label_smoothing) 46 | loss = F.cross_entropy(logits, smooth_targets, reduction='none') 47 | 48 | # mask out the ignore index if present 49 | if self.ignore_index is not None: 50 | mask = targets != self.ignore_index 51 | loss = loss[mask] 52 | 53 | # reduce the loss if needed 54 | if self.reduction == 'mean': 55 | loss = loss.mean() 56 | elif self.reduction == 'sum': 57 | loss = loss.sum() 58 | 59 | return loss 60 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/dark_room_40x20_train.yaml: -------------------------------------------------------------------------------- 1 | base: ${DATA_DIR}/minihack/dark_room/40x20 2 | names: 3 | - "[0,0]_[29,6].pkl" 4 | - "[0,0]_[38,6].pkl" 5 | - "[0,0]_[36,2].pkl" 6 | - "[0,0]_[29,0].pkl" 7 | - "[0,0]_[27,2].pkl" 8 | - "[0,0]_[32,7].pkl" 9 | - "[0,0]_[10,8].pkl" 10 | - "[0,0]_[36,11].pkl" 11 | - "[0,0]_[37,18].pkl" 12 | - "[0,0]_[9,8].pkl" 13 | - "[0,0]_[8,16].pkl" 14 | - "[0,0]_[10,10].pkl" 15 | - "[0,0]_[13,8].pkl" 16 | - "[0,0]_[6,2].pkl" 17 | - "[0,0]_[4,11].pkl" 18 | - "[0,0]_[17,17].pkl" 19 | - "[0,0]_[12,16].pkl" 20 | - "[0,0]_[34,1].pkl" 21 | - "[0,0]_[20,16].pkl" 22 | - "[0,0]_[32,9].pkl" 23 | - "[0,0]_[19,4].pkl" 24 | - "[0,0]_[11,10].pkl" 25 | - "[0,0]_[16,14].pkl" 26 | - "[0,0]_[19,6].pkl" 27 | - "[0,0]_[27,15].pkl" 28 | - "[0,0]_[17,2].pkl" 29 | - "[0,0]_[28,4].pkl" 30 | - "[0,0]_[18,14].pkl" 31 | - "[0,0]_[26,7].pkl" 32 | - "[0,0]_[30,15].pkl" 33 | - "[0,0]_[17,8].pkl" 34 | - "[0,0]_[35,10].pkl" 35 | - "[0,0]_[15,14].pkl" 36 | - "[0,0]_[24,0].pkl" 37 | - "[0,0]_[6,8].pkl" 38 | - "[0,0]_[19,12].pkl" 39 | - "[0,0]_[11,17].pkl" 40 | - "[0,0]_[19,18].pkl" 41 | - "[0,0]_[26,16].pkl" 42 | - "[0,0]_[22,13].pkl" 43 | - "[0,0]_[0,11].pkl" 44 | - "[0,0]_[0,17].pkl" 45 | - "[0,0]_[39,2].pkl" 46 | - "[0,0]_[39,8].pkl" 47 | - "[0,0]_[19,14].pkl" 48 | - "[0,0]_[21,8].pkl" 49 | - "[0,0]_[25,15].pkl" 50 | - "[0,0]_[4,2].pkl" 51 | - "[0,0]_[4,14].pkl" 52 | - "[0,0]_[1,4].pkl" 53 | - "[0,0]_[34,12].pkl" 54 | - "[0,0]_[11,8].pkl" 55 | - "[0,0]_[37,5].pkl" 56 | - "[0,0]_[18,9].pkl" 57 | - "[0,0]_[25,4].pkl" 58 | - "[0,0]_[22,16].pkl" 59 | - "[0,0]_[20,6].pkl" 60 | - "[0,0]_[18,0].pkl" 61 | - "[0,0]_[35,8].pkl" 62 | - "[0,0]_[9,10].pkl" 63 | - "[0,0]_[12,1].pkl" 64 | - "[0,0]_[34,16].pkl" 65 | - "[0,0]_[38,17].pkl" 66 | - "[0,0]_[20,18].pkl" 67 | - "[0,0]_[39,3].pkl" 68 | - "[0,0]_[35,5].pkl" 69 | - "[0,0]_[3,7].pkl" 70 | - "[0,0]_[16,11].pkl" 71 | - "[0,0]_[31,0].pkl" 72 | - "[0,0]_[14,17].pkl" 73 | - "[0,0]_[23,4].pkl" 74 | - "[0,0]_[12,10].pkl" 75 | - "[0,0]_[2,7].pkl" 76 | - "[0,0]_[29,13].pkl" 77 | - "[0,0]_[25,12].pkl" 78 | - "[0,0]_[9,16].pkl" 79 | - "[0,0]_[13,9].pkl" 80 | - "[0,0]_[20,4].pkl" 81 | - "[0,0]_[2,17].pkl" 82 | - "[0,0]_[26,10].pkl" 83 | -------------------------------------------------------------------------------- /configs/env_params/dark_room_20x20.yaml: -------------------------------------------------------------------------------- 1 | envid: MiniHack-Room-Dark-Dense-20x20-v0 2 | target_return: 1 3 | num_envs: 1 4 | norm_obs: False 5 | record: False 6 | record_freq: 1000000 7 | record_length: 2000 8 | reward_scale: 1 9 | 10 | # randomly selected positions 11 | # 80 for training, 20 for testing 12 | env_kwargs: 13 | random: False 14 | observation_keys: 15 | - tty_cursor 16 | 17 | train_start_pos: 18 | - [0,0] 19 | 20 | train_goal_pos: 21 | - [2, 15] 22 | - [13, 11] 23 | - [19, 10] 24 | - [5, 5] 25 | - [11, 3] 26 | - [17, 18] 27 | - [1, 0] 28 | - [15, 9] 29 | - [19, 15] 30 | - [6, 19] 31 | - [6, 18] 32 | - [9, 5] 33 | - [4, 14] 34 | - [4, 2] 35 | - [18, 1] 36 | - [19, 2] 37 | - [7, 3] 38 | - [17, 15] 39 | - [16, 8] 40 | - [17, 9] 41 | - [15, 1] 42 | - [3, 4] 43 | - [1, 7] 44 | - [11, 8] 45 | - [19, 16] 46 | - [6, 16] 47 | - [3, 7] 48 | - [12, 3] 49 | - [8, 19] 50 | - [11, 14] 51 | - [16, 18] 52 | - [14, 4] 53 | - [0, 12] 54 | - [11, 13] 55 | - [17, 1] 56 | - [3, 2] 57 | - [11, 10] 58 | - [2, 19] 59 | - [11, 9] 60 | - [5, 3] 61 | - [16, 2] 62 | - [19, 9] 63 | - [0, 15] 64 | - [10, 7] 65 | - [2, 4] 66 | - [12, 17] 67 | - [17, 5] 68 | - [19, 1] 69 | - [8, 16] 70 | - [13, 3] 71 | - [0, 3] 72 | - [10, 5] 73 | - [11, 1] 74 | - [16, 1] 75 | - [0, 17] 76 | - [12, 7] 77 | - [6, 1] 78 | - [8, 5] 79 | - [13, 1] 80 | - [18, 18] 81 | - [16, 5] 82 | - [1, 9] 83 | - [19, 8] 84 | - [1, 10] 85 | - [0, 0] 86 | - [16, 7] 87 | - [9, 10] 88 | - [13, 6] 89 | - [15, 12] 90 | - [1, 18] 91 | - [18, 4] 92 | - [3, 9] 93 | - [15, 6] 94 | - [1, 16] 95 | - [18, 8] 96 | - [6, 6] 97 | - [7, 0] 98 | - [18, 0] 99 | - [14, 9] 100 | - [9, 3] 101 | 102 | # train tasks for sfixed, grand 103 | eval_start_pos: [0, 0] 104 | eval_goal_pos: 105 | - [16, 12] 106 | - [9, 15] 107 | - [1, 6] 108 | - [5, 17] 109 | - [7, 17] 110 | - [8, 17] 111 | - [10, 14] 112 | - [9, 19] 113 | - [10, 10] 114 | - [16, 17] 115 | - [8, 6] 116 | - [9, 13] 117 | - [15, 8] 118 | - [15, 18] 119 | - [15, 15] 120 | - [3, 18] 121 | - [17, 17] 122 | - [10, 8] 123 | - [5, 2] 124 | - [0, 18] 125 | -------------------------------------------------------------------------------- /configs/agent_params/radt_mazerunner.yaml: -------------------------------------------------------------------------------- 1 | kind: "DCDT" 2 | use_critic: False 3 | learning_starts: 0 4 | batch_size: 128 5 | gradient_steps: 1 6 | stochastic_policy: False 7 | loss_fn: "ce" 8 | ent_coef: 0.0 9 | offline_steps: ${run_params.total_timesteps} 10 | buffer_max_len_type: "transition" 11 | buffer_size: 80000000 # 8e7 12 | buffer_weight_by: len 13 | warmup_steps: 4000 14 | learnable_ret: True 15 | use_amp: True 16 | compile: True 17 | frozen: False 18 | sep_eval_cache: True 19 | # return conditioning 20 | target_return_type: fixed 21 | a_sample_kwargs: 22 | top_p: 0.5 23 | # query type 24 | representation_type: mean 25 | cache_steps: 25 26 | cache_len: ${agent_params.cache_steps} 27 | agg_token: s 28 | query_dropout: 0.2 29 | eval_ret_steps: 25 30 | 31 | freeze_kwargs: 32 | exclude_crossattn: True 33 | 34 | replay_buffer_kwargs: 35 | num_workers: 16 36 | pin_memory: False 37 | 38 | cache_kwargs: 39 | num_workers: 0 40 | prefetch_factor: null 41 | pin_memory: False 42 | init_top_p: 1 43 | exclude_same_trjs: True 44 | task_weight: 1 45 | top_k: 50 46 | reweight_top_k: 1 47 | min_seq_len: 10 48 | sim_cutoff: 0.98 49 | share_trjs: True 50 | use_gpu: True 51 | deduplicate: True 52 | norm: True 53 | index_kwargs: 54 | nb_cores: 64 55 | # go with Flat index for now, others have colissions and are slower with high nprobe 56 | index_key: Flat 57 | 58 | eval_cache_kwargs: 59 | index_kwargs: 60 | nb_cores: ${agent_params.cache_kwargs.index_kwargs.nb_cores} 61 | return_weight: 1 62 | top_k: 50 63 | reweight_top_k: ${agent_params.cache_kwargs.reweight_top_k} 64 | 65 | load_path: 66 | dir_path: ${MODELS_DIR}/mazerunner_15x15 67 | file_name: dt_medium_64.zip 68 | 69 | defaults: 70 | - huggingface: dt_medium_64 71 | - data_paths: mazerunner15x15 72 | - data_paths@cache_data_paths: ${agent_params/data_paths} 73 | - model_kwargs: dark_room 74 | - lr_sched_kwargs: cosine 75 | 76 | huggingface: 77 | activation_function: gelu 78 | max_length: 50 79 | add_cross_attention: True 80 | output_attentions: False 81 | n_positions: 1600 82 | eval_context_len: ${agent_params.huggingface.max_length} 83 | 84 | model_kwargs: 85 | global_pos_embds: True 86 | tokenize_rtg: False 87 | rtg_tok_kwargs: 88 | min_val: 0 89 | max_val: 100 90 | separate_ca_embed: True 91 | -------------------------------------------------------------------------------- /configs/env_params/dark_room_40x20.yaml: -------------------------------------------------------------------------------- 1 | envid: MiniHack-Room-Dark-Dense-40x20-v0 2 | target_return: 1 3 | num_envs: 1 4 | norm_obs: False 5 | record: False 6 | record_freq: 1000000 7 | record_length: 2000 8 | reward_scale: 1 9 | 10 | # randomly selected positions 11 | # 80 for training, 20 for testing 12 | env_kwargs: 13 | random: False 14 | observation_keys: 15 | - tty_cursor 16 | 17 | train_start_pos: 18 | - [0,0] 19 | train_goal_pos: 20 | - [29, 6] 21 | - [38, 6] 22 | - [36, 2] 23 | - [29, 0] 24 | - [27, 2] 25 | - [32, 7] 26 | - [10, 8] 27 | - [36, 11] 28 | - [37, 18] 29 | - [9, 8] 30 | - [8, 16] 31 | - [10, 10] 32 | - [13, 8] 33 | - [6, 2] 34 | - [4, 11] 35 | - [17, 17] 36 | - [12, 16] 37 | - [34, 1] 38 | - [20, 16] 39 | - [32, 9] 40 | - [19, 4] 41 | - [11, 10] 42 | - [16, 14] 43 | - [19, 6] 44 | - [27, 15] 45 | - [17, 2] 46 | - [28, 4] 47 | - [18, 14] 48 | - [26, 7] 49 | - [30, 15] 50 | - [17, 8] 51 | - [35, 10] 52 | - [15, 14] 53 | - [24, 0] 54 | - [6, 8] 55 | - [19, 12] 56 | - [11, 17] 57 | - [19, 18] 58 | - [26, 16] 59 | - [22, 13] 60 | - [0, 11] 61 | - [0, 17] 62 | - [39, 2] 63 | - [39, 8] 64 | - [19, 14] 65 | - [21, 8] 66 | - [25, 15] 67 | - [4, 2] 68 | - [4, 14] 69 | - [1, 4] 70 | - [34, 12] 71 | - [11, 8] 72 | - [37, 5] 73 | - [18, 9] 74 | - [25, 4] 75 | - [22, 16] 76 | - [20, 6] 77 | - [18, 0] 78 | - [35, 8] 79 | - [9, 10] 80 | - [12, 1] 81 | - [34, 16] 82 | - [38, 17] 83 | - [20, 18] 84 | - [39, 3] 85 | - [35, 5] 86 | - [3, 7] 87 | - [16, 11] 88 | - [31, 0] 89 | - [14, 17] 90 | - [23, 4] 91 | - [12, 10] 92 | - [2, 7] 93 | - [29, 13] 94 | - [25, 12] 95 | - [9, 16] 96 | - [13, 9] 97 | - [20, 4] 98 | - [2, 17] 99 | - [26, 10] 100 | 101 | # train tasks for sfixed, grand 102 | eval_start_pos: [0, 0] 103 | eval_goal_pos: 104 | - [14, 13] 105 | - [31, 2] 106 | - [5, 19] 107 | - [18, 5] 108 | - [29, 9] 109 | - [5, 11] 110 | - [28, 8] 111 | - [22, 10] 112 | - [1, 15] 113 | - [27, 11] 114 | - [15, 17] 115 | - [5, 5] 116 | - [4, 8] 117 | - [38, 5] 118 | - [21, 6] 119 | - [24, 15] 120 | - [7, 19] 121 | - [13, 16] 122 | - [13, 1] 123 | - [12, 9] 124 | -------------------------------------------------------------------------------- /configs/agent_params/radt_disc_icl.yaml: -------------------------------------------------------------------------------- 1 | kind: "DCDT" 2 | use_critic: False 3 | learning_starts: 0 4 | batch_size: 128 5 | gradient_steps: 1 6 | stochastic_policy: False 7 | loss_fn: "ce" 8 | ent_coef: 0.0 9 | offline_steps: ${run_params.total_timesteps} 10 | buffer_max_len_type: "transition" 11 | buffer_size: 80000000 # 8e7 12 | buffer_weight_by: len 13 | warmup_steps: 4000 14 | learnable_ret: True 15 | use_amp: True 16 | compile: True 17 | frozen: False 18 | sep_eval_cache: True 19 | # return conditioning 20 | target_return_loss_fn_type: ce 21 | target_return_type: fixed 22 | a_sample_kwargs: 23 | top_p: 0.5 24 | # query type 25 | representation_type: mean 26 | cache_steps: 25 27 | cache_len: ${agent_params.cache_steps} 28 | agg_token: s 29 | query_dropout: 0.2 30 | 31 | freeze_kwargs: 32 | exclude_crossattn: True 33 | 34 | replay_buffer_kwargs: 35 | num_workers: 16 36 | pin_memory: False 37 | 38 | cache_kwargs: 39 | num_workers: 0 40 | prefetch_factor: null 41 | pin_memory: False 42 | init_top_p: 1 43 | exclude_same_trjs: True 44 | task_weight: 1 45 | top_k: 50 46 | reweight_top_k: 1 47 | min_seq_len: 10 48 | sim_cutoff: 0.98 49 | share_trjs: True 50 | use_gpu: True 51 | deduplicate: True 52 | norm: True 53 | index_kwargs: 54 | nb_cores: 64 55 | # go with Flat index for now, others have colissions and are slower with high nprobe 56 | index_key: Flat 57 | 58 | eval_cache_kwargs: 59 | index_kwargs: 60 | nb_cores: ${agent_params.cache_kwargs.index_kwargs.nb_cores} 61 | return_weight: 1 62 | top_k: 50 63 | reweight_top_k: ${agent_params.cache_kwargs.reweight_top_k} 64 | 65 | load_path: 66 | dir_path: ${MODELS_DIR}/minihack/dark_room_dense_10x10 67 | file_name: dt_medium_64.zip 68 | 69 | defaults: 70 | - huggingface: dt_medium_64 71 | - data_paths: dark_room_10x10_train 72 | - data_paths@cache_data_paths: ${agent_params/data_paths} 73 | - model_kwargs: dark_room 74 | - lr_sched_kwargs: cosine 75 | 76 | huggingface: 77 | activation_function: gelu 78 | max_length: 50 79 | add_cross_attention: True 80 | output_attentions: False 81 | n_positions: 1600 82 | eval_context_len: ${agent_params.huggingface.max_length} 83 | 84 | model_kwargs: 85 | global_pos_embds: True 86 | tokenize_rtg: True 87 | rtg_tok_kwargs: 88 | min_val: 0 89 | max_val: 100 90 | separate_ca_embed: True 91 | -------------------------------------------------------------------------------- /configs/agent_params/radt_procgen.yaml: -------------------------------------------------------------------------------- 1 | kind: "DCDT" 2 | use_critic: False 3 | learning_starts: 0 4 | batch_size: 128 5 | gradient_steps: 1 6 | stochastic_policy: False 7 | loss_fn: "ce" 8 | ent_coef: 0.0 9 | offline_steps: ${run_params.total_timesteps} 10 | buffer_max_len_type: "transition" 11 | buffer_size: 80000000 # 8e7 12 | buffer_weight_by: len 13 | warmup_steps: 4000 14 | learnable_ret: True 15 | use_amp: True 16 | compile: True 17 | frozen: False 18 | sep_eval_cache: True 19 | # return conditioning 20 | # target_return_loss_fn_type: ce 21 | target_return_type: predefined 22 | # a_sample_kwargs: 23 | # top_p: 0.5 24 | # query type 25 | representation_type: mean 26 | cache_steps: 50 27 | cache_len: ${agent_params.cache_steps} 28 | agg_token: s 29 | reinit_policy: True 30 | 31 | freeze_kwargs: 32 | exclude_crossattn: True 33 | 34 | replay_buffer_kwargs: 35 | num_workers: 16 36 | pin_memory: False 37 | from_disk: False 38 | 39 | cache_kwargs: 40 | num_workers: 0 41 | prefetch_factor: null 42 | pin_memory: False 43 | init_top_p: 1 44 | exclude_same_trjs: True 45 | task_weight: 1 46 | reweight_top_k: 1 47 | # min_seq_len: 10 48 | share_trjs: True 49 | from_disk: False 50 | index_kwargs: 51 | nb_cores: 64 52 | # go with Flat index for now, others have colissions and are slower with high nprobe 53 | index_key: Flat 54 | 55 | eval_cache_kwargs: 56 | index_kwargs: 57 | nb_cores: ${agent_params.cache_kwargs.index_kwargs.nb_cores} 58 | return_weight: 1 59 | top_k: 50 60 | reweight_top_k: ${agent_params.cache_kwargs.reweight_top_k} 61 | 62 | # load kwargs for image encoder 63 | load_path: 64 | dir_path: ${MODELS_DIR}/procgen_encoders_custom_v2 65 | file_name: dt_mediumplus_64.zip 66 | 67 | defaults: 68 | - huggingface: dt_mediumplus_64 69 | - data_paths: procgen12 70 | - data_paths@cache_data_paths: ${agent_params/data_paths} 71 | - model_kwargs: procgen 72 | - lr_sched_kwargs: cosine 73 | 74 | huggingface: 75 | activation_function: gelu 76 | max_length: 50 77 | add_cross_attention: True 78 | output_attentions: False 79 | n_positions: 1600 80 | eval_context_len: ${agent_params.huggingface.max_length} 81 | 82 | model_kwargs: 83 | global_pos_embds: True 84 | # tokenize_rtg: True 85 | # rtg_tok_kwargs: 86 | # min_val: 0 87 | # max_val: 100 88 | separate_ca_embed: True 89 | img_is_encoded: True 90 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .vscode 132 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/dark_keydoor_10x10_train.yaml: -------------------------------------------------------------------------------- 1 | base: ${DATA_DIR}/minihack/dark_keydoor/10x10 2 | names: 3 | - "[0,0]_[2,0]_[8,3].pkl" 4 | - "[0,0]_[0,2]_[5,3].pkl" 5 | - "[0,0]_[1,5]_[7,0].pkl" 6 | - "[0,0]_[2,2]_[4,5].pkl" 7 | - "[0,0]_[5,7]_[4,4].pkl" 8 | - "[0,0]_[9,1]_[3,9].pkl" 9 | - "[0,0]_[6,9]_[2,2].pkl" 10 | - "[0,0]_[5,5]_[8,0].pkl" 11 | - "[0,0]_[1,1]_[1,0].pkl" 12 | - "[0,0]_[7,9]_[0,0].pkl" 13 | - "[0,0]_[0,9]_[1,8].pkl" 14 | - "[0,0]_[3,8]_[3,0].pkl" 15 | - "[0,0]_[8,5]_[7,3].pkl" 16 | - "[0,0]_[0,0]_[3,3].pkl" 17 | - "[0,0]_[8,9]_[9,0].pkl" 18 | - "[0,0]_[1,3]_[0,4].pkl" 19 | - "[0,0]_[0,5]_[7,6].pkl" 20 | - "[0,0]_[0,1]_[7,7].pkl" 21 | - "[0,0]_[9,5]_[1,2].pkl" 22 | - "[0,0]_[8,3]_[3,1].pkl" 23 | - "[0,0]_[4,4]_[5,5].pkl" 24 | - "[0,0]_[1,2]_[8,8].pkl" 25 | - "[0,0]_[7,8]_[2,6].pkl" 26 | - "[0,0]_[9,4]_[4,2].pkl" 27 | - "[0,0]_[3,7]_[6,9].pkl" 28 | - "[0,0]_[9,2]_[1,5].pkl" 29 | - "[0,0]_[9,7]_[4,0].pkl" 30 | - "[0,0]_[5,6]_[9,6].pkl" 31 | - "[0,0]_[6,3]_[0,9].pkl" 32 | - "[0,0]_[4,6]_[7,2].pkl" 33 | - "[0,0]_[0,8]_[1,1].pkl" 34 | - "[0,0]_[3,3]_[4,7].pkl" 35 | - "[0,0]_[4,5]_[8,5].pkl" 36 | - "[0,0]_[1,9]_[2,8].pkl" 37 | - "[0,0]_[1,4]_[9,3].pkl" 38 | - "[0,0]_[9,3]_[0,5].pkl" 39 | - "[0,0]_[7,3]_[6,6].pkl" 40 | - "[0,0]_[3,9]_[6,5].pkl" 41 | - "[0,0]_[2,4]_[3,5].pkl" 42 | - "[0,0]_[0,6]_[1,6].pkl" 43 | - "[0,0]_[6,2]_[4,9].pkl" 44 | - "[0,0]_[2,3]_[3,4].pkl" 45 | - "[0,0]_[1,8]_[0,7].pkl" 46 | - "[0,0]_[4,2]_[9,5].pkl" 47 | - "[0,0]_[8,0]_[2,7].pkl" 48 | - "[0,0]_[8,6]_[1,9].pkl" 49 | - "[0,0]_[3,1]_[8,1].pkl" 50 | - "[0,0]_[6,7]_[2,5].pkl" 51 | - "[0,0]_[2,7]_[6,2].pkl" 52 | - "[0,0]_[1,0]_[1,3].pkl" 53 | - "[0,0]_[4,0]_[2,4].pkl" 54 | - "[0,0]_[7,0]_[0,3].pkl" 55 | - "[0,0]_[9,6]_[1,7].pkl" 56 | - "[0,0]_[8,8]_[3,8].pkl" 57 | - "[0,0]_[2,6]_[0,8].pkl" 58 | - "[0,0]_[5,4]_[7,8].pkl" 59 | - "[0,0]_[7,1]_[0,6].pkl" 60 | - "[0,0]_[2,9]_[6,4].pkl" 61 | - "[0,0]_[2,5]_[3,6].pkl" 62 | - "[0,0]_[4,3]_[8,9].pkl" 63 | - "[0,0]_[4,1]_[5,6].pkl" 64 | - "[0,0]_[7,2]_[9,9].pkl" 65 | - "[0,0]_[0,3]_[5,4].pkl" 66 | - "[0,0]_[8,1]_[4,3].pkl" 67 | - "[0,0]_[5,3]_[5,0].pkl" 68 | - "[0,0]_[7,7]_[6,7].pkl" 69 | - "[0,0]_[6,1]_[4,6].pkl" 70 | - "[0,0]_[0,7]_[6,8].pkl" 71 | - "[0,0]_[2,8]_[6,1].pkl" 72 | - "[0,0]_[9,9]_[9,7].pkl" 73 | - "[0,0]_[5,2]_[7,9].pkl" 74 | - "[0,0]_[4,8]_[4,1].pkl" 75 | - "[0,0]_[8,2]_[5,8].pkl" 76 | - "[0,0]_[6,0]_[4,8].pkl" 77 | - "[0,0]_[7,4]_[9,8].pkl" 78 | - "[0,0]_[3,2]_[5,7].pkl" 79 | - "[0,0]_[4,7]_[7,5].pkl" 80 | - "[0,0]_[7,6]_[3,2].pkl" 81 | - "[0,0]_[3,6]_[9,4].pkl" 82 | - "[0,0]_[0,4]_[5,9].pkl" -------------------------------------------------------------------------------- /src/algos/__init__.py: -------------------------------------------------------------------------------- 1 | MODEL_CLASSES = { 2 | "DT": None, 3 | "ODT": None, 4 | "UDT": None, 5 | "DDT": None, 6 | "HelmDT": None, 7 | "DHelmDT": None, 8 | "CDT": None, 9 | "DCDT": None, 10 | } 11 | 12 | AGENT_CLASSES = { 13 | "DT": None, 14 | "ODT": None, 15 | "UDT": None, 16 | "DDT": None, 17 | "MDDT": None, 18 | "HelmDT": None, 19 | "DHelmDT": None, 20 | "CDT": None, 21 | "DCDT": None, 22 | } 23 | 24 | 25 | def get_model_class(kind): 26 | if kind in ["DT", "ODT", "UDT"]: 27 | from .models.online_decision_transformer_model import OnlineDecisionTransformerModel 28 | MODEL_CLASSES[kind] = OnlineDecisionTransformerModel 29 | elif kind in ["DDT"]: 30 | from .models.discrete_decision_transformer_model import DiscreteDTModel 31 | MODEL_CLASSES[kind] = DiscreteDTModel 32 | elif kind in ["HelmDT"]: 33 | from .models.helm_decision_transformer_model import HelmDTModel 34 | MODEL_CLASSES[kind] = HelmDTModel 35 | elif kind in ["DHelmDT"]: 36 | from .models.helm_decision_transformer_model import DiscreteHelmDTModel 37 | MODEL_CLASSES[kind] = DiscreteHelmDTModel 38 | elif kind in ["CDT"]: 39 | from .models.cache_decision_transformer_model import CacheDTModel 40 | MODEL_CLASSES[kind] = CacheDTModel 41 | elif kind in ["DCDT"]: 42 | from .models.cache_decision_transformer_model import DiscreteCacheDTModel 43 | MODEL_CLASSES[kind] = DiscreteCacheDTModel 44 | assert kind in MODEL_CLASSES, f"Unknown kind: {kind}" 45 | return MODEL_CLASSES[kind] 46 | 47 | 48 | def get_agent_class(kind): 49 | assert kind in AGENT_CLASSES, f"Unknown kind: {kind}" 50 | # lazy imports only when needed 51 | if kind in ["DT", "ODT", "HelmDT", "DHelmDT"]: 52 | from .decision_transformer_sb3 import DecisionTransformerSb3 53 | AGENT_CLASSES[kind] = DecisionTransformerSb3 54 | elif kind in ["UDT"]: 55 | from .universal_decision_transformer_sb3 import UDT 56 | AGENT_CLASSES[kind] = UDT 57 | elif kind in ["DDT"]: 58 | from .discrete_decision_transformer_sb3 import DiscreteDecisionTransformerSb3 59 | AGENT_CLASSES[kind] = DiscreteDecisionTransformerSb3 60 | elif kind == "CDT": 61 | from .cache_decision_transformer_sb3 import CacheDecisionTransformerSb3, DiscreteCacheDecisionTransformerSb3 62 | AGENT_CLASSES[kind] = CacheDecisionTransformerSb3 63 | elif kind == "DCDT": 64 | from .cache_decision_transformer_sb3 import DiscreteCacheDecisionTransformerSb3 65 | AGENT_CLASSES[kind] = DiscreteCacheDecisionTransformerSb3 66 | return AGENT_CLASSES[kind] 67 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/dark_keydoor_20x20_train.yaml: -------------------------------------------------------------------------------- 1 | base: ${DATA_DIR}/minihack/dark_keydoor/20x20 2 | names: 3 | - "[0,0]_[2,15]_[10,9].pkl" 4 | - "[0,0]_[13,11]_[14,0].pkl" 5 | - "[0,0]_[19,10]_[1,13].pkl" 6 | - "[0,0]_[5,5]_[10,10].pkl" 7 | - "[0,0]_[11,3]_[4,13].pkl" 8 | - "[0,0]_[17,18]_[4,4].pkl" 9 | - "[0,0]_[1,0]_[16,9].pkl" 10 | - "[0,0]_[15,9]_[4,14].pkl" 11 | - "[0,0]_[19,15]_[13,6].pkl" 12 | - "[0,0]_[6,19]_[6,6].pkl" 13 | - "[0,0]_[6,18]_[0,9].pkl" 14 | - "[0,0]_[9,5]_[18,1].pkl" 15 | - "[0,0]_[4,14]_[2,16].pkl" 16 | - "[0,0]_[4,2]_[3,12].pkl" 17 | - "[0,0]_[18,1]_[6,12].pkl" 18 | - "[0,0]_[19,2]_[2,2].pkl" 19 | - "[0,0]_[7,3]_[13,18].pkl" 20 | - "[0,0]_[17,15]_[18,16].pkl" 21 | - "[0,0]_[16,8]_[11,11].pkl" 22 | - "[0,0]_[17,9]_[19,5].pkl" 23 | - "[0,0]_[15,1]_[3,17].pkl" 24 | - "[0,0]_[3,4]_[0,15].pkl" 25 | - "[0,0]_[1,7]_[19,11].pkl" 26 | - "[0,0]_[11,8]_[13,11].pkl" 27 | - "[0,0]_[19,16]_[0,0].pkl" 28 | - "[0,0]_[6,16]_[19,16].pkl" 29 | - "[0,0]_[3,7]_[5,14].pkl" 30 | - "[0,0]_[12,3]_[11,5].pkl" 31 | - "[0,0]_[8,19]_[13,2].pkl" 32 | - "[0,0]_[11,14]_[5,4].pkl" 33 | - "[0,0]_[16,18]_[19,15].pkl" 34 | - "[0,0]_[14,4]_[9,13].pkl" 35 | - "[0,0]_[0,12]_[13,1].pkl" 36 | - "[0,0]_[11,13]_[2,17].pkl" 37 | - "[0,0]_[17,1]_[11,12].pkl" 38 | - "[0,0]_[3,2]_[5,16].pkl" 39 | - "[0,0]_[11,10]_[5,13].pkl" 40 | - "[0,0]_[2,19]_[17,2].pkl" 41 | - "[0,0]_[11,9]_[7,18].pkl" 42 | - "[0,0]_[5,3]_[7,1].pkl" 43 | - "[0,0]_[16,2]_[19,1].pkl" 44 | - "[0,0]_[19,9]_[2,15].pkl" 45 | - "[0,0]_[0,15]_[3,16].pkl" 46 | - "[0,0]_[10,7]_[1,5].pkl" 47 | - "[0,0]_[2,4]_[4,2].pkl" 48 | - "[0,0]_[12,17]_[19,2].pkl" 49 | - "[0,0]_[17,5]_[7,8].pkl" 50 | - "[0,0]_[19,1]_[9,1].pkl" 51 | - "[0,0]_[8,16]_[1,2].pkl" 52 | - "[0,0]_[13,3]_[8,13].pkl" 53 | - "[0,0]_[0,3]_[2,6].pkl" 54 | - "[0,0]_[10,5]_[16,1].pkl" 55 | - "[0,0]_[11,1]_[16,18].pkl" 56 | - "[0,0]_[16,1]_[3,10].pkl" 57 | - "[0,0]_[0,17]_[18,14].pkl" 58 | - "[0,0]_[12,7]_[1,19].pkl" 59 | - "[0,0]_[6,1]_[11,3].pkl" 60 | - "[0,0]_[8,5]_[8,12].pkl" 61 | - "[0,0]_[13,1]_[1,10].pkl" 62 | - "[0,0]_[18,18]_[7,12].pkl" 63 | - "[0,0]_[16,5]_[6,4].pkl" 64 | - "[0,0]_[1,9]_[14,14].pkl" 65 | - "[0,0]_[19,8]_[12,15].pkl" 66 | - "[0,0]_[1,10]_[3,18].pkl" 67 | - "[0,0]_[0,0]_[5,1].pkl" 68 | - "[0,0]_[16,7]_[1,11].pkl" 69 | - "[0,0]_[9,10]_[17,12].pkl" 70 | - "[0,0]_[13,6]_[13,8].pkl" 71 | - "[0,0]_[15,12]_[19,14].pkl" 72 | - "[0,0]_[1,18]_[3,13].pkl" 73 | - "[0,0]_[18,4]_[16,0].pkl" 74 | - "[0,0]_[3,9]_[7,0].pkl" 75 | - "[0,0]_[15,6]_[0,5].pkl" 76 | - "[0,0]_[1,16]_[2,5].pkl" 77 | - "[0,0]_[18,8]_[19,8].pkl" 78 | - "[0,0]_[6,6]_[12,6].pkl" 79 | - "[0,0]_[7,0]_[11,7].pkl" 80 | - "[0,0]_[18,0]_[18,9].pkl" 81 | - "[0,0]_[14,9]_[8,16].pkl" 82 | - "[0,0]_[9,3]_[14,9].pkl" -------------------------------------------------------------------------------- /src/utils/debug.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from matplotlib.lines import Line2D 4 | from pathlib import Path 5 | from datetime import datetime 6 | 7 | 8 | class GradPlotter: 9 | 10 | def __init__(self, y_min=-0.01, y_max=0.5, base_dir=None): 11 | self.base_dir = base_dir 12 | if self.base_dir is None: 13 | self.base_dir = "/home/thomas/Projects-Linux/CRL_with_transformers/debug" 14 | time = datetime.now().strftime("%d-%m-%Y_%Hh%Mm") 15 | self.base_dir = Path(self.base_dir) / time 16 | self.base_dir.mkdir(exist_ok=True, parents=True) 17 | self.y_min = y_min 18 | self.y_max = y_max 19 | 20 | def plot_grad_flow(self, named_parameters, file_name): 21 | """ 22 | Adjusted from: 23 | https://gist.github.com/Flova/8bed128b41a74142a661883af9e51490 24 | 25 | Plots the gradients flowing through different layers in the net during training. 26 | Can be used for checking for possible gradient vanishing / exploding problems. 27 | 28 | Usage: Plug this function in Trainer class after loss.backwards() as 29 | "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow 30 | 31 | E.g., call using: 32 | if self._n_updates % 1000 == 0: 33 | plot_grad_flow(self.critic.named_parameters(), f"critic_update,critic,step={self._n_updates}.png") 34 | plot_grad_flow(self.policy.named_parameters(), f"critic_update,policy,step={self._n_updates}.png") 35 | 36 | """ 37 | ave_grads, max_grads, layers = [], [], [] 38 | for n, p in named_parameters: 39 | if (p.requires_grad) and ("bias" not in n): 40 | layers.append(n) 41 | ave_grads.append(p.grad.abs().mean().item() if p.grad is not None else 0) 42 | max_grads.append(p.grad.abs().max().item() if p.grad is not None else 0) 43 | plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.5, lw=1, color="c") 44 | plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.5, lw=1, color="b") 45 | plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k") 46 | plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") 47 | plt.xlim(left=0, right=len(ave_grads)) 48 | # plt.ylim(bottom=self.y_min, top=self.y_max) 49 | plt.xlabel("Layers") 50 | plt.ylabel("average gradient") 51 | plt.title("Gradient flow") 52 | plt.grid(True) 53 | plt.legend([Line2D([0], [0], color="c", lw=4), 54 | Line2D([0], [0], color="b", lw=4), 55 | Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient']) 56 | plt.savefig(self.base_dir / file_name, bbox_inches='tight') 57 | plt.close() 58 | -------------------------------------------------------------------------------- /src/schedulers/schedulers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class ScheduleBase: 5 | def __init__(self, init_val=1.0, max_step_multiplier=1.0, max_step=None, min_val=None): 6 | self._max_step = max_step * max_step_multiplier 7 | self._min_val = min_val 8 | self._init_val = init_val 9 | 10 | @property 11 | def max_step(self): 12 | return self._max_step 13 | 14 | @property 15 | def min_val(self): 16 | return self._min_val 17 | 18 | @property 19 | def init_val(self): 20 | return self._init_val 21 | 22 | def __call__(self, step): 23 | return self.get_value(step) 24 | 25 | def get_value(self, step): 26 | raise NotImplementedError 27 | 28 | 29 | class Linear(ScheduleBase): 30 | """ 31 | Similar to : 32 | https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.LinearLR.html#torch.optim.lr_scheduler.LinearLR 33 | """ 34 | def get_value(self, step): 35 | val = step / self.max_step 36 | if self.min_val is not None and val < self.min_val: 37 | return self.min_val 38 | return val 39 | 40 | 41 | class Step(ScheduleBase): 42 | def __init__(self, step_size, gamma=0.5, **kwargs): 43 | """ 44 | Similar to: 45 | https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html#torch.optim.lr_scheduler.StepLR 46 | Args: 47 | step_size: Int. 48 | gamma: Float. 49 | """ 50 | super().__init__(**kwargs) 51 | self.step_size = step_size 52 | self.gamma = gamma 53 | 54 | def get_value(self, step): 55 | exponent = int(step / self.step_size) 56 | val = self.init_val * (self.gamma ** exponent) 57 | if self.min_val is not None and val < self.min_val: 58 | return self.min_val 59 | return val 60 | 61 | 62 | class CosineAnnealing(ScheduleBase): 63 | def __init__(self, eta_min=0, **kwargs): 64 | """ 65 | Similar to: 66 | https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingLR.html#torch.optim.lr_scheduler.CosineAnnealingLR 67 | Args: 68 | eta_min: Float. 69 | **kwargs: 70 | """ 71 | super().__init__(**kwargs) 72 | self.eta_min = eta_min 73 | 74 | def get_value(self, step): 75 | if step > self.max_step: 76 | return self.eta_min 77 | return self.eta_min + (self.init_val - self.eta_min) * (1 + math.cos(math.pi * step / self.max_step)) / 2 78 | 79 | 80 | def make_scheduler(kind="linear", **kwargs): 81 | if kind == "linear": 82 | return Linear(**kwargs) 83 | elif kind == "cosine": 84 | return CosineAnnealing(**kwargs) 85 | elif kind == "step": 86 | return Step(**kwargs) 87 | return None 88 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/dark_keydoor_40x20_train.yaml: -------------------------------------------------------------------------------- 1 | base: ${DATA_DIR}/minihack/dark_keydoor/40x20 2 | names: 3 | - "[0,0]_[29,6]_[34,16].pkl" 4 | - "[0,0]_[38,6]_[33,7].pkl" 5 | - "[0,0]_[36,2]_[3,3].pkl" 6 | - "[0,0]_[29,0]_[26,13].pkl" 7 | - "[0,0]_[27,2]_[3,6].pkl" 8 | - "[0,0]_[32,7]_[31,1].pkl" 9 | - "[0,0]_[10,8]_[17,6].pkl" 10 | - "[0,0]_[36,11]_[24,10].pkl" 11 | - "[0,0]_[37,18]_[38,0].pkl" 12 | - "[0,0]_[9,8]_[22,16].pkl" 13 | - "[0,0]_[8,16]_[3,5].pkl" 14 | - "[0,0]_[10,10]_[14,6].pkl" 15 | - "[0,0]_[13,8]_[31,15].pkl" 16 | - "[0,0]_[6,2]_[3,7].pkl" 17 | - "[0,0]_[4,11]_[16,7].pkl" 18 | - "[0,0]_[17,17]_[19,15].pkl" 19 | - "[0,0]_[12,16]_[12,4].pkl" 20 | - "[0,0]_[34,1]_[18,17].pkl" 21 | - "[0,0]_[20,16]_[30,10].pkl" 22 | - "[0,0]_[32,9]_[26,6].pkl" 23 | - "[0,0]_[19,4]_[32,18].pkl" 24 | - "[0,0]_[11,10]_[26,9].pkl" 25 | - "[0,0]_[16,14]_[31,2].pkl" 26 | - "[0,0]_[19,6]_[36,1].pkl" 27 | - "[0,0]_[27,15]_[18,0].pkl" 28 | - "[0,0]_[17,2]_[1,10].pkl" 29 | - "[0,0]_[28,4]_[13,0].pkl" 30 | - "[0,0]_[18,14]_[31,17].pkl" 31 | - "[0,0]_[26,7]_[37,6].pkl" 32 | - "[0,0]_[30,15]_[28,10].pkl" 33 | - "[0,0]_[17,8]_[10,15].pkl" 34 | - "[0,0]_[35,10]_[3,18].pkl" 35 | - "[0,0]_[15,14]_[28,18].pkl" 36 | - "[0,0]_[24,0]_[19,3].pkl" 37 | - "[0,0]_[6,8]_[1,19].pkl" 38 | - "[0,0]_[19,12]_[1,3].pkl" 39 | - "[0,0]_[11,17]_[34,12].pkl" 40 | - "[0,0]_[19,18]_[19,18].pkl" 41 | - "[0,0]_[26,16]_[39,16].pkl" 42 | - "[0,0]_[22,13]_[37,1].pkl" 43 | - "[0,0]_[0,11]_[6,19].pkl" 44 | - "[0,0]_[0,17]_[12,10].pkl" 45 | - "[0,0]_[39,2]_[8,14].pkl" 46 | - "[0,0]_[39,8]_[16,3].pkl" 47 | - "[0,0]_[19,14]_[29,15].pkl" 48 | - "[0,0]_[21,8]_[26,5].pkl" 49 | - "[0,0]_[25,15]_[21,3].pkl" 50 | - "[0,0]_[4,2]_[29,16].pkl" 51 | - "[0,0]_[4,14]_[30,4].pkl" 52 | - "[0,0]_[1,4]_[26,14].pkl" 53 | - "[0,0]_[34,12]_[13,5].pkl" 54 | - "[0,0]_[11,8]_[5,9].pkl" 55 | - "[0,0]_[37,5]_[33,6].pkl" 56 | - "[0,0]_[18,9]_[14,14].pkl" 57 | - "[0,0]_[25,4]_[16,16].pkl" 58 | - "[0,0]_[22,16]_[18,8].pkl" 59 | - "[0,0]_[20,6]_[22,6].pkl" 60 | - "[0,0]_[18,0]_[16,13].pkl" 61 | - "[0,0]_[35,8]_[9,18].pkl" 62 | - "[0,0]_[9,10]_[8,8].pkl" 63 | - "[0,0]_[12,1]_[24,1].pkl" 64 | - "[0,0]_[34,16]_[36,3].pkl" 65 | - "[0,0]_[38,17]_[32,15].pkl" 66 | - "[0,0]_[20,18]_[15,6].pkl" 67 | - "[0,0]_[39,3]_[11,11].pkl" 68 | - "[0,0]_[35,5]_[21,2].pkl" 69 | - "[0,0]_[3,7]_[2,9].pkl" 70 | - "[0,0]_[16,11]_[39,7].pkl" 71 | - "[0,0]_[31,0]_[17,17].pkl" 72 | - "[0,0]_[14,17]_[4,6].pkl" 73 | - "[0,0]_[23,4]_[27,5].pkl" 74 | - "[0,0]_[12,10]_[9,12].pkl" 75 | - "[0,0]_[2,7]_[39,15].pkl" 76 | - "[0,0]_[29,13]_[1,13].pkl" 77 | - "[0,0]_[25,12]_[1,11].pkl" 78 | - "[0,0]_[9,16]_[9,19].pkl" 79 | - "[0,0]_[13,9]_[18,5].pkl" 80 | - "[0,0]_[20,4]_[24,6].pkl" 81 | - "[0,0]_[2,17]_[29,14].pkl" 82 | - "[0,0]_[26,10]_[28,8].pkl" -------------------------------------------------------------------------------- /src/tokenizers_custom/mu_law_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adjusted from: https://github.com/G-Wang/WaveRNN-Pytorch/blob/master/utils.py 3 | 4 | Could also just use torchaudio: 5 | https://github.com/pytorch/audio/blob/0cd25093626d067e008e1f81ad76e072bd4a1edd/torchaudio/transforms.py#L757 6 | 7 | """ 8 | import torch 9 | import numpy as np 10 | from .base_tokenizer import BaseTokenizer 11 | 12 | 13 | class MuLawTokenizer(BaseTokenizer): 14 | 15 | def __init__(self, **kwargs): 16 | super().__init__(**kwargs) 17 | 18 | def tokenize(self, x): 19 | """ 20 | Encode signal based on mu-law companding. For more info see the 21 | `Wikipedia Entry `_ 22 | This algorithm assumes the signal has been scaled to between -1 and 1 and 23 | returns a signal encoded with values from 0 to quantization_channels - 1 24 | Args: 25 | quantization_channels (int): Number of channels. default: 256 26 | """ 27 | mu = self.vocab_size - 1 28 | if isinstance(x, np.ndarray): 29 | x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu) 30 | tokens = ((x_mu + 1) / 2 * mu + 0.5).astype(int) 31 | if self.shift != 0: 32 | return tokens + self.shift 33 | return tokens 34 | elif isinstance(x, (torch.Tensor, torch.LongTensor)): 35 | if isinstance(x, torch.LongTensor): 36 | x = x.float() 37 | mu = torch.FloatTensor([mu]).to(device=x.device) 38 | x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu) 39 | tokens = ((x_mu + 1) / 2 * mu + 0.5).long() 40 | if self.shift != 0: 41 | return tokens + self.shift 42 | return tokens 43 | raise NotImplementedError() 44 | 45 | def inv_tokenize(self, x_mu): 46 | """ 47 | Decode mu-law encoded signal. For more info see the 48 | `Wikipedia Entry `_ 49 | This expects an input with values between 0 and quantization_channels - 1 50 | and returns a signal scaled between -1 and 1. 51 | Args: 52 | quantization_channels (int): Number of channels. default: 256 53 | """ 54 | mu = self.vocab_size - 1. 55 | if self.shift != 0: 56 | x_mu = x_mu - self.shift 57 | if isinstance(x_mu, np.ndarray): 58 | x = ((x_mu) / mu) * 2 - 1. 59 | return np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu 60 | elif isinstance(x_mu, (torch.Tensor, torch.LongTensor)): 61 | if isinstance(x_mu, (torch.LongTensor, torch.cuda.LongTensor)): 62 | x_mu = x_mu.float() 63 | mu = torch.FloatTensor([mu]).to(x_mu.device) 64 | x = ((x_mu) / mu) * 2 - 1. 65 | return torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu 66 | raise NotImplementedError() 67 | -------------------------------------------------------------------------------- /src/algos/models/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def sample_from_logits(logits, temperature=1.0, top_k=0, top_p=0.5): 8 | """ 9 | Adjusted from: 10 | - https://github.com/google-research/google-research/tree/master/multi_game_dt 11 | - https://github.com/etaoxing/multigame-dt 12 | """ 13 | logits = logits.double() 14 | if top_p > 0.0: 15 | # percentile: 0 to 100, quantile: 0 to 1 16 | # torch.quantile cannot handle float16 17 | percentile = torch.quantile(logits, top_p, dim=-1) 18 | if percentile != logits.max(): 19 | # otherwise all logits would become -inf 20 | logits = torch.where(logits > percentile.unsqueeze(-1), logits, -float("inf")) 21 | if top_k > 0: 22 | logits, top_indices = torch.topk(logits, top_k) 23 | try: 24 | sample = torch.distributions.Categorical(logits=temperature * logits).sample() 25 | except Exception as e: 26 | print(e, logits) 27 | if (logits == -float("inf")).all(): 28 | # uniformly sample 29 | sample = torch.distributions.Categorical(logits=torch.zeros_like(logits)).sample() 30 | if top_k > 0: 31 | sample_shape = sample.shape 32 | # Flatten top-k indices and samples for easy indexing. 33 | top_indices = torch.reshape(top_indices, [-1, top_k]) 34 | sample = sample.flatten() 35 | sample = top_indices[torch.arange(len(sample)), sample] 36 | # Reshape samples back to original dimensions. 37 | sample = torch.reshape(sample, sample_shape) 38 | return sample 39 | 40 | 41 | def position_encoding_init(n_position, d_pos_vec): 42 | ''' Init the sinusoid position encoding table ''' 43 | position_enc = np.array([ 44 | [pos / np.power(10000, 2*i/d_pos_vec) for i in range(d_pos_vec)] 45 | for pos in range(n_position)]) 46 | 47 | position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i 48 | position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1 49 | return torch.from_numpy(position_enc).type(torch.FloatTensor) 50 | 51 | 52 | def make_sinusoidal_embd(n_positions, embed_dim): 53 | position_enc = torch.nn.Embedding(n_positions, embed_dim) 54 | position_enc.weight.data = position_encoding_init(n_positions, embed_dim) 55 | return position_enc 56 | 57 | 58 | class SwiGLU(nn.Module): 59 | # SwiGLU https://arxiv.org/abs/2002.05202 60 | def forward(self, x): 61 | x, gate = x.chunk(2, dim=-1) 62 | return F.silu(gate) * x 63 | 64 | 65 | class GEGLU(nn.Module): 66 | """ 67 | References: 68 | Shazeer et al., "GLU Variants Improve Transformer," 2020. 69 | https://arxiv.org/abs/2002.05202 70 | """ 71 | 72 | def geglu(self, x): 73 | assert x.shape[-1] % 2 == 0 74 | a, b = x.chunk(2, dim=-1) 75 | return a * F.gelu(b) 76 | 77 | def forward(self, x): 78 | return self.geglu(x) 79 | -------------------------------------------------------------------------------- /src/tokenizers_custom/minmax_tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_tokenizer import BaseTokenizer 3 | 4 | 5 | class MinMaxTokenizer(BaseTokenizer): 6 | 7 | def __init__(self, min_val=-1, max_val=1, one_hot=False, **kwargs): 8 | super().__init__(**kwargs) 9 | self.min_val = min_val 10 | self.max_val = max_val 11 | self.one_hot = one_hot 12 | self.bin_width = (max_val - min_val) / self.vocab_size 13 | 14 | def tokenize(self, x): 15 | # Reshape the input tensor to have shape (batch_size, num_features) 16 | batch_size, num_features = x.shape[0], x.shape[1:] 17 | x = x.view(batch_size, -1) 18 | 19 | # Compute the indices of the bins 20 | tokens = ((x - self.min_val) / self.bin_width).long().clamp(min=0, max=self.vocab_size - 1) 21 | 22 | # Reshape the output tensor to have the same shape as the input tensor 23 | tokens = tokens.view(batch_size, *num_features) 24 | 25 | if self.shift != 0: 26 | return tokens + self.shift 27 | if self.one_hot: 28 | return torch.nn.functional.one_hot(tokens, num_classes=self.vocab_size).float().flatten(-2) 29 | return tokens 30 | 31 | def inv_tokenize(self, x): 32 | if self.one_hot: 33 | x = torch.argmax(x, dim=-1) 34 | if self.shift != 0: 35 | x = x - self.shift 36 | # can't be smaller than 0 37 | x[x < 0] = 0 38 | 39 | # Reshape the input tensor to have shape (batch_size, num_features) 40 | batch_size, num_features = x.shape[0], x.shape[1:] 41 | x = x.view(batch_size, -1) 42 | 43 | # Compute the values of the bins 44 | values = x.float() * self.bin_width + self.min_val 45 | 46 | # Reshape the output tensor to have the same shape as the input tensor 47 | return values.view(batch_size, *num_features) 48 | 49 | 50 | class MinMaxTokenizer2(BaseTokenizer): 51 | 52 | def __init__(self, min_val=-1, max_val=1, **kwargs): 53 | """ 54 | Tokenizes a given (action) input as described by: https://arxiv.org/abs/2212.06817 55 | Args: 56 | **kwargs: 57 | 58 | """ 59 | super().__init__(**kwargs) 60 | self.min_val = min_val 61 | self.max_val = max_val 62 | 63 | def tokenize(self, x): 64 | x = torch.clamp(x, self.min_val, self.max_val) 65 | # Normalize the action [batch, actions_size] 66 | tokens = (x - self.min_val) / (self.max_val - self.min_val) 67 | # Bucket and discretize the action to vocab_size, [batch, actions_size] 68 | tokens = (tokens * (self.vocab_size - 1)).long() 69 | if self.shift != 0: 70 | return tokens + self.shift 71 | return tokens 72 | 73 | def inv_tokenize(self, x): 74 | if self.shift != 0: 75 | x = x - self.shift 76 | x[x < 0] = 0 77 | x = x.float() / (self.vocab_size - 1) 78 | x = (x * (self.max_val - self.min_val)) + self.min_val 79 | return x 80 | -------------------------------------------------------------------------------- /precompute_img_embeds.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import omegaconf 3 | import functools 4 | import torch 5 | import numpy as np 6 | import h5py 7 | from joblib import delayed 8 | from tqdm import tqdm 9 | from src.utils import maybe_split 10 | from src.utils.misc import ProgressParallel 11 | from src.envs import make_env 12 | from src.algos.builder import make_agent 13 | from src.buffers.buffer_utils import load_hdf5, append_to_hdf5 14 | 15 | 16 | def encode_single_img_seq(img_encoder, path, device, max_batch_size=512): 17 | assert path.suffix == ".hdf5", "Only .hdf5 files are supported." 18 | try: 19 | observations, _, _, _, _ = load_hdf5(path) 20 | except Exception as e: 21 | print(f"Error reading from {path}.") 22 | raise e 23 | observations = torch.from_numpy(observations).float().to(device) / 255.0 24 | if observations.shape[0] > max_batch_size: 25 | img_embeds = [] 26 | for i in range(0, observations.shape[0], max_batch_size): 27 | with torch.no_grad(): 28 | # amp here? 29 | embeds = img_encoder(observations[i : i + max_batch_size]).detach().cpu().numpy() 30 | img_embeds.append(embeds) 31 | img_embeds = np.concatenate(img_embeds, axis=0) 32 | else: 33 | with torch.no_grad(): 34 | img_embeds = img_encoder(observations).detach().cpu().numpy() 35 | try: 36 | append_to_hdf5(path, {"states_encoded": img_embeds}) 37 | except Exception as e: 38 | print(f"Error writing to {path}.") 39 | raise e 40 | del observations, img_embeds 41 | if torch.cuda.is_available(): 42 | torch.cuda.empty_cache() 43 | 44 | 45 | def encode_image_sequences(img_encoder, paths, device, n_jobs=-1): 46 | img_encoder.eval() 47 | fn = functools.partial(encode_single_img_seq, img_encoder=img_encoder, device=device) 48 | ProgressParallel(n_jobs=n_jobs, total=len(paths), timeout=5000)(delayed(fn)(path=p) for p in paths) 49 | 50 | 51 | @hydra.main(config_path="configs", config_name="config") 52 | def main(config): 53 | """ 54 | For every sequence path, loads the sequence, encodes images and write encoded images to .hdf5 files. 55 | """ 56 | print("Config: \n", omegaconf.OmegaConf.to_yaml(config, resolve=True, sort_keys=True)) 57 | logdir = None 58 | env, _, _ = make_env(config, logdir) 59 | agent = make_agent(config, env, logdir) 60 | paths = agent.replay_buffer.trajectories 61 | if config.get("missing_only", False): 62 | missing = [] 63 | for p in tqdm(paths): 64 | try: 65 | with h5py.File(p, "r") as f: 66 | if "states_encoded" not in f: 67 | missing.append(p) 68 | except Exception as e: 69 | print(f"Error reading {p}.") 70 | raise e 71 | print(f"Found {len(missing)} missing paths: ", missing) 72 | paths = missing 73 | encode_image_sequences(agent.policy.embed_image, paths, 74 | device=agent.device, n_jobs=config.get("n_jobs", -1)) 75 | 76 | if __name__ == "__main__": 77 | omegaconf.OmegaConf.register_new_resolver("maybe_split", maybe_split) 78 | main() 79 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: radt 2 | channels: 3 | - anaconda 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1 7 | - _openmp_mutex=4.5 8 | - _tflow_select=2.1.0 9 | - aiohttp=3.8.1 10 | - aiosignal=1.2.0 11 | - astor=0.8.1 12 | - astunparse=1.6.3 13 | - async-timeout=4.0.1 14 | - attrs=21.4.0 15 | - blas=1.0 16 | - blinker=1.4 17 | - brotlipy=0.7.0 18 | - c-ares=1.18.1 19 | - ca-certificates=2022.3.29 20 | - cffi=1.15.0 21 | - click=8.0.4 22 | - cryptography=3.4.8 23 | - dataclasses=0.8 24 | - frozenlist=1.2.0 25 | - gast=0.4.0 26 | - google-pasta=0.2.0 27 | - hdf5=1.10.6 28 | - idna=3.3 29 | - importlib-metadata=4.11.3 30 | - intel-openmp=2021.4.0 31 | - keras-preprocessing=1.1.2 32 | - ld_impl_linux-64=2.35.1 33 | - libffi=3.3 34 | - libgcc-ng=9.3.0 35 | - libgfortran-ng=7.5.0 36 | - libgfortran4=7.5.0 37 | - libgomp=9.3.0 38 | - libprotobuf=3.19.1 39 | - libstdcxx-ng=9.3.0 40 | - mkl=2021.4.0 41 | - mkl-service=2.4.0 42 | - mkl_fft=1.3.1 43 | - mkl_random=1.2.2 44 | - multidict=5.2.0 45 | - ncurses=6.3 46 | - oauthlib=3.2.0 47 | - openssl=1.1.1n 48 | - opt_einsum=3.3.0 49 | - pip=21.2.4 50 | - pyasn1=0.4.8 51 | - pycparser=2.21 52 | - pyjwt=2.1.0 53 | - pyopenssl=21.0.0 54 | - pysocks=1.7.1 55 | - python=3.9.12 56 | - python-flatbuffers=2.0 57 | - readline=8.1.2 58 | - requests=2.27.1 59 | - six=1.16.0 60 | - sqlite=3.38.2 61 | - tk=8.6.11 62 | - typing_extensions 63 | - tzdata=2022a 64 | - urllib3=1.26.9 65 | - xz=5.2.5 66 | - yarl=1.6.3 67 | - zlib=1.2.12 68 | - pip: 69 | - absl-py==1.0.0 70 | - antlr4-python3-runtime==4.8 71 | - cachetools==5.0.0 72 | - certifi==2021.10.8 73 | - charset-normalizer==2.0.12 74 | - cloudpickle==2.0.0 75 | - cycler==0.11.0 76 | - cython==0.29.28 77 | - docker-pycreds==0.4.0 78 | - fasteners==0.17.3 79 | - filelock==3.7.0 80 | - flatbuffers==1.12 81 | - fonttools==4.33.3 82 | - future==0.18.2 83 | - gitdb==4.0.9 84 | - gitpython==3.1.27 85 | - glfw==2.5.3 86 | - google-auth==2.6.6 87 | - google-auth-oauthlib==0.4.6 88 | - grpcio==1.44.0 89 | - wheel==0.38.0 90 | - setuptools==65.5.0 91 | - gym-notices==0.0.6 92 | - hydra-core==1.1.2 93 | - imageio==2.18.0 94 | - kiwisolver==1.4.2 95 | - labmaze==1.0.5 96 | - libclang==14.0.1 97 | - lxml==4.8.0 98 | - markdown==3.3.6 99 | - matplotlib==3.5.1 100 | - numpy 101 | - omegaconf==2.1.2 102 | - packaging==21.3 103 | - pandas==1.4.2 104 | - pathtools==0.1.2 105 | - pillow==9.1.0 106 | - promise==2.3 107 | - protobuf==3.20.1 108 | - psutil==5.9.0 109 | - pyasn1-modules==0.2.8 110 | - pybullet==3.2.4 111 | - pyopengl==3.1.6 112 | - pyparsing==2.4.7 113 | - python-dateutil==2.8.2 114 | - pytz==2022.1 115 | - pyyaml==6.0 116 | - regex==2022.4.24 117 | - requests-oauthlib==1.3.1 118 | - rsa==4.8 119 | - scipy==1.8.0 120 | - seaborn==0.11.2 121 | - sentry-sdk==1.5.12 122 | - setproctitle==1.2.3 123 | - shortuuid==1.0.9 124 | - smmap==5.0.0 125 | - termcolor==1.1.0 126 | - tqdm==4.64.0 127 | - typing-extensions 128 | - werkzeug==2.1.2 129 | - wrapt==1.12.1 130 | - zipp==3.8.0 131 | -------------------------------------------------------------------------------- /src/algos/models/extractors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, create_mlp 4 | 5 | 6 | class FlattenExtractorWithMLP(FlattenExtractor): 7 | 8 | def __init__(self, observation_space, net_arch=None): 9 | super().__init__(observation_space) 10 | if net_arch is None: 11 | net_arch = [128, 128] 12 | 13 | mlp = create_mlp(self.features_dim, net_arch[-1], net_arch) 14 | self.mlp = nn.Sequential(*mlp) 15 | self._features_dim = net_arch[-1] 16 | 17 | def forward(self, observations): 18 | return self.mlp(self.flatten(observations)) 19 | 20 | 21 | class TextureFeatureExtractor(BaseFeaturesExtractor): 22 | """ 23 | Textures Feature Extractor for Crafter. Textures that at dim 21 24 | """ 25 | def __init__(self, observation_space, features_dim=256, texture_start_dim=21, num_textures=63, 26 | texture_embed_dim=4, textures_shape=(9,7), hidden_dim=192, **kwargs): 27 | super().__init__(observation_space, features_dim=features_dim) 28 | self.texture_start_dim = texture_start_dim 29 | self.texture_emb = nn.Embedding(num_textures + 1, texture_embed_dim) 30 | self.texture_net = nn.Sequential( 31 | nn.Linear(texture_embed_dim * textures_shape[0] * textures_shape[1], hidden_dim), 32 | nn.LeakyReLU(), 33 | nn.Linear(hidden_dim, hidden_dim), 34 | nn.LayerNorm(hidden_dim) 35 | ) 36 | self.out = nn.Linear(texture_start_dim + hidden_dim, self.features_dim) 37 | 38 | def forward(self, observations): 39 | # receives flatttened info + textures 40 | info, textures = observations[..., :self.texture_start_dim], observations[..., self.texture_start_dim:].long() 41 | texture_embeds = self.texture_emb(textures).flatten(-2) 42 | texture_features = self.texture_net(texture_embeds) 43 | x = torch.cat((info, texture_features), dim=-1) 44 | return self.out(x) 45 | 46 | 47 | def create_cwnet( 48 | input_dim: int, 49 | output_dim: int, 50 | net_arch=(256,256,256), 51 | # activation_fn=lambda: nn.LeakyReLU(negative_slope=0.2), 52 | activation_fn=nn.LeakyReLU, 53 | squash_output: bool = False, 54 | ): 55 | """ 56 | Creates the same Net as described in https://arxiv.org/pdf/2105.10919.pdf 57 | Basically just adds LayerNorm + Tanh after first Dense layer. 58 | 59 | :param input_dim: Dimension of the input vector 60 | :param output_dim: 61 | :param net_arch: Architecture of the neural net 62 | It represents the number of units per layer. 63 | The length of this list is the number of layers. 64 | :param activation_fn: The activation function 65 | to use after each layer. 66 | :param squash_output: Whether to squash the output using a Tanh 67 | activation function 68 | :return: 69 | """ 70 | 71 | if len(net_arch) > 0: 72 | modules = [nn.Linear(input_dim, net_arch[0])] 73 | else: 74 | modules = [] 75 | 76 | modules.append(nn.LayerNorm(net_arch[0])) 77 | modules.append(nn.Tanh()) 78 | 79 | for idx in range(len(net_arch) - 1): 80 | modules.append(nn.Linear(net_arch[idx], net_arch[idx + 1])) 81 | modules.append(activation_fn()) 82 | 83 | if output_dim > 0: 84 | last_layer_dim = net_arch[-1] if len(net_arch) > 0 else input_dim 85 | modules.append(nn.Linear(last_layer_dim, output_dim)) 86 | if squash_output: 87 | modules.append(nn.Tanh()) 88 | return modules 89 | -------------------------------------------------------------------------------- /configs/env_params/dark_keydoor.yaml: -------------------------------------------------------------------------------- 1 | envid: MiniHack-KeyDoor-Dark-Dense-10x10-v0 2 | target_return: [90,5] 3 | num_envs: 1 4 | norm_obs: False 5 | record: False 6 | record_freq: 1000000 7 | record_length: 2000 8 | reward_scale: 1 9 | 10 | # randomly selected positions 11 | # 80 for training, 20 for testing 12 | env_kwargs: 13 | random: False 14 | observation_keys: 15 | - tty_cursor 16 | 17 | train_start_pos: 18 | - [0,0] 19 | train_goal_pos: 20 | - [2,0] 21 | - [0,2] 22 | - [1,5] 23 | - [2,2] 24 | - [5,7] 25 | - [9,1] 26 | - [6,9] 27 | - [5,5] 28 | - [1,1] 29 | - [7,9] 30 | - [0,9] 31 | - [3,8] 32 | - [8,5] 33 | - [0,0] 34 | - [8,9] 35 | - [1,3] 36 | - [0,5] 37 | - [0,1] 38 | - [9,5] 39 | - [8,3] 40 | - [4,4] 41 | - [1,2] 42 | - [7,8] 43 | - [9,4] 44 | - [3,7] 45 | - [9,2] 46 | - [9,7] 47 | - [5,6] 48 | - [6,3] 49 | - [4,6] 50 | - [0,8] 51 | - [3,3] 52 | - [4,5] 53 | - [1,9] 54 | - [1,4] 55 | - [9,3] 56 | - [7,3] 57 | - [3,9] 58 | - [2,4] 59 | - [0,6] 60 | - [6,2] 61 | - [2,3] 62 | - [1,8] 63 | - [4,2] 64 | - [8,0] 65 | - [8,6] 66 | - [3,1] 67 | - [6,7] 68 | - [2,7] 69 | - [1,0] 70 | - [4,0] 71 | - [7,0] 72 | - [9,6] 73 | - [8,8] 74 | - [2,6] 75 | - [5,4] 76 | - [7,1] 77 | - [2,9] 78 | - [2,5] 79 | - [4,3] 80 | - [4,1] 81 | - [7,2] 82 | - [0,3] 83 | - [8,1] 84 | - [5,3] 85 | - [7,7] 86 | - [6,1] 87 | - [0,7] 88 | - [2,8] 89 | - [9,9] 90 | - [5,2] 91 | - [4,8] 92 | - [8,2] 93 | - [6,0] 94 | - [7,4] 95 | - [3,2] 96 | - [4,7] 97 | - [7,6] 98 | - [3,6] 99 | - [0,4] 100 | train_key_pos: 101 | - [8,3] 102 | - [5,3] 103 | - [7,0] 104 | - [4,5] 105 | - [4,4] 106 | - [3,9] 107 | - [2,2] 108 | - [8,0] 109 | - [1,0] 110 | - [0,0] 111 | - [1,8] 112 | - [3,0] 113 | - [7,3] 114 | - [3,3] 115 | - [9,0] 116 | - [0,4] 117 | - [7,6] 118 | - [7,7] 119 | - [1,2] 120 | - [3,1] 121 | - [5,5] 122 | - [8,8] 123 | - [2,6] 124 | - [4,2] 125 | - [6,9] 126 | - [1,5] 127 | - [4,0] 128 | - [9,6] 129 | - [0,9] 130 | - [7,2] 131 | - [1,1] 132 | - [4,7] 133 | - [8,5] 134 | - [2,8] 135 | - [9,3] 136 | - [0,5] 137 | - [6,6] 138 | - [6,5] 139 | - [3,5] 140 | - [1,6] 141 | - [4,9] 142 | - [3,4] 143 | - [0,7] 144 | - [9,5] 145 | - [2,7] 146 | - [1,9] 147 | - [8,1] 148 | - [2,5] 149 | - [6,2] 150 | - [1,3] 151 | - [2,4] 152 | - [0,3] 153 | - [1,7] 154 | - [3,8] 155 | - [0,8] 156 | - [7,8] 157 | - [0,6] 158 | - [6,4] 159 | - [3,6] 160 | - [8,9] 161 | - [5,6] 162 | - [9,9] 163 | - [5,4] 164 | - [4,3] 165 | - [5,0] 166 | - [6,7] 167 | - [4,6] 168 | - [6,8] 169 | - [6,1] 170 | - [9,7] 171 | - [7,9] 172 | - [4,1] 173 | - [5,8] 174 | - [4,8] 175 | - [9,8] 176 | - [5,7] 177 | - [7,5] 178 | - [3,2] 179 | - [9,4] 180 | - [5,9] 181 | 182 | # train tasks for sfixed, grand 183 | eval_start_pos: [0, 0] 184 | eval_goal_pos: 185 | - [9,0] 186 | - [8,4] 187 | - [3,4] 188 | - [6,5] 189 | - [7,5] 190 | - [5,0] 191 | - [3,5] 192 | - [9,8] 193 | - [8,7] 194 | - [3,0] 195 | - [6,6] 196 | - [5,9] 197 | - [1,7] 198 | - [5,1] 199 | - [1,6] 200 | - [5,8] 201 | - [2,1] 202 | - [4,9] 203 | - [6,4] 204 | - [6,8] 205 | eval_key_pos: 206 | - [6,3] 207 | - [3,1] 208 | - [3,7] 209 | - [2,9] 210 | - [0,1] 211 | - [5,2] 212 | - [2,1] 213 | - [0,2] 214 | - [2,3] 215 | - [8,7] 216 | - [9,1] 217 | - [7,4] 218 | - [8,6] 219 | - [8,2] 220 | - [2,0] 221 | - [6,0] 222 | - [7,1] 223 | - [1,4] 224 | - [9,2] 225 | - [5,1] 226 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import traceback 3 | import hydra 4 | import wandb 5 | import omegaconf 6 | from pathlib import Path 7 | from torch.distributed import init_process_group, destroy_process_group 8 | from src.utils import maybe_split, safe_mean 9 | 10 | 11 | def setup_logging(config): 12 | config_dict = omegaconf.OmegaConf.to_container(config, resolve=True, throw_on_missing=True) 13 | config_dict["PID"] = os.getpid() 14 | print(f"PID: {os.getpid()}") 15 | # hydra changes working directories automatically 16 | logdir = str(Path.joinpath(Path(os.getcwd()), config.logdir)) 17 | Path(logdir).mkdir(exist_ok=True, parents=True) 18 | print(f"Logdir: {logdir}") 19 | 20 | run = None 21 | if config.use_wandb: 22 | print("Setting up logging to Weights & Biases.") 23 | # make "wandb" path, otherwise WSL might block writing to dir 24 | wandb_path = Path.joinpath(Path(logdir), "wandb") 25 | wandb_path.mkdir(exist_ok=True, parents=True) 26 | wandb_params = omegaconf.OmegaConf.to_container(config.wandb_params, resolve=True, throw_on_missing=True) 27 | key, host = wandb_params.pop("key", None), wandb_params.pop("host", None) 28 | if key is not None and host is not None: 29 | wandb.login(key=key, host=host) 30 | config.wandb_params.update({"key": None, "host": None}) 31 | run = wandb.init(tags=[config.experiment_name], 32 | config=config_dict, **wandb_params) 33 | print(f"Writing Weights & Biases logs to: {str(wandb_path)}") 34 | run.log_code(hydra.utils.get_original_cwd()) 35 | return run, logdir 36 | 37 | 38 | def setup_ddp(): 39 | init_process_group(backend="nccl") 40 | 41 | 42 | @hydra.main(config_path="configs", config_name="config") 43 | def main(config): 44 | print("Config: \n", omegaconf.OmegaConf.to_yaml(config, resolve=True, sort_keys=True)) 45 | ddp = config.get("ddp", False) 46 | if ddp: 47 | setup_ddp() 48 | # make sure only global rank0 writes to wandb 49 | logdir = None 50 | global_rank = int(os.environ["RANK"]) 51 | if global_rank == 0: 52 | run, logdir = setup_logging(config) 53 | else: 54 | run, logdir = setup_logging(config) 55 | 56 | # imports after initializing ddp to avoid fork/spawn issues 57 | from src.envs import make_env 58 | from src.callbacks import make_callbacks 59 | from src.algos.builder import make_agent 60 | 61 | env, eval_env, train_eval_env = make_env(config, logdir) 62 | agent = make_agent(config, env, logdir) 63 | callbacks = make_callbacks(config, env=env, eval_env=eval_env, logdir=logdir, train_eval_env=train_eval_env) 64 | res, score = None, None 65 | try: 66 | res = agent.learn( 67 | **config.run_params, 68 | eval_env=eval_env, 69 | callback=callbacks 70 | ) 71 | except Exception as e: 72 | print(f"Exception: {e}") 73 | traceback.print_exc() 74 | finally: 75 | print("Finalizing run...") 76 | if config.use_wandb: 77 | if config.env_params.record: 78 | env.video_recorder.close() 79 | if not ddp or (ddp and global_rank == 0): 80 | run.finish() 81 | wandb.finish 82 | # return last avg reward for hparam optimization 83 | score = None if res is None else safe_mean([ep_info["r"] for ep_info in res.ep_info_buffer]) 84 | if ddp: 85 | destroy_process_group() 86 | if hasattr(agent, "cache"): 87 | agent.cache.cleanup_cache() 88 | return score 89 | 90 | 91 | if __name__ == "__main__": 92 | omegaconf.OmegaConf.register_new_resolver("maybe_split", maybe_split) 93 | main() 94 | -------------------------------------------------------------------------------- /configs/env_params/dark_keydoor_20x20.yaml: -------------------------------------------------------------------------------- 1 | envid: MiniHack-KeyDoor-Dark-Dense-20x20-v0 2 | target_return: 1 3 | num_envs: 1 4 | norm_obs: False 5 | record: False 6 | record_freq: 1000000 7 | record_length: 2000 8 | reward_scale: 1 9 | 10 | # randomly selected positions 11 | # 80 for training, 20 for testing 12 | env_kwargs: 13 | random: False 14 | observation_keys: 15 | - tty_cursor 16 | 17 | train_start_pos: 18 | - [0,0] 19 | train_goal_pos: 20 | - [2,15] 21 | - [13,11] 22 | - [19,10] 23 | - [5,5] 24 | - [11,3] 25 | - [17,18] 26 | - [1,0] 27 | - [15,9] 28 | - [19,15] 29 | - [6,19] 30 | - [6,18] 31 | - [9,5] 32 | - [4,14] 33 | - [4,2] 34 | - [18,1] 35 | - [19,2] 36 | - [7,3] 37 | - [17,15] 38 | - [16,8] 39 | - [17,9] 40 | - [15,1] 41 | - [3,4] 42 | - [1,7] 43 | - [11,8] 44 | - [19,16] 45 | - [6,16] 46 | - [3,7] 47 | - [12,3] 48 | - [8,19] 49 | - [11,14] 50 | - [16,18] 51 | - [14,4] 52 | - [0,12] 53 | - [11,13] 54 | - [17,1] 55 | - [3,2] 56 | - [11,10] 57 | - [2,19] 58 | - [11,9] 59 | - [5,3] 60 | - [16,2] 61 | - [19,9] 62 | - [0,15] 63 | - [10,7] 64 | - [2,4] 65 | - [12,17] 66 | - [17,5] 67 | - [19,1] 68 | - [8,16] 69 | - [13,3] 70 | - [0,3] 71 | - [10,5] 72 | - [11,1] 73 | - [16,1] 74 | - [0,17] 75 | - [12,7] 76 | - [6,1] 77 | - [8,5] 78 | - [13,1] 79 | - [18,18] 80 | - [16,5] 81 | - [1,9] 82 | - [19,8] 83 | - [1,10] 84 | - [0,0] 85 | - [16,7] 86 | - [9,10] 87 | - [13,6] 88 | - [15,12] 89 | - [1,18] 90 | - [18,4] 91 | - [3,9] 92 | - [15,6] 93 | - [1,16] 94 | - [18,8] 95 | - [6,6] 96 | - [7,0] 97 | - [18,0] 98 | - [14,9] 99 | - [9,3] 100 | train_key_pos: 101 | - [10,9] 102 | - [14,0] 103 | - [1,13] 104 | - [10,10] 105 | - [4,13] 106 | - [4,4] 107 | - [16,9] 108 | - [4,14] 109 | - [13,6] 110 | - [6,6] 111 | - [0,9] 112 | - [18,1] 113 | - [2,16] 114 | - [3,12] 115 | - [6,12] 116 | - [2,2] 117 | - [13,18] 118 | - [18,16] 119 | - [11,11] 120 | - [19,5] 121 | - [3,17] 122 | - [0,15] 123 | - [19,11] 124 | - [13,11] 125 | - [0,0] 126 | - [19,16] 127 | - [5,14] 128 | - [11,5] 129 | - [13,2] 130 | - [5,4] 131 | - [19,15] 132 | - [9,13] 133 | - [13,1] 134 | - [2,17] 135 | - [11,12] 136 | - [5,16] 137 | - [5,13] 138 | - [17,2] 139 | - [7,18] 140 | - [7,1] 141 | - [19,1] 142 | - [2,15] 143 | - [3,16] 144 | - [1,5] 145 | - [4,2] 146 | - [19,2] 147 | - [7,8] 148 | - [9,1] 149 | - [1,2] 150 | - [8,13] 151 | - [2,6] 152 | - [16,1] 153 | - [16,18] 154 | - [3,10] 155 | - [18,14] 156 | - [1,19] 157 | - [11,3] 158 | - [8,12] 159 | - [1,10] 160 | - [7,12] 161 | - [6,4] 162 | - [14,14] 163 | - [12,15] 164 | - [3,18] 165 | - [5,1] 166 | - [1,11] 167 | - [17,12] 168 | - [13,8] 169 | - [19,14] 170 | - [3,13] 171 | - [16,0] 172 | - [7,0] 173 | - [0,5] 174 | - [2,5] 175 | - [19,8] 176 | - [12,6] 177 | - [11,7] 178 | - [18,9] 179 | - [8,16] 180 | - [14,9] 181 | 182 | # train tasks for sfixed,grand 183 | eval_start_pos: [0,0] 184 | eval_goal_pos: 185 | - [16,12] 186 | - [9,15] 187 | - [1,6] 188 | - [5,17] 189 | - [7,17] 190 | - [8,17] 191 | - [10,14] 192 | - [9,19] 193 | - [10,10] 194 | - [16,17] 195 | - [8,6] 196 | - [9,13] 197 | - [15,8] 198 | - [15,18] 199 | - [15,15] 200 | - [3,18] 201 | - [17,17] 202 | - [10,8] 203 | - [5,2] 204 | - [0,18] 205 | eval_key_pos: 206 | - [0,3] 207 | - [0,18] 208 | - [10,2] 209 | - [12,10] 210 | - [13,14] 211 | - [3,3] 212 | - [12,8] 213 | - [15,1] 214 | - [5,8] 215 | - [4,10] 216 | - [11,13] 217 | - [16,15] 218 | - [5,18] 219 | - [11,0] 220 | - [9,0] 221 | - [15,14] 222 | - [18,13] 223 | - [19,0] 224 | - [11,19] 225 | - [3,15] -------------------------------------------------------------------------------- /configs/env_params/dark_keydoor_40x20.yaml: -------------------------------------------------------------------------------- 1 | envid: MiniHack-KeyDoor-Dark-Dense-40x20-v0 2 | target_return: 1 3 | num_envs: 1 4 | norm_obs: False 5 | record: False 6 | record_freq: 1000000 7 | record_length: 2000 8 | reward_scale: 1 9 | 10 | # randomly selected positions 11 | # 80 for training,20 for testing 12 | env_kwargs: 13 | random: False 14 | observation_keys: 15 | - tty_cursor 16 | 17 | train_start_pos: 18 | - [0,0] 19 | train_goal_pos: 20 | - [29,6] 21 | - [38,6] 22 | - [36,2] 23 | - [29,0] 24 | - [27,2] 25 | - [32,7] 26 | - [10,8] 27 | - [36,11] 28 | - [37,18] 29 | - [9,8] 30 | - [8,16] 31 | - [10,10] 32 | - [13,8] 33 | - [6,2] 34 | - [4,11] 35 | - [17,17] 36 | - [12,16] 37 | - [34,1] 38 | - [20,16] 39 | - [32,9] 40 | - [19,4] 41 | - [11,10] 42 | - [16,14] 43 | - [19,6] 44 | - [27,15] 45 | - [17,2] 46 | - [28,4] 47 | - [18,14] 48 | - [26,7] 49 | - [30,15] 50 | - [17,8] 51 | - [35,10] 52 | - [15,14] 53 | - [24,0] 54 | - [6,8] 55 | - [19,12] 56 | - [11,17] 57 | - [19,18] 58 | - [26,16] 59 | - [22,13] 60 | - [0,11] 61 | - [0,17] 62 | - [39,2] 63 | - [39,8] 64 | - [19,14] 65 | - [21,8] 66 | - [25,15] 67 | - [4,2] 68 | - [4,14] 69 | - [1,4] 70 | - [34,12] 71 | - [11,8] 72 | - [37,5] 73 | - [18,9] 74 | - [25,4] 75 | - [22,16] 76 | - [20,6] 77 | - [18,0] 78 | - [35,8] 79 | - [9,10] 80 | - [12,1] 81 | - [34,16] 82 | - [38,17] 83 | - [20,18] 84 | - [39,3] 85 | - [35,5] 86 | - [3,7] 87 | - [16,11] 88 | - [31,0] 89 | - [14,17] 90 | - [23,4] 91 | - [12,10] 92 | - [2,7] 93 | - [29,13] 94 | - [25,12] 95 | - [9,16] 96 | - [13,9] 97 | - [20,4] 98 | - [2,17] 99 | - [26,10] 100 | 101 | train_key_pos: 102 | - [34,16] 103 | - [33,7] 104 | - [3,3] 105 | - [26,13] 106 | - [3,6] 107 | - [31,1] 108 | - [17,6] 109 | - [24,10] 110 | - [38,0] 111 | - [22,16] 112 | - [3,5] 113 | - [14,6] 114 | - [31,15] 115 | - [3,7] 116 | - [16,7] 117 | - [19,15] 118 | - [12,4] 119 | - [18,17] 120 | - [30,10] 121 | - [26,6] 122 | - [32,18] 123 | - [26,9] 124 | - [31,2] 125 | - [36,1] 126 | - [18,0] 127 | - [1,10] 128 | - [13,0] 129 | - [31,17] 130 | - [37,6] 131 | - [28,10] 132 | - [10,15] 133 | - [3,18] 134 | - [28,18] 135 | - [19,3] 136 | - [1,19] 137 | - [1,3] 138 | - [34,12] 139 | - [19,18] 140 | - [39,16] 141 | - [37,1] 142 | - [6,19] 143 | - [12,10] 144 | - [8,14] 145 | - [16,3] 146 | - [29,15] 147 | - [26,5] 148 | - [21,3] 149 | - [29,16] 150 | - [30,4] 151 | - [26,14] 152 | - [13,5] 153 | - [5,9] 154 | - [33,6] 155 | - [14,14] 156 | - [16,16] 157 | - [18,8] 158 | - [22,6] 159 | - [16,13] 160 | - [9,18] 161 | - [8,8] 162 | - [24,1] 163 | - [36,3] 164 | - [32,15] 165 | - [15,6] 166 | - [11,11] 167 | - [21,2] 168 | - [2,9] 169 | - [39,7] 170 | - [17,17] 171 | - [4,6] 172 | - [27,5] 173 | - [9,12] 174 | - [39,15] 175 | - [1,13] 176 | - [1,11] 177 | - [9,19] 178 | - [18,5] 179 | - [24,6] 180 | - [29,14] 181 | - [28,8] 182 | 183 | # train tasks for sfixed,grand 184 | eval_start_pos: [0,0] 185 | eval_goal_pos: 186 | - [14,13] 187 | - [31,2] 188 | - [5,19] 189 | - [18,5] 190 | - [29,9] 191 | - [5,11] 192 | - [28,8] 193 | - [22,10] 194 | - [1,15] 195 | - [27,11] 196 | - [15,17] 197 | - [5,5] 198 | - [4,8] 199 | - [38,5] 200 | - [21,6] 201 | - [24,15] 202 | - [7,19] 203 | - [13,16] 204 | - [13,1] 205 | - [12,9] 206 | eval_key_pos: 207 | - [21,8] 208 | - [6,17] 209 | - [3,12] 210 | - [3,17] 211 | - [25,12] 212 | - [39,3] 213 | - [39,6] 214 | - [13,15] 215 | - [18,1] 216 | - [4,16] 217 | - [32,1] 218 | - [35,15] 219 | - [14,1] 220 | - [19,13] 221 | - [31,8] 222 | - [30,8] 223 | - [35,5] 224 | - [37,14] 225 | - [2,14] 226 | - [37,10] 227 | -------------------------------------------------------------------------------- /src/algos/discrete_decision_transformer_sb3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .universal_decision_transformer_sb3 import UDT 3 | from .models.model_utils import sample_from_logits 4 | 5 | 6 | class DiscreteDecisionTransformerSb3(UDT): 7 | 8 | def __init__(self, policy, env, loss_fn="ce", rtg_sample_kwargs=None, a_sample_kwargs=None, **kwargs): 9 | super().__init__(policy, env, loss_fn=loss_fn, **kwargs) 10 | self.rtg_sample_kwargs = {} if rtg_sample_kwargs is None else rtg_sample_kwargs 11 | self.a_sample_kwargs = a_sample_kwargs 12 | 13 | def get_action_pred(self, policy, states, actions, rewards, returns_to_go, timesteps, attention_mask, 14 | deterministic, prompt, is_eval=False, task_id=None, env_act_dim=None): 15 | inputs = { 16 | "states": states, 17 | "actions": actions, 18 | "rewards": rewards, 19 | "returns_to_go": returns_to_go, 20 | "timesteps": timesteps, 21 | "attention_mask": attention_mask, 22 | "return_dict": True, 23 | "deterministic": deterministic, 24 | "prompt": prompt, 25 | "task_id": task_id, 26 | "ddp_kwargs": self.ddp_kwargs, 27 | "use_inference_cache": self.use_inference_cache, 28 | "past_key_values": self.past_key_values # None by default 29 | } 30 | 31 | # exper-action inference mechanism 32 | if self.target_return_type == "infer": 33 | with torch.autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp): 34 | policy_output = policy(**inputs) 35 | return_logits = policy_output.return_preds[:, -1] 36 | return_sample = policy.sample_from_rtg_logits(return_logits, **self.rtg_sample_kwargs) 37 | inputs["returns_to_go"][0, -1] = return_sample 38 | 39 | if not self.policy.tok_a_target_only: 40 | # autoregressive action prediction 41 | # e.g., for discretizes continuous action space need to predict each action dim after another 42 | act_dim = actions.shape[-1] if env_act_dim is None else env_act_dim 43 | for i in range(act_dim): 44 | with torch.autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp): 45 | policy_output = policy(**inputs) 46 | if not is_eval and self.num_timesteps % 10000 == 0 and self.log_attn_maps: 47 | self._record_attention_maps(policy_output.attentions, step=self.num_timesteps, prefix="rollout") 48 | if policy_output.cross_attentions is not None: 49 | self._record_attention_maps(policy_output.cross_attentions, step=self.num_timesteps + i, 50 | prefix="rollout_cross", lower_triu=False) 51 | if self.a_sample_kwargs is not None: 52 | action_logits = policy_output.action_logits[0, -1, i] 53 | inputs["actions"][0, -1, i] = sample_from_logits(action_logits, **self.a_sample_kwargs) 54 | else: 55 | inputs["actions"][0, -1, i] = policy_output.action_preds[0, -1, i] 56 | if self.use_inference_cache: 57 | self.past_key_values = policy_output.past_key_values 58 | inputs["past_key_values"] = self.past_key_values 59 | action = inputs["actions"][0, -1] 60 | else: 61 | with torch.autocast(device_type='cuda', dtype=self.amp_dtype, enabled=self.use_amp): 62 | policy_output = policy(**inputs) 63 | action = policy_output.action_preds[0, -1] 64 | if self.use_inference_cache: 65 | self.past_key_values = policy_output.past_key_values 66 | 67 | if env_act_dim is not None: 68 | action = action[:env_act_dim] 69 | return action, inputs["returns_to_go"][0, -1] if self.target_return_type == "infer" else action 70 | -------------------------------------------------------------------------------- /src/callbacks/builder.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import wandb 3 | from wandb.integration.sb3 import WandbCallback 4 | from stable_baselines3.common.callbacks import CallbackList 5 | from ..algos import AGENT_CLASSES 6 | from ..envs.env_names import ID_TO_NAMES, MT50_ENVS, MT50_ENVS_v2, ATARI_ENVS, DM_CONTROL_ENVS, \ 7 | MINIHACK_ENVS, GYM_ENVS, PROCGEN_ENVS, MAZERUNNER_ENVS 8 | 9 | 10 | class CustomWandbCallback(WandbCallback): 11 | 12 | def __init__(self, model_sync_wandb=False, **kwargs): 13 | super().__init__(**kwargs) 14 | self.model_sync_wandb = model_sync_wandb 15 | 16 | def save_model(self) -> None: 17 | print(f"Saving model checkpoint to {self.path}") 18 | self.model.save(self.path) 19 | if self.model_sync_wandb: 20 | wandb.save(self.path, base_path=self.model_save_path) 21 | 22 | 23 | def make_callbacks(config, env=None, eval_env=None, logdir=None, train_eval_env=None): 24 | callbacks = [] 25 | if config.use_wandb and logdir is not None: 26 | model_save_path = None 27 | if config.wandb_callback_params.model_save_path is not None: 28 | model_save_path = f"{logdir}/{config.wandb_callback_params.model_save_path}" 29 | callbacks.append( 30 | CustomWandbCallback( 31 | gradient_save_freq=config.wandb_callback_params.gradient_save_freq, 32 | verbose=config.wandb_callback_params.verbose, model_save_path=model_save_path, 33 | model_sync_wandb=config.wandb_callback_params.get("model_sync_wandb", False), 34 | model_save_freq=config.wandb_callback_params.get("model_save_freq", 0) 35 | ) 36 | ) 37 | if config.eval_params.use_eval_callback: 38 | if config.agent_params.kind in AGENT_CLASSES.keys(): 39 | from .custom_eval_callback import CustomEvalCallback, MultiEnvEvalCallback 40 | if config.env_params.envid not in [*list(ID_TO_NAMES.keys()), *ATARI_ENVS, *MT50_ENVS, *MT50_ENVS_v2, 41 | *DM_CONTROL_ENVS, *MINIHACK_ENVS, *GYM_ENVS, 42 | *PROCGEN_ENVS, *MAZERUNNER_ENVS]: 43 | eval_callback_class = functools.partial(CustomEvalCallback, use_wandb=config.use_wandb) 44 | else: 45 | eval_callback_class = functools.partial(MultiEnvEvalCallback, use_wandb=config.use_wandb) 46 | else: 47 | from stable_baselines3.common.callbacks import EvalCallback 48 | eval_callback_class = EvalCallback 49 | if config.eval_params.max_no_improvement_evals > 0: 50 | from stable_baselines3.common.callbacks import StopTrainingOnNoModelImprovement 51 | stop_training_callback = StopTrainingOnNoModelImprovement( 52 | max_no_improvement_evals=config.eval_params.max_no_improvement_evals, verbose=1) 53 | else: 54 | stop_training_callback = None 55 | eval_callback_kwargs = { 56 | "n_eval_episodes": config.eval_params.n_eval_episodes, "eval_freq": config.eval_params.eval_freq, 57 | "callback_after_eval": stop_training_callback, "deterministic": config.eval_params.deterministic, 58 | "first_step": config.eval_params.get("first_step", True), 59 | "log_eval_trj": config.eval_params.get("log_eval_trj", False) 60 | } 61 | if config.eval_params.get("eval_on_train", False): 62 | train_eval_callback = eval_callback_class(eval_env=env, prefix="train_eval", **eval_callback_kwargs) 63 | callbacks.append(train_eval_callback) 64 | if train_eval_env is not None: 65 | train_eval_seeds_callback = eval_callback_class(eval_env=train_eval_env, prefix="train_eval_seeds", 66 | **eval_callback_kwargs) 67 | callbacks.append(train_eval_seeds_callback) 68 | 69 | eval_callback = eval_callback_class(eval_env=eval_env, **eval_callback_kwargs) 70 | callbacks.append(eval_callback) 71 | 72 | return CallbackList(callbacks) 73 | -------------------------------------------------------------------------------- /src/schedulers/visualize_schedulers.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import matplotlib.pyplot as plt 4 | import seaborn as sns 5 | from pathlib import Path 6 | from torch.optim.lr_scheduler import SequentialLR 7 | from lr_schedulers import make_lr_scheduler 8 | 9 | 10 | def visualize_lr_scheduler(lr_scheduler, total_steps=1e6, title="", save_dir=None): 11 | total_steps = int(total_steps) 12 | lrs = [] 13 | for _ in range(total_steps): 14 | lrs.append(lr_scheduler.get_last_lr()) 15 | lr_scheduler.step() 16 | plt.plot(lrs) 17 | plt.title(title) 18 | plt.xlabel("Steps") 19 | plt.ylabel("Learning rate") 20 | if save_dir is not None: 21 | save_dir = Path(save_dir) 22 | save_dir.mkdir(parents=True, exist_ok=True) 23 | plt.savefig(save_dir / f"{title}.png", bbox_inches='tight') 24 | plt.show() 25 | 26 | 27 | if __name__ == "__main__": 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--save_dir", type=str, default="./figures") 30 | args = parser.parse_args() 31 | 32 | sns.set_style("whitegrid") 33 | net = torch.nn.Linear(10, 10) 34 | total_steps = int(1e6) 35 | lr = 0.0001 36 | eta_min = 0.000001 37 | warmup_steps = 50000 38 | 39 | # cosine 40 | optimizer = torch.optim.AdamW(net.parameters(), lr=lr) 41 | sched_kwargs = {"eta_min": eta_min, "T_max": total_steps} 42 | lr_scheduler = make_lr_scheduler(optimizer, kind="cosine", sched_kwargs=sched_kwargs) 43 | warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)) 44 | lr_scheduler = SequentialLR(optimizer, [warmup, lr_scheduler], milestones=[warmup_steps]) 45 | visualize_lr_scheduler(lr_scheduler, total_steps=total_steps, title="cosine", save_dir=args.save_dir) 46 | 47 | # cosine 48 | optimizer = torch.optim.AdamW(net.parameters(), lr=lr) 49 | sched_kwargs = {"eta_min": eta_min, "T_max": total_steps / 5} 50 | lr_scheduler = make_lr_scheduler(optimizer, kind="cosine", sched_kwargs=sched_kwargs) 51 | warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)) 52 | lr_scheduler = SequentialLR(optimizer, [warmup, lr_scheduler], milestones=[warmup_steps]) 53 | visualize_lr_scheduler(lr_scheduler, total_steps=total_steps, title="cosine_2", save_dir=args.save_dir) 54 | 55 | # cosine_restart 56 | optimizer = torch.optim.AdamW(net.parameters(), lr=lr) 57 | sched_kwargs = {"eta_min": eta_min, "T_0": int(total_steps / 5)} 58 | lr_scheduler = make_lr_scheduler(optimizer, kind="cosine_restart", sched_kwargs=sched_kwargs) 59 | warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)) 60 | lr_scheduler = SequentialLR(optimizer, [warmup, lr_scheduler], milestones=[warmup_steps]) 61 | visualize_lr_scheduler(lr_scheduler, total_steps=total_steps, title="cosine_restart", save_dir=args.save_dir) 62 | 63 | # cosine_restart 64 | optimizer = torch.optim.AdamW(net.parameters(), lr=lr) 65 | sched_kwargs = {"eta_min": eta_min, "T_0": int(total_steps / 5), "T_mult": 2} 66 | lr_scheduler = make_lr_scheduler(optimizer, kind="cosine_restart", sched_kwargs=sched_kwargs) 67 | warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)) 68 | lr_scheduler = SequentialLR(optimizer, [warmup, lr_scheduler], milestones=[warmup_steps]) 69 | visualize_lr_scheduler(lr_scheduler, total_steps=total_steps, title="cosine_restart_2", save_dir=args.save_dir) 70 | 71 | # cyclic 72 | optimizer = torch.optim.AdamW(net.parameters(), lr=lr) 73 | sched_kwargs = {"base_lr": lr, "max_lr": 0.001, "step_size_up": 1e5} 74 | lr_scheduler = make_lr_scheduler(optimizer, kind="cyclic", sched_kwargs=sched_kwargs) 75 | warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)) 76 | lr_scheduler = SequentialLR(optimizer, [warmup, lr_scheduler], milestones=[warmup_steps]) 77 | visualize_lr_scheduler(lr_scheduler, total_steps=total_steps, title="cyclic", save_dir=args.save_dir) 78 | 79 | # step 80 | optimizer = torch.optim.AdamW(net.parameters(), lr=lr) 81 | sched_kwargs = {"gamma": 0.1, "step_size": 2e5} 82 | lr_scheduler = make_lr_scheduler(optimizer, kind="step", sched_kwargs=sched_kwargs) 83 | warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)) 84 | lr_scheduler = SequentialLR(optimizer, [warmup, lr_scheduler], milestones=[warmup_steps]) 85 | visualize_lr_scheduler(lr_scheduler, total_steps=total_steps, title="step", save_dir=args.save_dir) 86 | 87 | # exponential 88 | optimizer = torch.optim.AdamW(net.parameters(), lr=lr) 89 | sched_kwargs = {"gamma": eta_min} 90 | lr_scheduler = make_lr_scheduler(optimizer, kind="exp", sched_kwargs=sched_kwargs) 91 | warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)) 92 | lr_scheduler = SequentialLR(optimizer, [warmup, lr_scheduler], milestones=[warmup_steps]) 93 | visualize_lr_scheduler(lr_scheduler, total_steps=total_steps, title="exponential", save_dir=args.save_dir) 94 | -------------------------------------------------------------------------------- /src/envs/hn_scores.py: -------------------------------------------------------------------------------- 1 | # Adjusted from: 2 | # - https://github.com/deepmind/dqn_zoo/blob/807379c19f8819e407329ac1b95dcaccb9d536c3/dqn_zoo/atari_data.py 3 | # - https://github.com/etaoxing/multigame-dt/blob/master/atari_data.py 4 | 5 | # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # ============================================================================== 19 | """Utilities to compute human-normalized Atari scores. 20 | The data used in this module is human and random performance data on Atari-57. 21 | It comprises of evaluation scores (undiscounted returns), each averaged 22 | over at least 3 episode runs, on each of the 57 Atari games. Each episode begins 23 | with the environment already stepped with a uniform random number (between 1 and 24 | 30 inclusive) of noop actions. 25 | The two agents are: 26 | * 'random' (agent choosing its actions uniformly randomly on each step) 27 | * 'human' (professional human game tester) 28 | Scores are obtained by averaging returns over the episodes played by each agent, 29 | with episode length capped to 108,000 frames (i.e. timeout after 30 minutes). 30 | The term 'human-normalized' here means a linear per-game transformation of 31 | a game score in such a way that 0 corresponds to random performance and 1 32 | corresponds to human performance. 33 | """ 34 | 35 | import math 36 | from .env_names import ATARI_NAME_TO_ENVID 37 | 38 | 39 | # Game: score-tuple dictionary. Each score tuple contains 40 | # 0: score random (float) and 1: score human (float). 41 | ENVID_TO_HNS = { 42 | "alien": (227.8, 7127.7), 43 | "amidar": (5.8, 1719.5), 44 | "assault": (222.4, 742.0), 45 | "asterix": (210.0, 8503.3), 46 | "asteroids": (719.1, 47388.7), 47 | "atlantis": (12850.0, 29028.1), 48 | "bank-heist": (14.2, 753.1), 49 | "battle-zone": (2360.0, 37187.5), 50 | "beam-rider": (363.9, 16926.5), 51 | "berzerk": (123.7, 2630.4), 52 | "bowling": (23.1, 160.7), 53 | "boxing": (0.1, 12.1), 54 | "breakout": (1.7, 30.5), 55 | "centipede": (2090.9, 12017.0), 56 | "chopper-command": (811.0, 7387.8), 57 | "crazy-climber": (10780.5, 35829.4), 58 | "defender": (2874.5, 18688.9), 59 | "demon-attack": (152.1, 1971.0), 60 | "double-dunk": (-18.6, -16.4), 61 | "enduro": (0.0, 860.5), 62 | "fishing-derby": (-91.7, -38.7), 63 | "freeway": (0.0, 29.6), 64 | "frostbite": (65.2, 4334.7), 65 | "gopher": (257.6, 2412.5), 66 | "gravitar": (173.0, 3351.4), 67 | "hero": (1027.0, 30826.4), 68 | "ice-hockey": (-11.2, 0.9), 69 | "jamesbond": (29.0, 302.8), 70 | "kangaroo": (52.0, 3035.0), 71 | "krull": (1598.0, 2665.5), 72 | "kung-fu-master": (258.5, 22736.3), 73 | "montezuma-revenge": (0.0, 4753.3), 74 | "ms-pacman": (307.3, 6951.6), 75 | "name-this-game": (2292.3, 8049.0), 76 | "phoenix": (761.4, 7242.6), 77 | "pitfall": (-229.4, 6463.7), 78 | "pong": (-20.7, 14.6), 79 | "private-eye": (24.9, 69571.3), 80 | "qbert": (163.9, 13455.0), 81 | "riverraid": (1338.5, 17118.0), 82 | "road-runner": (11.5, 7845.0), 83 | "robotank": (2.2, 11.9), 84 | "seaquest": (68.4, 42054.7), 85 | "skiing": (-17098.1, -4336.9), 86 | "solaris": (1236.3, 12326.7), 87 | "space-invaders": (148.0, 1668.7), 88 | "star-gunner": (664.0, 10250.0), 89 | "surround": (-10.0, 6.5), 90 | "tennis": (-23.8, -8.3), 91 | "time-pilot": (3568.0, 5229.2), 92 | "tutankham": (11.4, 167.6), 93 | "up-n-down": (533.4, 11693.2), 94 | "venture": (0.0, 1187.5), 95 | # Note the random agent score on Video Pinball is sometimes greater than the 96 | # human score under other evaluation methods. 97 | "video-pinball": (16256.9, 17667.9), 98 | "wizard-of-wor": (563.5, 4756.5), 99 | "yars-revenge": (3092.9, 54576.9), 100 | "zaxxon": (32.5, 9173.3), 101 | # extracted from procgen paper 102 | "bigfish": (1,40), 103 | "bossfight": (0.5,13), 104 | "caveflyer": (3.5,12), 105 | "chaser": (0.5,13), 106 | "climber": (2,12.6), 107 | "coinrun": (5,10), 108 | "dodgeball": (1.5,19), 109 | "fruitbot": (-1.5,32.4), 110 | "heist": (3.5,10), 111 | "jumper": (3,10), 112 | "leaper": (3,10), 113 | "maze": (5,10), 114 | "miner": (1.5,13), 115 | "ninja": (3.5,10), 116 | "plunder": (4.5,30), 117 | "starpilot": (2.5,64), 118 | } 119 | 120 | # add scores for actual env ids 121 | keys = list(ENVID_TO_HNS.keys()) 122 | for k in keys: 123 | if k not in ATARI_NAME_TO_ENVID: 124 | continue 125 | envid = ATARI_NAME_TO_ENVID[k] 126 | ENVID_TO_HNS[envid] = ENVID_TO_HNS[k] 127 | 128 | 129 | def get_human_normalized_score(game: str, raw_score: float, random_col=0, human_col=1) -> float: 130 | """Converts game score to human-normalized score.""" 131 | game_scores = ENVID_TO_HNS.get(game, (math.nan, math.nan)) 132 | random, human = game_scores[random_col], game_scores[human_col] 133 | return (raw_score - random) / (human - random) 134 | -------------------------------------------------------------------------------- /src/envs/procgen_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import omegaconf 4 | from copy import deepcopy 5 | from procgen import ProcgenEnv, ProcgenGym3Env 6 | from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor, VecTransposeImage, VecNormalize 7 | from stable_baselines3.common.env_util import DummyVecEnv 8 | from procgen.env import ToBaselinesVecEnv 9 | 10 | 11 | class StickyBaselinesEnv(ToBaselinesVecEnv): 12 | """ 13 | Sticky action wrapper for Procgen. 14 | """ 15 | def __init__(self, env, p_sticky: float = 0.1, **kwargs): 16 | super().__init__(env, **kwargs) 17 | self.p_sticky = p_sticky 18 | self.sticky_actions = np.zeros(self.num_envs, dtype=int) 19 | 20 | def reset(self): 21 | self.sticky_actions.fill(0) 22 | return super().reset() 23 | 24 | def step_async(self, ac): 25 | repeat_action_mask = np.random.rand(self.num_envs) < self.p_sticky 26 | self.sticky_actions[~repeat_action_mask] = ac[~repeat_action_mask] 27 | return super().step_async(self.sticky_actions) 28 | 29 | 30 | def make_sticky_procgen_env(env_name, num_envs, p_sticky=0, **kwargs): 31 | return StickyBaselinesEnv(ProcgenGym3Env(num=num_envs, env_name=env_name, **kwargs), p_sticky=p_sticky) 32 | 33 | 34 | def get_procgen_constructor(envid, distribution_mode="easy", time_limit=None, env_kwargs=None): 35 | env_kwargs = env_kwargs if env_kwargs is not None else {} 36 | if isinstance(env_kwargs, omegaconf.DictConfig): 37 | env_kwargs = omegaconf.OmegaConf.to_container(env_kwargs, resolve=True, throw_on_missing=True) 38 | p_sticky = env_kwargs.pop("p_sticky", 0) 39 | num_envs = env_kwargs.pop("num_envs", 1) 40 | norm_reward = env_kwargs.pop("norm_reward", False) 41 | def make(): 42 | if p_sticky > 0: 43 | env = make_sticky_procgen_env(env_name=envid, num_envs=num_envs, distribution_mode=distribution_mode, 44 | p_sticky=p_sticky, **env_kwargs) 45 | else: 46 | env = ProcgenEnv(env_name=envid, num_envs=num_envs, 47 | distribution_mode=distribution_mode, **env_kwargs) 48 | # monitor to obtain ep_rew_mean, ep_rew_len + extract rgb images from dict states 49 | env = CustomVecMonitor(VecExtractDictObs(env, 'rgb'), time_limit=time_limit) 50 | env = VecTransposeImage(env) 51 | if norm_reward: 52 | env = VecNormalize(env, norm_obs=False, norm_reward=True) 53 | env.name = envid 54 | return env 55 | return make 56 | 57 | 58 | class CustomVecMonitor(VecMonitor): 59 | """ 60 | Custom version of VecMonitor that allows for a timelimit. 61 | Once, timelimit is hit, we also need to reset the environment. 62 | We can however, not save the reset state there. 63 | """ 64 | def __init__( 65 | self, 66 | venv, 67 | filename=None, 68 | info_keywords=(), 69 | time_limit=None 70 | ): 71 | super().__init__(venv, filename, info_keywords) 72 | self.time_limit = time_limit 73 | 74 | def step_wait(self): 75 | obs, rewards, dones, infos = self.venv.step_wait() 76 | self.episode_returns += rewards 77 | self.episode_lengths += 1 78 | new_infos = list(infos[:]) 79 | for i in range(len(dones)): 80 | if self.time_limit is not None and self.episode_lengths[i] >= self.time_limit: 81 | dones[i] = True 82 | # send action -1 to reset ProcgenEnv: https://github.com/openai/procgen/issues/40#issuecomment-633720234 83 | if self.num_envs > 1: 84 | raise NotImplementedError("Resetting ProcgenEnv with multiple environments is not supported.") 85 | self.venv.step(np.ones((1,), dtype=int) * -1) 86 | if dones[i]: 87 | info = infos[i].copy() 88 | episode_return = self.episode_returns[i] 89 | episode_length = self.episode_lengths[i] 90 | episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)} 91 | for key in self.info_keywords: 92 | episode_info[key] = info[key] 93 | info["episode"] = episode_info 94 | self.episode_count += 1 95 | self.episode_returns[i] = 0 96 | self.episode_lengths[i] = 0 97 | if self.results_writer: 98 | self.results_writer.write_row(episode_info) 99 | new_infos[i] = info 100 | return obs, rewards, dones, new_infos 101 | 102 | 103 | class CustomDummyVecEnv(DummyVecEnv): 104 | """ 105 | Custom version of DummyVecEnv that allows wrapping ProcgenEnvs. 106 | By default, ProcgenEnvs are vectorized already. 107 | Therefore wrapping different tasks in a single DummyVecEnv fails, due to returning of vectorized infor buffers. 108 | """ 109 | def step_wait(self): 110 | for env_idx in range(self.num_envs): 111 | obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step( 112 | self.actions[env_idx] 113 | ) 114 | if self.buf_dones[env_idx]: 115 | # save final observation where user can get it, then reset 116 | # self.buf_infos[env_idx]terminal_observation"] = obs 117 | self.buf_infos[env_idx][0]["terminal_observation"] = obs 118 | obs = self.envs[env_idx].reset() 119 | self._save_obs(env_idx, obs) 120 | return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) 121 | -------------------------------------------------------------------------------- /src/algos/models/universal_decision_transformer_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from stable_baselines3.common.distributions import SquashedDiagGaussianDistribution, DiagGaussianDistribution 3 | from stable_baselines3.common import preprocessing 4 | from .online_decision_transformer_model import OnlineDecisionTransformerModel, OnlineDecisionTransformerOutput 5 | from .image_encoders import make_image_encoder 6 | 7 | 8 | class DummyUDTModel(OnlineDecisionTransformerModel): 9 | 10 | def __init__(self, config, observation_space, action_space, **kwargs): 11 | """ 12 | Class for testing purposes. 13 | Replaces transformer policy by regular feedforward net. Keeps the function headers the same in order 14 | to avoid unnecessary overhead. 15 | 16 | Args: 17 | config: 18 | action_space: 19 | **kwargs: 20 | """ 21 | super().__init__(config, observation_space, action_space, **kwargs) 22 | 23 | def setup_policy(self): 24 | features_dim = self.config.hidden_size 25 | act_dim = self.config.act_dim if not self.is_discrete else self.action_space.n 26 | if self.is_image_space: 27 | self.encoder = make_image_encoder(observation_space=self.observation_space, 28 | features_dim=self.config.hidden_size, encoder_kwargs=self.encoder_kwargs) 29 | else: 30 | self.encoder = nn.Sequential( 31 | nn.Linear(self.config.state_dim, features_dim), 32 | nn.LayerNorm(features_dim), 33 | nn.Tanh(), 34 | nn.Linear(features_dim, features_dim), 35 | nn.LeakyReLU(), 36 | nn.Linear(features_dim, features_dim), 37 | nn.LeakyReLU(), 38 | ) 39 | if self.stochastic_policy: 40 | self.mu = nn.Linear(features_dim, act_dim) 41 | self.log_std = nn.Linear(features_dim, act_dim) 42 | self.action_dist = SquashedDiagGaussianDistribution(act_dim) if self.config.action_tanh \ 43 | else DiagGaussianDistribution(act_dim) 44 | else: 45 | self.action_pred = nn.Sequential( 46 | *([nn.Linear(features_dim, act_dim)] + ([nn.Tanh()] if self.config.action_tanh else [])) 47 | ) 48 | del self.embed_timestep 49 | del self.embed_return 50 | del self.embed_state 51 | del self.embed_action 52 | del self.embed_ln 53 | del self.predict_state 54 | del self.predict_return 55 | 56 | def forward( 57 | self, 58 | states=None, 59 | actions=None, 60 | rewards=None, 61 | returns_to_go=None, 62 | timesteps=None, 63 | attention_mask=None, 64 | output_hidden_states=None, 65 | output_attentions=None, 66 | return_dict=None, 67 | deterministic=True, 68 | with_log_probs=False, 69 | prompt=None, 70 | task_id=None, 71 | ): 72 | """ 73 | Overwrites the original forward, as Transformer steps are not required for this model. 74 | Just takes states and predicts actions. 75 | """ 76 | if self.is_image_space: 77 | states = preprocessing.preprocess_obs(states, observation_space=self.observation_space, normalize_images=True) 78 | 79 | if self.is_image_space and len(states.shape) > 4: 80 | batch_size, seq_len = states.shape[0], states.shape[1] 81 | state = states.reshape(-1, *self.observation_space.shape) 82 | x = self.encoder(state).reshape(batch_size, seq_len, self.config.hidden_size) 83 | else: 84 | x = self.encoder(states) 85 | state_preds, action_preds, action_log_probs, return_preds, reward_preds, action_logits, entropy = self.get_predictions( 86 | x, with_log_probs=with_log_probs, deterministic=deterministic 87 | ) 88 | return OnlineDecisionTransformerOutput( 89 | last_hidden_state=None, 90 | state_preds=state_preds, 91 | action_preds=action_preds, 92 | return_preds=return_preds, 93 | hidden_states=None, 94 | attentions=None, 95 | action_log_probs=action_log_probs, 96 | reward_preds=reward_preds, 97 | action_logits=action_logits, 98 | entropy=entropy, 99 | last_encoder_output=x 100 | ) 101 | 102 | def get_predictions(self, x, with_log_probs=False, deterministic=False, task_id=None): 103 | action_log_probs, reward_preds, action_logits, entropy = None, None, None, None 104 | if with_log_probs: 105 | action_preds, action_log_probs = self.action_log_prob(x, task_id=task_id) 106 | else: 107 | action_preds = self.predict_action(x, deterministic=deterministic, task_id=task_id) 108 | if self.reward_condition: 109 | reward_preds = self.predict_reward(x) 110 | return None, action_preds, action_log_probs, None, reward_preds, action_logits, entropy 111 | 112 | def compute_hidden_states( 113 | self, 114 | states=None, 115 | actions=None, 116 | rewards=None, 117 | returns_to_go=None, 118 | timesteps=None, 119 | attention_mask=None, 120 | output_hidden_states=None, 121 | output_attentions=None, 122 | return_dict=None, 123 | prompt=None, 124 | task_id=None, 125 | past_key_values=None, 126 | use_inference_cache=False 127 | ): 128 | return self.encoder(states), None, None 129 | 130 | def _init_weights(self, module): 131 | # use default initialization for dummy net 132 | pass 133 | -------------------------------------------------------------------------------- /src/utils/misc.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import pandas as pd 4 | import matplotlib.pyplot as plt 5 | import seaborn as sns 6 | import joblib 7 | from tqdm import tqdm 8 | 9 | 10 | def maybe_split(dir_name: str) -> str: 11 | """ 12 | Recursively splits a given dir_name at half, once it exceeds max folder size of 255. 13 | """ 14 | if len(dir_name) > 255: 15 | half = len(dir_name) // 2 16 | dir_name = maybe_split(dir_name[:half]) + "/" + maybe_split(dir_name[half:]) 17 | return dir_name 18 | 19 | def safe_mean(arr): 20 | return np.nan if len(arr) == 0 else np.mean(arr) 21 | 22 | 23 | def set_frozen_to_eval(module): 24 | requires_grad = [] 25 | for p in module.parameters(): 26 | requires_grad.append(p.requires_grad) 27 | if not any(requires_grad): 28 | module.eval() 29 | 30 | 31 | def make_attention_maps(attention_scores, step, lower_triu=True, vmin=None, vmax=None): 32 | """ 33 | attention_scores: Tuple of `torch.FloatTensor` (one for each layer) of shape 34 | `(batch_size, num_heads, sequence_length,sequence_length)`. 35 | step: Int. Current timestep 36 | 37 | """ 38 | figures = {} 39 | mask = None 40 | for i, scores in enumerate(attention_scores): 41 | # first attention head 42 | if scores is None: 43 | print(f"Attention scores for layer {i} are None. Skipping") 44 | continue 45 | scores = scores.float().detach().cpu().numpy() 46 | h0_scores = scores[-1, 0] 47 | fig, ax = plt.subplots() 48 | if lower_triu: 49 | mask = np.triu(np.ones_like(h0_scores, dtype=bool)) 50 | np.fill_diagonal(mask, False) 51 | sns.heatmap(h0_scores, cmap="rocket_r", mask=mask, ax=ax, vmin=vmin, vmax=vmax) 52 | ax.set_title(f"Timestep: {step}, Layer: {i}, Head: 0") 53 | figures[f"layer{i}_head0"] = fig 54 | # avg over all heads 55 | avg_scores = scores[-1].mean(0) 56 | fig, ax = plt.subplots() 57 | if lower_triu: 58 | mask = np.triu(np.ones_like(avg_scores, dtype=bool)) 59 | np.fill_diagonal(mask, False) 60 | sns.heatmap(avg_scores, cmap="rocket_r", mask=mask, ax=ax, vmin=vmin, vmax=vmax) 61 | ax.set_title(f"Timestep: {step}, Layer: {i}, Head: all") 62 | figures[f"layer{i}_allheads"] = fig 63 | return figures 64 | 65 | 66 | def make_qk_dist_plot(key, query, step): 67 | key, query = key.squeeze(), query.squeeze() 68 | df_key = pd.DataFrame(key.T, columns=[f"k{i}" for i in range(key.shape[0])]) 69 | df_query = pd.DataFrame(query.T, columns=[f"q{i}" for i in range(query.shape[0])]) 70 | df = pd.concat([df_key, df_query], axis=1).T 71 | fig, ax = plt.subplots() 72 | sns.heatmap(df, cmap="rocket_r", ax=ax) 73 | ax.set_title(f"Timestep: {str(step)}") 74 | ax.set_xlabel("Feature dimension") 75 | ax.set_ylabel("Q-K index") 76 | return fig 77 | 78 | 79 | def make_sim_plot(sim, step, max_samples=5): 80 | """ 81 | Make heatmap from given similarity matrix. 82 | Args: 83 | sim: np.ndarray of shape (batch_size x pool_size) 84 | step: Int. 85 | max_samples: Int. Max samples to use (across batch size). Matrix becomes unreadable for more than 10 samples. 86 | 87 | Returns: Matplotlib figure. 88 | 89 | """ 90 | fig, ax = plt.subplots(figsize=(max_samples, sim.shape[1] * 0.3)) 91 | if sim.shape[0] > max_samples: 92 | sim = sim[:max_samples] 93 | sns.heatmap(sim.T, cmap="rocket_r", ax=ax, annot=True) 94 | ax.set_title(f"Timestep: {str(step)}") 95 | ax.set_xlabel("Batch idx") 96 | ax.set_ylabel("Pool idx") 97 | return fig 98 | 99 | 100 | def make_retrieved_states_plot(state, action, states_retrieved, actions_retrieved, step): 101 | """ 102 | Plots retrieved states next to current state in addition to the performed actions as title. 103 | 104 | Args: 105 | state: np.ndarray of shape (H, W, C) 106 | states_retrieved: np.ndarray of shape (B, H, W, C) 107 | action: np.ndarray of shape (1, act_dim) 108 | actions_retrieved: np.ndarray of shape (B, act_dim) 109 | step: int 110 | 111 | Returns: Matplotlib figure. 112 | 113 | """ 114 | num_retrieved = len(states_retrieved) 115 | fig, axs = plt.subplots(1, num_retrieved + 1, figsize=(4 * (num_retrieved + 1), 4)) 116 | state, states_retrieved = state.cpu().numpy(), states_retrieved.cpu().numpy() 117 | action, actions_retrieved = action.cpu().numpy(), actions_retrieved.cpu().numpy() 118 | if len(state.shape) == 3: 119 | state = state.squeeze(0) 120 | 121 | # Plot current state in first subplot 122 | axs[0].imshow(state) 123 | axs[0].set_title(f"Current state | Action: {action}") 124 | 125 | # Plot retrieved states in remaining subplots 126 | for i, ret in enumerate(states_retrieved): 127 | if len(ret.shape) == 3: 128 | ret = ret.squeeze(0) 129 | axs[i + 1].imshow(ret) 130 | axs[i + 1].set_title(f"Retrieved {i+1} | Action: {actions_retrieved[i]}") 131 | 132 | fig.suptitle(f"Evaluation Step: {step}") 133 | 134 | return fig 135 | 136 | 137 | class ProgressParallel(joblib.Parallel): 138 | # from: https://stackoverflow.com/questions/37804279/how-can-we-use-tqdm-in-a-parallel-execution-with-joblib 139 | def __init__(self, use_tqdm=True, total=None, *args, **kwargs): 140 | self._use_tqdm = use_tqdm 141 | self._total = total 142 | super().__init__(*args, **kwargs) 143 | 144 | def __call__(self, *args, **kwargs): 145 | with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar: 146 | return joblib.Parallel.__call__(self, *args, **kwargs) 147 | 148 | def print_progress(self): 149 | if self._total is None: 150 | self._pbar.total = self.n_dispatched_tasks 151 | self._pbar.n = self.n_completed_tasks 152 | self._pbar.refresh() 153 | -------------------------------------------------------------------------------- /src/envs/dmcontrol_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import gym 4 | import dmc2gym_custom 5 | from gym import spaces 6 | from stable_baselines3.common.monitor import Monitor 7 | from dm_control import suite 8 | 9 | 10 | def extract_obs_dims(exclude_domains=[]): 11 | obstype_to_dim = {} 12 | for domain_name, task_name in suite.BENCHMARKING: 13 | env = suite.load(domain_name, task_name) 14 | time_step = env.reset() 15 | print(f"{domain_name}-{task_name}", 16 | {k: v.shape for k, v in time_step.observation.items()}, 17 | env.action_spec().shape, 18 | "\n") 19 | if any(domain in domain_name for domain in exclude_domains): 20 | continue 21 | for k, v in time_step.observation.items(): 22 | v = np.array([v]) if np.isscalar(v) else v.ravel() 23 | obstype_to_dim[k] = max(obstype_to_dim.get(k, 0), v.shape[0]) 24 | return obstype_to_dim 25 | 26 | def extract_obstype_to_startidx(obstype_to_dim): 27 | cum_dim = 0 28 | obstype_to_start_idx = {} 29 | for k, v in obstype_to_dim.items(): 30 | obstype_to_start_idx[k] = cum_dim 31 | cum_dim += v 32 | return obstype_to_start_idx 33 | 34 | 35 | DMC_OBSTYPE_TO_DIM = { 36 | 'orientations': 14, 'velocity': 27, 'position': 8, 'touch': 5, 'target_position': 2, 'dist_to_target': 1, 37 | 'joint_angles': 21, 'upright': 1, 'target': 3, 'head_height': 1, 'extremities': 12, 'torso_vertical': 3, 38 | 'com_velocity': 3, 'arm_pos': 16, 'arm_vel': 8, 'hand_pos': 4, 'object_pos': 4, 'object_vel': 3, 'target_pos': 4, 39 | 'orientation': 2, 'to_target': 2, 'joints': 14, 'body_velocities': 45, 'height': 1 40 | } 41 | 42 | DMC_FULL_OBS_DIM = sum(DMC_OBSTYPE_TO_DIM.values()) 43 | 44 | DMC_OBSTYPE_TO_STARTIDX = { 45 | 'orientations': 0, 'velocity': 14, 'position': 41, 'touch': 49, 'target_position': 54, 'dist_to_target': 56, 46 | 'joint_angles': 57, 'upright': 78, 'target': 79, 'head_height': 82, 'extremities': 83, 'torso_vertical': 95, 47 | 'com_velocity': 98, 'arm_pos': 101, 'arm_vel': 117, 'hand_pos': 125, 'object_pos': 129, 'object_vel': 133, 48 | 'target_pos': 136, 'orientation': 140, 'to_target': 142, 'joints': 144, 'body_velocities': 158, 'height': 203 49 | } 50 | 51 | 52 | def map_obs_to_full_space(obs, hide_goal=False): 53 | dtype = obs.dtype if hasattr(obs, "dtype") else np.float32 54 | full_obs = np.zeros(DMC_FULL_OBS_DIM, dtype=dtype) 55 | for k, v in obs.items(): 56 | start_idx = DMC_OBSTYPE_TO_STARTIDX[k] 57 | v = np.array([v]) if np.isscalar(v) else v.ravel() 58 | full_obs[start_idx: start_idx + v.shape[0]] = v 59 | if hide_goal and "target" in k: 60 | full_obs[start_idx: start_idx + v.shape[0]] = 0 61 | return full_obs 62 | 63 | 64 | def map_flattened_obs_to_full_space(obs, obs_spec): 65 | if not isinstance(obs, np.ndarray): 66 | obs = np.array(obs) 67 | is_one_dim = len(obs.shape) == 1 68 | if is_one_dim: 69 | obs = np.expand_dims(obs, axis=0) 70 | full_obs = np.zeros((*obs.shape[:-1], DMC_FULL_OBS_DIM)) 71 | flat_start_idx = 0 72 | for k, v in obs_spec.items(): 73 | dim = np.prod(v.shape) if len(v.shape) > 0 else 1 74 | full_start_idx = DMC_OBSTYPE_TO_STARTIDX[k] 75 | full_obs[..., full_start_idx: full_start_idx + dim] = obs[..., flat_start_idx: flat_start_idx + dim] 76 | flat_start_idx += dim 77 | if is_one_dim: 78 | full_obs = full_obs.ravel() 79 | return full_obs 80 | 81 | 82 | class DmcFullObsWrapper(gym.ObservationWrapper): 83 | """ 84 | Converts a given state observation to the full observation space of all DMControl environments. 85 | 86 | Unforunately, dmc2gym always flattens the obsevation by default. Therefore, this wrapper should 87 | always be used with dmc2gym custom, which make flattening the observation optional. 88 | 89 | Args: 90 | env: Gym environment. 91 | """ 92 | 93 | def __init__(self, env: gym.Env, hide_goal=False): 94 | gym.ObservationWrapper.__init__(self, env) 95 | self.hide_goal = hide_goal 96 | low, high = np.array([-float("inf")] * DMC_FULL_OBS_DIM), np.array([float("inf")] * DMC_FULL_OBS_DIM) 97 | self.observation_space = spaces.Box( 98 | low=low, high=high, dtype=np.float32 99 | ) 100 | 101 | def observation(self, obs): 102 | return map_obs_to_full_space(obs, hide_goal=self.hide_goal) 103 | 104 | 105 | class GrayscaleWrapper(gym.ObservationWrapper): 106 | """ 107 | Converts a given frame to grayscale. The given frame must be channel last. 108 | 109 | Args: 110 | env: Gym environment. 111 | """ 112 | 113 | def __init__(self, env: gym.Env): 114 | gym.ObservationWrapper.__init__(self, env) 115 | channels, height, width, = env.observation_space.shape 116 | assert channels != 1, "Image is grayscale already." 117 | self.observation_space = spaces.Box( 118 | low=0, high=255, shape=(1, height, width), dtype=env.observation_space.dtype 119 | ) 120 | 121 | def observation(self, frame): 122 | frame = cv2.cvtColor(frame.transpose(1, 2, 0), cv2.COLOR_RGB2GRAY) 123 | return np.expand_dims(frame, 0) 124 | 125 | 126 | def get_dmcontrol_constructor(envid, env_kwargs=None, hide_goal=False): 127 | env_kwargs = dict(env_kwargs) if env_kwargs is not None else {} 128 | render_mode = env_kwargs.pop("render_mode", None) 129 | def make(): 130 | domain_name, task_name = envid.split("-") 131 | env = dmc2gym_custom.make(domain_name=domain_name, task_name=task_name, **env_kwargs) 132 | # change envid to make more readable than default in dmc2gym_custom 133 | env.spec.id = f"{domain_name}-{task_name}" 134 | if env_kwargs.get("from_pixels", False): 135 | env = GrayscaleWrapper(env) 136 | if not env_kwargs.get("flatten_obs", True): 137 | env = DmcFullObsWrapper(env, hide_goal=hide_goal) 138 | if render_mode is not None: 139 | env.metadata.update({"render.modes": [render_mode]}) 140 | return Monitor(env) 141 | return make 142 | -------------------------------------------------------------------------------- /src/algos/models/rope.py: -------------------------------------------------------------------------------- 1 | # copied from modellin_llama.py in transformers 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class LlamaRotaryEmbedding(nn.Module): 7 | def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): 8 | super().__init__() 9 | self.scaling_factor = scaling_factor 10 | self.dim = dim 11 | self.max_position_embeddings = max_position_embeddings 12 | self.base = base 13 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) 14 | self.register_buffer("inv_freq", inv_freq, persistent=False) 15 | # For BC we register cos and sin cached 16 | self.max_seq_len_cached = max_position_embeddings 17 | t = torch.arange(self.max_seq_len_cached, dtype=torch.int64).type_as(self.inv_freq) 18 | t = t / self.scaling_factor 19 | freqs = torch.outer(t, self.inv_freq) 20 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 21 | emb = torch.cat((freqs, freqs), dim=-1) 22 | self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False) 23 | self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False) 24 | 25 | @property 26 | def sin_cached(self): 27 | print( 28 | "The sin_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " 29 | "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" 30 | ) 31 | return self._sin_cached 32 | 33 | @property 34 | def cos_cached(self): 35 | print( 36 | "The cos_cached attribute will be removed in 4.39. Bear in mind that its contents changed in v4.38. Use " 37 | "the forward method of RoPE from now on instead. It is not used in the `LlamaAttention` class" 38 | ) 39 | return self._cos_cached 40 | 41 | @torch.no_grad() 42 | def forward(self, x, position_ids): 43 | # x: [bs, num_attention_heads, seq_len, head_size] 44 | inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) 45 | position_ids_expanded = position_ids[:, None, :].float() 46 | # Force float32 since bfloat16 loses precision on long contexts 47 | # See https://github.com/huggingface/transformers/pull/29285 48 | # device_type = x.device.type 49 | # device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" 50 | # with torch.autocast(device_type=device_type, enabled=False): 51 | # does not work for some reason, if we compile 52 | freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) 53 | emb = torch.cat((freqs, freqs), dim=-1) 54 | cos = emb.cos() 55 | sin = emb.sin() 56 | return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) 57 | 58 | 59 | class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): 60 | """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 61 | 62 | def forward(self, x, position_ids): 63 | # difference to the original RoPE: a scaling factor is aplied to the position ids 64 | position_ids = position_ids.float() / self.scaling_factor 65 | cos, sin = super().forward(x, position_ids) 66 | return cos, sin 67 | 68 | 69 | class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): 70 | """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 71 | 72 | def forward(self, x, position_ids): 73 | # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length 74 | seq_len = torch.max(position_ids) + 1 75 | if seq_len > self.max_position_embeddings: 76 | base = self.base * ( 77 | (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) 78 | ) ** (self.dim / (self.dim - 2)) 79 | inv_freq = 1.0 / ( 80 | base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) 81 | ) 82 | self.register_buffer("inv_freq", inv_freq, persistent=False) 83 | 84 | cos, sin = super().forward(x, position_ids) 85 | return cos, sin 86 | 87 | 88 | def rotate_half(x): 89 | """Rotates half the hidden dims of the input.""" 90 | x1 = x[..., : x.shape[-1] // 2] 91 | x2 = x[..., x.shape[-1] // 2 :] 92 | return torch.cat((-x2, x1), dim=-1) 93 | 94 | 95 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): 96 | """Applies Rotary Position Embedding to the query and key tensors. 97 | 98 | Args: 99 | q (`torch.Tensor`): The query tensor. 100 | k (`torch.Tensor`): The key tensor. 101 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 102 | sin (`torch.Tensor`): The sine part of the rotary embedding. 103 | position_ids (`torch.Tensor`, *optional*): 104 | Deprecated and unused. 105 | unsqueeze_dim (`int`, *optional*, defaults to 1): 106 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 107 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 108 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 109 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 110 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 111 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 112 | Returns: 113 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 114 | """ 115 | cos = cos.unsqueeze(unsqueeze_dim) 116 | sin = sin.unsqueeze(unsqueeze_dim) 117 | q_embed = (q * cos) + (rotate_half(q) * sin) 118 | k_embed = (k * cos) + (rotate_half(k) * sin) 119 | return q_embed, k_embed 120 | -------------------------------------------------------------------------------- /src/envs/dn_scores.py: -------------------------------------------------------------------------------- 1 | """ 2 | For DMControl and Gym MuJoCo, there are no human-normalized scores. 3 | Therefore, we normalize the scores based on the performance the expert agent reaches at the end of training. 4 | --> Data normalized 5 | 6 | Random and expert refernence scores for D4RL are available here: https://github.com/Farama-Foundation/D4RL/blob/master/d4rl/infos.py 7 | 8 | """ 9 | import math 10 | import dmc2gym 11 | import numpy as np 12 | import pandas as pd 13 | from .env_names import DM_CONTROL_ENVS 14 | 15 | 16 | # Task: score-tuple dictionary. Each score tuple contains 17 | # 0: score random (float) and 1: mean scores in the datasets (float). 18 | ENVID_TO_DNS = { 19 | 'acrobot-swingup': (8.351, 4.877), 20 | 'ball_in_cup-catch': (0.0, 926.719), 21 | 'cartpole-balance': (350.391, 938.506), 22 | 'cartpole-swingup': (27.414, 766.15), 23 | 'cheetah-run': (3.207, 324.045), 24 | 'finger-spin': (0.2, 834.629), 25 | 'finger-turn_easy': (57.8, 800.645), 26 | 'finger-turn_hard': (40.6, 676.144), 27 | 'fish-swim': (67.675, 78.212), 28 | 'fish-upright': (229.406, 547.962), 29 | 'hopper-hop': (0.076, 62.794), 30 | 'hopper-stand': (1.296, 266.783), 31 | 'humanoid-run': (0.741, 0.794), 32 | 'humanoid-stand': (4.327, 5.053), 33 | 'humanoid-walk': (0.913, 1.194), 34 | 'manipulator-bring_ball': (0.0, 0.429), 35 | 'manipulator-insert_ball': (0.0, 43.307), 36 | 'manipulator-insert_peg': (0.235, 78.477), 37 | 'pendulum-swingup': (0.0, 614.491), 38 | 'point_mass-easy': (1.341, 779.273), 39 | 'reacher-easy': (33.0, 849.241), 40 | 'reacher-hard': (8.0, 779.947), 41 | 'swimmer-swimmer15': (78.817, 152.297), 42 | 'swimmer-swimmer6': (229.834, 167.082), 43 | 'walker-run': (23.427, 344.794), 44 | 'walker-stand': (134.701, 816.322), 45 | 'walker-walk': (30.193, 773.174), 46 | "HalfCheetah-v3":(-280.178953, 12135.0), 47 | "Walker2d-v3": (1.629008, 4592.3), 48 | "Hopper-v3": (-20.272305, 3234.3), 49 | "HalfCheetah-v2":(-280.178953, 12135.0), 50 | "Walker2d-v2": (1.629008, 4592.3), 51 | "Hopper-v2": (-20.272305, 3234.3), 52 | # extracted from 25M data 53 | "bigfish": (0.0, 5.9107), 54 | "bossfight": (0.0, 2.179), 55 | "caveflyer": (0.0, 7.6341), 56 | "chaser": (0.0, 3.4349), 57 | "climber": (0.0, 9.1516), 58 | "coinrun": (0.0, 9.6781), 59 | "dodgeball": (0.0, 3.1873), 60 | "fruitbot": (-7.0, 16.9643), 61 | "heist": (0.0, 7.9555), 62 | "jumper": (0.0, 8.7396), 63 | "leaper": (0.0, 4.9065), 64 | "maze": (0.0, 9.4536), 65 | "miner": (0.0, 11.6814), 66 | "ninja": (0.0, 7.7674), 67 | "plunder": (0.0, 4.9095), 68 | "starpilot": (0.0, 17.3367), 69 | # metaworld - random scores, + mean scores in dataset 70 | 'reach-v2': (333.01, 1842.41), 71 | 'push-v2': (13.79, 1311.81), 72 | # for pick-place, we use the max score in the datasets, as the mean is ~random 73 | 'pick-place-v2': (4.63, 1300.69), 74 | 'door-open-v2': (96.88, 1484.42), 75 | 'drawer-open-v2': (258.2, 1594.88), 76 | 'drawer-close-v2': (297.9, 1823.24), 77 | 'button-press-topdown-v2': (68.1, 1255.89), 78 | 'peg-insert-side-v2': (3.84, 1418.33), 79 | 'window-open-v2': (86.62, 1480.98), 80 | 'window-close-v2': (117.4, 1395.35), 81 | 'door-close-v2': (12.93, 1523.22), 82 | 'reach-wall-v2': (268.53, 1817.09), 83 | 'pick-place-wall-v2': (0.0, 260.63), 84 | 'push-wall-v2': (24.91, 1361.09), 85 | 'button-press-v2': (68.25, 1366.96), 86 | 'button-press-topdown-wall-v2': (73.82, 1264.01), 87 | 'button-press-wall-v2': (21.23, 1438.6), 88 | 'peg-unplug-side-v2': (8.2, 1044.82), 89 | 'disassemble-v2': (82.99, 1139.05), 90 | 'hammer-v2': (202.6, 1444.97), 91 | 'plate-slide-v2': (120.86, 1634.72), 92 | 'plate-slide-side-v2': (34.02, 1598.42), 93 | 'plate-slide-back-v2': (61.46, 1725.26), 94 | 'plate-slide-back-side-v2': (76.74, 1716.65), 95 | 'handle-press-v2': (130.96, 1824.9), 96 | 'handle-pull-v2': (16.98, 1519.89), 97 | 'handle-press-side-v2': (170.76, 1808.85), 98 | 'handle-pull-side-v2': (5.09, 1484.7), 99 | 'stick-push-v2': (5.53, 925.04), 100 | 'stick-pull-v2': (4.6, 977.54), 101 | 'basketball-v2': (5.08, 702.4), 102 | 'soccer-v2': (13.44, 255.1), 103 | 'faucet-open-v2': (516.4, 1712.81), 104 | 'faucet-close-v2': (508.07, 1741.73), 105 | 'coffee-push-v2': (8.62, 253.72), 106 | 'coffee-pull-v2': (8.92, 973.07), 107 | 'coffee-button-v2': (60.16, 1461.26), 108 | 'sweep-v2': (21.74, 1024.95), 109 | 'sweep-into-v2': (28.63, 1638.45), 110 | 'pick-out-of-hole-v2': (2.73, 891.61), 111 | 'assembly-v2': (92.02, 908.84), 112 | 'shelf-place-v2': (0.0, 997.24), 113 | 'push-back-v2': (2.33, 645.29), 114 | 'lever-pull-v2': (99.86, 731.48), 115 | 'dial-turn-v2': (49.74, 1556.25), 116 | 'bin-picking-v2': (4.27, 249.41), 117 | 'box-close-v2': (100.89, 550.25), 118 | 'hand-insert-v2': (5.35, 1196.92), 119 | 'door-lock-v2': (226.23, 1589.09), 120 | 'door-unlock-v2': (192.42, 1662.54), 121 | } 122 | 123 | 124 | def get_data_normalized_score(task: str, raw_score: float, random_col=0, data_col=1) -> float: 125 | """Converts task score to data-normalized score.""" 126 | scores = ENVID_TO_DNS.get(task, (math.nan, math.nan)) 127 | random, data = scores[random_col], scores[data_col] 128 | return (raw_score - random) / (data - random) 129 | 130 | 131 | def compute_random_dmcontrol_scores(): 132 | random_scores = {} 133 | for envid in DM_CONTROL_ENVS: 134 | domain_name, task_name = envid.split("-") 135 | print(f"Computing random scores for {envid} ...") 136 | env = dmc2gym.make(domain_name=domain_name, task_name=task_name) 137 | random_scores[envid] = evaluate_random_policy(env) 138 | return random_scores 139 | 140 | 141 | def evaluate_random_policy(env, n_eval_episodes=10): 142 | returns = [] 143 | for _ in range(n_eval_episodes): 144 | _ = env.reset() 145 | done = False 146 | episode_return = 0 147 | while not done: 148 | action = env.action_space.sample() 149 | _, reward, done, _ = env.step(action) 150 | episode_return += reward 151 | returns.append(episode_return) 152 | return np.mean(returns) 153 | -------------------------------------------------------------------------------- /dmc2gym_custom/dmc2gym_custom/wrappers.py: -------------------------------------------------------------------------------- 1 | from gym import core, spaces 2 | from dm_control import suite 3 | from dm_env import specs 4 | import numpy as np 5 | 6 | 7 | def _spec_to_box(spec, dtype): 8 | def extract_min_max(s): 9 | assert s.dtype == np.float64 or s.dtype == np.float32 10 | dim = np.int(np.prod(s.shape)) 11 | if type(s) == specs.Array: 12 | bound = np.inf * np.ones(dim, dtype=np.float32) 13 | return -bound, bound 14 | elif type(s) == specs.BoundedArray: 15 | zeros = np.zeros(dim, dtype=np.float32) 16 | return s.minimum + zeros, s.maximum + zeros 17 | 18 | mins, maxs = [], [] 19 | for s in spec: 20 | mn, mx = extract_min_max(s) 21 | mins.append(mn) 22 | maxs.append(mx) 23 | low = np.concatenate(mins, axis=0).astype(dtype) 24 | high = np.concatenate(maxs, axis=0).astype(dtype) 25 | assert low.shape == high.shape 26 | return spaces.Box(low, high, dtype=dtype) 27 | 28 | 29 | def _flatten_obs(obs): 30 | obs_pieces = [] 31 | for v in obs.values(): 32 | flat = np.array([v]) if np.isscalar(v) else v.ravel() 33 | obs_pieces.append(flat) 34 | return np.concatenate(obs_pieces, axis=0) 35 | 36 | 37 | class DMCWrapper(core.Env): 38 | def __init__( 39 | self, 40 | domain_name, 41 | task_name, 42 | task_kwargs=None, 43 | visualize_reward={}, 44 | from_pixels=False, 45 | height=84, 46 | width=84, 47 | camera_id=0, 48 | frame_skip=1, 49 | environment_kwargs=None, 50 | channels_first=True, 51 | flatten_obs=True, 52 | deterministic=False 53 | ): 54 | assert 'random' in task_kwargs, 'please specify a seed, for deterministic behaviour' 55 | self._from_pixels = from_pixels 56 | self._height = height 57 | self._width = width 58 | self._camera_id = camera_id 59 | self._frame_skip = frame_skip 60 | self._channels_first = channels_first 61 | self._flatten_obs = flatten_obs 62 | self.deterministic = deterministic 63 | 64 | # create task 65 | self._env = suite.load( 66 | domain_name=domain_name, 67 | task_name=task_name, 68 | task_kwargs=task_kwargs, 69 | visualize_reward=visualize_reward, 70 | environment_kwargs=environment_kwargs 71 | ) 72 | 73 | # true and normalized action spaces 74 | self._true_action_space = _spec_to_box([self._env.action_spec()], np.float32) 75 | self._norm_action_space = spaces.Box( 76 | low=-1.0, 77 | high=1.0, 78 | shape=self._true_action_space.shape, 79 | dtype=np.float32 80 | ) 81 | 82 | # create observation space 83 | if from_pixels: 84 | shape = [3, height, width] if channels_first else [height, width, 3] 85 | self._observation_space = spaces.Box( 86 | low=0, high=255, shape=shape, dtype=np.uint8 87 | ) 88 | else: 89 | self._observation_space = _spec_to_box( 90 | self._env.observation_spec().values(), 91 | np.float64 92 | ) 93 | 94 | self._state_space = _spec_to_box( 95 | self._env.observation_spec().values(), 96 | np.float64 97 | ) 98 | 99 | self.current_state = None 100 | 101 | # set seed 102 | self._seed = task_kwargs.get('random', 1) 103 | self.seed(seed=self._seed) 104 | 105 | def __getattr__(self, name): 106 | return getattr(self._env, name) 107 | 108 | def _get_obs(self, time_step): 109 | if self._from_pixels: 110 | obs = self.render( 111 | height=self._height, 112 | width=self._width, 113 | camera_id=self._camera_id 114 | ) 115 | if self._channels_first: 116 | obs = obs.transpose(2, 0, 1).copy() 117 | else: 118 | obs = _flatten_obs(time_step.observation) if self._flatten_obs else time_step.observation 119 | return obs 120 | 121 | def _convert_action(self, action): 122 | action = action.astype(np.float64) 123 | true_delta = self._true_action_space.high - self._true_action_space.low 124 | norm_delta = self._norm_action_space.high - self._norm_action_space.low 125 | action = (action - self._norm_action_space.low) / norm_delta 126 | action = action * true_delta + self._true_action_space.low 127 | action = action.astype(np.float32) 128 | return action 129 | 130 | @property 131 | def observation_space(self): 132 | return self._observation_space 133 | 134 | @property 135 | def state_space(self): 136 | return self._state_space 137 | 138 | @property 139 | def action_space(self): 140 | return self._norm_action_space 141 | 142 | @property 143 | def reward_range(self): 144 | return 0, self._frame_skip 145 | 146 | def seed(self, seed): 147 | self._true_action_space.seed(seed) 148 | self._norm_action_space.seed(seed) 149 | self._observation_space.seed(seed) 150 | 151 | def step(self, action): 152 | assert self._norm_action_space.contains(action) 153 | action = self._convert_action(action) 154 | assert self._true_action_space.contains(action) 155 | reward = 0 156 | extra = {'internal_state': self._env.physics.get_state().copy()} 157 | 158 | for _ in range(self._frame_skip): 159 | time_step = self._env.step(action) 160 | reward += time_step.reward or 0 161 | done = time_step.last() 162 | if done: 163 | break 164 | obs = self._get_obs(time_step) 165 | self.current_state = _flatten_obs(time_step.observation) if self._flatten_obs else time_step.observation 166 | extra['discount'] = time_step.discount 167 | return obs, reward, done, extra 168 | 169 | def reset(self): 170 | if self.deterministic: 171 | self._env.task._random = np.random.RandomState(self._seed) 172 | time_step = self._env.reset() 173 | self.current_state = _flatten_obs(time_step.observation) if self._flatten_obs else time_step.observation 174 | obs = self._get_obs(time_step) 175 | return obs 176 | 177 | def render(self, mode='rgb_array', height=None, width=None, camera_id=0): 178 | assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode 179 | height = height or self._height 180 | width = width or self._width 181 | camera_id = camera_id or self._camera_id 182 | return self._env.physics.render( 183 | height=height, width=width, camera_id=camera_id 184 | ) -------------------------------------------------------------------------------- /src/buffers/cache_dataset.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import torch 3 | import numpy as np 4 | from pathlib import Path 5 | from .trajectory_dataset import TrajectoryDataset 6 | from .buffer_utils import compute_start_end_context_idx 7 | 8 | 9 | class CacheDataset(TrajectoryDataset): 10 | def __init__(self, trajectories, env, context_len, action_pad, cache_steps=1, rand_first_chunk=False, **kwargs): 11 | super().__init__(trajectories, env, context_len, action_pad, **kwargs) 12 | self.trj_idx_to_count = collections.defaultdict(int) 13 | self.cache_steps = cache_steps 14 | self.rand_first_chunk = rand_first_chunk 15 | 16 | def get_single_sample_from_memory(self, trj, full_trj=False, idx=None): 17 | assert idx is not None, "CacheDataset requires idx." 18 | # handle first step 19 | # does not make sense to sample the very first step of trajectory alone, as action not included there 20 | end_idx = max(self.trj_idx_to_count[idx], 2) 21 | self.trj_idx_to_count[idx] += self.cache_steps 22 | # self.trj_idx_to_count[idx] += self.cache_steps 23 | # end_idx = min(self.trj_idx_to_count[idx], len(trj)) 24 | if self.rand_first_chunk and end_idx < self.cache_steps: 25 | end_idx = np.random.randint(2, min(len(trj), self.cache_steps)) 26 | s, s1, a, r, togo, t, done, task_id, trj_id, trj_seed = trj.sample(self.context_len, end_idx=end_idx) 27 | return s, s1, a, r, togo, t, done, task_id, trj_id, None, trj_seed 28 | 29 | def get_single_sample_from_disk(self, path, idx): 30 | # directly load subset of trajectory from disk making use of trj_lengths 31 | end_idx = max(self.trj_idx_to_count[idx], 2) 32 | end_idx = min(end_idx, self.trj_lengths[path]) 33 | self.trj_idx_to_count[idx] += self.cache_steps 34 | start_idx = max(0, end_idx - self.context_len) 35 | s, a, r, done_flag, togo = self.load_trj(path, start_idx=start_idx, end_idx=end_idx) 36 | assert togo is not None, "RTGs must be stored in trj file." 37 | r = r.astype(np.float32) 38 | if len(a.shape) == 1: 39 | a = np.expand_dims(a, -1) 40 | if isinstance(done_flag, (list, tuple, np.ndarray)): 41 | done_flag = done_flag[..., -1] 42 | done = np.zeros(len(s)) 43 | done[-1] = done_flag 44 | s1 = np.zeros_like(s) 45 | t = np.arange(start_idx, end_idx) 46 | task_id, trj_id, trj_seed = 0, idx, 0 47 | return s, s1, a, r, togo, t, done, task_id, trj_id, None, trj_seed 48 | 49 | 50 | class CacheWithContextDataset(CacheDataset): 51 | def __init__(self, trajectories, env, context_len, action_pad, cache_context_len, 52 | future_context_len, full_context_len=True, dynamic_context_len=False, **kwargs): 53 | super().__init__(trajectories, env, context_len, action_pad, **kwargs) 54 | self.cache_context_len = cache_context_len 55 | self.future_context_len = future_context_len 56 | self.full_context_len = full_context_len 57 | self.dynamic_context_len = dynamic_context_len 58 | 59 | def __getitem__(self, idx): 60 | s, a, s1, r, togo, t, mask, done, task_id, trj_id, action_mask, total_return, trj_seed = super().__getitem__(idx) 61 | # get context trjs 62 | c_s, c_a, _, c_r, c_togo, c_t, c_mask, _, _ = self.extract_context_trjs(idx) 63 | return s, a, s1, r, togo, t, mask, done, task_id, trj_id, \ 64 | action_mask, total_return, trj_seed, c_s, c_a, c_r, c_togo, c_t, c_mask, 65 | 66 | def extract_context_trjs(self, idx): 67 | trj = self.trajectories[idx] 68 | if isinstance(trj, (str, Path)): 69 | # load from disk 70 | path = str(trj) 71 | s, s1, a, r, togo, t, done, action_mask = self.get_context_from_disk(path, idx) 72 | else: 73 | # samples stored in memory, load from there 74 | s, s1, a, r, togo, t, done, action_mask = self.get_context_from_memory(trj, idx) 75 | 76 | # postprocess states, actions 77 | if len(s.shape) == 4 and self.to_rgb: 78 | # convert to "RGB" by repeating the gray-scale channel 79 | s = np.repeat(s, 3, axis=1) 80 | s1 = np.repeat(s1, 3, axis=1) 81 | if self.env is not None: 82 | s = self.env.normalize_obs(s) 83 | s1 = self.env.normalize_obs(s1) 84 | 85 | padding = max(0, (self.cache_context_len + self.future_context_len) - s.shape[0]) 86 | mask = self.make_attention_mask(padding, s.shape[0]) 87 | action_mask = np.ones_like(a, dtype=np.int32) if action_mask is None else action_mask 88 | if self.max_act_dim is not None and a.dtype.kind == "f": 89 | a, action_mask = self.pad_actions(a) 90 | if self.max_state_dim is not None and len(s.shape) == 2: 91 | s, s1 = self.pad_states(s, s1) 92 | if padding: 93 | s, s1, a, r, togo, t, done, action_mask = self.pad_sequences(s, s1, a, r, togo, t, 94 | done, action_mask, padding) 95 | if len(s.shape) == 4 and self.transforms is not None: 96 | # perform image augmentations 97 | s = self.transforms(torch.from_numpy(s).float()) 98 | s1 = self.transforms(torch.from_numpy(s1).float()) 99 | 100 | return s, a, s1, np.expand_dims(r, axis=1), np.expand_dims(togo, axis=1), \ 101 | t, mask, done, action_mask 102 | 103 | def get_context_from_memory(self, trj, idx): 104 | # return s, s1, a, r, togo, t, done, task_id, trj_id, None 105 | assert idx is not None, "CacheDataset requires idx." 106 | cur_trj_idx = self.trj_idx_to_count[idx] - self.cache_steps 107 | # cur_trj_idx = self.trj_idx_to_count[idx] 108 | start, end = compute_start_end_context_idx(cur_trj_idx, len(trj), 109 | self.cache_context_len, self.future_context_len, 110 | full_context_len=self.full_context_len, 111 | dynamic_len=self.dynamic_context_len) 112 | s, s1, a, r, togo, t, done, _, _, _ = trj._get_samples(start, end) 113 | return s, s1, a, r, togo, t, done, None 114 | 115 | def get_context_from_disk(self, path, idx): 116 | # directly load subset of trajectory from disk making use of trj_lengths 117 | cur_trj_idx = self.trj_idx_to_count[idx] - self.cache_steps 118 | start, end = compute_start_end_context_idx(cur_trj_idx, self.trj_lengths[path], 119 | self.cache_context_len, self.future_context_len, 120 | full_context_len=self.full_context_len, 121 | dynamic_len=self.dynamic_context_len) 122 | s, a, r, done_flag, togo = self.load_trj(path, start_idx=start, end_idx=end) 123 | assert togo is not None, "RTGs must be stored in trj file." 124 | r = r.astype(np.float32) 125 | if len(a.shape) == 1: 126 | a = np.expand_dims(a, -1) 127 | if isinstance(done_flag, (list, tuple, np.ndarray)): 128 | done_flag = done_flag[..., -1] 129 | done = np.zeros(len(s)) 130 | done[-1] = done_flag 131 | s1 = np.zeros_like(s) 132 | t = np.arange(start, end) 133 | return s, s1, a, r, togo, t, done, None 134 | -------------------------------------------------------------------------------- /src/buffers/buffer_utils.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import pickle 3 | import numpy as np 4 | import torch 5 | import collections 6 | from pathlib import Path 7 | 8 | 9 | def discount_cumsum(x, gamma): 10 | new_x = np.zeros_like(x) 11 | new_x[-1] = x[-1] 12 | for t in reversed(range(x.shape[0] - 1)): 13 | new_x[t] = x[t] + gamma * new_x[t + 1] 14 | return new_x 15 | 16 | def discount_cumsum_np(x, gamma): 17 | # much faster version of the above 18 | new_x = np.zeros_like(x) 19 | rev_cumsum = np.cumsum(np.flip(x, 0)) 20 | new_x = np.flip(rev_cumsum * gamma ** np.arange(0, x.shape[0]), 0) 21 | new_x = np.ascontiguousarray(new_x).astype(np.float32) 22 | return new_x 23 | 24 | 25 | def discount_cumsum_torch(x, gamma): 26 | new_x = torch.zeros_like(x) 27 | rev_cumsum = torch.cumsum(torch.flip(x, [0]), 0) 28 | new_x = torch.flip(rev_cumsum * gamma ** torch.arange(0, x.shape[0], device=x.device), [0]) 29 | new_x = new_x.contiguous().to(dtype=torch.float32) 30 | return new_x 31 | 32 | 33 | def compute_rtg_from_target(x, target_return): 34 | new_x = np.zeros_like(x) 35 | new_x[0] = target_return 36 | for i in range(1, x.shape[0]): 37 | new_x[i] = min(new_x[i - 1] - x[i - 1], target_return) 38 | return new_x 39 | 40 | 41 | def filter_top_p_trajectories(trajectories, top_p=1, epname_to_return=None): 42 | start = len(trajectories) - int(len(trajectories) * top_p) 43 | if epname_to_return is None: 44 | if hasattr(trajectories[0], "rewards"): 45 | def sort_fn(x): return np.array(x.rewards).sum() 46 | else: 47 | def sort_fn(x): return np.array(x.get("rewards")).sum() 48 | else: 49 | def sort_fn(x): return epname_to_return[x.stem] 50 | sorted_trajectories = sorted(trajectories, key=sort_fn) 51 | return sorted_trajectories[start:] 52 | 53 | 54 | def filter_trajectories_uniform(trajectories, p=1): 55 | # sample uniformly with trj len weights 56 | trj_lens = [len(t["observations"]) for t in trajectories] 57 | total_samples = sum(trj_lens) 58 | trajectory_probs = [l / total_samples for l in trj_lens] 59 | idx = np.random.choice(len(trajectories), size=int(len(trajectories) * p), p=trajectory_probs, replace=False) 60 | return [trajectories[i] for i in idx] 61 | 62 | 63 | def filter_trajectories_first(trajectories, p=1): 64 | return trajectories[:int(len(trajectories) * p)] 65 | 66 | 67 | def load_npz(path, start_idx=None, end_idx=None): 68 | returns_to_go = None 69 | # trj = np.load(path, mmap_mode="r" if start_idx and end_idx else None) 70 | with np.load(path, mmap_mode="r" if start_idx and end_idx else None) as trj: 71 | if start_idx is not None and end_idx is not None: 72 | # subtrajectory only 73 | observations, actions, rewards = trj["states"][start_idx: end_idx].astype(np.float32), \ 74 | trj["actions"][start_idx: end_idx].astype(np.float32), trj["rewards"][start_idx: end_idx].astype(np.float32) 75 | if "returns_to_go" in trj: 76 | returns_to_go = trj["returns_to_go"][start_idx: end_idx].astype(np.float32) 77 | else: 78 | # fully trajectory 79 | observations, actions, rewards = trj["states"], trj["actions"], trj["rewards"], 80 | if "returns_to_go" in trj: 81 | returns_to_go = trj["returns_to_go"].astype(np.float32) 82 | dones = np.array([trj["dones"]]) 83 | return observations, actions, rewards, dones, returns_to_go 84 | 85 | 86 | def load_hdf5(path, start_idx=None, end_idx=None, img_is_encoded=False): 87 | returns_to_go, dones = None, None 88 | with h5py.File(path, "r") as f: 89 | if start_idx is not None and end_idx is not None: 90 | # subtrajectory only 91 | if img_is_encoded: 92 | observations = f['states_encoded'][start_idx: end_idx] 93 | else: 94 | observations = f['states'][start_idx: end_idx] 95 | actions = f['actions'][start_idx: end_idx] 96 | rewards = f['rewards'][start_idx: end_idx] 97 | if "returns_to_go" in f: 98 | returns_to_go = f["returns_to_go"][start_idx: end_idx] 99 | if "dones" in f: 100 | try: 101 | dones = f['dones'][start_idx: end_idx] 102 | except Exception as e: 103 | pass 104 | else: 105 | # fully trajectory 106 | if img_is_encoded: 107 | observations = f['states_encoded'][:] 108 | else: 109 | observations = f['states'][:] 110 | actions = f['actions'][:] 111 | rewards = f['rewards'][:] 112 | if "returns_to_go" in f: 113 | returns_to_go = f["returns_to_go"][:] 114 | if "dones" in f: 115 | try: 116 | dones = f['dones'][:] 117 | except Exception as e: 118 | pass 119 | if dones is None: 120 | dones = np.array([f['dones'][()]]) 121 | return observations, actions, rewards, dones, returns_to_go 122 | 123 | 124 | def append_to_hdf5(path, new_vals, compress_kwargs=None): 125 | compress_kwargs = {"compression": "gzip", "compression_opts": 1} if compress_kwargs is None \ 126 | else compress_kwargs 127 | # open in append mode, add new vals 128 | with h5py.File(str(path), 'a') as f: 129 | for k, v in new_vals.items(): 130 | if k in f: 131 | del f[k] 132 | f.create_dataset(k, data=v, **compress_kwargs) 133 | 134 | 135 | def load_pkl(path, start_idx=None, end_idx=None): 136 | returns_to_go = None 137 | with open(path, "rb") as f: 138 | trj = pickle.load(f) 139 | if start_idx is not None and end_idx is not None: 140 | # subtrajectory only 141 | observations, actions, rewards = trj["states"][start_idx: end_idx], \ 142 | trj["actions"][start_idx: end_idx], trj["rewards"][start_idx: end_idx] 143 | if "returns_to_go" in trj: 144 | returns_to_go = trj["returns_to_go"][start_idx: end_idx] 145 | else: 146 | # fully trajectory 147 | observations, actions, rewards = trj["states"], trj["actions"], trj["rewards"], 148 | if "returns_to_go" in trj: 149 | returns_to_go = trj["returns_to_go"] 150 | dones = np.array([trj["dones"]]) 151 | return observations, actions, rewards, dones, returns_to_go 152 | 153 | 154 | def compute_start_end_context_idx(idx, seq_len, cache_len, future_cache_len, full_context_len=True, dynamic_len=False): 155 | start = max(0, idx - cache_len) 156 | end = min(seq_len, idx + future_cache_len) 157 | if dynamic_len: 158 | start = np.random.randint(start, idx + 1) 159 | end = np.random.randint(idx, end + 1) 160 | elif full_context_len: 161 | total_cache_len = cache_len + future_cache_len 162 | if end - start < total_cache_len: 163 | if start > 0: 164 | start -= total_cache_len - (end - start) 165 | else: 166 | end += total_cache_len - (end - start) 167 | start = max(0, start) 168 | end = min(seq_len, end) 169 | return start, end 170 | 171 | 172 | def dump_retrieval(query, distances, idx, values, save_dir, batch_idx=None): 173 | save_dir = Path(save_dir) 174 | if batch_idx is not None: 175 | save_dir = save_dir / str(batch_idx) 176 | save_dir.mkdir(parents=True, exist_ok=True) 177 | vals = [] 178 | for row in idx: 179 | vals.append({k: v[row] for k, v in values.items()}) 180 | with open(str(save_dir / "values.pkl"), "wb") as f: 181 | pickle.dump(vals, f) 182 | with open(str(save_dir / "query.pkl"), "wb") as f: 183 | pickle.dump(query, f) 184 | with open(str(save_dir / "distances.pkl"), "wb") as f: 185 | pickle.dump(distances, f) 186 | with open(str(save_dir / "idx.pkl"), "wb") as f: 187 | pickle.dump(idx, f) 188 | 189 | 190 | def load_retrieval(path): 191 | path = Path(path) 192 | with open(str(path / "values.pkl"), "rb") as f: 193 | values = pickle.load(f) 194 | with open(str(path / "query.pkl"), "rb") as f: 195 | query = pickle.load(f) 196 | with open(str(path / "distances.pkl"), "rb") as f: 197 | distances = pickle.load(f) 198 | with open(str(path / "idx.pkl"), "rb") as f: 199 | idx = pickle.load(f) 200 | return query, distances, idx, values 201 | -------------------------------------------------------------------------------- /src/buffers/trajectory.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .buffer_utils import discount_cumsum_np, compute_rtg_from_target 3 | 4 | 5 | class Trajectory: 6 | 7 | def __init__(self, obs_shape, action_dim, max_len=1024, task_id=0, trj_id=0, trj_seed=0, 8 | relative_pos_embds=False, handle_timeout_termination=True, 9 | sample_full_seqs_only=False, last_seq_only=False, init_trj_buffers=True, 10 | episodic=False): 11 | self.obs_shape = obs_shape 12 | self.action_dim = action_dim 13 | self.max_len = max_len 14 | self.relative_pos_embds = relative_pos_embds 15 | self.handle_timeout_termination = handle_timeout_termination 16 | self.sample_full_seqs_only = sample_full_seqs_only 17 | self.task_id = task_id 18 | self.trj_id = trj_id 19 | self.trj_seed = trj_seed 20 | self.last_seq_only = last_seq_only 21 | self.init_trj_buffers = init_trj_buffers 22 | self.episodic = episodic 23 | if self.init_trj_buffers: 24 | assert obs_shape is not None and action_dim is not None, "obs_shape and action_dim must be provided" 25 | self.observations = np.zeros((self.max_len, ) + self.obs_shape, dtype=np.float32) 26 | self.next_observations = np.zeros((self.max_len, ) + self.obs_shape, dtype=np.float32) 27 | self.actions = np.zeros((self.max_len, self.action_dim), dtype=np.float32) 28 | self.rewards = np.zeros((self.max_len), dtype=np.float32) 29 | self.timesteps = np.zeros((self.max_len), dtype=np.float32) 30 | self.timeouts = np.zeros((self.max_len), dtype=np.float32) 31 | self.pos = 0 32 | self.full = False 33 | self.returns_to_go = None 34 | 35 | def add(self, obs, next_obs, action, reward, done, infos=None): 36 | self.observations[self.pos] = np.array(obs).copy() 37 | self.next_observations[self.pos] = np.array(next_obs).copy() 38 | self.actions[self.pos] = np.array(action).copy() 39 | self.rewards[self.pos] = np.array(reward).copy() 40 | self.timesteps[self.pos] = np.array(self.pos).copy() 41 | if self.handle_timeout_termination: 42 | # we assume there is only one environment --> index 0 of infos 43 | self.timeouts[self.pos] = np.array(infos[0].get("TimeLimit.truncated", False)) 44 | self.pos += 1 45 | if done: 46 | self.add_dones() 47 | if self.pos == self.max_len: 48 | self.full = True 49 | return self.full 50 | 51 | def add_full_trj(self, obs, next_obs, action, reward, done, task_id, trj_id, returns_to_go=None, trj_seed=None): 52 | self.pos = len(obs) 53 | if self.episodic: 54 | total_reward = np.sum(reward) 55 | reward = np.zeros_like(reward) 56 | reward[-1] = total_reward 57 | if self.init_trj_buffers: 58 | # buffers already, exist populate them with given trajectory 59 | self.observations[:self.pos] = obs 60 | if next_obs is not None: 61 | self.next_observations[:self.pos] = next_obs 62 | self.actions[:self.pos] = action 63 | self.rewards[:self.pos] = reward 64 | self.timesteps[:self.pos] = np.arange(0, self.pos) 65 | else: 66 | # buffers do not exist, assign using the given trajectory 67 | self.observations = obs 68 | self.next_observations = next_obs 69 | self.actions = action 70 | self.rewards = reward 71 | self.timesteps = np.arange(0, self.pos) 72 | self.timeouts = np.zeros((self.pos), dtype=np.float32) 73 | self.add_dones(is_done=done[-1]) 74 | self.full = True 75 | self.task_id = task_id 76 | self.trj_id = trj_id 77 | self.trj_seed = trj_seed 78 | self.returns_to_go = returns_to_go 79 | 80 | def sample(self, context_len=1, full_trj=False, end_idx=None): 81 | """ 82 | Samples a trajectory from the buffer. 83 | 84 | It is important to sample the end_idx first. 85 | Otherwise we may get trajectories that can't actually happen during evaluation: 86 | E.g., for context len=20, could have start_idx=5, end_idx=10 --> can't happen 87 | If we sampled end_idx first: end_idx=10 and start_idx=min(10 - context_len, 0), which could be a valid trj. 88 | 89 | Args: 90 | context_len (int, optional): The length of the context to include in the trajectory. Defaults to 1. 91 | full_trj (bool, optional): Whether to sample the full trajectory. Defaults to False. 92 | end_idx (int, optional): The index of the end of the trajectory. Defaults to None. 93 | 94 | Returns: 95 | tuple: A tuple containing the sampled trajectory. 96 | """ 97 | if full_trj: 98 | start, end = 0, self.pos 99 | elif end_idx is not None: 100 | end = min(end_idx, self.pos) 101 | start = max(0, end - context_len) 102 | else: 103 | # start = np.random.randint(0, self.pos, size=1)[0] 104 | # end = min(start + context_len, self.pos) 105 | end = np.random.randint(1, self.pos, size=1)[0] if self.pos > 1 else 1 106 | start = max(0, end - context_len) 107 | if self.last_seq_only and start < context_len: 108 | # ensure that, the agent also has the possibility to see the first n steps of a trajectory 109 | end = np.random.randint(start + 1, min(start + context_len, self.pos)) 110 | if self.sample_full_seqs_only and (end - start) < context_len: 111 | # ensure that only full sequences are sampled 112 | residual = context_len - (end - start) 113 | if context_len > self.pos: 114 | start = 0 115 | end = self.pos 116 | elif start - residual >= 0: 117 | start -= residual 118 | elif end + residual < self.pos: 119 | end += residual 120 | return self._get_samples(start, end) 121 | 122 | def _get_samples(self, start, end): 123 | timesteps = self.timesteps[start: end] 124 | dones = self.dones[start: end] 125 | if self.relative_pos_embds: 126 | timesteps = np.arange(len(timesteps)) 127 | if self.handle_timeout_termination: 128 | dones = (dones * (1 - self.timeouts[start: end])) 129 | obs = self.observations[start: end] 130 | return obs, \ 131 | self.next_observations[start: end] if self.next_observations is not None else np.zeros_like(obs), \ 132 | self.actions[start: end], self.rewards[start: end], \ 133 | self.returns_to_go[start: end], \ 134 | timesteps, dones, self.task_id, self.trj_id, self.trj_seed 135 | 136 | def prune_trajectory(self): 137 | # to avoid OOM issues. 138 | self.observations = self.observations[:self.pos] 139 | self.next_observations = self.next_observations[:self.pos] if self.next_observations is not None else None 140 | self.actions = self.actions[:self.pos] 141 | self.rewards = self.rewards[:self.pos] 142 | self.timesteps = self.timesteps[:self.pos] 143 | self.timeouts = self.timeouts[:self.pos] 144 | 145 | def setup_final_trj(self, target_return=None, compute_stats=True): 146 | self.compute_returns_to_go(target_return=target_return) 147 | if compute_stats: 148 | self.compute_mean_reward() 149 | self.compute_std_reward() 150 | self.prune_trajectory() 151 | 152 | def compute_returns_to_go(self, target_return=None): 153 | if self.returns_to_go is not None: 154 | # was already initialized when adding full trajectory 155 | return 156 | if target_return is not None: 157 | self.returns_to_go = compute_rtg_from_target(self.rewards, target_return) 158 | self.total_return = self.rewards.sum() 159 | else: 160 | self.returns_to_go = discount_cumsum_np(self.rewards[:self.pos], 1) 161 | self.total_return = self.returns_to_go[0] 162 | 163 | def compute_mean_reward(self): 164 | self.mean_reward = self.rewards[:self.pos].mean() 165 | 166 | def compute_std_reward(self): 167 | self.std_reward = self.rewards[:self.pos].std() 168 | 169 | def add_dones(self, is_done=True): 170 | self.dones = np.zeros(self.pos) 171 | if is_done: 172 | self.dones[-1] = 1 173 | 174 | def size(self): 175 | return self.pos 176 | 177 | def __len__(self): 178 | return self.pos 179 | --------------------------------------------------------------------------------