├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── configs ├── agent_params │ ├── darkroom.yaml │ ├── data_paths │ │ ├── atari.yaml │ │ ├── composuite240.yaml │ │ ├── d4rl.yaml │ │ ├── dark_keydoor_10x10.yaml │ │ ├── dark_room_10x10.yaml │ │ ├── dmcontrol11.yaml │ │ ├── mimicgen83.yaml │ │ ├── mt45_v2.yaml │ │ ├── mt45v2_dmc11.yaml │ │ ├── mt45v2_dmc11_pg12.yaml │ │ ├── mt45v2_dmc11_pg12_atari41.yaml │ │ ├── mt45v2_dmc11_pg12_atari41_cs240.yaml │ │ ├── mt45v2_dmc11_pg12_atari41_cs240_mg83.yaml │ │ ├── names │ │ │ ├── atari41.yaml │ │ │ ├── dmcontrol11.yaml │ │ │ └── mt45_v2.yaml │ │ └── procgen12.yaml │ ├── huggingface │ │ ├── dt_huge.yaml │ │ ├── dt_hugeplus.yaml │ │ ├── dt_large.yaml │ │ ├── dt_large_64.yaml │ │ ├── dt_largeplus_64.yaml │ │ ├── dt_larger.yaml │ │ ├── dt_medium.yaml │ │ ├── dt_medium_64.yaml │ │ ├── dt_mediumplus_64.yaml │ │ ├── mamba_huge.yaml │ │ ├── mamba_huge_half.yaml │ │ ├── mamba_hugeplus.yaml │ │ ├── mamba_large.yaml │ │ ├── mamba_medium.yaml │ │ ├── mamba_mediumplus.yaml │ │ ├── xlstm_huge.yaml │ │ ├── xlstm_huge_half.yaml │ │ ├── xlstm_hugeplus.yaml │ │ ├── xlstm_large.yaml │ │ ├── xlstm_large_half.yaml │ │ ├── xlstm_medium.yaml │ │ ├── xlstm_medium_half.yaml │ │ ├── xlstm_mediumplus.yaml │ │ ├── xlstm_mediumplus_half.yaml │ │ └── xlstm_ms_mediumplus.yaml │ ├── lr_sched_kwargs │ │ └── cosine.yaml │ ├── model_kwargs │ │ ├── atari.yaml │ │ ├── dark_room.yaml │ │ ├── default.yaml │ │ ├── dmcontrol.yaml │ │ ├── mt_disc.yaml │ │ ├── multi_domain.yaml │ │ └── procgen.yaml │ ├── multi_domain.yaml │ └── replay_buffer_kwargs │ │ ├── multi_domain_mtdmccs.yaml │ │ └── single_domain_disc.yaml ├── config.yaml ├── env_params │ ├── atari.yaml │ ├── atari_freeway.yaml │ ├── composuite.yaml │ ├── dark_keydoor.yaml │ ├── dark_room.yaml │ ├── dmcontrol_icl.yaml │ ├── mimicgen.yaml │ ├── mt45.yaml │ ├── mt_dmc_procgen.yaml │ ├── mt_dmc_procgen_atari.yaml │ ├── mt_dmc_procgen_atari_cs.yaml │ ├── mt_dmc_procgen_atari_cs_mg.yaml │ ├── mujoco_gym.yaml │ └── procgen.yaml ├── eval_params │ ├── base.yaml │ ├── finetune.yaml │ ├── pretrain.yaml │ └── pretrain_icl.yaml ├── run_params │ ├── base.yaml │ ├── evaluate.yaml │ ├── finetune.yaml │ └── pretrain.yaml └── wandb_callback_params │ └── pretrain.yaml ├── dmc2gym_custom ├── README.md ├── dmc2gym_custom │ ├── __init__.py │ └── wrappers.py └── setup.py ├── environment.yaml ├── evaluate.py ├── figures └── lram.png ├── main.py ├── requirements.txt └── src ├── __init__.py ├── algos ├── __init__.py ├── agent_utils.py ├── builder.py ├── decision_mamba.py ├── decision_transformer_sb3.py ├── decision_xlstm.py ├── discrete_decision_transformer_sb3.py ├── models │ ├── __init__.py │ ├── custom_critic.py │ ├── custom_dt_model.py │ ├── decision_mamba.py │ ├── decision_xlstm.py │ ├── discrete_decision_transformer_model.py │ ├── extractors.py │ ├── image_encoders.py │ ├── model_utils.py │ ├── multi_domain_discrete_dt_model.py │ ├── online_decision_transformer_model.py │ ├── rms_norm.py │ ├── rope.py │ └── token_learner.py ├── ppo_with_buffer.py └── universal_decision_transformer_sb3.py ├── augmentations ├── __init__.py └── augs.py ├── buffers ├── __init__.py ├── buffer_utils.py ├── dataloaders.py ├── multi_domain_buffer.py ├── samplers.py ├── trajectory.py ├── trajectory_buffer.py └── trajectory_dataset.py ├── callbacks ├── __init__.py ├── builder.py ├── custom_eval_callback.py ├── evaluation.py └── validation_callback.py ├── data ├── __init__.py ├── atari │ ├── README.md │ ├── download_atari_datasets.py │ └── requirements.txt ├── composuite │ ├── README.md │ └── prepare_data.py ├── data_stats_extractor.py ├── mimicgen │ ├── README.md │ └── prepare_data.py ├── parallel_copy.py ├── procgen │ └── prepare_data.py └── untar_files.sh ├── envs ├── __init__.py ├── atari_utils.py ├── builder.py ├── compatibility_wrapper.py ├── composuite_utils.py ├── cw_utils.py ├── dmcontrol_utils.py ├── dn_scores.py ├── dummy_env_utils.py ├── env_names.py ├── env_utils.py ├── hn_scores.py ├── mimicgen_utils.py ├── minihack_utils.py ├── procgen_utils.py └── target_returns.py ├── optimizers └── __init__.py ├── schedulers ├── __init__.py ├── lr_schedulers.py └── visualize_schedulers.py ├── tokenizers_custom ├── __init__.py ├── base_tokenizer.py ├── minmax_tokenizer.py └── mu_law_tokenizer.py └── utils ├── __init__.py ├── debug.py └── misc.py /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/.gitignore -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/.gitmodules -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/LICENSE -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/README.md -------------------------------------------------------------------------------- /configs/agent_params/darkroom.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/darkroom.yaml -------------------------------------------------------------------------------- /configs/agent_params/data_paths/atari.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/data_paths/atari.yaml -------------------------------------------------------------------------------- /configs/agent_params/data_paths/composuite240.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/data_paths/composuite240.yaml -------------------------------------------------------------------------------- /configs/agent_params/data_paths/d4rl.yaml: -------------------------------------------------------------------------------- 1 | base: ${DATA_DIR}/d4rl 2 | names: hopper-medium-v2.pkl 3 | 4 | -------------------------------------------------------------------------------- /configs/agent_params/data_paths/dark_keydoor_10x10.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/data_paths/dark_keydoor_10x10.yaml -------------------------------------------------------------------------------- /configs/agent_params/data_paths/dark_room_10x10.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/data_paths/dark_room_10x10.yaml -------------------------------------------------------------------------------- /configs/agent_params/data_paths/dmcontrol11.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/data_paths/dmcontrol11.yaml -------------------------------------------------------------------------------- /configs/agent_params/data_paths/mimicgen83.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/data_paths/mimicgen83.yaml -------------------------------------------------------------------------------- /configs/agent_params/data_paths/mt45_v2.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/data_paths/mt45_v2.yaml -------------------------------------------------------------------------------- /configs/agent_params/data_paths/mt45v2_dmc11.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/data_paths/mt45v2_dmc11.yaml -------------------------------------------------------------------------------- /configs/agent_params/data_paths/mt45v2_dmc11_pg12.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/data_paths/mt45v2_dmc11_pg12.yaml -------------------------------------------------------------------------------- /configs/agent_params/data_paths/mt45v2_dmc11_pg12_atari41.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/data_paths/mt45v2_dmc11_pg12_atari41.yaml -------------------------------------------------------------------------------- /configs/agent_params/data_paths/mt45v2_dmc11_pg12_atari41_cs240.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/data_paths/mt45v2_dmc11_pg12_atari41_cs240.yaml -------------------------------------------------------------------------------- /configs/agent_params/data_paths/mt45v2_dmc11_pg12_atari41_cs240_mg83.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/data_paths/mt45v2_dmc11_pg12_atari41_cs240_mg83.yaml -------------------------------------------------------------------------------- /configs/agent_params/data_paths/names/atari41.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/data_paths/names/atari41.yaml -------------------------------------------------------------------------------- /configs/agent_params/data_paths/names/dmcontrol11.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/data_paths/names/dmcontrol11.yaml -------------------------------------------------------------------------------- /configs/agent_params/data_paths/names/mt45_v2.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/data_paths/names/mt45_v2.yaml -------------------------------------------------------------------------------- /configs/agent_params/data_paths/procgen12.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/data_paths/procgen12.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_huge.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/dt_huge.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_hugeplus.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/dt_hugeplus.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_large.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/dt_large.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_large_64.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/dt_large_64.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_largeplus_64.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/dt_largeplus_64.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_larger.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/dt_larger.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_medium.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/dt_medium.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_medium_64.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/dt_medium_64.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/dt_mediumplus_64.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/dt_mediumplus_64.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/mamba_huge.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/mamba_huge.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/mamba_huge_half.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/mamba_huge_half.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/mamba_hugeplus.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/mamba_hugeplus.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/mamba_large.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/mamba_large.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/mamba_medium.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/mamba_medium.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/mamba_mediumplus.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/mamba_mediumplus.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_huge.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/xlstm_huge.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_huge_half.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/xlstm_huge_half.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_hugeplus.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/xlstm_hugeplus.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_large.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/xlstm_large.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_large_half.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/xlstm_large_half.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_medium.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/xlstm_medium.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_medium_half.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/xlstm_medium_half.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_mediumplus.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/xlstm_mediumplus.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_mediumplus_half.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/xlstm_mediumplus_half.yaml -------------------------------------------------------------------------------- /configs/agent_params/huggingface/xlstm_ms_mediumplus.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/huggingface/xlstm_ms_mediumplus.yaml -------------------------------------------------------------------------------- /configs/agent_params/lr_sched_kwargs/cosine.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/lr_sched_kwargs/cosine.yaml -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/atari.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/model_kwargs/atari.yaml -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/dark_room.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/model_kwargs/dark_room.yaml -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/default.yaml: -------------------------------------------------------------------------------- 1 | reward_condition: True 2 | relative_pos_embds: False 3 | -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/dmcontrol.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/model_kwargs/dmcontrol.yaml -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/mt_disc.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/model_kwargs/mt_disc.yaml -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/multi_domain.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/model_kwargs/multi_domain.yaml -------------------------------------------------------------------------------- /configs/agent_params/model_kwargs/procgen.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/model_kwargs/procgen.yaml -------------------------------------------------------------------------------- /configs/agent_params/multi_domain.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/multi_domain.yaml -------------------------------------------------------------------------------- /configs/agent_params/replay_buffer_kwargs/multi_domain_mtdmccs.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/replay_buffer_kwargs/multi_domain_mtdmccs.yaml -------------------------------------------------------------------------------- /configs/agent_params/replay_buffer_kwargs/single_domain_disc.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/agent_params/replay_buffer_kwargs/single_domain_disc.yaml -------------------------------------------------------------------------------- /configs/config.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/config.yaml -------------------------------------------------------------------------------- /configs/env_params/atari.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/env_params/atari.yaml -------------------------------------------------------------------------------- /configs/env_params/atari_freeway.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/env_params/atari_freeway.yaml -------------------------------------------------------------------------------- /configs/env_params/composuite.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/env_params/composuite.yaml -------------------------------------------------------------------------------- /configs/env_params/dark_keydoor.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/env_params/dark_keydoor.yaml -------------------------------------------------------------------------------- /configs/env_params/dark_room.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/env_params/dark_room.yaml -------------------------------------------------------------------------------- /configs/env_params/dmcontrol_icl.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/env_params/dmcontrol_icl.yaml -------------------------------------------------------------------------------- /configs/env_params/mimicgen.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/env_params/mimicgen.yaml -------------------------------------------------------------------------------- /configs/env_params/mt45.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/env_params/mt45.yaml -------------------------------------------------------------------------------- /configs/env_params/mt_dmc_procgen.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/env_params/mt_dmc_procgen.yaml -------------------------------------------------------------------------------- /configs/env_params/mt_dmc_procgen_atari.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/env_params/mt_dmc_procgen_atari.yaml -------------------------------------------------------------------------------- /configs/env_params/mt_dmc_procgen_atari_cs.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/env_params/mt_dmc_procgen_atari_cs.yaml -------------------------------------------------------------------------------- /configs/env_params/mt_dmc_procgen_atari_cs_mg.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/env_params/mt_dmc_procgen_atari_cs_mg.yaml -------------------------------------------------------------------------------- /configs/env_params/mujoco_gym.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/env_params/mujoco_gym.yaml -------------------------------------------------------------------------------- /configs/env_params/procgen.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/env_params/procgen.yaml -------------------------------------------------------------------------------- /configs/eval_params/base.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/eval_params/base.yaml -------------------------------------------------------------------------------- /configs/eval_params/finetune.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/eval_params/finetune.yaml -------------------------------------------------------------------------------- /configs/eval_params/pretrain.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/eval_params/pretrain.yaml -------------------------------------------------------------------------------- /configs/eval_params/pretrain_icl.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/eval_params/pretrain_icl.yaml -------------------------------------------------------------------------------- /configs/run_params/base.yaml: -------------------------------------------------------------------------------- 1 | total_timesteps: 1e6 2 | log_interval: 10 -------------------------------------------------------------------------------- /configs/run_params/evaluate.yaml: -------------------------------------------------------------------------------- 1 | log_interval: 1 2 | total_timesteps: 0 3 | -------------------------------------------------------------------------------- /configs/run_params/finetune.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/run_params/finetune.yaml -------------------------------------------------------------------------------- /configs/run_params/pretrain.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/run_params/pretrain.yaml -------------------------------------------------------------------------------- /configs/wandb_callback_params/pretrain.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/configs/wandb_callback_params/pretrain.yaml -------------------------------------------------------------------------------- /dmc2gym_custom/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/dmc2gym_custom/README.md -------------------------------------------------------------------------------- /dmc2gym_custom/dmc2gym_custom/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/dmc2gym_custom/dmc2gym_custom/__init__.py -------------------------------------------------------------------------------- /dmc2gym_custom/dmc2gym_custom/wrappers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/dmc2gym_custom/dmc2gym_custom/wrappers.py -------------------------------------------------------------------------------- /dmc2gym_custom/setup.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/dmc2gym_custom/setup.py -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/environment.yaml -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/evaluate.py -------------------------------------------------------------------------------- /figures/lram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/figures/lram.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/main.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/requirements.txt -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/algos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/__init__.py -------------------------------------------------------------------------------- /src/algos/agent_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/agent_utils.py -------------------------------------------------------------------------------- /src/algos/builder.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/builder.py -------------------------------------------------------------------------------- /src/algos/decision_mamba.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/decision_mamba.py -------------------------------------------------------------------------------- /src/algos/decision_transformer_sb3.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/decision_transformer_sb3.py -------------------------------------------------------------------------------- /src/algos/decision_xlstm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/decision_xlstm.py -------------------------------------------------------------------------------- /src/algos/discrete_decision_transformer_sb3.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/discrete_decision_transformer_sb3.py -------------------------------------------------------------------------------- /src/algos/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/models/__init__.py -------------------------------------------------------------------------------- /src/algos/models/custom_critic.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/models/custom_critic.py -------------------------------------------------------------------------------- /src/algos/models/custom_dt_model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/models/custom_dt_model.py -------------------------------------------------------------------------------- /src/algos/models/decision_mamba.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/models/decision_mamba.py -------------------------------------------------------------------------------- /src/algos/models/decision_xlstm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/models/decision_xlstm.py -------------------------------------------------------------------------------- /src/algos/models/discrete_decision_transformer_model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/models/discrete_decision_transformer_model.py -------------------------------------------------------------------------------- /src/algos/models/extractors.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/models/extractors.py -------------------------------------------------------------------------------- /src/algos/models/image_encoders.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/models/image_encoders.py -------------------------------------------------------------------------------- /src/algos/models/model_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/models/model_utils.py -------------------------------------------------------------------------------- /src/algos/models/multi_domain_discrete_dt_model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/models/multi_domain_discrete_dt_model.py -------------------------------------------------------------------------------- /src/algos/models/online_decision_transformer_model.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/models/online_decision_transformer_model.py -------------------------------------------------------------------------------- /src/algos/models/rms_norm.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/models/rms_norm.py -------------------------------------------------------------------------------- /src/algos/models/rope.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/models/rope.py -------------------------------------------------------------------------------- /src/algos/models/token_learner.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/models/token_learner.py -------------------------------------------------------------------------------- /src/algos/ppo_with_buffer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/ppo_with_buffer.py -------------------------------------------------------------------------------- /src/algos/universal_decision_transformer_sb3.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/algos/universal_decision_transformer_sb3.py -------------------------------------------------------------------------------- /src/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from .augs import make_augmentations -------------------------------------------------------------------------------- /src/augmentations/augs.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/augmentations/augs.py -------------------------------------------------------------------------------- /src/buffers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/buffers/__init__.py -------------------------------------------------------------------------------- /src/buffers/buffer_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/buffers/buffer_utils.py -------------------------------------------------------------------------------- /src/buffers/dataloaders.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/buffers/dataloaders.py -------------------------------------------------------------------------------- /src/buffers/multi_domain_buffer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/buffers/multi_domain_buffer.py -------------------------------------------------------------------------------- /src/buffers/samplers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/buffers/samplers.py -------------------------------------------------------------------------------- /src/buffers/trajectory.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/buffers/trajectory.py -------------------------------------------------------------------------------- /src/buffers/trajectory_buffer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/buffers/trajectory_buffer.py -------------------------------------------------------------------------------- /src/buffers/trajectory_dataset.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/buffers/trajectory_dataset.py -------------------------------------------------------------------------------- /src/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import make_callbacks 2 | -------------------------------------------------------------------------------- /src/callbacks/builder.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/callbacks/builder.py -------------------------------------------------------------------------------- /src/callbacks/custom_eval_callback.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/callbacks/custom_eval_callback.py -------------------------------------------------------------------------------- /src/callbacks/evaluation.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/callbacks/evaluation.py -------------------------------------------------------------------------------- /src/callbacks/validation_callback.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/callbacks/validation_callback.py -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/data/atari/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/data/atari/README.md -------------------------------------------------------------------------------- /src/data/atari/download_atari_datasets.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/data/atari/download_atari_datasets.py -------------------------------------------------------------------------------- /src/data/atari/requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/data/atari/requirements.txt -------------------------------------------------------------------------------- /src/data/composuite/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/data/composuite/README.md -------------------------------------------------------------------------------- /src/data/composuite/prepare_data.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/data/composuite/prepare_data.py -------------------------------------------------------------------------------- /src/data/data_stats_extractor.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/data/data_stats_extractor.py -------------------------------------------------------------------------------- /src/data/mimicgen/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/data/mimicgen/README.md -------------------------------------------------------------------------------- /src/data/mimicgen/prepare_data.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/data/mimicgen/prepare_data.py -------------------------------------------------------------------------------- /src/data/parallel_copy.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/data/parallel_copy.py -------------------------------------------------------------------------------- /src/data/procgen/prepare_data.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/data/procgen/prepare_data.py -------------------------------------------------------------------------------- /src/data/untar_files.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/data/untar_files.sh -------------------------------------------------------------------------------- /src/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import make_env 2 | -------------------------------------------------------------------------------- /src/envs/atari_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/envs/atari_utils.py -------------------------------------------------------------------------------- /src/envs/builder.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/envs/builder.py -------------------------------------------------------------------------------- /src/envs/compatibility_wrapper.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/envs/compatibility_wrapper.py -------------------------------------------------------------------------------- /src/envs/composuite_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/envs/composuite_utils.py -------------------------------------------------------------------------------- /src/envs/cw_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/envs/cw_utils.py -------------------------------------------------------------------------------- /src/envs/dmcontrol_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/envs/dmcontrol_utils.py -------------------------------------------------------------------------------- /src/envs/dn_scores.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/envs/dn_scores.py -------------------------------------------------------------------------------- /src/envs/dummy_env_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/envs/dummy_env_utils.py -------------------------------------------------------------------------------- /src/envs/env_names.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/envs/env_names.py -------------------------------------------------------------------------------- /src/envs/env_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/envs/env_utils.py -------------------------------------------------------------------------------- /src/envs/hn_scores.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/envs/hn_scores.py -------------------------------------------------------------------------------- /src/envs/mimicgen_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/envs/mimicgen_utils.py -------------------------------------------------------------------------------- /src/envs/minihack_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/envs/minihack_utils.py -------------------------------------------------------------------------------- /src/envs/procgen_utils.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/envs/procgen_utils.py -------------------------------------------------------------------------------- /src/envs/target_returns.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/envs/target_returns.py -------------------------------------------------------------------------------- /src/optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/optimizers/__init__.py -------------------------------------------------------------------------------- /src/schedulers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/schedulers/__init__.py -------------------------------------------------------------------------------- /src/schedulers/lr_schedulers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/schedulers/lr_schedulers.py -------------------------------------------------------------------------------- /src/schedulers/visualize_schedulers.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/schedulers/visualize_schedulers.py -------------------------------------------------------------------------------- /src/tokenizers_custom/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/tokenizers_custom/__init__.py -------------------------------------------------------------------------------- /src/tokenizers_custom/base_tokenizer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/tokenizers_custom/base_tokenizer.py -------------------------------------------------------------------------------- /src/tokenizers_custom/minmax_tokenizer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/tokenizers_custom/minmax_tokenizer.py -------------------------------------------------------------------------------- /src/tokenizers_custom/mu_law_tokenizer.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/tokenizers_custom/mu_law_tokenizer.py -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import maybe_split, safe_mean, multiply 2 | -------------------------------------------------------------------------------- /src/utils/debug.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/utils/debug.py -------------------------------------------------------------------------------- /src/utils/misc.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ml-jku/LRAM/HEAD/src/utils/misc.py --------------------------------------------------------------------------------