├── relic ├── __init__.py ├── configs │ ├── __init__.py │ ├── baseline │ │ ├── relic_HSSD.yaml │ │ ├── baselines_comparison │ │ │ ├── relic_HSSD.yaml │ │ │ ├── relic_HSSD_no_iea.yaml │ │ │ ├── single_episode_HSSD.yaml │ │ │ ├── trxl_HSSD.yaml │ │ │ └── rl2.yaml │ │ ├── relic_HSSD_64k.yaml │ │ ├── relic_replicaCAD.yaml │ │ ├── partial_updates_ablation │ │ │ ├── relic_replicaCAD.yaml │ │ │ └── relic_replicaCAD_no_partial_updates.yaml │ │ ├── sink_ablation │ │ │ ├── relic_replicaCAD_sink_kv.yaml │ │ │ ├── relic_replicaCAD_no_sink.yaml │ │ │ ├── relic_replicaCAD_sink_k0v.yaml │ │ │ ├── relic_replicaCAD_sink_k0v0.yaml │ │ │ ├── relic_replicaCAD_sink_kv0.yaml │ │ │ └── relic_replicaCAD_sink_token.yaml │ │ ├── custom_task │ │ │ ├── darkroom_relic.yaml │ │ │ └── darkroom_single_episode.yaml │ │ └── relic_base.yaml │ ├── backbone │ │ ├── vc1.yaml │ │ ├── blind.yaml │ │ ├── resnet18.yaml │ │ └── vc1_smallObjs.yaml │ ├── policy │ │ ├── lstm.yaml │ │ ├── transformer_tiny.yaml │ │ ├── transformer_large.yaml │ │ └── transformer_small.yaml │ ├── vc1 │ │ └── vc1_vitb_ft_cls_e15.yaml │ ├── tasks │ │ ├── ExtObjNav.yaml │ │ └── ExtObjNav_replicaCAD.yaml │ └── pddl │ │ ├── hssd_domain.yaml │ │ └── domain.yaml ├── tasks │ ├── __init__.py │ ├── utils.py │ ├── sensors.py │ └── actions.py ├── policies │ ├── __init__.py │ ├── visual_encoders.py │ ├── llamarl │ │ └── configuration_llamarl.py │ ├── transformer_wrappers.py │ └── transformerxl │ │ └── modeling_transformerxl.py ├── evaluator │ ├── __init__.py │ ├── process_eval_data.py │ └── habitat_evalutor.py ├── trainer │ ├── __init__.py │ ├── envs.py │ ├── rl2_storage.py │ ├── transformers_agent_access_mgr.py │ └── datasets.py ├── download_datasets.sh ├── run.py ├── run.sh ├── envs │ ├── custom_envs │ │ ├── darkroom.py │ │ └── example.py │ ├── train_env_factory.py │ └── train_il_env_factory.py └── monkey_patch_eai_vc.py ├── .pre-commit-config.yaml ├── setup.py ├── requirements.txt ├── .gitignore ├── launch.yaml └── README.md /relic/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /relic/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # nothing 2 | -------------------------------------------------------------------------------- /relic/configs/baseline/relic_HSSD.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - relic_base 5 | - /tasks: ExtObjNav 6 | - _self_ 7 | -------------------------------------------------------------------------------- /relic/configs/baseline/baselines_comparison/relic_HSSD.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /baseline/relic_HSSD 5 | - _self_ 6 | -------------------------------------------------------------------------------- /relic/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | from .measures import PddlTaskSuccess 2 | from .pddl_multi_task import PddlMultiTask 3 | from .sensors import OneHotTargetSensor 4 | from .actions import * 5 | -------------------------------------------------------------------------------- /relic/configs/backbone/vc1.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | habitat_baselines: 4 | force_blind_policy: False 5 | rl: 6 | ddppo: 7 | train_encoder: False 8 | backbone: vc1 9 | -------------------------------------------------------------------------------- /relic/configs/backbone/blind.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | habitat_baselines: 4 | force_blind_policy: True 5 | rl: 6 | ddppo: 7 | train_encoder: True 8 | backbone: resnet18 9 | -------------------------------------------------------------------------------- /relic/configs/backbone/resnet18.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | habitat_baselines: 4 | force_blind_policy: False 5 | rl: 6 | ddppo: 7 | train_encoder: True 8 | backbone: resnet18 9 | -------------------------------------------------------------------------------- /relic/policies/__init__.py: -------------------------------------------------------------------------------- 1 | from .pointnav import PointNavResNetTransformerPolicy 2 | from .pointnav_lstm import PointNavResNetLstmPolicy 3 | 4 | __all__ = ["PointNavResNetTransformerPolicy", "PointNavResNetLstmPolicy"] 5 | -------------------------------------------------------------------------------- /relic/configs/backbone/vc1_smallObjs.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | habitat_baselines: 4 | force_blind_policy: False 5 | rl: 6 | ddppo: 7 | train_encoder: False 8 | backbone: vc1_configs/vc1/vc1_vitb_ft_cls_e15.yaml 9 | -------------------------------------------------------------------------------- /relic/configs/baseline/relic_HSSD_64k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /baseline/relic_HSSD 5 | - _self_ 6 | 7 | 8 | habitat_baselines: 9 | rl: 10 | ppo: 11 | updates_per_rollout: 256 12 | num_steps: 65536 13 | -------------------------------------------------------------------------------- /relic/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from .repEps_evaluator import TransformersRepEpsHabitatEvaluator 2 | from .habitat_evalutor import TransformersHabitatEvaluator 3 | from .rl2_evaluator import RL2Evaluator 4 | from .process_eval_data import read_csvs, extract_episodes_data_from_df 5 | -------------------------------------------------------------------------------- /relic/configs/baseline/relic_replicaCAD.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - relic_base 5 | - /tasks: ExtObjNav_replicaCAD 6 | - _self_ 7 | 8 | habitat_baselines: 9 | rl: 10 | ppo: 11 | full_updates_per_rollout: 1 12 | updates_per_rollout: 4 13 | -------------------------------------------------------------------------------- /relic/configs/baseline/partial_updates_ablation/relic_replicaCAD.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /baseline/relic_replicaCAD 5 | - _self_ 6 | 7 | habitat_baselines: 8 | rl: 9 | ppo: 10 | full_updates_per_rollout: 1 11 | updates_per_rollout: 4 12 | -------------------------------------------------------------------------------- /relic/configs/baseline/partial_updates_ablation/relic_replicaCAD_no_partial_updates.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /baseline/relic_replicaCAD 5 | - _self_ 6 | 7 | habitat_baselines: 8 | rl: 9 | ppo: 10 | full_updates_per_rollout: 1 11 | updates_per_rollout: 1 12 | -------------------------------------------------------------------------------- /relic/configs/policy/lstm.yaml: -------------------------------------------------------------------------------- 1 | name: PointNavResNetLstmPolicy 2 | action_distribution_type: categorical 3 | action_dist: 4 | use_log_std: True 5 | use_std_param: True 6 | std_init: -1 7 | vc1_config: 8 | is_2d_output: False 9 | avg_pool_size: 2 10 | training_precision_config: 11 | visual_encoder: float16 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/psf/black 9 | rev: 22.10.0 10 | hooks: 11 | - id: black 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | with open("requirements.txt") as f: 4 | REQUIREMENTS = [x.strip() for x in f.readlines()] 5 | 6 | setup( 7 | name="relic", 8 | packages=["relic"], 9 | install_requires=REQUIREMENTS, 10 | extras_require={"dev": ["pre-commit"]}, 11 | version="0.1", 12 | ) 13 | -------------------------------------------------------------------------------- /relic/configs/baseline/baselines_comparison/relic_HSSD_no_iea.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /baseline/relic_HSSD 5 | - _self_ 6 | 7 | habitat_baselines: 8 | rl: 9 | policy: 10 | main_agent: 11 | transformer_config: 12 | inter_episodes_attention: False 13 | reset_position_index: True 14 | add_sequence_idx_embed: False 15 | -------------------------------------------------------------------------------- /relic/configs/baseline/sink_ablation/relic_replicaCAD_sink_kv.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /baseline/relic_replicaCAD 5 | - _self_ 6 | 7 | habitat_baselines: 8 | rl: 9 | policy: 10 | main_agent: 11 | transformer_config: 12 | add_sink_tokens: False 13 | add_sink_kv: True 14 | is_sink_v_trainable: True 15 | is_sink_k_trainable: True 16 | -------------------------------------------------------------------------------- /relic/configs/baseline/sink_ablation/relic_replicaCAD_no_sink.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /baseline/relic_replicaCAD 5 | - _self_ 6 | 7 | habitat_baselines: 8 | rl: 9 | policy: 10 | main_agent: 11 | transformer_config: 12 | add_sink_tokens: False 13 | add_sink_kv: False 14 | is_sink_v_trainable: False 15 | is_sink_k_trainable: False 16 | -------------------------------------------------------------------------------- /relic/configs/baseline/sink_ablation/relic_replicaCAD_sink_k0v.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /baseline/relic_replicaCAD 5 | - _self_ 6 | 7 | habitat_baselines: 8 | rl: 9 | policy: 10 | main_agent: 11 | transformer_config: 12 | add_sink_tokens: False 13 | add_sink_kv: True 14 | is_sink_v_trainable: True 15 | is_sink_k_trainable: False 16 | -------------------------------------------------------------------------------- /relic/configs/baseline/sink_ablation/relic_replicaCAD_sink_k0v0.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /baseline/relic_replicaCAD 5 | - _self_ 6 | 7 | habitat_baselines: 8 | rl: 9 | policy: 10 | main_agent: 11 | transformer_config: 12 | add_sink_tokens: False 13 | add_sink_kv: True 14 | is_sink_v_trainable: False 15 | is_sink_k_trainable: False 16 | -------------------------------------------------------------------------------- /relic/configs/baseline/sink_ablation/relic_replicaCAD_sink_kv0.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /baseline/relic_replicaCAD 5 | - _self_ 6 | 7 | habitat_baselines: 8 | rl: 9 | policy: 10 | main_agent: 11 | transformer_config: 12 | add_sink_tokens: False 13 | add_sink_kv: True 14 | is_sink_v_trainable: False 15 | is_sink_k_trainable: True 16 | -------------------------------------------------------------------------------- /relic/configs/baseline/sink_ablation/relic_replicaCAD_sink_token.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /baseline/relic_replicaCAD 5 | - _self_ 6 | 7 | habitat_baselines: 8 | rl: 9 | policy: 10 | main_agent: 11 | transformer_config: 12 | add_sink_tokens: True 13 | add_sink_kv: False 14 | is_sink_v_trainable: False 15 | is_sink_k_trainable: False 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.35.0 2 | einops==0.7.0 3 | accelerate==0.24.1 4 | torch~=2.1.0 5 | pandas 6 | flash-attn==2.4.1 7 | vc_models @ git+https://github.com/facebookresearch/eai-vc.git@76fe35e87b1937168f1ec4b236e863451883eaf3#subdirectory=vc_models 8 | habitat-lab @ git+https://github.com/facebookresearch/habitat-lab.git@v0.3.0#subdirectory=habitat-lab 9 | habitat-baselines @ git+https://github.com/facebookresearch/habitat-lab.git@v0.3.0#subdirectory=habitat-baselines 10 | -------------------------------------------------------------------------------- /relic/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import ( 2 | RearrangeDatasetTransformersV0, 3 | ) 4 | from .envs import CustomGymHabitatEnv 5 | from .transformer_ppo import TransformerPPO 6 | from .transformer_storage import ( 7 | TransformerRolloutStorage, 8 | ) 9 | from .rl2_storage import RL2RolloutStorage 10 | from .transformers_agent_access_mgr import ( 11 | TransformerSingleAgentAccessMgr, 12 | ) 13 | from .transformers_trainer import ( 14 | TransformersTrainer, 15 | ) 16 | 17 | from .rl2_trainer import ( 18 | RL2Trainer, 19 | ) 20 | -------------------------------------------------------------------------------- /relic/configs/vc1/vc1_vitb_ft_cls_e15.yaml: -------------------------------------------------------------------------------- 1 | _target_: vc_models.models.load_model 2 | model: 3 | _target_: vc_models.models.vit.vit.load_mae_encoder 4 | checkpoint_path: model_ckpts/cls/full_ft_e15_l0.463_ac0.895.pth 5 | model: 6 | _target_: vc_models.models.vit.vit.vit_base_patch16 7 | img_size: 224 8 | use_cls: True 9 | drop_path_rate: 0.0 10 | transform: 11 | _target_: vc_models.transforms.vit_transforms 12 | metadata: 13 | algo: mae 14 | model: vit_base_patch16 15 | data: 16 | - ego 17 | - imagenet 18 | - inav 19 | comment: 182_epochs 20 | -------------------------------------------------------------------------------- /relic/configs/baseline/baselines_comparison/single_episode_HSSD.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /baseline/relic_HSSD 5 | - _self_ 6 | 7 | habitat_baselines: 8 | reset_envs_after_update: False 9 | call_after_update_env: True 10 | rl: 11 | policy: 12 | main_agent: 13 | transformer_config: 14 | inter_episodes_attention: False 15 | reset_position_index: True 16 | add_sequence_idx_embed: False 17 | context_len: 512 18 | ppo: 19 | num_steps: 256 20 | full_updates_per_rollout: 0 21 | updates_per_rollout: 1 22 | shuffle_old_episodes: False 23 | -------------------------------------------------------------------------------- /relic/configs/baseline/baselines_comparison/trxl_HSSD.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /baseline/relic_HSSD 5 | - _self_ 6 | 7 | habitat_baselines: 8 | reset_envs_after_update: False 9 | call_after_update_env: True 10 | rl: 11 | policy: 12 | main_agent: 13 | transformer_config: 14 | model_name: transformerxl 15 | inter_episodes_attention: False 16 | reset_position_index: True 17 | add_sequence_idx_embed: False 18 | context_len: 0 19 | mem_len: 256 20 | ppo: 21 | num_steps: 256 22 | full_updates_per_rollout: 1 23 | updates_per_rollout: 1 24 | shuffle_old_episodes: False 25 | force_env_reset_every: 4096 26 | shift_scene_every: 4096 27 | -------------------------------------------------------------------------------- /relic/configs/policy/transformer_tiny.yaml: -------------------------------------------------------------------------------- 1 | name: PointNavResNetTransformerPolicy 2 | action_distribution_type: categorical 3 | transformer_config: 4 | model_name: "llamarl" 5 | n_layers: 2 6 | n_heads: 8 7 | n_hidden: 64 8 | n_mlp_hidden: 256 9 | kv_size: 32 10 | activation: "gelu_new" 11 | inter_episodes_attention: True 12 | reset_position_index: True 13 | add_sequence_idx_embed: True 14 | position_embed_type: learnable 15 | sequence_embed_type: rope 16 | gated_residual: False 17 | context_len: 0 18 | banded_attention: False 19 | orphan_steps_attention: True 20 | add_context_loss: False 21 | depth_dropout_p: 0.0 22 | max_position_embeddings: 2048 23 | add_sink_kv: True 24 | mul_factor_for_sink_attn: False 25 | is_sink_v_trainable: False 26 | vc1_config: 27 | is_2d_output: False 28 | avg_pool_size: 2 29 | training_precision_config: 30 | visual_encoder: float16 31 | -------------------------------------------------------------------------------- /relic/configs/policy/transformer_large.yaml: -------------------------------------------------------------------------------- 1 | name: PointNavResNetTransformerPolicy 2 | action_distribution_type: categorical 3 | transformer_config: 4 | model_name: "llamarl" 5 | n_layers: 8 6 | n_heads: 24 7 | n_hidden: 768 8 | n_mlp_hidden: 3072 9 | kv_size: 32 10 | activation: "gelu_new" 11 | inter_episodes_attention: True 12 | reset_position_index: True 13 | add_sequence_idx_embed: True 14 | position_embed_type: learnable 15 | sequence_embed_type: rope 16 | gated_residual: False 17 | context_len: 0 18 | banded_attention: False 19 | orphan_steps_attention: True 20 | add_context_loss: False 21 | depth_dropout_p: 0.1 22 | max_position_embeddings: 32768 23 | add_sink_kv: True 24 | mul_factor_for_sink_attn: False 25 | is_sink_v_trainable: False 26 | vc1_config: 27 | is_2d_output: False 28 | avg_pool_size: 2 29 | training_precision_config: 30 | visual_encoder: float16 31 | -------------------------------------------------------------------------------- /relic/configs/policy/transformer_small.yaml: -------------------------------------------------------------------------------- 1 | name: PointNavResNetTransformerPolicy 2 | action_distribution_type: categorical 3 | transformer_config: 4 | model_name: "llamarl" 5 | n_layers: 4 6 | n_heads: 8 7 | n_hidden: 256 8 | n_mlp_hidden: 1024 9 | kv_size: 32 10 | activation: "gelu_new" 11 | inter_episodes_attention: True 12 | reset_position_index: True 13 | add_sequence_idx_embed: True 14 | position_embed_type: learnable 15 | sequence_embed_type: rope 16 | gated_residual: False 17 | context_len: 0 18 | banded_attention: False 19 | orphan_steps_attention: True 20 | add_context_loss: False 21 | depth_dropout_p: 0.1 22 | max_position_embeddings: 32768 23 | add_sink_kv: True 24 | mul_factor_for_sink_attn: False 25 | is_sink_v_trainable: False 26 | vc1_config: 27 | is_2d_output: False 28 | avg_pool_size: 2 29 | training_precision_config: 30 | visual_encoder: float16 31 | -------------------------------------------------------------------------------- /relic/configs/baseline/custom_task/darkroom_relic.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /baseline/relic_base 5 | - /tasks: ExtObjNav_replicaCAD 6 | - override /backbone: blind 7 | - override /policy@habitat_baselines.rl.policy.main_agent: transformer_tiny 8 | - _self_ 9 | 10 | habitat: 11 | task: 12 | make_env_fn: "relic.envs.custom_envs.example.make_env" 13 | 14 | habitat_baselines: 15 | num_environments: 8 16 | total_num_steps: 5e6 17 | num_checkpoints: 10 18 | rl: 19 | policy: 20 | main_agent: 21 | transformer_config: 22 | is_sink_v_trainable: True 23 | is_sink_k_trainable: True 24 | ppo: 25 | num_mini_batch: 8 26 | grad_accum_mini_batches: 4 27 | entropy_coef: 0.01 28 | value_loss_coef: 0.5 29 | updates_per_rollout: 16 30 | num_steps: 4096 31 | full_updates_per_rollout: 1 32 | lr: 2e-4 33 | warmup_total_iters: 10000 34 | lrsched_T_max: 5000000 # Typically the same value as habitat_baselines.total_num_steps. 35 | -------------------------------------------------------------------------------- /relic/download_datasets.sh: -------------------------------------------------------------------------------- 1 | # ReplicaCAD scenes 2 | echo "Downloading ReplicaCAD scenes..." 3 | python -m habitat_sim.utils.datasets_download --uids rearrange_task_assets 4 | 5 | # HSSD scenes 6 | if [ -d "data/scene_datasets/hssd-hab" ]; then 7 | echo "HSSD exists. Skipping HSSD download." 8 | else 9 | echo "Downloading HSSD scenes..." 10 | mkdir -p data/scene_datasets/ 11 | pushd data/scene_datasets/ 12 | 13 | GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/datasets/hssd/hssd-hab 14 | cd hssd-hab 15 | git checkout 34b5f48c343cd7eb65d46e45402c85a004d77c92 16 | git lfs pull 17 | git lfs prune -f 18 | 19 | popd 20 | fi 21 | 22 | # ExtObjNav 23 | echo "Downloading ExtObjNav..." 24 | huggingface-cli download aielawady/ExtObjNav --local-dir-use-symlinks False --repo-type dataset --local-dir data/datasets/ExtObjNav_HSSD 25 | huggingface-cli download aielawady/ExtObjNav_ReplicaCAD --local-dir-use-symlinks False --repo-type dataset --local-dir data/datasets/ExtObjNav_replicaCAD 26 | 27 | # VC1 finetuned 28 | echo "Downloading VC1 finetuned checkpoint..." 29 | huggingface-cli download aielawady/vc1-smallObj --local-dir-use-symlinks False --local-dir model_ckpts/cls 30 | -------------------------------------------------------------------------------- /relic/configs/baseline/custom_task/darkroom_single_episode.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /baseline/relic_base 5 | - /tasks: ExtObjNav_replicaCAD 6 | - override /backbone: blind 7 | - override /policy@habitat_baselines.rl.policy.main_agent: transformer_tiny 8 | - _self_ 9 | 10 | habitat: 11 | task: 12 | make_env_fn: "relic.envs.custom_envs.example.make_env" 13 | 14 | habitat_baselines: 15 | num_environments: 8 16 | total_num_steps: 5e6 17 | num_checkpoints: 10 18 | rl: 19 | policy: 20 | main_agent: 21 | transformer_config: 22 | is_sink_v_trainable: True 23 | is_sink_k_trainable: True 24 | inter_episodes_attention: False 25 | reset_position_index: True 26 | add_sequence_idx_embed: False 27 | context_len: 128 # This should larger than the max number of steps in the task. 28 | ppo: 29 | num_mini_batch: 8 30 | grad_accum_mini_batches: 4 31 | entropy_coef: 0.01 32 | value_loss_coef: 0.5 33 | updates_per_rollout: 1 34 | num_steps: 256 # This controls the frequency of policy updates. 35 | full_updates_per_rollout: 0 36 | lr: 2e-4 37 | warmup_total_iters: 10000 38 | lrsched_T_max: 5000000 # Typically the same value as habitat_baselines.total_num_steps. 39 | shuffle_old_episodes: False 40 | -------------------------------------------------------------------------------- /relic/run.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to launch Habitat Baselines trainer. 3 | """ 4 | 5 | import random 6 | 7 | import gym 8 | import hydra 9 | import numpy as np 10 | import torch 11 | from habitat.config.default import patch_config 12 | from habitat.config.default_structured_configs import register_hydra_plugin 13 | from habitat_baselines.common.baseline_registry import baseline_registry 14 | from habitat_baselines.config.default_structured_configs import ( 15 | HabitatBaselinesConfigPlugin, 16 | ) 17 | 18 | import relic.policies 19 | import relic.tasks 20 | import relic.trainer 21 | from relic import default_structured_configs 22 | 23 | from relic import monkey_patch_eai_vc 24 | 25 | gym.logger.set_level(40) 26 | 27 | 28 | @hydra.main( 29 | version_base=None, 30 | config_path="configs", 31 | config_name="pointnav/ppo_pointnav_example", 32 | ) 33 | def main(cfg): 34 | cfg = patch_config(cfg) 35 | random.seed(cfg.habitat.seed) 36 | np.random.seed(cfg.habitat.seed) 37 | torch.manual_seed(cfg.habitat.seed) 38 | 39 | if cfg.habitat_baselines.force_torch_single_threaded and torch.cuda.is_available(): 40 | torch.set_num_threads(1) 41 | 42 | trainer_init = baseline_registry.get_trainer(cfg.habitat_baselines.trainer_name) 43 | trainer = trainer_init(cfg) 44 | 45 | if cfg.habitat_baselines.evaluate: 46 | trainer.eval() 47 | else: 48 | trainer.train() 49 | 50 | 51 | if __name__ == "__main__": 52 | register_hydra_plugin(HabitatBaselinesConfigPlugin) 53 | main() 54 | -------------------------------------------------------------------------------- /relic/run.sh: -------------------------------------------------------------------------------- 1 | # Change these values 2 | 3 | CONFIG_NAME=baseline/relic_HSSD 4 | 5 | WB_ENTITY= 6 | WB_JOB_ID="relic_ExtObjNav" 7 | WB_PROJECT_NAME="relic" 8 | 9 | DATA_DIR="exp_data" 10 | 11 | N_GPUS=1 12 | 13 | ################################ EVAL PARAMS ################################## 14 | 15 | 16 | EVAL_N_STEPS=8200 17 | EVAL_N_DEMOS=-1 18 | EVAL_FIX_TARGET_IN_TRIAL=False 19 | EVAL_MAX_NUM_START_POS=-1 20 | EVAL_MAX_NUM_EPISODES=-1 21 | CKPT_NAME=latest.pth # Or ckpt.#.pth where # is the checkpoint number 22 | 23 | ############################################################################### 24 | CHECKPOINT_FOLDER=$DATA_DIR/checkpoints/$WB_PROJECT_NAME/$WB_JOB_ID/ 25 | VIDEO_DIR=$DATA_DIR/vids/$WB_PROJECT_NAME/$WB_JOB_ID/ 26 | LOG_FILE=$DATA_DIR/logs/$WB_PROJECT_NAME/$WB_JOB_ID.log 27 | EVAL_DATA_DIR=$DATA_DIR/eval_data/$WB_PROJECT_NAME/$WB_JOB_ID/$CKPT_NAME/ 28 | 29 | IS_EVAL=${1:-0} 30 | 31 | if [ "$IS_EVAL" = "--eval" ] ; then 32 | echo "here $IS_EVAL" 33 | EVAL_ARGS=( 34 | habitat_baselines.eval_ckpt_path_dir=$DATA_DIR/checkpoints/$WB_PROJECT_NAME/$WB_JOB_ID/$CKPT_NAME \ 35 | habitat_baselines.evaluate=True 36 | habitat_baselines.eval.video_option="[]" 37 | habitat_baselines.writer_type="tb" 38 | +habitat_baselines.evaluation_config.n_steps=$EVAL_N_STEPS 39 | +habitat_baselines.evaluation_config.n_demos=$EVAL_N_DEMOS 40 | +habitat_baselines.evaluation_config.fix_target_same_episode=$EVAL_FIX_TARGET_IN_TRIAL 41 | +habitat_baselines.evaluation_config.max_num_start_pos=$EVAL_MAX_NUM_START_POS 42 | +habitat_baselines.evaluation_config.max_n_eps=$EVAL_MAX_NUM_EPISODES 43 | ) 44 | 45 | else 46 | EVAL_ARGS=() 47 | fi 48 | echo ${EVAL_ARGS[@]} 49 | 50 | ############################################################################### 51 | 52 | export GLOG_minloglevel=2 53 | export MAGNUM_LOG=quiet 54 | export HABITAT_SIM_LOG=quiet 55 | 56 | set -x 57 | 58 | python run.py --config-name $CONFIG_NAME \ 59 | habitat_baselines.wb.entity=$WB_ENTITY \ 60 | habitat_baselines.wb.run_name=$WB_JOB_ID \ 61 | habitat_baselines.wb.project_name=$WB_PROJECT_NAME \ 62 | habitat_baselines.checkpoint_folder=$CHECKPOINT_FOLDER \ 63 | habitat_baselines.video_dir=$VIDEO_DIR \ 64 | habitat_baselines.log_file=$LOG_FILE \ 65 | habitat_baselines.eval_data_dir=$EVAL_DATA_DIR \ 66 | habitat_baselines.writer_type=wb \ 67 | ${EVAL_ARGS[@]} 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | rdevon/utils_package/build 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | project_diffusion_rl/.DS_Store 132 | .DS_Store 133 | 134 | data/ 135 | 136 | *.log 137 | *.err 138 | *.out 139 | scripts 140 | exp_data 141 | outputs 142 | wandb 143 | -------------------------------------------------------------------------------- /relic/envs/custom_envs/darkroom.py: -------------------------------------------------------------------------------- 1 | # The code from https://github.com/jon--lee/decision-pretrained-transformer/blob/15172df758a04bb112fa949dd62dcb9765780ec0/envs/darkroom_env.py#L12-L82 2 | # With change in reset method to accept `goal` argument. 3 | 4 | import gym 5 | import numpy as np 6 | 7 | 8 | class DarkroomEnv(gym.Env): 9 | def __init__(self, dim, goal, horizon): 10 | self.dim = dim 11 | self.goal = np.array(goal) 12 | self.horizon = horizon 13 | self.state_dim = 2 14 | self.action_dim = 5 15 | self.observation_space = gym.spaces.Box( 16 | low=0, high=dim - 1, shape=(self.state_dim,) 17 | ) 18 | self.action_space = gym.spaces.Discrete(self.action_dim) 19 | 20 | def sample_state(self): 21 | return np.random.randint(0, self.dim, 2) 22 | 23 | def sample_action(self): 24 | i = np.random.randint(0, 5) 25 | a = np.zeros(self.action_space.n) 26 | a[i] = 1 27 | return a 28 | 29 | def reset(self, goal=None): 30 | if goal is not None: 31 | self.goal = goal 32 | 33 | self.current_step = 0 34 | self.state = np.array([0, 0]) 35 | return self.state 36 | 37 | def transit(self, state, action): 38 | action = np.argmax(action) 39 | assert action in np.arange(self.action_space.n) 40 | state = np.array(state) 41 | if action == 0: 42 | state[0] += 1 43 | elif action == 1: 44 | state[0] -= 1 45 | elif action == 2: 46 | state[1] += 1 47 | elif action == 3: 48 | state[1] -= 1 49 | state = np.clip(state, 0, self.dim - 1) 50 | 51 | if np.all(state == self.goal): 52 | reward = 1 53 | else: 54 | reward = 0 55 | return state, reward 56 | 57 | def step(self, action): 58 | if self.current_step >= self.horizon: 59 | raise ValueError("Episode has already ended") 60 | 61 | self.state, r = self.transit(self.state, action) 62 | self.current_step += 1 63 | done = self.current_step >= self.horizon 64 | return self.state.copy(), r, done, {} 65 | 66 | def get_obs(self): 67 | return self.state.copy() 68 | 69 | def opt_action(self, state): 70 | if state[0] < self.goal[0]: 71 | action = 0 72 | elif state[0] > self.goal[0]: 73 | action = 1 74 | elif state[1] < self.goal[1]: 75 | action = 2 76 | elif state[1] > self.goal[1]: 77 | action = 3 78 | else: 79 | action = 4 80 | zeros = np.zeros(self.action_space.n) 81 | zeros[action] = 1 82 | return zeros 83 | -------------------------------------------------------------------------------- /relic/evaluator/process_eval_data.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | from tqdm.auto import tqdm 4 | import pandas as pd 5 | import numpy as np 6 | from collections import defaultdict 7 | import re 8 | import multiprocessing as mp 9 | from functools import partial 10 | 11 | 12 | def get_1st_index(files, rank): 13 | last_ch_ind = 100000000 14 | output_ind = len(files) 15 | for i, f in enumerate(files): 16 | if f"rank{rank}" in f and int(f.split("/")[-1].split("_")[1]) < last_ch_ind: 17 | output_ind = i 18 | last_ch_ind = int(f.split("/")[-1].split("_")[1]) 19 | return output_ind 20 | 21 | 22 | def read_csvs(path, max_length=100000000): 23 | files = glob(path) 24 | files.sort(key=lambda x: os.path.getmtime(x), reverse=True) 25 | index = max([get_1st_index(files, i) for i in range(4)]) + 1 26 | files = files[:index] 27 | 28 | if len(files) > 1: 29 | files = sorted( 30 | files, 31 | key=lambda x: int(x.split("_")[-1].split(".")[0]) 32 | if x.split("_")[-1].split(".")[0].isdigit() 33 | else 0, 34 | ) 35 | data = [] 36 | for f in files: 37 | 38 | df = pd.read_csv(f) 39 | data.append(df) 40 | if sum([len(x) for x in data]) >= max_length: 41 | break 42 | df = pd.concat(data, ignore_index=True) 43 | return df 44 | 45 | 46 | def process_row(data, max_eps): 47 | _, row = data 48 | obs_data = row["obs"] 49 | tmp = {} 50 | for k, v in row.items(): 51 | if k == "obs": 52 | continue 53 | try: 54 | tmp[k] = np.asarray(eval(v)) 55 | except Exception: 56 | pass 57 | dones = np.nonzero(tmp["done"])[0][:max_eps] 58 | return tmp, dones, obs_data, [*tmp.keys(), "obs"] 59 | 60 | 61 | def extract_episodes_data_from_df(df, max_eps=2000): 62 | if "Unnamed: 0" in df.columns: 63 | df = df.drop(columns=["Unnamed: 0"]) 64 | n_cols = len(df.keys()) 65 | n_rows = len(df) 66 | data = defaultdict(lambda: np.full((n_rows, max_eps), np.nan)) 67 | counts = np.zeros(max_eps) 68 | 69 | with mp.Pool(8) as p: 70 | func = partial(process_row, max_eps=max_eps) 71 | for ep_i, (tmp, dones, obs_data, row_keys) in tqdm( 72 | enumerate(p.imap(func, df.iterrows(), chunksize=16)), total=len(df) 73 | ): 74 | counts[: len(dones)] += 1 75 | for k in row_keys: 76 | if k == "obs": 77 | if "one_hot_target_sensor" in obs_data[0]: 78 | data["target"][ep_i, np.arange(len(dones))] = [ 79 | np.argmax(obs_data[i]["one_hot_target_sensor"]) 80 | for i in dones 81 | ] 82 | else: 83 | data[k][ep_i, np.arange(len(dones))] = tmp[k][dones] 84 | data["counts"] = counts 85 | return data 86 | -------------------------------------------------------------------------------- /relic/trainer/envs.py: -------------------------------------------------------------------------------- 1 | from typing import ( 2 | TYPE_CHECKING, 3 | Any, 4 | Dict, 5 | List, 6 | Optional, 7 | Tuple, 8 | cast, 9 | ) 10 | 11 | import gym 12 | from habitat.core.simulator import Observations 13 | from habitat.utils import profiling_wrapper 14 | import numpy as np 15 | 16 | if TYPE_CHECKING: 17 | from omegaconf import DictConfig 18 | import gym 19 | import habitat 20 | from habitat import Dataset 21 | from habitat.core.environments import RLTaskEnv 22 | from habitat.gym.gym_wrapper import HabGymWrapper 23 | 24 | 25 | class CustomRLTaskEnv(RLTaskEnv): 26 | def after_update(self): 27 | self._env.episode_iterator.after_update() 28 | task = self._env.task 29 | if hasattr(task, "after_update"): 30 | task.after_update() 31 | 32 | 33 | @habitat.registry.register_env(name="CustomGymHabitatEnv") 34 | class CustomGymHabitatEnv(gym.Wrapper): 35 | """ 36 | A registered environment that wraps a RLTaskEnv with the HabGymWrapper 37 | to use the default gym API. 38 | """ 39 | 40 | def __init__(self, config: "DictConfig", dataset: Optional[Dataset] = None): 41 | base_env = CustomRLTaskEnv(config=config, dataset=dataset) 42 | env = HabGymWrapper(env=base_env) 43 | super().__init__(env) 44 | 45 | 46 | from habitat.core.registry import registry 47 | from habitat.tasks.rearrange.rearrange_sim import RearrangeSim 48 | import magnum as mn 49 | from habitat.datasets.rearrange.samplers.receptacle import ( 50 | AABBReceptacle, 51 | find_receptacles, 52 | ) 53 | 54 | 55 | @registry.register_simulator(name="CustomRearrangeSim-v0") 56 | class CustomRearrangeSim(RearrangeSim): 57 | def _create_recep_info( 58 | self, scene_id: str, ignore_handles: List[str] 59 | ) -> Dict[str, mn.Range3D]: 60 | if scene_id not in self._receptacles_cache: 61 | receps = {} 62 | all_receps = find_receptacles( 63 | self, 64 | ignore_handles=ignore_handles, 65 | ) 66 | for recep in all_receps: 67 | recep = cast(AABBReceptacle, recep) 68 | local_bounds = recep.bounds 69 | global_T = recep.get_global_transform(self) 70 | # Some coordinates may be flipped by the global transformation, 71 | # mixing the minimum and maximum bound coordinates. 72 | bounds = np.stack( 73 | [ 74 | global_T.transform_point(local_bounds.min), 75 | global_T.transform_point(local_bounds.max), 76 | ], 77 | axis=0, 78 | ) 79 | receps[recep.unique_name.split("|")[0]] = mn.Range3D( 80 | np.min(bounds, axis=0), np.max(bounds, axis=0) 81 | ) 82 | self._receptacles_cache[scene_id] = receps 83 | return self._receptacles_cache[scene_id] 84 | -------------------------------------------------------------------------------- /relic/policies/visual_encoders.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | import torch 3 | import torch.nn as nn 4 | from vc_models.models.vit import model_utils 5 | from torch.nn import functional as F 6 | 7 | 8 | class Vc1Wrapper(nn.Module): 9 | """ 10 | Wrapper for the VC1 visual encoder. This will automatically download the model if it's not already. 11 | """ 12 | 13 | def __init__(self, im_obs_space, model_id=None, vc1_config=None): 14 | super().__init__() 15 | assert vc1_config is not None, "Make sure you pass vc1_config to Vc1Wrapper." 16 | self.vc1_config = vc1_config 17 | 18 | if model_id is None: 19 | model_id = model_utils.VC1_BASE_NAME 20 | print(f"loading {model_id}.") 21 | ( 22 | self.net, 23 | self.embd_size, 24 | self.model_transforms, 25 | model_info, 26 | ) = model_utils.load_model(model_id) 27 | self._image_obs_keys = [k for k in im_obs_space.spaces.keys() if k != "depth"] 28 | 29 | # Count total # of channels 30 | self._n_input_channels = sum( 31 | im_obs_space.spaces[k].shape[2] for k in self._image_obs_keys 32 | ) 33 | if self.vc1_config.is_2d_output and self.vc1_config.avg_pool_size: 34 | self.postprocess = nn.AvgPool2d( 35 | self.vc1_config.avg_pool_size, ceil_mode=True 36 | ) 37 | self.out_dim = int(ceil(14 / self.vc1_config.avg_pool_size)) 38 | else: 39 | self.postprocess = nn.Identity() 40 | self.out_dim = 1 41 | 42 | @property 43 | def is_blind(self): 44 | return self._n_input_channels == 0 45 | 46 | @torch.autocast("cuda") 47 | def forward(self, obs): 48 | # Extract tensors that are shape [batch_size, img_width, img_height, img_channels] 49 | feats = [] 50 | imgs = [v for k, v in obs.items() if k in self._image_obs_keys] 51 | for img in imgs: 52 | if img.shape[-1] != 3: 53 | img = torch.concat([img] * 3, dim=-1) 54 | scale_factor = 1.0 55 | else: 56 | scale_factor = 255.0 57 | 58 | img = self.model_transforms( 59 | img.permute(0, 3, 1, 2).contiguous() / scale_factor 60 | ) 61 | 62 | feats.append(self.net(img)) 63 | 64 | if len(feats) == 2: 65 | # feats = (feats[0] + feats[1])/2 66 | feats = torch.concat(feats, dim=-1) 67 | else: 68 | feats = feats[0] 69 | 70 | return self.postprocess(feats).flatten(1) 71 | 72 | @property 73 | def output_shape(self): 74 | return ( 75 | self.out_dim * self.out_dim * self.embd_size * len(self._image_obs_keys), 76 | ) 77 | 78 | @property 79 | def feats_size(self): 80 | return self.embd_size * len(self._image_obs_keys) 81 | 82 | def set_grad_checkpointing(self): 83 | return self.net.set_grad_checkpointing() 84 | -------------------------------------------------------------------------------- /relic/monkey_patch_eai_vc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import vc_models.models.vit.model_utils as model_utils 3 | import vc_models.models.vit.vit as vit 4 | from vc_models.models.vit.model_utils import _EAI_VC1_BASE_URL, _download_url 5 | from vc_models.models.vit.vit import resize_pos_embed 6 | import os 7 | import vc_models 8 | import hydra 9 | import omegaconf 10 | import logging 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.warn("Monkey patching vc_models package to load model by config path.") 14 | 15 | 16 | def load_model(model_name): 17 | """ 18 | Loads a model from the vc_models package. 19 | Args: 20 | model_name (str): name of the model to load 21 | Returns: 22 | model (torch.nn.Module): the model 23 | embedding_dim (int): the dimension of the embedding 24 | transform (torchvision.transforms): the transform to apply to the image 25 | metadata (dict): the metadata of the model 26 | """ 27 | if os.path.exists(model_name): 28 | cfg_path = model_name 29 | else: 30 | models_filepath = os.path.dirname(os.path.abspath(vc_models.__file__)) 31 | cfg_path = os.path.join(models_filepath, "conf", "model", f"{model_name}.yaml") 32 | 33 | model_cfg = omegaconf.OmegaConf.load(cfg_path) 34 | # returns tuple of model, embedding_dim, transform, metadata 35 | return hydra.utils.call(model_cfg) 36 | 37 | 38 | def download_model_if_needed(ckpt_file): 39 | if not os.path.exists(ckpt_file): 40 | model_base_dir = os.path.join( 41 | os.path.dirname(os.path.abspath(__file__)), "..", "..", ".." 42 | ) 43 | ckpt_file = os.path.join(model_base_dir, ckpt_file) 44 | 45 | if not os.path.exists(ckpt_file): 46 | os.makedirs(os.path.dirname(ckpt_file), exist_ok=True) 47 | 48 | model_name = ckpt_file.split("/")[-1] 49 | model_url = _EAI_VC1_BASE_URL + model_name 50 | _download_url(model_url, ckpt_file) 51 | 52 | 53 | def load_mae_encoder(model, checkpoint_path=None): 54 | if checkpoint_path is None: 55 | return model 56 | else: 57 | model_utils.download_model_if_needed(checkpoint_path) 58 | 59 | if not os.path.exists(checkpoint_path): 60 | model_base_dir = os.path.join( 61 | os.path.dirname(os.path.abspath(__file__)), "..", "..", ".." 62 | ) 63 | checkpoint_path = os.path.join(model_base_dir, checkpoint_path) 64 | 65 | state_dict = torch.load(checkpoint_path, map_location="cpu")["model"] 66 | if state_dict["pos_embed"].shape != model.pos_embed.shape: 67 | state_dict["pos_embed"] = resize_pos_embed( 68 | state_dict["pos_embed"], 69 | model.pos_embed, 70 | getattr(model, "num_tokens", 1), 71 | model.patch_embed.grid_size, 72 | ) 73 | 74 | # filter out keys with name decoder or mask_token 75 | state_dict = { 76 | k: v 77 | for k, v in state_dict.items() 78 | if "decoder" not in k and "mask_token" not in k 79 | } 80 | 81 | if model.classifier_feature == "global_pool": 82 | # remove layer that start with norm 83 | state_dict = {k: v for k, v in state_dict.items() if not k.startswith("norm")} 84 | # add fc_norm in the state dict from the model 85 | state_dict["fc_norm.weight"] = model.fc_norm.weight 86 | state_dict["fc_norm.bias"] = model.fc_norm.bias 87 | 88 | model.load_state_dict(state_dict) 89 | return model 90 | 91 | 92 | model_utils.load_model = load_model 93 | model_utils.download_model_if_needed = download_model_if_needed 94 | vit.load_mae_encoder = load_mae_encoder 95 | -------------------------------------------------------------------------------- /relic/configs/baseline/baselines_comparison/rl2.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /tasks: ExtObjNav 5 | - /habitat_baselines: habitat_baselines_rl2_config_base 6 | - /habitat/simulator/sim_sensors@habitat_baselines.eval.extra_sim_sensors.third_rgb_sensor: third_rgb_sensor 7 | - /policy@habitat_baselines.rl.policy.main_agent: lstm 8 | - _self_ 9 | 10 | habitat_baselines: 11 | verbose: False 12 | trainer_name: "ddppo" 13 | torch_gpu_id: 0 14 | tensorboard_dir: "tb" 15 | video_dir: "video_dir" 16 | video_fps: 30 17 | test_episode_count: -1 18 | eval_ckpt_path_dir: "data/new_checkpoints" 19 | num_environments: 20 20 | writer_type: 'tb' 21 | checkpoint_folder: "data/new_checkpoints" 22 | num_updates: -1 23 | total_num_steps: 1.0e9 24 | log_interval: 10 25 | num_checkpoints: 200 26 | reset_envs_after_update: True 27 | call_after_update_env: True 28 | # Force PyTorch to be single threaded as 29 | # this improves performance considerably 30 | force_torch_single_threaded: True 31 | eval_keys_to_include_in_name: ['reward', 'force', 'success'] 32 | 33 | separate_envs_and_policy: False 34 | separate_rollout_and_policy: False 35 | 36 | vector_env_factory: 37 | _target_: "relic.envs.train_env_factory.HabitatVectorEnvFactory" 38 | eval: 39 | video_option: ["disk"] 40 | 41 | rl: 42 | ppo: 43 | # ppo params 44 | clip_param: 0.2 45 | ppo_epoch: 4 46 | num_mini_batch: 2 47 | grad_accum_mini_batches: 1 48 | value_loss_coef: 0.5 49 | entropy_coef: 0.01 50 | lr: 2.5e-4 51 | optimizer_name: adam 52 | adamw_weight_decay: 0.1 53 | warmup: True 54 | warmup_total_iters: 100000 55 | warmup_start_factor: 1e-3 56 | lr_scheduler: cosine_decay 57 | lrsched_T_max: 1000000000 58 | lrsched_eta_min: 1e-5 59 | 60 | eps: 1e-5 61 | max_grad_norm: 0.2 62 | num_steps: 64 63 | use_gae: True 64 | gamma: 0.99 65 | tau: 0.95 66 | use_linear_clip_decay: False 67 | use_linear_lr_decay: False 68 | reward_window_size: 50 69 | 70 | use_normalized_advantage: False 71 | 72 | hidden_size: 512 73 | 74 | # Use double buffered sampling, typically helps 75 | # when environment time is similar or larger than 76 | # policy inference time during rollout generation 77 | use_double_buffered_sampler: True 78 | update_stale_kv: True 79 | update_stale_values: True 80 | full_updates_per_rollout: 1 81 | updates_per_rollout: 16 82 | ignore_old_obs_grad: False 83 | storage_low_precision: True 84 | gradient_checkpointing: False 85 | slice_in_partial_update: False 86 | percent_envs_update: 1 87 | shuffle_old_episodes: True 88 | shift_scene_every: 100000 89 | 90 | ddppo: 91 | sync_frac: 0.6 92 | # The PyTorch distributed backend to use 93 | distrib_backend: NCCL 94 | # Visual encoder backbone 95 | pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth 96 | # Initialize with pretrained weights 97 | pretrained: False 98 | # Initialize just the visual encoder backbone with pretrained weights 99 | pretrained_encoder: False 100 | # Whether the visual encoder backbone will be trained. 101 | train_encoder: False 102 | # Whether to reset the critic linear layer 103 | reset_critic: True 104 | 105 | # Model parameters 106 | backbone: vc1_vc1_vitb_ft_cls_e15 107 | rnn_type: LSTM 108 | num_recurrent_layers: 2 109 | -------------------------------------------------------------------------------- /relic/configs/baseline/relic_base.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /habitat_baselines: habitat_baselines_rl_config_base 5 | - /habitat/simulator/sim_sensors@habitat_baselines.eval.extra_sim_sensors.third_rgb_sensor: third_rgb_sensor 6 | - /policy@habitat_baselines.rl.policy.main_agent: transformer_small 7 | - /backbone: vc1_smallObjs 8 | - _self_ 9 | 10 | habitat_baselines: 11 | verbose: False 12 | trainer_name: "transformers" 13 | torch_gpu_id: 0 14 | tensorboard_dir: "tb" 15 | video_dir: "video_dir" 16 | video_fps: 30 17 | test_episode_count: -1 18 | eval_ckpt_path_dir: "data/new_checkpoints" 19 | num_environments: 20 20 | writer_type: 'tb' 21 | checkpoint_folder: "data/new_checkpoints" 22 | num_updates: -1 23 | total_num_steps: 1.0e9 24 | log_interval: 10 25 | num_checkpoints: 200 26 | reset_envs_after_update: True 27 | call_after_update_env: True 28 | # Force PyTorch to be single threaded as 29 | # this improves performance considerably 30 | force_torch_single_threaded: True 31 | eval_keys_to_include_in_name: ['reward', 'force', 'success'] 32 | rollout_storage_name: TransformerRolloutStorage 33 | updater_name: "TransformerPPO" 34 | distrib_updater_name: "DistributedTransformerPPO" 35 | 36 | separate_envs_and_policy: False 37 | separate_rollout_and_policy: False 38 | 39 | evaluator: 40 | _target_: "relic.evaluator.TransformersRepEpsHabitatEvaluator" 41 | vector_env_factory: 42 | _target_: "relic.envs.train_env_factory.HabitatVectorEnvFactory" 43 | eval: 44 | video_option: ["disk"] 45 | 46 | rl: 47 | agent: 48 | type: TransformerSingleAgentAccessMgr 49 | 50 | ppo: 51 | # ppo params 52 | clip_param: 0.2 53 | ppo_epoch: 4 54 | num_mini_batch: 20 55 | grad_accum_mini_batches: 10 56 | value_loss_coef: 0.5 57 | entropy_coef: 0.01 58 | lr: 2e-4 59 | optimizer_name: adam 60 | adamw_weight_decay: 0.1 61 | warmup: True 62 | warmup_total_iters: 100000 63 | warmup_start_factor: 1e-3 64 | lr_scheduler: cosine_decay 65 | lrsched_T_max: 1000000000 66 | lrsched_eta_min: 1e-5 67 | 68 | eps: 1e-5 69 | max_grad_norm: 0.2 70 | num_steps: 4096 71 | use_gae: True 72 | gamma: 0.99 73 | tau: 0.95 74 | use_linear_clip_decay: False 75 | use_linear_lr_decay: False 76 | reward_window_size: 50 77 | 78 | use_normalized_advantage: False 79 | 80 | hidden_size: 512 81 | 82 | # Use double buffered sampling, typically helps 83 | # when environment time is similar or larger than 84 | # policy inference time during rollout generation 85 | use_double_buffered_sampler: True 86 | update_stale_kv: True 87 | update_stale_values: True 88 | full_updates_per_rollout: 1 89 | updates_per_rollout: 16 90 | ignore_old_obs_grad: False 91 | storage_low_precision: True 92 | gradient_checkpointing: False 93 | slice_in_partial_update: False 94 | percent_envs_update: 1 95 | shuffle_old_episodes: True 96 | shift_scene_every: -1 97 | 98 | ddppo: 99 | sync_frac: 0.6 100 | # The PyTorch distributed backend to use 101 | distrib_backend: NCCL 102 | # Visual encoder backbone 103 | pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth 104 | # Initialize with pretrained weights 105 | pretrained: False 106 | # Initialize just the visual encoder backbone with pretrained weights 107 | pretrained_encoder: False 108 | # Whether to reset the critic linear layer 109 | reset_critic: True 110 | 111 | # Model parameters 112 | rnn_type: transformer 113 | num_recurrent_layers: 2 114 | -------------------------------------------------------------------------------- /launch.yaml: -------------------------------------------------------------------------------- 1 | # Andrew's settings for launching experiments 2 | base_data_dir: "/srv/share/aszot3/habitat2" 3 | proj_name: "hab_trans" 4 | wb_entity: "aszot" 5 | ckpt_cfg_key: "CHECKPOINT_FOLDER" 6 | add_env_vars: 7 | - "MAGNUM_LOG=quiet" 8 | - "HABITAT_SIM_LOG=quiet" 9 | proj_dat_add_env_vars: 10 | debug: "HABITAT_ENV_DEBUG=1" 11 | debug_eval: "HABITAT_ENV_DEBUG=1" 12 | conda_env: "hab_trans" 13 | slurm_ignore_nodes: ["spd-13"] 14 | add_all: "habitat_baselines.wb.entity=$WB_ENTITY habitat_baselines.wb.run_name=$SLURM_ID habitat_baselines.wb.project_name=$PROJECT_NAME habitat_baselines.checkpoint_folder=$DATA_DIR/checkpoints/$SLURM_ID/ habitat_baselines.video_dir=$DATA_DIR/vids/$SLURM_ID/ habitat_baselines.log_file=$DATA_DIR/logs/$SLURM_ID.log habitat_baselines.tensorboard_dir=$DATA_DIR/tb/$SLURM_ID/ habitat_baselines.writer_type=wb" 15 | eval_sys: 16 | ckpt_load_k: "habitat_baselines.eval_ckpt_path_dir" 17 | ckpt_search_dir: "checkpoints" 18 | run_id_k: "habitat_baselines.wb.run_name" 19 | sep: "=" 20 | add_eval_to_vals: 21 | - "habitat_baselines.checkpoint_folder" 22 | - "habitat_baselines.log_file" 23 | - "habitat_baselines.wb.run_name" 24 | change_vals: 25 | "--run-type": "eval" 26 | proj_data: 27 | # Eval settings 28 | eval: "habitat_baselines.writer_type tb habitat_baselines.num_environments 1 habitat_baselines.load_resume_state_config=False" 29 | eval10proc: "habitat_baselines.num_environments=10 habitat_baselines.load_resume_state_config=False" 30 | norender: "habitat_baselines.eval.video_option=\"[]\"" 31 | video: "habitat_baselines.test_episode_count=5 habitat_baselines.writer_type=tb habitat_baselines.num_environments=1" 32 | 33 | # Debug settings. 34 | debug: "habitat_baselines.num_environments=1 habitat_baselines.writer_type=tb habitat_baselines.log_interval=1 habitat_baselines.rl.ppo.num_mini_batch=1 habitat_baselines.video_dir=$DATA_DIR/vids/debug/ habitat_baselines.trainer_name=ppo" 35 | verdebug: "habitat_baselines.num_environments=1 habitat_baselines.writer_type=tb habitat_baselines.log_interval=1 habitat_baselines.rl.ppo.num_mini_batch=1 habitat_baselines.video_dir=$DATA_DIR/vids/debug/" 36 | procdebug: "habitat_baselines.writer_type=tb habitat_baselines.log_interval=1 habitat_baselines.video_dir=$DATA_DIR/vids/debug/" 37 | 38 | # Dataset settings. 39 | ppo: "habitat_baselines.trainer_name=ppo" 40 | ddppo: "habitat_baselines.trainer_name=ddppo" 41 | minival: "habitat_baselines.eval.split=minival habitat.dataset.split=minival" 42 | testep: "habitat.dataset.data_path=\"data/datasets/replica_cad/rearrange/v1/train/rearrange_easy_test.json.gz\"" 43 | train: "eval.split=train" # Evaluate on the train dataset. 44 | 45 | # Task settings 46 | kin: "habitat.simulator.kinematic_mode=True habitat.simulator.ac_freq_ratio=1 habitat.task.measurements.force_terminate.max_accum_force=-1.0 habitat.task.measurements.force_terminate.max_instant_force=-1.0" 47 | hl: "habitat.task.measurements.composite_success.must_call_stop=False habitat.environment.max_episode_steps=20" 48 | noforceterm: "habitat.task.measurements.force_terminate.max_accum_force=-1.0 habitat.task.measurements.force_terminate.max_instant_force=-1.0" 49 | 50 | # GPU adjustments 51 | # A40 depth input 52 | a40d: "habitat_baselines.num_environments=48" 53 | 54 | # Skills 55 | place: "benchmark/rearrange=place" 56 | pick: "benchmark/rearrange=pick" 57 | nav: "benchmark/rearrange=nav_to_obj" 58 | opencab: "benchmark/rearrange=open_cab" 59 | openfridge: "benchmark/rearrange=open_fridge" 60 | # Add at the end for TP-SRL mode. 61 | tpsrl: "habitat/task/rearrange/agents@habitat.task.habitat.task.rearrange.agents=fetch_arm habitat.task.spawn_max_dist_to_obj=-1.0 habitat.dataset.split=minival habitat.task.base_angle_noise=0.1 habitat.task.num_spawn_attempts=1" 62 | 63 | slurm: 64 | small: 65 | c: 7 66 | partition: short 67 | large: 68 | c: 16 69 | partition: short 70 | constraint: 'a40' 71 | -------------------------------------------------------------------------------- /relic/configs/tasks/ExtObjNav.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /habitat: habitat_config_base 5 | - /habitat/task: custom_task_config_base 6 | - /habitat/simulator: rearrange_sim 7 | - /habitat/simulator/sensor_setups@habitat.simulator.agents.main_agent: rgb_head_agent 8 | - /habitat/simulator/sim_sensors@habitat.simulator.agents.main_agent.sim_sensors.head_panoptic_sensor: head_panoptic_sensor 9 | - /habitat/simulator/agents@habitat.simulator.agents.main_agent: fetch_suction 10 | - /habitat/dataset/rearrangement: hssd 11 | - /habitat/task/actions: 12 | - rearrange_stop 13 | - move_forward_custom 14 | - turn_left_custom 15 | - turn_right_custom 16 | - rearrange_look_up 17 | - rearrange_look_down 18 | - /habitat/task/measurements: 19 | - articulated_agent_force 20 | - force_terminate 21 | - articulated_agent_colls 22 | - does_want_terminate 23 | - num_steps 24 | - geo_disc_distance 25 | - l2_distance 26 | - rot_dist_to_closest_goal 27 | - extobjnav_success 28 | - spl_geodisc 29 | - soft_spl_geodisc 30 | - bad_called_terminate 31 | - named_nav_to_obj_reward 32 | - /habitat/task/lab_sensors: 33 | - joint_sensor 34 | - one_hot_target_sensor 35 | - gps_sensor 36 | - _self_ 37 | 38 | 39 | habitat: 40 | env_task: "CustomGymHabitatEnv" 41 | simulator: 42 | type: CustomRearrangeSim-v0 43 | kinematic_mode: True 44 | ac_freq_ratio: 1 45 | step_physics: False 46 | turn_angle: 30 47 | habitat_sim_v0: 48 | allow_sliding: True 49 | agents: 50 | main_agent: 51 | joint_start_noise: 0 52 | joint_start_override: [1.573,1.573,-3.14,-1.573,0,-1.573,0] 53 | gym: 54 | obs_keys: 55 | - head_rgb 56 | - one_hot_target_sensor 57 | - gps 58 | task: 59 | type: PddlMultiTask-v0 60 | start_template: [] 61 | goal_template: 62 | expr_type: AND 63 | sub_exprs: 64 | - "robot_at(obj, robot_0)" 65 | sample_entities: 66 | "obj": 67 | "type": "movable_entity_type" 68 | reward_measure: named_nav_to_obj_reward 69 | task_spec_base_path: 'configs/pddl' 70 | pddl_domain_def: "hssd_domain" 71 | success_measure: extobjnav_success 72 | success_reward: 2 73 | slack_reward: -0.001 74 | end_on_success: True 75 | constraint_violation_ends_episode: False 76 | constraint_violation_drops_object: True 77 | measurements: 78 | custom_predicate_task_success: 79 | must_call_stop: True 80 | max_angle: 10000 # inf 81 | must_see_object: True 82 | sees_vertical_margin: 20 # not used 83 | sees_horizontal_margin: 50 # not used 84 | ignore_objects: False # not used 85 | ignore_receptacles: False # not used 86 | ignore_non_negative: False # not used 87 | custom_predicate_task_reward: 88 | dist_reward: 1.0 89 | should_reward_turn: False 90 | angle_dist_reward: 0.01 91 | constraint_violate_pen: 0.0 92 | force_pen: 0.0 93 | max_force_pen: 0.0 94 | force_end_pen: 0.0 95 | bad_term_pen: 0.0 96 | end_on_bad_termination: True 97 | use_max_dist: False 98 | force_terminate: 99 | max_accum_force: -1 100 | max_instant_force: -1 101 | geo_disc_distance: 102 | lock_closest_object: False 103 | fix_position_same_episode: False 104 | fix_target_same_episode: False 105 | fix_instance_index: False 106 | target_type: object_type 107 | target_sampling_strategy: object_type 108 | cleanup_nav_points: False 109 | one_receptacle: False 110 | is_large_objs: False 111 | actions: 112 | turn_left_custom: 113 | ang_speed: 30 114 | turn_right_custom: 115 | ang_speed: -30 116 | rearrange_look_down: 117 | tilt_angle: 30 118 | rearrange_look_up: 119 | tilt_angle: 30 120 | move_forward_custom: 121 | lin_speed: 0.25 122 | 123 | environment: 124 | max_episode_steps: 500 125 | dataset: 126 | type: "RearrangeDatasetTransformers-v0" 127 | data_path: "data/datasets/ExtObjNav_HSSD/rearrange/{split}/rearrange_ep_dataset.json.gz" 128 | -------------------------------------------------------------------------------- /relic/configs/tasks/ExtObjNav_replicaCAD.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /habitat: habitat_config_base 5 | - /habitat/task: custom_task_config_base 6 | - /habitat/simulator: rearrange_sim 7 | - /habitat/simulator/sensor_setups@habitat.simulator.agents.main_agent: rgb_head_agent 8 | - /habitat/simulator/sim_sensors@habitat.simulator.agents.main_agent.sim_sensors.head_panoptic_sensor: head_panoptic_sensor 9 | - /habitat/simulator/agents@habitat.simulator.agents.main_agent: fetch_suction 10 | - /habitat/dataset/rearrangement: replica_cad 11 | - /habitat/task/actions: 12 | - rearrange_stop 13 | - move_forward_custom 14 | - turn_left_custom 15 | - turn_right_custom 16 | - rearrange_look_up 17 | - rearrange_look_down 18 | - /habitat/task/measurements: 19 | - articulated_agent_force 20 | - force_terminate 21 | - articulated_agent_colls 22 | - does_want_terminate 23 | - num_steps 24 | - geo_disc_distance 25 | - l2_distance 26 | - rot_dist_to_closest_goal 27 | - extobjnav_success 28 | - spl_geodisc 29 | - soft_spl_geodisc 30 | - bad_called_terminate 31 | - named_nav_to_obj_reward 32 | - /habitat/task/lab_sensors: 33 | - joint_sensor 34 | - one_hot_target_sensor 35 | - gps_sensor 36 | - _self_ 37 | 38 | 39 | habitat: 40 | env_task: "CustomGymHabitatEnv" 41 | simulator: 42 | type: CustomRearrangeSim-v0 43 | kinematic_mode: True 44 | ac_freq_ratio: 1 45 | step_physics: False 46 | turn_angle: 30 47 | habitat_sim_v0: 48 | allow_sliding: True 49 | agents: 50 | main_agent: 51 | joint_start_noise: 0 52 | joint_start_override: [1.573,1.573,-3.14,-1.573,0,-1.573,0] 53 | gym: 54 | obs_keys: 55 | - head_rgb 56 | - one_hot_target_sensor 57 | - gps 58 | task: 59 | type: PddlMultiTask-v0 60 | start_template: [] 61 | goal_template: 62 | expr_type: AND 63 | sub_exprs: 64 | - "robot_at(obj, robot_0)" 65 | sample_entities: 66 | "obj": 67 | "type": "movable_entity_type" 68 | reward_measure: named_nav_to_obj_reward 69 | task_spec_base_path: 'configs/pddl' 70 | pddl_domain_def: "domain" 71 | success_measure: extobjnav_success 72 | success_reward: 2 73 | slack_reward: -0.001 74 | end_on_success: True 75 | constraint_violation_ends_episode: False 76 | constraint_violation_drops_object: True 77 | measurements: 78 | custom_predicate_task_success: 79 | must_call_stop: True 80 | max_angle: 10000 # inf 81 | must_see_object: True 82 | sees_vertical_margin: 20 # not used 83 | sees_horizontal_margin: 50 # not used 84 | ignore_objects: False # not used 85 | ignore_receptacles: False # not used 86 | ignore_non_negative: False # not used 87 | custom_predicate_task_reward: 88 | dist_reward: 1.0 89 | should_reward_turn: False 90 | angle_dist_reward: 0.01 91 | constraint_violate_pen: 0.0 92 | force_pen: 0.0 93 | max_force_pen: 0.0 94 | force_end_pen: 0.0 95 | bad_term_pen: 0.0 96 | end_on_bad_termination: True 97 | use_max_dist: False 98 | force_terminate: 99 | max_accum_force: -1 100 | max_instant_force: -1 101 | geo_disc_distance: 102 | lock_closest_object: False 103 | fix_position_same_episode: False 104 | fix_target_same_episode: False 105 | fix_instance_index: False 106 | target_type: object_type 107 | target_sampling_strategy: object_type 108 | cleanup_nav_points: False 109 | one_receptacle: False 110 | is_large_objs: False 111 | actions: 112 | turn_left_custom: 113 | ang_speed: 30 114 | turn_right_custom: 115 | ang_speed: -30 116 | rearrange_look_down: 117 | tilt_angle: 30 118 | rearrange_look_up: 119 | tilt_angle: 30 120 | move_forward_custom: 121 | lin_speed: 0.25 122 | 123 | environment: 124 | max_episode_steps: 500 125 | dataset: 126 | type: "RearrangeDatasetTransformers-v0" 127 | data_path: data/datasets/ExtObjNav_replicaCAD/rearrange/v1/{split}/rearrange_easy_clean.json.gz 128 | -------------------------------------------------------------------------------- /relic/envs/custom_envs/example.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from gym import spaces 3 | import numpy as np 4 | from habitat.core.dataset import BaseEpisode 5 | 6 | from relic.envs.custom_envs.darkroom import DarkroomEnv 7 | 8 | try: 9 | import gymnasium 10 | 11 | IS_GYMNASIUM_AVAILABLE = True 12 | except ImportError: 13 | IS_GYMNASIUM_AVAILABLE = False 14 | 15 | 16 | def b2b(src: "gymnasium.spaces.Box"): 17 | """Converts gymnasium.spaces.Box to gym.spaces.Box""" 18 | try: 19 | return spaces.Box(low=src.low, high=src.high, dtype=src.dtype, shape=src.shape) 20 | except Exception: 21 | return spaces.Box(low=src.low, high=src.high, shape=src.shape) 22 | 23 | 24 | def d2d(src: "gymnasium.spaces.Discrete"): 25 | """Converts gymnasium.spaces.Discrete to gym.spaces.Discrete""" 26 | return spaces.Discrete(src.n, start=src.start) 27 | 28 | 29 | def convert2gym_space(src): 30 | """Converts from gymnasium spaces to gym spaces.""" 31 | if isinstance(src, (spaces.Box, spaces.Discrete)): 32 | return src 33 | elif IS_GYMNASIUM_AVAILABLE: 34 | if isinstance(src, (gymnasium.spaces.Box)): 35 | return b2b(src) 36 | elif isinstance(src, gymnasium.spaces.Discrete): 37 | return d2d(src) 38 | else: 39 | raise TypeError(f"The conversion is not implement for type {type(src)}.") 40 | else: 41 | raise TypeError( 42 | f"The conversion is not implement for type {type(src)}. If this is a gymnasium class, make sure the package is installed." 43 | ) 44 | 45 | 46 | class NewEnv(gym.Wrapper): 47 | def __init__(self, *args, **kwargs): 48 | """Create env wrapper compatible with ReLIC.""" 49 | 50 | # Each element of self.goals is a different task in the darkroom env. 51 | # This can be replaced by seed in gym/gymnasium envs which is passed to reset calls. 52 | self.goals = kwargs.pop("goals") 53 | super().__init__(*args, **kwargs) 54 | 55 | # ReLIC supports gym spaces. This is how to convert from gymnasium spaces to gym spaces. 56 | self.observation_space = spaces.Dict( 57 | { 58 | "state": convert2gym_space(self.env.observation_space), 59 | "reward_input": spaces.Box(-np.inf, np.inf, shape=(1,)), 60 | } 61 | ) 62 | 63 | # ReLIC supports gym spaces. This is how to convert from gymnasium spaces to gym spaces. 64 | self.action_space = convert2gym_space(self.env.action_space) 65 | 66 | # self.current_goal_idx indicates the trial's goal so that when the env is reset in the trial 67 | # the goal doesn't change. 68 | self.current_goal_idx = 0 69 | self.after_update() 70 | 71 | # Required information for the evaluator. 72 | self.episodes = [str(x) for x in self.goals] 73 | self.number_of_episodes = len(self.episodes) 74 | self._has_number_episode = True 75 | 76 | def step(self, action): 77 | # relic passes int for categorical actions but darkroom accepts 1-hot encoding. 78 | new_action = np.zeros(self.action_space.n) 79 | new_action[action] = 1 80 | 81 | obs, reward, done, info = super().step(new_action) 82 | obs = { 83 | "state": obs, 84 | "reward_input": [reward], 85 | } 86 | return obs, reward, done, info 87 | 88 | def reset(self, *args, **kwargs): 89 | """Reset the env. 90 | 91 | This method shoud reset the env but not change the task. This can be 92 | controlled by using the same goal in the darkroom env or the same seed 93 | in gym/gymnasium envs. The method responsible for changing the task is 94 | `after_update`. 95 | """ 96 | obs = super().reset(goal=self.goals[self.current_goal_idx]) 97 | obs = { 98 | "state": obs, 99 | "reward_input": [0], 100 | } 101 | return obs 102 | 103 | def after_update(self): 104 | """Change the task. This method is called after the trial ends to change the task.""" 105 | self.current_goal_idx += 1 106 | self.current_goal_idx %= len(self.goals) 107 | 108 | def current_episode(self, *args): 109 | """Return episode identifier.""" 110 | goal = self.goals[self.current_goal_idx] 111 | return BaseEpisode( 112 | episode_id=str(goal), 113 | scene_id=str(goal), 114 | ) 115 | 116 | @property 117 | def original_action_space(self) -> spaces.space: 118 | """Return the action space.""" 119 | return self.action_space 120 | 121 | 122 | def make_env(config, dataset=None): 123 | env = DarkroomEnv(10, (0, 0), 100) 124 | 125 | # Create train/validation splits 126 | is_eval = config.habitat_baselines.evaluate 127 | goals = np.array([[(j, i) for i in range(10)] for j in range(10)]).reshape(-1, 2) 128 | np.random.RandomState(seed=0).shuffle(goals) 129 | train_test_split = int(0.8 * len(goals)) 130 | if is_eval: 131 | goals = goals[train_test_split:] 132 | else: 133 | goals = goals[:train_test_split] 134 | 135 | # Shuffle the tasks. config.habitat.seed is different for each env worker. 136 | np.random.RandomState(config.habitat.seed).shuffle(goals) 137 | 138 | return NewEnv(env, goals=goals) 139 | -------------------------------------------------------------------------------- /relic/envs/train_env_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from math import ceil 6 | import os 7 | import random 8 | from typing import TYPE_CHECKING, Any, List, Type 9 | 10 | from habitat import ThreadedVectorEnv, VectorEnv, logger, make_dataset 11 | from habitat.config import read_write 12 | from habitat.gym import make_gym_from_config 13 | from habitat_baselines.common.env_factory import VectorEnvFactory 14 | from habitat_baselines.rl.ddppo.ddp_utils import get_distrib_size 15 | import torch 16 | import importlib 17 | 18 | if TYPE_CHECKING: 19 | from omegaconf import DictConfig 20 | 21 | 22 | def get_make_env_func_by_name(name): 23 | if "." in name: 24 | module_name, func_name = name.rsplit(".", 1) 25 | module = importlib.import_module(module_name) 26 | func = getattr(module, func_name) 27 | else: 28 | func = globals()[name] 29 | return func 30 | 31 | 32 | class HabitatVectorEnvFactory(VectorEnvFactory): 33 | def construct_envs( 34 | self, 35 | config: "DictConfig", 36 | workers_ignore_signals: bool = False, 37 | enforce_scenes_greater_eq_environments: bool = False, 38 | is_first_rank: bool = True, 39 | distribute_envs_across_gpus=None, 40 | ) -> VectorEnv: 41 | r"""Create VectorEnv object with specified config and env class type. 42 | To allow better performance, dataset are split into small ones for 43 | each individual env, grouped by scenes. 44 | """ 45 | 46 | if distribute_envs_across_gpus is None: 47 | distribute_envs_across_gpus = enforce_scenes_greater_eq_environments 48 | 49 | num_environments = config.habitat_baselines.num_environments 50 | configs = [] 51 | make_env_func_name = config.habitat.task.get( 52 | "make_env_fn", "make_gym_from_config" 53 | ) 54 | if make_env_func_name == "make_gym_from_config": 55 | dataset = make_dataset(config.habitat.dataset.type) 56 | scenes = list(config.habitat.dataset.content_scenes) 57 | if "*" in config.habitat.dataset.content_scenes: 58 | scenes = dataset.get_scenes_to_load(config.habitat.dataset) 59 | scenes = sorted(scenes) 60 | local_rank, world_rank, world_size = get_distrib_size() 61 | split_size = ceil(len(scenes) / world_size) 62 | orig_size = len(scenes) 63 | scenes = scenes[world_rank * split_size : (world_rank + 1) * split_size] 64 | scenes_ids = list(range(orig_size))[ 65 | world_rank * split_size : (world_rank + 1) * split_size 66 | ] 67 | logger.warn(f"Loading {len(scenes)}/{orig_size}. IDs: {scenes_ids}") 68 | 69 | if num_environments < 1: 70 | raise RuntimeError("num_environments must be strictly positive") 71 | 72 | if len(scenes) == 0: 73 | raise RuntimeError( 74 | "No scenes to load, multiple process logic relies on being able to split scenes uniquely between processes" 75 | ) 76 | 77 | random.shuffle(scenes) 78 | 79 | scene_splits: List[List[str]] = [[] for _ in range(num_environments)] 80 | for idx in range(max(len(scene_splits), len(scenes))): 81 | scene_splits[idx % len(scene_splits)].append(scenes[idx % len(scenes)]) 82 | 83 | logger.warn(f"Scene splits: {scene_splits}.") 84 | assert all(scene_splits) 85 | else: 86 | scenes = [] 87 | 88 | for env_index in range(num_environments): 89 | proc_config = config.copy() 90 | with read_write(proc_config): 91 | if distribute_envs_across_gpus: 92 | proc_config.habitat.simulator.habitat_sim_v0.gpu_device_id = ( 93 | env_index % torch.cuda.device_count() 94 | ) 95 | 96 | task_config = proc_config.habitat 97 | task_config.seed = task_config.seed + env_index 98 | remove_measure_names = [] 99 | if not is_first_rank: 100 | # Filter out non rank0_measure from the task config if we are not on rank0. 101 | remove_measure_names.extend(task_config.task.rank0_measure_names) 102 | if (env_index != 0) or not is_first_rank: 103 | # Filter out non-rank0_env0 measures from the task config if we 104 | # are not on rank0 env0. 105 | remove_measure_names.extend( 106 | task_config.task.rank0_env0_measure_names 107 | ) 108 | 109 | task_config.task.measurements = { 110 | k: v 111 | for k, v in task_config.task.measurements.items() 112 | if k not in remove_measure_names 113 | } 114 | 115 | if len(scenes) > 0: 116 | task_config.dataset.content_scenes = scene_splits[env_index] 117 | 118 | configs.append(proc_config) 119 | 120 | vector_env_cls: Type[Any] 121 | if int(os.environ.get("HABITAT_ENV_DEBUG", 0)): 122 | logger.warn( 123 | "Using the debug Vector environment interface. Expect slower performance." 124 | ) 125 | vector_env_cls = ThreadedVectorEnv 126 | else: 127 | vector_env_cls = VectorEnv 128 | 129 | envs = vector_env_cls( 130 | make_env_fn=get_make_env_func_by_name(make_env_func_name), 131 | env_fn_args=tuple((c,) for c in configs), 132 | workers_ignore_signals=workers_ignore_signals, 133 | ) 134 | 135 | if config.habitat.simulator.renderer.enable_batch_renderer: 136 | envs.initialize_batch_renderer(config) 137 | 138 | return envs 139 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ReLIC: A recipe for 64k steps In-Context Reinforcement Learning for Embodied AI 2 | 3 | This is the official implementation for "ReLIC: A recipe for 64k steps In-Context Reinforcement Learning for Embodied AI". 4 | 5 | 6 | **Abstract**: Intelligent embodied agents need to quickly adapt to new scenarios by integrating long histories of experience into decision-making. For instance, a robot in an unfamiliar house initially wouldn't know the locations of objects needed for tasks and might perform inefficiently. However, as it gathers more experience, it should learn the layout and remember where objects are, allowing it to complete new tasks more efficiently. To enable such rapid adaptation to new tasks, we present ReLIC, a new approach for in-context reinforcement learning (RL) for embodied agents. With ReLIC, agents are capable of adapting to new environments using 64,000 steps of in-context experience with full attention mechanism while being trained through self-generated experience via RL. We achieve this by proposing a novel policy update scheme for on-policy RL called "partial updates" as well as a Sink-KV mechanism which enables effective utilization of long observation history for embodied agents. Our method outperforms a variety of meta-RL baselines in adapting to unseen houses in an embodied multi-object navigation task. In addition, we find that ReLIC is capable of few-shot imitation learning despite never being trained with expert demonstrations. We also provide a comprehensive analysis ReLIC, highlighting that the combination of large-scale RL training, the proposed partial updates scheme, and the Sink-KV are essential for effective in-context learning. 7 | 8 | 9 | # Getting Started 10 | 11 | * Clone the repo. 12 | ```bash 13 | git clone https://github.com/aielawady/relic.git 14 | cd relic 15 | ``` 16 | * Install relic 17 | ```bash 18 | conda create -n relic -y python=3.9 19 | conda activate relic 20 | conda install -y habitat-sim==0.3.0 withbullet headless -c conda-forge -c aihabitat 21 | 22 | pip install packaging ninja 23 | pip install -e . 24 | ``` 25 | * Change the dir to `relic`. 26 | ```bash 27 | cd relic 28 | ``` 29 | * Download ExtObjNav data (ExtObjNav dataset, scenes dataset and VC1 finetuned checkpoint). 30 | ```bash 31 | bash download_datasets.sh 32 | ``` 33 | * To run the training, update the variables in `run.sh` to config wandb and run it. 34 | ```bash 35 | bash run.sh 36 | ``` 37 | * To run the evaluation, update the evaluation variables in `run.sh` run it with `--eval` flag. 38 | ```bash 39 | bash run.sh --eval 40 | ``` 41 | * To process the evaluation result follow these steps. 42 | ```python 43 | from relic.evaluator import read_csvs, extract_episodes_data_from_df 44 | 45 | # Each row in the df is a trial. 46 | df = read_csvs("exp_data/eval_data/relic/relic_ExtObjNav/latest_*.csv") 47 | 48 | # `data` is a dict. The keys are the metrics names. 49 | # The values are 2D array with size (# Trials x # Episodes). 50 | # data["counts"] is a 1D array with size (# Episodes). 51 | # data["counts"][i] is the number of trials that have ith episode. 52 | data = extract_episodes_data_from_df(df) 53 | 54 | 55 | # To plot the data, create a mask to exclude the episodes that aren't 56 | # represented in all trials. 57 | mask = data["counts"] == data["counts"].max() 58 | 59 | # Then plot the data 60 | plt.plot(data["extobjnav_success"][:, mask].mean(axis=0)) 61 | plt.xlabel("# of In-context Episodes") 62 | plt.ylabel("Success") 63 | ``` 64 | 65 | # Paper Experiments 66 | 67 | You can change `CONFIG_NAME` in `run.sh` to run other experiments. This is the description of the configs available. 68 | 69 | | Description | Config name | Figure | Task-Scene | 70 | | -- | -- | -- | -- | 71 | | ReLIC (64k context length) | baseline/relic_HSSD_64k | 3.b | ExtObjNav-HSSD | 72 | | ReLIC | baseline/baselines_comparison/relic_HSSD | 1, 2.b, 3.a | ExtObjNav-HSSD | 73 | | Single episode transformer | baseline/baselines_comparison/single_episode_HSSD | 1 | ExtObjNav-HSSD | 74 | | Our method without inter-episode attention | baseline/baselines_comparison/relic_HSSD_no_iea | 1 | ExtObjNav-HSSD | 75 | | Transformer-XL | baseline/baselines_comparison/trxl_HSSD | 1 | ExtObjNav-HSSD | 76 | | RL2 | baseline/baselines_comparison/rl2 | 1 | ExtObjNav-HSSD | 77 | | ReLIC | baseline/partial_updates_ablation/relic_replicaCAD | 2.a | ExtObjNav-ReplicaCAD | 78 | | ReLIC without partial updates | baseline/partial_updates_ablation/relic_replicaCAD_no_partial_updates | 2.a | ExtObjNav-ReplicaCAD | 79 | | ReLIC with Sink KV | baseline/sink_ablation/relic_replicaCAD_sink_kv | 11, 2.c | ExtObjNav-ReplicaCAD | 80 | | ReLIC with Sink Token | baseline/sink_ablation/relic_replicaCAD_sink_token | 2.c | ExtObjNav-ReplicaCAD | 81 | | ReLIC with Sink KV (only K is trainable) | baseline/sink_ablation/relic_replicaCAD_sink_kv0 | 11 | ExtObjNav-ReplicaCAD | 82 | | ReLIC with Sink KV (only V is trainable) | baseline/sink_ablation/relic_replicaCAD_sink_k0v | 11 | ExtObjNav-ReplicaCAD | 83 | | ReLIC with Sink KV (both are not trainable) | baseline/sink_ablation/relic_replicaCAD_sink_k0v0 | 11, 2.c | ExtObjNav-ReplicaCAD | 84 | | ReLIC without Sink mechanism | baseline/sink_ablation/relic_replicaCAD_no_sink | 2.c | ExtObjNav-ReplicaCAD | 85 | 86 | # Use ReLIC on your own task 87 | 88 | To run ReLIC on the [Darkroom](https://github.com/jon--lee/decision-pretrained-transformer) task, 89 | 90 | * Change the `CONFIG_NAME` to `baseline/custom_task/darkroom_relic` in `run.sh`. 91 | * Then run it `bash run.sh`. 92 | 93 | ReLIC works with `gym` compatible environments. Adding new task requires two files, the environment definition and the configuration. 94 | 95 | We provide `darkroom` environment wrapper with detailed comments in [example.py](relic/envs/custom_envs/example.py). We also provide two configs, one for ReLIC (Multi-episode trials) [baseline/custom_task/darkroom_relic](relic/configs/baseline/custom_task/darkroom_relic.yaml) and one for single-episode training [baseline/custom_task/darkroom_single_episode](relic/configs/baseline/custom_task/darkroom_single_episode.yaml). 96 | 97 | Once you define the two required files, you can follow the steps in the getting started section. 98 | -------------------------------------------------------------------------------- /relic/configs/pddl/hssd_domain.yaml: -------------------------------------------------------------------------------- 1 | types: 2 | static_obj_type: 3 | - art_receptacle_entity_type 4 | - obj_type 5 | - static_receptacle_entity_type 6 | obj_type: 7 | - movable_entity_type 8 | - goal_entity_type 9 | art_receptacle_entity_type: 10 | - cab_type 11 | - fridge_type 12 | 13 | constants: [] 14 | 15 | predicates: 16 | - name: in 17 | args: 18 | - name: obj 19 | expr_type: obj_type 20 | - name: receptacle 21 | expr_type: art_receptacle_entity_type 22 | set_state: 23 | obj_states: 24 | obj: receptacle 25 | 26 | - name: holding 27 | args: 28 | - name: obj 29 | expr_type: movable_entity_type 30 | - name: robot_id 31 | expr_type: robot_entity_type 32 | set_state: 33 | robot_states: 34 | robot_id: 35 | holding: obj 36 | 37 | - name: not_holding 38 | args: 39 | - name: robot_id 40 | expr_type: robot_entity_type 41 | set_state: 42 | robot_states: 43 | robot_id: 44 | should_drop: True 45 | 46 | - name: opened_cab 47 | args: 48 | - name: cab_id 49 | expr_type: cab_type 50 | set_state: 51 | art_states: 52 | cab_id: 53 | value: 0.45 54 | cmp: 'greater' 55 | override_thresh: 0.1 56 | 57 | - name: closed_cab 58 | args: 59 | - name: cab_id 60 | expr_type: cab_type 61 | set_state: 62 | arg_spec: 63 | name_match: "cab" 64 | art_states: 65 | cab_id: 66 | value: 0.0 67 | cmp: 'close' 68 | 69 | 70 | - name: opened_fridge 71 | args: 72 | - name: fridge_id 73 | expr_type: fridge_type 74 | set_state: 75 | art_states: 76 | fridge_id: 77 | value: 1.22 78 | cmp: 'greater' 79 | 80 | - name: closed_fridge 81 | args: 82 | - name: fridge_id 83 | expr_type: fridge_type 84 | set_state: 85 | art_states: 86 | fridge_id: 87 | value: 0.0 88 | cmp: 'close' 89 | 90 | - name: robot_at 91 | args: 92 | - name: Y 93 | expr_type: static_obj_type 94 | - name: robot_id 95 | expr_type: robot_entity_type 96 | set_state: 97 | robot_states: 98 | robot_id: 99 | pos: Y 100 | 101 | - name: at 102 | args: 103 | - name: obj 104 | expr_type: movable_entity_type 105 | - name: at_entity 106 | expr_type: static_obj_type 107 | set_state: 108 | obj_states: 109 | obj: at_entity 110 | 111 | actions: 112 | - name: nav 113 | parameters: 114 | - name: obj 115 | expr_type: obj_type 116 | - name: robot 117 | expr_type: robot_entity_type 118 | precondition: null 119 | postcondition: 120 | - robot_at(obj, robot) 121 | 122 | - name: nav_to_receptacle 123 | parameters: 124 | - name: marker 125 | expr_type: art_receptacle_entity_type 126 | - name: obj 127 | expr_type: obj_type 128 | - name: robot 129 | expr_type: robot_entity_type 130 | precondition: 131 | expr_type: AND 132 | sub_exprs: 133 | - in(obj, marker) 134 | postcondition: 135 | - robot_at(marker, robot) 136 | 137 | - name: pick 138 | parameters: 139 | - name: obj 140 | expr_type: movable_entity_type 141 | - name: robot 142 | expr_type: robot_entity_type 143 | precondition: 144 | expr_type: AND 145 | sub_exprs: 146 | - robot_at(obj, robot) 147 | - quantifier: FORALL 148 | inputs: 149 | - name: recep 150 | expr_type: cab_type 151 | expr_type: NAND 152 | sub_exprs: 153 | - in(obj, recep) 154 | - closed_cab(recep) 155 | postcondition: 156 | - holding(obj, robot) 157 | 158 | - name: place 159 | parameters: 160 | - name: place_obj 161 | expr_type: movable_entity_type 162 | - name: obj 163 | expr_type: goal_entity_type 164 | - name: robot 165 | expr_type: robot_entity_type 166 | precondition: 167 | expr_type: AND 168 | sub_exprs: 169 | - holding(place_obj, robot) 170 | - robot_at(obj, robot) 171 | postcondition: 172 | - not_holding(robot) 173 | - at(place_obj, obj) 174 | 175 | - name: open_fridge 176 | parameters: 177 | - name: fridge_id 178 | expr_type: fridge_type 179 | - name: obj 180 | expr_type: obj_type 181 | - name: robot 182 | expr_type: robot_entity_type 183 | precondition: 184 | expr_type: AND 185 | sub_exprs: 186 | - robot_at(fridge_id, robot) 187 | - closed_fridge(fridge_id) 188 | - in(obj,fridge_id) 189 | postcondition: 190 | - opened_fridge(fridge_id) 191 | 192 | - name: close_fridge 193 | parameters: 194 | - name: fridge_id 195 | expr_type: fridge_type 196 | - name: obj 197 | expr_type: obj_type 198 | - name: robot 199 | expr_type: robot_entity_type 200 | precondition: 201 | expr_type: AND 202 | sub_exprs: 203 | - robot_at(fridge_id, robot) 204 | - opened_fridge(fridge_id) 205 | - in(obj,fridge_id) 206 | postcondition: 207 | - closed_fridge(fridge_id) 208 | 209 | - name: open_cab 210 | parameters: 211 | - name: marker 212 | expr_type: cab_type 213 | - name: obj 214 | expr_type: obj_type 215 | - name: robot 216 | expr_type: robot_entity_type 217 | precondition: 218 | expr_type: AND 219 | sub_exprs: 220 | - robot_at(marker, robot) 221 | - closed_cab(marker) 222 | - in(obj,marker) 223 | postcondition: 224 | - opened_cab(marker) 225 | 226 | - name: close_cab 227 | parameters: 228 | - name: marker 229 | expr_type: cab_type 230 | - name: obj 231 | expr_type: obj_type 232 | - name: robot 233 | expr_type: robot_entity_type 234 | precondition: 235 | expr_type: AND 236 | sub_exprs: 237 | - robot_at(marker, robot) 238 | - opened_cab(marker) 239 | - in(obj,marker) 240 | postcondition: 241 | - closed_cab(marker) 242 | -------------------------------------------------------------------------------- /relic/trainer/rl2_storage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import warnings 7 | from typing import Iterator, Optional 8 | 9 | import numpy as np 10 | import torch 11 | 12 | from habitat_baselines.common.baseline_registry import baseline_registry 13 | from habitat_baselines.common.rollout_storage import RolloutStorage 14 | from habitat_baselines.common.tensor_dict import DictTree, TensorDict 15 | from habitat_baselines.rl.models.rnn_state_encoder import ( 16 | build_pack_info_from_dones, 17 | build_rnn_build_seq_info, 18 | ) 19 | from habitat_baselines.utils.common import get_action_space_info 20 | 21 | 22 | @baseline_registry.register_storage 23 | class RL2RolloutStorage(RolloutStorage): 24 | def __init__( 25 | self, 26 | numsteps, 27 | num_envs, 28 | observation_space, 29 | action_space, 30 | actor_critic, 31 | is_double_buffered: bool = False, 32 | change_done_masks: bool = True, 33 | set_done_to_false_during_rollout: bool = False, 34 | ): 35 | self.change_done_masks = change_done_masks 36 | self.set_done_to_false_during_rollout = set_done_to_false_during_rollout 37 | action_shape, discrete_actions = get_action_space_info(action_space) 38 | 39 | self.buffers = TensorDict() 40 | self.buffers["observations"] = TensorDict() 41 | 42 | for sensor in observation_space.spaces: 43 | self.buffers["observations"][sensor] = torch.from_numpy( 44 | np.zeros( 45 | ( 46 | numsteps + 1, 47 | num_envs, 48 | *observation_space.spaces[sensor].shape, 49 | ), 50 | dtype=observation_space.spaces[sensor].dtype, 51 | ) 52 | ) 53 | 54 | self.buffers["recurrent_hidden_states"] = torch.zeros( 55 | numsteps + 1, 56 | num_envs, 57 | actor_critic.num_recurrent_layers, 58 | actor_critic.recurrent_hidden_size, 59 | ) 60 | 61 | self.buffers["rewards"] = torch.zeros(numsteps + 1, num_envs, 1) 62 | self.buffers["value_preds"] = torch.zeros(numsteps + 1, num_envs, 1) 63 | self.buffers["returns"] = torch.zeros(numsteps + 1, num_envs, 1) 64 | 65 | self.buffers["action_log_probs"] = torch.zeros(numsteps + 1, num_envs, 1) 66 | 67 | if action_shape is None: 68 | action_shape = action_space.shape 69 | 70 | self.buffers["actions"] = torch.zeros(numsteps + 1, num_envs, *action_shape) 71 | self.buffers["prev_actions"] = torch.zeros( 72 | numsteps + 1, num_envs, *action_shape 73 | ) 74 | if discrete_actions: 75 | assert isinstance(self.buffers["actions"], torch.Tensor) 76 | assert isinstance(self.buffers["prev_actions"], torch.Tensor) 77 | self.buffers["actions"] = self.buffers["actions"].long() 78 | self.buffers["prev_actions"] = self.buffers["prev_actions"].long() 79 | 80 | self.buffers["masks"] = torch.zeros(numsteps + 1, num_envs, 1, dtype=torch.bool) 81 | 82 | self.is_double_buffered = is_double_buffered 83 | self._nbuffers = 2 if is_double_buffered else 1 84 | self._num_envs = num_envs 85 | 86 | assert (self._num_envs % self._nbuffers) == 0 87 | 88 | self.num_steps = numsteps 89 | self.current_rollout_step_idxs = [0 for _ in range(self._nbuffers)] 90 | 91 | # The default device to torch is the CPU, so everything is on the CPU. 92 | self.device = torch.device("cpu") 93 | 94 | def reset_recurrent_hidden_states(self): 95 | self.buffers[0]["recurrent_hidden_states"] = torch.zeros_like( 96 | self.buffers[0]["recurrent_hidden_states"] 97 | ) 98 | 99 | def data_generator( 100 | self, 101 | advantages: Optional[torch.Tensor], 102 | num_mini_batch: int, 103 | ) -> Iterator[DictTree]: 104 | assert isinstance(self.buffers["returns"], torch.Tensor) 105 | num_environments = self.buffers["returns"].size(1) 106 | assert num_environments >= num_mini_batch, ( 107 | "Trainer requires the number of environments ({}) " 108 | "to be greater than or equal to the number of " 109 | "trainer mini batches ({}).".format(num_environments, num_mini_batch) 110 | ) 111 | if num_environments % num_mini_batch != 0: 112 | warnings.warn( 113 | "Number of environments ({}) is not a multiple of the" 114 | " number of mini batches ({}). This results in mini batches" 115 | " of different sizes, which can harm training performance.".format( 116 | num_environments, num_mini_batch 117 | ) 118 | ) 119 | 120 | dones_cpu = ( 121 | torch.logical_not(self.buffers["masks"]) 122 | .cpu() 123 | .view(-1, self._num_envs) 124 | .numpy() 125 | ) 126 | if self.change_done_masks and not self.set_done_to_false_during_rollout: 127 | dones_cpu = np.zeros_like(dones_cpu, dtype=bool) 128 | 129 | for inds in torch.randperm(num_environments).chunk(num_mini_batch): 130 | curr_slice = (slice(0, self.current_rollout_step_idx), inds) 131 | 132 | batch = self.buffers[curr_slice] 133 | if advantages is not None: 134 | batch["advantages"] = advantages[curr_slice] 135 | batch["recurrent_hidden_states"] = batch["recurrent_hidden_states"][0:1] 136 | 137 | batch.map_in_place(lambda v: v.flatten(0, 1)) 138 | 139 | batch["rnn_build_seq_info"] = build_rnn_build_seq_info( 140 | device=self.device, 141 | build_fn_result=build_pack_info_from_dones( 142 | dones_cpu[0 : self.current_rollout_step_idx, inds.numpy()].reshape( 143 | -1, len(inds) 144 | ), 145 | ), 146 | ) 147 | 148 | yield batch.to_tree() 149 | -------------------------------------------------------------------------------- /relic/trainer/transformers_agent_access_mgr.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Callable, Dict, Optional 2 | from habitat_baselines.common.baseline_registry import baseline_registry 3 | from habitat_baselines.rl.ppo.single_agent_access_mgr import ( 4 | SingleAgentAccessMgr, 5 | get_rollout_obs_space, 6 | ) 7 | import gym.spaces as spaces 8 | from habitat_baselines.common.env_spec import EnvironmentSpec 9 | from habitat_baselines.common.storage import Storage 10 | from habitat_baselines.rl.ppo.policy import NetPolicy 11 | from habitat import logger 12 | 13 | from torch import optim 14 | from bisect import bisect_right 15 | import torch 16 | from habitat_baselines.rl.ppo.ppo import PPO 17 | 18 | if TYPE_CHECKING: 19 | from omegaconf import DictConfig 20 | 21 | 22 | class CustomSequentialLR(optim.lr_scheduler.SequentialLR): 23 | def step(self, epoch=None): 24 | if epoch is not None: 25 | self.last_epoch = epoch 26 | else: 27 | self.last_epoch += 1 28 | idx = bisect_right(self._milestones, self.last_epoch) 29 | scheduler = self._schedulers[idx] 30 | if idx > 0: 31 | scheduler.step(self.last_epoch - self._milestones[idx - 1]) 32 | else: 33 | scheduler.step(self.last_epoch) 34 | self._last_lr = scheduler.get_last_lr() 35 | 36 | 37 | @baseline_registry.register_agent_access_mgr 38 | class TransformerSingleAgentAccessMgr(SingleAgentAccessMgr): 39 | def _init_policy_and_updater(self, lr_schedule_fn, resume_state): 40 | self._actor_critic = self._create_policy() 41 | 42 | self._updater = self._create_updater(self._actor_critic) 43 | 44 | if self._updater.optimizer is None: 45 | self._lr_scheduler = None 46 | else: 47 | scheds = [] 48 | milestones = [] 49 | if self._ppo_cfg.warmup: 50 | scheds.append( 51 | optim.lr_scheduler.LinearLR( 52 | self._updater.optimizer, 53 | start_factor=self._ppo_cfg.warmup_start_factor, 54 | total_iters=self._ppo_cfg.warmup_total_iters, 55 | end_factor=self._ppo_cfg.warmup_end_factor, 56 | ) 57 | ) 58 | milestones.append(self._ppo_cfg.warmup_total_iters) 59 | 60 | if self._ppo_cfg.lr_scheduler == "cosine_decay": 61 | scheds.append( 62 | optim.lr_scheduler.CosineAnnealingLR( 63 | self._updater.optimizer, 64 | T_max=self._ppo_cfg.lrsched_T_max, 65 | eta_min=self._ppo_cfg.lrsched_eta_min, 66 | ) 67 | ) 68 | elif self._ppo_cfg.lr_scheduler == "cosine_annealing_warm_restarts": 69 | scheds.append( 70 | optim.lr_scheduler.CosineAnnealingWarmRestarts( 71 | self._updater.optimizer, 72 | T_0=self._ppo_cfg.lrsched_T_0, 73 | eta_min=self._ppo_cfg.lrsched_eta_min, 74 | ) 75 | ) 76 | elif not self._ppo_cfg.lr_scheduler: 77 | pass 78 | else: 79 | raise ValueError 80 | 81 | if scheds and len(scheds) > 1: 82 | self._lr_scheduler = CustomSequentialLR( 83 | self._updater.optimizer, scheds, milestones 84 | ) 85 | elif scheds: 86 | self._lr_scheduler = scheds[0] 87 | else: 88 | self._lr_scheduler = None 89 | 90 | if resume_state is not None: 91 | self._updater.load_state_dict(resume_state["state_dict"]) 92 | self._updater.load_state_dict( 93 | {"actor_critic." + k: v for k, v, in resume_state["state_dict"].items()} 94 | ) 95 | self._policy_action_space = self._env_spec.action_space 96 | 97 | def load_state_dict(self, state: Dict, strict=True) -> None: 98 | self._actor_critic.load_state_dict(state["state_dict"], strict=strict) 99 | if self._updater is not None: 100 | self._updater.load_state_dict(state) 101 | if "lr_sched_state" in state: 102 | self._lr_scheduler.load_state_dict(state["lr_sched_state"]) 103 | 104 | def _create_storage( 105 | self, 106 | num_envs: int, 107 | env_spec: EnvironmentSpec, 108 | actor_critic: NetPolicy, 109 | policy_action_space: spaces.Space, 110 | config: "DictConfig", 111 | device, 112 | ) -> Storage: 113 | """ 114 | Default behavior for setting up and initializing the rollout storage. 115 | """ 116 | 117 | obs_space = get_rollout_obs_space( 118 | env_spec.observation_space, actor_critic, config 119 | ) 120 | ppo_cfg = config.habitat_baselines.rl.ppo 121 | dtype = ( 122 | torch.float16 123 | if config.habitat_baselines.rl.ppo.storage_low_precision 124 | else torch.float32 125 | ) 126 | separate_rollout_and_policy = ( 127 | config.habitat_baselines.separate_rollout_and_policy 128 | ) 129 | rollout_on_cpu = config.habitat_baselines.rollout_on_cpu 130 | rollouts = baseline_registry.get_storage( 131 | config.habitat_baselines.rollout_storage_name 132 | )( 133 | numsteps=ppo_cfg.num_steps, 134 | num_envs=num_envs, 135 | observation_space=obs_space, 136 | action_space=policy_action_space, 137 | actor_critic=actor_critic, 138 | is_double_buffered=ppo_cfg.use_double_buffered_sampler, 139 | device=device, 140 | separate_rollout_and_policy=separate_rollout_and_policy, 141 | dtype=dtype, 142 | freeze_visual_feats=not config.habitat_baselines.rl.ddppo.train_encoder, 143 | on_cpu=rollout_on_cpu, 144 | acting_context=ppo_cfg.acting_context, 145 | is_memory=config.habitat_baselines.rl.policy.main_agent.transformer_config.model_name 146 | == "transformerxl", 147 | ) 148 | return rollouts 149 | 150 | def after_update(self): 151 | # if ( 152 | # self._lr_scheduler is not None 153 | # ): 154 | # self._lr_scheduler.step() # type: ignore 155 | self._updater.after_update() 156 | 157 | def _step_sched(self, steps): 158 | if self._lr_scheduler: 159 | self._lr_scheduler.step(steps) 160 | -------------------------------------------------------------------------------- /relic/configs/pddl/domain.yaml: -------------------------------------------------------------------------------- 1 | types: 2 | static_obj_type: 3 | - art_receptacle_entity_type 4 | - obj_type 5 | obj_type: 6 | - movable_entity_type 7 | - goal_entity_type 8 | art_receptacle_entity_type: 9 | - cab_type 10 | - fridge_type 11 | 12 | 13 | constants: 14 | - name: cab_push_point_7 15 | expr_type: cab_type 16 | - name: cab_push_point_6 17 | expr_type: cab_type 18 | - name: cab_push_point_5 19 | expr_type: cab_type 20 | - name: cab_push_point_4 21 | expr_type: cab_type 22 | - name: fridge_push_point 23 | expr_type: fridge_type 24 | 25 | predicates: 26 | - name: in 27 | args: 28 | - name: obj 29 | expr_type: obj_type 30 | - name: receptacle 31 | expr_type: art_receptacle_entity_type 32 | set_state: 33 | obj_states: 34 | obj: receptacle 35 | 36 | - name: holding 37 | args: 38 | - name: obj 39 | expr_type: movable_entity_type 40 | - name: robot_id 41 | expr_type: robot_entity_type 42 | set_state: 43 | robot_states: 44 | robot_id: 45 | holding: obj 46 | 47 | - name: not_holding 48 | args: 49 | - name: robot_id 50 | expr_type: robot_entity_type 51 | set_state: 52 | robot_states: 53 | robot_id: 54 | should_drop: True 55 | 56 | - name: opened_cab 57 | args: 58 | - name: cab_id 59 | expr_type: cab_type 60 | set_state: 61 | art_states: 62 | cab_id: 63 | value: 0.45 64 | cmp: 'greater' 65 | override_thresh: 0.1 66 | 67 | - name: closed_cab 68 | args: 69 | - name: cab_id 70 | expr_type: cab_type 71 | set_state: 72 | arg_spec: 73 | name_match: "cab" 74 | art_states: 75 | cab_id: 76 | value: 0.0 77 | cmp: 'close' 78 | 79 | 80 | - name: opened_fridge 81 | args: 82 | - name: fridge_id 83 | expr_type: fridge_type 84 | set_state: 85 | art_states: 86 | fridge_id: 87 | value: 1.22 88 | cmp: 'greater' 89 | 90 | - name: closed_fridge 91 | args: 92 | - name: fridge_id 93 | expr_type: fridge_type 94 | set_state: 95 | art_states: 96 | fridge_id: 97 | value: 0.0 98 | cmp: 'close' 99 | 100 | - name: robot_at 101 | args: 102 | - name: Y 103 | expr_type: static_obj_type 104 | - name: robot_id 105 | expr_type: robot_entity_type 106 | set_state: 107 | robot_states: 108 | robot_id: 109 | pos: Y 110 | 111 | - name: at 112 | args: 113 | - name: obj 114 | expr_type: movable_entity_type 115 | - name: at_entity 116 | expr_type: static_obj_type 117 | set_state: 118 | obj_states: 119 | obj: at_entity 120 | 121 | actions: 122 | - name: nav 123 | parameters: 124 | - name: obj 125 | expr_type: obj_type 126 | - name: robot 127 | expr_type: robot_entity_type 128 | precondition: null 129 | postcondition: 130 | - robot_at(obj, robot) 131 | 132 | - name: nav_to_receptacle 133 | parameters: 134 | - name: marker 135 | expr_type: art_receptacle_entity_type 136 | - name: obj 137 | expr_type: obj_type 138 | - name: robot 139 | expr_type: robot_entity_type 140 | precondition: 141 | expr_type: AND 142 | sub_exprs: 143 | - in(obj, marker) 144 | postcondition: 145 | - robot_at(marker, robot) 146 | 147 | - name: pick 148 | parameters: 149 | - name: obj 150 | expr_type: movable_entity_type 151 | - name: robot 152 | expr_type: robot_entity_type 153 | precondition: 154 | expr_type: AND 155 | sub_exprs: 156 | - robot_at(obj, robot) 157 | - quantifier: FORALL 158 | inputs: 159 | - name: recep 160 | expr_type: cab_type 161 | expr_type: NAND 162 | sub_exprs: 163 | - in(obj, recep) 164 | - closed_cab(recep) 165 | postcondition: 166 | - holding(obj, robot) 167 | 168 | - name: place 169 | parameters: 170 | - name: place_obj 171 | expr_type: movable_entity_type 172 | - name: obj 173 | expr_type: goal_entity_type 174 | - name: robot 175 | expr_type: robot_entity_type 176 | precondition: 177 | expr_type: AND 178 | sub_exprs: 179 | - holding(place_obj, robot) 180 | - robot_at(obj, robot) 181 | postcondition: 182 | - not_holding(robot) 183 | - at(place_obj, obj) 184 | 185 | - name: open_fridge 186 | parameters: 187 | - name: fridge_id 188 | expr_type: fridge_type 189 | - name: obj 190 | expr_type: obj_type 191 | - name: robot 192 | expr_type: robot_entity_type 193 | precondition: 194 | expr_type: AND 195 | sub_exprs: 196 | - robot_at(fridge_id, robot) 197 | - closed_fridge(fridge_id) 198 | - in(obj,fridge_id) 199 | postcondition: 200 | - opened_fridge(fridge_id) 201 | 202 | - name: close_fridge 203 | parameters: 204 | - name: fridge_id 205 | expr_type: fridge_type 206 | - name: obj 207 | expr_type: obj_type 208 | - name: robot 209 | expr_type: robot_entity_type 210 | precondition: 211 | expr_type: AND 212 | sub_exprs: 213 | - robot_at(fridge_id, robot) 214 | - opened_fridge(fridge_id) 215 | - in(obj,fridge_id) 216 | postcondition: 217 | - closed_fridge(fridge_id) 218 | 219 | - name: open_cab 220 | parameters: 221 | - name: marker 222 | expr_type: cab_type 223 | - name: obj 224 | expr_type: obj_type 225 | - name: robot 226 | expr_type: robot_entity_type 227 | precondition: 228 | expr_type: AND 229 | sub_exprs: 230 | - robot_at(marker, robot) 231 | - closed_cab(marker) 232 | - in(obj,marker) 233 | postcondition: 234 | - opened_cab(marker) 235 | 236 | - name: close_cab 237 | parameters: 238 | - name: marker 239 | expr_type: cab_type 240 | - name: obj 241 | expr_type: obj_type 242 | - name: robot 243 | expr_type: robot_entity_type 244 | precondition: 245 | expr_type: AND 246 | sub_exprs: 247 | - robot_at(marker, robot) 248 | - opened_cab(marker) 249 | - in(obj,marker) 250 | postcondition: 251 | - closed_cab(marker) 252 | 253 | ######################################################### 254 | # Receptacle name only based variants of the receptacle skills. This does not 255 | # require any information about knowing which objects the receptacle 256 | # contains. 257 | - name: nav_to_receptacle_by_name 258 | parameters: 259 | - name: marker 260 | expr_type: art_receptacle_entity_type 261 | - name: robot 262 | expr_type: robot_entity_type 263 | precondition: null 264 | postcondition: 265 | - robot_at(marker, robot) 266 | 267 | - name: open_fridge_by_name 268 | parameters: 269 | - name: fridge_id 270 | expr_type: fridge_type 271 | - name: robot 272 | expr_type: robot_entity_type 273 | precondition: 274 | expr_type: AND 275 | sub_exprs: 276 | - robot_at(fridge_id, robot) 277 | - closed_fridge(fridge_id) 278 | postcondition: 279 | - opened_fridge(fridge_id) 280 | 281 | - name: close_fridge_by_name 282 | parameters: 283 | - name: fridge_id 284 | expr_type: fridge_type 285 | - name: robot 286 | expr_type: robot_entity_type 287 | precondition: 288 | expr_type: AND 289 | sub_exprs: 290 | - robot_at(fridge_id, robot) 291 | - opened_fridge(fridge_id) 292 | postcondition: 293 | - closed_fridge(fridge_id) 294 | 295 | - name: open_cab_by_name 296 | parameters: 297 | - name: marker 298 | expr_type: cab_type 299 | - name: robot 300 | expr_type: robot_entity_type 301 | precondition: 302 | expr_type: AND 303 | sub_exprs: 304 | - robot_at(marker, robot) 305 | - closed_cab(marker) 306 | postcondition: 307 | - opened_cab(marker) 308 | 309 | - name: close_cab_by_name 310 | parameters: 311 | - name: marker 312 | expr_type: cab_type 313 | - name: robot 314 | expr_type: robot_entity_type 315 | precondition: 316 | expr_type: AND 317 | sub_exprs: 318 | - robot_at(marker, robot) 319 | - opened_cab(marker) 320 | postcondition: 321 | - closed_cab(marker) 322 | -------------------------------------------------------------------------------- /relic/tasks/utils.py: -------------------------------------------------------------------------------- 1 | import magnum as mn 2 | import habitat_sim 3 | import numpy as np 4 | 5 | # https://github.com/facebookresearch/habitat-sim/blob/366294cadd914791e57d7d70e61ae67026386f0b/examples/tutorials/colabs/ECCV_2020_Advanced_Features.ipynb 6 | def get_2d_point(sim, sensor_name, point_3d): 7 | # get the scene render camera and sensor object 8 | render_camera = sim._sensors[sensor_name]._sensor_object.render_camera 9 | 10 | # use the camera and projection matrices to transform the point onto the near plane 11 | projected_point_3d = render_camera.projection_matrix.transform_point( 12 | render_camera.camera_matrix.transform_point(point_3d) 13 | ) 14 | 15 | if projected_point_3d[2] < 0: 16 | return None 17 | 18 | # convert the 3D near plane point to integer pixel space 19 | point_2d = mn.Vector2(projected_point_3d[0], -projected_point_3d[1]) 20 | point_2d = point_2d / render_camera.projection_size()[0] 21 | point_2d += mn.Vector2(0.5) 22 | point_2d *= render_camera.viewport 23 | return mn.Vector2i(point_2d) 24 | 25 | 26 | def is_target_occluded( 27 | target: mn.Vector3, 28 | agent_position: mn.Vector3, 29 | agent_height: float, 30 | task, 31 | ignore_object_ids: set = set(), 32 | ignore_objects: bool = False, 33 | ignore_receptacles: bool = False, 34 | ignore_non_negative: bool = False, 35 | ) -> bool: 36 | 37 | sim = task._sim 38 | ray = habitat_sim.geo.Ray() 39 | ray.origin = agent_position + mn.Vector3(0, agent_height, 0) 40 | ray.direction = target - ray.origin 41 | raycast_results = sim.cast_ray(ray) 42 | 43 | hits = [ 44 | x.object_id 45 | for x in raycast_results.hits 46 | if x.ray_distance < 1 and x.object_id not in ignore_object_ids 47 | ] 48 | if ignore_objects: 49 | hits = [x for x in hits if x not in task.all_object_ids] 50 | if ignore_receptacles: 51 | hits = [x for x in hits if x not in task.all_receptacles_ids] 52 | if ignore_non_negative: 53 | hits = [x for x in hits if x < 0] 54 | 55 | if hits: 56 | return True 57 | return False 58 | 59 | 60 | def is_connected(l2_dist, dist, angle_diff, threshold=0.1): 61 | # print(angle_diff) 62 | return np.abs(l2_dist / dist - 1) < 0.15 and np.abs(angle_diff) < 20 / 180 * np.pi 63 | 64 | 65 | def get_dists(c, points, task): 66 | path = habitat_sim.MultiGoalShortestPath() 67 | path.requested_start = c 68 | path.requested_ends = points 69 | did_find_a_path = task._sim.pathfinder.find_path(path) 70 | if not did_find_a_path: 71 | return None 72 | dist = path.geodesic_distance 73 | l2_dist = np.linalg.norm(points[path.closest_end_point_index] - c) 74 | return l2_dist, dist, path.closest_end_point_index 75 | 76 | 77 | def get_angle(a, b): 78 | return np.arctan2(*np.asarray(b - a)[[2, 0]]) 79 | 80 | 81 | def remove_closest(c, points, task, angle=None): 82 | points = sorted(points, key=lambda x: np.linalg.norm(x - c), reverse=False) 83 | out = get_dists(c, points, task) 84 | if out is None: 85 | return points 86 | l2_dist, dist, closest_end_point_index = out 87 | angle_new = get_angle(c, points[0]) 88 | if angle is not None: 89 | angle_diff = angle_new - angle 90 | else: 91 | angle_diff = 0 92 | if is_connected(l2_dist, dist, angle_diff=angle_diff): 93 | return remove_closest( 94 | points[closest_end_point_index], 95 | [x for i, x in enumerate(points) if i != closest_end_point_index], 96 | angle=angle_new, 97 | task=task, 98 | ) 99 | else: 100 | return points 101 | 102 | 103 | def process_navigation_points(c, points, task): 104 | new_points = [] 105 | for i in range(len(points)): 106 | for j in range(i + 1, len(points)): 107 | if points[j] == points[i]: 108 | break 109 | else: 110 | new_points.append(points[i]) 111 | points = new_points 112 | 113 | to_keep = [] 114 | while points: 115 | last_len = len(points) 116 | points = sorted(points, key=lambda x: np.linalg.norm(x - c), reverse=False) 117 | to_keep.append(points[0]) 118 | pot_points = [] 119 | last_len = -1 120 | while len(pot_points) != last_len: 121 | last_len = len(pot_points) 122 | pot_points = remove_closest(points[0], points[1:], task=task) 123 | points = [points[0], *pot_points] 124 | points = points[1:] 125 | return to_keep 126 | 127 | 128 | def get_navigation_points( 129 | c, 130 | task, 131 | r=2, 132 | n=60, 133 | eps=0.1, 134 | angle_tol=5, 135 | max_num_trials=20, 136 | height=1.2, 137 | target_object_id=None, 138 | ignore_objects: bool = False, 139 | ignore_receptacles: bool = False, 140 | ignore_non_negative: bool = False, 141 | cleanup_nav_points: bool = False, 142 | ): 143 | c = np.asarray(c) 144 | output = [] 145 | for i in np.linspace(0, 2 * np.pi, n): 146 | for trial_i in range(1, max_num_trials + 1): 147 | shift = [(trial_i * eps) * np.cos(i), 0, (trial_i * eps) * np.sin(i)] 148 | projection = task._sim.pathfinder.snap_point( 149 | c + shift, task._sim.largest_island_idx 150 | ) 151 | dst = np.linalg.norm(projection - c) 152 | if ( 153 | dst <= r 154 | and np.abs(np.arctan2(*np.asarray(projection - c)[[2, 0]]) - i) 155 | < angle_tol / 180 * np.pi 156 | ): 157 | break 158 | 159 | if dst <= r and not is_target_occluded( 160 | c, 161 | projection, 162 | height, 163 | task, 164 | ignore_object_ids=[target_object_id], 165 | ignore_objects=ignore_objects, 166 | ignore_receptacles=ignore_receptacles, 167 | ignore_non_negative=ignore_non_negative, 168 | ): 169 | output.append(projection) 170 | 171 | if cleanup_nav_points: 172 | output = process_navigation_points(c, output, task) 173 | return output 174 | 175 | 176 | def get_navigation_points_grid( 177 | c, 178 | task, 179 | r=2, 180 | n=60, 181 | eps=0.1, 182 | angle_tol=5, 183 | max_num_trials=20, 184 | height=1.2, 185 | target_object_id=None, 186 | ignore_objects: bool = False, 187 | ignore_receptacles: bool = False, 188 | ignore_non_negative: bool = False, 189 | cleanup_nav_points: bool = False, 190 | ): 191 | for k in range(max_num_trials): 192 | c = np.asarray(c) 193 | island_y = task._sim.pathfinder.get_random_navigable_point_near( 194 | c, 3, island_index=task._sim.largest_island_idx 195 | )[1] 196 | points = np.meshgrid(np.arange(-r, r, eps), np.arange(-r, r, eps)) 197 | mask = points[0] ** 2 + points[1] ** 2 < r**2 - eps 198 | xs = points[0][mask] + c[0] 199 | ys = np.zeros_like(xs) + island_y 200 | zs = points[1][mask] + c[2] 201 | filtered_points = [ 202 | (x, y, z) 203 | for x, y, z in zip(xs, ys, zs) 204 | if task._sim.is_navigable([x, y, z]) 205 | ] 206 | output = [ 207 | np.asarray(x) 208 | for x in filtered_points 209 | if not is_target_occluded( 210 | c, 211 | np.asarray(x), 212 | height + eps * k, 213 | task, 214 | ignore_object_ids=[target_object_id], 215 | ignore_objects=ignore_objects, 216 | ignore_receptacles=ignore_receptacles, 217 | ignore_non_negative=ignore_non_negative, 218 | ) 219 | ] 220 | if cleanup_nav_points: 221 | output = process_navigation_points(c, output, task) 222 | if output: 223 | break 224 | return output 225 | 226 | 227 | def get_obj_pixel_counts(task, margin=None, strict=True): 228 | sim_obs = task._sim.get_sensor_observations() 229 | observations = task._sim._sensor_suite.sensors["head_panoptic"].get_observation( 230 | sim_obs 231 | ) 232 | objs_ids, objs_count = np.unique(observations, return_counts=True) 233 | objs_ids -= task._sim.habitat_config.object_ids_start 234 | id2count = dict(zip(objs_ids, objs_count)) 235 | if margin is not None: 236 | observations[margin:-margin, margin:-margin] = -10000 237 | margin_objs_ids, margin_objs_count = np.unique(observations, return_counts=True) 238 | margin_objs_ids -= task._sim.habitat_config.object_ids_start 239 | margin_id2count = dict(zip(margin_objs_ids, margin_objs_count)) 240 | for k in margin_id2count: 241 | if k in id2count and margin_id2count[k] > (0 if strict else id2count[k]): 242 | del id2count[k] 243 | return id2count 244 | -------------------------------------------------------------------------------- /relic/policies/llamarl/configuration_llamarl.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 5 | # and OPT implementations in this library. It has been modified from its 6 | # original forms to accommodate minor architectural differences compared 7 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ LLaMARL model configuration""" 21 | 22 | from transformers.configuration_utils import PretrainedConfig 23 | from transformers.utils import logging 24 | 25 | logger = logging.get_logger(__name__) 26 | 27 | LLAMARL_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 28 | 29 | 30 | class LlamaRLConfig(PretrainedConfig): 31 | r""" 32 | This is the configuration class to store the configuration of a [`LlamaRLModel`]. It is used to instantiate an LLaMARL 33 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 34 | defaults will yield a similar configuration to that of the LLaMARL-7B. 35 | 36 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 37 | documentation from [`PretrainedConfig`] for more information. 38 | 39 | 40 | Args: 41 | vocab_size (`int`, *optional*, defaults to 32000): 42 | Vocabulary size of the LLaMARL model. Defines the number of different tokens that can be represented by the 43 | `inputs_ids` passed when calling [`LlamaRLModel`] 44 | hidden_size (`int`, *optional*, defaults to 4096): 45 | Dimension of the hidden representations. 46 | intermediate_size (`int`, *optional*, defaults to 11008): 47 | Dimension of the MLP representations. 48 | num_hidden_layers (`int`, *optional*, defaults to 32): 49 | Number of hidden layers in the Transformer encoder. 50 | num_attention_heads (`int`, *optional*, defaults to 32): 51 | Number of attention heads for each attention layer in the Transformer encoder. 52 | num_key_value_heads (`int`, *optional*): 53 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 54 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 55 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 56 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 57 | by meanpooling all the original heads within that group. For more details checkout [this 58 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 59 | `num_attention_heads`. 60 | pretraining_tp (`int`, *optional*, defaults to `1`): 61 | Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this 62 | document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is 63 | necessary to ensure exact reproducibility of the pretraining results. Please refer to [this 64 | issue](https://github.com/pytorch/pytorch/issues/76232). 65 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 66 | The non-linear activation function (function or string) in the decoder. 67 | max_position_embeddings (`int`, *optional*, defaults to 2048): 68 | The maximum sequence length that this model might ever be used with. Typically set this to something large 69 | just in case (e.g., 512 or 1024 or 2048). 70 | initializer_range (`float`, *optional*, defaults to 0.02): 71 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 72 | rms_norm_eps (`float`, *optional*, defaults to 1e-12): 73 | The epsilon used by the rms normalization layers. 74 | use_cache (`bool`, *optional*, defaults to `True`): 75 | Whether or not the model should return the last key/values attentions (not used by all models). Only 76 | relevant if `config.is_decoder=True`. 77 | tie_word_embeddings(`bool`, *optional*, defaults to `False`): 78 | Whether to tie weight embeddings 79 | rope_scaling (`Dict`, *optional*): 80 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 81 | strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format 82 | is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 83 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how 84 | these scaling strategies behave: 85 | https://www.reddit.com/r/LocalLLaMARL/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an 86 | experimental feature, subject to breaking API changes in future versions. 87 | 88 | Example: 89 | 90 | ```python 91 | >>> from transformers import LlamaRLModel, LlamaRLConfig 92 | 93 | >>> # Initializing a LLaMARL llamaRL-7b style configuration 94 | >>> configuration = LlamaRLConfig() 95 | 96 | >>> # Initializing a model from the llamaRL-7b style configuration 97 | >>> model = LlamaRLModel(configuration) 98 | 99 | >>> # Accessing the model configuration 100 | >>> configuration = model.config 101 | ```""" 102 | model_type = "llamaRL" 103 | keys_to_ignore_at_inference = ["past_key_values"] 104 | 105 | def __init__( 106 | self, 107 | max_num_sequence=10000, 108 | hidden_size=4096, 109 | intermediate_size=11008, 110 | num_hidden_layers=32, 111 | num_attention_heads=32, 112 | num_key_value_heads=None, 113 | hidden_act="silu", 114 | max_position_embeddings=2048, 115 | initializer_range=0.02, 116 | rms_norm_eps=1e-6, 117 | use_cache=True, 118 | pad_token_id=None, 119 | bos_token_id=1, 120 | eos_token_id=2, 121 | pretraining_tp=1, 122 | tie_word_embeddings=False, 123 | position_embed_type="learnable", 124 | rope_scaling=None, 125 | depth_dropout_p: float = 0, 126 | inter_episodes_attention: bool = False, 127 | reset_position_index: bool = True, 128 | add_sequence_idx_embed: bool = False, 129 | sequence_embed_type: str = "learnable", 130 | gated_residual: bool = False, 131 | context_len: bool = 0, 132 | banded_attention: bool = False, 133 | orphan_steps_attention: bool = True, 134 | add_sink_kv: bool = False, 135 | add_sink_tokens: bool = False, 136 | num_sink_tokens: int = 1, 137 | mul_factor_for_sink_attn: bool = True, 138 | is_causal=True, 139 | is_sink_v_trainable=True, 140 | is_sink_k_trainable=True, 141 | **kwargs, 142 | ): 143 | self.max_num_sequence = max_num_sequence 144 | self.max_position_embeddings = max_position_embeddings 145 | self.hidden_size = hidden_size 146 | self.intermediate_size = intermediate_size 147 | self.num_hidden_layers = num_hidden_layers 148 | self.num_attention_heads = num_attention_heads 149 | 150 | # for backward compatibility 151 | if num_key_value_heads is None: 152 | num_key_value_heads = num_attention_heads 153 | 154 | self.num_key_value_heads = num_key_value_heads 155 | self.hidden_act = hidden_act 156 | self.initializer_range = initializer_range 157 | self.rms_norm_eps = rms_norm_eps 158 | self.pretraining_tp = pretraining_tp 159 | self.use_cache = use_cache 160 | self.position_embed_type = position_embed_type 161 | self.rope_scaling = rope_scaling 162 | self._rope_scaling_validation() 163 | self.inter_episodes_attention = inter_episodes_attention 164 | self.reset_position_index = reset_position_index 165 | self.add_sequence_idx_embed = add_sequence_idx_embed 166 | self.sequence_embed_type = sequence_embed_type 167 | self.gated_residual = gated_residual 168 | self.context_len = context_len 169 | self.banded_attention = banded_attention 170 | self.orphan_steps_attention = orphan_steps_attention 171 | self.add_sink_kv = add_sink_kv 172 | self.is_sink_v_trainable = is_sink_v_trainable 173 | self.is_sink_k_trainable = is_sink_k_trainable 174 | self.mul_factor_for_sink_attn = mul_factor_for_sink_attn 175 | self.add_sink_tokens = add_sink_tokens 176 | self.num_sink_tokens = num_sink_tokens 177 | self.depth_dropout_p = depth_dropout_p 178 | self.is_causal = is_causal 179 | super().__init__( 180 | pad_token_id=pad_token_id, 181 | bos_token_id=bos_token_id, 182 | eos_token_id=eos_token_id, 183 | tie_word_embeddings=tie_word_embeddings, 184 | **kwargs, 185 | ) 186 | 187 | def _rope_scaling_validation(self): 188 | """ 189 | Validate the `rope_scaling` configuration. 190 | """ 191 | if self.rope_scaling is None: 192 | return 193 | 194 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 195 | raise ValueError( 196 | "`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, " 197 | f"got {self.rope_scaling}" 198 | ) 199 | rope_scaling_type = self.rope_scaling.get("type", None) 200 | rope_scaling_factor = self.rope_scaling.get("factor", None) 201 | if rope_scaling_type is None or rope_scaling_type not in [ 202 | "linear", 203 | "dynamic", 204 | ]: 205 | raise ValueError( 206 | f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 207 | ) 208 | if ( 209 | rope_scaling_factor is None 210 | or not isinstance(rope_scaling_factor, float) 211 | or rope_scaling_factor <= 1.0 212 | ): 213 | raise ValueError( 214 | f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}" 215 | ) 216 | -------------------------------------------------------------------------------- /relic/trainer/datasets.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from itertools import groupby 3 | import json 4 | import random 5 | from typing import ( 6 | Any, 7 | Iterator, 8 | List, 9 | Optional, 10 | Sequence, 11 | Union, 12 | ) 13 | 14 | import numpy as np 15 | from habitat.core.dataset import Episode, EpisodeIterator, T 16 | from habitat.core.registry import registry 17 | from habitat.datasets.rearrange.rearrange_dataset import ( 18 | RearrangeDatasetV0, 19 | RearrangeEpisode, 20 | ) 21 | from habitat.datasets.object_nav.object_nav_dataset import ObjectNavDatasetV1 22 | from numpy import ndarray 23 | 24 | 25 | class EpisodeIteratorRepeat(EpisodeIterator): 26 | def __init__(self, *args, **kwargs): 27 | super().__init__(*args, **kwargs) 28 | self._next_episode = next(self._iterator) 29 | if len(self.episodes) < 10: 30 | print([x.episode_id for x in self.episodes]) 31 | 32 | def after_update(self): 33 | self._forced_scene_switch_if() 34 | 35 | self._next_episode = next(self._iterator, None) 36 | if self._next_episode is None: 37 | if not self.cycle: 38 | raise StopIteration 39 | 40 | self._iterator = iter(self.episodes) 41 | 42 | if self.shuffle: 43 | self._shuffle() 44 | 45 | self._next_episode = next(self._iterator) 46 | 47 | if ( 48 | self._prev_scene_id != self._next_episode.scene_id 49 | and self._prev_scene_id is not None 50 | ): 51 | self._rep_count = 0 52 | self._step_count = 0 53 | 54 | self._prev_scene_id = self._next_episode.scene_id 55 | 56 | def __next__(self) -> Episode: 57 | return self._next_episode 58 | 59 | 60 | # ========================================================================================== 61 | 62 | 63 | class ObjNavEpisodeIterator(Iterator[T]): 64 | r"""Episode Iterator class that gives options for how a list of episodes 65 | should be iterated. 66 | 67 | Some of those options are desirable for the internal simulator to get 68 | higher performance. More context: simulator suffers overhead when switching 69 | between scenes, therefore episodes of the same scene should be loaded 70 | consecutively. However, if too many consecutive episodes from same scene 71 | are feed into RL model, the model will risk to overfit that scene. 72 | Therefore it's better to load same scene consecutively and switch once a 73 | number threshold is reached. 74 | 75 | Currently supports the following features: 76 | 77 | Cycling: 78 | when all episodes are iterated, cycle back to start instead of throwing 79 | StopIteration. 80 | Cycling with shuffle: 81 | when cycling back, shuffle episodes groups grouped by scene. 82 | Group by scene: 83 | episodes of same scene will be grouped and loaded consecutively. 84 | Set max scene repeat: 85 | set a number threshold on how many episodes from the same scene can be 86 | loaded consecutively. 87 | Sample episodes: 88 | sample the specified number of episodes. 89 | """ 90 | 91 | def __init__( 92 | self, 93 | episodes: Sequence[T], 94 | cycle: bool = True, 95 | shuffle: bool = False, 96 | group_by_scene: bool = True, 97 | max_scene_repeat_episodes: int = -1, 98 | max_scene_repeat_steps: int = -1, 99 | num_episode_sample: int = -1, 100 | step_repetition_range: float = 0.2, 101 | seed: int = None, 102 | ) -> None: 103 | r""".. 104 | 105 | :param episodes: list of episodes. 106 | :param cycle: if :py:`True`, cycle back to first episodes when 107 | StopIteration. 108 | :param shuffle: if :py:`True`, shuffle scene groups when cycle. No 109 | effect if cycle is set to :py:`False`. Will shuffle grouped scenes 110 | if :p:`group_by_scene` is :py:`True`. 111 | :param group_by_scene: if :py:`True`, group episodes from same scene. 112 | :param max_scene_repeat_episodes: threshold of how many episodes from the same 113 | scene can be loaded consecutively. :py:`-1` for no limit 114 | :param max_scene_repeat_steps: threshold of how many steps from the same 115 | scene can be taken consecutively. :py:`-1` for no limit 116 | :param num_episode_sample: number of episodes to be sampled. :py:`-1` 117 | for no sampling. 118 | :param step_repetition_range: The maximum number of steps within each scene is 119 | uniformly drawn from 120 | [1 - step_repeat_range, 1 + step_repeat_range] * max_scene_repeat_steps 121 | on each scene switch. This stops all workers from swapping scenes at 122 | the same time 123 | """ 124 | assert group_by_scene 125 | if seed: 126 | random.seed(seed) 127 | np.random.seed(seed) 128 | 129 | # sample episodes 130 | if num_episode_sample >= 0: 131 | episodes = np.random.choice( # type: ignore[assignment] 132 | episodes, num_episode_sample, replace=False # type: ignore[arg-type] 133 | ) 134 | 135 | if not isinstance(episodes, list): 136 | episodes = list(episodes) 137 | 138 | print(f"Scenes: {set(x.scene_id for x in episodes)}") 139 | 140 | self.episodes = episodes 141 | self.cycle = cycle 142 | self.group_by_scene = group_by_scene 143 | self.shuffle = shuffle 144 | 145 | if shuffle: 146 | random.shuffle(self.episodes) 147 | 148 | self.episodes = self._group_scenes(self.episodes) 149 | 150 | self.max_scene_repetition_episodes = max_scene_repeat_episodes 151 | self.max_scene_repetition_steps = max_scene_repeat_steps 152 | 153 | self._rep_count = -1 # 0 corresponds to first episode already returned 154 | self._step_count = 0 155 | self._prev_scene_id: Optional[str] = None 156 | self._iterator = iter(self.episodes[0]) 157 | 158 | self.step_repetition_range = step_repetition_range 159 | self._set_shuffle_intervals() 160 | self.should_switch_scene = False 161 | self.switches_count = 1 162 | 163 | def __iter__(self) -> "EpisodeIterator": 164 | return self 165 | 166 | def __next__(self) -> Episode: 167 | r"""The main logic for handling how episodes will be iterated. 168 | 169 | :return: next episode. 170 | """ 171 | # self._forced_scene_switch_if() 172 | 173 | next_episode = next(self._iterator, None) 174 | if next_episode is None: 175 | self.should_switch_scene = True 176 | if not self.cycle: 177 | raise StopIteration 178 | 179 | self._iterator = iter(self.episodes[0]) 180 | 181 | if self.shuffle: 182 | self._shuffle() 183 | 184 | next_episode = next(self._iterator) 185 | 186 | if ( 187 | self._prev_scene_id != next_episode.scene_id 188 | and self._prev_scene_id is not None 189 | ): 190 | self._rep_count = 0 191 | self._step_count = 0 192 | 193 | self._prev_scene_id = next_episode.scene_id 194 | return next_episode 195 | 196 | def _forced_scene_switch(self) -> None: 197 | r"""Internal method to switch the scene. Moves remaining episodes 198 | from current scene to the end and switch to next scene episodes. 199 | """ 200 | 201 | self.episodes.rotate(-1) 202 | self._iterator = iter(self.episodes[0]) 203 | 204 | def _shuffle(self) -> None: 205 | r"""Internal method that shuffles the remaining episodes. 206 | If self.group_by_scene is true, then shuffle groups of scenes. 207 | """ 208 | assert self.shuffle 209 | # random.shuffle(episodes) 210 | for e in self.episodes: 211 | random.shuffle(e) 212 | 213 | def _group_scenes( 214 | self, episodes: Union[Sequence[Episode], List[Episode], ndarray] 215 | ) -> List[T]: 216 | r"""Internal method that groups episodes by scene 217 | Groups will be ordered by the order the first episode of a given 218 | scene is in the list of episodes 219 | 220 | So if the episodes list shuffled before calling this method, 221 | the scenes will be in a random order 222 | """ 223 | assert self.group_by_scene 224 | episodes = sorted(episodes, key=lambda x: (x.scene_id, x.object_category)) 225 | print(f"There are {len(set(x.scene_id for x in episodes))} scenes.") 226 | print( 227 | f"There are {len(set((x.scene_id, x.object_category) for x in episodes))} scenes and targets." 228 | ) 229 | groups = deque( 230 | [ 231 | list(x[1]) 232 | for x in groupby(episodes, lambda x: (x.scene_id, x.object_category)) 233 | ] 234 | ) 235 | print(f"Found {len(groups)} groups.") 236 | return groups 237 | 238 | def step_taken(self) -> None: 239 | self._step_count += 1 240 | 241 | @staticmethod 242 | def _randomize_value(value: int, value_range: float) -> int: 243 | return random.randint( 244 | int(value * (1 - value_range)), int(value * (1 + value_range)) 245 | ) 246 | 247 | def _set_shuffle_intervals(self) -> None: 248 | if self.max_scene_repetition_episodes > 0: 249 | self._max_rep_episode = self.max_scene_repetition_episodes 250 | else: 251 | self._max_rep_episode = None 252 | 253 | if self.max_scene_repetition_steps > 0: 254 | self._max_rep_step = self._randomize_value( 255 | self.max_scene_repetition_steps, self.step_repetition_range 256 | ) 257 | else: 258 | self._max_rep_step = None 259 | 260 | def _forced_scene_switch_if(self) -> None: 261 | do_switch = False 262 | self._rep_count += 1 263 | 264 | # Shuffle if a scene has been selected more than _max_rep_episode times in a row 265 | if ( 266 | self._max_rep_episode is not None 267 | and self._rep_count >= self._max_rep_episode 268 | ): 269 | do_switch = True 270 | 271 | # Shuffle if a scene has been used for more than _max_rep_step steps in a row 272 | if self._max_rep_step is not None and self._step_count >= self._max_rep_step: 273 | do_switch = True 274 | 275 | if do_switch or self.should_switch_scene: 276 | self._forced_scene_switch() 277 | self._set_shuffle_intervals() 278 | self.should_switch_scene = False 279 | self.switches_count += 1 280 | 281 | def after_update(self): 282 | self._forced_scene_switch_if() 283 | 284 | 285 | # ================================================= 286 | 287 | 288 | @registry.register_dataset(name="RearrangeDatasetTransformers-v0") 289 | class RearrangeDatasetTransformersV0(RearrangeDatasetV0): 290 | def get_episode_iterator(self, *args: Any, **kwargs: Any) -> Iterator[T]: 291 | return EpisodeIteratorRepeat(self.episodes, *args, **kwargs) 292 | 293 | def from_json( 294 | self, json_str: str, scenes_dir: Optional[str] = None, reset_episode_ids=True 295 | ) -> None: 296 | deserialized = json.loads(json_str) 297 | 298 | for i, episode in enumerate(deserialized["episodes"]): 299 | rearrangement_episode = RearrangeEpisode(**episode) 300 | if reset_episode_ids: 301 | rearrangement_episode.episode_id = str(i) 302 | 303 | self.episodes.append(rearrangement_episode) 304 | 305 | 306 | @registry.register_dataset(name="ObjectNavTransformers-v1") 307 | class ObjectNavTransformersV1(ObjectNavDatasetV1): 308 | def get_episode_iterator(self, *args: Any, **kwargs: Any) -> Iterator[T]: 309 | return ObjNavEpisodeIterator(self.episodes, *args, **kwargs) 310 | -------------------------------------------------------------------------------- /relic/policies/transformer_wrappers.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import transformers 8 | from habitat import logger 9 | from habitat_baselines.rl.ddppo.policy import resnet 10 | from habitat_baselines.rl.ddppo.policy.resnet_policy import ResNetEncoder 11 | 12 | from relic.policies.llamarl.configuration_llamarl import ( 13 | LlamaRLConfig, 14 | ) 15 | from relic.policies.llamarl.modeling_llamarl import LlamaRLModel 16 | 17 | # from transformers import TransfoXLModel, TransfoXLConfig 18 | from transformers import TransfoXLConfig 19 | from relic.policies.transformerxl.modeling_transformerxl import TransfoXLModel 20 | 21 | 22 | class TransformerWrapper(nn.Module): 23 | def __init__(self, input_size: int, config): 24 | super().__init__() 25 | self.model_name = config.model_name 26 | self.inter_episodes_attention = config.inter_episodes_attention 27 | self.reset_position_index = config.reset_position_index 28 | self.add_sequence_idx_embed = config.add_sequence_idx_embed 29 | self.n_layers = config.n_layers 30 | self.n_embed = config.n_hidden 31 | self.n_mlp_hidden = config.n_mlp_hidden 32 | self.n_head = config.n_heads 33 | self.activation = config.activation 34 | self.position_embed_type = config.position_embed_type 35 | self.sequence_embed_type = config.sequence_embed_type 36 | self.depth_dropout_p = config.depth_dropout_p 37 | self.gated_residual = config.gated_residual 38 | self.add_sink_kv = config.add_sink_kv 39 | self.mul_factor_for_sink_attn = config.mul_factor_for_sink_attn 40 | self.is_sink_v_trainable = config.is_sink_v_trainable 41 | self.is_sink_k_trainable = config.get("is_sink_k_trainable", True) 42 | self.add_sink_tokens = config.add_sink_tokens 43 | self.num_sink_tokens = config.num_sink_tokens 44 | 45 | self.context_len = config.context_len 46 | self.mem_len = config.get("mem_len", -1) 47 | self.banded_attention = config.banded_attention 48 | self.orphan_steps_attention = config.orphan_steps_attention 49 | self.add_context_loss = config.add_context_loss 50 | self.max_position_embeddings = config.max_position_embeddings 51 | self.feats_proj = nn.Linear(input_size, self.n_embed) 52 | self.feats_out = nn.Linear(self.n_embed, self.n_embed) 53 | 54 | if self.model_name == "gpt": 55 | self.hf_config = transformers.GPT2Config( 56 | vocab_size=0, 57 | n_embd=self.n_embed, 58 | n_layer=self.n_layers, 59 | n_head=self.n_head, 60 | ) 61 | 62 | self.model = transformers.GPT2Model(self.hf_config) 63 | self.model.wte.weight.requires_grad_(False) 64 | elif self.model_name == "llamarl": 65 | self.hf_config = LlamaRLConfig( 66 | hidden_size=self.n_embed, 67 | intermediate_size=self.n_mlp_hidden, 68 | num_hidden_layers=self.n_layers, 69 | num_attention_heads=self.n_head, 70 | hidden_act=self.activation, 71 | inter_episodes_attention=self.inter_episodes_attention, 72 | reset_position_index=self.reset_position_index, 73 | add_sequence_idx_embed=self.add_sequence_idx_embed, 74 | position_embed_type=self.position_embed_type, 75 | gated_residual=self.gated_residual, 76 | context_len=self.context_len, 77 | banded_attention=self.banded_attention, 78 | orphan_steps_attention=self.orphan_steps_attention, 79 | depth_dropout_p=self.depth_dropout_p, 80 | max_position_embeddings=self.max_position_embeddings, 81 | add_sink_kv=self.add_sink_kv, 82 | mul_factor_for_sink_attn=self.mul_factor_for_sink_attn, 83 | is_sink_v_trainable=self.is_sink_v_trainable, 84 | is_sink_k_trainable=self.is_sink_k_trainable, 85 | add_sink_tokens=self.add_sink_tokens, 86 | num_sink_tokens=self.num_sink_tokens, 87 | sequence_embed_type=self.sequence_embed_type, 88 | ) 89 | 90 | self.model = LlamaRLModel(self.hf_config) 91 | elif self.model_name == "transformerxl": 92 | self.hf_config = TransfoXLConfig( 93 | d_model=self.n_embed, 94 | d_embed=self.n_embed, 95 | n_head=self.n_head, 96 | d_inner=self.n_mlp_hidden, 97 | pre_lnorm=True, 98 | n_layer=self.n_layers, 99 | mem_len=self.mem_len, 100 | dropout=self.depth_dropout_p, 101 | ) 102 | self.model = TransfoXLModel(self.hf_config) 103 | else: 104 | raise ValueError(f"Unrecognized {self.model_name}") 105 | 106 | logger.info(f"Done loading llama") 107 | 108 | def postprocess_past_key_value(self, past_key_values, full_rnn_state=False): 109 | # past_key_values.shape -> [nL, 2(k and v), bs(nE), nH, nS, nE/nH] 110 | past_key_values = torch.stack([torch.stack(x) for x in past_key_values]) 111 | if not full_rnn_state: 112 | return past_key_values.permute(2, 0, 1, 3, 4, 5)[..., -1, :].flatten(2, 4) 113 | else: 114 | return past_key_values.permute(4, 2, 0, 1, 3, 5).flatten(3, 5) 115 | 116 | def stack_past_key_values(self, past_key_values, last_step=False): 117 | if past_key_values is None: 118 | return None 119 | if self.model_name == "transformerxl": 120 | return torch.stack(past_key_values, dim=0) 121 | if last_step: 122 | past_key_values = torch.stack( 123 | [torch.stack([y[..., -1, :] for y in x]) for x in past_key_values] 124 | ) 125 | else: 126 | past_key_values = torch.stack([torch.stack(x) for x in past_key_values]) 127 | return past_key_values 128 | 129 | def preprocess_past_key_value(self, past_key_values): 130 | # past_key_values.shape -> [nS, bs, nL, 2*nH*nE/nH] 131 | bs, nS, nL, _ = past_key_values.shape 132 | nH = self.n_head 133 | nE = self.n_embed 134 | return past_key_values.reshape(bs, nS, nL, 2, nH, nE // nH).permute( 135 | 2, 3, 0, 4, 1, 5 136 | ) 137 | 138 | def forward( 139 | self, 140 | feats, 141 | rnn_hidden_states, 142 | masks, 143 | rnn_build_seq_info, 144 | full_rnn_state=False, 145 | **kwargs, 146 | ): 147 | if rnn_build_seq_info is None: 148 | past_key_values = ( 149 | rnn_hidden_states if np.prod(rnn_hidden_states.shape) > 0 else None 150 | ) 151 | n_envs = rnn_hidden_states.shape[2] 152 | seq_len = 1 153 | masks = masks.squeeze(-1).float() 154 | stop_grad_steps = 0 155 | use_cache = True 156 | else: 157 | n_envs, seq_len = rnn_build_seq_info["dims"] 158 | if self.model_name == "transformerxl": 159 | past_key_values = rnn_hidden_states 160 | past_key_values = ( 161 | past_key_values.squeeze(-1) 162 | .unflatten(0, (self.num_recurrent_layers, self.memory_size)) 163 | .float() 164 | ) 165 | use_cache = False 166 | else: 167 | past_key_values = None 168 | use_cache = full_rnn_state 169 | masks = masks.squeeze(-1).unflatten(0, (n_envs, seq_len)).float() 170 | if "stop_grad_steps" in rnn_build_seq_info: 171 | stop_grad_steps = rnn_build_seq_info["stop_grad_steps"] 172 | else: 173 | stop_grad_steps = 0 174 | 175 | feats = feats.unflatten(0, (n_envs, seq_len)) 176 | 177 | if rnn_build_seq_info is not None: 178 | old_context_length = rnn_build_seq_info["old_context_length"] 179 | else: 180 | old_context_length = 0 181 | 182 | feats = torch.concat( 183 | [ 184 | feats[:, :old_context_length].detach(), 185 | feats[:, old_context_length:], 186 | ], 187 | dim=1, 188 | ) 189 | 190 | if ( 191 | rnn_build_seq_info is not None 192 | and not rnn_build_seq_info["is_first"] 193 | and not self.add_context_loss 194 | ): 195 | feats = torch.concat( 196 | [ 197 | feats[:, : rnn_build_seq_info["old_context_length"]].detach(), 198 | feats[:, rnn_build_seq_info["old_context_length"] :], 199 | ], 200 | dim=1, 201 | ) 202 | 203 | if stop_grad_steps: 204 | feats_ = feats[:, :stop_grad_steps].detach() 205 | masks_ = masks[:, :stop_grad_steps].detach() 206 | feats = feats[:, stop_grad_steps:] 207 | 208 | # TODO check why torch.no_grad doesn't work. 209 | feats_ = self.feats_proj(feats_) 210 | output_ = self.model( 211 | inputs_embeds=feats_, 212 | past_key_values=None, 213 | attention_mask=masks_, 214 | ) 215 | feats_ = output_.last_hidden_state 216 | feats_ = self.feats_out(feats_) 217 | 218 | past_key_values = output_.past_key_values 219 | 220 | feats_ = feats_.detach() 221 | past_key_values = self.stack_past_key_values(past_key_values).detach() 222 | 223 | feats = self.feats_proj(feats) 224 | if self.model_name == "transformerxl": 225 | output = self.model( 226 | inputs_embeds=feats, 227 | mems=past_key_values, 228 | # attention_mask=masks, 229 | # use_cache=use_cache, 230 | # **kwargs 231 | ) 232 | output.past_key_values = output.mems 233 | else: 234 | output = self.model( 235 | inputs_embeds=feats, 236 | past_key_values=past_key_values, 237 | attention_mask=masks, 238 | use_cache=use_cache, 239 | **kwargs, 240 | ) 241 | 242 | feats = output.last_hidden_state 243 | feats = self.feats_out(feats) 244 | 245 | if ( 246 | rnn_build_seq_info is not None 247 | and not rnn_build_seq_info["is_first"] 248 | and not self.add_context_loss 249 | ): 250 | feats = feats[:, rnn_build_seq_info["old_context_length"] :] 251 | 252 | if stop_grad_steps: 253 | feats = torch.concat([feats_, feats], dim=1) 254 | feats = feats.flatten(0, 1) 255 | if kwargs: 256 | return ( 257 | feats, 258 | self.stack_past_key_values( 259 | output.past_key_values, 260 | last_step=self.model_name != "transformerxl" and not full_rnn_state, 261 | ), 262 | output, 263 | ) 264 | else: 265 | return feats, self.stack_past_key_values( 266 | output.past_key_values, 267 | last_step=self.model_name != "transformerxl" and not full_rnn_state, 268 | ) 269 | 270 | def get_trainable_params(self): 271 | return chain( 272 | self.feats_proj.parameters(), 273 | self.model.named_parameters(), 274 | self.feats_out.parameters(), 275 | ) 276 | 277 | @property 278 | def num_recurrent_layers(self): 279 | return self.n_layers 280 | 281 | @property 282 | def recurrent_hidden_size(self): 283 | return self.n_embed 284 | 285 | @property 286 | def memory_size(self): 287 | return self.mem_len 288 | 289 | def gradient_checkpointing_enable(self): 290 | return self.model.gradient_checkpointing_enable() 291 | -------------------------------------------------------------------------------- /relic/tasks/sensors.py: -------------------------------------------------------------------------------- 1 | import gym.spaces as spaces 2 | import numpy as np 3 | from habitat.core.registry import registry 4 | from habitat.core.simulator import Sensor, SensorTypes 5 | 6 | 7 | @registry.register_sensor 8 | class RelativeLocalizationSensor(Sensor): 9 | cls_uuid = "rel_localization_sensor" 10 | 11 | def __init__(self, sim, config, *args, **kwargs): 12 | super().__init__(config=config) 13 | self._sim = sim 14 | 15 | def _get_uuid(self, *args, **kwargs): 16 | return RelativeLocalizationSensor.cls_uuid 17 | 18 | def _get_sensor_type(self, *args, **kwargs): 19 | return SensorTypes.TENSOR 20 | 21 | def _get_observation_space(self, *args, **kwargs): 22 | return spaces.Box( 23 | shape=(3,), 24 | low=np.finfo(np.float32).min, 25 | high=np.finfo(np.float32).max, 26 | dtype=np.float32, 27 | ) 28 | 29 | def get_observation(self, *args, task, **kwargs): 30 | agent = self._sim.get_agent_data(self.agent_id).articulated_agent 31 | 32 | rel_base_pos = agent.base_pos - task.agent_start.translation 33 | return np.array(rel_base_pos, dtype=np.float32) 34 | 35 | 36 | @registry.register_sensor 37 | class OneHotTargetSensor(Sensor): 38 | def __init__(self, *args, task, **kwargs): 39 | self._task = task 40 | # TODO: Hard-coded for the ycb objects. Change to work with any object 41 | # set. 42 | if self._task._config.get("is_large_objs", False): 43 | self._all_objs = ["bed", "couch", "chair", "tv", "plant", "toilet"] 44 | else: 45 | self._all_objs = [ 46 | "002_master_chef_can", # 0 47 | "003_cracker_box", # 1 48 | "004_sugar_box", # 2 49 | "005_tomato_soup_can", # 3 50 | "007_tuna_fish_can", # 4 51 | "008_pudding_box", # 5 52 | "009_gelatin_box", # 6 53 | "010_potted_meat_can", # 7 54 | "011_banana", # 8 55 | "012_strawberry", # 9 56 | "013_apple", # 10 57 | "014_lemon", # 11 58 | "015_peach", # 12 59 | "016_pear", # 13 60 | "017_orange", # 14 61 | "018_plum", # 15 62 | "021_bleach_cleanser", # 16 63 | "024_bowl", # 17 64 | "025_mug", # 18 65 | "026_sponge", # 19 66 | ] 67 | self._n_cls = len(self._all_objs) 68 | 69 | super().__init__(*args, **kwargs) 70 | 71 | def _get_uuid(self, *args, **kwargs): 72 | return "one_hot_target_sensor" 73 | 74 | def _get_sensor_type(self, *args, **kwargs): 75 | return SensorTypes.TENSOR 76 | 77 | def _get_observation_space(self, *args, config, **kwargs): 78 | return spaces.Box(shape=(self._n_cls,), low=0, high=1, dtype=np.float32) 79 | 80 | def get_observation(self, *args, **kwargs): 81 | cur_target = self._task.get_sampled()[0] 82 | 83 | # For receptacles the name will not be a class but the name directly. 84 | use_name = cur_target.expr_type.name 85 | if cur_target.name in self._all_objs: 86 | use_name = cur_target.name 87 | 88 | if use_name not in self._all_objs: 89 | raise ValueError( 90 | f"Object not found given {use_name}, {cur_target}, {self._task.get_sampled()}" 91 | ) 92 | set_i = self._all_objs.index(use_name) 93 | 94 | obs = np.zeros((self._n_cls,)) 95 | if use_name in self._all_objs: 96 | set_i = self._all_objs.index(use_name) 97 | obs[set_i] = 1.0 98 | 99 | return obs 100 | 101 | 102 | @registry.register_sensor 103 | class OneHotReceptacleSensor(Sensor): 104 | def __init__(self, *args, task, **kwargs): 105 | self._task = task 106 | # TODO: Hard-coded for the ycb objects. Change to work with any object 107 | # set. 108 | self._all_objs = [ 109 | None, 110 | "couch", 111 | "unknown", 112 | "table", 113 | "cabinet", 114 | "chair", 115 | "bed", 116 | "chest_of_drawers", 117 | "shelves", 118 | "stool", 119 | "toilet", 120 | "washer_dryer", 121 | "bench", 122 | "stand", 123 | "counter", 124 | "bathtub", 125 | "car", 126 | "wardrobe", 127 | "sink", 128 | "shower", 129 | ] 130 | self._n_cls = len(self._all_objs) 131 | 132 | super().__init__(*args, **kwargs) 133 | 134 | def _get_uuid(self, *args, **kwargs): 135 | return "one_hot_receptacle_sensor" 136 | 137 | def _get_sensor_type(self, *args, **kwargs): 138 | return SensorTypes.TENSOR 139 | 140 | def _get_observation_space(self, *args, config, **kwargs): 141 | return spaces.Box(shape=(self._n_cls,), low=0, high=1, dtype=np.float32) 142 | 143 | def get_observation(self, *args, **kwargs): 144 | rec_name = self._task._receptacle_name 145 | 146 | if rec_name not in self._all_objs: 147 | raise ValueError(f"Receptacle not found given {rec_name}.") 148 | set_i = self._all_objs.index(rec_name) 149 | 150 | obs = np.zeros((self._n_cls,)) 151 | if rec_name in self._all_objs: 152 | set_i = self._all_objs.index(rec_name) 153 | obs[set_i] = 1.0 154 | 155 | return obs 156 | 157 | 158 | from habitat.core.simulator import ( 159 | SemanticSensor, 160 | Sensor, 161 | VisualObservation, 162 | ) 163 | import habitat_sim 164 | from omegaconf import DictConfig 165 | from typing import ( 166 | TYPE_CHECKING, 167 | Any, 168 | Callable, 169 | Dict, 170 | List, 171 | Optional, 172 | Sequence, 173 | Set, 174 | Union, 175 | cast, 176 | ) 177 | 178 | if TYPE_CHECKING: 179 | from torch import Tensor 180 | 181 | from habitat.sims.habitat_simulator.habitat_simulator import ( 182 | HabitatSimSensor, 183 | check_sim_obs, 184 | HabitatSimSemanticSensor, 185 | ) 186 | 187 | 188 | @registry.register_sensor 189 | class MaskedSemanticSensor(Sensor): 190 | def _get_sensor_type(self, *args, **kwargs): 191 | return SensorTypes.TENSOR 192 | 193 | def _get_uuid(self, *args, **kwargs): 194 | return "masked_semantic_sensor" 195 | 196 | def __init__(self, *args, task, **kwargs) -> None: 197 | self._task = task 198 | # TODO: Hard-coded for the ycb objects. Change to work with any object 199 | # set. 200 | self._all_objs = [ 201 | "002_master_chef_can", 202 | "003_cracker_box", 203 | "004_sugar_box", 204 | "005_tomato_soup_can", 205 | "007_tuna_fish_can", 206 | "008_pudding_box", 207 | "009_gelatin_box", 208 | "010_potted_meat_can", 209 | "011_banana", 210 | "012_strawberry", 211 | "013_apple", 212 | "014_lemon", 213 | "015_peach", 214 | "016_pear", 215 | "017_orange", 216 | "018_plum", 217 | "021_bleach_cleanser", 218 | "024_bowl", 219 | "025_mug", 220 | "026_sponge", 221 | ] 222 | self._n_cls = len(self._all_objs) 223 | 224 | super().__init__(*args, **kwargs) 225 | 226 | def _get_observation_space(self, *args: Any, **kwargs: Any): 227 | return spaces.Box( 228 | low=0, 229 | high=1, 230 | shape=(self.config.height, self.config.width, 1), 231 | dtype=np.bool_, 232 | ) 233 | 234 | def get_observation(self, *args, observations, **kwargs) -> VisualObservation: 235 | obs = observations["head_semantic"] 236 | ids = [] 237 | for obj in self._task.new_entities.values(): 238 | ids.append(self._task.object_handle2id[obj.name]) 239 | 240 | obs_mask = np.zeros_like(obs, dtype="bool") 241 | for id_ in ids: 242 | obs_mask |= obs == id_ + self._task._sim.habitat_config.object_ids_start 243 | 244 | return obs_mask 245 | 246 | 247 | @registry.register_sensor 248 | class MaskedFlattenedSemanticSensor(Sensor): 249 | def _get_sensor_type(self, *args, **kwargs): 250 | return SensorTypes.TENSOR 251 | 252 | def _get_uuid(self, *args, **kwargs): 253 | return "masked_flattened_semantic_sensor" 254 | 255 | def __init__(self, *args, task, **kwargs) -> None: 256 | self._task = task 257 | # TODO: Hard-coded for the ycb objects. Change to work with any object 258 | # set. 259 | self._all_objs = [ 260 | "002_master_chef_can", # 0 261 | "003_cracker_box", # 1 262 | "004_sugar_box", # 2 263 | "005_tomato_soup_can", # 3 264 | "007_tuna_fish_can", # 4 265 | "008_pudding_box", # 5 266 | "009_gelatin_box", # 6 267 | "010_potted_meat_can", # 7 268 | "011_banana", # 8 269 | "012_strawberry", # 9 270 | "013_apple", # 10 271 | "014_lemon", # 11 272 | "015_peach", # 12 273 | "016_pear", # 13 274 | "017_orange", # 14 275 | "018_plum", # 15 276 | "021_bleach_cleanser", # 16 277 | "024_bowl", # 17 278 | "025_mug", # 18 279 | "026_sponge", # 19 280 | ] 281 | self._n_cls = len(self._all_objs) 282 | 283 | super().__init__(*args, **kwargs) 284 | 285 | def _get_observation_space(self, *args: Any, **kwargs: Any): 286 | return spaces.Box( 287 | low=0, 288 | high=1, 289 | shape=(self.config.n_cells * self.config.n_cells, 1), 290 | dtype=np.float32, 291 | ) 292 | 293 | def get_observation(self, *args, observations, **kwargs) -> VisualObservation: 294 | obs = observations["head_semantic"] 295 | ids = [] 296 | for obj in self._task.new_entities.values(): 297 | ids.append(self._task.object_handle2id[obj.name]) 298 | 299 | obs_mask = np.zeros_like(obs, dtype="bool") 300 | for id_ in ids: 301 | obs_mask |= obs == id_ + self._task._sim.habitat_config.object_ids_start 302 | 303 | obs_mask = obs_mask.astype("float32") 304 | 305 | STEP = obs.shape[0] // self.config.n_cells 306 | feats = np.zeros((self.config.n_cells * self.config.n_cells, 1)) 307 | for i in range(self.config.n_cells): 308 | for j in range(self.config.n_cells): 309 | feats[i * self.config.n_cells + j] = obs_mask[ 310 | i * STEP : (i + 1) * STEP, j * STEP : (j + 1) * STEP 311 | ].mean() 312 | 313 | return feats 314 | 315 | 316 | from habitat.tasks.nav.nav import PointGoalSensor 317 | 318 | 319 | @registry.register_sensor(name="PointGoalWithGPSCompassSensorV3") 320 | class IntegratedPointGoalGPSAndCompassSensorV3(PointGoalSensor): 321 | r"""Sensor that integrates PointGoals observations (which are used PointGoal Navigation) and GPS+Compass. 322 | 323 | For the agent in simulator the forward direction is along negative-z. 324 | In polar coordinate format the angle returned is azimuth to the goal. 325 | 326 | Args: 327 | sim: reference to the simulator for calculating task observations. 328 | config: config for the PointGoal sensor. Can contain field for 329 | `goal_format` which can be used to specify the format in which 330 | the pointgoal is specified. Current options for goal format are 331 | cartesian and polar. 332 | 333 | Also contains a `dimensionality` field which specifes the number 334 | of dimensions ued to specify the goal, must be in [2, 3] 335 | 336 | Attributes: 337 | _goal_format: format for specifying the goal which can be done 338 | in cartesian or polar coordinates. 339 | _dimensionality: number of dimensions used to specify the goal 340 | """ 341 | cls_uuid: str = "pointgoal_with_gps_compass_v3" 342 | 343 | def __init__(self, *args: Any, task, **kwargs: Any) -> str: 344 | self._task = task 345 | super().__init__(*args, **kwargs) 346 | self.counter = 0 347 | self.last_goal = None 348 | self.noise = 0 349 | 350 | def _get_uuid(self, *args: Any, **kwargs: Any) -> str: 351 | return self.cls_uuid 352 | 353 | def get_observation(self, observations, episode, *args: Any, **kwargs: Any): 354 | agent_state = self._sim.get_agent_state() 355 | agent_position = agent_state.position 356 | rotation_world_agent = agent_state.rotation 357 | if ( 358 | self.last_goal is None 359 | or (self.last_goal != self._task.random_obj_pos).any() 360 | ): 361 | self.last_goal = self._task.random_obj_pos 362 | self.noise = ( 363 | np.random.randn(len(self._task.random_obj_pos)) 364 | * self.counter 365 | / 1_000_000 366 | * self.config.get("std_noise_1m", 0) 367 | ) 368 | 369 | goal_position = (self._task.random_obj_pos) + self.noise 370 | self.counter += 1 371 | return self._compute_pointgoal( 372 | agent_position, rotation_world_agent, goal_position 373 | ) 374 | -------------------------------------------------------------------------------- /relic/tasks/actions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # TODO, lots of typing errors in here 8 | 9 | from copy import deepcopy 10 | from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union 11 | 12 | import attr 13 | import numpy as np 14 | import quaternion 15 | from gym import spaces 16 | 17 | from habitat.config import read_write 18 | from habitat.config.default import get_agent_config 19 | from habitat.core.dataset import Dataset, Episode 20 | from habitat.core.embodied_task import ( 21 | EmbodiedTask, 22 | ) 23 | from habitat.tasks.rearrange.actions.articulated_agent_action import ( 24 | ArticulatedAgentAction, 25 | ) 26 | from habitat.core.logging import logger 27 | from habitat.core.registry import registry 28 | from habitat.core.simulator import ( 29 | AgentState, 30 | RGBSensor, 31 | Sensor, 32 | SensorTypes, 33 | ShortestPathPoint, 34 | Simulator, 35 | ) 36 | from habitat.core.spaces import ActionSpace, EmptySpace 37 | from habitat.core.utils import not_none_validator, try_cv2_import 38 | from habitat.sims.habitat_simulator.actions import HabitatSimActions 39 | from habitat.tasks.utils import cartesian_to_polar 40 | from habitat.utils.geometry_utils import ( 41 | quaternion_from_coeff, 42 | quaternion_rotate_vector, 43 | ) 44 | from habitat.utils.visualizations import fog_of_war, maps 45 | 46 | try: 47 | from habitat.sims.habitat_simulator.habitat_simulator import HabitatSim 48 | from habitat_sim import RigidState 49 | from habitat_sim.physics import VelocityControl 50 | except ImportError: 51 | pass 52 | 53 | try: 54 | import magnum as mn 55 | except ImportError: 56 | pass 57 | 58 | if TYPE_CHECKING: 59 | from omegaconf import DictConfig 60 | 61 | #!/usr/bin/env python3 62 | 63 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 64 | # This source code is licensed under the MIT license found in the 65 | # LICENSE file in the root directory of this source tree. 66 | 67 | from typing import Optional, cast 68 | 69 | import magnum as mn 70 | import numpy as np 71 | from gym import spaces 72 | 73 | import habitat_sim 74 | from habitat.core.embodied_task import SimulatorTaskAction 75 | from habitat.core.registry import registry 76 | from habitat.sims.habitat_simulator.actions import HabitatSimActions 77 | from habitat.tasks.rearrange.actions.articulated_agent_action import ( 78 | ArticulatedAgentAction, 79 | ) 80 | 81 | # flake8: noqa 82 | # These actions need to be imported since there is a Python evaluation 83 | # statement which dynamically creates the desired grip controller. 84 | from habitat.tasks.rearrange.actions.grip_actions import ( 85 | GazeGraspAction, 86 | GripSimulatorTaskAction, 87 | MagicGraspAction, 88 | SuctionGraspAction, 89 | ) 90 | from habitat.tasks.rearrange.rearrange_sim import RearrangeSim 91 | from habitat.tasks.rearrange.utils import rearrange_collision, rearrange_logger 92 | 93 | from habitat.tasks.rearrange.actions.actions import BaseVelAction 94 | 95 | cv2 = try_cv2_import() 96 | 97 | 98 | def _strafe_body( 99 | sim, 100 | move_amount: float, 101 | strafe_angle_deg: float, 102 | noise_amount: float, 103 | ): 104 | # Get the state of the agent 105 | agent_state = sim.get_agent_state() 106 | # Convert from np.quaternion (quaternion.quaternion) to mn.Quaternion 107 | normalized_quaternion = agent_state.rotation 108 | agent_mn_quat = mn.Quaternion( 109 | normalized_quaternion.imag, normalized_quaternion.real 110 | ) 111 | forward = agent_mn_quat.transform_vector(-mn.Vector3.z_axis()) 112 | strafe_angle = np.random.uniform( 113 | (1 - noise_amount) * strafe_angle_deg, 114 | (1 + noise_amount) * strafe_angle_deg, 115 | ) 116 | strafe_angle = mn.Deg(strafe_angle) 117 | rotation = mn.Quaternion.rotation(strafe_angle, mn.Vector3.y_axis()) 118 | move_amount = np.random.uniform( 119 | (1 - noise_amount) * move_amount, (1 + noise_amount) * move_amount 120 | ) 121 | delta_position = rotation.transform_vector(forward) * move_amount 122 | final_position = sim.pathfinder.try_step( # type: ignore 123 | agent_state.position, agent_state.position + delta_position 124 | ) 125 | sim.set_agent_state( 126 | final_position, 127 | [*rotation.vector, rotation.scalar], 128 | reset_sensors=False, 129 | ) 130 | 131 | 132 | class NavigationMovementAgentAction(ArticulatedAgentAction): 133 | def __init__(self, *args, config, sim, **kwargs): 134 | super().__init__(*args, config=config, sim=sim, **kwargs) 135 | self._sim = sim 136 | self._tilt_angle = config.tilt_angle 137 | self.max_angle = 100 138 | self.min_angle = 20 139 | 140 | def _move_camera_vertical(self, amount: float): 141 | agent_data = self._sim.get_agent_data(0) 142 | for cam_name in agent_data.articulated_agent.params.cameras: 143 | *curr_xy, _ = agent_data.articulated_agent.params.cameras[ 144 | cam_name 145 | ].cam_look_at_pos 146 | norm = np.linalg.norm(curr_xy) 147 | 148 | angle = np.arctan2(*curr_xy) - amount / 180 * np.pi 149 | angle = min( 150 | max(angle, self.min_angle / 180 * np.pi), self.max_angle / 180 * np.pi 151 | ) 152 | x, y = np.sin(angle) * norm, np.cos(angle) * norm 153 | 154 | agent_data.articulated_agent.params.cameras[cam_name].cam_look_at_pos[0] = x 155 | agent_data.articulated_agent.params.cameras[cam_name].cam_look_at_pos[1] = y 156 | 157 | 158 | @registry.register_task_action 159 | class RearrangeMoveForwardAction(ArticulatedAgentAction): 160 | name: str = "rearrange_move_forward" 161 | 162 | def step(self, *args: Any, **kwargs: Any): 163 | r"""Update ``_metric``, this method is called from ``Env`` on each 164 | ``step``. 165 | """ 166 | # return self._sim.step(HabitatSimActions.move_forward) 167 | _strafe_body(self._sim, 0.25, 0, 0) 168 | 169 | 170 | @registry.register_task_action 171 | class RearrangeTurnLeftAction(ArticulatedAgentAction): 172 | name: str = "rearrange_turn_left" 173 | 174 | def step(self, *args: Any, **kwargs: Any): 175 | r"""Update ``_metric``, this method is called from ``Env`` on each 176 | ``step``. 177 | """ 178 | _strafe_body(self._sim, 0.0, 30, 0) 179 | 180 | 181 | @registry.register_task_action 182 | class RearrangeTurnRightAction(ArticulatedAgentAction): 183 | name: str = "rearrange_turn_right" 184 | 185 | def step(self, *args: Any, **kwargs: Any): 186 | r"""Update ``_metric``, this method is called from ``Env`` on each 187 | ``step``. 188 | """ 189 | _strafe_body(self._sim, 0.0, -30, 0) 190 | 191 | 192 | @registry.register_task_action 193 | class RearrangeStopAction(ArticulatedAgentAction): 194 | name: str = "rearrange_stop" 195 | 196 | def reset(self, task: EmbodiedTask, *args: Any, **kwargs: Any): 197 | self.does_want_terminate = False # type: ignore 198 | 199 | def step(self, task: EmbodiedTask, *args: Any, **kwargs: Any): 200 | r"""Update ``_metric``, this method is called from ``Env`` on each 201 | ``step``. 202 | """ 203 | self.does_want_terminate = True # type: ignore 204 | 205 | 206 | @registry.register_task_action 207 | class RearrangeLookUpAction(NavigationMovementAgentAction): 208 | name: str = "rearrange_look_up" 209 | 210 | def __init__(self, *args: Any, **kwargs: Any): 211 | super().__init__(*args, **kwargs) 212 | 213 | agent_data = self._sim.get_agent_data(0) 214 | self.cams_init_pos = {} 215 | 216 | for cam_name in agent_data.articulated_agent.params.cameras: 217 | self.cams_init_pos[cam_name] = list( 218 | agent_data.articulated_agent.params.cameras[cam_name].cam_look_at_pos 219 | ) 220 | 221 | def step(self, *args: Any, **kwargs: Any): 222 | r"""Update ``_metric``, this method is called from ``Env`` on each 223 | ``step``. 224 | """ 225 | self._move_camera_vertical(self._tilt_angle) 226 | 227 | def reset(self, *args, **kwargs): 228 | super().reset(*args, **kwargs) 229 | 230 | agent_data = self._sim.get_agent_data(0) 231 | for cam_name in agent_data.articulated_agent.params.cameras: 232 | agent_data.articulated_agent.params.cameras[ 233 | cam_name 234 | ].cam_look_at_pos = self.cams_init_pos[cam_name].copy() 235 | 236 | 237 | @registry.register_task_action 238 | class RearrangeLookDownAction(NavigationMovementAgentAction): 239 | name: str = "rearrange_look_down" 240 | 241 | def __init__(self, *args: Any, **kwargs: Any): 242 | super().__init__(*args, **kwargs) 243 | 244 | agent_data = self._sim.get_agent_data(0) 245 | self.cams_init_pos = {} 246 | 247 | for cam_name in agent_data.articulated_agent.params.cameras: 248 | self.cams_init_pos[cam_name] = list( 249 | agent_data.articulated_agent.params.cameras[cam_name].cam_look_at_pos 250 | ) 251 | 252 | def step(self, *args: Any, **kwargs: Any): 253 | r"""Update ``_metric``, this method is called from ``Env`` on each 254 | ``step``. 255 | """ 256 | self._move_camera_vertical(-self._tilt_angle) 257 | 258 | def reset(self, *args, **kwargs): 259 | super().reset(*args, **kwargs) 260 | 261 | agent_data = self._sim.get_agent_data(0) 262 | for cam_name in agent_data.articulated_agent.params.cameras: 263 | agent_data.articulated_agent.params.cameras[ 264 | cam_name 265 | ].cam_look_at_pos = self.cams_init_pos[cam_name].copy() 266 | 267 | 268 | @registry.register_task_action 269 | class DiscreteMoveForward(BaseVelAction): 270 | """ 271 | The articulated agent base motion is constrained to the NavMesh and controlled with velocity commands integrated with the VelocityControl interface. 272 | 273 | Optionally cull states with active collisions if config parameter `allow_dyn_slide` is True 274 | """ 275 | 276 | def step(self, *args, **kwargs): 277 | lin_vel = 0.25 278 | ang_vel = 0 279 | 280 | if not self._allow_back: 281 | lin_vel = np.maximum(lin_vel, 0) 282 | 283 | self.base_vel_ctrl.linear_velocity = mn.Vector3(lin_vel, 0, 0) 284 | self.base_vel_ctrl.angular_velocity = mn.Vector3(0, ang_vel, 0) 285 | 286 | if lin_vel != 0.0 or ang_vel != 0.0: 287 | self.update_base() 288 | 289 | 290 | @registry.register_task_action 291 | class DiscreteTurnLeft(BaseVelAction): 292 | """ 293 | The articulated agent base motion is constrained to the NavMesh and controlled with velocity commands integrated with the VelocityControl interface. 294 | 295 | Optionally cull states with active collisions if config parameter `allow_dyn_slide` is True 296 | """ 297 | 298 | def step(self, *args, **kwargs): 299 | lin_vel = 0 300 | ang_vel = np.pi / 180 * 30 301 | 302 | if not self._allow_back: 303 | lin_vel = np.maximum(lin_vel, 0) 304 | 305 | self.base_vel_ctrl.linear_velocity = mn.Vector3(lin_vel, 0, 0) 306 | self.base_vel_ctrl.angular_velocity = mn.Vector3(0, ang_vel, 0) 307 | 308 | if lin_vel != 0.0 or ang_vel != 0.0: 309 | self.update_base() 310 | 311 | 312 | @registry.register_task_action 313 | class DiscreteTurnRight(BaseVelAction): 314 | """ 315 | The articulated agent base motion is constrained to the NavMesh and controlled with velocity commands integrated with the VelocityControl interface. 316 | 317 | Optionally cull states with active collisions if config parameter `allow_dyn_slide` is True 318 | """ 319 | 320 | def step(self, *args, **kwargs): 321 | lin_vel = 0 322 | ang_vel = -np.pi / 180 * 30 323 | 324 | if not self._allow_back: 325 | lin_vel = np.maximum(lin_vel, 0) 326 | 327 | self.base_vel_ctrl.linear_velocity = mn.Vector3(lin_vel, 0, 0) 328 | self.base_vel_ctrl.angular_velocity = mn.Vector3(0, ang_vel, 0) 329 | 330 | if lin_vel != 0.0 or ang_vel != 0.0: 331 | self.update_base() 332 | 333 | 334 | @registry.register_task_action 335 | class DiscreteTurnLeft(BaseVelAction): 336 | """ 337 | The articulated agent base motion is constrained to the NavMesh and controlled with velocity commands integrated with the VelocityControl interface. 338 | 339 | Optionally cull states with active collisions if config parameter `allow_dyn_slide` is True 340 | """ 341 | 342 | def step(self, *args, **kwargs): 343 | lin_vel = 0 344 | ang_vel = np.pi / 180 * 30 345 | 346 | if not self._allow_back: 347 | lin_vel = np.maximum(lin_vel, 0) 348 | 349 | self.base_vel_ctrl.linear_velocity = mn.Vector3(lin_vel, 0, 0) 350 | self.base_vel_ctrl.angular_velocity = mn.Vector3(0, ang_vel, 0) 351 | 352 | if lin_vel != 0.0 or ang_vel != 0.0: 353 | self.update_base() 354 | 355 | 356 | @registry.register_task_action 357 | class DiscreteMoveGeneric(BaseVelAction): 358 | """ 359 | The articulated agent base motion is constrained to the NavMesh and controlled with velocity commands integrated with the VelocityControl interface. 360 | 361 | Optionally cull states with active collisions if config parameter `allow_dyn_slide` is True 362 | """ 363 | 364 | @property 365 | def action_space(self): 366 | return EmptySpace() 367 | 368 | def step(self, *args, **kwargs): 369 | lin_vel = self._lin_speed * 30 / 0.25 370 | ang_vel = self._ang_speed 371 | 372 | if not self._allow_back: 373 | lin_vel = np.maximum(lin_vel, 0) 374 | 375 | self.base_vel_ctrl.linear_velocity = mn.Vector3(lin_vel, 0, 0) 376 | self.base_vel_ctrl.angular_velocity = mn.Vector3(0, ang_vel, 0) 377 | 378 | if lin_vel != 0.0 or ang_vel != 0.0: 379 | self.update_base() 380 | 381 | 382 | @registry.register_task_action 383 | class RearrangeCameraZoom(ArticulatedAgentAction): 384 | name: str = "rearrange_zoom" 385 | 386 | def __init__(self, *args, config, sim, **kwargs): 387 | super().__init__(*args, config=config, sim=sim, **kwargs) 388 | self._sim = sim 389 | self._zoom_amount = config.zoom_amount 390 | 391 | def _zoom(self, amount: Optional[float] = None): 392 | sensors_info = self._sim.agents[0]._sensors 393 | for cam in sensors_info.values(): 394 | if amount is None: 395 | cam.reset_zoom() 396 | else: 397 | cam.zoom(amount) 398 | 399 | def step(self, *args: Any, **kwargs: Any): 400 | r"""Update ``_metric``, this method is called from ``Env`` on each 401 | ``step``. 402 | """ 403 | self._zoom(self._zoom_amount) 404 | -------------------------------------------------------------------------------- /relic/evaluator/habitat_evalutor.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | from typing import Any, Dict, List 4 | 5 | import numpy as np 6 | import torch 7 | import tqdm 8 | 9 | from habitat import logger 10 | from habitat.tasks.rearrange.rearrange_sensors import GfxReplayMeasure 11 | from habitat.tasks.rearrange.utils import write_gfx_replay 12 | from habitat.utils.visualizations.utils import ( 13 | observations_to_image, 14 | overlay_frame, 15 | ) 16 | from habitat_baselines.common.obs_transformers import ( 17 | apply_obs_transforms_batch, 18 | ) 19 | from habitat_baselines.rl.ppo.evaluator import Evaluator 20 | from relic.evaluator.repEps_evaluator import pause_envs 21 | from habitat_baselines.utils.common import ( 22 | batch_obs, 23 | generate_video, 24 | get_action_space_info, 25 | inference_mode, 26 | is_continuous_action_space, 27 | ) 28 | from habitat_baselines.utils.info_dict import extract_scalars_from_info 29 | 30 | 31 | class TransformersHabitatEvaluator(Evaluator): 32 | """ 33 | Evaluator for Habitat environments. 34 | """ 35 | 36 | def evaluate_agent( 37 | self, 38 | agent, 39 | envs, 40 | config, 41 | checkpoint_index, 42 | step_id, 43 | writer, 44 | device, 45 | obs_transforms, 46 | env_spec, 47 | rank0_keys, 48 | ): 49 | observations = envs.reset() 50 | observations = envs.post_step(observations) 51 | batch = batch_obs(observations, device=device) 52 | batch = apply_obs_transforms_batch(batch, obs_transforms) # type: ignore 53 | 54 | action_shape, discrete_actions = get_action_space_info( 55 | agent._policy_action_space 56 | ) 57 | 58 | current_episode_reward = torch.zeros(envs.num_envs, 1, device="cpu") 59 | 60 | test_recurrent_hidden_states = torch.zeros( 61 | ( 62 | agent.actor_critic.num_recurrent_layers, 63 | 2, 64 | envs.num_envs, 65 | agent.actor_critic.num_heads, 66 | 1001, 67 | agent.actor_critic.recurrent_hidden_size 68 | // agent.actor_critic.num_heads, 69 | ), 70 | device=device, 71 | ) 72 | should_update_recurrent_hidden_states = ( 73 | np.prod(test_recurrent_hidden_states.shape) != 0 74 | ) 75 | prev_actions = torch.zeros( 76 | envs.num_envs, 77 | *action_shape, 78 | device=device, 79 | dtype=torch.long if discrete_actions else torch.float, 80 | ) 81 | not_done_masks = torch.zeros( 82 | envs.num_envs, 83 | 1001, 84 | 1, 85 | device=device, 86 | dtype=torch.bool, 87 | ) 88 | 89 | stats_episodes: Dict[ 90 | Any, Any 91 | ] = {} # dict of dicts that stores stats per episode 92 | ep_eval_count: Dict[Any, int] = defaultdict(lambda: 0) 93 | 94 | if len(config.habitat_baselines.eval.video_option) > 0: 95 | # Add the first frame of the episode to the video. 96 | rgb_frames: List[List[np.ndarray]] = [ 97 | [observations_to_image({k: v[env_idx] for k, v in batch.items()}, {})] 98 | for env_idx in range(envs.num_envs) 99 | ] 100 | else: 101 | rgb_frames = None 102 | 103 | if len(config.habitat_baselines.eval.video_option) > 0: 104 | os.makedirs(config.habitat_baselines.video_dir, exist_ok=True) 105 | 106 | number_of_eval_episodes = config.habitat_baselines.test_episode_count 107 | evals_per_ep = config.habitat_baselines.eval.evals_per_ep 108 | if number_of_eval_episodes == -1: 109 | number_of_eval_episodes = sum(envs.number_of_episodes) 110 | else: 111 | total_num_eps = sum(envs.number_of_episodes) 112 | # if total_num_eps is negative, it means the number of evaluation episodes is unknown 113 | if total_num_eps < number_of_eval_episodes and total_num_eps > 1: 114 | logger.warn( 115 | f"Config specified {number_of_eval_episodes} eval episodes" 116 | ", dataset only has {total_num_eps}." 117 | ) 118 | logger.warn(f"Evaluating with {total_num_eps} instead.") 119 | number_of_eval_episodes = total_num_eps 120 | else: 121 | assert evals_per_ep == 1 122 | assert ( 123 | number_of_eval_episodes > 0 124 | ), "You must specify a number of evaluation episodes with test_episode_count" 125 | 126 | pbar = tqdm.tqdm(total=number_of_eval_episodes * evals_per_ep) 127 | agent.eval() 128 | _step_num_ae = 0 129 | last_step_done = [0] * envs.num_envs 130 | while ( 131 | len(stats_episodes) < (number_of_eval_episodes * evals_per_ep) 132 | and envs.num_envs > 0 133 | ): 134 | current_episodes_info = envs.current_episodes() 135 | shift_n_steps = min(last_step_done) 136 | 137 | if shift_n_steps > 500: 138 | test_recurrent_hidden_states[ 139 | ..., : _step_num_ae - shift_n_steps + 1, : 140 | ] = test_recurrent_hidden_states[ 141 | ..., shift_n_steps : _step_num_ae + 1, : 142 | ] 143 | not_done_masks[:, : _step_num_ae - shift_n_steps + 1] = not_done_masks[ 144 | :, shift_n_steps : _step_num_ae + 1 145 | ] 146 | _step_num_ae -= shift_n_steps 147 | last_step_done = [v - shift_n_steps for v in last_step_done] 148 | 149 | with inference_mode(): 150 | action_data, outputs = agent.actor_critic.act( 151 | batch, 152 | test_recurrent_hidden_states[..., 1 : _step_num_ae + 1, :], 153 | prev_actions, 154 | not_done_masks[:, : _step_num_ae + 1], 155 | deterministic=False, 156 | output_attentions=True, 157 | ) 158 | outputs.past_key_values = None 159 | if action_data.should_inserts is None: 160 | test_recurrent_hidden_states[ 161 | ..., _step_num_ae + 1, : 162 | ] = action_data.rnn_hidden_states 163 | prev_actions.copy_(action_data.actions) # type: ignore 164 | else: 165 | for i, should_insert in enumerate(action_data.should_inserts): 166 | if not should_insert.item(): 167 | continue 168 | if should_update_recurrent_hidden_states: 169 | test_recurrent_hidden_states[ 170 | :, :, i, ..., _step_num_ae + 1, : 171 | ] = action_data.rnn_hidden_states[i] 172 | prev_actions[i].copy_(action_data.actions[i]) # type: ignore 173 | 174 | # NB: Move actions to CPU. If CUDA tensors are 175 | # sent in to env.step(), that will create CUDA contexts 176 | # in the subprocesses. 177 | if is_continuous_action_space(env_spec.action_space): 178 | # Clipping actions to the specified limits 179 | step_data = [ 180 | np.clip( 181 | a.numpy(), 182 | env_spec.action_space.low, 183 | env_spec.action_space.high, 184 | ) 185 | for a in action_data.env_actions.cpu() 186 | ] 187 | else: 188 | step_data = [a.item() for a in action_data.env_actions.cpu()] 189 | 190 | outputs = envs.step(step_data) 191 | 192 | observations, rewards_l, dones, infos = [list(x) for x in zip(*outputs)] 193 | # Note that `policy_infos` represents the information about the 194 | # action BEFORE `observations` (the action used to transition to 195 | # `observations`). 196 | policy_infos = agent.actor_critic.get_extra(action_data, infos, dones) 197 | for i in range(len(policy_infos)): 198 | infos[i].update(policy_infos[i]) 199 | 200 | observations = envs.post_step(observations) 201 | batch = batch_obs( # type: ignore 202 | observations, 203 | device=device, 204 | ) 205 | batch = apply_obs_transforms_batch(batch, obs_transforms) # type: ignore 206 | 207 | not_done_masks[:, _step_num_ae + 1] = torch.tensor( 208 | [[not done] for done in dones], 209 | dtype=torch.bool, 210 | device="cpu", 211 | ) 212 | 213 | last_step_done = [ 214 | v if done else _step_num_ae for v, done in zip(last_step_done, dones) 215 | ] 216 | 217 | rewards = torch.tensor( 218 | rewards_l, dtype=torch.float, device="cpu" 219 | ).unsqueeze(1) 220 | current_episode_reward += rewards 221 | next_episodes_info = envs.current_episodes() 222 | envs_to_pause = [] 223 | n_envs = envs.num_envs 224 | for i in range(n_envs): 225 | if ( 226 | ep_eval_count[ 227 | ( 228 | next_episodes_info[i].scene_id, 229 | next_episodes_info[i].episode_id, 230 | ) 231 | ] 232 | == evals_per_ep 233 | ): 234 | envs_to_pause.append(i) 235 | 236 | # Exclude the keys from `_rank0_keys` from displaying in the video 237 | disp_info = {k: v for k, v in infos[i].items() if k not in rank0_keys} 238 | 239 | if len(config.habitat_baselines.eval.video_option) > 0: 240 | # TODO move normalization / channel changing out of the policy and undo it here 241 | frame = observations_to_image( 242 | {k: v[i] for k, v in batch.items()}, disp_info 243 | ) 244 | if not not_done_masks[i, _step_num_ae + 1].item(): 245 | # The last frame corresponds to the first frame of the next episode 246 | # but the info is correct. So we use a black frame 247 | final_frame = observations_to_image( 248 | {k: v[i] * 0.0 for k, v in batch.items()}, 249 | disp_info, 250 | ) 251 | final_frame = overlay_frame(final_frame, disp_info) 252 | rgb_frames[i].append(final_frame) 253 | # The starting frame of the next episode will be the final element.. 254 | rgb_frames[i].append(frame) 255 | else: 256 | frame = overlay_frame(frame, disp_info) 257 | rgb_frames[i].append(frame) 258 | 259 | # episode ended 260 | if not not_done_masks[i, _step_num_ae + 1].item(): 261 | pbar.update() 262 | episode_stats = {"reward": current_episode_reward[i].item()} 263 | episode_stats.update(extract_scalars_from_info(infos[i])) 264 | current_episode_reward[i] = 0 265 | k = ( 266 | current_episodes_info[i].scene_id, 267 | current_episodes_info[i].episode_id, 268 | ) 269 | ep_eval_count[k] += 1 270 | # use scene_id + episode_id as unique id for storing stats 271 | stats_episodes[(k, ep_eval_count[k])] = episode_stats 272 | 273 | if len(config.habitat_baselines.eval.video_option) > 0: 274 | generate_video( 275 | video_option=config.habitat_baselines.eval.video_option, 276 | video_dir=config.habitat_baselines.video_dir, 277 | # Since the final frame is the start frame of the next episode. 278 | images=rgb_frames[i][:-1], 279 | episode_id=f"{current_episodes_info[i].episode_id}_{ep_eval_count[k]}", 280 | checkpoint_idx=checkpoint_index, 281 | metrics=extract_scalars_from_info(disp_info), 282 | fps=config.habitat_baselines.video_fps, 283 | tb_writer=writer, 284 | keys_to_include_in_name=config.habitat_baselines.eval_keys_to_include_in_name, 285 | ) 286 | 287 | # Since the starting frame of the next episode is the final frame. 288 | rgb_frames[i] = rgb_frames[i][-1:] 289 | 290 | gfx_str = infos[i].get(GfxReplayMeasure.cls_uuid, "") 291 | if gfx_str != "": 292 | write_gfx_replay( 293 | gfx_str, 294 | config.habitat.task, 295 | current_episodes_info[i].episode_id, 296 | ) 297 | 298 | not_done_masks = not_done_masks.to(device=device) 299 | ( 300 | envs, 301 | test_recurrent_hidden_states, 302 | not_done_masks, 303 | current_episode_reward, 304 | prev_actions, 305 | batch, 306 | rgb_frames, 307 | ) = pause_envs( 308 | envs_to_pause, 309 | envs, 310 | test_recurrent_hidden_states, 311 | not_done_masks, 312 | current_episode_reward, 313 | prev_actions, 314 | batch, 315 | rgb_frames, 316 | ) 317 | _step_num_ae += 1 318 | 319 | pbar.close() 320 | """ 321 | assert ( 322 | len(ep_eval_count) >= number_of_eval_episodes 323 | ), f"Expected {number_of_eval_episodes} episodes, got {len(ep_eval_count)}." 324 | """ 325 | aggregated_stats = {} 326 | all_ks = set() 327 | for ep in stats_episodes.values(): 328 | all_ks.update(ep.keys()) 329 | for stat_key in all_ks: 330 | aggregated_stats[stat_key] = np.mean( 331 | [v[stat_key] for v in stats_episodes.values() if stat_key in v] 332 | ) 333 | 334 | for k, v in aggregated_stats.items(): 335 | logger.info(f"Average episode {k}: {v:.4f}") 336 | 337 | writer.add_scalar( 338 | "eval_reward/average_reward", aggregated_stats["reward"], step_id 339 | ) 340 | 341 | metrics = {k: v for k, v in aggregated_stats.items() if k != "reward"} 342 | for k, v in metrics.items(): 343 | writer.add_scalar(f"eval_metrics/{k}", v, step_id) 344 | -------------------------------------------------------------------------------- /relic/policies/transformerxl/modeling_transformerxl.py: -------------------------------------------------------------------------------- 1 | from transformers.models.transfo_xl.modeling_transfo_xl import * 2 | from transformers.models.transfo_xl.modeling_transfo_xl import ( 3 | _CHECKPOINT_FOR_DOC, 4 | _CONFIG_FOR_DOC, 5 | ) 6 | 7 | from transformers.models.llama.modeling_llama import LlamaMLP 8 | 9 | 10 | class RelPartialLearnableDecoderLayer(nn.Module): 11 | def __init__( 12 | self, 13 | n_head, 14 | d_model, 15 | d_head, 16 | d_inner, 17 | dropout, 18 | layer_norm_epsilon=1e-5, 19 | **kwargs 20 | ): 21 | super().__init__() 22 | 23 | self.dec_attn = RelPartialLearnableMultiHeadAttn( 24 | n_head, 25 | d_model, 26 | d_head, 27 | dropout, 28 | layer_norm_epsilon=layer_norm_epsilon, 29 | **kwargs 30 | ) 31 | 32 | assert kwargs.get("pre_lnorm") 33 | 34 | self.layer_norm = nn.LayerNorm(d_model, eps=layer_norm_epsilon) 35 | 36 | class MLPConfig: 37 | hidden_size = d_model 38 | intermediate_size = d_inner 39 | hidden_act = "gelu_new" 40 | pretraining_tp = 1 41 | 42 | self.pos_ff = LlamaMLP(MLPConfig) 43 | 44 | def forward( 45 | self, 46 | dec_inp, 47 | r, 48 | dec_attn_mask=None, 49 | mems=None, 50 | head_mask=None, 51 | output_attentions=False, 52 | ): 53 | attn_outputs = self.dec_attn( 54 | dec_inp, 55 | r, 56 | attn_mask=dec_attn_mask, 57 | mems=mems, 58 | head_mask=head_mask, 59 | output_attentions=output_attentions, 60 | ) 61 | 62 | hidden_states = attn_outputs[0] 63 | residual = hidden_states 64 | hidden_states = self.layer_norm(hidden_states) 65 | hidden_states = self.pos_ff(hidden_states) 66 | ff_output = residual + hidden_states 67 | 68 | outputs = [ff_output] + attn_outputs[1:] 69 | 70 | return outputs 71 | 72 | 73 | TRANSFO_XL_START_DOCSTRING = r""" 74 | 75 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 76 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 77 | etc.) 78 | 79 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 80 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 81 | and behavior. 82 | 83 | Parameters: 84 | config ([`TransfoXLConfig`]): Model configuration class with all the parameters of the model. 85 | Initializing with a config file does not load the weights associated with the model, only the 86 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 87 | """ 88 | 89 | TRANSFO_XL_INPUTS_DOCSTRING = r""" 90 | Args: 91 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 92 | Indices of input sequence tokens in the vocabulary. 93 | 94 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 95 | [`PreTrainedTokenizer.__call__`] for details. 96 | 97 | [What are input IDs?](../glossary#input-ids) 98 | mems (`List[torch.FloatTensor]` of length `config.n_layers`): 99 | Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see 100 | `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems 101 | given to this model should not be passed as `input_ids` as they have already been computed. 102 | head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): 103 | Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: 104 | 105 | - 1 indicates the head is **not masked**, 106 | - 0 indicates the head is **masked**. 107 | 108 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 109 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 110 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 111 | model's internal embedding lookup matrix. 112 | output_attentions (`bool`, *optional*): 113 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 114 | tensors for more detail. 115 | output_hidden_states (`bool`, *optional*): 116 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 117 | more detail. 118 | return_dict (`bool`, *optional*): 119 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 120 | """ 121 | 122 | 123 | @add_start_docstrings( 124 | "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.", 125 | TRANSFO_XL_START_DOCSTRING, 126 | ) 127 | class TransfoXLModel(TransfoXLPreTrainedModel): 128 | def __init__(self, config): 129 | super().__init__(config) 130 | 131 | self.n_token = config.vocab_size 132 | 133 | self.d_embed = config.d_embed 134 | self.d_model = config.d_model 135 | self.n_head = config.n_head 136 | self.d_head = config.d_head 137 | 138 | # self.word_emb = AdaptiveEmbedding( 139 | # config.vocab_size, config.d_embed, config.d_model, config.cutoffs, div_val=config.div_val 140 | # ) 141 | 142 | self.drop = nn.Dropout(config.dropout) 143 | 144 | self.n_layer = config.n_layer 145 | self.mem_len = config.mem_len 146 | self.attn_type = config.attn_type 147 | 148 | if not config.untie_r: 149 | self.r_w_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) 150 | self.r_r_bias = nn.Parameter(torch.FloatTensor(self.n_head, self.d_head)) 151 | 152 | self.layers = nn.ModuleList() 153 | if config.attn_type == 0: # the default attention 154 | for i in range(config.n_layer): 155 | self.layers.append( 156 | RelPartialLearnableDecoderLayer( 157 | config.n_head, 158 | config.d_model, 159 | config.d_head, 160 | config.d_inner, 161 | config.dropout, 162 | dropatt=config.dropatt, 163 | pre_lnorm=config.pre_lnorm, 164 | r_w_bias=None if config.untie_r else self.r_w_bias, 165 | r_r_bias=None if config.untie_r else self.r_r_bias, 166 | layer_norm_epsilon=config.layer_norm_epsilon, 167 | ) 168 | ) 169 | else: # learnable embeddings and absolute embeddings are not used in our pretrained checkpoints 170 | raise NotImplementedError # Removed them to avoid maintaining dead code 171 | 172 | self.same_length = config.same_length 173 | self.clamp_len = config.clamp_len 174 | 175 | if self.attn_type == 0: # default attention 176 | self.pos_emb = PositionalEmbedding(self.d_model) 177 | else: # learnable embeddings and absolute embeddings 178 | raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint 179 | 180 | # Initialize weights and apply final processing 181 | self.post_init() 182 | 183 | def get_input_embeddings(self): 184 | return self.word_emb 185 | 186 | def set_input_embeddings(self, new_embeddings): 187 | self.word_emb = new_embeddings 188 | 189 | def backward_compatible(self): 190 | self.sample_softmax = -1 191 | 192 | def reset_memory_length(self, mem_len): 193 | self.mem_len = mem_len 194 | 195 | def _prune_heads(self, heads): 196 | logger.info("Head pruning is not implemented for Transformer-XL model") 197 | pass 198 | 199 | def init_mems(self, bsz): 200 | if self.mem_len > 0: 201 | mems = [] 202 | param = next(self.parameters()) 203 | for i in range(self.n_layer): 204 | empty = torch.zeros( 205 | self.mem_len, 206 | bsz, 207 | self.config.d_model, 208 | dtype=param.dtype, 209 | device=param.device, 210 | ) 211 | mems.append(empty) 212 | 213 | return mems 214 | else: 215 | return None 216 | 217 | def _update_mems(self, hids, mems, mlen, qlen): 218 | # does not deal with None 219 | if mems is None: 220 | return None 221 | 222 | # mems is not None 223 | assert len(hids) == len(mems), "len(hids) != len(mems)" 224 | 225 | # There are `mlen + qlen` steps that can be cached into mems 226 | with torch.no_grad(): 227 | new_mems = [] 228 | end_idx = mlen + max(0, qlen) 229 | beg_idx = max(0, end_idx - self.mem_len) 230 | for i in range(len(hids)): 231 | cat = torch.cat([mems[i], hids[i]], dim=0) 232 | new_mems.append(cat[beg_idx:end_idx].detach()) 233 | 234 | return new_mems 235 | 236 | @add_start_docstrings_to_model_forward(TRANSFO_XL_INPUTS_DOCSTRING) 237 | @add_code_sample_docstrings( 238 | checkpoint=_CHECKPOINT_FOR_DOC, 239 | output_type=TransfoXLModelOutput, 240 | config_class=_CONFIG_FOR_DOC, 241 | ) 242 | def forward( 243 | self, 244 | input_ids: Optional[torch.LongTensor] = None, 245 | mems: Optional[List[torch.FloatTensor]] = None, 246 | head_mask: Optional[torch.FloatTensor] = None, 247 | inputs_embeds: Optional[torch.FloatTensor] = None, 248 | output_attentions: Optional[bool] = None, 249 | output_hidden_states: Optional[bool] = None, 250 | return_dict: Optional[bool] = None, 251 | ) -> Union[Tuple, TransfoXLModelOutput]: 252 | output_attentions = ( 253 | output_attentions 254 | if output_attentions is not None 255 | else self.config.output_attentions 256 | ) 257 | output_hidden_states = ( 258 | output_hidden_states 259 | if output_hidden_states is not None 260 | else self.config.output_hidden_states 261 | ) 262 | return_dict = ( 263 | return_dict if return_dict is not None else self.config.use_return_dict 264 | ) 265 | 266 | # the original code for Transformer-XL used shapes [len, bsz] but we want a unified interface in the library 267 | # so we transpose here from shape [bsz, len] to shape [len, bsz] 268 | if input_ids is not None and inputs_embeds is not None: 269 | raise ValueError( 270 | "You cannot specify both input_ids and inputs_embeds at the same time" 271 | ) 272 | elif input_ids is not None: 273 | input_ids = input_ids.transpose(0, 1).contiguous() 274 | qlen, bsz = input_ids.size() 275 | elif inputs_embeds is not None: 276 | inputs_embeds = inputs_embeds.transpose(0, 1).contiguous() 277 | qlen, bsz = inputs_embeds.shape[0], inputs_embeds.shape[1] 278 | else: 279 | raise ValueError("You have to specify either input_ids or inputs_embeds") 280 | 281 | if mems is None: 282 | mems = self.init_mems(bsz) 283 | 284 | # Prepare head mask if needed 285 | # 1.0 in head_mask indicate we keep the head 286 | # attention_probs has shape bsz x n_heads x N x N 287 | # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] (a head_mask for each layer) 288 | # and head_mask is converted to shape [num_hidden_layers x qlen x klen x bsz x n_head] 289 | if head_mask is not None: 290 | if head_mask.dim() == 1: 291 | head_mask = ( 292 | head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0) 293 | ) 294 | head_mask = head_mask.expand(self.n_layer, -1, -1, -1, -1) 295 | elif head_mask.dim() == 2: 296 | head_mask = head_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1) 297 | head_mask = head_mask.to( 298 | dtype=next(self.parameters()).dtype 299 | ) # switch to float if need + fp16 compatibility 300 | else: 301 | head_mask = [None] * self.n_layer 302 | 303 | if inputs_embeds is not None: 304 | word_emb = inputs_embeds 305 | else: 306 | word_emb = self.word_emb(input_ids) 307 | 308 | mlen = mems[0].size(0) if mems is not None else 0 309 | klen = mlen + qlen 310 | if self.same_length: 311 | all_ones = word_emb.new_ones((qlen, klen), dtype=torch.bool) 312 | mask_len = klen - self.mem_len 313 | if mask_len > 0: 314 | mask_shift_len = qlen - mask_len 315 | else: 316 | mask_shift_len = qlen 317 | dec_attn_mask = ( 318 | torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones, -mask_shift_len) 319 | )[ 320 | :, :, None 321 | ] # -1 322 | else: 323 | dec_attn_mask = torch.triu( 324 | word_emb.new_ones((qlen, klen), dtype=torch.bool), diagonal=1 + mlen 325 | )[:, :, None] 326 | 327 | hids = [] 328 | attentions = [] if output_attentions else None 329 | if self.attn_type == 0: # default 330 | pos_seq = torch.arange( 331 | klen - 1, -1, -1.0, device=word_emb.device, dtype=word_emb.dtype 332 | ) 333 | if self.clamp_len > 0: 334 | pos_seq.clamp_(max=self.clamp_len) 335 | pos_emb = self.pos_emb(pos_seq) 336 | 337 | core_out = self.drop(word_emb) 338 | pos_emb = self.drop(pos_emb) 339 | 340 | for i, layer in enumerate(self.layers): 341 | hids.append(core_out) 342 | mems_i = None if mems is None else mems[i] 343 | layer_outputs = layer( 344 | core_out, 345 | pos_emb, 346 | dec_attn_mask=dec_attn_mask, 347 | mems=mems_i, 348 | head_mask=head_mask[i], 349 | output_attentions=output_attentions, 350 | ) 351 | core_out = layer_outputs[0] 352 | if output_attentions: 353 | attentions.append(layer_outputs[1]) 354 | else: # learnable embeddings and absolute embeddings 355 | raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint 356 | 357 | core_out = self.drop(core_out) 358 | 359 | new_mems = self._update_mems(hids, mems, mlen, qlen) 360 | 361 | if output_hidden_states: 362 | # Add last layer and transpose to library standard shape [bsz, len, hidden_dim] 363 | hids.append(core_out) 364 | hids = tuple(t.transpose(0, 1).contiguous() for t in hids) 365 | else: 366 | hids = None 367 | if output_attentions: 368 | # Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len] 369 | attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions) 370 | # We transpose back here to shape [bsz, len, hidden_dim] 371 | core_out = core_out.transpose(0, 1).contiguous() 372 | 373 | if not return_dict: 374 | return tuple( 375 | v for v in [core_out, new_mems, hids, attentions] if v is not None 376 | ) 377 | 378 | return TransfoXLModelOutput( 379 | last_hidden_state=core_out, 380 | mems=new_mems, 381 | hidden_states=hids, 382 | attentions=attentions, 383 | ) 384 | -------------------------------------------------------------------------------- /relic/envs/train_il_env_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 2 | # This source code is licensed under the MIT license found in the 3 | # LICENSE file in the root directory of this source tree. 4 | 5 | from math import ceil 6 | import os 7 | import random 8 | import signal 9 | from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Type 10 | 11 | from habitat import ThreadedVectorEnv, VectorEnv, logger, make_dataset 12 | from habitat.config import read_write 13 | from habitat.gym import make_gym_from_config 14 | from habitat_baselines.common.env_factory import VectorEnvFactory 15 | from habitat_baselines.rl.ddppo.ddp_utils import get_distrib_size 16 | import torch 17 | from multiprocessing.connection import Connection 18 | from habitat.core.logging import logger 19 | from habitat.gym.gym_env_episode_count_wrapper import EnvCountEpisodeWrapper 20 | from habitat.gym.gym_env_obs_dict_wrapper import EnvObsDictWrapper 21 | 22 | from habitat.core.vector_env import ( 23 | CALL_COMMAND, 24 | CLOSE_COMMAND, 25 | COUNT_EPISODES_COMMAND, 26 | RENDER_COMMAND, 27 | RESET_COMMAND, 28 | STEP_COMMAND, 29 | ) 30 | from habitat.sims.habitat_simulator.actions import HabitatSimActions 31 | 32 | from relic.tasks.utils import get_obj_pixel_counts 33 | 34 | if TYPE_CHECKING: 35 | from omegaconf import DictConfig 36 | 37 | from habitat.tasks.nav.shortest_path_follower import ShortestPathFollower 38 | import numpy as np 39 | import importlib 40 | 41 | 42 | def get_make_env_func_by_name(name): 43 | if "." in name: 44 | module_name, func_name = name.rsplit(".", 1) 45 | module = importlib.import_module(module_name) 46 | func = getattr(module, func_name) 47 | else: 48 | func = globals()[name] 49 | return func 50 | 51 | 52 | class DummyAgent: 53 | def __init__(self, sim, agent_id=0): 54 | self.sim = sim 55 | self.agent_id = agent_id 56 | self.agent = self.sim.get_agent(agent_id) 57 | 58 | def __getattr__(self, attr): 59 | return getattr(self.agent, attr) 60 | 61 | @property 62 | def state(self): 63 | return self.sim.get_agent_state(self.agent_id) 64 | 65 | 66 | class Hab3ShortestPathFollower(ShortestPathFollower): 67 | def __init__(self, *args, task, **kwargs): 68 | super().__init__(*args, **kwargs) 69 | self.task = task 70 | self._init_state() 71 | 72 | def _build_follower(self, *args, **kwargs): 73 | super()._build_follower(*args, **kwargs) 74 | self._follower.agent = DummyAgent(self._sim) 75 | 76 | def _init_state(self): 77 | self.is_around_the_obj = False 78 | self.state = -1 79 | self.inside_state_counter = 0 80 | self.last_angle = None 81 | 82 | def after_update(self): 83 | self._current_scene = None 84 | self._init_state() 85 | 86 | def get_next_action(self, goal_pos): 87 | # l2_distance = self.task.measurements.get_metrics()["l2_distance"] 88 | if not self.is_around_the_obj: 89 | next_action = super().get_next_action(goal_pos) 90 | if next_action == HabitatSimActions.stop or next_action is None: 91 | self.is_around_the_obj = True 92 | self.state = 0 93 | else: 94 | return next_action 95 | 96 | rot_dist_to_closest_goal = self.task.measurements.get_metrics()[ 97 | "rot_dist_to_closest_goal" 98 | ] 99 | if self.state == 0: 100 | next_action = HabitatSimActions.turn_left 101 | self.inside_state_counter += 1 102 | if self.last_angle is None: 103 | self.last_angle = rot_dist_to_closest_goal 104 | elif (rot_dist_to_closest_goal > self.last_angle) and ( 105 | rot_dist_to_closest_goal < 1 106 | ): 107 | self.state += 1 108 | self.inside_state_counter = 0 109 | self.last_angle = None 110 | elif self.state == 1: 111 | next_action = HabitatSimActions.turn_right 112 | self.inside_state_counter += 1 113 | if self.last_angle is None: 114 | self.last_angle = rot_dist_to_closest_goal 115 | if rot_dist_to_closest_goal > self.last_angle: 116 | self.state += 1 117 | self.inside_state_counter = 0 118 | elif self.state == 2: 119 | next_action = 4 120 | self.inside_state_counter += 1 121 | if self.inside_state_counter > 5: 122 | self.state += 1 123 | self.inside_state_counter = 0 124 | elif self.state == 3: 125 | next_action = 5 126 | self.inside_state_counter += 1 127 | if self.inside_state_counter > 5: 128 | self.state += 1 129 | self.inside_state_counter = 0 130 | 131 | closest_object_id = self.task.measurements.measures[ 132 | "rot_dist_to_closest_goal" 133 | ].closest_object_index 134 | if self.state >= 4 or ( 135 | self.state >= 2 136 | and get_obj_pixel_counts(self.task, margin=5, strict=False).get( 137 | closest_object_id, 0 138 | ) 139 | > 10 140 | ): 141 | next_action = HabitatSimActions.stop 142 | self.last_angle = rot_dist_to_closest_goal 143 | # print(self.state, get_obj_pixel_counts(self.task, margin=5, strict=False).get(closest_object_id, 0), next_action) 144 | return next_action 145 | 146 | 147 | class ShortestPathVectorEnv(VectorEnv): 148 | @staticmethod 149 | def _worker_env( 150 | connection_read_fn: Callable, 151 | connection_write_fn: Callable, 152 | env_fn: Callable, 153 | env_fn_args: Tuple[Any], 154 | auto_reset_done: bool, 155 | mask_signals: bool = False, 156 | child_pipe: Optional[Connection] = None, 157 | parent_pipe: Optional[Connection] = None, 158 | ) -> None: 159 | r"""process worker for creating and interacting with the environment.""" 160 | if mask_signals: 161 | signal.signal(signal.SIGINT, signal.SIG_IGN) 162 | signal.signal(signal.SIGTERM, signal.SIG_IGN) 163 | 164 | signal.signal(signal.SIGUSR1, signal.SIG_IGN) 165 | signal.signal(signal.SIGUSR2, signal.SIG_IGN) 166 | 167 | env = EnvCountEpisodeWrapper(EnvObsDictWrapper(env_fn(*env_fn_args))) 168 | try: 169 | follower = Hab3ShortestPathFollower( 170 | env.env.env.habitat_env.sim, 171 | 0.25, 172 | False, 173 | task=env.env.env.habitat_env._task, 174 | ) 175 | except Exception: 176 | logger.warn("Couldn't create shortest path follower.") 177 | follower = None 178 | 179 | if parent_pipe is not None: 180 | parent_pipe.close() 181 | try: 182 | command, data = connection_read_fn() 183 | while command != CLOSE_COMMAND: 184 | if command == STEP_COMMAND: 185 | observations, reward, done, info = env.step(data) 186 | 187 | if auto_reset_done and done: 188 | observations = env.reset() 189 | if follower is not None: 190 | follower._init_state() 191 | 192 | connection_write_fn((observations, reward, done, info)) 193 | 194 | elif command == RESET_COMMAND: 195 | observations = env.reset() 196 | connection_write_fn(observations) 197 | 198 | elif command == RENDER_COMMAND: 199 | connection_write_fn(env.render(*data[0], **data[1])) 200 | 201 | elif command == CALL_COMMAND: 202 | function_name, function_args = data 203 | if function_name == "best_action": 204 | task = env.env.env.habitat_env._task 205 | closest_index = task.measurements.measures[ 206 | "rot_dist_to_closest_goal" 207 | ].closest_object_index 208 | if closest_index == 0: 209 | start_index = 0 210 | end_index = task.all_snapped_obj_pos_sizes[closest_index] 211 | else: 212 | start_index = task.all_snapped_obj_pos_sizes[ 213 | closest_index - 1 214 | ] 215 | end_index = task.all_snapped_obj_pos_sizes[closest_index] 216 | snapped_points = np.asarray( 217 | task.all_snapped_obj_pos[start_index:end_index] 218 | ) 219 | agent_norms = np.linalg.norm( 220 | ( 221 | snapped_points 222 | - np.asarray(task.all_obj_pos[closest_index]) 223 | )[:, [0, 2]], 224 | axis=1, 225 | ) 226 | closest_point = agent_norms.argmin() 227 | 228 | if (agent_norms < 1.5).any(): 229 | snapped_points = snapped_points[agent_norms < 1.5] 230 | goal_point = snapped_points[ 231 | np.random.choice(len(snapped_points)) 232 | ] 233 | else: 234 | goal_point = snapped_points[closest_point] 235 | action = follower.get_next_action(goal_point) 236 | connection_write_fn(action) 237 | elif function_name == "episodes" and hasattr( 238 | env.env.env, "habitat_env" 239 | ): 240 | iterator_ = env.env.env.habitat_env.episode_iterator 241 | connection_write_fn( 242 | [(x.scene_id, x.episode_id) for x in iterator_.episodes] 243 | ) 244 | else: 245 | if function_name == "after_update" and follower is not None: 246 | follower.after_update() 247 | 248 | if function_args is None: 249 | function_args = {} 250 | 251 | result_or_fn = getattr(env, function_name) 252 | 253 | if len(function_args) > 0 or callable(result_or_fn): 254 | result = result_or_fn(**function_args) 255 | else: 256 | result = result_or_fn 257 | 258 | connection_write_fn(result) 259 | 260 | elif command == COUNT_EPISODES_COMMAND: 261 | connection_write_fn(len(env.episodes)) 262 | 263 | else: 264 | raise NotImplementedError(f"Unknown command {command}") 265 | 266 | command, data = connection_read_fn() 267 | 268 | except KeyboardInterrupt: 269 | logger.info("Worker KeyboardInterrupt") 270 | finally: 271 | if child_pipe is not None: 272 | child_pipe.close() 273 | env.close() 274 | 275 | 276 | class HabitatVectorEnvFactory(VectorEnvFactory): 277 | def construct_envs( 278 | self, 279 | config: "DictConfig", 280 | workers_ignore_signals: bool = False, 281 | enforce_scenes_greater_eq_environments: bool = False, 282 | is_first_rank: bool = True, 283 | distribute_envs_across_gpus=None, 284 | ) -> VectorEnv: 285 | r"""Create VectorEnv object with specified config and env class type. 286 | To allow better performance, dataset are split into small ones for 287 | each individual env, grouped by scenes. 288 | """ 289 | if distribute_envs_across_gpus is None: 290 | distribute_envs_across_gpus = enforce_scenes_greater_eq_environments 291 | 292 | num_environments = config.habitat_baselines.num_environments 293 | configs = [] 294 | make_env_func_name = config.habitat.task.get( 295 | "make_env_fn", "make_gym_from_config" 296 | ) 297 | if make_env_func_name == "make_gym_from_config": 298 | dataset = make_dataset(config.habitat.dataset.type) 299 | scenes = list(config.habitat.dataset.content_scenes) 300 | if "*" in config.habitat.dataset.content_scenes: 301 | scenes = dataset.get_scenes_to_load(config.habitat.dataset) 302 | scenes = sorted(scenes) 303 | local_rank, world_rank, world_size = get_distrib_size() 304 | split_size = ceil(len(scenes) / world_size) 305 | orig_size = len(scenes) 306 | scenes = scenes[world_rank * split_size : (world_rank + 1) * split_size] 307 | scenes_ids = list(range(orig_size))[ 308 | world_rank * split_size : (world_rank + 1) * split_size 309 | ] 310 | logger.warn(f"Loading {len(scenes)}/{orig_size}. IDs: {scenes_ids}") 311 | 312 | if num_environments < 1: 313 | raise RuntimeError("num_environments must be strictly positive") 314 | 315 | if len(scenes) == 0: 316 | raise RuntimeError( 317 | "No scenes to load, multiple process logic relies on being able to split scenes uniquely between processes" 318 | ) 319 | 320 | random.shuffle(scenes) 321 | 322 | scene_splits: List[List[str]] = [[] for _ in range(num_environments)] 323 | for idx in range(max(len(scene_splits), len(scenes))): 324 | scene_splits[idx % len(scene_splits)].append(scenes[idx % len(scenes)]) 325 | 326 | logger.warn(f"Scene splits: {scene_splits}.") 327 | assert all(scene_splits) 328 | else: 329 | scenes = [] 330 | 331 | for env_index in range(num_environments): 332 | proc_config = config.copy() 333 | with read_write(proc_config): 334 | if distribute_envs_across_gpus: 335 | proc_config.habitat.simulator.habitat_sim_v0.gpu_device_id = ( 336 | env_index % torch.cuda.device_count() 337 | ) 338 | 339 | task_config = proc_config.habitat 340 | task_config.seed = task_config.seed + env_index 341 | remove_measure_names = [] 342 | if not is_first_rank: 343 | # Filter out non rank0_measure from the task config if we are not on rank0. 344 | remove_measure_names.extend(task_config.task.rank0_measure_names) 345 | if (env_index != 0) or not is_first_rank: 346 | # Filter out non-rank0_env0 measures from the task config if we 347 | # are not on rank0 env0. 348 | remove_measure_names.extend( 349 | task_config.task.rank0_env0_measure_names 350 | ) 351 | 352 | task_config.task.measurements = { 353 | k: v 354 | for k, v in task_config.task.measurements.items() 355 | if k not in remove_measure_names 356 | } 357 | 358 | if len(scenes) > 0: 359 | task_config.dataset.content_scenes = scene_splits[env_index] 360 | 361 | configs.append(proc_config) 362 | 363 | vector_env_cls: Type[Any] 364 | if int(os.environ.get("HABITAT_ENV_DEBUG", 0)): 365 | logger.warn( 366 | "Using the debug Vector environment interface. Expect slower performance." 367 | ) 368 | vector_env_cls = ThreadedVectorEnv 369 | else: 370 | vector_env_cls = ShortestPathVectorEnv 371 | 372 | envs = vector_env_cls( 373 | make_env_fn=get_make_env_func_by_name(make_env_func_name), 374 | env_fn_args=tuple((c,) for c in configs), 375 | workers_ignore_signals=workers_ignore_signals, 376 | ) 377 | 378 | if config.habitat.simulator.renderer.enable_batch_renderer: 379 | envs.initialize_batch_renderer(config) 380 | 381 | return envs 382 | --------------------------------------------------------------------------------