├── ovon ├── models │ ├── pointnav.py │ ├── encoders │ │ ├── habitat_resnet.py │ │ ├── vc1_encoder.py │ │ ├── siglip_encoder.py │ │ ├── dinov2_encoder.py │ │ ├── make_encoder.py │ │ ├── cma_xattn.py │ │ ├── depth_encoder.py │ │ ├── visual_encoder.py │ │ ├── cross_attention.py │ │ ├── vit.py │ │ └── visual_encoder_v2.py │ ├── visual_encoders.py │ ├── transformer_policy.py │ └── transforms.py ├── task │ ├── __init__.py │ ├── simulator.py │ └── rewards.py ├── measurements │ ├── __init__.py │ ├── collision_penalty.py │ ├── sum_reward.py │ └── imagenav.py ├── utils │ ├── analysis │ │ ├── My_HM3D_Missing_Tags_Taxonomy.csv │ │ ├── fixes_to_provided_mappings.ods │ │ ├── fixes_to_provided_mappings.csv │ │ ├── postprocess_meta.py │ │ └── Issues_with_provided_mappings.csv │ ├── test_dataset.py │ ├── plot_tsne.py │ ├── rollout_storage_no_2d.py │ ├── visualize_policy_embeddings.py │ ├── visualize │ │ ├── statistics.py │ │ ├── viz.py │ │ └── semantic_nav_analysis.py │ ├── dagger_environment_worker.py │ ├── shuffle_episodes.py │ ├── utils.py │ ├── cache_clip_embeddings.py │ ├── lr_scheduler.py │ ├── visualize_trajectories.py │ └── sample_episodes.py ├── dataset │ ├── __init__.py │ └── ovon_dataset.py ├── trainers │ ├── envs.py │ ├── inference_worker_with_kv.py │ ├── ppo_trainer_no_2d.py │ ├── ver_rollout_storage_with_kv.py │ └── dagger_trainer.py ├── __init__.py ├── test.py └── obs_transformers │ ├── relabel_imagegoal.py │ ├── relabel_teacher_actions.py │ └── resize.py ├── setup.py ├── docs └── ovon_task.jpg ├── .gitmodules ├── .gitignore ├── config ├── tasks │ ├── objectnav_locobot_hm3d.yaml │ ├── objectnav_stretch_hm3d.yaml │ └── imagenav_stretch_hm3d.yaml └── experiments │ ├── rnn_rl.yaml │ ├── transformer_rl.yaml │ ├── rnn_rl_finetune.yaml │ ├── rnn_dagger.yaml │ ├── transformer_rl_finetune.yaml │ └── transformer_dagger.yaml ├── scripts ├── dataset │ ├── generate_objectnav_dataset.sh │ ├── submit_language_annotation_jobs.sh │ ├── clean_full_dataset.sh │ ├── clean_scene_dataset.sh │ ├── submit_data_gen_jobs.sh │ ├── generate_val_dataset.sh │ ├── generate_dataset.sh │ ├── generate_languagenav_dataset.sh │ └── generate_val_episode_counts.py ├── prepare_dataset.sh ├── train_ddppo.sh ├── train_ver.sh ├── train │ ├── 2-imagenav-ver.sh │ ├── 1-ovon-ver.sh │ ├── 3-ovon-ver-scratch.sh │ └── 4-objectnav-transformer.sh └── eval │ ├── 4-objectnav-transformer.sh │ └── 3-ovon-ver-eval-analysis.sh ├── setup.sh ├── .pre-commit-config.yaml ├── pyproject.toml └── README.md /ovon/models/pointnav.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ovon/task/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ovon/measurements/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /ovon/utils/analysis/My_HM3D_Missing_Tags_Taxonomy.csv: -------------------------------------------------------------------------------- 1 | unknown" 2 | -------------------------------------------------------------------------------- /docs/ovon_task.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naokiyokoyama/ovon/HEAD/docs/ovon_task.jpg -------------------------------------------------------------------------------- /ovon/utils/analysis/fixes_to_provided_mappings.ods: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naokiyokoyama/ovon/HEAD/ovon/utils/analysis/fixes_to_provided_mappings.ods -------------------------------------------------------------------------------- /ovon/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from ovon.dataset import ( 2 | generate_viewpoints, 3 | objectnav_generator, 4 | pose_sampler, 5 | semantic_utils, 6 | visualization, 7 | ) 8 | from ovon.dataset.ovon_dataset import OVONDatasetV1 9 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "distributed_dagger"] 2 | path = distributed_dagger 3 | url = git@github.com:naokiyokoyama/distributed_dagger 4 | [submodule "frontier_exploration"] 5 | path = frontier_exploration 6 | url = git@github.com:naokiyokoyama/frontier_exploration 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # python 2 | __pycache__/ 3 | *.egg-info/ 4 | 5 | # environment 6 | variables.sh 7 | habitat-lab 8 | 9 | # development 10 | .idea 11 | .vscode 12 | 13 | # data 14 | data/ 15 | 16 | # logs 17 | tb/ 18 | train.log 19 | slurm_logs/ 20 | 21 | # scripts 22 | scripts/ 23 | -------------------------------------------------------------------------------- /ovon/models/encoders/habitat_resnet.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from habitat_baselines.rl.ddppo.policy.resnet_policy import ResNetEncoder 7 | 8 | 9 | class HabitatResNetEncoder(nn.Module): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.backbone = ResNetEncoder(**kwargs) 13 | self.output_size = np.prod(self.backbone.output_shape) 14 | self.output_shape = self.backbone.output_shape 15 | 16 | def forward(self, observations: Dict, *args, **kwargs) -> torch.Tensor: 17 | return self.backbone(observations) 18 | -------------------------------------------------------------------------------- /config/tasks/objectnav_locobot_hm3d.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /habitat: habitat_config_base 5 | - /habitat/task: objectnav 6 | - /habitat/simulator/agents@habitat.simulator.agents.main_agent: rgb_agent 7 | - /habitat/dataset/objectnav: hm3d 8 | - _self_ 9 | 10 | habitat: 11 | environment: 12 | max_episode_steps: 500 13 | 14 | simulator: 15 | turn_angle: 30 16 | tilt_angle: 30 17 | action_space_config: "v1" 18 | agents: 19 | main_agent: 20 | sim_sensors: 21 | rgb_sensor: 22 | width: 640 23 | height: 480 24 | hfov: 79 25 | position: [0, 0.88, 0] 26 | height: 0.88 27 | radius: 0.18 28 | habitat_sim_v0: 29 | gpu_device_id: 0 30 | allow_sliding: False 31 | -------------------------------------------------------------------------------- /config/tasks/objectnav_stretch_hm3d.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /habitat: habitat_config_base 5 | - /habitat/task: objectnav 6 | - /habitat/simulator/agents@habitat.simulator.agents.main_agent: rgb_agent 7 | - /habitat/dataset/objectnav: hm3d 8 | - _self_ 9 | 10 | habitat: 11 | environment: 12 | max_episode_steps: 500 13 | 14 | simulator: 15 | turn_angle: 30 16 | tilt_angle: 30 17 | action_space_config: "v1" 18 | agents: 19 | main_agent: 20 | sim_sensors: 21 | rgb_sensor: 22 | width: 360 23 | height: 640 24 | hfov: 42 25 | position: [0, 1.31, 0] 26 | height: 1.41 27 | radius: 0.17 28 | habitat_sim_v0: 29 | gpu_device_id: 0 30 | allow_sliding: False 31 | -------------------------------------------------------------------------------- /scripts/dataset/generate_objectnav_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=ovon 3 | #SBATCH --output=slurm_logs/dataset-%j.out 4 | #SBATCH --error=slurm_logs/dataset-%j.err 5 | #SBATCH --gpus 4 6 | #SBATCH --nodes 1 7 | #SBATCH --cpus-per-task 10 8 | #SBATCH --ntasks-per-node 1 9 | #SBATCH --constraint=a40 10 | #SBATCH --partition=short 11 | #SBATCH --exclude=conroy 12 | #SBATCH --signal=USR1@100 13 | #SBATCH --requeue 14 | 15 | export GLOG_minloglevel=2 16 | export HABITAT_SIM_LOG=quiet 17 | export MAGNUM_LOG=quiet 18 | 19 | MAIN_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) 20 | export MAIN_ADDR 21 | 22 | srun python ovon/dataset/objectnav_generator.py \ 23 | --split train \ 24 | --tasks-per-gpu 12 \ 25 | --start-poses-per-object 8000 \ 26 | --use-v1-scenes \ 27 | --multiprocessing 28 | -------------------------------------------------------------------------------- /scripts/dataset/submit_language_annotation_jobs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | path=$1 4 | split=$2 5 | model=$3 6 | 7 | export OPENAI_APIKEY= 8 | 9 | count=0 10 | prompt_meta_files=`ls ${path}/${split}/content/*_meta.json` 11 | for i in ${prompt_meta_files[@]} 12 | do 13 | scene_id=`basename $i` 14 | base=${scene_id%.*} # remove .gz 15 | base=${base%.*} # remove .json 16 | 17 | if [ -f "${path}/${split}/content/${base}_annotated.json" ]; then 18 | echo "Skipping ${base}" 19 | continue 20 | fi 21 | 22 | meta_path="${path}/${split}/content/${base}.json" 23 | meta_output_path="${path}/${split}/content/${base}_annotated.json" 24 | 25 | echo "Submitting ${base} - ${count}" 26 | python ovon/dataset/generate_captions.py --path $meta_path --output-path $meta_output_path --model $model 27 | count=$((count + 1)) 28 | done 29 | -------------------------------------------------------------------------------- /scripts/dataset/clean_full_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | split=$1 4 | content_dir="data/datasets/ovon/hm3d/v3_shuffled/${split}/content" 5 | output_path="data/datasets/ovon/hm3d/v3_shuffled_cleaned/${split}/content" 6 | 7 | gz_files=`ls ${content_dir}/*json.gz` 8 | for i in ${gz_files[@]} 9 | do 10 | scene_id=`basename $i` 11 | base=${scene_id%.*} # remove .gz 12 | base=${base%.*} # remove .json 13 | 14 | if [ -f "${output_path}/${base}.json.gz" ]; then 15 | echo "Skipping ${base}" 16 | continue 17 | fi 18 | echo "Submitting ${base}" 19 | 20 | sbatch --job-name=clean-${split}-${base} \ 21 | --output=slurm_logs/dataset/clean-${split}-${base}.out \ 22 | --error=slurm_logs/dataset/clean-${split}-${base}.err \ 23 | --export=ALL,scene_path=$i,output_path=$output_path \ 24 | scripts/dataset/clean_scene_dataset.sh 25 | done 26 | 27 | -------------------------------------------------------------------------------- /scripts/dataset/clean_scene_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --gpus 1 3 | #SBATCH --nodes 1 4 | #SBATCH --cpus-per-task 7 5 | #SBATCH --ntasks-per-node 1 6 | #SBATCH --signal=USR1@100 7 | #SBATCH --requeue 8 | #SBATCH --constraint="a40|rtx_6000|2080_ti" 9 | #SBATCH --partition=short 10 | #SBATCH --exclude calculon,alexa,cortana,bmo,c3po,ripl-s1,t1000,hal,irona,fiona 11 | 12 | export GLOG_minloglevel=2 13 | export HABITAT_SIM_LOG=quiet 14 | export MAGNUM_LOG=quiet 15 | 16 | MAIN_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) 17 | export MAIN_ADDR 18 | 19 | source /srv/flash1/rramrakhya3/miniconda3/etc/profile.d/conda.sh 20 | conda deactivate 21 | conda activate ovon 22 | 23 | echo "\n" 24 | echo $scene_path 25 | echo $(which python) 26 | echo "ola" 27 | 28 | srun python ovon/dataset/clean_episodes.py --path $scene_path --output-path $output_path -------------------------------------------------------------------------------- /config/tasks/imagenav_stretch_hm3d.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /habitat: habitat_config_base 5 | - /habitat/task: imagenav 6 | - /habitat/simulator/agents@habitat.simulator.agents.main_agent: rgb_agent 7 | - _self_ 8 | 9 | habitat: 10 | environment: 11 | max_episode_steps: 500 12 | 13 | simulator: 14 | turn_angle: 30 15 | tilt_angle: 30 16 | action_space_config: "v1" 17 | agents: 18 | main_agent: 19 | sim_sensors: 20 | rgb_sensor: 21 | width: 360 22 | height: 640 23 | hfov: 42 24 | position: [0, 1.31, 0] 25 | height: 1.41 26 | radius: 0.17 27 | habitat_sim_v0: 28 | gpu_device_id: 0 29 | allow_sliding: False 30 | 31 | dataset: 32 | type: PointNav-v1 33 | split: train 34 | data_path: data/datasets/imagenav/hm3d/v1_stretch/{split}/{split}.json.gz 35 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | # Record of how the environment was set up 2 | # Create conda environment. Mamba is recommended for faster installation. 3 | conda_env_name=ovon 4 | mamba create -n $conda_env_name python=3.7 cmake=3.14.0 -y 5 | mamba install -n $conda_env_name \ 6 | habitat-sim=0.2.3 headless pytorch=1.12.1 cudatoolkit=11.3 \ 7 | -c pytorch -c nvidia -c conda-forge -c aihabitat -y 8 | 9 | # Install this repo as a package 10 | mamba activate $conda_env_name 11 | pip install -e . 12 | 13 | # Install frontier_exploration 14 | cd frontier_exploration && pip install -e . && cd .. 15 | 16 | # Install habitat-lab 17 | git clone --branch v0.2.3 git@github.com:facebookresearch/habitat-lab.git 18 | cd habitat-lab 19 | pip install -e habitat-lab 20 | pip install -e habitat-baselines 21 | 22 | pip install ftfy regex tqdm GPUtil trimesh seaborn timm scikit-learn einops transformers 23 | pip install git+https://github.com/openai/CLIP.git 24 | -------------------------------------------------------------------------------- /ovon/trainers/envs.py: -------------------------------------------------------------------------------- 1 | # from typing import Optional 2 | 3 | # import gym 4 | # import habitat 5 | # import numpy as np 6 | # from habitat import Dataset 7 | # from habitat.core.environments import RLTaskEnv 8 | # from habitat.gym.gym_wrapper import HabGymWrapper 9 | 10 | 11 | # class CustomRLTaskEnv(RLTaskEnv): 12 | # def after_update(self): 13 | # self._env.episode_iterator.after_update() 14 | 15 | 16 | # @habitat.registry.register_env(name="CustomGymHabitatEnv") 17 | # class CustomGymHabitatEnv(gym.Wrapper): 18 | # """ 19 | # A registered environment that wraps a RLTaskEnv with the HabGymWrapper 20 | # to use the default gym API. 21 | # """ 22 | 23 | # def __init__( 24 | # self, config: "DictConfig", dataset: Optional[Dataset] = None 25 | # ): 26 | # base_env = CustomRLTaskEnv(config=config, dataset=dataset) 27 | # env = HabGymWrapper(env=base_env) 28 | # super().__init__(env) 29 | -------------------------------------------------------------------------------- /scripts/dataset/submit_data_gen_jobs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | scenes_dir=$1 4 | split=$2 5 | output_path=$3 6 | start_poses_per_object=2000 7 | episodes_per_object=50 8 | 9 | glb_files=`ls ${scenes_dir}/*/*.semantic.glb` 10 | for i in ${glb_files[@]} 11 | do 12 | scene_id=`basename $i` 13 | base=${scene_id%.*} # remove .gz 14 | base=${base%.*} # remove .json 15 | 16 | if [ -f "${output_path}/${split}/content/${base}.json.gz" ]; then 17 | echo "Skipping ${base}" 18 | continue 19 | fi 20 | 21 | echo "Submitting ${base}" 22 | sbatch --job-name=$split-${base} \ 23 | --output=slurm_logs/dataset/lnav-$split-${base}.out \ 24 | --error=slurm_logs/dataset/lnav-$split-${base}.err \ 25 | --gpus 1 \ 26 | --cpus-per-task 6 \ 27 | --export=ALL,scene=$i,num_tasks=1,split=$split,num_scenes=1,output_path=$output_path,start_poses_per_object=$start_poses_per_object,episodes_per_object=$episodes_per_object \ 28 | scripts/dataset/generate_languagenav_dataset.sh 29 | done 30 | 31 | -------------------------------------------------------------------------------- /scripts/prepare_dataset.sh: -------------------------------------------------------------------------------- 1 | valid=true 2 | 3 | # Check for two args 4 | if [ -z "$2" ] 5 | then 6 | echo "Please provide two paths to hm3d-train-glb-v0.2 and hm3d-val-glb-v0.2 as args" 7 | valid=false 8 | fi 9 | 10 | # Check for valid paths 11 | if [ ! -d "$1" ] 12 | then 13 | echo "$1 is not a valid directory!" 14 | echo "Please provide a valid path for hm3d-train-glb-v0.2" 15 | valid=false 16 | fi 17 | 18 | if [ ! -d "$2" ] 19 | then 20 | echo "$2 is not a valid directory!" 21 | echo "Please provide a valid path for hm3d-val-glb-v0.2" 22 | valid=false 23 | fi 24 | 25 | if [ $valid = true ] 26 | then 27 | mkdir -p habitat-lab/data/scene_datasets/hm3d && 28 | echo "Creating symlink from $1 to habitat-lab/data/scene_datasets/hm3d" && 29 | ln -s `realpath $1` habitat-lab/data/scene_datasets/hm3d/train && 30 | echo "Creating symlink from $2 to habitat-lab/data/scene_datasets/hm3d" && 31 | ln -s `realpath $2` habitat-lab/data/scene_datasets/hm3d/val && 32 | echo "Done" 33 | fi 34 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/charliermarsh/ruff-pre-commit 3 | rev: 'v0.0.263' 4 | hooks: 5 | - id: ruff 6 | args: ['--fix', '--config', 'pyproject.toml'] 7 | 8 | - repo: https://github.com/pre-commit/pre-commit-hooks 9 | rev: v4.4.0 10 | hooks: 11 | - id: end-of-file-fixer 12 | - id: trailing-whitespace 13 | - id: check-yaml 14 | - id: check-added-large-files 15 | - id: check-toml 16 | 17 | - repo: https://github.com/psf/black 18 | rev: 23.3.0 19 | hooks: 20 | - id: black 21 | language_version: python3.7 22 | args: ['--config', 'pyproject.toml'] 23 | verbose: true 24 | 25 | # - repo: https://github.com/pre-commit/mirrors-mypy 26 | # rev: v1.2.0 27 | # hooks: 28 | # - id: mypy 29 | # pass_filenames: false 30 | # additional_dependencies: 31 | # - types-protobuf 32 | # - types-requests 33 | # - types-simplejson 34 | # - types-ujson 35 | # - types-PyYAML 36 | # - types-toml 37 | # - types-six 38 | -------------------------------------------------------------------------------- /scripts/dataset/generate_val_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=ovon-dgen 3 | #SBATCH --output=slurm_logs/dataset-%j.out 4 | #SBATCH --error=slurm_logs/dataset-%j.err 5 | #SBATCH --gpus 1 6 | #SBATCH --nodes 1 7 | #SBATCH --cpus-per-task 16 8 | #SBATCH --ntasks-per-node 1 9 | #SBATCH --partition=short 10 | #SBATCH --exclude=conroy,ig-88 11 | #SBATCH --signal=USR1@100 12 | #SBATCH --requeue 13 | 14 | export GLOG_minloglevel=2 15 | export HABITAT_SIM_LOG=quiet 16 | export MAGNUM_LOG=quiet 17 | 18 | MAIN_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) 19 | export MAIN_ADDR 20 | 21 | source /srv/flash1/rramrakhya3/miniconda3/etc/profile.d/conda.sh 22 | conda deactivate 23 | conda activate ovon 24 | 25 | SPLIT="val" 26 | NUM_TASKS=$1 27 | OUTPUT_PATH=$2 28 | NUM_SCENES=-1 29 | 30 | srun python ovon/dataset/objectnav_generator.py \ 31 | --split $SPLIT \ 32 | --num-scenes $NUM_SCENES \ 33 | --tasks-per-gpu $NUM_TASKS \ 34 | --start-poses-per-object 200 \ 35 | --episodes-per-object 10 \ 36 | --output-path $OUTPUT_PATH \ 37 | --multiprocessing \ 38 | --disable-euc-geo-ratio-check 39 | -------------------------------------------------------------------------------- /scripts/train_ddppo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=ovon 3 | #SBATCH --output=slurm_logs/ovon-ddppo-%j.out 4 | #SBATCH --error=slurm_logs/ovon-ddppo-%j.err 5 | #SBATCH --gpus 4 6 | #SBATCH --nodes 1 7 | #SBATCH --cpus-per-task 10 8 | #SBATCH --ntasks-per-node 4 9 | #SBATCH --constraint=a40 10 | #SBATCH --partition=short 11 | #SBATCH --signal=USR1@100 12 | #SBATCH --requeue 13 | 14 | 15 | export GLOG_minloglevel=2 16 | export HABITAT_SIM_LOG=quiet 17 | export MAGNUM_LOG=quiet 18 | 19 | MAIN_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) 20 | export MAIN_ADDR 21 | 22 | TENSORBOARD_DIR="tb/objectnav/ddppo/resnetclip/seed_1" 23 | CHECKPOINT_DIR="data/new_checkpoints/objectnav/ddppo/resnetclip/seed_1" 24 | 25 | srun python -um ovon.run \ 26 | --run-type train \ 27 | --exp-config config/experiments/ver_objectnav.yaml \ 28 | habitat_baselines.rl.policy.name=PointNavResNetCLIPPolicy \ 29 | habitat_baselines.rl.ddppo.train_encoder=False \ 30 | habitat_baselines.rl.ddppo.backbone=resnet50_clip_avgpool \ 31 | habitat_baselines.tensorboard_dir=${TENSORBOARD_DIR} \ 32 | habitat_baselines.checkpoint_folder=${CHECKPOINT_DIR} \ 33 | -------------------------------------------------------------------------------- /scripts/dataset/generate_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=ovon-dgen 3 | #SBATCH --output=slurm_logs/dataset-%j.out 4 | #SBATCH --error=slurm_logs/dataset-%j.err 5 | #SBATCH --gpus 1 6 | #SBATCH --nodes 1 7 | #SBATCH --cpus-per-task 6 8 | #SBATCH --ntasks-per-node 1 9 | #SBATCH --constraint="a40|rtx_6000|2080_ti" 10 | #SBATCH --partition=short 11 | #SBATCH --exclude calculon,alexa,cortana,bmo,c3po,ripl-s1,t1000,hal,irona,fiona 12 | #SBATCH --signal=USR1@100 13 | #SBATCH --requeue 14 | 15 | export GLOG_minloglevel=2 16 | export HABITAT_SIM_LOG=quiet 17 | export MAGNUM_LOG=quiet 18 | 19 | MAIN_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) 20 | export MAIN_ADDR 21 | 22 | source /srv/flash1/rramrakhya3/miniconda3/etc/profile.d/conda.sh 23 | conda deactivate 24 | conda activate ovon 25 | 26 | # SPLIT=$1 27 | # NUM_TASKS=$2 28 | # OUTPUT_PATH=$3 29 | # NUM_SCENES=-1 30 | 31 | srun python ovon/dataset/objectnav_generator.py \ 32 | --scene $scene \ 33 | --split $split \ 34 | --num-scenes $num_scenes \ 35 | --tasks-per-gpu $num_tasks \ 36 | --output-path $output_path \ 37 | --start-poses-per-object $start_poses_per_object \ 38 | --episodes-per-object $episodes_per_object \ 39 | --disable-euc-geo-ratio-check 40 | -------------------------------------------------------------------------------- /ovon/__init__.py: -------------------------------------------------------------------------------- 1 | from ovon import config 2 | from ovon.dataset import ovon_dataset 3 | from ovon.measurements import collision_penalty, nav, sum_reward 4 | from ovon.models import ( 5 | clip_policy, 6 | objaverse_clip_policy, 7 | ovrl_policy, 8 | transformer_policy, 9 | ) 10 | from ovon.obs_transformers import ( 11 | image_goal_encoder, 12 | relabel_imagegoal, 13 | relabel_teacher_actions, 14 | resize, 15 | ) 16 | from ovon.task import rewards, sensors, simulator 17 | from ovon.trainers import dagger_trainer, ppo_trainer_no_2d, ver_trainer 18 | from ovon.utils import visualize_trajectories 19 | 20 | try: 21 | import frontier_exploration 22 | except ModuleNotFoundError as e: 23 | # If the error was due to the frontier_exploration package not being installed, then 24 | # pass, but warn. Do not pass if it was due to another package being missing. 25 | if e.name != "frontier_exploration": 26 | raise e 27 | else: 28 | print( 29 | "Warning: frontier_exploration package not installed. Things may not work. " 30 | "To install:\n" 31 | "git clone git@github.com:naokiyokoyama/frontier_exploration.git &&\n" 32 | "cd frontier_exploration && pip install -e ." 33 | ) 34 | -------------------------------------------------------------------------------- /ovon/test.py: -------------------------------------------------------------------------------- 1 | import habitat_sim 2 | 3 | GOAL_CATEGORIES = [ 4 | "chair", 5 | "bed", 6 | "plant", 7 | "toilet", 8 | "tv_monitor", 9 | "sofa", 10 | ] 11 | 12 | backend_cfg = habitat_sim.SimulatorConfiguration() 13 | backend_cfg.scene_id = ( 14 | "data/scene_datasets/hm3d/minival/00800-TEEsavR23oF/TEEsavR23oF.basis.glb" 15 | ) 16 | backend_cfg.scene_dataset_config_file = ( 17 | "data/scene_datasets/hm3d/hm3d_annotated_basis.scene_dataset_config.json" 18 | ) 19 | 20 | rgb_cfg = habitat_sim.CameraSensorSpec() 21 | rgb_cfg.uuid = "rgb" 22 | rgb_cfg.sensor_type = habitat_sim.SensorType.COLOR 23 | 24 | 25 | sem_cfg = habitat_sim.CameraSensorSpec() 26 | sem_cfg.uuid = "semantic" 27 | sem_cfg.sensor_type = habitat_sim.SensorType.SEMANTIC 28 | 29 | agent_cfg = habitat_sim.agent.AgentConfiguration() 30 | agent_cfg.sensor_specifications = [rgb_cfg, sem_cfg] 31 | 32 | sim_cfg = habitat_sim.Configuration(backend_cfg, [agent_cfg]) 33 | sim = habitat_sim.Simulator(sim_cfg) 34 | 35 | total_objects = len(sim.semantic_scene.objects) 36 | print("Total objects - {}".format(total_objects)) 37 | 38 | for obj in sim.semantic_scene.objects: 39 | if obj.category.name("") in GOAL_CATEGORIES: 40 | print(obj.category.name(""), obj.obb.center, obj.aabb.center, obj.obb.sizes) 41 | -------------------------------------------------------------------------------- /scripts/train_ver.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=ovon 3 | #SBATCH --output=slurm_logs/ovon-ver-%j.out 4 | #SBATCH --error=slurm_logs/ovon-ver-%j.err 5 | #SBATCH --gpus 2 6 | #SBATCH --nodes 1 7 | #SBATCH --cpus-per-task 10 8 | #SBATCH --ntasks-per-node 2 9 | #SBATCH --constraint=a40 10 | #SBATCH --partition=short 11 | #SBATCH --signal=USR1@100 12 | #SBATCH --requeue 13 | 14 | export GLOG_minloglevel=2 15 | export HABITAT_SIM_LOG=quiet 16 | export MAGNUM_LOG=quiet 17 | 18 | MAIN_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) 19 | export MAIN_ADDR 20 | 21 | TENSORBOARD_DIR="tb/objectnav/ver/resnetclip/hm3d_v0.2_22_cat/seed_1" 22 | CHECKPOINT_DIR="data/new_checkpoints/objectnav/ver/resnetclip/hm3d_v0.2_22_cat/seed_1" 23 | DATA_PATH="data/datasets/objectnav/hm3d_semantic_v0.2/v1" 24 | 25 | srun python -um ovon.run \ 26 | --run-type train \ 27 | --exp-config config/experiments/ver_objectnav.yaml \ 28 | habitat_baselines.trainer_name="ver" \ 29 | habitat_baselines.rl.policy.name=PointNavResNetCLIPPolicy \ 30 | habitat_baselines.rl.ddppo.train_encoder=False \ 31 | habitat_baselines.rl.ddppo.backbone=resnet50_clip_avgpool \ 32 | habitat_baselines.tensorboard_dir=${TENSORBOARD_DIR} \ 33 | habitat_baselines.checkpoint_folder=${CHECKPOINT_DIR} \ 34 | habitat.dataset.data_path=${DATA_PATH}/train/train.json.gz 35 | -------------------------------------------------------------------------------- /scripts/dataset/generate_languagenav_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=ovon-dgen 3 | #SBATCH --output=slurm_logs/dataset-%j.out 4 | #SBATCH --error=slurm_logs/dataset-%j.err 5 | #SBATCH --gpus 1 6 | #SBATCH --nodes 1 7 | #SBATCH --cpus-per-task 6 8 | #SBATCH --ntasks-per-node 1 9 | #SBATCH --constraint="a40|rtx_6000|2080_ti" 10 | #SBATCH --partition=short 11 | #SBATCH --exclude calculon,alexa,cortana,bmo,c3po,ripl-s1,t1000,hal,irona,fiona 12 | #SBATCH --signal=USR1@100 13 | #SBATCH --requeue 14 | 15 | export GLOG_minloglevel=2 16 | export HABITAT_SIM_LOG=quiet 17 | export MAGNUM_LOG=quiet 18 | 19 | MAIN_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) 20 | export MAIN_ADDR 21 | 22 | source /srv/flash1/rramrakhya3/miniconda3/etc/profile.d/conda.sh 23 | conda deactivate 24 | conda activate ovon 25 | 26 | # SPLIT=$1 27 | # NUM_TASKS=$2 28 | # OUTPUT_PATH=$3 29 | # NUM_SCENES=-1 30 | export PYTHONPATH=/srv/flash1/rramrakhya3/spring_2023/habitat-sim/src_python/ 31 | export HOME=/srv/flash1/rramrakhya3/summer_2023/ 32 | 33 | srun python ovon/dataset/languagenav_generator.py \ 34 | --scene $scene \ 35 | --split $split \ 36 | --num-scenes $num_scenes \ 37 | --tasks-per-gpu $num_tasks \ 38 | --output-path $output_path \ 39 | --start-poses-per-object $start_poses_per_object \ 40 | --episodes-per-object $episodes_per_object \ 41 | --with-start-poses 42 | -------------------------------------------------------------------------------- /ovon/models/encoders/vc1_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from habitat_baselines.common.tensor_dict import TensorDict 4 | 5 | from ovon.obs_transformers.resize import image_resize 6 | 7 | 8 | class VC1Encoder(nn.Module): 9 | def __init__(self): 10 | from vc_models.models.vit import model_utils 11 | 12 | super().__init__() 13 | ( 14 | self.model, 15 | self.output_size, 16 | self.model_transforms, 17 | model_info, 18 | ) = model_utils.load_model(model_utils.VC1_BASE_NAME) 19 | self.output_shape = (self.output_size,) 20 | 21 | def forward(self, observations: "TensorDict", *args, **kwargs) -> torch.Tensor: 22 | rgb = observations["rgb"] 23 | if rgb.dtype == torch.uint8: 24 | rgb = rgb.float() / 255.0 25 | 26 | if rgb.shape[1] != 224 and rgb.shape[2] != 224: 27 | # The img loaded should be Bx3x250x250 28 | rgb = image_resize(rgb, size=(250, 250), channels_last=True) 29 | 30 | assert rgb.shape[1] == 224 and rgb.shape[2] == 224 31 | # Change the channels to be first 32 | rgb = rgb.permute(0, 3, 1, 2) 33 | with torch.inference_mode(): 34 | # Output will be of size Bx3x224x224 35 | x = self.model_transforms(rgb) 36 | # Embedding will be 1x768 37 | x = self.model(x) 38 | 39 | return x 40 | -------------------------------------------------------------------------------- /ovon/utils/test_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | 5 | from ovon.utils.utils import count_episodes, load_dataset 6 | 7 | 8 | def test_dataset(path): 9 | print(path) 10 | files = glob.glob(os.path.join(path, "*.json.gz")) 11 | dataset = load_dataset(files[0]) 12 | 13 | print("Total # of episodes: {}".format(count_episodes(dataset))) 14 | 15 | for ep in dataset["episodes"]: 16 | if len(ep["children_object_categories"]) > 0: 17 | print("Found episode with children object categories") 18 | break 19 | 20 | print("Object goal: {}".format(ep["object_category"])) 21 | for children in ep["children_object_categories"]: 22 | print(children) 23 | scene_id = ep["scene_id"].split("/")[-1] 24 | goal_key = f"{scene_id}_{children}" 25 | 26 | # Ignore if there are no valid viewpoints for goal 27 | if goal_key not in dataset["goals_by_category"]: 28 | print("No valid viewpoints for child: {}".format(children)) 29 | continue 30 | print( 31 | "Viewpoints: {} for child: {}".format( 32 | len(dataset["goals_by_category"][goal_key]), children 33 | ) 34 | ) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument("--path", type=str, default="data/ovon_dataset.json") 40 | args = parser.parse_args() 41 | 42 | test_dataset(args.path) 43 | -------------------------------------------------------------------------------- /ovon/trainers/inference_worker_with_kv.py: -------------------------------------------------------------------------------- 1 | from multiprocessing.context import BaseContext 2 | 3 | from habitat_baselines.common.tensor_dict import TensorDict 4 | from habitat_baselines.rl.ver.inference_worker import ( 5 | InferenceWorker, 6 | InferenceWorkerProcess, 7 | ) 8 | from habitat_baselines.rl.ver.worker_common import WorkerBase 9 | 10 | 11 | class InferenceWorkerWithKVProcess(InferenceWorkerProcess): 12 | def _update_storage_no_ver( 13 | self, prev_step: TensorDict, current_step: TensorDict, current_steps 14 | ): 15 | current_step.pop("recurrent_hidden_states") 16 | super()._update_storage_no_ver(prev_step, current_step, current_steps) 17 | 18 | def _update_storage_ver( 19 | self, prev_step: TensorDict, current_step: TensorDict, my_slice 20 | ): 21 | current_step.pop("recurrent_hidden_states") 22 | super()._update_storage_ver(prev_step, current_step, my_slice) 23 | 24 | 25 | class InferenceWorkerWithKV(InferenceWorker): 26 | def __init__(self, mp_ctx: BaseContext, use_kv: bool = True, *args, **kwargs): 27 | if use_kv: 28 | self.setup_queue = mp_ctx.SimpleQueue() 29 | WorkerBase.__init__( 30 | self, 31 | mp_ctx, 32 | InferenceWorkerWithKVProcess, 33 | self.setup_queue, 34 | *args, 35 | **kwargs, 36 | ) 37 | else: 38 | super().__init__(mp_ctx, *args, **kwargs) 39 | -------------------------------------------------------------------------------- /scripts/train/2-imagenav-ver.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=imagenav 3 | #SBATCH --output=slurm_logs/imagenav-ver-%j.out 4 | #SBATCH --error=slurm_logs/imagenav-ver-%j.err 5 | #SBATCH --gpus 4 6 | #SBATCH --nodes 1 7 | #SBATCH --cpus-per-task 10 8 | #SBATCH --ntasks-per-node 4 9 | #SBATCH --constraint=a40 10 | #SBATCH --partition=short 11 | #SBATCH --exclude=cheetah,samantha,xaea-12,kitt 12 | #SBATCH --signal=USR1@100 13 | #SBATCH --requeue 14 | 15 | export GLOG_minloglevel=2 16 | export HABITAT_SIM_LOG=quiet 17 | export MAGNUM_LOG=quiet 18 | 19 | MAIN_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) 20 | export MAIN_ADDR 21 | 22 | source /srv/flash1/rramrakhya3/miniconda3/etc/profile.d/conda.sh 23 | conda deactivate 24 | conda activate ovon 25 | 26 | export PYTHONPATH=/srv/flash1/rramrakhya3/spring_2023/habitat-sim/src_python/ 27 | 28 | TENSORBOARD_DIR="tb/imagenav/ver/resnetclip_avgattnpool/seed_1/" 29 | CHECKPOINT_DIR="data/new_checkpoints/imagenav/ver/resnetclip_avgattnpool/seed_1/" 30 | DATA_PATH="data/datasets/imagenav/hm3d/v1_stretch" 31 | 32 | srun python -um ovon.run \ 33 | --run-type train \ 34 | --exp-config config/experiments/ver_imagenav.yaml \ 35 | habitat_baselines.num_environments=32 \ 36 | habitat_baselines.tensorboard_dir=${TENSORBOARD_DIR} \ 37 | habitat_baselines.checkpoint_folder=${CHECKPOINT_DIR} \ 38 | habitat.dataset.data_path=${DATA_PATH}/train/train.json.gz \ 39 | habitat.simulator.type="OVONSim-v0" \ 40 | habitat_baselines.log_interval=20 \ 41 | -------------------------------------------------------------------------------- /ovon/models/encoders/siglip_encoder.py: -------------------------------------------------------------------------------- 1 | import timm 2 | import torch 3 | import torch.nn as nn 4 | from habitat_baselines.common.tensor_dict import TensorDict 5 | from torchvision import transforms 6 | 7 | 8 | class SigLIPEncoder(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | self.model = timm.create_model( 12 | model_name="vit_base_patch16_siglip_256", 13 | pretrained=True, 14 | num_classes=0, 15 | ) 16 | self.model = self.model.eval() 17 | self.transforms = transforms.Compose( 18 | [ 19 | transforms.Resize( 20 | size=(256, 256), interpolation=transforms.InterpolationMode.BICUBIC 21 | ), 22 | transforms.Normalize( 23 | mean=torch.tensor([0.5000, 0.5000, 0.5000]), 24 | std=torch.tensor([0.5000, 0.5000, 0.5000]), 25 | ), 26 | ] 27 | ) 28 | self.output_size = 768 29 | self.output_shape = (self.output_size,) 30 | 31 | def forward(self, observations: TensorDict, *args, **kwargs) -> torch.Tensor: 32 | rgb = observations["rgb"] 33 | rgb = rgb.permute(0, 3, 1, 2) # NHWC -> NCHW 34 | if rgb.dtype == torch.uint8: 35 | rgb = rgb.float() / 255.0 36 | else: 37 | assert (rgb >= 0.0).all() and (rgb <= 1.0).all() 38 | with torch.inference_mode(): 39 | # Output will be of size Bx3x224x224 40 | x = self.transforms(rgb) 41 | # Embedding will be 1x768 42 | x = self.model(x) 43 | 44 | return x 45 | -------------------------------------------------------------------------------- /ovon/models/encoders/dinov2_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.hub 3 | import torch.nn as nn 4 | from torchvision import transforms 5 | 6 | 7 | class DINOV2Encoder(nn.Module): 8 | def __init__(self, backbone_name="dinov2_vitb14", output_size=768): 9 | super().__init__() 10 | self.model = torch.hub.load( 11 | repo_or_dir="facebookresearch/dinov2", model=backbone_name 12 | ) 13 | self.model.eval() 14 | self.model_transforms = self.make_depth_transform() 15 | self.output_size = output_size 16 | self.output_shape = (self.output_size,) 17 | 18 | def make_depth_transform(self): 19 | return transforms.Compose( 20 | [ 21 | transforms.Normalize( 22 | mean=(123.675, 116.28, 103.53), 23 | std=(58.395, 57.12, 57.375), 24 | ), 25 | ] 26 | ) 27 | 28 | def forward(self, observations: "TensorDict", *args, **kwargs) -> torch.Tensor: 29 | # rgb is a tensor of shape (batch_size, height, width, channels) 30 | rgb = observations["rgb"] 31 | 32 | # Assert that the rgb images are of type uint8 33 | assert rgb.dtype == torch.uint8 34 | 35 | # Assert that the height and width are both 224 36 | assert rgb.shape[1] == 224 and rgb.shape[2] == 224 37 | 38 | rgb = rgb.float() 39 | 40 | # PyTorch models expect the input in (batch, channels, height, width) format 41 | rgb = rgb.permute(0, 3, 1, 2) 42 | 43 | with torch.inference_mode(): 44 | x = self.model(self.model_transforms(rgb)) 45 | 46 | return x 47 | -------------------------------------------------------------------------------- /ovon/task/simulator.py: -------------------------------------------------------------------------------- 1 | import habitat_sim 2 | from habitat.core.registry import registry 3 | from habitat.sims.habitat_simulator.habitat_simulator import HabitatSim 4 | from omegaconf import DictConfig 5 | 6 | 7 | @registry.register_simulator(name="OVONSim-v0") 8 | class OVONSim(HabitatSim): 9 | def __init__(self, config: DictConfig) -> None: 10 | super().__init__(config) 11 | self.navmesh_settings = self.load_navmesh_settings() 12 | self.recompute_navmesh( 13 | self.pathfinder, 14 | self.navmesh_settings, 15 | include_static_objects=False, 16 | ) 17 | self.curr_scene_goals = {} 18 | 19 | def load_navmesh_settings(self): 20 | agent_cfg = self.habitat_config.agents.main_agent 21 | navmesh_settings = habitat_sim.NavMeshSettings() 22 | navmesh_settings.set_defaults() 23 | navmesh_settings.agent_height = agent_cfg.height 24 | navmesh_settings.agent_radius = agent_cfg.radius 25 | navmesh_settings.agent_max_climb = ( 26 | self.habitat_config.navmesh_settings.agent_max_climb 27 | ) 28 | navmesh_settings.cell_height = self.habitat_config.navmesh_settings.cell_height 29 | return navmesh_settings 30 | 31 | def reconfigure( 32 | self, 33 | habitat_config: DictConfig, 34 | should_close_on_new_scene: bool = True, 35 | ): 36 | is_same_scene = habitat_config.scene == self._current_scene 37 | super().reconfigure(habitat_config, should_close_on_new_scene) 38 | if not is_same_scene: 39 | self.recompute_navmesh( 40 | self.pathfinder, 41 | self.navmesh_settings, 42 | include_static_objects=False, 43 | ) 44 | self.curr_scene_goals = {} 45 | -------------------------------------------------------------------------------- /ovon/models/encoders/make_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from gym import spaces 4 | 5 | from ovon.models.encoders.clip_encoder import ResNetCLIPEncoder 6 | from ovon.models.encoders.dinov2_encoder import DINOV2Encoder 7 | from ovon.models.encoders.siglip_encoder import SigLIPEncoder 8 | from ovon.models.encoders.vc1_encoder import VC1Encoder 9 | 10 | POSSIBLE_ENCODERS = [ 11 | "clip_attnpool", 12 | "clip_avgpool", 13 | "clip_avgattnpool", 14 | "vc1", 15 | "dinov2", 16 | "resnet", 17 | "siglip", 18 | ] 19 | 20 | 21 | def make_encoder(backbone: str, observation_space: spaces.Dict) -> Any: 22 | if backbone == "resnet50_clip_avgpool": 23 | backbone = "clip_avgpool" 24 | print("WARNING: resnet50_clip_avgpool is deprecated. Use clip_avgpool instead.") 25 | 26 | assert ( 27 | backbone in POSSIBLE_ENCODERS 28 | ), f"Backbone {backbone} not found. Possible encoders: {POSSIBLE_ENCODERS}" 29 | 30 | if "clip" in backbone: 31 | backbone_type = backbone.split("_")[1] 32 | return ResNetCLIPEncoder( 33 | observation_space, 34 | backbone_type=backbone_type, 35 | clip_model="RN50", 36 | ) 37 | elif backbone == "vc1": 38 | return VC1Encoder() 39 | elif backbone == "dinov2": 40 | return DINOV2Encoder() 41 | elif backbone == "siglip": 42 | return SigLIPEncoder() 43 | elif backbone == "resnet": 44 | resnet_baseplanes = 32 45 | from habitat_baselines.rl.ddppo.policy import resnet 46 | from habitat_baselines.rl.ddppo.policy.resnet_policy import ResNetEncoder 47 | 48 | return ResNetEncoder( 49 | observation_space=observation_space, 50 | baseplanes=resnet_baseplanes, 51 | ngroups=resnet_baseplanes // 2, 52 | make_backbone=resnet.resnet50, 53 | ) 54 | -------------------------------------------------------------------------------- /ovon/utils/plot_tsne.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pandas as pd 6 | import seaborn as sns 7 | from sklearn.manifold import TSNE 8 | 9 | from ovon.utils.utils import load_pickle 10 | 11 | 12 | def plot_tsne(embeddings, output_path): 13 | features = list(embeddings.values()) 14 | categories = list(embeddings.keys()) 15 | 16 | tsne = TSNE(n_components=2, verbose=1, perplexity=20, n_iter=500) 17 | tsne_results = tsne.fit_transform(features) 18 | 19 | df = pd.DataFrame( 20 | { 21 | "tsne-2d-one": tsne_results[:, 0], 22 | "tsne-2d-two": tsne_results[:, 1], 23 | "labels": categories, 24 | } 25 | ) 26 | 27 | pallete_size = np.unique(categories).shape[0] 28 | 29 | colors = sns.color_palette("hls", pallete_size) 30 | color_map = {} 31 | for i in range(len(categories)): 32 | color_map[categories[i]] = colors[i] 33 | 34 | print(color_map) 35 | 36 | plt.figure(figsize=(16, 10)) 37 | sns.scatterplot( 38 | data=df, 39 | x="tsne-2d-one", 40 | y="tsne-2d-two", 41 | hue="labels", 42 | palette=color_map, 43 | legend="full", 44 | s=120, 45 | ) 46 | 47 | plt.savefig(output_path, bbox_inches="tight", dpi=300) 48 | return tsne_results 49 | 50 | 51 | def plot_embeddings(path, output_path): 52 | clip_embeddings = load_pickle(path) 53 | plot_tsne(clip_embeddings, output_path) 54 | 55 | 56 | def main(): 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument("--path", type=str, required=True, help="path to the dataset") 59 | parser.add_argument( 60 | "--output-path", type=str, required=True, help="path to the dataset" 61 | ) 62 | 63 | args = parser.parse_args() 64 | 65 | plot_embeddings(args.path, args.output_path) 66 | 67 | 68 | if __name__ == "__main__": 69 | main() 70 | -------------------------------------------------------------------------------- /ovon/measurements/collision_penalty.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any 3 | 4 | from habitat import EmbodiedTask, Measure, Simulator, registry 5 | from habitat.config.default_structured_configs import MeasurementConfig 6 | from hydra.core.config_store import ConfigStore 7 | from omegaconf import DictConfig 8 | 9 | 10 | @registry.register_measure 11 | class CollisionPenalty(Measure): 12 | """ 13 | Returns a penalty value if the robot has collided. 14 | """ 15 | 16 | cls_uuid: str = "collision_penalty" 17 | 18 | def __init__(self, sim: Simulator, config: "DictConfig", *args: Any, **kwargs: Any): 19 | self._sim = sim 20 | self._config = config 21 | self._collision_penalty = config.collision_penalty 22 | super().__init__() 23 | 24 | def _get_uuid(self, *args: Any, **kwargs: Any) -> str: 25 | return self.cls_uuid 26 | 27 | def reset_metric(self, episode, task, *args: Any, **kwargs: Any): 28 | task.measurements.check_measure_dependencies(self.uuid, ["collisions"]) 29 | self.update_metric(episode=episode, task=task, *args, **kwargs) # type: ignore 30 | 31 | def update_metric(self, episode, task: EmbodiedTask, *args: Any, **kwargs: Any): 32 | collisions = task.measurements.measures["collisions"].get_metric() 33 | collided = collisions is not None and collisions["is_collision"] 34 | if collided: 35 | self._metric = -self._collision_penalty 36 | else: 37 | self._metric = 0 38 | 39 | 40 | @dataclass 41 | class CollisionPenaltyMeasurementConfig(MeasurementConfig): 42 | type: str = CollisionPenalty.__name__ 43 | collision_penalty: float = 0.003 44 | 45 | 46 | cs = ConfigStore.instance() 47 | cs.store( 48 | package=f"habitat.task.measurements.{CollisionPenalty.cls_uuid}", 49 | group="habitat/task/measurements", 50 | name=f"{CollisionPenalty.cls_uuid}", 51 | node=CollisionPenaltyMeasurementConfig, 52 | ) 53 | -------------------------------------------------------------------------------- /ovon/obs_transformers/relabel_imagegoal.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from dataclasses import dataclass 3 | from typing import Dict 4 | 5 | import torch 6 | from gym import spaces 7 | from habitat_baselines.common.baseline_registry import baseline_registry 8 | from habitat_baselines.common.obs_transformers import ObservationTransformer 9 | from habitat_baselines.config.default_structured_configs import ( 10 | ObsTransformConfig, 11 | ) 12 | from hydra.core.config_store import ConfigStore 13 | from omegaconf import DictConfig 14 | 15 | from ovon.task.sensors import ClipImageGoalSensor, ImageGoalRotationSensor 16 | 17 | 18 | @baseline_registry.register_obs_transformer() 19 | class RelabelImageGoal(ObservationTransformer): 20 | """Renames ImageGoalRotationSensor to ClipImageGoalSensor""" 21 | 22 | def transform_observation_space(self, observation_space: spaces.Dict, **kwargs): 23 | assert ImageGoalRotationSensor.cls_uuid in observation_space.spaces 24 | observation_space = copy.deepcopy(observation_space) 25 | observation_space.spaces[ClipImageGoalSensor.cls_uuid] = ( 26 | observation_space.spaces.pop(ImageGoalRotationSensor.cls_uuid) 27 | ) 28 | return observation_space 29 | 30 | @classmethod 31 | def from_config(cls, config: DictConfig): 32 | return cls() 33 | 34 | def forward(self, observations: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 35 | observations[ClipImageGoalSensor.cls_uuid] = observations.pop( 36 | ImageGoalRotationSensor.cls_uuid 37 | ) 38 | return observations 39 | 40 | 41 | @dataclass 42 | class RelabelImageGoalConfig(ObsTransformConfig): 43 | type: str = RelabelImageGoal.__name__ 44 | 45 | 46 | cs = ConfigStore.instance() 47 | 48 | cs.store( 49 | package="habitat_baselines.rl.policy.obs_transforms.relabel_image_goal", 50 | group="habitat_baselines/rl/policy/obs_transforms", 51 | name="relabel_image_goal", 52 | node=RelabelImageGoalConfig, 53 | ) 54 | -------------------------------------------------------------------------------- /ovon/trainers/ppo_trainer_no_2d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from gym import spaces 4 | from habitat_baselines import PPOTrainer 5 | from habitat_baselines.common.baseline_registry import baseline_registry 6 | from habitat_baselines.rl.ddppo.policy import PointNavResNetNet 7 | 8 | from ovon.utils.rollout_storage_no_2d import RolloutStorageNo2D 9 | 10 | 11 | @baseline_registry.register_trainer(name="ddppo_no_2d") 12 | @baseline_registry.register_trainer(name="ppo_no_2d") 13 | class PPONo2DTrainer(PPOTrainer): 14 | def _init_train(self, *args, **kwargs): 15 | super()._init_train(*args, **kwargs) 16 | # Hacky overwriting of existing RolloutStorage with a new one 17 | ppo_cfg = self.config.habitat_baselines.rl.ppo 18 | action_shape = self.rollouts.buffers["actions"].shape[2:] 19 | discrete_actions = self.rollouts.buffers["actions"].dtype == torch.long 20 | batch = self.rollouts.buffers["observations"][0] 21 | 22 | obs_space = spaces.Dict( 23 | { 24 | PointNavResNetNet.PRETRAINED_VISUAL_FEATURES_KEY: spaces.Box( 25 | low=np.finfo(np.float32).min, 26 | high=np.finfo(np.float32).max, 27 | shape=self._encoder.output_shape, 28 | dtype=np.float32, 29 | ), 30 | **self.obs_space.spaces, 31 | } 32 | ) 33 | 34 | self.rollouts = RolloutStorageNo2D( 35 | self.actor_critic.net.visual_encoder, 36 | batch, 37 | ppo_cfg.num_steps, 38 | self.envs.num_envs, 39 | obs_space, 40 | self.policy_action_space, 41 | ppo_cfg.hidden_size, 42 | num_recurrent_layers=self.actor_critic.net.num_recurrent_layers, 43 | is_double_buffered=ppo_cfg.use_double_buffered_sampler, 44 | action_shape=action_shape, 45 | discrete_actions=discrete_actions, 46 | ) 47 | self.rollouts.to(self.device) 48 | -------------------------------------------------------------------------------- /ovon/models/visual_encoders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from vc_models.models.vit import model_utils 4 | 5 | 6 | class Vc1Wrapper(nn.Module): 7 | """ 8 | Wrapper for the VC1 visual encoder. This will automatically download the model if it's not already. 9 | """ 10 | 11 | def __init__(self, im_obs_space, model_id=None): 12 | super().__init__() 13 | if model_id is None: 14 | model_id = model_utils.VC1_BASE_NAME 15 | print(f"loading {model_id}.") 16 | ( 17 | self.net, 18 | self.embd_size, 19 | self.model_transforms, 20 | model_info, 21 | ) = model_utils.load_model(model_id) 22 | self._image_obs_keys = im_obs_space.spaces.keys() 23 | 24 | # Count total # of channels 25 | self._n_input_channels = sum( 26 | im_obs_space.spaces[k].shape[2] for k in self._image_obs_keys 27 | ) 28 | 29 | @property 30 | def is_blind(self): 31 | return self._n_input_channels == 0 32 | 33 | @torch.autocast("cuda") 34 | def forward(self, obs): 35 | # Extract tensors that are shape [batch_size, img_width, img_height, img_channels] 36 | feats = [] 37 | imgs = [v for k, v in obs.items() if k in self._image_obs_keys] 38 | for img in imgs: 39 | if img.shape[-1] != 3: 40 | img = torch.concat([img] * 3, dim=-1) 41 | scale_factor = 1.0 42 | else: 43 | scale_factor = 255.0 44 | 45 | img = self.model_transforms( 46 | img.permute(0, 3, 1, 2).contiguous() / scale_factor 47 | ) 48 | 49 | feats.append(self.net(img)) 50 | 51 | if len(feats) == 2: 52 | # feats = (feats[0] + feats[1])/2 53 | feats = torch.concat(feats, dim=-1) 54 | else: 55 | feats = feats[0] 56 | return feats 57 | 58 | @property 59 | def output_shape(self): 60 | return (self.embd_size * len(self._image_obs_keys),) 61 | -------------------------------------------------------------------------------- /scripts/train/1-ovon-ver.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=ovon 3 | #SBATCH --output=slurm_logs/ovon-ver-%j.out 4 | #SBATCH --error=slurm_logs/ovon-ver-%j.err 5 | #SBATCH --gpus 4 6 | #SBATCH --nodes 1 7 | #SBATCH --cpus-per-task 10 8 | #SBATCH --ntasks-per-node 4 9 | #SBATCH --constraint=a40 10 | #SBATCH --partition=short 11 | #SBATCH --exclude=cheetah,samantha,xaea-12,kitt 12 | #SBATCH --signal=USR1@100 13 | #SBATCH --requeue 14 | 15 | export GLOG_minloglevel=2 16 | export HABITAT_SIM_LOG=quiet 17 | export MAGNUM_LOG=quiet 18 | 19 | MAIN_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) 20 | export MAIN_ADDR 21 | 22 | source /srv/flash1/rramrakhya3/miniconda3/etc/profile.d/conda.sh 23 | conda deactivate 24 | conda activate ovon 25 | 26 | export PYTHONPATH=/srv/flash1/rramrakhya3/spring_2023/habitat-sim/src_python/ 27 | 28 | TENSORBOARD_DIR="tb/ovon/ver/resnetclip_rgb_text/seed_2/" 29 | CHECKPOINT_DIR="data/new_checkpoints/ovon/ver/resnetclip_rgb_text/seed_2/" 30 | DATA_PATH="data/datasets/ovon/hm3d/v2" 31 | 32 | srun python -um ovon.run \ 33 | --run-type train \ 34 | --exp-config config/experiments/ver_objectnav.yaml \ 35 | habitat_baselines.trainer_name="ver" \ 36 | habitat_baselines.num_environments=32 \ 37 | habitat_baselines.rl.policy.name=PointNavResNetCLIPPolicy \ 38 | habitat_baselines.rl.ddppo.train_encoder=False \ 39 | habitat_baselines.rl.ddppo.backbone=resnet50_clip_avgpool \ 40 | habitat_baselines.tensorboard_dir=${TENSORBOARD_DIR} \ 41 | habitat_baselines.checkpoint_folder=${CHECKPOINT_DIR} \ 42 | habitat.dataset.data_path=${DATA_PATH}/train/train.json.gz \ 43 | +habitat/task/lab_sensors@habitat.task.lab_sensors.clip_objectgoal_sensor=clip_objectgoal_sensor \ 44 | ~habitat.task.lab_sensors.objectgoal_sensor \ 45 | habitat.task.lab_sensors.clip_objectgoal_sensor.cache=data/clip_embeddings/ovon_stretch_final_cache.pkl \ 46 | habitat.task.measurements.success.success_distance=0.25 \ 47 | habitat.dataset.type="OVON-v1" \ 48 | habitat.task.measurements.distance_to_goal.type=OVONDistanceToGoal \ 49 | habitat.simulator.type="OVONSim-v0" 50 | -------------------------------------------------------------------------------- /ovon/measurements/sum_reward.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Any, List 3 | 4 | from habitat import EmbodiedTask, Measure, registry 5 | from habitat.config.default_structured_configs import MeasurementConfig 6 | from hydra.core.config_store import ConfigStore 7 | from omegaconf import DictConfig 8 | 9 | 10 | @registry.register_measure 11 | class SumReward(Measure): 12 | """ 13 | Sums various reward measures. 14 | """ 15 | 16 | cls_uuid: str = "sum_reward" 17 | 18 | def __init__(self, config: "DictConfig", *args: Any, **kwargs: Any): 19 | self._config = config 20 | self._reward_terms = config.reward_terms 21 | self._reward_coefficients = [float(i) for i in config.reward_coefficients] 22 | super().__init__() 23 | 24 | def _get_uuid(self, *args: Any, **kwargs: Any) -> str: 25 | return self.cls_uuid 26 | 27 | def reset_metric(self, episode, task, *args: Any, **kwargs: Any): 28 | task.measurements.check_measure_dependencies(self.uuid, self._reward_terms) 29 | self.update_metric(episode=episode, task=task, *args, **kwargs) # type: ignore 30 | 31 | def update_metric(self, episode, task: EmbodiedTask, *args: Any, **kwargs: Any): 32 | self._metric = 0 33 | for term, coefficient in zip(self._reward_terms, self._reward_coefficients): 34 | self._metric += coefficient * task.measurements.measures[term].get_metric() 35 | 36 | 37 | @dataclass 38 | class SumRewardMeasurementConfig(MeasurementConfig): 39 | type: str = SumReward.__name__ 40 | reward_terms: List[str] = field( 41 | # available options are "disk" and "tensorboard" 42 | default_factory=list 43 | ) 44 | reward_coefficients: List[str] = field( 45 | # available options are "disk" and "tensorboard" 46 | default_factory=list 47 | ) 48 | 49 | 50 | cs = ConfigStore.instance() 51 | cs.store( 52 | package=f"habitat.task.measurements.{SumReward.cls_uuid}", 53 | group="habitat/task/measurements", 54 | name=f"{SumReward.cls_uuid}", 55 | node=SumRewardMeasurementConfig, 56 | ) 57 | -------------------------------------------------------------------------------- /scripts/train/3-ovon-ver-scratch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=ovon 3 | #SBATCH --output=slurm_logs/ovon-ver-%j.out 4 | #SBATCH --error=slurm_logs/ovon-ver-%j.err 5 | #SBATCH --gpus 1 6 | #SBATCH --nodes 1 7 | #SBATCH --cpus-per-task 10 8 | #SBATCH --ntasks-per-node 1 9 | #SBATCH --constraint=a40 10 | #SBATCH --partition=short 11 | #SBATCH --exclude=cheetah,samantha,xaea-12,kitt 12 | #SBATCH --signal=USR1@100 13 | #SBATCH --requeue 14 | 15 | export GLOG_minloglevel=2 16 | export HABITAT_SIM_LOG=quiet 17 | export MAGNUM_LOG=quiet 18 | 19 | MAIN_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) 20 | export MAIN_ADDR 21 | 22 | source /srv/flash1/rramrakhya3/miniconda3/etc/profile.d/conda.sh 23 | conda deactivate 24 | conda activate ovon 25 | 26 | export PYTHONPATH=/srv/flash1/rramrakhya3/spring_2023/habitat-sim/src_python/ 27 | 28 | TENSORBOARD_DIR="tb/ovon/ver/resnet_scratch_clip_goal/seed_1/" 29 | CHECKPOINT_DIR="data/new_checkpoints/ovon/ver/resnet_scratch_clip_goal/seed_1/" 30 | DATA_PATH="data/datasets/ovon/hm3d/v5_final" 31 | 32 | srun python -um ovon.run \ 33 | --run-type train \ 34 | --exp-config config/experiments/ver_objectnav.yaml \ 35 | habitat_baselines.trainer_name="ver" \ 36 | habitat_baselines.num_environments=32 \ 37 | habitat_baselines.rl.policy.name=OVRLPolicy \ 38 | habitat_baselines.rl.ddppo.train_encoder=True \ 39 | habitat_baselines.rl.policy.backbone=resnet50 \ 40 | habitat_baselines.rl.policy.freeze_backbone=False \ 41 | habitat_baselines.tensorboard_dir=${TENSORBOARD_DIR} \ 42 | habitat_baselines.checkpoint_folder=${CHECKPOINT_DIR} \ 43 | habitat.dataset.data_path=${DATA_PATH}/train/train.json.gz \ 44 | +habitat/task/lab_sensors@habitat.task.lab_sensors.clip_objectgoal_sensor=clip_objectgoal_sensor \ 45 | ~habitat.task.lab_sensors.objectgoal_sensor \ 46 | habitat.task.lab_sensors.clip_objectgoal_sensor.cache=data/clip_embeddings/ovon_stretch_final_cache.pkl \ 47 | habitat.task.measurements.success.success_distance=0.25 \ 48 | habitat.dataset.type="OVON-v1" \ 49 | habitat.task.measurements.distance_to_goal.type=OVONDistanceToGoal \ 50 | habitat.simulator.type="OVONSim-v0" 51 | -------------------------------------------------------------------------------- /scripts/train/4-objectnav-transformer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=ovon-tf 3 | #SBATCH --output=slurm_logs/%x-%j.out 4 | #SBATCH --error=slurm_logs/%x-%j.err 5 | #SBATCH --gpus a40:4 6 | #SBATCH --nodes 1 7 | #SBATCH --cpus-per-task 16 8 | #SBATCH --ntasks-per-node 4 9 | #SBATCH --time=48:00:00 10 | #SBATCH --signal=USR1@90 11 | #SBATCH --exclude=crushinator,major,chappie,deebot,xaea-12 12 | #SBATCH --requeue 13 | #SBATCH --partition=cvmlp-lab,overcap 14 | #SBATCH --qos=short 15 | 16 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 17 | # This source code is licensed under the MIT license found in the 18 | # LICENSE file in the root directory of this source tree. 19 | 20 | export GLOG_minloglevel=2 21 | export MAGNUM_LOG=quiet 22 | export HABITAT_SIM_LOG=quiet 23 | 24 | MAIN_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) 25 | export MAIN_ADDR 26 | 27 | source /srv/flash1/rramrakhya3/summer_2023/miniconda3/etc/profile.d/conda.sh 28 | conda deactivate 29 | conda activate ovon-v2 30 | 31 | cd /srv/flash1/rramrakhya3/spring_2023/ovon 32 | 33 | export PYTHONPATH= 34 | 35 | TENSORBOARD_DIR="tb/objectnav/ddppo/vc1_llama/seed_5/" 36 | CHECKPOINT_DIR="data/new_checkpoints/objectnav/ddppo/vc1_llama/seed_5/" 37 | DATA_PATH="data/datasets/objectnav/hm3d/v2" 38 | 39 | srun python ovon/run.py --exp-config config/experiments/rl_transformer_hm3d.yaml \ 40 | --run-type train \ 41 | habitat_baselines.checkpoint_folder=$CHECKPOINT_DIR/ \ 42 | habitat_baselines.tensorboard_dir=$TENSORBOARD_DIR \ 43 | habitat_baselines.num_environments=24 \ 44 | habitat_baselines.rl.policy.transformer_config.inter_episodes_attention=False \ 45 | habitat_baselines.rl.policy.transformer_config.add_sequence_idx_embed=False \ 46 | habitat_baselines.rl.policy.transformer_config.reset_position_index=False \ 47 | habitat_baselines.rl.policy.transformer_config.max_position_embeddings=2000 \ 48 | habitat_baselines.rl.policy.transformer_config.n_hidden=1024 \ 49 | habitat.environment.max_episode_steps=500 \ 50 | habitat.dataset.data_path=${DATA_PATH}/train/train.json.gz \ 51 | habitat_baselines.rl.ppo.training_precision="float32" 52 | -------------------------------------------------------------------------------- /ovon/obs_transformers/relabel_teacher_actions.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from dataclasses import dataclass 3 | from typing import Dict 4 | 5 | import torch 6 | from gym import spaces 7 | from habitat_baselines.common.baseline_registry import baseline_registry 8 | from habitat_baselines.common.obs_transformers import ObservationTransformer 9 | from habitat_baselines.config.default_structured_configs import ObsTransformConfig 10 | from hydra.core.config_store import ConfigStore 11 | from omegaconf import DictConfig 12 | 13 | 14 | @baseline_registry.register_obs_transformer() 15 | class RelabelTeacherActions(ObservationTransformer): 16 | """Renames the entry corresponding to the given key string within the observations 17 | dict to 'teacher_actions'""" 18 | 19 | TEACHER_LABEL: str = "teacher_label" 20 | 21 | def __init__(self, teacher_label: str): 22 | super().__init__() 23 | self.teacher_label = teacher_label 24 | 25 | def transform_observation_space(self, observation_space: spaces.Dict, **kwargs): 26 | assert ( 27 | self.teacher_label in observation_space.spaces 28 | ), f"Teacher action key {self.teacher_label} not in observation space!" 29 | observation_space = copy.deepcopy(observation_space) 30 | observation_space.spaces[self.TEACHER_LABEL] = observation_space.spaces.pop( 31 | self.teacher_label 32 | ) 33 | return observation_space 34 | 35 | @classmethod 36 | def from_config(cls, config: DictConfig): 37 | return cls(config.teacher_label) 38 | 39 | def forward(self, observations: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 40 | observations[self.TEACHER_LABEL] = observations.pop(self.teacher_label) 41 | return observations 42 | 43 | 44 | @dataclass 45 | class RelabelTeacherActionsConfig(ObsTransformConfig): 46 | type: str = RelabelTeacherActions.__name__ 47 | teacher_label: str = "" 48 | 49 | 50 | cs = ConfigStore.instance() 51 | 52 | cs.store( 53 | package="habitat_baselines.rl.policy.obs_transforms.relabel_teacher_actions", 54 | group="habitat_baselines/rl/policy/obs_transforms", 55 | name="relabel_teacher_actions", 56 | node=RelabelTeacherActionsConfig, 57 | ) 58 | -------------------------------------------------------------------------------- /scripts/eval/4-objectnav-transformer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=ovon 3 | #SBATCH --output=slurm_logs/eval/ovon-tf-%j.out 4 | #SBATCH --error=slurm_logs/eval/ovon-tf-%j.err 5 | #SBATCH --gpus a40:1 6 | #SBATCH --nodes 1 7 | #SBATCH --cpus-per-task 6 8 | #SBATCH --ntasks-per-node 1 9 | #SBATCH --partition=cvmlp-lab,overcap 10 | #SBATCH --qos=short 11 | #SBATCH --signal=USR1@100 12 | 13 | export GLOG_minloglevel=2 14 | export HABITAT_SIM_LOG=quiet 15 | export MAGNUM_LOG=quiet 16 | 17 | MAIN_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) 18 | export MAIN_ADDR 19 | 20 | source /srv/flash1/rramrakhya3/summer_2023/miniconda3/etc/profile.d/conda.sh 21 | conda deactivate 22 | conda activate ovon-v2 23 | 24 | export PYTHONPATH= #/srv/flash1/rramrakhya3/spring_2023/habitat-sim/src_python/ 25 | 26 | DATA_PATH="data/datasets/objectnav/hm3d/v2/" 27 | eval_ckpt_path_dir="data/new_checkpoints/objectnav/ddppo/vc1_llama/seed_3/ckpt.0.pth" 28 | tensorboard_dir="tb/objectnav/ddppo/vc1_llama/seed_2/eval_debug/" 29 | split="val" 30 | 31 | echo "Evaluating ckpt: ${eval_ckpt_path_dir}" 32 | echo "Data path: ${DATA_PATH}/${split}/${split}.json.gz" 33 | 34 | python -um ovon.run \ 35 | --run-type eval \ 36 | --exp-config config/experiments/rl_transformer_hm3d.yaml \ 37 | -cvt \ 38 | habitat_baselines.trainer_name=transformer_ddppo \ 39 | habitat_baselines.tensorboard_dir=$tensorboard_dir \ 40 | habitat_baselines.eval_ckpt_path_dir=$eval_ckpt_path_dir \ 41 | habitat_baselines.checkpoint_folder=$eval_ckpt_path_dir \ 42 | habitat.dataset.data_path="${DATA_PATH}/${split}/${split}.json.gz" \ 43 | habitat_baselines.num_environments=2 \ 44 | habitat_baselines.rl.ppo.training_precision="float32" \ 45 | habitat_baselines.rl.policy.transformer_config.inter_episodes_attention=False \ 46 | habitat_baselines.rl.policy.transformer_config.add_sequence_idx_embed=False \ 47 | habitat_baselines.rl.policy.transformer_config.reset_position_index=False \ 48 | habitat_baselines.rl.policy.transformer_config.max_position_embeddings=2000 \ 49 | habitat.environment.max_episode_steps=500 \ 50 | habitat.dataset.data_path="data/datasets/objectnav/hm3d/v2/val/val.json.gz" \ 51 | habitat.dataset.split=val \ 52 | habitat_baselines.eval.use_ckpt_config=False \ 53 | habitat_baselines.load_resume_state_config=False \ 54 | habitat.simulator.habitat_sim_v0.allow_sliding=False \ 55 | habitat_baselines.eval.split=$split 56 | -------------------------------------------------------------------------------- /ovon/utils/rollout_storage_no_2d.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from habitat_baselines import RolloutStorage 4 | from habitat_baselines.common.tensor_dict import TensorDict 5 | from habitat_baselines.rl.ddppo.policy import PointNavResNetNet 6 | from torch import nn 7 | 8 | 9 | class RolloutStorageNo2D(RolloutStorage): 10 | """RolloutStorage variant that will store visual features extracted using a given 11 | visual encoder instead of raw images to save space.""" 12 | 13 | buffers: TensorDict 14 | visual_encoder: nn.Module 15 | 16 | def __init__( 17 | self, visual_encoder, initial_obs: Optional[TensorDict] = None, *args, **kwargs 18 | ): 19 | super().__init__(*args, **kwargs) 20 | # Remove any 2D observations from the rollout storage buffer 21 | delete_keys = [] 22 | for sensor in self.buffers["observations"]: 23 | if self.buffers["observations"][sensor].dim() >= 4: # NCHW -> 4 dims 24 | delete_keys.append(sensor) 25 | for k in delete_keys: 26 | del self.buffers["observations"][k] 27 | self.visual_encoder = visual_encoder 28 | if initial_obs is not None: 29 | self.buffers["observations"][0] = self.filter_obs(initial_obs) 30 | 31 | def filter_obs(self, obs: TensorDict) -> TensorDict: 32 | filtered_obs = TensorDict() 33 | for sensor in obs: 34 | # Filter out 2D observations 35 | if obs[sensor].dim() < 4: 36 | filtered_obs[sensor] = obs[sensor] 37 | # Extract visual features from 2D observations 38 | filtered_obs[PointNavResNetNet.PRETRAINED_VISUAL_FEATURES_KEY] = ( 39 | self.visual_encoder(obs) 40 | ) 41 | return filtered_obs 42 | 43 | def insert( 44 | self, 45 | next_observations=None, 46 | next_recurrent_hidden_states=None, 47 | actions=None, 48 | action_log_probs=None, 49 | value_preds=None, 50 | rewards=None, 51 | next_masks=None, 52 | buffer_index: int = 0, 53 | ): 54 | if next_observations is not None: 55 | filtered_next_observations = self.filter_obs(next_observations) 56 | else: 57 | filtered_next_observations = None 58 | super().insert( # noqa 59 | filtered_next_observations, 60 | next_recurrent_hidden_states, 61 | actions, 62 | action_log_probs, 63 | value_preds, 64 | rewards, 65 | next_masks, 66 | buffer_index, 67 | ) 68 | -------------------------------------------------------------------------------- /ovon/models/encoders/cma_xattn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | 7 | class CrossModalAttention(nn.Module): 8 | def __init__( 9 | self, text_embedding_dim: int, rgb_embedding_dim: int, hidden_size: int = 512 10 | ) -> None: 11 | super().__init__() 12 | 13 | # Linear transformation for the text query and RGB key-value pairs 14 | self.text_q = nn.Linear(text_embedding_dim, hidden_size // 2) 15 | self.rgb_kv = nn.Conv1d(rgb_embedding_dim, hidden_size, 1) 16 | 17 | # Scale for attention 18 | self.register_buffer("_scale", torch.tensor(1.0 / ((hidden_size // 2) ** 0.5))) 19 | 20 | self._hidden_size = hidden_size 21 | 22 | @property 23 | def output_size(self) -> int: 24 | return self._hidden_size // 2 25 | 26 | def _attn(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 27 | logits = torch.einsum("nc, nci -> ni", q, k) 28 | attn = F.softmax(logits * self._scale, dim=1) 29 | return torch.einsum("ni, nci -> nc", attn, v) 30 | 31 | def forward(self, text_embedding: Tensor, rgb_embedding: Tensor) -> Tensor: 32 | """ 33 | Args: 34 | text_embedding: [batch_size, text_embedding_dim] tensor (language) 35 | rgb_embedding: [batch_size, rgb_embedding_dim] tensor (visual) 36 | 37 | Returns: 38 | [batch_size, embed_dim] tensor 39 | """ 40 | # Reshape rgb_embedding tensor to [batch_size, rgb_embedding_dim, 1] 41 | rgb_embedding_reshaped = rgb_embedding.unsqueeze(2) 42 | rgb_kv = self.rgb_kv(rgb_embedding_reshaped) 43 | rgb_k, rgb_v = torch.split(rgb_kv, self.text_q.out_features, dim=1) 44 | text_q = self.text_q(text_embedding) 45 | rgb_embedding = self._attn(text_q, rgb_k, rgb_v) 46 | 47 | return rgb_embedding 48 | 49 | 50 | if __name__ == "__main__": 51 | # Define embedding dimensions 52 | text_embedding_dim = 1024 53 | rgb_embedding_dim = 1024 54 | 55 | # Instantiate the model 56 | model = CrossModalAttention(text_embedding_dim, rgb_embedding_dim, hidden_size=512) 57 | print(f"Model: \n{model}\n") 58 | 59 | # Generate random embeddings 60 | batch_size = 1 61 | text_embedding = torch.rand(batch_size, text_embedding_dim) 62 | rgb_embedding = torch.rand(batch_size, rgb_embedding_dim) 63 | 64 | # Pass embeddings through the model 65 | output = model(text_embedding, rgb_embedding) 66 | 67 | # Print output 68 | print(f"Output Shape: {output.shape}") 69 | -------------------------------------------------------------------------------- /ovon/utils/visualize_policy_embeddings.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from gym import spaces 6 | 7 | from ovon.models.clip_policy import PointNavResNetCLIPPolicy 8 | from ovon.utils.plot_tsne import plot_tsne 9 | from ovon.utils.utils import save_pickle 10 | 11 | 12 | def visualize_policy_embeddings(checkpoint_path, output_path): 13 | checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) 14 | print(checkpoint.keys()) 15 | 16 | OBJECT_MAPPING = { 17 | "chair": 0, 18 | "bed": 1, 19 | "plant": 2, 20 | "toilet": 3, 21 | "tv_monitor": 4, 22 | "sofa": 5, 23 | } 24 | 25 | h, w = ( 26 | 480, 27 | 640, 28 | ) 29 | 30 | observation_space = { 31 | "compass": spaces.Box(low=-np.pi, high=np.pi, shape=(1,), dtype=np.float32), 32 | "gps": spaces.Box( 33 | low=np.finfo(np.float32).min, 34 | high=np.finfo(np.float32).max, 35 | shape=(2,), 36 | dtype=np.float32, 37 | ), 38 | "rgb": spaces.Box( 39 | low=0, 40 | high=255, 41 | shape=(h, w, 3), 42 | dtype=np.uint8, 43 | ), 44 | "objectgoal": spaces.Box( 45 | low=0, high=len(OBJECT_MAPPING) - 1, shape=(1,), dtype=np.int64 46 | ), 47 | } 48 | 49 | observation_space = spaces.Dict(observation_space) 50 | 51 | action_space = spaces.Discrete(6) 52 | 53 | policy = PointNavResNetCLIPPolicy.from_config( 54 | checkpoint["config"], observation_space, action_space 55 | ) 56 | policy.load_state_dict( 57 | {k.replace("actor_critic.", ""): v for k, v in checkpoint["state_dict"].items()} 58 | ) 59 | 60 | policy.eval() 61 | print("Policy initialized....") 62 | 63 | embeddings = {} 64 | for cat, cat_id in OBJECT_MAPPING.items(): 65 | out = policy.net.obj_categories_embedding(torch.tensor(cat_id)) 66 | embeddings[cat] = out.detach().numpy() 67 | 68 | save_pickle(embeddings, "data/clip_embeddings/hm3d_onehot.pkl") 69 | plot_tsne(embeddings, output_path) 70 | 71 | 72 | if __name__ == "__main__": 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument( 75 | "--checkpoint", 76 | type=str, 77 | required=True, 78 | help="path to the habitat baselines file", 79 | ) 80 | parser.add_argument( 81 | "--output-path", type=str, required=True, help="path to the TSNE plot" 82 | ) 83 | 84 | args = parser.parse_args() 85 | visualize_policy_embeddings(args.checkpoint, args.output_path) 86 | -------------------------------------------------------------------------------- /scripts/eval/3-ovon-ver-eval-analysis.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=ovon 3 | #SBATCH --output=slurm_logs/eval/ovon-ver-%j.out 4 | #SBATCH --error=slurm_logs/eval/ovon-ver-%j.err 5 | #SBATCH --gpus 1 6 | #SBATCH --nodes 1 7 | #SBATCH --cpus-per-task 10 8 | #SBATCH --ntasks-per-node 1 9 | #SBATCH --constraint="a40|rtx_6000" 10 | #SBATCH --partition=short 11 | #SBATCH --exclude=cheetah,samantha,xaea-12,kitt,calculon,vicki,neo,kipp,ripl-s1,tars 12 | #SBATCH --signal=USR1@100 13 | 14 | export GLOG_minloglevel=2 15 | export HABITAT_SIM_LOG=quiet 16 | export MAGNUM_LOG=quiet 17 | 18 | MAIN_ADDR=$(scontrol show hostnames "${SLURM_JOB_NODELIST}" | head -n 1) 19 | export MAIN_ADDR 20 | 21 | source /srv/flash1/rramrakhya3/miniconda3/etc/profile.d/conda.sh 22 | conda deactivate 23 | conda activate ovon 24 | 25 | export PYTHONPATH=/srv/flash1/rramrakhya3/spring_2023/habitat-sim/src_python/ 26 | 27 | DATA_PATH="data/datasets/ovon/hm3d/v5_final/" 28 | eval_ckpt_path_dir="data/new_checkpoints/ovon/ver/resnetclip_rgb_text/seed_1/ckpt.80.pth" 29 | tensorboard_dir="tb/ovon/ver/resnetclip_rgb_text/seed_1/eval_val_seen_debug/" 30 | split="val_seen" 31 | OVON_EPISODES_JSON="data/analysis/episode_metrics/rl_ckpt_80_additional_metrics_${split}.json" 32 | 33 | echo "Evaluating ckpt: ${eval_ckpt_path_dir}" 34 | echo "Data path: ${DATA_PATH}/${split}/${split}.json.gz" 35 | 36 | srun python -um ovon.run \ 37 | --run-type eval \ 38 | --exp-config config/experiments/ver_objectnav.yaml \ 39 | -cvt \ 40 | habitat_baselines.num_environments=2 \ 41 | habitat_baselines.test_episode_count=4 \ 42 | habitat_baselines.trainer_name=ver_pirlnav \ 43 | habitat_baselines.rl.policy.name=PointNavResNetCLIPPolicy \ 44 | habitat_baselines.tensorboard_dir=$tensorboard_dir \ 45 | habitat_baselines.eval_ckpt_path_dir=$eval_ckpt_path_dir \ 46 | habitat_baselines.checkpoint_folder=$eval_ckpt_path_dir \ 47 | habitat.dataset.data_path="${DATA_PATH}/${split}/${split}.json.gz" \ 48 | +habitat/task/lab_sensors@habitat.task.lab_sensors.clip_objectgoal_sensor=clip_objectgoal_sensor \ 49 | habitat.task.lab_sensors.clip_objectgoal_sensor.cache=data/clip_embeddings/ovon_stretch_cache.pkl \ 50 | habitat.task.measurements.success.success_distance=0.25 \ 51 | habitat.dataset.type="OVON-v1" \ 52 | habitat.task.measurements.distance_to_goal.type=OVONDistanceToGoal \ 53 | +habitat/task/measurements@habitat.task.measurements.ovon_object_goal_id=ovon_object_goal_id \ 54 | +habitat/task/measurements@habitat.task.measurements.failure_modes=failure_modes \ 55 | habitat.simulator.type="OVONSim-v0" \ 56 | habitat_baselines.eval.use_ckpt_config=False \ 57 | habitat_baselines.load_resume_state_config=False \ 58 | habitat.simulator.habitat_sim_v0.allow_sliding=False \ 59 | habitat_baselines.eval.split=$split 60 | 61 | touch $checkpoint_counter 62 | -------------------------------------------------------------------------------- /ovon/models/encoders/depth_encoder.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from gym import spaces 4 | from gym.spaces import Dict as SpaceDict 5 | from habitat_baselines.rl.ddppo.policy.resnet import resnet18 6 | from habitat_baselines.rl.ddppo.policy.resnet_policy import ResNetEncoder 7 | from torch import nn as nn 8 | 9 | 10 | class ResNet18DepthEncoder(nn.Module): 11 | def __init__(self, depth_encoder, visual_fc): 12 | super().__init__() 13 | self.encoder = depth_encoder 14 | self.visual_fc = visual_fc 15 | 16 | def forward(self, x): 17 | x = self.encoder(x) 18 | x = self.visual_fc(x) 19 | return x 20 | 21 | def load_state_dict(self, state_dict, strict: bool = True): 22 | # TODO: allow dicts trained with both attn and avg pool to be loaded 23 | ignore_attnpool = False 24 | if ignore_attnpool: 25 | pass 26 | return super().load_state_dict(state_dict, strict=strict) 27 | 28 | 29 | def copy_depth_encoder(depth_ckpt): 30 | """ 31 | Returns an encoder that stacks the encoder and visual_fc of the provided 32 | depth checkpoint 33 | :param depth_ckpt: path to a resnet18 depth pointnav policy 34 | :return: nn.Module representing the backbone of the depth policy 35 | """ 36 | # Initialize encoder and fc layers 37 | base_planes = 32 38 | ngroups = 32 39 | spatial_size = 128 40 | 41 | observation_space = SpaceDict( 42 | { 43 | "depth": spaces.Box( 44 | low=0.0, high=1.0, shape=(256, 256, 1), dtype=np.float32 45 | ), 46 | } 47 | ) 48 | depth_encoder = ResNetEncoder( 49 | observation_space, 50 | base_planes, 51 | ngroups, 52 | spatial_size, 53 | make_backbone=resnet18, 54 | ) 55 | 56 | flat_output_shape = 2048 57 | hidden_size = 512 58 | visual_fc = nn.Sequential( 59 | nn.Flatten(), 60 | nn.Linear(flat_output_shape, hidden_size), 61 | nn.ReLU(True), 62 | ) 63 | 64 | pretrained_state = torch.load(depth_ckpt, map_location="cpu") 65 | 66 | # Load weights into depth encoder 67 | depth_encoder_state_dict = { 68 | k[len("actor_critic.net.visual_encoder.") :]: v 69 | for k, v in pretrained_state["state_dict"].items() 70 | if k.startswith("actor_critic.net.visual_encoder.") 71 | } 72 | depth_encoder.load_state_dict(depth_encoder_state_dict) 73 | 74 | # Load weights in fc layers 75 | visual_fc_state_dict = { 76 | k[len("actor_critic.net.visual_fc.") :]: v 77 | for k, v in pretrained_state["state_dict"].items() 78 | if k.startswith("actor_critic.net.visual_fc.") 79 | } 80 | visual_fc.load_state_dict(visual_fc_state_dict) 81 | 82 | modified_depth_encoder = ResNet18DepthEncoder(depth_encoder, visual_fc) 83 | 84 | return modified_depth_encoder 85 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "ovon" 7 | version = "0.1" 8 | description = "open-vocab-objectnav" 9 | authors = [ 10 | {name = "Naoki Yokoyama", email = "naokiyokoyama@github.com"}, 11 | ] 12 | readme = "README.md" 13 | requires-python = ">=3.7" 14 | 15 | [project.optional-dependencies] 16 | dev = [ 17 | "pre-commit >= 2.21.0", 18 | ] 19 | 20 | [project.urls] 21 | "Homepage" = "github.com" 22 | "GitHub" = "https://github.com/naokiyokoyama/ovon" 23 | 24 | [tool.setuptools] 25 | packages = ["ovon"] 26 | 27 | [tool.ruff] 28 | # Enable pycodestyle (`E`), Pyflakes (`F`), and import sorting (`I`) 29 | select = ["E", "F", "I"] 30 | ignore = [] 31 | 32 | # Allow autofix for all enabled rules (when `--fix`) is provided. 33 | fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"] 34 | unfixable = [] 35 | 36 | # Exclude a variety of commonly ignored directories. 37 | exclude = [ 38 | ".bzr", 39 | ".direnv", 40 | ".eggs", 41 | ".git", 42 | ".hg", 43 | ".mypy_cache", 44 | ".nox", 45 | ".pants.d", 46 | ".pytype", 47 | ".ruff_cache", 48 | ".svn", 49 | ".tox", 50 | ".venv", 51 | "__pypackages__", 52 | "_build", 53 | "buck-out", 54 | "build", 55 | "dist", 56 | "node_modules", 57 | "venv", 58 | "docker/ros", 59 | ] 60 | 61 | # Same as Black. 62 | line-length = 88 63 | 64 | # Allow unused variables when underscore-prefixed. 65 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 66 | 67 | # Assume Python 3.9. 68 | target-version = "py39" 69 | 70 | [tool.ruff.per-file-ignores] 71 | "__init__.py" = ["F401"] 72 | 73 | [tool.ruff.mccabe] 74 | # Unlike Flake8, default to a complexity level of 10. 75 | max-complexity = 10 76 | 77 | [tool.black] 78 | line-length = 88 79 | target-version = ['py39'] 80 | include = '\.pyi?$' 81 | # `extend-exclude` is not honored when `black` is passed a file path explicitly, 82 | # as is typical when `black` is invoked via `pre-commit`. 83 | force-exclude = ''' 84 | /( 85 | docker/ros/.* 86 | )/ 87 | ''' 88 | 89 | preview = true 90 | 91 | # mypy configuration 92 | [tool.mypy] 93 | python_version = "3.9" 94 | disallow_untyped_defs = true 95 | ignore_missing_imports = true 96 | explicit_package_bases = true 97 | check_untyped_defs = true 98 | strict_equality = true 99 | warn_unreachable = true 100 | warn_redundant_casts = true 101 | no_implicit_optional = true 102 | files = ['ovon'] 103 | exclude = '^(docker|.*external|.*thirdparty|.*install|.*build|.*_experimental)/' 104 | -------------------------------------------------------------------------------- /ovon/trainers/ver_rollout_storage_with_kv.py: -------------------------------------------------------------------------------- 1 | from typing import Iterator, Optional 2 | 3 | import torch 4 | from habitat_baselines.common.tensor_dict import DictTree, TensorDict 5 | from habitat_baselines.rl.ver.ver_rollout_storage import ( 6 | VERRolloutStorage, 7 | generate_ver_mini_batches, 8 | ) 9 | 10 | 11 | def build_rnn_build_seq_info( 12 | device: torch.device, 13 | episode_ids, 14 | ) -> TensorDict: 15 | rnn_build_seq_info = TensorDict() 16 | rnn_build_seq_info["episode_ids"] = ( 17 | torch.from_numpy(episode_ids).to(device=device).reshape(-1, 1) 18 | ) 19 | 20 | return rnn_build_seq_info 21 | 22 | 23 | class VERRolloutStorageWithKVCache(VERRolloutStorage): 24 | def __init__( 25 | self, 26 | num_layers: int, 27 | num_heads: int, 28 | max_context_length: int, 29 | head_dim: int, 30 | *args, 31 | **kwargs, 32 | ): 33 | super().__init__(*args, **kwargs) 34 | self._aux_buffers["next_hidden_states"] = torch.zeros( 35 | self._num_envs, 36 | num_layers, 37 | 2, # key, value 38 | num_heads, 39 | max_context_length - 1, 40 | head_dim, 41 | device=self.buffers["recurrent_hidden_states"].device, 42 | ) 43 | 44 | self._set_aux_buffers() 45 | 46 | def recurrent_generator( 47 | self, 48 | advantages: Optional[torch.Tensor], 49 | num_mini_batch: int, 50 | ) -> Iterator[DictTree]: 51 | """ 52 | An exact copy of the recurrent_generator method from the parent class, the only 53 | difference in behaviour is: 54 | - The new definition of build_rnn_build_seq_info (above) is used instead. 55 | - "recurrent_hidden_states" is not set 56 | """ 57 | if not self.variable_experience: 58 | yield from super().recurrent_generator(advantages, num_mini_batch) 59 | else: 60 | for mb_inds in generate_ver_mini_batches( 61 | num_mini_batch, 62 | self.sequence_lengths, 63 | self.num_seqs_at_step, 64 | self.select_inds, 65 | self.last_sequence_in_batch_mask, 66 | self.episode_ids_cpu, 67 | ): 68 | mb_inds_cpu = torch.from_numpy(mb_inds) 69 | mb_inds = mb_inds_cpu.to(device=self.device) 70 | 71 | if not self.variable_experience: 72 | batch = self.buffers.map(lambda t: t.flatten(0, 1))[mb_inds] 73 | if advantages is not None: 74 | batch["advantages"] = advantages.flatten(0, 1)[mb_inds] 75 | else: 76 | batch = self.buffers[mb_inds] 77 | if advantages is not None: 78 | batch["advantages"] = advantages[mb_inds] 79 | 80 | batch["rnn_build_seq_info"] = build_rnn_build_seq_info( 81 | device=self.device, 82 | episode_ids=self.episode_ids_cpu[mb_inds_cpu], 83 | ) 84 | 85 | yield batch.to_tree() 86 | -------------------------------------------------------------------------------- /config/experiments/rnn_rl.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /habitat_baselines: habitat_baselines_rl_config_base 5 | - /habitat_baselines/rl/policy/obs_transforms: 6 | - resize 7 | - objectnav_stretch_hm3d 8 | - override /habitat/task/lab_sensors: 9 | - clip_objectgoal_sensor 10 | - _self_ 11 | 12 | habitat: 13 | environment: 14 | iterator_options: 15 | max_scene_repeat_steps: 50000 16 | task: 17 | success_reward: 2.5 18 | slack_reward: -1e-3 19 | lab_sensors: 20 | clip_objectgoal_sensor: 21 | cache: data/text_embeddings/siglip.pkl 22 | measurements: 23 | success: 24 | success_distance: 0.25 25 | distance_to_goal: 26 | type: OVONDistanceToGoal 27 | dataset: 28 | type: "OVON-v1" 29 | split: train 30 | data_path: data/datasets/ovon/hm3d/v1/{split}/{split}.json.gz 31 | simulator: 32 | type: "OVONSim-v0" 33 | navmesh_settings: 34 | agent_max_climb: 0.1 35 | cell_height: 0.05 36 | 37 | habitat_baselines: 38 | torch_gpu_id: 0 39 | tensorboard_dir: "tb" 40 | video_dir: "video_dir" 41 | test_episode_count: -1 42 | eval_ckpt_path_dir: "data/new_checkpoints" 43 | num_environments: 32 44 | checkpoint_folder: "data/new_checkpoints" 45 | trainer_name: "ver" 46 | num_updates: -1 47 | total_num_steps: 150000000 48 | log_interval: 10 49 | num_checkpoints: 50 50 | # Force PyTorch to be single threaded as 51 | # this improves performance considerably 52 | force_torch_single_threaded: True 53 | 54 | eval: 55 | split: "val" 56 | 57 | rl: 58 | 59 | policy: 60 | name: "PointNavResNetCLIPPolicy" 61 | backbone: "siglip" 62 | fusion_type: "concat" 63 | use_vis_query: True 64 | use_residual: True 65 | residual_vision: True 66 | rgb_only: False 67 | 68 | ppo: 69 | # ppo params 70 | clip_param: 0.2 71 | ppo_epoch: 1 72 | num_mini_batch: 2 73 | value_loss_coef: 0.5 74 | entropy_coef: 0.01 75 | lr: 2.5e-4 76 | eps: 1e-5 77 | max_grad_norm: 0.2 78 | num_steps: 100 79 | use_gae: True 80 | gamma: 0.99 81 | tau: 0.95 82 | use_linear_clip_decay: False 83 | use_linear_lr_decay: False 84 | reward_window_size: 50 85 | 86 | use_normalized_advantage: False 87 | 88 | hidden_size: 1024 89 | 90 | ddppo: 91 | sync_frac: 0.6 92 | # The PyTorch distributed backend to use 93 | distrib_backend: NCCL 94 | # Visual encoder backbone 95 | pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth 96 | # Initialize with pretrained weights 97 | pretrained: False 98 | # Initialize just the visual encoder backbone with pretrained weights 99 | pretrained_encoder: False 100 | # Whether or not the visual encoder backbone will be trained. 101 | train_encoder: False 102 | # Whether or not to reset the critic linear layer 103 | reset_critic: True 104 | 105 | # Model parameters 106 | rnn_type: LSTM 107 | num_recurrent_layers: 4 108 | -------------------------------------------------------------------------------- /ovon/utils/visualize/statistics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import math 4 | import os 5 | 6 | import matplotlib.pyplot as plt 7 | import pandas as pd 8 | import seaborn as sns 9 | from tqdm import tqdm 10 | 11 | from ovon.utils.utils import load_dataset 12 | 13 | 14 | def plot_statistics(path, output_path, split="train"): 15 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 16 | 17 | geodesic_distances = [] 18 | euclidean_distances = [] 19 | files = glob.glob(os.path.join(path, "*json.gz")) 20 | 21 | categories = {} 22 | for f in tqdm(files): 23 | dataset = load_dataset(f) 24 | for ep in dataset["episodes"]: 25 | geodesic_distances.append(ep["info"]["geodesic_distance"]) 26 | euclidean_distances.append(ep["info"]["euclidean_distance"]) 27 | categories[ep["object_category"]] = ( 28 | categories.get(ep["object_category"], 0) + 1 29 | ) 30 | 31 | # Plot distances for visualization 32 | plt.figure(figsize=(8, 8)) 33 | hist_data = list(filter(math.isfinite, geodesic_distances)) 34 | hist_data = pd.DataFrame.from_dict({"Geodesic distance": hist_data}) 35 | ax = sns.histplot(data=hist_data, x="Geodesic distance") 36 | 37 | ax.set_xticks(range(0, 32, 2)) 38 | 39 | plt.title("Geodesic distance to closest goal") 40 | plt.tight_layout() 41 | plt.savefig( 42 | os.path.join(output_path, "ovon_{}_geodesic_distances.png".format(split)) 43 | ) 44 | 45 | plt.figure(figsize=(8, 8)) 46 | hist_data = list( 47 | filter( 48 | math.isfinite, 49 | [g / e for g, e in zip(geodesic_distances, euclidean_distances)], 50 | ) 51 | ) 52 | hist_data = pd.DataFrame.from_dict({"Euc Geo ratio": hist_data}) 53 | ax = sns.histplot(data=hist_data, x="Euc Geo ratio") 54 | 55 | ax.set_ylim(0, 5000) 56 | ax.set_xlim(0, 5) 57 | ax.set_xticks(range(0, 5, 1)) 58 | 59 | plt.title("Euc Geo ratio") 60 | plt.tight_layout() 61 | plt.savefig(os.path.join(output_path, "ovon_{}_euc_geo_ratio.png".format(split))) 62 | 63 | categories = { 64 | "objects": list(categories.keys()), 65 | "frequency": list(categories.values()), 66 | } 67 | 68 | df = pd.DataFrame.from_dict(categories) 69 | df.sort_values(by="frequency", inplace=True, ascending=False) 70 | print(df.columns) 71 | 72 | fig, axs = plt.subplots(1, 1, figsize=(8, 50)) 73 | 74 | sns.barplot(data=df, x="frequency", y="objects", ax=axs) 75 | 76 | fig.savefig( 77 | os.path.join(output_path, "ovon_{}_categories.png".format(split)), 78 | dpi=100, 79 | bbox_inches="tight", 80 | pad_inches=0.1, 81 | transparent=False, 82 | ) 83 | 84 | 85 | if __name__ == "__main__": 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument( 88 | "--path", 89 | type=str, 90 | default="data/datasets/ovon/hm3d/v4_stretch/val_seen/content/", 91 | ) 92 | parser.add_argument("--output-path", type=str, default="val_unseen.png") 93 | parser.add_argument("--split", type=str, default="train") 94 | args = parser.parse_args() 95 | 96 | plot_statistics(args.path, args.output_path, args.split) 97 | -------------------------------------------------------------------------------- /ovon/models/encoders/visual_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | import torch 5 | from habitat_baselines.rl.ddppo.policy.running_mean_and_var import ( 6 | RunningMeanAndVar, 7 | ) 8 | from torch import nn as nn 9 | from torch.nn import functional as F 10 | 11 | from ovon.models.encoders import resnet_gn as resnet 12 | 13 | 14 | class VisualEncoder(nn.Module): 15 | def __init__( 16 | self, 17 | image_size: int, 18 | backbone: str, 19 | input_channels: int = 3, 20 | resnet_baseplanes: int = 32, 21 | resnet_ngroups: int = 32, 22 | normalize_visual_inputs: bool = True, 23 | avgpooled_image: bool = False, 24 | drop_path_rate: float = 0.0, 25 | visual_transform: Any = None, 26 | num_environments: int = 1, 27 | ): 28 | super().__init__() 29 | self.avgpooled_image = avgpooled_image 30 | self.is_blind = False 31 | self.visual_transform = visual_transform 32 | self.num_environments = num_environments 33 | 34 | if normalize_visual_inputs: 35 | self.running_mean_and_var: nn.Module = RunningMeanAndVar(input_channels) 36 | else: 37 | self.running_mean_and_var = nn.Sequential() 38 | 39 | if "resnet" in backbone: 40 | make_backbone = getattr(resnet, backbone) 41 | self.backbone = make_backbone( 42 | input_channels, resnet_baseplanes, resnet_ngroups 43 | ) 44 | 45 | spatial_size = image_size 46 | if self.avgpooled_image: 47 | spatial_size = image_size // 2 48 | 49 | final_spatial = int(spatial_size * self.backbone.final_spatial_compress) 50 | after_compression_flat_size = 2048 51 | num_compression_channels = int( 52 | round(after_compression_flat_size / (final_spatial**2)) 53 | ) 54 | self.compression = nn.Sequential( 55 | nn.Conv2d( 56 | self.backbone.final_channels, 57 | num_compression_channels, 58 | kernel_size=3, 59 | padding=1, 60 | bias=False, 61 | ), 62 | nn.GroupNorm(1, num_compression_channels), 63 | nn.ReLU(True), 64 | ) 65 | 66 | output_shape = ( 67 | num_compression_channels, 68 | final_spatial, 69 | final_spatial, 70 | ) 71 | self.output_shape = output_shape 72 | self.output_size = np.prod(output_shape) 73 | else: 74 | raise ValueError("unknown backbone {}".format(backbone)) 75 | 76 | def forward(self, observations: torch.Tensor, N: int = None) -> torch.Tensor: # type: ignore 77 | num_environments = self.num_environments 78 | if N is not None: 79 | num_environments = N 80 | 81 | rgb = observations["rgb"] 82 | x = self.visual_transform(rgb, num_environments) 83 | 84 | if ( 85 | self.avgpooled_image 86 | ): # For compatibility with the habitat_baselines implementation 87 | x = F.avg_pool2d(x, 2) 88 | x = self.running_mean_and_var(x) 89 | x = self.backbone(x) 90 | x = self.compression(x) 91 | return x 92 | -------------------------------------------------------------------------------- /config/experiments/transformer_rl.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /habitat_baselines: habitat_baselines_rl_config_base 5 | - /habitat_baselines/rl/policy/obs_transforms: 6 | - resize 7 | - objectnav_stretch_hm3d 8 | - override /habitat/task/lab_sensors: 9 | - clip_objectgoal_sensor 10 | - step_id_sensor 11 | - _self_ 12 | 13 | habitat: 14 | environment: 15 | iterator_options: 16 | max_scene_repeat_steps: 50000 17 | task: 18 | success_reward: 2.5 19 | slack_reward: -1e-3 20 | lab_sensors: 21 | clip_objectgoal_sensor: 22 | cache: data/text_embeddings/siglip.pkl 23 | measurements: 24 | success: 25 | success_distance: 0.25 26 | distance_to_goal: 27 | type: OVONDistanceToGoal 28 | dataset: 29 | type: "OVON-v1" 30 | split: train 31 | data_path: data/datasets/ovon/hm3d/v1/{split}/{split}.json.gz 32 | simulator: 33 | type: "OVONSim-v0" 34 | navmesh_settings: 35 | agent_max_climb: 0.1 36 | cell_height: 0.05 37 | 38 | habitat_baselines: 39 | torch_gpu_id: 0 40 | tensorboard_dir: "tb" 41 | video_dir: "video_dir" 42 | test_episode_count: -1 43 | eval_ckpt_path_dir: "data/new_checkpoints" 44 | num_environments: 32 45 | checkpoint_folder: "data/new_checkpoints" 46 | trainer_name: "ver_transformer" 47 | num_updates: -1 48 | total_num_steps: 150000000 49 | log_interval: 10 50 | num_checkpoints: 50 51 | # Force PyTorch to be single threaded as 52 | # this improves performance considerably 53 | force_torch_single_threaded: True 54 | 55 | eval: 56 | split: "val" 57 | 58 | rl: 59 | 60 | policy: 61 | name: "OVONTransformerPolicy" 62 | backbone: "siglip" 63 | fusion_type: "concat" 64 | use_vis_query: True 65 | use_residual: True 66 | residual_vision: True 67 | rgb_only: False 68 | transformer_config: 69 | model_name: "llama" 70 | n_layers: 4 71 | n_heads: 8 72 | n_hidden: 512 73 | n_mlp_hidden: 1024 74 | max_context_length: 100 75 | shuffle_pos_id_for_update: True 76 | 77 | ppo: 78 | # ppo params 79 | clip_param: 0.2 80 | ppo_epoch: 1 81 | num_mini_batch: 2 82 | value_loss_coef: 0.5 83 | entropy_coef: 0.01 84 | lr: 2.5e-4 85 | eps: 1e-5 86 | max_grad_norm: 0.2 87 | num_steps: 100 88 | use_gae: True 89 | gamma: 0.99 90 | tau: 0.95 91 | use_linear_clip_decay: False 92 | use_linear_lr_decay: False 93 | reward_window_size: 50 94 | 95 | use_normalized_advantage: False 96 | 97 | hidden_size: 512 98 | 99 | ddppo: 100 | sync_frac: 0.6 101 | # The PyTorch distributed backend to use 102 | distrib_backend: NCCL 103 | # Visual encoder backbone 104 | pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth 105 | # Initialize with pretrained weights 106 | pretrained: False 107 | # Initialize just the visual encoder backbone with pretrained weights 108 | pretrained_encoder: False 109 | # Whether or not the visual encoder backbone will be trained. 110 | train_encoder: False 111 | # Whether or not to reset the critic linear layer 112 | reset_critic: True 113 | -------------------------------------------------------------------------------- /ovon/task/rewards.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any, Optional 2 | 3 | import numpy as np 4 | from habitat.core.embodied_task import EmbodiedTask, Measure 5 | from habitat.core.registry import registry 6 | from habitat.core.simulator import Simulator 7 | from habitat.tasks.nav.nav import DistanceToGoal, Success 8 | 9 | from ovon.measurements.imagenav import AngleSuccess, AngleToGoal 10 | 11 | if TYPE_CHECKING: 12 | from omegaconf import DictConfig 13 | 14 | 15 | @registry.register_measure 16 | class ImageNavReward(Measure): 17 | cls_uuid: str = "imagenav_reward" 18 | 19 | def __init__(self, *args: Any, sim: Simulator, config: "DictConfig", **kwargs: Any): 20 | super().__init__(**kwargs) 21 | self._sim = sim 22 | self._config = config 23 | self._previous_dtg: Optional[float] = None 24 | self._previous_atg: Optional[float] = None 25 | 26 | def _get_uuid(self, *args: Any, **kwargs: Any) -> str: 27 | return self.cls_uuid 28 | 29 | def reset_metric( 30 | self, 31 | *args: Any, 32 | task: EmbodiedTask, 33 | **kwargs: Any, 34 | ): 35 | task.measurements.check_measure_dependencies( 36 | self.uuid, 37 | [ 38 | Success.cls_uuid, 39 | DistanceToGoal.cls_uuid, 40 | AngleToGoal.cls_uuid, 41 | AngleSuccess.cls_uuid, 42 | ], 43 | ) 44 | self._metric = None 45 | self._previous_dtg = None 46 | self._previous_atg = None 47 | self.update_metric(task=task) 48 | 49 | def update_metric(self, *args: Any, task: EmbodiedTask, **kwargs: Any): 50 | # success reward 51 | success = task.measurements.measures[Success.cls_uuid].get_metric() 52 | success_reward = self._config.success_reward if success else 0.0 53 | 54 | # distance-to-goal reward 55 | dtg = task.measurements.measures[DistanceToGoal.cls_uuid].get_metric() 56 | if self._previous_dtg is None: 57 | self._previous_dtg = dtg 58 | add_dtg = self._config.use_dtg_reward 59 | dtg_reward = self._previous_dtg - dtg if add_dtg else 0.0 60 | self._previous_dtg = dtg 61 | 62 | # angle-to-goal reward 63 | atg = task.measurements.measures[AngleToGoal.cls_uuid].get_metric() 64 | add_atg = self._config.use_atg_reward 65 | if self._config.use_atg_fix: 66 | if dtg > self._config.atg_reward_distance: 67 | atg = np.pi 68 | else: 69 | if dtg > self._config.atg_reward_distance: 70 | add_atg = False 71 | if self._previous_atg is None: 72 | self._previous_atg = atg 73 | angle_reward = self._previous_atg - atg if add_atg else 0.0 74 | self._previous_atg = atg 75 | 76 | # angle success reward 77 | angle_success = task.measurements.measures[AngleSuccess.cls_uuid].get_metric() 78 | angle_success_reward = ( 79 | self._config.angle_success_reward if angle_success else 0.0 80 | ) 81 | 82 | # slack penalty 83 | slack_penalty = self._config.slack_penalty 84 | 85 | # reward 86 | self._metric = ( 87 | success_reward 88 | + dtg_reward 89 | + angle_reward 90 | + angle_success_reward 91 | + slack_penalty 92 | ) 93 | -------------------------------------------------------------------------------- /ovon/utils/dagger_environment_worker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Meta Platforms, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from multiprocessing.context import BaseContext 9 | from typing import ( 10 | TYPE_CHECKING, 11 | List, 12 | Optional, 13 | ) 14 | 15 | import attr 16 | from habitat_baselines.rl.ver.environment_worker import ( 17 | EnvironmentWorker, 18 | EnvironmentWorkerProcess, 19 | _create_worker_configs, 20 | ) 21 | from habitat_baselines.rl.ver.worker_common import WorkerBase, WorkerQueues 22 | 23 | if TYPE_CHECKING: 24 | from omegaconf import DictConfig 25 | 26 | 27 | @attr.s(auto_attribs=True, auto_detect=True) 28 | class ILEnvironmentWorkerProcess(EnvironmentWorkerProcess): 29 | teacher_label: Optional[str] = None 30 | 31 | def _step_env(self, action): 32 | with self.timer.avg_time("step env"): 33 | obs, reward, done, info = self.env.step( 34 | self._last_obs[self.teacher_label].item() 35 | ) 36 | # ^only line different from the original EnvironmentWorkerProcess._step_env 37 | self._step_id += 1 38 | 39 | if not math.isfinite(reward): 40 | reward = -1.0 41 | 42 | with self.timer.avg_time("reset env"): 43 | if done: 44 | self._episode_id += 1 45 | self._step_id = 0 46 | if self.auto_reset_done: 47 | obs = self.env.reset() 48 | 49 | return obs, reward, done, info 50 | 51 | 52 | class ILEnvironmentWorker(EnvironmentWorker): 53 | def __init__( 54 | self, 55 | mp_ctx: BaseContext, 56 | env_idx: int, 57 | env_config, 58 | auto_reset_done, 59 | queues: WorkerQueues, 60 | ): 61 | teacher_label = None 62 | obs_trans_conf = env_config.habitat_baselines.rl.policy.obs_transforms 63 | if hasattr(env_config.habitat_baselines.rl.policy, "obs_transforms"): 64 | for obs_transform_config in obs_trans_conf.values(): 65 | if hasattr(obs_transform_config, "teacher_label"): 66 | teacher_label = obs_transform_config.teacher_label 67 | break 68 | assert teacher_label is not None, "teacher_label not found in config" 69 | WorkerBase.__init__( 70 | self, 71 | mp_ctx, 72 | ILEnvironmentWorkerProcess, 73 | env_idx, 74 | env_config, 75 | auto_reset_done, 76 | queues, 77 | teacher_label=teacher_label, 78 | ) 79 | self.env_worker_queue = queues.environments[env_idx] 80 | 81 | 82 | def _construct_il_environment_workers_impl( 83 | configs, 84 | auto_reset_done, 85 | mp_ctx: BaseContext, 86 | queues: WorkerQueues, 87 | ): 88 | num_environments = len(configs) 89 | workers = [] 90 | for i in range(num_environments): 91 | w = ILEnvironmentWorker(mp_ctx, i, configs[i], auto_reset_done, queues) 92 | workers.append(w) 93 | 94 | return workers 95 | 96 | 97 | def construct_il_environment_workers( 98 | config: "DictConfig", 99 | mp_ctx: BaseContext, 100 | worker_queues: WorkerQueues, 101 | ) -> List[EnvironmentWorker]: 102 | configs = _create_worker_configs(config) 103 | 104 | return _construct_il_environment_workers_impl(configs, True, mp_ctx, worker_queues) 105 | -------------------------------------------------------------------------------- /config/experiments/rnn_rl_finetune.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /habitat_baselines: habitat_baselines_rl_config_base 5 | - /habitat_baselines/rl/policy/obs_transforms: 6 | - resize 7 | - objectnav_stretch_hm3d 8 | - override /habitat/task/lab_sensors: 9 | - clip_objectgoal_sensor 10 | - override /habitat/task/measurements: 11 | - distance_to_goal 12 | - success 13 | - spl 14 | - soft_spl 15 | - collisions 16 | - collision_penalty 17 | - _self_ 18 | 19 | habitat: 20 | environment: 21 | iterator_options: 22 | max_scene_repeat_steps: 50000 23 | task: 24 | success_reward: 5.0 25 | slack_reward: -1e-3 26 | reward_measure: "collision_penalty" 27 | lab_sensors: 28 | clip_objectgoal_sensor: 29 | cache: data/text_embeddings/siglip.pkl 30 | measurements: 31 | success: 32 | success_distance: 0.25 33 | distance_to_goal: 34 | type: OVONDistanceToGoal 35 | dataset: 36 | type: "OVON-v1" 37 | split: train 38 | data_path: data/datasets/ovon/hm3d/v1/{split}/{split}.json.gz 39 | simulator: 40 | type: "OVONSim-v0" 41 | navmesh_settings: 42 | agent_max_climb: 0.1 43 | cell_height: 0.05 44 | 45 | habitat_baselines: 46 | torch_gpu_id: 0 47 | tensorboard_dir: "tb" 48 | video_dir: "video_dir" 49 | test_episode_count: -1 50 | eval_ckpt_path_dir: "data/new_checkpoints" 51 | num_environments: 32 52 | checkpoint_folder: "data/new_checkpoints" 53 | trainer_name: "ver" 54 | num_updates: -1 55 | total_num_steps: 150000000 56 | log_interval: 10 57 | num_checkpoints: 50 58 | # Force PyTorch to be single threaded as 59 | # this improves performance considerably 60 | force_torch_single_threaded: True 61 | 62 | eval: 63 | split: "val" 64 | 65 | rl: 66 | 67 | policy: 68 | name: "PointNavResNetCLIPPolicy" 69 | backbone: "siglip" 70 | fusion_type: "concat" 71 | use_vis_query: True 72 | use_residual: True 73 | residual_vision: True 74 | rgb_only: False 75 | 76 | finetune: 77 | enabled: True 78 | lr: 1.5e-5 79 | start_actor_warmup_at: 750 80 | start_actor_update_at: 1000 81 | start_critic_warmup_at: 500 82 | start_critic_update_at: 1000 83 | 84 | ppo: 85 | # ppo params 86 | clip_param: 0.2 87 | ppo_epoch: 1 88 | num_mini_batch: 2 89 | value_loss_coef: 0.5 90 | entropy_coef: 0.01 91 | lr: 2.5e-4 92 | eps: 1e-5 93 | max_grad_norm: 0.2 94 | num_steps: 100 95 | use_gae: True 96 | gamma: 0.99 97 | tau: 0.95 98 | use_linear_clip_decay: False 99 | use_linear_lr_decay: True 100 | reward_window_size: 50 101 | 102 | use_normalized_advantage: False 103 | 104 | hidden_size: 1024 105 | 106 | ddppo: 107 | sync_frac: 0.6 108 | # The PyTorch distributed backend to use 109 | distrib_backend: NCCL 110 | # Visual encoder backbone 111 | pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth 112 | # Initialize with pretrained weights 113 | pretrained: False 114 | # Initialize just the visual encoder backbone with pretrained weights 115 | pretrained_encoder: False 116 | # Whether or not the visual encoder backbone will be trained. 117 | train_encoder: False 118 | # Whether or not to reset the critic linear layer 119 | reset_critic: True 120 | 121 | # Model parameters 122 | rnn_type: LSTM 123 | num_recurrent_layers: 4 124 | -------------------------------------------------------------------------------- /ovon/utils/analysis/fixes_to_provided_mappings.csv: -------------------------------------------------------------------------------- 1 | Raw,Corrected 2 | All Raw that map to “lamp”, 3 | lamp,lamp 4 | lamp ,lamp 5 | lamps,lamp 6 | lmap,lamp 7 | unknown/ probably lamp,lamp 8 | łamp,lamp 9 | , 10 | All Raw that map to “table lamp”, 11 | table lamp,lamp 12 | table lamp,lamp 13 | unknown/ probably table lamp,table lamp 14 | , 15 | All Raw that map to “plant”, 16 | dacorative plant,plant 17 | ornament plant,plant 18 | ornamental plant,plant 19 | plant,plant 20 | plants,plant 21 | table plant,plant 22 | twigs in vase,plant 23 | unknown/ probably decorative plant,plant 24 | vegetation,plant 25 | , 26 | All Raw that map to “flower”, 27 | flower,flower 28 | flower ,flower 29 | FLOWER,flower 30 | flower,flower 31 | flowers,flower 32 | ornament flower,flower 33 | vase with flower,flower 34 | , 35 | All Raw that map to “couch”, 36 | couch,couch 37 | , 38 | All Raw that map to “sofa”, 39 | unknown/ probably sofa,sofa 40 | sofa,sofa 41 | , 42 | All Raw that map to “box”, 43 | box,box 44 | tabletop box,box 45 | tea box,box 46 | tin,box 47 | towel box,box 48 | unknown/ probably box,box 49 | unknown/ probably desk or box,box 50 | , 51 | All Raw that map to “boxes”, 52 | boxes,boxes 53 | tea boxes ,boxes 54 | tea boxes,boxes 55 | pile of boxes,boxes 56 | , 57 | All Raw that map to “refrigerator”, 58 | refridgerator,refrigerator 59 | refrigearator,refrigerator 60 | refrigerator,refrigerator 61 | wine refrigerator,refrigerator 62 | , 63 | All Raw that map to “fridge”, 64 | fridge,fridge 65 | unknown/ probably fridge,fridge 66 | , 67 | All Raw that map to “picture”, 68 | art,picture 69 | framed picture,picture 70 | image,picture 71 | pictrure,picture 72 | picture,picture 73 | picture ,picture 74 | pictures,picture 75 | picure,picture 76 | unknown picture/window,picture 77 | unknown/ probably framed picture,picture 78 | wall decoration,picture 79 | wall frame,picture 80 | , 81 | All Raw that map to “poster”, 82 | poster,picture 83 | poster figure,poster 84 | , 85 | All Raw that map to “bedside table”, 86 | nighstand,nightstand 87 | night stand,nightstand 88 | night table,nightstand 89 | nightsand,nightstand 90 | nightstand,nightstand 91 | nigtstand,nightstand 92 | , 93 | All Raw that map to “bedside table”, 94 | beside table,bedside table 95 | bedside table,bedside table 96 | , 97 | All Raw that map to “paper”, 98 | paper,paper 99 | paper ,paper 100 | paper,paper 101 | , 102 | All Raw that map to “papers”, 103 | papers,papers 104 | pile of papers,papers 105 | unknown/ probably letters,papers 106 | , 107 | All Raw that map to “ironing board”, 108 | ironing board,ironing board 109 | ironing board ,ironing board 110 | , 111 | All Raw that map to “iron board”, 112 | iron board,iron board 113 | , 114 | All Raw that map to “cup”, 115 | cup,cup 116 | toothbrush cup,cup 117 | toothbrush holder,cup 118 | unknown /probably cup,cup 119 | unknown/ proably cup,cup 120 | , 121 | All Raw that map to “cups”, 122 | cups,cups 123 | pile of cups,cups 124 | unknown/ probably cups,cups 125 | , 126 | All Raw that map to “drawer”, 127 | drawer,drawer 128 | wardrobe drawer,drawer 129 | , 130 | All Raw that map to “drawers”, 131 | drawers,drawers 132 | office drawer,drawers 133 | organizer drawers,drawers 134 | , 135 | All Raw that map to “paper towel”, 136 | paper towel,paper towel 137 | , 138 | All Raw that map to “paper towels”, 139 | unknown/ probably paper towel dispenser,paper towels 140 | unknown/ probably paper towel holder,paper towels 141 | paper towels,paper towel 142 | , 143 | All Raw that map to “weight”, 144 | weight,weight 145 | , 146 | , 147 | All Raw that map to “weights”, 148 | workout weight,weights 149 | weights,weights 150 | -------------------------------------------------------------------------------- /config/experiments/rnn_dagger.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /habitat_baselines: habitat_baselines_rl_config_base 5 | - /habitat_baselines/rl/policy/obs_transforms: 6 | - resize 7 | - relabel_teacher_actions 8 | - /habitat/task/measurements: 9 | - frontier_exploration_map 10 | - objectnav_stretch_hm3d 11 | - override /habitat/task/lab_sensors: 12 | - objnav_explorer 13 | - clip_objectgoal_sensor 14 | - _self_ 15 | 16 | habitat: 17 | environment: 18 | iterator_options: 19 | max_scene_repeat_steps: 50000 20 | task: 21 | success_reward: 2.5 22 | slack_reward: -1e-3 23 | lab_sensors: 24 | objnav_explorer: 25 | map_resolution: 128 26 | beeline_dist_thresh: 3.5 27 | visibility_dist: 3.0 28 | area_thresh: 1.5 29 | success_distance: 0.25 30 | fov: 42 31 | clip_objectgoal_sensor: 32 | cache: data/text_embeddings/siglip.pkl 33 | measurements: 34 | success: 35 | success_distance: 0.25 36 | distance_to_goal: 37 | type: OVONDistanceToGoal 38 | dataset: 39 | type: "OVON-v1" 40 | split: train 41 | data_path: data/datasets/ovon/hm3d/v1/{split}/{split}.json.gz 42 | simulator: 43 | type: "OVONSim-v0" 44 | navmesh_settings: 45 | agent_max_climb: 0.1 46 | cell_height: 0.05 47 | habitat_sim_v0: 48 | allow_sliding: True 49 | 50 | habitat_baselines: 51 | torch_gpu_id: 0 52 | tensorboard_dir: "tb" 53 | video_dir: "video_dir" 54 | test_episode_count: -1 55 | eval_ckpt_path_dir: "data/new_checkpoints" 56 | num_environments: 32 57 | checkpoint_folder: "data/new_checkpoints" 58 | trainer_name: "ver_dagger" 59 | num_updates: -1 60 | total_num_steps: 150000000 61 | log_interval: 10 62 | num_checkpoints: 50 63 | # Force PyTorch to be single threaded as 64 | # this improves performance considerably 65 | force_torch_single_threaded: True 66 | 67 | eval: 68 | split: "val" 69 | 70 | rl: 71 | 72 | policy: 73 | name: "PointNavResNetCLIPPolicy" 74 | backbone: "siglip" 75 | fusion_type: "concat" 76 | use_vis_query: True 77 | use_residual: True 78 | residual_vision: True 79 | rgb_only: False 80 | obs_transforms: 81 | relabel_teacher_actions: 82 | teacher_label: "objnav_explorer" 83 | 84 | ppo: 85 | # ppo params 86 | clip_param: 0.2 87 | ppo_epoch: 1 88 | num_mini_batch: 2 89 | value_loss_coef: 0.5 90 | entropy_coef: 0.01 91 | lr: 2.5e-4 92 | eps: 1e-5 93 | max_grad_norm: 0.2 94 | num_steps: 100 95 | use_gae: True 96 | gamma: 0.99 97 | tau: 0.95 98 | use_linear_clip_decay: False 99 | use_linear_lr_decay: False 100 | reward_window_size: 50 101 | 102 | use_normalized_advantage: False 103 | 104 | hidden_size: 1024 105 | 106 | ddppo: 107 | sync_frac: 0.6 108 | # The PyTorch distributed backend to use 109 | distrib_backend: NCCL 110 | # Visual encoder backbone 111 | pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth 112 | # Initialize with pretrained weights 113 | pretrained: False 114 | # Initialize just the visual encoder backbone with pretrained weights 115 | pretrained_encoder: False 116 | # Whether or not the visual encoder backbone will be trained. 117 | train_encoder: False 118 | # Whether or not to reset the critic linear layer 119 | reset_critic: True 120 | 121 | # Model parameters 122 | rnn_type: LSTM 123 | num_recurrent_layers: 4 124 | -------------------------------------------------------------------------------- /ovon/utils/analysis/postprocess_meta.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict 3 | 4 | import clip 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pandas as pd 8 | import seaborn as sns 9 | import torch 10 | 11 | from ovon.utils.utils import load_json, write_json 12 | 13 | PROMPT = "{category}" 14 | 15 | 16 | def postprocess_meta(input_path, output_path): 17 | df = pd.read_csv(input_path) 18 | 19 | categories_by_region = defaultdict(list) 20 | for idx, row in df.iterrows(): 21 | categories_by_region[row["Region Proposal"]].append(row["Category"]) 22 | 23 | for k, v in categories_by_region.items(): 24 | print(k, len(v)) 25 | 26 | write_json(categories_by_region, output_path) 27 | 28 | 29 | def clip_embeddings(clip_m, prompts): 30 | tokens = [] 31 | for prompt in prompts: 32 | tokens.append(clip.tokenize(prompt, context_length=77).numpy()) 33 | 34 | batch = torch.tensor(np.array(tokens)).cuda() 35 | with torch.no_grad(): 36 | text_embedding = clip_m.encode_text(batch.flatten(0, 1)).float() 37 | return text_embedding 38 | 39 | 40 | def max_similarity(clip_m, category, val_seen_categories): 41 | categories = val_seen_categories.copy() 42 | if category in categories: 43 | categories.remove(category) 44 | 45 | prompt = PROMPT.format(category=category) 46 | text_embedding = clip_embeddings(clip_m, [prompt] + categories) 47 | return ( 48 | torch.cosine_similarity(text_embedding[0].unsqueeze(0), text_embedding[1:]) 49 | .max() 50 | .item() 51 | ) 52 | 53 | 54 | def semantic_failures(input_path): 55 | clip_m, preprocess = clip.load("RN50", device="cuda") 56 | 57 | records = load_json(input_path) 58 | 59 | failures = 0 60 | defaultdict(int) 61 | for k, v in records.items(): 62 | failures += 1 - v["success"] 63 | if not v["success"]: 64 | nearest_objects = v["failure_modes.objects_within_2m"].split(",") 65 | category = v["target"] 66 | if category in nearest_objects: 67 | nearest_objects.remove(category) 68 | 69 | 70 | def failure_metrics(input_path, output_path): 71 | records = load_json(input_path) 72 | 73 | failures = 0 74 | failure_modes = defaultdict(int) 75 | for k, v in records.items(): 76 | failures += 1 - v["success"] 77 | if not v["success"]: 78 | for kk in v.keys(): 79 | if kk in [ 80 | "failure_modes.recognition", 81 | "failure_modes.exploration", 82 | "failure_modes.last_mile_nav", 83 | "failure_modes.stop_failure", 84 | ]: 85 | failure_modes[kk] += v[kk] 86 | 87 | failure_modes = {k: v / failures for k, v in failure_modes.items()} 88 | labels = list(failure_modes.keys()) 89 | metrics = list(failure_modes.values()) 90 | 91 | colors = sns.color_palette("pastel")[0:5] 92 | 93 | # create pie chart 94 | fig, ax = plt.subplots(figsize=(6, 3)) 95 | ax.pie(metrics, labels=labels, colors=colors, autopct="%.0f%%") 96 | fig.savefig(output_path, bbox_inches="tight") 97 | 98 | print(failure_modes) 99 | 100 | 101 | if __name__ == "__main__": 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument( 104 | "--input-path", type=str, help="Path to the meta csv file", required=True 105 | ) 106 | parser.add_argument( 107 | "--output-path", type=str, help="Path to the output file", required=True 108 | ) 109 | args = parser.parse_args() 110 | # postprocess_meta(args.input_path, args.output_path) 111 | failure_metrics(args.input_path, args.output_path) 112 | -------------------------------------------------------------------------------- /config/experiments/transformer_rl_finetune.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /habitat_baselines: habitat_baselines_rl_config_base 5 | - /habitat_baselines/rl/policy/obs_transforms: 6 | - resize 7 | - objectnav_stretch_hm3d 8 | - override /habitat/task/lab_sensors: 9 | - clip_objectgoal_sensor 10 | - step_id_sensor 11 | - override /habitat/task/measurements: 12 | - distance_to_goal 13 | - success 14 | - spl 15 | - soft_spl 16 | - collisions 17 | - collision_penalty 18 | - _self_ 19 | 20 | habitat: 21 | environment: 22 | iterator_options: 23 | max_scene_repeat_steps: 50000 24 | task: 25 | success_reward: 5.0 26 | slack_reward: -1e-3 27 | reward_measure: "collision_penalty" 28 | lab_sensors: 29 | clip_objectgoal_sensor: 30 | cache: data/text_embeddings/siglip.pkl 31 | measurements: 32 | success: 33 | success_distance: 0.25 34 | distance_to_goal: 35 | type: OVONDistanceToGoal 36 | dataset: 37 | type: "OVON-v1" 38 | split: train 39 | data_path: data/datasets/ovon/hm3d/v1/{split}/{split}.json.gz 40 | simulator: 41 | type: "OVONSim-v0" 42 | navmesh_settings: 43 | agent_max_climb: 0.1 44 | cell_height: 0.05 45 | 46 | habitat_baselines: 47 | torch_gpu_id: 0 48 | tensorboard_dir: "tb" 49 | video_dir: "video_dir" 50 | test_episode_count: -1 51 | eval_ckpt_path_dir: "data/new_checkpoints" 52 | num_environments: 32 53 | checkpoint_folder: "data/new_checkpoints" 54 | trainer_name: "ver_transformer" 55 | num_updates: -1 56 | total_num_steps: 150000000 57 | log_interval: 10 58 | num_checkpoints: 50 59 | # Force PyTorch to be single threaded as 60 | # this improves performance considerably 61 | force_torch_single_threaded: True 62 | 63 | eval: 64 | split: "val" 65 | 66 | rl: 67 | 68 | policy: 69 | name: "OVONTransformerPolicy" 70 | backbone: "siglip" 71 | fusion_type: "concat" 72 | use_vis_query: True 73 | use_residual: True 74 | residual_vision: True 75 | rgb_only: False 76 | transformer_config: 77 | model_name: "llama" 78 | n_layers: 4 79 | n_heads: 8 80 | n_hidden: 512 81 | n_mlp_hidden: 1024 82 | max_context_length: 100 83 | shuffle_pos_id_for_update: True 84 | 85 | finetune: 86 | enabled: True 87 | lr: 1.5e-5 88 | start_actor_warmup_at: 750 89 | start_actor_update_at: 1000 90 | start_critic_warmup_at: 500 91 | start_critic_update_at: 1000 92 | 93 | ppo: 94 | # ppo params 95 | clip_param: 0.2 96 | ppo_epoch: 1 97 | num_mini_batch: 2 98 | value_loss_coef: 0.5 99 | entropy_coef: 0.01 100 | lr: 2.5e-4 101 | eps: 1e-5 102 | max_grad_norm: 0.2 103 | num_steps: 100 104 | use_gae: True 105 | gamma: 0.99 106 | tau: 0.95 107 | use_linear_clip_decay: False 108 | use_linear_lr_decay: True 109 | reward_window_size: 50 110 | 111 | use_normalized_advantage: False 112 | 113 | hidden_size: 512 114 | 115 | ddppo: 116 | sync_frac: 0.6 117 | # The PyTorch distributed backend to use 118 | distrib_backend: NCCL 119 | # Visual encoder backbone 120 | pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth 121 | # Initialize with pretrained weights 122 | pretrained: False 123 | # Initialize just the visual encoder backbone with pretrained weights 124 | pretrained_encoder: False 125 | # Whether or not the visual encoder backbone will be trained. 126 | train_encoder: False 127 | # Whether or not to reset the critic linear layer 128 | reset_critic: True 129 | -------------------------------------------------------------------------------- /config/experiments/transformer_dagger.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - /habitat_baselines: habitat_baselines_rl_config_base 5 | - /habitat_baselines/rl/policy/obs_transforms: 6 | - resize 7 | - relabel_teacher_actions 8 | - /habitat/task/measurements: 9 | - frontier_exploration_map 10 | - objectnav_stretch_hm3d 11 | - override /habitat/task/lab_sensors: 12 | - objnav_explorer 13 | - clip_objectgoal_sensor 14 | - step_id_sensor 15 | - _self_ 16 | 17 | habitat: 18 | environment: 19 | iterator_options: 20 | max_scene_repeat_steps: 50000 21 | task: 22 | success_reward: 2.5 23 | slack_reward: -1e-3 24 | lab_sensors: 25 | objnav_explorer: 26 | map_resolution: 128 27 | beeline_dist_thresh: 3.5 28 | visibility_dist: 3.0 29 | area_thresh: 1.5 30 | success_distance: 0.25 31 | fov: 42 32 | clip_objectgoal_sensor: 33 | cache: data/text_embeddings/siglip.pkl 34 | measurements: 35 | success: 36 | success_distance: 0.25 37 | distance_to_goal: 38 | type: OVONDistanceToGoal 39 | dataset: 40 | type: "OVON-v1" 41 | split: train 42 | data_path: data/datasets/ovon/hm3d/v1/{split}/{split}.json.gz 43 | simulator: 44 | type: "OVONSim-v0" 45 | navmesh_settings: 46 | agent_max_climb: 0.1 47 | cell_height: 0.05 48 | habitat_sim_v0: 49 | allow_sliding: True 50 | 51 | habitat_baselines: 52 | torch_gpu_id: 0 53 | tensorboard_dir: "tb" 54 | video_dir: "video_dir" 55 | test_episode_count: -1 56 | eval_ckpt_path_dir: "data/new_checkpoints" 57 | num_environments: 32 58 | checkpoint_folder: "data/new_checkpoints" 59 | trainer_name: "ver_dagger" 60 | num_updates: -1 61 | total_num_steps: 150000000 62 | log_interval: 10 63 | num_checkpoints: 50 64 | # Force PyTorch to be single threaded as 65 | # this improves performance considerably 66 | force_torch_single_threaded: True 67 | 68 | eval: 69 | split: "val_seen" 70 | 71 | rl: 72 | 73 | policy: 74 | name: "OVONTransformerPolicy" 75 | backbone: "siglip" 76 | fusion_type: "concat" 77 | use_vis_query: True 78 | use_residual: True 79 | residual_vision: True 80 | rgb_only: False 81 | obs_transforms: 82 | relabel_teacher_actions: 83 | teacher_label: "objnav_explorer" 84 | transformer_config: 85 | model_name: "llama" 86 | n_layers: 4 87 | n_heads: 8 88 | n_hidden: 512 89 | n_mlp_hidden: 1024 90 | max_context_length: 100 91 | shuffle_pos_id_for_update: True 92 | 93 | ppo: 94 | # ppo params 95 | clip_param: 0.2 96 | ppo_epoch: 1 97 | num_mini_batch: 2 98 | value_loss_coef: 0.5 99 | entropy_coef: 0.01 100 | lr: 2.5e-4 101 | eps: 1e-5 102 | max_grad_norm: 0.2 103 | num_steps: 100 104 | use_gae: True 105 | gamma: 0.99 106 | tau: 0.95 107 | use_linear_clip_decay: False 108 | use_linear_lr_decay: False 109 | reward_window_size: 50 110 | 111 | use_normalized_advantage: False 112 | 113 | hidden_size: 512 114 | 115 | ddppo: 116 | sync_frac: 0.6 117 | # The PyTorch distributed backend to use 118 | distrib_backend: NCCL 119 | # Visual encoder backbone 120 | pretrained_weights: data/ddppo-models/gibson-2plus-resnet50.pth 121 | # Initialize with pretrained weights 122 | pretrained: False 123 | # Initialize just the visual encoder backbone with pretrained weights 124 | pretrained_encoder: False 125 | # Whether or not the visual encoder backbone will be trained. 126 | train_encoder: False 127 | # Whether or not to reset the critic linear layer 128 | reset_critic: True 129 | -------------------------------------------------------------------------------- /ovon/utils/shuffle_episodes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | from collections import defaultdict 5 | 6 | from ovon.utils.utils import load_dataset, load_json, write_dataset 7 | 8 | 9 | def shuffle_episodes( 10 | dataset_path: str, 11 | output_path: str, 12 | meta_path: str = "data/hm3d_meta/val_splits.json", 13 | ): 14 | category_per_splits = load_json(meta_path) 15 | splits = list(category_per_splits.keys()) 16 | print(splits) 17 | 18 | scenes = glob.glob(os.path.join(dataset_path, splits[0], "content/*json.gz")) 19 | 20 | # make directories 21 | for split in splits: 22 | os.makedirs(os.path.join(output_path, split, "content"), exist_ok=True) 23 | 24 | for scene in scenes: 25 | scene_id = scene.split("/")[-1] 26 | scene_name = scene_id.split(".")[0] 27 | 28 | goals_by_category = {} 29 | episodes_by_category = defaultdict(list) 30 | for split in splits: 31 | path = os.path.join(dataset_path, split, "content", scene_id) 32 | dataset = load_dataset(path) 33 | 34 | for goal_key, goal in dataset["goals_by_category"].items(): 35 | goals_by_category[goal_key] = goal 36 | 37 | for episode in dataset["episodes"]: 38 | episodes_by_category[episode["object_category"]].append(episode) 39 | 40 | for split in splits: 41 | path = os.path.join(dataset_path, split, "content", scene_id) 42 | dataset = load_dataset(path) 43 | 44 | goals_before = len(dataset["goals_by_category"].keys()) 45 | episodes_before = len(dataset["episodes"]) 46 | 47 | dataset["goals_by_category"] = {} 48 | dataset["episodes"] = [] 49 | 50 | for key in category_per_splits[split]: 51 | g_key = "{}.basis.glb_{}".format(scene_name, key) 52 | if goals_by_category.get(g_key) is None: 53 | continue 54 | dataset["goals_by_category"][g_key] = goals_by_category[g_key] 55 | dataset["episodes"].extend(episodes_by_category[key]) 56 | print( 57 | "Split: {}, # of goals: {}/{}, # of episodes: {}/{}".format( 58 | split, 59 | len(dataset["goals_by_category"].keys()), 60 | goals_before, 61 | len(dataset["episodes"]), 62 | episodes_before, 63 | ) 64 | ) 65 | 66 | op = os.path.join(output_path, split, "content", scene_id) 67 | print("Output: {}".format(op)) 68 | write_dataset(dataset, op) 69 | 70 | print("\n") 71 | for split in splits: 72 | files = glob.glob(os.path.join(output_path, split, "content/*json.gz")) 73 | 74 | goals = [] 75 | episodes = 0 76 | for f in files: 77 | dataset = load_dataset(f) 78 | episodes += len(dataset["episodes"]) 79 | 80 | goal_keys = [k.split("_")[-1] for k in dataset["goals_by_category"].keys()] 81 | goals.extend(goal_keys) 82 | 83 | diff = set(category_per_splits[split]).difference(set(goals)) 84 | print( 85 | "Validating Split: {}, # of goals: {}, # of episodes: {}, Difference: {}" 86 | .format(split, len(set(goals)), episodes, len(diff)) 87 | ) 88 | 89 | 90 | if __name__ == "__main__": 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument( 93 | "--dataset_path", type=str, default="data/hm3d_meta/val_splits.json" 94 | ) 95 | parser.add_argument( 96 | "--output_path", type=str, default="data/hm3d_meta/val_splits.json" 97 | ) 98 | 99 | args = parser.parse_args() 100 | 101 | shuffle_episodes(args.dataset_path, args.output_path) 102 | -------------------------------------------------------------------------------- /ovon/utils/analysis/Issues_with_provided_mappings.csv: -------------------------------------------------------------------------------- 1 | "Raw","Corrected " 2 | "All Raw that map to “lamp”", 3 | "lamp","lamp" 4 | "lamp ","lamp" 5 | "lamps","lamp" 6 | "lmap","lamp" 7 | "table lamp","lamp" 8 | " table lamp","lamp" 9 | "unknown/ probably lamp","lamp" 10 | "łamp","lamp" 11 | , 12 | "All Raw that map to “table lamp”", 13 | "unknown/ probably table lamp","table lamp" 14 | , 15 | "All Raw that map to “plant”", 16 | "dacorative plant","plant" 17 | "flower","plant" 18 | "flower ","plant" 19 | "FLOWER","plant" 20 | " flower","plant" 21 | "flowers","plant" 22 | "ornament plant","plant" 23 | "ornamental plant","plant" 24 | "plant","plant" 25 | "plants","plant" 26 | "table plant","plant" 27 | "twigs in vase","plant" 28 | "unknown/ probably decorative plant","plant" 29 | "vegetation","plant" 30 | , 31 | "All Raw that map to “flower”", 32 | "ornament flower","flower" 33 | "vase with flower","flower" 34 | , 35 | "All Raw that map to “couch”", 36 | "couch","couch" 37 | "sofa","couch" 38 | , 39 | "All Raw that map to “sofa”", 40 | "unknown/ probably sofa","sofa" 41 | , 42 | "All Raw that map to “box”", 43 | "box","box" 44 | "boxes","box" 45 | "tabletop box","box" 46 | "tea boxes ","box" 47 | "tea box","box" 48 | "tea boxes","box" 49 | "tin","box" 50 | "towel box","box" 51 | "unknown/ probably box","box" 52 | "unknown/ probably desk or box","box" 53 | , 54 | "All Raw that map to “boxes”", 55 | "pile of boxes","boxes" 56 | , 57 | "All Raw that map to “refrigerator”", 58 | "fridge","refrigerator" 59 | "refridgerator","refrigerator" 60 | "refrigearator","refrigerator" 61 | "refrigerator","refrigerator" 62 | "wine refrigerator","refrigerator" 63 | , 64 | "All Raw that map to “fridge”", 65 | "unknown/ probably fridge","fridge" 66 | , 67 | "All Raw that map to “picture”", 68 | "art","picture" 69 | "framed picture","picture" 70 | "image","picture" 71 | "pictrure","picture" 72 | "picture","picture" 73 | "picture ","picture" 74 | "pictures","picture" 75 | "picure","picture" 76 | "poster","picture" 77 | "unknown picture/window","picture" 78 | "unknown/ probably framed picture","picture" 79 | "wall decoration","picture" 80 | "wall frame","picture" 81 | , 82 | "All Raw that map to “poster”", 83 | "poster figure","poster" 84 | , 85 | "All Raw that map to “bedside table”", 86 | "bedside table","nightstand" 87 | "nighstand","nightstand" 88 | "night stand","nightstand" 89 | "night table","nightstand" 90 | "nightsand","nightstand" 91 | "nightstand","nightstand" 92 | "nigtstand","nightstand" 93 | , 94 | "All Raw that map to “bedside table”", 95 | "beside table","bedside table" 96 | , 97 | "All Raw that map to “paper”", 98 | "paper","paper" 99 | "paper ","paper" 100 | " paper","paper" 101 | "papers","paper" 102 | , 103 | "All Raw that map to “papers”", 104 | "pile of papers","papers" 105 | "unknown/ probably letters","papers" 106 | , 107 | "All Raw that map to “ironing board”", 108 | "ironing board","ironing board" 109 | "ironing board ","ironing board" 110 | , 111 | "All Raw that map to “iron board”", 112 | "iron board","iron board" 113 | , 114 | "All Raw that map to “cup”", 115 | "cup","cup" 116 | "cups","cup" 117 | "toothbrush cup","cup" 118 | "toothbrush holder","cup" 119 | "unknown /probably cup","cup" 120 | "unknown/ proably cup","cup" 121 | "unknown/ probably cups","cup" 122 | , 123 | "All Raw that map to “cups”", 124 | "pile of cups","cups" 125 | , 126 | "All Raw that map to “drawer”", 127 | "drawer","drawer" 128 | "drawers","drawer" 129 | "wardrobe drawer","drawer" 130 | , 131 | "All Raw that map to “drawers”", 132 | "office drawer","drawers" 133 | "organizer drawers","drawers" 134 | , 135 | "All Raw that map to “paper towel”", 136 | "paper towel","paper towel" 137 | "paper towels","paper towel" 138 | , 139 | "All Raw that map to “paper towels”", 140 | "unknown/ probably paper towel dispenser","paper towels" 141 | "unknown/ probably paper towel holder","paper towels" 142 | , 143 | "All Raw that map to “weight”", 144 | "weight","weight" 145 | "weights","weight" 146 | , 147 | "All Raw that map to “weights”", 148 | "workout weight","weights" 149 | -------------------------------------------------------------------------------- /ovon/models/encoders/cross_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class CrossAttention(nn.Module): 6 | def __init__( 7 | self, 8 | x1_dim: int, 9 | x2_dim: int, 10 | num_heads: int, 11 | use_vis_query: bool, 12 | use_residual: bool, 13 | residual_vision: bool, 14 | embed_dim: int = None, 15 | dropout: float = 0.1, 16 | ) -> None: 17 | """ 18 | Meant for fusion of two different modalities, x1 being language embeddings and 19 | x2 being visual embeddings. 20 | 21 | Args: 22 | x1_dim: Dimension of the first input (language) 23 | x2_dim: Dimension of the second input (visual) 24 | num_heads: Number of heads for the multihead attention 25 | use_vis_query: Whether to use visual encoding as the query and value 26 | use_residual: Whether to use the residual connection 27 | embed_dim: Dimension of the embedding space 28 | dropout: Dropout rate for the multihead attention 29 | """ 30 | super(CrossAttention, self).__init__() 31 | 32 | embed_dim = embed_dim or num_heads * 64 33 | 34 | if x1_dim == x2_dim == embed_dim: 35 | self.proj1 = nn.Identity() 36 | self.proj2 = nn.Identity() 37 | else: 38 | # Linear layers to project x1 and x2 into embedding space 39 | self.proj1 = nn.Linear(x1_dim, embed_dim) 40 | self.proj2 = nn.Linear(x2_dim, embed_dim) 41 | 42 | # Initialize with Xavier initialization 43 | nn.init.xavier_uniform_(self.proj1.weight) 44 | nn.init.xavier_uniform_(self.proj2.weight) 45 | nn.init.zeros_(self.proj1.bias) 46 | nn.init.zeros_(self.proj2.bias) 47 | 48 | self.multihead_attn = nn.MultiheadAttention( 49 | embed_dim=embed_dim, num_heads=num_heads, dropout=dropout 50 | ) 51 | 52 | # Initialize weights and biases in MultiheadAttention 53 | for name, param in self.multihead_attn.named_parameters(): 54 | if "weight" in name: 55 | nn.init.xavier_uniform_(param) 56 | elif "bias" in name: 57 | nn.init.zeros_(param) 58 | 59 | self.norm = nn.LayerNorm(embed_dim) 60 | self.use_vis_query = use_vis_query 61 | self.use_residual = use_residual 62 | self.residual_vision = residual_vision 63 | self.output_size = embed_dim 64 | 65 | def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: 66 | """ 67 | Args: 68 | x1: [batch_size, x1_dim] tensor (language) 69 | x2: [batch_size, x2_dim] tensor (visual) 70 | 71 | Returns: 72 | [batch_size, embed_dim] tensor 73 | """ 74 | 75 | # Project x1 and x2 into the embedding space 76 | x1_proj = self.proj1(x1) 77 | x2_proj = self.proj2(x2) 78 | 79 | # Reshape the tensors to be [seq_len, batch_size, embed_dim], where seq_len is 1 80 | x1_proj = x1_proj.unsqueeze(0) 81 | x2_proj = x2_proj.unsqueeze(0) 82 | 83 | # Perform the cross-attention calculation based on use_vis_query 84 | if self.use_vis_query: 85 | query = x2_proj 86 | key = x1_proj 87 | value = x1_proj 88 | else: 89 | query = x1_proj 90 | key = x2_proj 91 | value = x2_proj 92 | 93 | # output: [1, batch_size, embed_dim] 94 | output, _ = self.multihead_attn(query=query, key=key, value=value) 95 | 96 | if self.use_residual: 97 | # Add residual connection 98 | if self.residual_vision: 99 | output += x2_proj 100 | else: 101 | output += query 102 | 103 | # Apply Layer Normalization 104 | # output: [1, batch_size, embed_dim] 105 | output = self.norm(output) 106 | 107 | # Squeeze the sequence length dimension 108 | # output: [batch_size, embed_dim] 109 | output = output.squeeze(0) 110 | 111 | return output 112 | -------------------------------------------------------------------------------- /ovon/measurements/imagenav.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any 2 | 3 | import numpy as np 4 | import quaternion 5 | from habitat.core.embodied_task import EmbodiedTask, Measure 6 | from habitat.core.registry import registry 7 | from habitat.core.simulator import Simulator 8 | from habitat.tasks.nav.nav import NavigationEpisode, Success 9 | from habitat.utils.geometry_utils import ( 10 | angle_between_quaternions, 11 | quaternion_from_coeff, 12 | ) 13 | 14 | if TYPE_CHECKING: 15 | from omegaconf import DictConfig 16 | 17 | 18 | @registry.register_measure 19 | class AngleToGoal(Measure): 20 | """The measure calculates an angle towards the goal. Note: this measure is 21 | only valid for single goal tasks (e.g., ImageNav) 22 | """ 23 | 24 | cls_uuid: str = "angle_to_goal" 25 | 26 | def __init__(self, sim: Simulator, *args: Any, **kwargs: Any): 27 | super().__init__() 28 | self._sim = sim 29 | 30 | def _get_uuid(self, *args: Any, **kwargs: Any) -> str: 31 | return self.cls_uuid 32 | 33 | def reset_metric(self, episode, *args: Any, **kwargs: Any): 34 | self._metric = None 35 | self.update_metric(episode=episode, *args, **kwargs) # type: ignore 36 | 37 | def update_metric(self, episode: NavigationEpisode, *args: Any, **kwargs: Any): 38 | current_angle = self._sim.get_agent_state().rotation 39 | if not isinstance(current_angle, quaternion.quaternion): 40 | current_angle = quaternion_from_coeff(current_angle) 41 | 42 | goal_angle = episode.goals[0].rotation 43 | if not isinstance(goal_angle, quaternion.quaternion): 44 | goal_angle = quaternion_from_coeff(goal_angle) 45 | 46 | self._metric = angle_between_quaternions(current_angle, goal_angle) 47 | 48 | 49 | @registry.register_measure 50 | class AngleSuccess(Measure): 51 | """Weather or not the agent is within an angle tolerance.""" 52 | 53 | cls_uuid: str = "angle_success" 54 | 55 | def __init__(self, config: "DictConfig", *args: Any, **kwargs: Any): 56 | self._config = config 57 | 58 | super().__init__() 59 | 60 | def _get_uuid(self, *args: Any, **kwargs: Any) -> str: 61 | return self.cls_uuid 62 | 63 | def reset_metric(self, task: EmbodiedTask, *args: Any, **kwargs: Any): 64 | task.measurements.check_measure_dependencies( 65 | self.uuid, [Success.cls_uuid, AngleToGoal.cls_uuid] 66 | ) 67 | self.update_metric(task=task, *args, **kwargs) # type: ignore 68 | 69 | def update_metric(self, task: EmbodiedTask, *args: Any, **kwargs: Any): 70 | success = task.measurements.measures[Success.cls_uuid].get_metric() 71 | angle_to_goal = task.measurements.measures[AngleToGoal.cls_uuid].get_metric() 72 | 73 | if success and np.rad2deg(angle_to_goal) < self._config.success_angle: 74 | self._metric = 1.0 75 | else: 76 | self._metric = 0.0 77 | 78 | 79 | @registry.register_measure 80 | class AgentPosition(Measure): 81 | """The measure calculates current position of agent""" 82 | 83 | cls_uuid: str = "agent_position" 84 | 85 | def __init__(self, sim: Simulator, *args: Any, **kwargs: Any): 86 | super().__init__() 87 | self._sim = sim 88 | 89 | def _get_uuid(self, *args: Any, **kwargs: Any) -> str: 90 | return self.cls_uuid 91 | 92 | def reset_metric(self, *args: Any, **kwargs: Any): 93 | self._metric = None 94 | self.update_metric(*args, **kwargs) # type: ignore 95 | 96 | def update_metric(self, *args: Any, **kwargs: Any): 97 | self._metric = self._sim.get_agent_state().position 98 | 99 | 100 | @registry.register_measure 101 | class AgentRotation(Measure): 102 | """The measure calculates current position of agent""" 103 | 104 | cls_uuid: str = "agent_rotation" 105 | 106 | def __init__(self, sim: Simulator, *args: Any, **kwargs: Any): 107 | super().__init__() 108 | self._sim = sim 109 | 110 | def _get_uuid(self, *args: Any, **kwargs: Any) -> str: 111 | return self.cls_uuid 112 | 113 | def reset_metric(self, *args: Any, **kwargs: Any): 114 | self._metric = None 115 | self.update_metric(*args, **kwargs) # type: ignore 116 | 117 | def update_metric(self, *args: Any, **kwargs: Any): 118 | self._metric = self._sim.get_agent_state().rotation 119 | -------------------------------------------------------------------------------- /ovon/utils/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import gzip 3 | import json 4 | import os 5 | import pickle 6 | from collections import defaultdict 7 | 8 | import numpy as np 9 | import torch 10 | from habitat.utils.visualizations import maps 11 | from PIL import Image 12 | from tqdm import tqdm 13 | 14 | from ovon.models.encoders.resnet_gn import ResNet 15 | 16 | 17 | def write_json(data, path): 18 | with open(path, "w") as file: 19 | file.write(json.dumps(data)) 20 | 21 | 22 | def load_json(path): 23 | file = open(path, "r") 24 | data = json.loads(file.read()) 25 | return data 26 | 27 | 28 | def write_txt(data, path): 29 | with open(path, "w") as file: 30 | file.write("\n".join(data)) 31 | 32 | 33 | def save_image(img, file_name): 34 | im = Image.fromarray(img) 35 | im.save(file_name) 36 | 37 | 38 | def load_dataset(path): 39 | with gzip.open(path, "rt") as file: 40 | data = json.loads(file.read(), encoding="utf-8") 41 | return data 42 | 43 | 44 | def save_pickle(data, path): 45 | file = open(path, "wb") 46 | data = pickle.dump(data, file) 47 | 48 | 49 | def load_pickle(path): 50 | file = open(path, "rb") 51 | data = pickle.load(file) 52 | return data 53 | 54 | 55 | def write_dataset(data, path): 56 | with gzip.open(path, "wt") as f: 57 | json.dump(data, f) 58 | 59 | 60 | def is_on_same_floor(height, ref_floor_height, ceiling_height=0.5): 61 | return ( 62 | (ref_floor_height - ceiling_height) 63 | <= height 64 | < (ref_floor_height + ceiling_height) 65 | ) 66 | 67 | 68 | def draw_point(sim, top_down_map, position, point_type, point_padding=2): 69 | t_x, t_y = maps.to_grid( 70 | position[2], 71 | position[0], 72 | (top_down_map.shape[0], top_down_map.shape[1]), 73 | sim=sim, 74 | ) 75 | top_down_map[ 76 | t_x - point_padding : t_x + point_padding + 1, 77 | t_y - point_padding : t_y + point_padding + 1, 78 | ] = point_type 79 | return top_down_map 80 | 81 | 82 | def draw_bounding_box( 83 | sim, top_down_map, goal_object_id, ref_floor_height, line_thickness=4 84 | ): 85 | sem_scene = sim.semantic_annotations() 86 | object_id = goal_object_id 87 | 88 | sem_obj = None 89 | for object in sem_scene.objects: 90 | if object.id == object_id: 91 | sem_obj = object 92 | break 93 | 94 | center = sem_obj.aabb.center 95 | x_len, _, z_len = sem_obj.aabb.sizes / 2.0 96 | # Nodes to draw rectangle 97 | corners = [ 98 | center + np.array([x, 0, z]) 99 | for x, z in [ 100 | (-x_len, -z_len), 101 | (-x_len, z_len), 102 | (x_len, z_len), 103 | (x_len, -z_len), 104 | (-x_len, -z_len), 105 | ] 106 | if is_on_same_floor(center[1], ref_floor_height=ref_floor_height) 107 | ] 108 | 109 | map_corners = [ 110 | maps.to_grid( 111 | p[2], 112 | p[0], 113 | ( 114 | top_down_map.shape[0], 115 | top_down_map.shape[1], 116 | ), 117 | sim=sim, 118 | ) 119 | for p in corners 120 | ] 121 | 122 | maps.draw_path( 123 | top_down_map, 124 | map_corners, 125 | maps.MAP_TARGET_BOUNDING_BOX, 126 | line_thickness, 127 | ) 128 | return top_down_map 129 | 130 | 131 | def load_encoder(encoder, path): 132 | assert os.path.exists(path) 133 | if isinstance(encoder.backbone, ResNet): 134 | state_dict = torch.load(path, map_location="cpu")["teacher"] 135 | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} 136 | return encoder.load_state_dict(state_dict=state_dict, strict=False) 137 | else: 138 | raise ValueError("unknown encoder backbone") 139 | 140 | 141 | def count_episodes(path): 142 | files = glob.glob(os.path.join(path, "*.json.gz")) 143 | count = 0 144 | categories = defaultdict(int) 145 | for f in tqdm(files): 146 | dataset = load_dataset(f) 147 | for episode in dataset["episodes"]: 148 | categories[episode["object_category"]] += 1 149 | count += len(dataset["episodes"]) 150 | print("Total episodes: {}".format(count)) 151 | print("Categories: {}".format(categories)) 152 | print("Total categories: {}".format(len(categories))) 153 | return count, categories 154 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

HM3D-OVON: A Dataset and Benchmark for Open-Vocabulary Object Goal Navigation

4 |

5 | Naoki Yokoyama, 6 | Ram Ramrakhya, 7 | Abhishek Das, 8 | Dhruv Batra, 9 | Sehoon Ha 10 |

11 |

12 | Project Website 13 |

14 |

15 | 16 | ## Overview 17 | 18 | We present the Habitat-Matterport 3D Open Vocabulary Object Goal Navigation dataset (HM3D-OVON), a large-scale benchmark that broadens the scope and se- mantic range of prior Object Goal Navigation (ObjectNav) benchmarks. Leveraging the HM3DSem dataset, HM3D-OVON incorporates over 15k annotated instances of household objects across 379 distinct categories, derived from photo-realistic 3D scans of real-world environments. In contrast to earlier ObjectNav datasets, which limit goal objects to a predefined set of 6-21 categories, HM3D-OVON facilitates the training and evaluation of models with an open-set of goals defined through free-form language at test-time. Through this open-vocabulary formulation, HM3D-OVON encourages progress towards learning visuo-semantic navigation behaviors that are capable of searching for any object specified by text in an open- vocabulary manner. Additionally, we systematically evaluate and compare several different types of approaches on HM3D-OVON. We find that HM3D-OVON can be used to train an open-vocabulary ObjectNav agent that achieves both higher performance and is more robust to localization and actuation noise than the state-of-the-art ObjectNav approach. We hope that our benchmark and baseline results will drive interest in developing embodied agents that can navigate real-world spaces to find household objects specified through free-form language, taking a step towards more flexible and human-like semantic visual navigation. Videos available at: naoki.io/ovon. 19 | 20 | ## :hammer_and_wrench: Installation 21 | 22 | ### Getting Started 23 | 24 | Create the conda environment and install all of the dependencies. Mamba is recommended for faster installation: 25 | ```bash 26 | conda_env_name=ovon 27 | mamba create -n $conda_env_name python=3.7 cmake=3.14.0 -y 28 | mamba install -n $conda_env_name \ 29 | habitat-sim=0.2.3 headless pytorch=1.12.1 cudatoolkit=11.3 \ 30 | -c pytorch -c nvidia -c conda-forge -c aihabitat -y 31 | 32 | # Install this repo as a package 33 | mamba activate $conda_env_name 34 | pip install -e . 35 | 36 | # Install frontier_exploration 37 | cd frontier_exploration && pip install -e . && cd .. 38 | 39 | # Install habitat-lab 40 | git clone --branch v0.2.3 git@github.com:facebookresearch/habitat-lab.git 41 | cd habitat-lab 42 | pip install -e habitat-lab 43 | pip install -e habitat-baselines 44 | 45 | pip install ftfy regex tqdm GPUtil trimesh seaborn timm scikit-learn einops transformers 46 | pip install git+https://github.com/openai/CLIP.git 47 | ``` 48 | ## :dart: Downloading the datasets 49 | First, set the following variables during installation (don't need to put in .bashrc): 50 | ```bash 51 | MATTERPORT_TOKEN_ID= 52 | MATTERPORT_TOKEN_SECRET= 53 | DATA_DIR= 54 | ``` 55 | 56 | ### Clone and install habitat-lab, then download datasets 57 | ```bash 58 | # Download HM3D 3D scans (scenes_dataset) 59 | python -m habitat_sim.utils.datasets_download \ 60 | --username $MATTERPORT_TOKEN_ID --password $MATTERPORT_TOKEN_SECRET \ 61 | --uids hm3d_train_v0.2 \ 62 | --data-path $DATA_DIR && 63 | python -m habitat_sim.utils.datasets_download \ 64 | --username $MATTERPORT_TOKEN_ID --password $MATTERPORT_TOKEN_SECRET \ 65 | --uids hm3d_val_v0.2 \ 66 | --data-path $DATA_DIR 67 | ``` 68 | 69 | The OVON navigation episodes can be found here: https://huggingface.co/datasets/nyokoyama/hm3d_ovon/ 70 | The tar.gz file should be decompressed in `data/datasets/ovon/`, such that the `hm3d` directory is located at `data/datasets/ovon/hm3d/`. 71 | 72 | ## :weight_lifting: Downloading pre-trained weights 73 | The weights for the DagRL policy can be downloaded from the following link: 74 | - `dagrl.pth`: https://drive.google.com/drive/folders/1U-tnPYQa81JbYHSlW1nyjiXOK8cE2Ki8?usp=sharing 75 | 76 | ## :arrow_forward: Evaluation within Habitat 77 | 78 | Run the following to evaluate: 79 | ```bash 80 | python -m ovon.run \ 81 | --run-type eval \ 82 | --exp-config config/experiments/dagger_objectnav.yaml \ 83 | habitat_baselines.eval_ckpt_path_dir= 84 | ``` 85 | 86 | ## :rocket: Training 87 | 88 | Run the following to train: 89 | ```bash 90 | python -m ovon.run \ 91 | --run-type train \ 92 | --exp-config config/experiments/dagger_objectnav.yaml 93 | ``` 94 | -------------------------------------------------------------------------------- /ovon/trainers/dagger_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from habitat import logger 3 | from habitat.config import read_write 4 | from habitat_baselines.common.baseline_registry import baseline_registry 5 | from habitat_baselines.common.obs_transformers import ( 6 | apply_obs_transforms_obs_space, 7 | get_active_obs_transforms, 8 | ) 9 | from habitat_baselines.rl.ddppo.policy import PointNavResNetNet # noqa: F401. 10 | from omegaconf import DictConfig, open_dict 11 | 12 | from ovon.algos.dagger import DAgger, DAggerPolicy, DDPDAgger 13 | from ovon.trainers.ver_transformer_trainer import VERTransformerTrainer 14 | 15 | try: 16 | torch.backends.cudnn.allow_tf32 = True 17 | torch.backends.cuda.matmul.allow_tf32 = True 18 | except AttributeError: 19 | pass 20 | 21 | 22 | @baseline_registry.register_trainer(name="ver_dagger") 23 | @baseline_registry.register_trainer(name="ver_il") 24 | class VERDAggerTrainer(VERTransformerTrainer): 25 | def __init__(self, config: DictConfig): 26 | with open_dict(config.habitat_baselines.rl.policy): # allow new keys 27 | with read_write(config.habitat_baselines.rl.policy): # allow write 28 | if hasattr(config.habitat_baselines.rl.policy, "original_name"): 29 | original_name = config.habitat_baselines.rl.policy.original_name 30 | else: 31 | original_name = config.habitat_baselines.rl.policy.name 32 | config.habitat_baselines.rl.policy["original_name"] = ( 33 | # add new key "original_name" 34 | original_name 35 | ) 36 | config.habitat_baselines.rl.policy["teacher_forcing"] = ( 37 | # add new key "teacher_forcing" 38 | config.habitat_baselines.trainer_name 39 | == "ver_il" 40 | ) 41 | config.habitat_baselines.rl.policy.name = DAggerPolicy.__name__ 42 | super().__init__(config) 43 | 44 | def _setup_actor_critic_agent(self, ppo_cfg: "DictConfig") -> None: 45 | r"""Same as PPOTrainer._setup_actor_critic_agent but mixes the policy class with 46 | DAgger mixin so that evaluate_actions induces the correct gradients, and allows 47 | the usage of DAgger agent instead of DDPPO or PPO. Also, the critic is gone, so 48 | we don't need to reset it. 49 | 50 | Args: 51 | ppo_cfg: config node with relevant params 52 | 53 | Returns: 54 | None 55 | """ 56 | logger.add_filehandler(self.config.habitat_baselines.log_file) 57 | 58 | policy = baseline_registry.get_policy( 59 | self.config.habitat_baselines.rl.policy.name 60 | ) 61 | observation_space = self.obs_space 62 | self.obs_transforms = get_active_obs_transforms(self.config) 63 | observation_space = apply_obs_transforms_obs_space( 64 | observation_space, self.obs_transforms 65 | ) 66 | 67 | self.actor_critic = policy.from_config( 68 | self.config, 69 | observation_space, 70 | self.policy_action_space, 71 | orig_action_space=self.orig_policy_action_space, 72 | ) 73 | self.obs_space = observation_space 74 | self.actor_critic.to(self.device) 75 | 76 | if ( 77 | self.config.habitat_baselines.rl.ddppo.pretrained_encoder 78 | or self.config.habitat_baselines.rl.ddppo.pretrained 79 | ): 80 | pretrained_state = torch.load( 81 | self.config.habitat_baselines.rl.ddppo.pretrained_weights, 82 | map_location="cpu", 83 | ) 84 | 85 | if self.config.habitat_baselines.rl.ddppo.pretrained: 86 | self.actor_critic.load_state_dict( 87 | { # type: ignore 88 | k[len("actor_critic.") :]: v 89 | for k, v in pretrained_state["state_dict"].items() 90 | } 91 | ) 92 | elif self.config.habitat_baselines.rl.ddppo.pretrained_encoder: 93 | prefix = "actor_critic.net.visual_encoder." 94 | self.actor_critic.net.visual_encoder.load_state_dict( 95 | { 96 | k[len(prefix) :]: v 97 | for k, v in pretrained_state["state_dict"].items() 98 | if k.startswith(prefix) 99 | } 100 | ) 101 | 102 | if not self.config.habitat_baselines.rl.ddppo.train_encoder and hasattr( 103 | self.actor_critic, "net" 104 | ): 105 | self._static_encoder = True 106 | for param in self.actor_critic.net.visual_encoder.parameters(): 107 | param.requires_grad_(False) 108 | 109 | self.agent = (DDPDAgger if self._is_distributed else DAgger).from_config( 110 | self.actor_critic, ppo_cfg 111 | ) 112 | -------------------------------------------------------------------------------- /ovon/models/transformer_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Tuple 2 | 3 | import torch 4 | from gym import spaces 5 | from habitat_baselines.common.baseline_registry import baseline_registry 6 | from habitat_baselines.rl.ppo import NetPolicy 7 | from omegaconf import DictConfig 8 | 9 | from ovon.models.clip_policy import OVONNet, PointNavResNetCLIPPolicy 10 | from ovon.models.transformer_encoder import TransformerEncoder 11 | 12 | 13 | @baseline_registry.register_policy 14 | class OVONTransformerPolicy(PointNavResNetCLIPPolicy): 15 | is_transformer = True 16 | 17 | def __init__( 18 | self, 19 | observation_space: spaces.Dict, 20 | action_space, 21 | transformer_config, 22 | hidden_size: int = 512, 23 | num_recurrent_layers: int = 1, 24 | rnn_type: str = "GRU", 25 | backbone: str = "clip_avgpool", 26 | policy_config: DictConfig = None, 27 | aux_loss_config: Optional[DictConfig] = None, 28 | depth_ckpt: str = "", 29 | fusion_type: str = "concat", 30 | attn_heads: int = 3, 31 | use_vis_query: bool = False, 32 | use_residual: bool = True, 33 | residual_vision: bool = False, 34 | unfreeze_xattn: bool = False, 35 | rgb_only: bool = True, 36 | use_prev_action: bool = True, 37 | use_odom: bool = False, 38 | **kwargs, 39 | ): 40 | self.unfreeze_xattn = unfreeze_xattn 41 | if policy_config is not None: 42 | discrete_actions = policy_config.action_distribution_type == "categorical" 43 | self.action_distribution_type = policy_config.action_distribution_type 44 | else: 45 | discrete_actions = True 46 | self.action_distribution_type = "categorical" 47 | 48 | NetPolicy.__init__( 49 | self, 50 | OVONTransformerNet( 51 | observation_space=observation_space, 52 | action_space=action_space, # for previous action 53 | hidden_size=hidden_size, 54 | num_recurrent_layers=num_recurrent_layers, 55 | rnn_type=rnn_type, 56 | backbone=backbone, 57 | discrete_actions=discrete_actions, 58 | depth_ckpt=depth_ckpt, 59 | fusion_type=fusion_type, 60 | attn_heads=attn_heads, 61 | use_vis_query=use_vis_query, 62 | use_residual=use_residual, 63 | residual_vision=residual_vision, 64 | transformer_config=transformer_config, 65 | rgb_only=rgb_only, 66 | use_prev_action=use_prev_action, 67 | use_odom=use_odom, 68 | ), 69 | action_space=action_space, 70 | policy_config=policy_config, 71 | aux_loss_config=aux_loss_config, 72 | ) 73 | 74 | @classmethod 75 | def from_config(cls, config: DictConfig, *args, **kwargs): 76 | tf_cfg = config.habitat_baselines.rl.policy.transformer_config 77 | return super().from_config(config, transformer_config=tf_cfg, *args, **kwargs) 78 | 79 | @property 80 | def num_recurrent_layers(self): 81 | return self.net.state_encoder.n_layers 82 | 83 | @property 84 | def num_heads(self): 85 | return self.net.state_encoder.n_head 86 | 87 | @property 88 | def max_context_length(self): 89 | return self.net.state_encoder.max_context_length 90 | 91 | @property 92 | def recurrent_hidden_size(self): 93 | return self.net.state_encoder.n_embed 94 | 95 | 96 | class OVONTransformerNet(OVONNet): 97 | """Same as OVONNet but uses transformer instead of LSTM.""" 98 | 99 | def __init__(self, transformer_config, *args, **kwargs): 100 | self.transformer_config = transformer_config 101 | super().__init__(*args, **kwargs) 102 | 103 | @property 104 | def output_size(self): 105 | return self.transformer_config.n_hidden 106 | 107 | def build_state_encoder(self): 108 | state_encoder = TransformerEncoder( 109 | self.rnn_input_size, config=self.transformer_config 110 | ) 111 | return state_encoder 112 | 113 | def forward( 114 | self, 115 | observations: Dict[str, torch.Tensor], 116 | rnn_hidden_states, 117 | prev_actions, 118 | masks, 119 | rnn_build_seq_info: Optional[Dict[str, torch.Tensor]] = None, 120 | ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]: 121 | if "step_id" in observations: 122 | if rnn_build_seq_info is None: 123 | # Means online inference. Update should already have "episode_ids" key. 124 | rnn_build_seq_info = {} 125 | rnn_build_seq_info["step_id"] = observations["step_id"] 126 | return super().forward( 127 | observations, 128 | rnn_hidden_states, 129 | prev_actions, 130 | masks, 131 | rnn_build_seq_info, 132 | ) 133 | -------------------------------------------------------------------------------- /ovon/utils/visualize/viz.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | import cv2 4 | import numpy as np 5 | from habitat_baselines.rl.ddppo.policy import ( # noqa: F401. 6 | PointNavResNetNet, 7 | PointNavResNetPolicy, 8 | ) 9 | 10 | 11 | def append_text_to_image(image: np.ndarray, text: List[str], font_size: float = 0.5): 12 | r"""Appends lines of text on top of an image. First this will render to the 13 | left-hand side of the image, once that column is full, it will render to 14 | the right hand-side of the image. 15 | :param image: the image to put text underneath 16 | :param text: The list of strings which will be rendered, separated by new lines. 17 | :returns: A new image with text inserted underneath the input image 18 | """ 19 | h, w, c = image.shape 20 | font_thickness = 1 21 | font = cv2.FONT_HERSHEY_SIMPLEX 22 | 23 | y = 0 24 | left_aligned = True 25 | for line in text: 26 | textsize = cv2.getTextSize(line, font, font_size, font_thickness)[0] 27 | y += textsize[1] + 10 28 | if y > h: 29 | left_aligned = False 30 | y = textsize[1] + 10 31 | 32 | if left_aligned: 33 | x = 10 34 | else: 35 | x = w - (textsize[0] + 10) 36 | 37 | cv2.putText( 38 | image, 39 | line, 40 | (x, y), 41 | font, 42 | font_size, 43 | (0, 0, 0), 44 | font_thickness * 2, 45 | lineType=cv2.LINE_AA, 46 | ) 47 | 48 | cv2.putText( 49 | image, 50 | line, 51 | (x, y), 52 | font, 53 | font_size, 54 | (255, 255, 255, 255), 55 | font_thickness, 56 | lineType=cv2.LINE_AA, 57 | ) 58 | 59 | return np.clip(image, 0, 255) 60 | 61 | 62 | def flatten_dict(d: Dict, parent_key: str = "", sep: str = ".") -> Dict: 63 | r"""Flattens nested dict. 64 | 65 | Source: https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys 66 | 67 | :param d: Nested dict. 68 | :param parent_key: Parameter to set parent dict key. 69 | :param sep: Nested keys separator. 70 | :return: Flattened dict. 71 | """ 72 | items: List[Tuple[str, Any]] = [] 73 | for k, v in d.items(): 74 | new_key = parent_key + sep + str(k) if parent_key else str(k) 75 | if isinstance(v, dict): 76 | items.extend(flatten_dict(v, parent_key=new_key).items()) 77 | else: 78 | items.append((new_key, v)) 79 | return dict(items) 80 | 81 | 82 | def overlay_text_to_image(image: np.ndarray, text: List[str], font_size: float = 0.5): 83 | r"""Overlays lines of text on top of an image. 84 | 85 | First this will render to the left-hand side of the image, once that column is full, 86 | it will render to the right hand-side of the image. 87 | 88 | :param image: The image to put text on top. 89 | :param text: The list of strings which will be rendered (separated by new lines). 90 | :param font_size: Font size. 91 | :return: A new image with text overlaid on top. 92 | """ 93 | h, w, c = image.shape 94 | font_thickness = 1 95 | font = cv2.FONT_HERSHEY_SIMPLEX 96 | 97 | y = 0 98 | left_aligned = True 99 | for line in text: 100 | textsize = cv2.getTextSize(line, font, font_size, font_thickness)[0] 101 | y += textsize[1] + 10 102 | if y > h: 103 | left_aligned = False 104 | y = textsize[1] + 10 105 | 106 | if left_aligned: 107 | x = 10 108 | else: 109 | x = w - (textsize[0] + 10) 110 | 111 | cv2.putText( 112 | image, 113 | line, 114 | (x, y), 115 | font, 116 | font_size, 117 | (0, 0, 0), 118 | font_thickness * 2, 119 | lineType=cv2.LINE_AA, 120 | ) 121 | 122 | cv2.putText( 123 | image, 124 | line, 125 | (x, y), 126 | font, 127 | font_size, 128 | (255, 255, 255, 255), 129 | font_thickness, 130 | lineType=cv2.LINE_AA, 131 | ) 132 | 133 | return np.clip(image, 0, 255) 134 | 135 | 136 | def overlay_frame( 137 | frame: np.ndarray, info: Dict[str, Any], additional: Optional[List[str]] = None 138 | ) -> np.ndarray: 139 | """ 140 | Renders text from the `info` dictionary to the `frame` image. 141 | """ 142 | 143 | lines = [] 144 | flattened_info = flatten_dict(info) 145 | for k, v in flattened_info.items(): 146 | if isinstance(v, str): 147 | lines.append(f"{k}: {v}") 148 | else: 149 | try: 150 | lines.append(f"{k}: {v:.2f}") 151 | except TypeError: 152 | pass 153 | if additional is not None: 154 | lines.extend(additional) 155 | 156 | frame = overlay_text_to_image(frame, lines, font_size=0.25) 157 | 158 | return frame 159 | -------------------------------------------------------------------------------- /scripts/dataset/generate_val_episode_counts.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import gzip 3 | import json 4 | import os.path as osp 5 | from collections import defaultdict 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | PLOT = True 10 | 11 | 12 | def main(val_split_dir: str): 13 | # Get basename of val_split_dir, remove trailing slash if present first though 14 | val_split_dir = val_split_dir.rstrip("/") 15 | val_split_dir_basename = osp.basename(val_split_dir) 16 | 17 | json_gz_files = glob.glob(f"{val_split_dir}/content/*.json.gz") 18 | 19 | category_to_scene_count = defaultdict(int) 20 | 21 | for gz in json_gz_files: 22 | with gzip.open(gz) as f: 23 | data = json.load(f) 24 | categories = [k.split("glb_")[-1] for k in data["goals_by_category"].keys()] 25 | for category in categories: 26 | category_to_scene_count[category] += 1 27 | 28 | print("Num categories:", len(category_to_scene_count)) 29 | 30 | # Count area under curve of bar graph 31 | area_under_curve = sum(category_to_scene_count.values()) 32 | 33 | print("Area under curve:", area_under_curve) 34 | category_to_total_count = scale_dict_values(category_to_scene_count, 3000) 35 | category_to_scene_counts = distribute_values( 36 | category_to_total_count, category_to_scene_count 37 | ) 38 | 39 | print("Sum of values:", sum(category_to_total_count.values())) 40 | 41 | # Get sum of values in category_to_scene_counts 42 | sum_of_values = sum([sum(x) for x in category_to_scene_counts.values()]) 43 | print("Sum of values:", sum_of_values) 44 | 45 | scene_id_to_category_to_count = {} 46 | for gz in json_gz_files: 47 | with gzip.open(gz) as f: 48 | data = json.load(f) 49 | 50 | category_to_count = defaultdict(int) 51 | categories = [k.split("glb_")[-1] for k in data["goals_by_category"].keys()] 52 | for category in categories: 53 | category_to_count[category] += category_to_scene_counts[category].pop() 54 | 55 | scene_id = osp.basename(gz).replace(".json.gz", "") 56 | scene_id_to_category_to_count[scene_id] = category_to_count 57 | 58 | # Assert all values of category_to_scene_counts are now [] 59 | for category, counts in category_to_scene_counts.items(): 60 | assert len(counts) == 0, f"category: {category}, counts: {counts}" 61 | 62 | # Save scene_id_to_category_to_count to a json file 63 | with open(f"{val_split_dir_basename}_scene_id_to_category_to_count.json", "w") as f: 64 | json.dump(scene_id_to_category_to_count, f, indent=4) 65 | 66 | if PLOT: 67 | # Plot a bar graph of the number of scenes per category 68 | # with category names as x-axis labels 69 | scale = 2.0 70 | plt.figure(figsize=(8 * scale, 6 * scale)) 71 | plt.xlabel("Category name") 72 | plt.ylabel("Number of scenes") 73 | # Sort bars by height 74 | sorted_category_to_scene_count = sorted( 75 | category_to_scene_count.items(), key=lambda x: -x[1] 76 | ) 77 | plt.bar( 78 | [x[0] for x in sorted_category_to_scene_count], 79 | [category_to_total_count[x[0]] for x in sorted_category_to_scene_count], 80 | ) 81 | plt.bar( 82 | [x[0] for x in sorted_category_to_scene_count], 83 | [x[1] for x in sorted_category_to_scene_count], 84 | ) 85 | plt.xticks(rotation=90) 86 | 87 | # Save plot to a file 88 | plt.savefig(f"{val_split_dir_basename}_hist.png") 89 | 90 | 91 | def scale_dict_values(dictionary, target_sum): 92 | original_sum = sum(dictionary.values()) 93 | scaling_factor = target_sum / original_sum 94 | 95 | scaled_dict = {} 96 | scaled_sum = 0 97 | 98 | for key, value in dictionary.items(): 99 | scaled_value = round(value * scaling_factor) 100 | scaled_dict[key] = scaled_value 101 | scaled_sum += scaled_value 102 | 103 | # Adjust for rounding errors 104 | if scaled_sum != target_sum: 105 | diff = target_sum - scaled_sum 106 | sorted_values = sorted( 107 | scaled_dict.items(), 108 | key=lambda x: abs(x[1] - (scaled_sum / len(scaled_dict))), 109 | ) 110 | for i in range(abs(diff)): 111 | key = sorted_values[i][0] 112 | scaled_dict[key] += 1 if diff > 0 else -1 113 | 114 | return scaled_dict 115 | 116 | 117 | def distribute_values(sum_dict, length_dict): 118 | result_dict = {} 119 | 120 | for key in sum_dict: 121 | value_sum = sum_dict[key] 122 | value_length = length_dict[key] 123 | 124 | quotient = value_sum // value_length 125 | remainder = value_sum % value_length 126 | 127 | values = [quotient] * value_length 128 | 129 | for i in range(remainder): 130 | values[i] += 1 131 | 132 | result_dict[key] = values 133 | 134 | return result_dict 135 | 136 | 137 | if __name__ == "__main__": 138 | import argparse 139 | 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument("val_split_dir", type=str) 142 | args = parser.parse_args() 143 | 144 | main(args.val_split_dir) 145 | -------------------------------------------------------------------------------- /ovon/models/encoders/vit.py: -------------------------------------------------------------------------------- 1 | # adapted from: https://github.com/facebookresearch/mae/blob/main/models_vit.py 2 | from functools import partial 3 | 4 | import timm.models.vision_transformer 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | # fmt: off 10 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 11 | """ Vision Transformer with support for global average pooling 12 | """ 13 | def __init__(self, use_fc_norm=False, global_pool=False, use_cls=False, mask_ratio=None, **kwargs): 14 | super(VisionTransformer, self).__init__(**kwargs) 15 | assert not (global_pool and use_cls) 16 | 17 | del self.head # don't use prediction head 18 | 19 | self.use_fc_norm = use_fc_norm 20 | if self.use_fc_norm: 21 | norm_layer = kwargs['norm_layer'] 22 | embed_dim = kwargs['embed_dim'] 23 | self.fc_norm = norm_layer(embed_dim) 24 | 25 | del self.norm # remove the original norm 26 | 27 | self.global_pool = global_pool 28 | self.use_cls = use_cls 29 | self.mask_ratio = mask_ratio 30 | 31 | def random_masking(self, x, mask_ratio): 32 | """ 33 | Perform per-sample random masking by per-sample shuffling. 34 | Per-sample shuffling is done by argsort random noise. 35 | x: [N, L, D], sequence 36 | """ 37 | N, L, D = x.shape # batch, length, dim 38 | len_keep = int(L * (1 - mask_ratio)) 39 | 40 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 41 | 42 | # sort noise for each sample 43 | ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove 44 | ids_restore = torch.argsort(ids_shuffle, dim=1) 45 | 46 | # keep the first subset 47 | ids_keep = ids_shuffle[:, :len_keep] 48 | x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 49 | 50 | # generate the binary mask: 0 is keep, 1 is remove 51 | mask = torch.ones([N, L], device=x.device) 52 | mask[:, :len_keep] = 0 53 | # unshuffle to get the binary mask 54 | mask = torch.gather(mask, dim=1, index=ids_restore) 55 | 56 | return x_masked, mask, ids_restore 57 | 58 | def forward_features(self, x): 59 | x = self.patch_embed(x) 60 | 61 | # add pos embed w/o cls token 62 | x = x + self.pos_embed[:, 1:, :] 63 | 64 | # masking: length -> length * mask_ratio 65 | if self.mask_ratio is not None: 66 | x, _, _ = self.random_masking(x, mask_ratio=self.mask_ratio) 67 | 68 | # append cls token 69 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 70 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 71 | x = torch.cat((cls_tokens, x), dim=1) 72 | 73 | # apply Transformer blocks 74 | for blk in self.blocks: 75 | x = blk(x) 76 | if not self.use_fc_norm: 77 | x = self.norm(x) 78 | 79 | # global pooling or remove cls token 80 | if self.global_pool: 81 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 82 | elif self.use_cls: 83 | x = x[:, 0] # use cls token 84 | else: 85 | x = x[:, 1:] # remove cls token 86 | 87 | # use fc_norm layer 88 | if self.use_fc_norm: 89 | x = self.fc_norm(x) 90 | 91 | return x 92 | 93 | def forward(self, x): 94 | return self.forward_features(x) 95 | 96 | 97 | def vit_small_patch16(**kwargs): 98 | """ViT small as defined in the DeiT paper.""" 99 | model = VisionTransformer( 100 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 101 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 102 | return model 103 | 104 | 105 | def vit_base_patch16(**kwargs): 106 | model = VisionTransformer( 107 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 108 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 109 | return model 110 | 111 | 112 | def vit_large_patch16(**kwargs): 113 | model = VisionTransformer( 114 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, 115 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 116 | return model 117 | 118 | 119 | def vit_huge_patch14(**kwargs): 120 | model = VisionTransformer( 121 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True, 122 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 123 | return model 124 | 125 | 126 | def load_ovrl_v2(checkpoint: str, img_size: int) -> nn.Module: 127 | encoder = vit_base_patch16( 128 | img_size=img_size, 129 | use_fc_norm=False, 130 | global_pool=False, 131 | use_cls=False, 132 | mask_ratio=None, 133 | drop_path_rate=0.0, 134 | ) 135 | weights = torch.load(checkpoint, map_location="cpu")["model"] 136 | orig_state_dict = encoder.state_dict() 137 | encoder.load_state_dict( 138 | { 139 | k: v 140 | for k, v in weights.items() 141 | if k in orig_state_dict 142 | } 143 | ) 144 | return encoder 145 | -------------------------------------------------------------------------------- /ovon/utils/cache_clip_embeddings.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import gzip 4 | import json 5 | import os 6 | import pickle 7 | 8 | import clip 9 | import numpy as np 10 | import torch 11 | from open_clip import create_model_from_pretrained, get_tokenizer 12 | from tqdm import tqdm 13 | 14 | # PROMPT = "{category}" 15 | PROMPT = "Seems like there is a {category} ahead." 16 | 17 | 18 | def save_to_disk(text_embedding, goal_categories, output_path): 19 | output = {} 20 | for goal_category, embedding in zip(goal_categories, text_embedding): 21 | output[goal_category] = embedding.detach().cpu().numpy() 22 | save_pickle(output, output_path) 23 | 24 | 25 | def cache_embeddings(goal_categories, output_path, clip_model="RN50"): 26 | model, _ = clip.load(clip_model) 27 | batch = tokenize_and_batch(clip, goal_categories) 28 | 29 | with torch.no_grad(): 30 | print(batch.shape) 31 | text_embedding = model.encode_text(batch.flatten(0, 1)).float() 32 | save_to_disk(text_embedding, goal_categories, output_path) 33 | 34 | 35 | def tokenize_and_batch(clip, goal_categories): 36 | tokens = [] 37 | for category in goal_categories: 38 | prompt = PROMPT.format(category=category) 39 | tokens.append(clip.tokenize(prompt, context_length=77).numpy()) 40 | return torch.tensor(np.array(tokens)).cuda() 41 | 42 | 43 | def tokenize_and_batch_siglip( 44 | goal_categories, model_name="hf-hub:timm/ViT-B-16-SigLIP-256" 45 | ): 46 | tokenizer = get_tokenizer(model_name) 47 | tokens = [] 48 | for category in tqdm(goal_categories): 49 | prompt = category 50 | tokens.append(tokenizer([prompt], context_length=64).numpy()) 51 | return torch.tensor(np.array(tokens)) 52 | 53 | 54 | def cache_embeddings_siglip( 55 | goal_categories, output_path, model_name="hf-hub:timm/ViT-B-16-SigLIP-256" 56 | ): 57 | model, _ = create_model_from_pretrained(model_name) 58 | batch = tokenize_and_batch_siglip(goal_categories, model_name) 59 | 60 | with torch.no_grad(): 61 | print(batch.shape) 62 | text_embedding = model.encode_text(batch.flatten(0, 1)).float() 63 | save_to_disk(text_embedding, goal_categories, output_path) 64 | 65 | 66 | def load_categories_from_dataset(path): 67 | files = glob.glob(os.path.join(path, "*json.gz")) 68 | 69 | categories = [] 70 | for f in tqdm(files): 71 | dataset = load_dataset(f) 72 | for goal_key in dataset["goals_by_category"].keys(): 73 | categories.append(goal_key.split("_")[1]) 74 | return list(set(categories)) 75 | 76 | 77 | def main(dataset_path, output_path, use_siglip): 78 | goal_categories = load_categories_from_dataset(dataset_path) 79 | val_seen_categories = load_categories_from_dataset( 80 | dataset_path.replace("train", "val_seen") 81 | ) 82 | val_unseen_easy_categories = load_categories_from_dataset( 83 | dataset_path.replace("train", "val_unseen_easy") 84 | ) 85 | val_unseen_hard_categories = load_categories_from_dataset( 86 | dataset_path.replace("train", "val_unseen_hard") 87 | ) 88 | 89 | # Print the first 5 categories of each split 90 | print("Total categories: {}".format(len(goal_categories))) 91 | print("First 5 categories:") 92 | print("goal_categories: {}".format(goal_categories[:5])) 93 | print("val_seen_categories: {}".format(val_seen_categories[:5])) 94 | print("val_unseen_easy_categories: {}".format(val_unseen_easy_categories[:5])) 95 | print("val_unseen_hard_categories: {}".format(val_unseen_hard_categories[:5])) 96 | 97 | goal_categories.extend(val_seen_categories) 98 | goal_categories.extend(val_unseen_easy_categories) 99 | goal_categories.extend(val_unseen_hard_categories) 100 | 101 | print("Total goal categories: {}".format(len(goal_categories))) 102 | print( 103 | "Train categories: {}, Val seen categories: {}, Val unseen easy categories: {}," 104 | " Val unseen hard categories: {}".format( 105 | len(goal_categories), 106 | len(val_seen_categories), 107 | len(val_unseen_easy_categories), 108 | len(val_unseen_hard_categories), 109 | ) 110 | ) 111 | if use_siglip: 112 | cache_embeddings_siglip(goal_categories, output_path) 113 | else: 114 | cache_embeddings(goal_categories, output_path) 115 | 116 | 117 | def load_dataset(path): 118 | with gzip.open(path, "rt") as file: 119 | data = json.loads(file.read()) 120 | return data 121 | 122 | 123 | def save_pickle(data, path): 124 | with open(path, "wb") as file: 125 | pickle.dump(data, file) 126 | 127 | 128 | if __name__ == "__main__": 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument( 131 | "--dataset-path", 132 | type=str, 133 | required=True, 134 | help="file path of OVON dataset", 135 | ) 136 | parser.add_argument( 137 | "--output-path", 138 | type=str, 139 | required=True, 140 | help="output path of text embeddings", 141 | ) 142 | parser.add_argument( 143 | "--use-siglip", 144 | action="store_true", 145 | help="use siglip model", 146 | ) 147 | args = parser.parse_args() 148 | main(args.dataset_path, args.output_path, args.use_siglip) 149 | -------------------------------------------------------------------------------- /ovon/dataset/ovon_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import json 4 | import os 5 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence 6 | 7 | import attr 8 | from habitat.core.registry import registry 9 | from habitat.core.simulator import AgentState, ShortestPathPoint 10 | from habitat.core.utils import DatasetFloatJSONEncoder 11 | from habitat.datasets.pointnav.pointnav_dataset import ( 12 | CONTENT_SCENES_PATH_FIELD, 13 | DEFAULT_SCENE_PATH_PREFIX, 14 | PointNavDatasetV1, 15 | ) 16 | from habitat.tasks.nav.object_nav_task import ( 17 | ObjectGoal, 18 | ObjectGoalNavEpisode, 19 | ObjectViewLocation, 20 | ) 21 | 22 | if TYPE_CHECKING: 23 | from omegaconf import DictConfig 24 | 25 | 26 | @attr.s(auto_attribs=True) 27 | class OVONObjectViewLocation(ObjectViewLocation): 28 | r"""OVONObjectViewLocation 29 | 30 | Args: 31 | raidus: radius of the circle 32 | """ 33 | 34 | radius: Optional[float] = None 35 | 36 | 37 | @attr.s(auto_attribs=True, kw_only=True) 38 | class OVONEpisode(ObjectGoalNavEpisode): 39 | r"""OVON Episode 40 | 41 | :param children_object_categories: Category of the object 42 | """ 43 | 44 | children_object_categories: Optional[List[str]] = [] 45 | 46 | 47 | @registry.register_dataset(name="OVON-v1") 48 | class OVONDatasetV1(PointNavDatasetV1): 49 | r""" 50 | Class inherited from PointNavDataset that loads Open-Vocab 51 | Object Navigation dataset. 52 | """ 53 | 54 | episodes: List[OVONEpisode] = [] # type: ignore 55 | content_scenes_path: str = "{data_path}/content/{scene}.json.gz" 56 | goals_by_category: Dict[str, Sequence[ObjectGoal]] 57 | 58 | @staticmethod 59 | def dedup_goals(dataset: Dict[str, Any]) -> Dict[str, Any]: 60 | if len(dataset["episodes"]) == 0: 61 | return dataset 62 | 63 | goals_by_category = {} 64 | for i, ep in enumerate(dataset["episodes"]): 65 | # Get the category from the first goal 66 | dataset["episodes"][i]["object_category"] = ep["goals"][0][ 67 | "object_category" 68 | ] 69 | ep = OVONEpisode(**ep) 70 | 71 | # Store unique goals under their key 72 | goals_key = ep.goals_key 73 | if goals_key not in goals_by_category: 74 | goals_by_category[goals_key] = ep.goals 75 | 76 | # Store a reference to the shared goals 77 | dataset["episodes"][i]["goals"] = [] 78 | 79 | dataset["goals_by_category"] = goals_by_category 80 | 81 | return dataset 82 | 83 | def to_json(self) -> str: 84 | for i in range(len(self.episodes)): 85 | self.episodes[i].goals = [] 86 | 87 | result = DatasetFloatJSONEncoder().encode(self) 88 | 89 | for i in range(len(self.episodes)): 90 | goals = self.goals_by_category[self.episodes[i].goals_key] 91 | if not isinstance(goals, list): 92 | goals = list(goals) 93 | self.episodes[i].goals = goals 94 | 95 | return result 96 | 97 | def __init__(self, config: Optional["DictConfig"] = None) -> None: 98 | self.goals_by_category = {} 99 | super().__init__(config) 100 | self.episodes = list(self.episodes) 101 | 102 | @staticmethod 103 | def __deserialize_goal(serialized_goal: Dict[str, Any]) -> ObjectGoal: 104 | g = ObjectGoal(**serialized_goal) 105 | g.object_id = int(g.object_id.split("_")[-1]) 106 | 107 | for vidx, view in enumerate(g.view_points): 108 | view_location = OVONObjectViewLocation(**view) # type: ignore 109 | view_location.agent_state = AgentState( 110 | **view_location.agent_state # type: ignore 111 | ) 112 | g.view_points[vidx] = view_location 113 | 114 | return g 115 | 116 | def from_json(self, json_str: str, scenes_dir: Optional[str] = None) -> None: 117 | deserialized = json.loads(json_str) 118 | if CONTENT_SCENES_PATH_FIELD in deserialized: 119 | self.content_scenes_path = deserialized[CONTENT_SCENES_PATH_FIELD] 120 | 121 | if len(deserialized["episodes"]) == 0: 122 | return 123 | 124 | if "goals_by_category" not in deserialized: 125 | deserialized = self.dedup_goals(deserialized) 126 | 127 | for k, v in deserialized["goals_by_category"].items(): 128 | self.goals_by_category[k] = [self.__deserialize_goal(g) for g in v] 129 | 130 | for i, episode in enumerate(deserialized["episodes"]): 131 | episode = OVONEpisode(**episode) 132 | episode.goals = self.goals_by_category[episode.goals_key] # noqa 133 | 134 | if scenes_dir is not None: 135 | if episode.scene_id.startswith(DEFAULT_SCENE_PATH_PREFIX): 136 | episode.scene_id = episode.scene_id[ 137 | len(DEFAULT_SCENE_PATH_PREFIX) : 138 | ] 139 | 140 | episode.scene_id = os.path.join(scenes_dir, episode.scene_id) 141 | 142 | if episode.shortest_paths is not None: 143 | for path in episode.shortest_paths: 144 | for p_index, point in enumerate(path): 145 | if point is None or isinstance(point, (int, str)): 146 | point = { 147 | "action": point, 148 | "rotation": None, 149 | "position": None, 150 | } 151 | 152 | path[p_index] = ShortestPathPoint(**point) 153 | 154 | self.episodes.append(episode) # type: ignore [attr-defined] 155 | -------------------------------------------------------------------------------- /ovon/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from habitat import logger 2 | 3 | 4 | class PIRLNavLRScheduler: 5 | def __init__( 6 | self, 7 | optimizer, 8 | agent, 9 | num_updates, 10 | base_lr, 11 | finetuning_lr, 12 | ppo_eps, 13 | start_actor_update_at, 14 | start_actor_warmup_at, 15 | start_critic_update_at, 16 | start_critic_warmup_at, 17 | ) -> None: 18 | self.optimizer = optimizer 19 | self.agent = agent 20 | self.update = 0 21 | self.num_updates = num_updates 22 | 23 | self.start_actor_update_at = start_actor_update_at 24 | self.start_actor_warmup_at = start_actor_warmup_at 25 | self.start_critic_update_at = start_critic_update_at 26 | self.start_critic_warmup_at = start_critic_warmup_at 27 | 28 | self.ppo_eps = ppo_eps 29 | self.base_lrs = [base_lr, 0, 0] 30 | self.finetuning_lr = finetuning_lr 31 | 32 | self.lr_lambdas = [ 33 | lambda x: self.critic_linear_decay( 34 | x, 35 | start_critic_warmup_at, 36 | start_critic_update_at, 37 | base_lr, 38 | finetuning_lr, 39 | ), 40 | lambda x: self.linear_warmup( 41 | x, 42 | self.start_actor_warmup_at, 43 | self.start_actor_update_at, 44 | 0.0, 45 | finetuning_lr, 46 | ), 47 | lambda x: self.linear_warmup( 48 | x, 49 | self.start_actor_warmup_at, 50 | self.start_actor_update_at, 51 | 0.0, 52 | finetuning_lr, 53 | ), 54 | ] 55 | 56 | def step(self): 57 | self.update += 1 58 | 59 | if self.update == self.start_actor_warmup_at: 60 | logger.info( 61 | "\n\nAgent number of parameters pre unfreeze: {}".format( 62 | sum( 63 | param.numel() if param.requires_grad else 0 64 | for param in self.agent.parameters() 65 | ) 66 | ) 67 | ) 68 | self.agent.actor_critic.unfreeze_actor() 69 | self.agent.actor_critic.unfreeze_state_encoder() 70 | logger.info( 71 | "\n\nAgent number of parameters post unfreeze: {}".format( 72 | sum( 73 | param.numel() if param.requires_grad else 0 74 | for param in self.agent.parameters() 75 | ) 76 | ) 77 | ) 78 | 79 | logger.info("\n\nStart actor finetuning at: {}".format(self.update)) 80 | 81 | start_index = 1 82 | for i, param_group in enumerate( 83 | self.agent.optimizer.param_groups[start_index:] 84 | ): 85 | param_group["eps"] = self.ppo_eps 86 | self.base_lrs[i + start_index] = 1.0 87 | 88 | logger.info("Base LRs: {}".format(self.base_lrs)) 89 | 90 | if self.update == self.start_critic_warmup_at: 91 | self.base_lrs[0] = 1.0 92 | logger.info("\n\nSet critic LR at: {}".format(self.update)) 93 | 94 | lrs = [ 95 | base_lr * lr_lamda(self.update) 96 | for base_lr, lr_lamda in zip(self.base_lrs, self.lr_lambdas) 97 | ] 98 | 99 | # Set LR for each param group 100 | for i, data in enumerate(zip(self.optimizer.param_groups, lrs)): 101 | param_group, lr = data 102 | param_group["lr"] = lr 103 | 104 | def linear_warmup( 105 | self, 106 | update, 107 | start_update: int, 108 | max_updates: int, 109 | start_lr: int, 110 | end_lr: int, 111 | ) -> float: 112 | r""" 113 | Returns a multiplicative factor for linear value warmup 114 | """ 115 | if update < start_update: 116 | return 1.0 117 | 118 | if update >= max_updates: 119 | return end_lr 120 | 121 | if max_updates == start_update: 122 | return end_lr 123 | 124 | pct_step = (update - start_update) / (max_updates - start_update) 125 | step_lr = (end_lr - start_lr) * pct_step + start_lr 126 | if step_lr > end_lr: 127 | step_lr = end_lr 128 | return step_lr 129 | 130 | def critic_linear_decay( 131 | self, 132 | update, 133 | start_update: int, 134 | max_updates: int, 135 | start_lr: int, 136 | end_lr: int, 137 | ) -> float: 138 | r""" 139 | Returns a multiplicative factor for linear value decay 140 | """ 141 | if update < start_update: 142 | return 1 143 | 144 | if update >= max_updates: 145 | return end_lr 146 | 147 | if max_updates == start_update: 148 | return end_lr 149 | 150 | pct_step = (update - start_update) / (max_updates - start_update) 151 | step_lr = start_lr - (start_lr - end_lr) * pct_step 152 | if step_lr < end_lr: 153 | step_lr = end_lr 154 | return step_lr 155 | 156 | def load_state_dict(self, state_dict): 157 | self.__dict__.update(state_dict) 158 | 159 | if self.update >= self.start_actor_update_at: 160 | self.agent.actor_critic.unfreeze_actor() 161 | self.agent.actor_critic.unfreeze_state_encoder() 162 | 163 | def state_dict(self): 164 | return { 165 | key: value 166 | for key, value in self.__dict__.items() 167 | if key not in ["optimizer", "agent", "lr_lambdas"] 168 | } 169 | -------------------------------------------------------------------------------- /ovon/obs_transformers/resize.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from dataclasses import dataclass 3 | from typing import Dict, Tuple 4 | 5 | import torch 6 | from gym import spaces 7 | from habitat.core.logging import logger 8 | from habitat_baselines.common.baseline_registry import baseline_registry 9 | from habitat_baselines.common.obs_transformers import ObservationTransformer 10 | from habitat_baselines.config.default_structured_configs import ( 11 | ObsTransformConfig, 12 | ) 13 | from habitat_baselines.utils.common import ( 14 | get_image_height_width, 15 | overwrite_gym_box_shape, 16 | ) 17 | from hydra.core.config_store import ConfigStore 18 | from omegaconf import DictConfig 19 | from torch import Tensor 20 | 21 | from ovon.task.sensors import ClipImageGoalSensor, ImageGoalRotationSensor 22 | 23 | 24 | @baseline_registry.register_obs_transformer() 25 | class Resize(ObservationTransformer): 26 | def __init__( 27 | self, 28 | size: Tuple[int, int], 29 | channels_last: bool = True, 30 | trans_keys: Tuple[str, ...] = ("rgb", "depth", "semantic"), 31 | semantic_key: str = "semantic", 32 | ): 33 | """Args: 34 | size: The size you want to resize the shortest edge to 35 | channels_last: indicates if channels is the last dimension 36 | """ 37 | super(Resize, self).__init__() 38 | self._size: Tuple[int, int] = size 39 | self.channels_last: bool = channels_last 40 | self.trans_keys: Tuple[str, ...] = trans_keys 41 | self.semantic_key = semantic_key 42 | 43 | def transform_observation_space(self, observation_space: spaces.Dict): 44 | observation_space = copy.deepcopy(observation_space) 45 | for key in observation_space.spaces: 46 | if key in self.trans_keys: 47 | # In the observation space dict, the channels are always last 48 | h, w = get_image_height_width( 49 | observation_space.spaces[key], channels_last=True 50 | ) 51 | if self._size == (h, w): 52 | continue 53 | logger.info( 54 | "Resizing observation of %s: from %s to %s" 55 | % (key, (h, w), self._size) 56 | ) 57 | observation_space.spaces[key] = overwrite_gym_box_shape( 58 | observation_space.spaces[key], self._size 59 | ) 60 | return observation_space 61 | 62 | def _transform_obs( 63 | self, obs: torch.Tensor, interpolation_mode: str 64 | ) -> torch.Tensor: 65 | return image_resize( 66 | obs, 67 | self._size, 68 | channels_last=self.channels_last, 69 | interpolation_mode=interpolation_mode, 70 | ) 71 | 72 | @torch.no_grad() 73 | def forward(self, observations: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 74 | for sensor in self.trans_keys: 75 | if sensor in observations: 76 | interpolation_mode = "area" 77 | if self.semantic_key in sensor: 78 | interpolation_mode = "nearest" 79 | observations[sensor] = self._transform_obs( 80 | observations[sensor], interpolation_mode 81 | ) 82 | return observations 83 | 84 | @classmethod 85 | def from_config(cls, config: "DictConfig"): 86 | return cls( 87 | tuple(config.size), 88 | config.channels_last, 89 | config.trans_keys, 90 | config.semantic_key, 91 | ) 92 | 93 | 94 | def image_resize( 95 | img: Tensor, 96 | size: Tuple[int, int], 97 | channels_last: bool = False, 98 | interpolation_mode="area", 99 | ) -> torch.Tensor: 100 | """Resizes an img. 101 | 102 | Args: 103 | img: the array object that needs to be resized (HWC) or (NHWC) 104 | size: the size that you want 105 | channels: a boolean that channel is the last dimension 106 | Returns: 107 | The resized array as a torch tensor. 108 | """ 109 | img = torch.as_tensor(img) 110 | no_batch_dim = len(img.shape) == 3 111 | if len(img.shape) < 3 or len(img.shape) > 5: 112 | raise NotImplementedError() 113 | if no_batch_dim: 114 | img = img.unsqueeze(0) # Adds a batch dimension 115 | if channels_last: 116 | if len(img.shape) == 4: 117 | # NHWC -> NCHW 118 | img = img.permute(0, 3, 1, 2) 119 | else: 120 | # NDHWC -> NDCHW 121 | img = img.permute(0, 1, 4, 2, 3) 122 | 123 | img = torch.nn.functional.interpolate( 124 | img.float(), size=size, mode=interpolation_mode 125 | ).to(dtype=img.dtype) 126 | if channels_last: 127 | if len(img.shape) == 4: 128 | # NCHW -> NHWC 129 | img = img.permute(0, 2, 3, 1) 130 | else: 131 | # NDCHW -> NDHWC 132 | img = img.permute(0, 1, 3, 4, 2) 133 | if no_batch_dim: 134 | img = img.squeeze(dim=0) # Removes the batch dimension 135 | return img 136 | 137 | 138 | @dataclass 139 | class ResizeConfig(ObsTransformConfig): 140 | type: str = Resize.__name__ 141 | size: Tuple[int, int] = ( 142 | 224, 143 | 224, 144 | ) 145 | channels_last: bool = True 146 | trans_keys: Tuple[str, ...] = ( 147 | "rgb", 148 | "depth", 149 | "semantic", 150 | ImageGoalRotationSensor.cls_uuid, 151 | ClipImageGoalSensor.cls_uuid, 152 | ) 153 | semantic_key: str = "semantic" 154 | 155 | 156 | cs = ConfigStore.instance() 157 | 158 | cs.store( 159 | package="habitat_baselines.rl.policy.obs_transforms.resize", 160 | group="habitat_baselines/rl/policy/obs_transforms", 161 | name="resize", 162 | node=ResizeConfig, 163 | ) 164 | -------------------------------------------------------------------------------- /ovon/models/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.transforms.functional as TF 7 | from torchvision.transforms import ColorJitter, RandomApply 8 | 9 | 10 | class RandomShiftsAug(nn.Module): 11 | def __init__(self, pad): 12 | super().__init__() 13 | self.pad = pad 14 | 15 | def forward(self, x): 16 | n, _, h, w = x.size() 17 | assert h == w 18 | padding = tuple([self.pad] * 4) 19 | x = F.pad(x, padding, "replicate") 20 | eps = 1.0 / (h + 2 * self.pad) 21 | arange = torch.linspace( 22 | -1.0 + eps, 23 | 1.0 - eps, 24 | h + 2 * self.pad, 25 | device=x.device, 26 | dtype=x.dtype, 27 | )[:h] 28 | arange = arange.unsqueeze(0).repeat(h, 1).unsqueeze(2) 29 | base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) 30 | base_grid = base_grid.unsqueeze(0).repeat(n, 1, 1, 1) 31 | 32 | shift = torch.randint( 33 | 0, 34 | 2 * self.pad + 1, 35 | size=(n, 1, 1, 2), 36 | device=x.device, 37 | dtype=x.dtype, 38 | ) 39 | shift *= 2.0 / (h + 2 * self.pad) 40 | 41 | grid = base_grid + shift 42 | return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) 43 | 44 | 45 | class Transform: 46 | randomize_environments: bool = False 47 | 48 | def apply(self, x: torch.Tensor): 49 | raise NotImplementedError 50 | 51 | def __call__( 52 | self, 53 | x: torch.Tensor, 54 | N: Optional[int] = None, 55 | ): 56 | if not self.randomize_environments or N is None: 57 | return self.apply(x) 58 | 59 | # shapes 60 | TN = x.size(0) 61 | T = TN // N 62 | 63 | # apply the same augmentation when t == 1 for speed 64 | # typically, t == 1 during policy rollout 65 | if T == 1: 66 | return self.apply(x) 67 | 68 | # put environment (n) first 69 | _, A, B, C = x.shape 70 | x = torch.einsum("tnabc->ntabc", x.view(T, N, A, B, C)) 71 | 72 | # apply the same transform within each environment 73 | x = torch.cat([self.apply(imgs) for imgs in x]) 74 | 75 | # put timestep (t) first 76 | _, A, B, C = x.shape 77 | x = torch.einsum("ntabc->tnabc", x.view(N, T, A, B, C)).flatten(0, 1) 78 | 79 | return x 80 | 81 | 82 | class ResizeTransform(Transform): 83 | def __init__(self, size): 84 | self.size = size 85 | 86 | def apply(self, x): 87 | x = x.permute(0, 3, 1, 2) 88 | x = TF.resize(x, self.size) 89 | x = TF.center_crop(x, output_size=self.size) 90 | x = x.float() / 255.0 91 | return x 92 | 93 | 94 | class ShiftAndJitterTransform(Transform): 95 | def __init__(self, augmentations_name, size): 96 | self.size = size 97 | self.augmentations_name = augmentations_name 98 | 99 | def apply(self, x): 100 | x = x.permute(0, 3, 1, 2) 101 | x = TF.resize(x, self.size) 102 | x = TF.center_crop(x, output_size=self.size) 103 | x = x.float() / 255.0 104 | if "jitter" in self.augmentations_name: 105 | x = RandomApply([ColorJitter(0.4, 0.4, 0.4, 0.4)], p=1.0)(x) 106 | if "shift" in self.augmentations_name: 107 | x = RandomShiftsAug(16)(x) 108 | return x 109 | 110 | 111 | class WeakAugmentation(Transform): 112 | is_random: bool = True 113 | 114 | def __init__(self, size): 115 | self.size = size 116 | 117 | def apply(self, x): 118 | x = x.permute(0, 3, 1, 2) 119 | x = TF.resize(x, self.size, interpolation=TF.InterpolationMode.BICUBIC) 120 | x = TF.center_crop(x, output_size=self.size) 121 | x = x.float() / 255.0 122 | x = RandomApply([ColorJitter(0.3, 0.3, 0.3, 0.3)], p=1.0)(x) 123 | x = RandomShiftsAug(4)(x) 124 | return x 125 | 126 | 127 | class CLIPTransform(Transform): 128 | def __init__(self, size): 129 | self.size = size 130 | self.mean = (0.48145466, 0.4578275, 0.40821073) 131 | self.std = (0.26862954, 0.26130258, 0.27577711) 132 | 133 | def apply(self, x): 134 | x = x.permute(0, 3, 1, 2) 135 | x = TF.resize(x, self.size, interpolation=TF.InterpolationMode.BICUBIC) 136 | x = TF.center_crop(x, output_size=self.size) 137 | x = x.float() / 255.0 138 | x = TF.normalize(x, self.mean, self.std) 139 | return x 140 | 141 | 142 | class CLIPWeakTransform(Transform): 143 | is_random: bool = True 144 | 145 | def __init__(self, size): 146 | self.size = size 147 | self.mean = (0.48145466, 0.4578275, 0.40821073) 148 | self.std = (0.26862954, 0.26130258, 0.27577711) 149 | 150 | def apply(self, x): 151 | x = x.permute(0, 3, 1, 2) 152 | x = TF.resize(x, self.size, interpolation=TF.InterpolationMode.BICUBIC) 153 | x = TF.center_crop(x, output_size=self.size) 154 | x = x.float() / 255.0 155 | x = RandomApply([ColorJitter(0.3, 0.3, 0.3, 0.3)], p=1.0)(x) 156 | x = RandomShiftsAug(4)(x) 157 | x = TF.normalize(x, self.mean, self.std) 158 | return x 159 | 160 | 161 | def get_transform(name, size): 162 | if name == "resize": 163 | return ResizeTransform(size) 164 | elif "shift" in name or "jitter" in name: 165 | return ShiftAndJitterTransform(name, size) 166 | elif name == "resize+weak": 167 | return WeakAugmentation(size) 168 | elif name == "clip": 169 | return CLIPTransform(size) 170 | elif name == "clip+weak": 171 | return CLIPWeakTransform(size) 172 | 173 | else: 174 | raise ValueError(f"Unknown transform {name}") 175 | -------------------------------------------------------------------------------- /ovon/models/encoders/visual_encoder_v2.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import numpy as np 4 | import torch 5 | from habitat_baselines.rl.ddppo.policy.running_mean_and_var import RunningMeanAndVar 6 | from torch import nn as nn 7 | from torch.nn import functional as F 8 | 9 | from ovon.models.encoders import resnet_gn as resnet 10 | from ovon.models.encoders import vit 11 | 12 | 13 | class VisualEncoder(nn.Module): 14 | def __init__( 15 | self, 16 | image_size: int, 17 | backbone: str, 18 | input_channels: int = 3, 19 | resnet_baseplanes: int = 32, 20 | resnet_ngroups: int = 32, 21 | vit_use_fc_norm: bool = False, 22 | vit_global_pool: bool = False, 23 | vit_use_cls: bool = False, 24 | vit_mask_ratio: Optional[float] = None, 25 | normalize_visual_inputs: bool = True, 26 | avgpooled_image: bool = False, 27 | drop_path_rate: float = 0.0, 28 | checkpoint: str = "", 29 | visual_transform: Any = None, 30 | ): 31 | super().__init__() 32 | self.visual_transform = visual_transform 33 | self.avgpooled_image = avgpooled_image 34 | 35 | if normalize_visual_inputs: 36 | self.running_mean_and_var: nn.Module = RunningMeanAndVar(input_channels) 37 | else: 38 | self.running_mean_and_var = nn.Sequential() 39 | 40 | if "resnet" in backbone: 41 | make_backbone = getattr(resnet, backbone) 42 | self.backbone = make_backbone( 43 | input_channels, resnet_baseplanes, resnet_ngroups 44 | ) 45 | 46 | spatial_size = image_size 47 | if self.avgpooled_image: 48 | spatial_size = image_size // 2 49 | 50 | final_spatial = int(spatial_size * self.backbone.final_spatial_compress) 51 | after_compression_flat_size = 2048 52 | num_compression_channels = int( 53 | round(after_compression_flat_size / (final_spatial**2)) 54 | ) 55 | self.compression = nn.Sequential( 56 | nn.Conv2d( 57 | self.backbone.final_channels, 58 | num_compression_channels, 59 | kernel_size=3, 60 | padding=1, 61 | bias=False, 62 | ), 63 | nn.GroupNorm(1, num_compression_channels), 64 | nn.ReLU(True), 65 | ) 66 | 67 | output_shape = ( 68 | num_compression_channels, 69 | final_spatial, 70 | final_spatial, 71 | ) 72 | self.output_size = np.prod(output_shape) 73 | elif "vit" in backbone: 74 | if self.avgpooled_image: 75 | image_size = image_size // 2 76 | 77 | if checkpoint == "": 78 | make_backbone = getattr(vit, backbone) 79 | self.backbone = make_backbone( 80 | img_size=image_size, 81 | use_fc_norm=vit_use_fc_norm, 82 | global_pool=vit_global_pool, 83 | use_cls=vit_use_cls, 84 | mask_ratio=vit_mask_ratio, 85 | drop_path_rate=drop_path_rate, 86 | ) 87 | else: 88 | self.backbone = vit.load_ovrl_v2(checkpoint, img_size=image_size) 89 | 90 | if self.backbone.global_pool or self.backbone.use_cls: 91 | self.compression = nn.Identity() 92 | self.output_size = self.backbone.embed_dim 93 | else: 94 | assert self.backbone.mask_ratio is None 95 | final_spatial = int(self.backbone.patch_embed.num_patches**0.5) 96 | after_compression_flat_size = 2048 97 | num_compression_channels = int( 98 | round(after_compression_flat_size / (final_spatial**2)) 99 | ) 100 | self.compression = nn.Sequential( 101 | ViTReshape(), 102 | nn.Conv2d( 103 | self.backbone.embed_dim, 104 | num_compression_channels, 105 | kernel_size=3, 106 | padding=1, 107 | bias=False, 108 | ), 109 | nn.GroupNorm(1, num_compression_channels), 110 | nn.ReLU(True), 111 | ) 112 | 113 | self.output_shape = output_shape = ( 114 | num_compression_channels, 115 | final_spatial, 116 | final_spatial, 117 | ) 118 | self.output_size = np.prod(output_shape) 119 | else: 120 | raise ValueError("unknown backbone {}".format(backbone)) 121 | 122 | def forward( 123 | self, observations: "TensorDict", *args, **kwargs # type: ignore 124 | ) -> torch.Tensor: 125 | rgb = observations["rgb"] 126 | num_environments = rgb.size(0) 127 | x = self.visual_transform(rgb, num_environments) 128 | 129 | if ( 130 | self.avgpooled_image 131 | ): # For compatibility with the habitat_baselines implementation 132 | x = F.avg_pool2d(x, 2) 133 | x = self.running_mean_and_var(x) 134 | x = self.backbone(x) 135 | x = self.compression(x) 136 | return x 137 | 138 | # @property 139 | # def output_shape(self): 140 | # return (self.output_size,) 141 | 142 | 143 | class ViTReshape(nn.Module): 144 | def __init__(self): 145 | super().__init__() 146 | 147 | def forward(self, x): 148 | N, L, D = x.shape 149 | H = W = int(L**0.5) 150 | x = x.reshape(N, H, W, D) 151 | x = torch.einsum("nhwd->ndhw", x) 152 | return x 153 | -------------------------------------------------------------------------------- /ovon/utils/visualize/semantic_nav_analysis.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from collections import defaultdict 4 | 5 | import clip 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import seaborn as sns 9 | import torch 10 | 11 | from ovon.utils.utils import load_json 12 | 13 | PROMPT = "{category}" 14 | 15 | 16 | def clip_embeddings(clip_m, prompts): 17 | tokens = [] 18 | for prompt in prompts: 19 | tokens.append(clip.tokenize(prompt, context_length=77).numpy()) 20 | 21 | batch = torch.tensor(np.array(tokens)).cuda() 22 | with torch.no_grad(): 23 | text_embedding = clip_m.encode_text(batch.flatten(0, 1)).float() 24 | return text_embedding 25 | 26 | 27 | def max_similarity(clip_m, category, val_seen_categories): 28 | categories = val_seen_categories.copy() 29 | if category in categories: 30 | categories.remove(category) 31 | 32 | prompt = PROMPT.format(category=category) 33 | text_embedding = clip_embeddings(clip_m, [prompt] + categories) 34 | return ( 35 | torch.cosine_similarity(text_embedding[0].unsqueeze(0), text_embedding[1:]) 36 | .max() 37 | .item() 38 | ) 39 | 40 | 41 | def plot_scatterplot(x, y, output_path): 42 | sns.set_theme(style="whitegrid") 43 | 44 | fig, ax = plt.subplots(figsize=(6, 6)) 45 | sns.scatterplot(x=x, y=y, ax=ax) 46 | 47 | fig.savefig(output_path) 48 | 49 | 50 | def plot_barplot(x, y, output_path): 51 | sns.set_theme(style="whitegrid") 52 | 53 | fig, ax = plt.subplots(figsize=(6, 6)) 54 | sns.barplot(x=x, y=y, ax=ax) 55 | 56 | fig.savefig(output_path) 57 | 58 | 59 | def region_analysis(input_path, output_path): 60 | categories_by_region = load_json("data/hm3d_meta/hm3d_categories_by_region.json") 61 | 62 | region_per_category = {} 63 | for region, categories in categories_by_region.items(): 64 | for category in categories: 65 | region_per_category[category] = region 66 | 67 | files = [ 68 | "/coc/testnvme/nyokoyama3/public/ft_dagger_ckpt_30_vue.json", 69 | "/coc/testnvme/nyokoyama3/public/ft_dagger_ckpt_30_vuh.json", 70 | ] 71 | 72 | val_seen_categories = load_json("data/hm3d_meta/ovon_categories.json")["val_seen"] 73 | val_seen_categories = [ 74 | PROMPT.format(category=category) for category in val_seen_categories 75 | ] 76 | 77 | clip_m, preprocess = clip.load("RN50", device="cuda") 78 | 79 | for file in files: 80 | print("Split: {}".format(file.split("/")[-1])) 81 | metrics = load_json(file) 82 | 83 | episodes_per_region = defaultdict(list) 84 | success_per_region = defaultdict(int) 85 | all_categories = [] 86 | for k, meta in metrics.items(): 87 | region = region_per_category.get(meta["target"]) 88 | region = region if region is not None else "NaN" 89 | success_per_region[region] += meta["success"] 90 | episodes_per_region[region].append(meta) 91 | all_categories.append(meta["target"]) 92 | 93 | all_categories = list(set(all_categories)) 94 | 95 | category_to_max_similarity = {} 96 | for category in all_categories: 97 | region = region_per_category.get(category) 98 | if region is None: 99 | continue 100 | region_categories = categories_by_region[region] 101 | 102 | vs_region_categories = list( 103 | set(region_categories).intersection(set(val_seen_categories)) 104 | ) 105 | category_to_max_similarity[category] = round( 106 | max_similarity(clip_m, category, vs_region_categories), 2 107 | ) 108 | 109 | for region, episodes in episodes_per_region.items(): 110 | success_per_similarity = defaultdict(int) 111 | count_per_similarity = defaultdict(int) 112 | 113 | for episode in episodes: 114 | cos_sim = category_to_max_similarity.get(episode["target"]) 115 | if cos_sim is None: 116 | continue 117 | success_per_similarity[cos_sim] += episode["success"] 118 | count_per_similarity[cos_sim] += 1 119 | 120 | success_per_similarity = { 121 | k: v / count_per_similarity[k] 122 | for k, v in success_per_similarity.items() 123 | } 124 | 125 | out_path = os.path.join( 126 | output_path, "{}_success_vs_sim.png".format(region.replace("/", "_")) 127 | ) 128 | 129 | plot_scatterplot( 130 | list(count_per_similarity.keys()), 131 | list(success_per_similarity.values()), 132 | out_path, 133 | ) 134 | 135 | # print(category_to_max_similarity) 136 | 137 | # print({k: v / 3000 for k, v in success_per_region.items()}) 138 | # print({k: v / 3000 for k, v in episodes_per_region.items()}) 139 | # # print(list(set(categories))) 140 | # for region, categories in categories_by_region.items(): 141 | # intersection = set(all_categories).intersection(set(categories)) 142 | # if region == "bathroom": 143 | # print("Region: {}, Intersection: {}".format(region, intersection)) 144 | 145 | # print("Categories in NaN: {}".format(set(all_categories).intersection(set(categories_by_region["NaN"])))) 146 | # print("Regions: {}".format(set(categories_by_region.keys()))) 147 | 148 | 149 | def semantic_failures(): 150 | pass 151 | 152 | 153 | if __name__ == "__main__": 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument( 156 | "--input-path", type=str, help="Path to the episode metrics file", required=True 157 | ) 158 | parser.add_argument( 159 | "--output-path", type=str, help="Path to the output file", required=True 160 | ) 161 | args = parser.parse_args() 162 | region_analysis(args.input_path, args.output_path) 163 | -------------------------------------------------------------------------------- /ovon/utils/visualize_trajectories.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import habitat 5 | import habitat_sim 6 | import numpy as np 7 | from habitat import get_config, logger 8 | from habitat.config import read_write 9 | from habitat.sims.habitat_simulator.actions import HabitatSimActions 10 | from habitat.tasks.nav.shortest_path_follower import ShortestPathFollower 11 | from habitat.utils.visualizations.utils import ( 12 | images_to_video, 13 | observations_to_image, 14 | ) 15 | from habitat_sim.utils.common import quat_from_two_vectors 16 | from numpy import ndarray 17 | from tqdm import tqdm 18 | 19 | from ovon.config import ClipObjectGoalSensorConfig, OVONDistanceToGoalConfig 20 | from ovon.utils.visualize.viz import append_text_to_image 21 | 22 | 23 | def _face_object(object_position: np.array, point: ndarray): 24 | EPS_ARRAY = np.array([1e-8, 0.0, 1e-8]) 25 | cam_normal = (object_position - point) + EPS_ARRAY 26 | cam_normal[1] = 0 27 | cam_normal = cam_normal / np.linalg.norm(cam_normal) 28 | return quat_from_two_vectors(habitat_sim.geo.FRONT, cam_normal) 29 | 30 | 31 | def make_videos(observations_list, output_dir, id): 32 | images_to_video(observations_list[0], output_dir=output_dir, video_name=id) 33 | 34 | 35 | def get_nearest_goal(episode, env): 36 | min_dist = 1000.0 37 | sim = env.sim 38 | goal_key = "{}_{}".format(episode.scene_id.split("/")[-1], episode.object_category) 39 | goals = env._dataset.goals_by_category[goal_key] 40 | 41 | goal_location = None 42 | goal_rotation = None 43 | 44 | agent_position = sim.get_agent_state().position 45 | for goal in goals: 46 | for view_point in goal.view_points: 47 | position = view_point.agent_state.position 48 | 49 | dist = sim.geodesic_distance(agent_position, position) 50 | if min_dist > dist: 51 | min_dist = dist 52 | goal_location = position 53 | goal_rotation = view_point.agent_state.rotation 54 | return goal_location, goal_rotation 55 | 56 | 57 | def generate_trajectories(cfg, video_dir="", num_episodes=1): 58 | os.makedirs(video_dir, exist_ok=True) 59 | with habitat.Env(cfg) as env: 60 | goal_radius = 0.1 61 | spl = 0 62 | total_success = 0.0 63 | total_episodes = 0.0 64 | scene_id = env._current_episode.scene_id.split("/")[-1].split(".")[0] 65 | 66 | logger.info("Total episodes: {}".format(len(env.episodes))) 67 | num_episodes = min(len(env.episodes), num_episodes) 68 | for episode_id in tqdm(range(num_episodes)): 69 | follower = ShortestPathFollower(env._sim, goal_radius, False) 70 | env.reset() 71 | success = 0 72 | episode = env.current_episode 73 | goal_position, goal_rotation = get_nearest_goal(episode, env) 74 | 75 | info = {} 76 | obs_list = [] 77 | if goal_position is None: 78 | continue 79 | 80 | while not env.episode_over: 81 | best_action = follower.get_next_action(goal_position) 82 | 83 | if ( 84 | "distance_to_goal" in info.keys() 85 | and info["distance_to_goal"] < 0.1 86 | and best_action != HabitatSimActions.stop 87 | ): 88 | best_action = HabitatSimActions.stop 89 | 90 | observations = env.step(best_action) 91 | 92 | if best_action == HabitatSimActions.stop: 93 | position = env.sim.get_agent_state().position 94 | observations = env.sim.get_observations_at( 95 | position, goal_rotation, False 96 | ) 97 | 98 | info = env.get_metrics() 99 | frame = observations_to_image({"rgb": observations["rgb"]}, info) 100 | frame = append_text_to_image( 101 | frame, "Go to {}".format(episode.object_category) 102 | ) 103 | obs_list.append(frame) 104 | 105 | success = info["success"] 106 | 107 | print(info) 108 | total_success += success 109 | spl += info["spl"] 110 | total_episodes += 1 111 | 112 | make_videos([obs_list], video_dir, "{}_{}".format(scene_id, episode_id)) 113 | print("Total episodes: {}".format(total_episodes)) 114 | 115 | print("\n\nEpisode success: {}".format(total_success / total_episodes)) 116 | print("SPL: {}, {}, {}".format(spl / total_episodes, spl, total_episodes)) 117 | print( 118 | "Success: {}, {}, {}".format( 119 | total_success / total_episodes, total_success, total_episodes 120 | ) 121 | ) 122 | 123 | 124 | def main(): 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument("--data", type=str, default="data/episodes/sampled.json.gz") 127 | parser.add_argument("--video-dir", type=str, default="data/video_dir/") 128 | parser.add_argument("--num-episodes", type=int, default=2) 129 | args = parser.parse_args() 130 | 131 | objectnav_config = "config/tasks/objectnav_stretch_hm3d.yaml" 132 | config = get_config(objectnav_config) 133 | with read_write(config): 134 | config.habitat.dataset.type = "OVON-v1" 135 | config.habitat.dataset.split = "train" 136 | config.habitat.dataset.scenes_dir = "data/scene_datasets/" 137 | config.habitat.dataset.content_scenes = ["*"] 138 | config.habitat.dataset.data_path = args.data 139 | del config.habitat.task.lab_sensors["objectgoal_sensor"] 140 | config.habitat.task.lab_sensors["clip_objectgoal_sensor"] = ( 141 | ClipObjectGoalSensorConfig() 142 | ) 143 | config.habitat.task.measurements.distance_to_goal = OVONDistanceToGoalConfig() 144 | config.habitat.task.measurements.success.success_distance = 0.25 145 | 146 | generate_trajectories( 147 | config, video_dir=args.video_dir, num_episodes=args.num_episodes 148 | ) 149 | 150 | 151 | if __name__ == "__main__": 152 | main() 153 | -------------------------------------------------------------------------------- /ovon/utils/sample_episodes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import os.path as osp 5 | import random 6 | from collections import defaultdict 7 | 8 | from tqdm import tqdm 9 | 10 | from ovon.utils.utils import count_episodes, load_dataset, load_json, write_dataset 11 | 12 | 13 | def sample_val(input_path, output_path, max_episodes): 14 | files = glob.glob(osp.join(input_path, "*.json.gz")) 15 | len(files) 16 | 17 | count, categories = count_episodes(input_path) 18 | print("Total episodes: {}".format(count)) 19 | 20 | max_epsiodes_per_category = max_episodes // len(categories) 21 | print("Max episodes per category: {}".format(max_epsiodes_per_category)) 22 | 23 | episodes_per_category = defaultdict(list) 24 | for f in tqdm(files): 25 | dataset = load_dataset(f) 26 | for episode in dataset["episodes"]: 27 | episodes_per_category[episode["object_category"]].append(episode) 28 | 29 | sampled_buffer = [] 30 | episodes_per_scene = defaultdict(list) 31 | total_episodes = 0 32 | for category, episodes in episodes_per_category.items(): 33 | random.shuffle(episodes) 34 | episodes_per_category[category] = episodes[:max_epsiodes_per_category] 35 | sampled_buffer.extend(episodes[max_epsiodes_per_category + 1 :]) 36 | total_episodes += len(episodes_per_category[category]) 37 | for episode in episodes_per_category[category]: 38 | episodes_per_scene[episode["scene_id"]].append(episode) 39 | 40 | if total_episodes != max_episodes: 41 | missing_episodes = max_episodes - total_episodes 42 | sampled_episodes = random.sample(sampled_buffer, missing_episodes) 43 | for episode in sampled_episodes: 44 | episodes_per_scene[episode["scene_id"]].append(episode) 45 | 46 | num_added = 0 47 | for idx, file in enumerate(tqdm(files)): 48 | dataset = load_dataset(file) 49 | scene_id = dataset["episodes"][0]["scene_id"] 50 | 51 | dataset["episodes"] = episodes_per_scene[scene_id] 52 | num_added += len(dataset["episodes"]) 53 | 54 | output_file = osp.join(output_path, osp.basename(file)) 55 | print(f"Copied {len(dataset['episodes'])} episodes to {output_file}!") 56 | write_dataset(dataset, output_file) 57 | 58 | print(f"Added {num_added} episodes in total!") 59 | 60 | 61 | def sample_custom(input_path, output_path, episode_meta_file): 62 | files = glob.glob(osp.join(input_path, "*.json.gz")) 63 | 64 | episode_meta = load_json(episode_meta_file) 65 | 66 | os.makedirs(output_path, exist_ok=True) 67 | 68 | num_added = 0 69 | eps_per_scene = [] 70 | for idx, file in enumerate(tqdm(files)): 71 | dataset = load_dataset(file) 72 | scene_id = file.split("/")[-1].split(".")[0] 73 | print(file, scene_id) 74 | 75 | episodes_by_category = defaultdict(list) 76 | for episode in dataset["episodes"]: 77 | episodes_by_category[episode["object_category"]].append(episode) 78 | 79 | dataset["episodes"] = [] 80 | for category in episode_meta[scene_id]: 81 | min_episodes = min( 82 | episode_meta[scene_id][category], len(episodes_by_category[category]) 83 | ) 84 | sampled_episodes = random.sample( 85 | episodes_by_category[category], min_episodes 86 | ) 87 | if min_episodes < episode_meta[scene_id][category]: 88 | print(f"Warning: not enough episodes for {scene_id} {category}!") 89 | dataset["episodes"].extend(sampled_episodes) 90 | 91 | output_file = osp.join(output_path, osp.basename(file)) 92 | print(f"Copied {len(dataset['episodes'])} episodes to {output_file}!") 93 | write_dataset(dataset, output_file) 94 | num_added += len(dataset["episodes"]) 95 | 96 | print(f"Added {num_added} episodes in total!") 97 | print(f"Episodes per scene: {eps_per_scene}") 98 | 99 | 100 | def main(input_path, output_path, max_episodes): 101 | files = glob.glob(osp.join(input_path, "*.json.gz")) 102 | num_gz_files = len(files) 103 | 104 | os.makedirs(output_path, exist_ok=True) 105 | 106 | num_added = 0 107 | eps_per_scene = [] 108 | for idx, file in enumerate(tqdm(files)): 109 | dataset = load_dataset(file) 110 | random.shuffle(dataset["episodes"]) 111 | 112 | num_left = max_episodes - num_added 113 | num_gz_remaining = num_gz_files - idx 114 | num_needed = min(num_left / num_gz_remaining, len(dataset["episodes"])) 115 | eps_per_scene.append(num_needed) 116 | 117 | sampled_episodes = random.sample(dataset["episodes"], int(num_needed)) 118 | num_added += len(sampled_episodes) 119 | 120 | dataset["episodes"] = sampled_episodes 121 | 122 | output_file = osp.join(output_path, osp.basename(file)) 123 | print(f"Copied {len(sampled_episodes)} episodes to {output_file}!") 124 | write_dataset(dataset, output_file) 125 | 126 | print(f"Added {num_added} episodes in total!") 127 | print(f"Episodes per scene: {eps_per_scene}") 128 | 129 | 130 | if __name__ == "__main__": 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument( 133 | "--input-path", 134 | type=str, 135 | help="Path to episode dir containing content/", 136 | ) 137 | parser.add_argument( 138 | "--output-path", 139 | type=str, 140 | help="Path to episode dir containing content/", 141 | ) 142 | parser.add_argument( 143 | "--episode-meta-file", 144 | type=str, 145 | help="Path to num episode per category per scene meta file", 146 | ) 147 | parser.add_argument("--max-episodes", type=int) 148 | parser.add_argument("--val", dest="is_val", action="store_true") 149 | args = parser.parse_args() 150 | 151 | if args.episode_meta_file is not None: 152 | sample_custom(args.input_path, args.output_path, args.episode_meta_file) 153 | elif args.is_val: 154 | sample_val(args.input_path, args.output_path, args.max_episodes) 155 | else: 156 | main(args.input_path, args.output_path, args.max_episodes) 157 | --------------------------------------------------------------------------------