├── ss_baselines ├── __init__.py ├── common │ ├── __init__.py │ ├── baseline_registry.py │ ├── tensorboard_utils.py │ ├── benchmark.py │ ├── environments.py │ ├── env_utils.py │ ├── simple_agents.py │ ├── rollout_storage.py │ ├── utils.py │ ├── base_trainer.py │ └── sync_vector_env.py └── av_nav │ ├── models │ ├── __init__.py │ ├── audio_cnn.py │ ├── rnn_state_encoder.py │ └── visual_cnn.py │ └── __init__.py ├── soundspaces ├── datasets │ ├── __init__.py │ └── audionav_dataset.py ├── tasks │ ├── __init__.py │ └── audionav_task.py ├── visualizations │ ├── __init__.py │ ├── utils.py │ └── maps.py ├── __init__.py ├── utils.py ├── action_space.py └── simulator.py ├── saavn.png ├── simulator └── __init__.py ├── trainer ├── __init__.py └── ppo │ ├── __init__.py │ ├── ppo.py │ └── policy.py ├── envs ├── __init__.py └── environments.py ├── storage ├── __init__.py └── rollout_storage.py ├── LICENSE ├── readme.md ├── .gitignore ├── dataset.md ├── main.py └── README.md /ss_baselines/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /soundspaces/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /soundspaces/tasks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ss_baselines/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /soundspaces/visualizations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ss_baselines/av_nav/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /saavn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yyf17/SAAVN/HEAD/saavn.png -------------------------------------------------------------------------------- /ss_baselines/av_nav/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /simulator/__init__.py: -------------------------------------------------------------------------------- 1 | from simulator.simulator import SoundSpaces 2 | 3 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import git 3 | import sys 4 | repo = git.Repo(".", search_parent_directories=True) 5 | if f"{repo.working_tree_dir}" not in sys.path: 6 | sys.path.append(f"{repo.working_tree_dir}") 7 | print("add") 8 | 9 | 10 | from trainer.ppo.ppo_trainer import PPOTrainer 11 | -------------------------------------------------------------------------------- /envs/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import git 4 | import sys 5 | repo = git.Repo(".", search_parent_directories=True) 6 | if f"{repo.working_tree_dir}" not in sys.path: 7 | sys.path.append(f"{repo.working_tree_dir}") 8 | print("add") 9 | 10 | 11 | 12 | from envs.environments import NavRLEnv 13 | -------------------------------------------------------------------------------- /trainer/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | 4 | import os 5 | import git 6 | import sys 7 | repo = git.Repo(".", search_parent_directories=True) 8 | if f"{repo.working_tree_dir}" not in sys.path: 9 | sys.path.append(f"{repo.working_tree_dir}") 10 | print("add") 11 | 12 | from trainer.ppo.ppo_trainer import PPOTrainer 13 | -------------------------------------------------------------------------------- /storage/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import git 3 | import sys 4 | repo = git.Repo(".", search_parent_directories=True) 5 | if f"{repo.working_tree_dir}" not in sys.path: 6 | sys.path.append(f"{repo.working_tree_dir}") 7 | print("add") 8 | 9 | from .rollout_storage import RolloutStorage, RolloutStorageHybrid, RolloutStorageMA 10 | 11 | __all__ =[ 12 | "RolloutStorage", 13 | "RolloutStorageHybrid", 14 | "RolloutStorageMA", 15 | ] -------------------------------------------------------------------------------- /soundspaces/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from soundspaces.action_space import MoveOnlySpaceConfiguration 3 | from soundspaces.simulator import SoundSpaces 4 | from soundspaces.datasets.audionav_dataset import AudioNavDataset 5 | from soundspaces.tasks.audionav_task import AudioNavigationTask 6 | from soundspaces.tasks.audionav_task import AudioGoalSensor 7 | from soundspaces.tasks.audionav_task import SpectrogramSensor 8 | from soundspaces.tasks.audionav_task import Collision 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 YinfengYu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /soundspaces/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import pickle 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def load_metadata(parent_folder): 9 | points_file = os.path.join(parent_folder, 'points.txt') 10 | if "replica" in parent_folder: 11 | graph_file = os.path.join(parent_folder, 'graph.pkl') 12 | points_data = np.loadtxt(points_file, delimiter="\t") 13 | points = list(zip( 14 | points_data[:, 1], 15 | points_data[:, 3] - 1.5528907, 16 | -points_data[:, 2]) 17 | ) 18 | else: 19 | graph_file = os.path.join(parent_folder, 'graph.pkl') 20 | points_data = np.loadtxt(points_file, delimiter="\t") 21 | points = list(zip( 22 | points_data[:, 1], 23 | points_data[:, 3] - 1.5, 24 | -points_data[:, 2]) 25 | ) 26 | if not os.path.exists(graph_file): 27 | raise FileExistsError(graph_file + ' does not exist!') 28 | else: 29 | with open(graph_file, 'rb') as fo: 30 | graph = pickle.load(fo) 31 | 32 | return points, graph 33 | 34 | 35 | def _to_tensor(v): 36 | if torch.is_tensor(v): 37 | return v 38 | elif isinstance(v, np.ndarray): 39 | return torch.from_numpy(v) 40 | else: 41 | return torch.tensor(v, dtype=torch.float) 42 | -------------------------------------------------------------------------------- /ss_baselines/common/baseline_registry.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import Optional 4 | 5 | from habitat.core.registry import Registry 6 | 7 | 8 | class BaselineRegistry(Registry): 9 | @classmethod 10 | def register_trainer(cls, to_register=None, *, name: Optional[str] = None): 11 | r"""Register a RL training algorithm to registry with key 'name'. 12 | 13 | Args: 14 | name: Key with which the trainer will be registered. 15 | If None will use the name of the class. 16 | 17 | """ 18 | from ss_baselines.common.base_trainer import BaseTrainer 19 | 20 | return cls._register_impl( 21 | "trainer", to_register, name, assert_type=BaseTrainer 22 | ) 23 | 24 | @classmethod 25 | def get_trainer(cls, name): 26 | return cls._get_impl("trainer", name) 27 | 28 | @classmethod 29 | def register_env(cls, to_register=None, *, name: Optional[str] = None): 30 | r"""Register a environment to registry with key 'name' 31 | currently only support subclass of RLEnv. 32 | 33 | Args: 34 | name: Key with which the env will be registered. 35 | If None will use the name of the class. 36 | 37 | """ 38 | 39 | return cls._register_impl("env", to_register, name) 40 | 41 | @classmethod 42 | def get_env(cls, name): 43 | return cls._get_impl("env", name) 44 | 45 | 46 | baseline_registry = BaselineRegistry() 47 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## Usage 2 | This repo supports AudioGoal Task on Replica and Matterport3D datasets. 3 | 4 | Below we show the commands for training and evaluating AudioGoal with Depth sensor on Replica, 5 | but it applies to Matterport dataset as well. 6 | 1. Training 7 | ``` 8 | python main.py --default av_nav --run-type train --exp-config [exp_config_file] --model-dir data/models/replica/av_nav/e0000/audiogoal_depth --tag-config [tag_config_file] TORCH_GPU_ID 0 SIMULATOR_GPU_ID 0 9 | ``` 10 | 2. Validation (evaluate each checkpoint and generate a validation curve) 11 | ``` 12 | python main.py --default av_nav --run-type eval --exp-config [exp_config_file] --model-dir data/models/replica/av_nav/e0000/audiogoal_depth --tag-config [tag_config_file] TORCH_GPU_ID 0 SIMULATOR_GPU_ID 0 13 | ``` 14 | 3. Test the best validation checkpoint based on validation curve 15 | ``` 16 | python main.py --default av_nav --run-type eval --exp-config [exp_config_file] --model-dir data/models/replica/av_nav/e0000/audiogoal_depth --tag-config [tag_config_file] TORCH_GPU_ID 0 SIMULATOR_GPU_ID 0 17 | ``` 18 | 4. Generate demo video with audio 19 | ``` 20 | python main.py --default av_nav --run-type eval --exp-config [exp_config_file] --model-dir data/models/replica/av_nav/e0000/audiogoal_depth --tag-config [tag_config_file] TORCH_GPU_ID 0 SIMULATOR_GPU_ID 0 21 | ``` 22 | 23 | Note: [exp_config_file] is the main parameter configuration file of the experiment, while [tag_config_file] is special parameter configuration file for abalation experiments. 24 | -------------------------------------------------------------------------------- /soundspaces/action_space.py: -------------------------------------------------------------------------------- 1 | 2 | import habitat_sim 3 | from habitat.core.registry import registry 4 | 5 | from habitat.core.simulator import ActionSpaceConfiguration 6 | from habitat.sims.habitat_simulator.actions import HabitatSimActions 7 | 8 | HabitatSimActions.extend_action_space("MOVE_BACKWARD") 9 | HabitatSimActions.extend_action_space("MOVE_LEFT") 10 | HabitatSimActions.extend_action_space("MOVE_RIGHT") 11 | 12 | 13 | @registry.register_action_space_configuration(name="move-all") 14 | class MoveOnlySpaceConfiguration(ActionSpaceConfiguration): 15 | def get(self): 16 | return { 17 | HabitatSimActions.STOP: habitat_sim.ActionSpec("stop"), 18 | HabitatSimActions.MOVE_FORWARD: habitat_sim.ActionSpec( 19 | "move_forward", 20 | habitat_sim.ActuationSpec( 21 | amount=self.config.FORWARD_STEP_SIZE 22 | ), 23 | ), 24 | HabitatSimActions.MOVE_BACKWARD: habitat_sim.ActionSpec( 25 | "move_backward", 26 | habitat_sim.ActuationSpec( 27 | amount=self.config.FORWARD_STEP_SIZE 28 | ), 29 | ), 30 | HabitatSimActions.MOVE_RIGHT: habitat_sim.ActionSpec( 31 | "move_right", 32 | habitat_sim.ActuationSpec( 33 | amount=self.config.FORWARD_STEP_SIZE 34 | ), 35 | ), 36 | HabitatSimActions.MOVE_LEFT: habitat_sim.ActionSpec( 37 | "move_left", 38 | habitat_sim.ActuationSpec( 39 | amount=self.config.FORWARD_STEP_SIZE 40 | ), 41 | ) 42 | } 43 | -------------------------------------------------------------------------------- /ss_baselines/common/tensorboard_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import Any 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | 10 | class TensorboardWriter: 11 | def __init__(self, log_dir: str, *args: Any, **kwargs: Any): 12 | r"""A Wrapper for tensorboard SummaryWriter. It creates a dummy writer 13 | when log_dir is empty string or None. It also has functionality that 14 | generates tb video directly from numpy images. 15 | 16 | Args: 17 | log_dir: Save directory location. Will not write to disk if 18 | log_dir is an empty string. 19 | *args: Additional positional args for SummaryWriter 20 | **kwargs: Additional keyword args for SummaryWriter 21 | """ 22 | self.writer = None 23 | if log_dir is not None and len(log_dir) > 0: 24 | self.writer = SummaryWriter(log_dir, *args, **kwargs) 25 | 26 | def __getattr__(self, item): 27 | if self.writer: 28 | return self.writer.__getattribute__(item) 29 | else: 30 | return lambda *args, **kwargs: None 31 | 32 | def __enter__(self): 33 | return self 34 | 35 | def __exit__(self, exc_type, exc_val, exc_tb): 36 | if self.writer: 37 | self.writer.close() 38 | 39 | def add_video_from_np_images( 40 | self, video_name: str, step_idx: int, images: np.ndarray, fps: int = 10 41 | ) -> None: 42 | r"""Write video into tensorboard from images frames. 43 | 44 | Args: 45 | video_name: name of video string. 46 | step_idx: int of checkpoint index to be displayed. 47 | images: list of n frames. Each frame is a np.ndarray of shape. 48 | fps: frame per second for output video. 49 | 50 | Returns: 51 | None. 52 | """ 53 | if not self.writer: 54 | return 55 | # initial shape of np.ndarray list: N * (H, W, 3) 56 | frame_tensors = [ 57 | torch.from_numpy(np_arr).unsqueeze(0) for np_arr in images 58 | ] 59 | video_tensor = torch.cat(tuple(frame_tensors)) 60 | video_tensor = video_tensor.permute(0, 3, 1, 2).unsqueeze(0) 61 | # final shape of video tensor: (1, n, 3, H, W) 62 | self.writer.add_video( 63 | video_name, video_tensor, fps=fps, global_step=step_idx 64 | ) 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /ss_baselines/av_nav/models/audio_cnn.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ss_baselines.common.utils import Flatten 7 | from ss_baselines.av_nav.models.visual_cnn import conv_output_dim, layer_init 8 | 9 | 10 | class AudioCNN(nn.Module): 11 | r"""A Simple 3-Conv CNN for processing audio spectrogram features 12 | 13 | Args: 14 | observation_space: The observation_space of the agent 15 | output_size: The size of the embedding vector 16 | """ 17 | 18 | def __init__(self, observation_space, output_size, audiogoal_sensor): 19 | super(AudioCNN, self).__init__() 20 | self._n_input_audio = observation_space.spaces[audiogoal_sensor].shape[2] 21 | self._audiogoal_sensor = audiogoal_sensor 22 | 23 | cnn_dims = np.array( 24 | observation_space.spaces[audiogoal_sensor].shape[:2], dtype=np.float32 25 | ) 26 | 27 | if cnn_dims[0] < 30 or cnn_dims[1] < 30: 28 | self._cnn_layers_kernel_size = [(5, 5), (3, 3), (3, 3)] 29 | self._cnn_layers_stride = [(2, 2), (2, 2), (1, 1)] 30 | else: 31 | self._cnn_layers_kernel_size = [(8, 8), (4, 4), (3, 3)] 32 | self._cnn_layers_stride = [(4, 4), (2, 2), (1, 1)] 33 | 34 | for kernel_size, stride in zip( 35 | self._cnn_layers_kernel_size, self._cnn_layers_stride 36 | ): 37 | cnn_dims = conv_output_dim( 38 | dimension=cnn_dims, 39 | padding=np.array([0, 0], dtype=np.float32), 40 | dilation=np.array([1, 1], dtype=np.float32), 41 | kernel_size=np.array(kernel_size, dtype=np.float32), 42 | stride=np.array(stride, dtype=np.float32), 43 | ) 44 | 45 | self.cnn = nn.Sequential( 46 | nn.Conv2d( 47 | in_channels=self._n_input_audio, 48 | out_channels=32, 49 | kernel_size=self._cnn_layers_kernel_size[0], 50 | stride=self._cnn_layers_stride[0], 51 | ), 52 | nn.ReLU(True), 53 | nn.Conv2d( 54 | in_channels=32, 55 | out_channels=64, 56 | kernel_size=self._cnn_layers_kernel_size[1], 57 | stride=self._cnn_layers_stride[1], 58 | ), 59 | nn.ReLU(True), 60 | nn.Conv2d( 61 | in_channels=64, 62 | out_channels=64, 63 | kernel_size=self._cnn_layers_kernel_size[2], 64 | stride=self._cnn_layers_stride[2], 65 | ), 66 | # nn.ReLU(True), 67 | Flatten(), 68 | nn.Linear(64 * cnn_dims[0] * cnn_dims[1], output_size), 69 | nn.ReLU(True), 70 | ) 71 | 72 | layer_init(self.cnn) 73 | 74 | def forward(self, observations): 75 | cnn_input = [] 76 | 77 | audio_observations = observations[self._audiogoal_sensor] 78 | # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] 79 | audio_observations = audio_observations.permute(0, 3, 1, 2) 80 | cnn_input.append(audio_observations) 81 | 82 | cnn_input = torch.cat(cnn_input, dim=1) 83 | 84 | return self.cnn(cnn_input) 85 | -------------------------------------------------------------------------------- /envs/environments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import Optional, Type 4 | import logging 5 | 6 | import habitat 7 | from habitat import Config, Dataset 8 | 9 | from ss_baselines.common.baseline_registry import baseline_registry 10 | 11 | @baseline_registry.register_env(name="NavRLEnv") 12 | class NavRLEnv(habitat.RLEnv): 13 | def __init__(self, config: Config, dataset: Optional[Dataset] = None): 14 | self._rl_config = config.RL 15 | self._core_env_config = config.TASK_CONFIG 16 | 17 | self._previous_target_distance = None 18 | self._previous_action = None 19 | self._episode_distance_covered = None 20 | self._success_distance = self._core_env_config.TASK.SUCCESS_DISTANCE 21 | super().__init__(self._core_env_config, dataset) 22 | 23 | # @profile 24 | def reset(self): 25 | self._previous_action = None 26 | 27 | observations = super().reset() 28 | logging.debug(super().current_episode) 29 | 30 | self._previous_target_distance = self.habitat_env.current_episode.info[ 31 | "geodesic_distance" 32 | ] 33 | return observations 34 | 35 | def step(self, *args, **kwargs): 36 | self._previous_action = kwargs["action"] 37 | return super().step(*args, **kwargs) 38 | 39 | def get_reward_range(self): 40 | return ( 41 | self._rl_config.SLACK_REWARD - 1.0, 42 | self._rl_config.SUCCESS_REWARD + 1.0, 43 | ) 44 | 45 | def get_reward(self, observations): 46 | reward = 0 47 | 48 | if self._rl_config.WITH_TIME_PENALTY: 49 | reward += self._rl_config.SLACK_REWARD 50 | 51 | if self._rl_config.WITH_DISTANCE_REWARD: 52 | current_target_distance = self._distance_target() 53 | # if current_target_distance < self._previous_target_distance: 54 | reward += (self._previous_target_distance - current_target_distance) * self._rl_config.DISTANCE_REWARD_SCALE 55 | self._previous_target_distance = current_target_distance 56 | 57 | if self._episode_success(): 58 | reward += self._rl_config.SUCCESS_REWARD 59 | logging.debug('Reaching goal!') 60 | 61 | return reward 62 | 63 | def _distance_target(self): 64 | current_position = self._env.sim.get_agent_state().position.tolist() 65 | target_position = self._env.current_episode.goals[0].position 66 | distance = self._env.sim.geodesic_distance( 67 | current_position, target_position 68 | ) 69 | return distance 70 | 71 | def _episode_success(self): 72 | if ( 73 | self._env.task.is_stop_called 74 | # and self._distance_target() < self._success_distance 75 | and self._env.sim.reaching_goal 76 | ): 77 | return True 78 | return False 79 | 80 | def get_done(self, observations): 81 | done = False 82 | if self._env.episode_over or self._episode_success(): 83 | done = True 84 | return done 85 | 86 | def get_info(self, observations): 87 | return self.habitat_env.get_metrics() 88 | 89 | # for data collection 90 | def get_current_episode_id(self): 91 | return self.habitat_env.current_episode.episode_id 92 | -------------------------------------------------------------------------------- /dataset.md: -------------------------------------------------------------------------------- 1 | # SoundSpaces Dataset 2 | 3 | ## Overview 4 | The SoundSpaces dataset includes audio renderings (room impulse responses) for two datasets, metadata of each scene, episode datasets and mono sound files. 5 | 6 | 7 | ## Download 8 | 0. Create a folder named "data" under root directory 9 | 1. Run the commands below in the **data** directory to download partial binaural RIRs (867G), metadata (1M), datasets (77M) and sound files (13M). Note that this partial binaural RIRs only contain renderings for nodes accessible by the agent on the navigation graph. 10 | ``` 11 | wget http://dl.fbaipublicfiles.com/SoundSpaces/binaural_rirs.tar && tar xvf binaural_rirs.tar 12 | wget http://dl.fbaipublicfiles.com/SoundSpaces/metadata.tar.xz && tar xvf metadata.tar.xz 13 | wget http://dl.fbaipublicfiles.com/SoundSpaces/sounds.tar.xz && tar xvf sounds.tar.xz 14 | wget http://dl.fbaipublicfiles.com/SoundSpaces/datasets.tar.xz && tar xvf datasets.tar.xz 15 | wget http://dl.fbaipublicfiles.com/SoundSpaces/pretrained_weights.tar.xz && tar xvf pretrained_weights.tar.xz 16 | ``` 17 | 2. Download [Replica-Dataset](https://github.com/facebookresearch/Replica-Dataset) and [Matterport3D](https://niessner.github.io/Matterport). 18 | 3. Run the command below in the root directory to cache observations for two datasets 19 | ``` 20 | python scripts/cache_observations.py 21 | ``` 22 | 4. (Optional) Download the full ambisonic (3.6T for Matterport) and binaural (682G for Matterport and 81G for Replica) RIRs data by running the following script in the root directory. Remember to first back up the downloaded bianural RIR data. 23 | ``` 24 | python scripts/download_data.py --dataset mp3d --rir-type binaural_rirs 25 | python scripts/download_data.py --dataset replica --rir-type binaural_rirs 26 | ``` 27 | 28 | 29 | ## Data Folder Structure 30 | ``` 31 | . 32 | ├── ... 33 | ├── metadata # stores metadata of environments 34 | │ └── [dataset] 35 | │ └── [scene] 36 | │ ├── point.txt # coordinates of all points in mesh coordinates 37 | │ ├── graph.pkl # points are pruned to a connectivity graph 38 | ├── binaural_rirs # binaural RIRs of 2 channels 39 | │ └── [dataset] 40 | │ └── [scene] 41 | │ └── [angle] # azimuth angle of agent's heading in mesh coordinates 42 | │ └── [receiver]-[source].wav 43 | ├── datasets # stores datasets of episodes of different splits 44 | │ └── [dataset] 45 | │ └── [version] 46 | │ └── [split] 47 | │ ├── [split].json.gz 48 | │ └── content 49 | │ └── [scene].json.gz 50 | ├── sounds # stores all 102 copyright-free sounds 51 | │ └── 1s_all 52 | │ └── [sound].wav 53 | ├── scene_datasets # scene_datasets 54 | │ └── [dataset] 55 | │ └── [scene] 56 | │ └── [scene].house (habitat/mesh_sementic.glb) 57 | └── scene_observations # pre-rendered scene observations 58 | │ └── [dataset] 59 | │ └── [scene].pkl # dictionary is in the format of {(receiver, rotation): sim_obs} 60 | ``` 61 | -------------------------------------------------------------------------------- /ss_baselines/common/benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from collections import defaultdict 4 | from typing import Dict, Optional 5 | import logging 6 | 7 | from tqdm import tqdm 8 | 9 | from habitat import Config 10 | from habitat.core.agent import Agent 11 | # from habitat.core.env import Env 12 | from ss_baselines.common.environments import NavRLEnv 13 | from habitat.datasets import make_dataset 14 | 15 | 16 | class Benchmark: 17 | r"""Benchmark for evaluating agents in environments. 18 | """ 19 | 20 | def __init__(self, task_config: Optional[Config] = None) -> None: 21 | r""".. 22 | 23 | :param task_config: config to be used for creating the environment 24 | """ 25 | dummy_config = Config() 26 | dummy_config.RL = Config() 27 | dummy_config.RL.SLACK_REWARD = -0.01 28 | dummy_config.RL.SUCCESS_REWARD = 10 29 | dummy_config.RL.WITH_TIME_PENALTY = True 30 | dummy_config.RL.DISTANCE_REWARD_SCALE = 1 31 | dummy_config.RL.WITH_DISTANCE_REWARD = True 32 | dummy_config.RL.defrost() 33 | dummy_config.TASK_CONFIG = task_config 34 | dummy_config.freeze() 35 | 36 | dataset = make_dataset(id_dataset=task_config.DATASET.TYPE, config=task_config.DATASET) 37 | self._env = NavRLEnv(config=dummy_config, dataset=dataset) 38 | 39 | def evaluate( 40 | self, agent: Agent, num_episodes: Optional[int] = None 41 | ) -> Dict[str, float]: 42 | r""".. 43 | 44 | :param agent: agent to be evaluated in environment. 45 | :param num_episodes: count of number of episodes for which the 46 | evaluation should be run. 47 | :return: dict containing metrics tracked by environment. 48 | """ 49 | 50 | if num_episodes is None: 51 | num_episodes = len(self._env.episodes) 52 | else: 53 | assert num_episodes <= len(self._env.episodes), ( 54 | "num_episodes({}) is larger than number of episodes " 55 | "in environment ({})".format( 56 | num_episodes, len(self._env.episodes) 57 | ) 58 | ) 59 | 60 | assert num_episodes > 0, "num_episodes should be greater than 0" 61 | 62 | agg_metrics: Dict = defaultdict(float) 63 | 64 | count_episodes = 0 65 | reward_episodes = 0 66 | step_episodes = 0 67 | success_count = 0 68 | for count_episodes in tqdm(range(num_episodes)): 69 | agent.reset() 70 | observations = self._env.reset() 71 | episode_reward = 0 72 | 73 | while not self._env.habitat_env.episode_over: 74 | action = agent.act(observations) 75 | observations, reward, done, info = self._env.step(**action) 76 | logging.debug("Reward: {}".format(reward)) 77 | if done: 78 | logging.debug('Episode reward: {}'.format(episode_reward)) 79 | episode_reward += reward 80 | step_episodes += 1 81 | 82 | metrics = self._env.habitat_env.get_metrics() 83 | for m, v in metrics.items(): 84 | agg_metrics[m] += v 85 | reward_episodes += episode_reward 86 | success_count += metrics['spl'] > 0 87 | 88 | avg_metrics = {k: v / count_episodes for k, v in agg_metrics.items()} 89 | logging.info("Average reward: {} in {} episodes".format(reward_episodes / count_episodes, count_episodes)) 90 | logging.info("Average episode steps: {}".format(step_episodes / count_episodes)) 91 | logging.info('Success rate: {}'.format(success_count / num_episodes)) 92 | 93 | return avg_metrics 94 | -------------------------------------------------------------------------------- /ss_baselines/common/environments.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import Optional, Type 4 | import logging 5 | 6 | import habitat 7 | from habitat import Config, Dataset 8 | from ss_baselines.common.baseline_registry import baseline_registry 9 | 10 | 11 | def get_env_class(env_name: str) -> Type[habitat.RLEnv]: 12 | r"""Return environment class based on name. 13 | 14 | Args: 15 | env_name: name of the environment. 16 | 17 | Returns: 18 | Type[habitat.RLEnv]: env class. 19 | """ 20 | return baseline_registry.get_env(env_name) 21 | 22 | 23 | @baseline_registry.register_env(name="NavRLEnv") 24 | class NavRLEnv(habitat.RLEnv): 25 | def __init__(self, config: Config, dataset: Optional[Dataset] = None): 26 | self._rl_config = config.RL 27 | self._core_env_config = config.TASK_CONFIG 28 | 29 | self._previous_target_distance = None 30 | self._previous_action = None 31 | self._episode_distance_covered = None 32 | self._success_distance = self._core_env_config.TASK.SUCCESS_DISTANCE 33 | super().__init__(self._core_env_config, dataset) 34 | 35 | def reset(self): 36 | self._previous_action = None 37 | 38 | observations = super().reset() 39 | logging.debug(super().current_episode) 40 | 41 | self._previous_target_distance = self.habitat_env.current_episode.info[ 42 | "geodesic_distance" 43 | ] 44 | return observations 45 | 46 | def step(self, *args, **kwargs): 47 | self._previous_action = kwargs["action"] 48 | return super().step(*args, **kwargs) 49 | 50 | def get_reward_range(self): 51 | return ( 52 | self._rl_config.SLACK_REWARD - 1.0, 53 | self._rl_config.SUCCESS_REWARD + 1.0, 54 | ) 55 | 56 | def get_reward(self, observations): 57 | reward = 0 58 | 59 | if self._rl_config.WITH_TIME_PENALTY: 60 | reward += self._rl_config.SLACK_REWARD 61 | 62 | if self._rl_config.WITH_DISTANCE_REWARD: 63 | current_target_distance = self._distance_target() 64 | # if current_target_distance < self._previous_target_distance: 65 | reward += (self._previous_target_distance - current_target_distance) * self._rl_config.DISTANCE_REWARD_SCALE 66 | self._previous_target_distance = current_target_distance 67 | 68 | if self._episode_success(): 69 | reward += self._rl_config.SUCCESS_REWARD 70 | logging.debug('Reaching goal!') 71 | 72 | return reward 73 | 74 | def _distance_target(self): 75 | current_position = self._env.sim.get_agent_state().position.tolist() 76 | target_positions = [goal.position for goal in self._env.current_episode.goals] 77 | distance = self._env.sim.geodesic_distance( 78 | current_position, target_positions 79 | ) 80 | return distance 81 | 82 | def _episode_success(self): 83 | if ( 84 | self._env.task.is_stop_called 85 | # and self._distance_target() < self._success_distance 86 | and self._env.sim.reaching_goal 87 | ): 88 | return True 89 | return False 90 | 91 | def get_done(self, observations): 92 | done = False 93 | if self._env.episode_over or self._episode_success(): 94 | done = True 95 | return done 96 | 97 | def get_info(self, observations): 98 | return self.habitat_env.get_metrics() 99 | 100 | # for data collection 101 | def get_current_episode_id(self): 102 | return self.habitat_env.current_episode.episode_id 103 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import git 4 | import sys 5 | repo = git.Repo(".", search_parent_directories=True) 6 | # print(repo.working_tree_dir) 7 | if f"{repo.working_tree_dir}" not in sys.path: 8 | sys.path.append(f"{repo.working_tree_dir}") 9 | print("add") 10 | 11 | import argparse 12 | import logging 13 | 14 | import warnings 15 | warnings.filterwarnings('ignore', category=FutureWarning) 16 | warnings.filterwarnings('ignore', category=UserWarning) 17 | import tensorflow as tf 18 | import torch 19 | 20 | 21 | from ss_baselines.common.baseline_registry import baseline_registry 22 | from habitat.core.registry import registry 23 | 24 | from contextlib import redirect_stdout 25 | 26 | from simulator import get_simulator_class 27 | from envs import get_env_class 28 | from trainer import get_trainer_class 29 | 30 | 31 | from copy import deepcopy 32 | import sys 33 | def get_default_config_by_arg(arg_name="--default"): 34 | 35 | cmd_params = deepcopy(sys.argv) 36 | 37 | print("sys.argv:",sys.argv) 38 | 39 | default_config_str = None 40 | for _i, _v in enumerate(cmd_params): 41 | if _v.split(" ")[0] == arg_name: 42 | default_config_str = _v.split(" ")[1] 43 | del cmd_params[_i] 44 | break 45 | 46 | assert default_config_str is not None ,"default config is not specify" 47 | return default_config_str 48 | 49 | 50 | 51 | DEFAULT_CONFIG_DIR = "configs/" 52 | 53 | def main(): 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument( 56 | "--default", 57 | choices="ma", 58 | # required=True, 59 | default="ma", 60 | help="default config of the experiment to get_config function", 61 | ) 62 | parser.add_argument( 63 | "--run-type", 64 | choices=["train", "eval"], 65 | # required=True, 66 | default='train', 67 | help="run type of the experiment (train or eval)", 68 | ) 69 | parser.add_argument( 70 | "--exp-config", 71 | type=str, 72 | # required=True, 73 | default='av_nav/config/pointgoal_rgb.yaml', 74 | help="path to config yaml containing info about experiment", 75 | ) 76 | parser.add_argument( 77 | "--tag-config", 78 | type=str, 79 | # required=True, 80 | default='', 81 | help="path to config yaml containing info about experiment with tag", 82 | ) 83 | parser.add_argument( 84 | "--model-dir", 85 | default=None, 86 | help="Modify config options from command line", 87 | ) 88 | parser.add_argument( 89 | "--eval-interval", 90 | type=int, 91 | default=1, 92 | help="Evaluation interval of checkpoints", 93 | ) 94 | parser.add_argument( 95 | "--overwrite", 96 | default=False, 97 | action='store_true', 98 | help="Modify config options from command line" 99 | ) 100 | parser.add_argument( 101 | "--prev-ckpt-ind", 102 | type=int, 103 | default=-1, 104 | help="Evaluation interval of checkpoints", 105 | ) 106 | parser.add_argument( 107 | "--eval-best", 108 | default=False, 109 | help="Modify config options from command line" 110 | ) 111 | parser.add_argument( 112 | "opts", 113 | default=None, 114 | nargs=argparse.REMAINDER, 115 | help="Modify config options from command line", 116 | ) 117 | args = parser.parse_args() 118 | 119 | ckpt_msg = "" 120 | if args.eval_best: 121 | best_ckpt_idx = find_best_ckpt_idx(os.path.join(args.model_dir, 'tb')) 122 | best_ckpt_path = os.path.join(args.model_dir, 'data', f'ckpt.{best_ckpt_idx}.pth') 123 | ckpt_msg = f"best spl ckpt:{best_ckpt_path}" 124 | args.opts += ['EVAL_CKPT_PATH_DIR', best_ckpt_path] 125 | 126 | root = getattr(args,'default') 127 | 128 | delattr(args,'default') 129 | delattr(args,'eval_best') 130 | 131 | 132 | config = get_config(get_config_dict[root]["_C"],get_config_dict[root]["_TC"],args.tag_config,args.exp_config, args.opts, args.model_dir, args.run_type, args.overwrite) 133 | config.defrost() 134 | config.TASK_CONFIG.SIMULATOR.TYPE = config.SIM_NAME 135 | config.freeze() 136 | 137 | 138 | trainer_init = baseline_registry.get_trainer(config.TRAINER_NAME) 139 | 140 | assert trainer_init is not None, f"{config.TRAINER_NAME} is not supported" 141 | trainer = trainer_init(config) 142 | torch.set_num_threads(1) 143 | 144 | level = logging.DEBUG if config.DEBUG else logging.INFO 145 | logging.basicConfig(level=level, format='%(asctime)s, %(levelname)s: %(message)s', 146 | datefmt="%Y-%m-%d %H:%M:%S") 147 | 148 | logging.info(ckpt_msg) 149 | if args.run_type == "train": 150 | trainer.train() 151 | elif args.run_type == "eval": 152 | trainer.eval(args.eval_interval, args.prev_ckpt_ind, config.USE_LAST_CKPT) 153 | 154 | 155 | if __name__ == "__main__": 156 | main() 157 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAAVN 2 | # SAAVN Code release for paper "Sound Adversarial Audio-Visual Navigation,ICLR2022" (In PyTorch) 3 | 4 | Contribution of SoundSpaces: 5 | - Build an audio simulation platform SoundSpaces[1] to enable audio-visual navigation for two visually realistic 3D environments: Replica[2] and Matterport3D[3]. 6 | - Proposed AudioGoal navigation Task:This task requires a robot equipped with a camera and microphones to interact with the environment and navigate to a sounding object. 7 | - SoundSpaces dataset: SoundSpaces is a first-of-its-kind dataset of audio renderings based on geometrical acoustic simulations for two sets of publicly available 3D environments: Replica[2] and Matterport3D[3]. 8 | ### Characteristic of SoundSpaces 9 | 10 | Sumary:SoundSpaces is focus on audio-visual navigation problem in the acoustically clean or simple environment: 11 | - The number of target sound sources is one. 12 | - The position of the target sound source is fixed in an episode of a scene. 13 | - The volume of the target sound source is the same in all episodes of all scenes, and there is no change. 14 | 15 | All in all, the sound in the setting of SoundSpaces is acoustically clean or simple. 16 | 17 | ## What we do? 18 | 19 | ### Motivation and Challenge 20 | 21 | However,there are many situations different from the setting of SoundSpaces , which there are some non-target sounding objects in the scene: 22 | For example, a kettle in the kitchen beeps to tell the robotthat the water is boiling, and the robot in the living room needs to navigate to the kitchen and turnoff the stove; while in the living room, two children are playing a game, chuckling loudly fromtime to time. 23 | 24 | #### Challenge 1: 25 | Can an agent still find its way to the destination without being distracted by all non-target sounds around the agent? 26 | 27 | non-target sounding objects: 28 | - not deliberately embarrassing the robot: someone walking and chatting past the robot 29 | - deliberately embarrassing the robot: someone blocking the robot forwarding 30 | 31 | #### Challenge 2: 32 | 33 | How to model non-target sounding objects in simulator or in reality? There are no such setting existed! 34 | 35 | ### Solution policy 36 | 37 | - Worst case strategy: Regard non-target sounding objects as deliberately embarrassing the robot,we called them as sound attacker. 38 | - Simplify:Only consider the simplest situation,one sound attacker. 39 | - Zero sum game:One agent,one sound attacker. 40 | 41 | 42 | 43 | ![SAAVN](saavn.png) 44 | --------------------------------------------------------------------------------------------------- 45 | 46 | ## These code are under cleaning! Some of bugs maybe happen, please tell me if you have any trouble. 47 | 48 | ## Thanks 49 | 50 | These codes are based on the [SoundSpaces](https://github.com/facebookresearch/sound-spaces) code base. 51 | 52 | ## Usage 53 | This repo supports AudioGoal Task on Replica and Matterport3D datasets. 54 | 55 | Below we show the commands for training and evaluating AudioGoal with Depth sensor on Replica, 56 | but it applies to Matterport dataset as well. 57 | 1. Training 58 | ``` 59 | python main.py --default av_nav --run-type train --exp-config [exp_config_file] --model-dir data/models/replica/av_nav/e0000/audiogoal_depth --tag-config [tag_config_file] TORCH_GPU_ID 0 SIMULATOR_GPU_ID 0 60 | ``` 61 | 2. Validation (evaluate each checkpoint and generate a validation curve) 62 | ``` 63 | python main.py --default av_nav --run-type eval --exp-config [exp_config_file] --model-dir data/models/replica/av_nav/e0000/audiogoal_depth --tag-config [tag_config_file] TORCH_GPU_ID 0 SIMULATOR_GPU_ID 0 64 | ``` 65 | 3. Test the best validation checkpoint based on validation curve 66 | ``` 67 | python main.py --default av_nav --run-type eval --exp-config [exp_config_file] --model-dir data/models/replica/av_nav/e0000/audiogoal_depth --tag-config [tag_config_file] TORCH_GPU_ID 0 SIMULATOR_GPU_ID 0 68 | ``` 69 | 4. Generate demo video with audio 70 | ``` 71 | python main.py --default av_nav --run-type eval --exp-config [exp_config_file] --model-dir data/models/replica/av_nav/e0000/audiogoal_depth --tag-config [tag_config_file] TORCH_GPU_ID 0 SIMULATOR_GPU_ID 0 72 | ``` 73 | 74 | Note: [exp_config_file] is the main parameter configuration file of the experiment, while [tag_config_file] is special parameter configuration file for abalation experiments. 75 | 76 | ## Citation 77 | If you use this model in your research, please cite the following paper: 78 | ``` 79 | @inproceedings{YinfengICLR2022saavn, 80 | author = {Yinfeng Yu and 81 | Wenbing Huang and 82 | Fuchun Sun and 83 | Changan Chen and 84 | Yikai Wang and 85 | Xiaohong Liu}, 86 | title = {Sound Adversarial Audio-Visual Navigation}, 87 | booktitle = {The Tenth International Conference on Learning Representations, {ICLR} 88 | 2022, Virtual Event, April 25-29, 2022}, 89 | publisher = {OpenReview.net}, 90 | year = {2022}, 91 | url = {https://openreview.net/forum?id=NkZq4OEYN-}, 92 | timestamp = {Thu, 18 Aug 2022 18:42:35 +0200}, 93 | biburl = {https://dblp.org/rec/conf/iclr/Yu00C0L22.bib}, 94 | bibsource = {dblp computer science bibliography, https://dblp.org} 95 | } 96 | ``` 97 | -------------------------------------------------------------------------------- /ss_baselines/common/env_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Type, Union 3 | import logging 4 | import copy 5 | import random 6 | 7 | import numpy as np 8 | import torch 9 | 10 | import habitat 11 | from habitat import Config, Env, RLEnv, VectorEnv 12 | from habitat.datasets import make_dataset 13 | from ss_baselines.common.sync_vector_env import SyncVectorEnv 14 | 15 | SCENES = ['apartment_0', 'apartment_1', 'apartment_2', 'frl_apartment_0', 'frl_apartment_1', 'frl_apartment_2', 16 | 'frl_apartment_3', 'frl_apartment_4', 'frl_apartment_5', 'office_0', 'office_1', 'office_2', 17 | 'office_3', 'office_4', 'hotel_0', 'room_0', 'room_1', 'room_2'] 18 | 19 | 20 | def construct_envs( 21 | config: Config, env_class: Type[Union[Env, RLEnv]], auto_reset_done=True 22 | ) -> VectorEnv: 23 | r"""Create VectorEnv object with specified config and env class type. 24 | To allow better performance, dataset are split into small ones for 25 | each individual env, grouped by scenes. 26 | 27 | Args: 28 | config: configs that contain num_processes as well as information 29 | necessary to create individual environments. 30 | env_class: class type of the envs to be created 31 | auto_reset_done: automatically reset environments when done 32 | Returns: 33 | VectorEnv object created according to specification. 34 | """ 35 | 36 | num_processes = config.NUM_PROCESSES 37 | configs = [] 38 | env_classes = [env_class for _ in range(num_processes)] 39 | dataset = make_dataset(config.TASK_CONFIG.DATASET.TYPE) 40 | scenes = dataset.get_scenes_to_load(config.TASK_CONFIG.DATASET) 41 | 42 | # if len(scenes) > 0: 43 | # # random.shuffle(scenes) 44 | # assert len(scenes) >= num_processes, ( 45 | # "reduce the number of processes as there " 46 | # "aren't enough number of scenes" 47 | # ) 48 | if len(scenes) >= num_processes: 49 | # rearrange scenes in the order of data scale since there is a severe imbalance of data size 50 | scenes_new = list() 51 | for scene in SCENES: 52 | if scene in scenes: 53 | scenes_new.append(scene) 54 | scenes = scenes_new 55 | 56 | scene_splits = [[] for _ in range(num_processes)] 57 | for idx, scene in enumerate(scenes): 58 | scene_splits[idx % len(scene_splits)].append(scene) 59 | assert sum(map(len, scene_splits)) == len(scenes) 60 | else: 61 | scene_splits = [copy.deepcopy(scenes) for _ in range(num_processes)] 62 | for split in scene_splits: 63 | random.shuffle(split) 64 | 65 | for i in range(num_processes): 66 | task_config = config.TASK_CONFIG.clone() 67 | task_config.defrost() 68 | if len(scenes) > 0: 69 | task_config.DATASET.CONTENT_SCENES = scene_splits[i] 70 | logging.debug('All scenes: {}'.format(','.join(scene_splits[i]))) 71 | 72 | # overwrite the task config with top-level config file 73 | task_config.SIMULATOR.HABITAT_SIM_V0.GPU_DEVICE_ID = ( 74 | config.SIMULATOR_GPU_ID 75 | ) 76 | task_config.SIMULATOR.AGENT_0.SENSORS = config.SENSORS 77 | task_config.freeze() 78 | 79 | config.defrost() 80 | config.TASK_CONFIG = task_config 81 | config.freeze() 82 | configs.append(config.clone()) 83 | 84 | # use VectorEnv for the best performance and ThreadedVectorEnv for debugging 85 | if config.USE_SYNC_VECENV: 86 | env_launcher = SyncVectorEnv 87 | logging.info('Using SyncVectorEnv') 88 | elif config.USE_VECENV: 89 | env_launcher = habitat.VectorEnv 90 | logging.info('Using VectorEnv') 91 | else: 92 | env_launcher = habitat.ThreadedVectorEnv 93 | logging.info('Using ThreadedVectorEnv') 94 | 95 | envs = env_launcher( 96 | make_env_fn=make_env_fn, 97 | env_fn_args=tuple( 98 | tuple(zip(configs, env_classes, range(num_processes)))), 99 | auto_reset_done=auto_reset_done 100 | ) 101 | return envs 102 | 103 | 104 | def make_env_fn( 105 | config: Config, env_class: Type[Union[Env, RLEnv]], rank: int 106 | ) -> Union[Env, RLEnv]: 107 | r"""Creates an env of type env_class with specified config and rank. 108 | This is to be passed in as an argument when creating VectorEnv. 109 | Args: 110 | config: root exp config that has core env config node as well as 111 | env-specific config node. 112 | env_class: class type of the env to be created. 113 | rank: rank of env to be created (for seeding). 114 | Returns: 115 | env object created according to specification. 116 | """ 117 | if not config.USE_SYNC_VECENV: 118 | level = logging.DEBUG if config.DEBUG else logging.INFO 119 | logging.basicConfig(level=level, format='%(asctime)s, %(levelname)s: %(message)s', 120 | datefmt="%Y-%m-%d %H:%M:%S") 121 | random.seed(rank) 122 | np.random.seed(rank) 123 | torch.manual_seed(rank) 124 | 125 | dataset = make_dataset( 126 | config.TASK_CONFIG.DATASET.TYPE, config=config.TASK_CONFIG.DATASET 127 | ) 128 | env = env_class(config=config, dataset=dataset) 129 | env.seed(rank) 130 | return env 131 | -------------------------------------------------------------------------------- /ss_baselines/av_nav/models/rnn_state_encoder.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class RNNStateEncoder(nn.Module): 7 | def __init__( 8 | self, 9 | input_size: int, 10 | hidden_size: int, 11 | num_layers: int = 1, 12 | rnn_type: str = "GRU", 13 | ): 14 | r"""An RNN for encoding the state in RL. 15 | 16 | Supports masking the hidden state during various timesteps in the forward lass 17 | 18 | Args: 19 | input_size: The input size of the RNN 20 | hidden_size: The hidden size 21 | num_layers: The number of recurrent layers 22 | rnn_type: The RNN cell type. Must be GRU or LSTM 23 | """ 24 | 25 | super().__init__() 26 | self._num_recurrent_layers = num_layers 27 | self._rnn_type = rnn_type 28 | 29 | self.rnn = getattr(nn, rnn_type)( 30 | input_size=input_size, 31 | hidden_size=hidden_size, 32 | num_layers=num_layers, 33 | ) 34 | 35 | self.layer_init() 36 | 37 | def layer_init(self): 38 | for name, param in self.rnn.named_parameters(): 39 | if "weight" in name: 40 | nn.init.orthogonal_(param) 41 | elif "bias" in name: 42 | nn.init.constant_(param, 0) 43 | 44 | @property 45 | def num_recurrent_layers(self): 46 | return self._num_recurrent_layers * ( 47 | 2 if "LSTM" in self._rnn_type else 1 48 | ) 49 | 50 | def _pack_hidden(self, hidden_states): 51 | if "LSTM" in self._rnn_type: 52 | hidden_states = torch.cat( 53 | [hidden_states[0], hidden_states[1]], dim=0 54 | ) 55 | 56 | return hidden_states 57 | 58 | def _unpack_hidden(self, hidden_states): 59 | if "LSTM" in self._rnn_type: 60 | hidden_states = ( 61 | hidden_states[0 : self._num_recurrent_layers], 62 | hidden_states[self._num_recurrent_layers :], 63 | ) 64 | 65 | return hidden_states 66 | 67 | def _mask_hidden(self, hidden_states, masks): 68 | if isinstance(hidden_states, tuple): 69 | hidden_states = tuple(v * masks for v in hidden_states) 70 | else: 71 | hidden_states = masks * hidden_states 72 | 73 | return hidden_states 74 | 75 | def single_forward(self, x, hidden_states, masks): 76 | r"""Forward for a non-sequence input 77 | """ 78 | hidden_states = self._unpack_hidden(hidden_states) 79 | x, hidden_states = self.rnn( 80 | x.unsqueeze(0), 81 | self._mask_hidden(hidden_states, masks.unsqueeze(0)), 82 | ) 83 | x = x.squeeze(0) 84 | hidden_states = self._pack_hidden(hidden_states) 85 | return x, hidden_states 86 | 87 | def seq_forward(self, x, hidden_states, masks): 88 | r"""Forward for a sequence of length T 89 | 90 | Args: 91 | x: (T, N, -1) Tensor that has been flattened to (T * N, -1) 92 | hidden_states: The starting hidden state. 93 | masks: The masks to be applied to hidden state at every timestep. 94 | A (T, N) tensor flatten to (T * N) 95 | """ 96 | # x is a (T, N, -1) tensor flattened to (T * N, -1) 97 | n = hidden_states.size(1) 98 | t = int(x.size(0) / n) 99 | 100 | # unflatten 101 | # RuntimeError: shape '[200, 5]' is invalid for input of size 250 102 | # print("x",x.size()) # x torch.Size([1000, 1024]) 103 | x = x.view(t, n, x.size(1)) # (150,5,1024) 104 | # print("t",t) 105 | # print("n:",n) 106 | # print("masks:",masks.size()) 107 | masks = masks.view(t, n) # t=150,n=5 108 | 109 | # steps in sequence which have zero for any agent. Assume t=0 has 110 | # a zero in it. 111 | has_zeros = (masks[1:] == 0.0).any(dim=-1).nonzero().squeeze().cpu() 112 | 113 | # +1 to correct the masks[1:] 114 | if has_zeros.dim() == 0: 115 | has_zeros = [has_zeros.item() + 1] # handle scalar 116 | else: 117 | has_zeros = (has_zeros + 1).numpy().tolist() 118 | 119 | # add t=0 and t=T to the list 120 | has_zeros = [0] + has_zeros + [t] 121 | 122 | hidden_states = self._unpack_hidden(hidden_states) 123 | outputs = [] 124 | for i in range(len(has_zeros) - 1): 125 | # process steps that don't have any zeros in masks together 126 | start_idx = has_zeros[i] 127 | end_idx = has_zeros[i + 1] 128 | 129 | rnn_scores, hidden_states = self.rnn( 130 | x[start_idx:end_idx], 131 | self._mask_hidden( 132 | hidden_states, masks[start_idx].view(1, -1, 1) 133 | ), 134 | ) 135 | 136 | outputs.append(rnn_scores) 137 | 138 | # x is a (T, N, -1) tensor 139 | x = torch.cat(outputs, dim=0) 140 | x = x.view(t * n, -1) # flatten 141 | 142 | hidden_states = self._pack_hidden(hidden_states) 143 | return x, hidden_states 144 | 145 | def forward(self, x, hidden_states, masks): 146 | if x.size(0) == hidden_states.size(1): 147 | return self.single_forward(x, hidden_states, masks) 148 | else: 149 | return self.seq_forward(x, hidden_states, masks) 150 | -------------------------------------------------------------------------------- /ss_baselines/common/simple_agents.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import argparse 4 | from math import pi 5 | import logging 6 | 7 | import numpy as np 8 | 9 | import habitat 10 | from habitat.sims.habitat_simulator.actions import HabitatSimActions 11 | # from habitat.config.default import get_config 12 | from ss_baselines.common.benchmark import Benchmark 13 | from ss_baselines.tools.config_tool import get_task_config 14 | from trainer.ppo_avnav.policy import PointNavBaselinePolicy as avnNet 15 | 16 | 17 | class RandomAgent(habitat.Agent): 18 | def __init__(self, success_distance, goal_sensor_uuid): 19 | self.dist_threshold_to_stop = success_distance 20 | self.goal_sensor_uuid = goal_sensor_uuid 21 | 22 | def reset(self): 23 | pass 24 | 25 | def is_goal_reached(self, observations): 26 | # because the frame is in with polar coordinates 27 | dist = observations[self.goal_sensor_uuid][0] 28 | return dist <= self.dist_threshold_to_stop 29 | 30 | def act(self, observations): 31 | if self.is_goal_reached(observations): 32 | action = HabitatSimActions.STOP 33 | else: 34 | action = np.random.choice( 35 | [ 36 | HabitatSimActions.MOVE_FORWARD, 37 | HabitatSimActions.TURN_LEFT, 38 | HabitatSimActions.TURN_RIGHT, 39 | ] 40 | ) 41 | return {"action": action} 42 | 43 | 44 | class ForwardOnlyAgent(RandomAgent): 45 | def act(self, observations): 46 | if self.is_goal_reached(observations): 47 | action = HabitatSimActions.STOP 48 | else: 49 | action = HabitatSimActions.MOVE_FORWARD 50 | return {"action": action} 51 | 52 | 53 | class RandomForwardAgent(RandomAgent): 54 | def __init__(self, success_distance, goal_sensor_uuid): 55 | super().__init__(success_distance, goal_sensor_uuid) 56 | self.FORWARD_PROBABILITY = 0.8 57 | 58 | def act(self, observations): 59 | if self.is_goal_reached(observations): 60 | action = HabitatSimActions.STOP 61 | else: 62 | if np.random.uniform(0, 1, 1) < self.FORWARD_PROBABILITY: 63 | action = HabitatSimActions.MOVE_FORWARD 64 | else: 65 | action = np.random.choice( 66 | [HabitatSimActions.TURN_LEFT, HabitatSimActions.TURN_RIGHT] 67 | ) 68 | 69 | return {"action": action} 70 | 71 | 72 | class GoalFollower(RandomAgent): 73 | def __init__(self, success_distance, goal_sensor_uuid): 74 | super().__init__(success_distance, goal_sensor_uuid) 75 | self.pos_th = self.dist_threshold_to_stop 76 | self.angle_th = float(np.deg2rad(15)) 77 | self.random_prob = 0 78 | 79 | def normalize_angle(self, angle): 80 | if angle < -pi: 81 | angle = 2.0 * pi + angle 82 | if angle > pi: 83 | angle = -2.0 * pi + angle 84 | return angle 85 | 86 | def turn_towards_goal(self, angle_to_goal): 87 | if angle_to_goal > pi or ( 88 | (angle_to_goal < 0) and (angle_to_goal > -pi) 89 | ): 90 | action = HabitatSimActions.TURN_RIGHT 91 | else: 92 | action = HabitatSimActions.TURN_LEFT 93 | return action 94 | 95 | def act(self, observations): 96 | if self.is_goal_reached(observations): 97 | action = HabitatSimActions.STOP 98 | else: 99 | angle_to_goal = self.normalize_angle( 100 | np.array(observations[self.goal_sensor_uuid][1]) 101 | ) 102 | if abs(angle_to_goal) < self.angle_th: 103 | action = HabitatSimActions.MOVE_FORWARD 104 | else: 105 | action = self.turn_towards_goal(angle_to_goal) 106 | 107 | return {"action": action} 108 | 109 | 110 | def get_all_subclasses(cls): 111 | return set(cls.__subclasses__()).union( 112 | [s for c in cls.__subclasses__() for s in get_all_subclasses(c)] 113 | ) 114 | 115 | 116 | def get_agent_cls(agent_class_name): 117 | sub_classes = [ 118 | sub_class 119 | for sub_class in get_all_subclasses(habitat.Agent) 120 | if sub_class.__name__ == agent_class_name 121 | ] 122 | return sub_classes[0] 123 | 124 | 125 | def main(): 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument("--success-distance", type=float, default=0.2) 128 | parser.add_argument( 129 | "--task-config", type=str, default="configs/tasks/pointnav.yaml" 130 | ) 131 | parser.add_argument("--agent-class", type=str, default="RandomAgent") 132 | parser.add_argument("--debug", default=False, action="store_true") 133 | args = parser.parse_args() 134 | 135 | level = logging.DEBUG if args.debug else logging.INFO 136 | logging.basicConfig(level=level, format='%(asctime)s, %(levelname)s: %(message)s', 137 | datefmt="%Y-%m-%d %H:%M:%S") 138 | 139 | task_config = get_task_config(args.task_config) 140 | task_config.defrost() 141 | task_config.DATASET.SPLIT = 'test_telephone' 142 | task_config.freeze() 143 | 144 | agent = get_agent_cls(args.agent_class)( 145 | success_distance=args.success_distance, 146 | goal_sensor_uuid=task_config.TASK.GOAL_SENSOR_UUID, 147 | ) 148 | benchmark = Benchmark(task_config) 149 | metrics = benchmark.evaluate(agent) 150 | 151 | for k, v in metrics.items(): 152 | habitat.logger.info("{}: {:.3f}".format(k, v)) 153 | 154 | 155 | if __name__ == "__main__": 156 | main() 157 | -------------------------------------------------------------------------------- /ss_baselines/av_nav/models/visual_cnn.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | from ss_baselines.common.utils import Flatten 8 | 9 | 10 | def conv_output_dim(dimension, padding, dilation, kernel_size, stride 11 | ): 12 | r"""Calculates the output height and width based on the input 13 | height and width to the convolution layer. 14 | 15 | ref: https://pytorch.org/docs/master/nn.html#torch.nn.Conv2d 16 | """ 17 | assert len(dimension) == 2 18 | out_dimension = [] 19 | for i in range(len(dimension)): 20 | out_dimension.append( 21 | int( 22 | np.floor( 23 | ( 24 | ( 25 | dimension[i] 26 | + 2 * padding[i] 27 | - dilation[i] * (kernel_size[i] - 1) 28 | - 1 29 | ) 30 | / stride[i] 31 | ) 32 | + 1 33 | ) 34 | ) 35 | ) 36 | return tuple(out_dimension) 37 | 38 | 39 | def layer_init(cnn): 40 | for layer in cnn: 41 | if isinstance(layer, (nn.Conv2d, nn.Linear)): 42 | nn.init.kaiming_normal_( 43 | layer.weight, nn.init.calculate_gain("relu") 44 | ) 45 | if layer.bias is not None: 46 | nn.init.constant_(layer.bias, val=0) 47 | 48 | 49 | class VisualCNN(nn.Module): 50 | r"""A Simple 3-Conv CNN followed by a fully connected layer 51 | 52 | Takes in observations and produces an embedding of the rgb and/or depth components 53 | 54 | Args: 55 | observation_space: The observation_space of the agent 56 | output_size: The size of the embedding vector 57 | """ 58 | 59 | def __init__(self, observation_space, output_size, extra_rgb): 60 | super().__init__() 61 | if "rgb" in observation_space.spaces and not extra_rgb: 62 | self._n_input_rgb = observation_space.spaces["rgb"].shape[2] 63 | else: 64 | self._n_input_rgb = 0 65 | 66 | if "depth" in observation_space.spaces: 67 | self._n_input_depth = observation_space.spaces["depth"].shape[2] 68 | else: 69 | self._n_input_depth = 0 70 | 71 | # kernel size for different CNN layers 72 | self._cnn_layers_kernel_size = [(8, 8), (4, 4), (3, 3)] 73 | 74 | # strides for different CNN layers 75 | self._cnn_layers_stride = [(4, 4), (2, 2), (2, 2)] 76 | 77 | if self._n_input_rgb > 0: 78 | cnn_dims = np.array( 79 | observation_space.spaces["rgb"].shape[:2], dtype=np.float32 80 | ) 81 | elif self._n_input_depth > 0: 82 | cnn_dims = np.array( 83 | observation_space.spaces["depth"].shape[:2], dtype=np.float32 84 | ) 85 | 86 | if self.is_blind: 87 | self.cnn = nn.Sequential() 88 | else: 89 | for kernel_size, stride in zip( 90 | self._cnn_layers_kernel_size, self._cnn_layers_stride 91 | ): 92 | cnn_dims = conv_output_dim( 93 | dimension=cnn_dims, 94 | padding=np.array([0, 0], dtype=np.float32), 95 | dilation=np.array([1, 1], dtype=np.float32), 96 | kernel_size=np.array(kernel_size, dtype=np.float32), 97 | stride=np.array(stride, dtype=np.float32), 98 | ) 99 | 100 | self.cnn = nn.Sequential( 101 | nn.Conv2d( 102 | in_channels=self._n_input_rgb + self._n_input_depth, 103 | out_channels=32, 104 | kernel_size=self._cnn_layers_kernel_size[0], 105 | stride=self._cnn_layers_stride[0], 106 | ), 107 | nn.ReLU(True), 108 | nn.Conv2d( 109 | in_channels=32, 110 | out_channels=64, 111 | kernel_size=self._cnn_layers_kernel_size[1], 112 | stride=self._cnn_layers_stride[1], 113 | ), 114 | nn.ReLU(True), 115 | nn.Conv2d( 116 | in_channels=64, 117 | out_channels=64, 118 | kernel_size=self._cnn_layers_kernel_size[2], 119 | stride=self._cnn_layers_stride[2], 120 | ), 121 | # nn.ReLU(True), 122 | Flatten(), 123 | nn.Linear(64 * cnn_dims[0] * cnn_dims[1], output_size), 124 | nn.ReLU(True), 125 | ) 126 | 127 | layer_init(self.cnn) 128 | 129 | @property 130 | def is_blind(self): 131 | return self._n_input_rgb + self._n_input_depth == 0 132 | 133 | def forward(self, observations): 134 | cnn_input = [] 135 | if self._n_input_rgb > 0: 136 | rgb_observations = observations["rgb"] 137 | # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] 138 | rgb_observations = rgb_observations.permute(0, 3, 1, 2) 139 | rgb_observations = rgb_observations / 255.0 # normalize RGB 140 | cnn_input.append(rgb_observations) 141 | 142 | if self._n_input_depth > 0: 143 | depth_observations = observations["depth"] 144 | # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] 145 | depth_observations = depth_observations.permute(0, 3, 1, 2) 146 | cnn_input.append(depth_observations) 147 | 148 | cnn_input = torch.cat(cnn_input, dim=1) 149 | 150 | return self.cnn(cnn_input) 151 | -------------------------------------------------------------------------------- /trainer/ppo/ppo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import git 4 | import sys 5 | repo = git.Repo(".", search_parent_directories=True) 6 | if f"{repo.working_tree_dir}" not in sys.path: 7 | sys.path.append(f"{repo.working_tree_dir}") 8 | print("add") 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | 14 | EPS_PPO = 1e-5 15 | 16 | 17 | class PPO(nn.Module): 18 | def __init__( 19 | self, 20 | actor_critic, 21 | clip_param, 22 | ppo_epoch, 23 | num_mini_batch, 24 | value_loss_coef, 25 | entropy_coef, 26 | lr=None, 27 | eps=None, 28 | max_grad_norm=None, 29 | use_clipped_value_loss=True, 30 | use_normalized_advantage=True, 31 | ): 32 | 33 | super().__init__() 34 | 35 | self.actor_critic = actor_critic 36 | 37 | self.clip_param = clip_param 38 | self.ppo_epoch = ppo_epoch 39 | self.num_mini_batch = num_mini_batch 40 | 41 | self.value_loss_coef = value_loss_coef 42 | self.entropy_coef = entropy_coef 43 | 44 | self.max_grad_norm = max_grad_norm 45 | self.use_clipped_value_loss = use_clipped_value_loss 46 | 47 | self.optimizer = optim.Adam(actor_critic.parameters(), lr=lr, eps=eps) 48 | self.device = next(actor_critic.parameters()).device 49 | self.use_normalized_advantage = use_normalized_advantage 50 | 51 | def forward(self, *x): 52 | raise NotImplementedError 53 | 54 | def get_advantages(self, rollouts): 55 | advantages = rollouts.returns[:-1] - rollouts.value_preds[:-1] 56 | if not self.use_normalized_advantage: 57 | return advantages 58 | 59 | return (advantages - advantages.mean()) / (advantages.std() + EPS_PPO) 60 | 61 | def update(self, rollouts): 62 | advantages = self.get_advantages(rollouts) 63 | 64 | value_loss_epoch = 0 65 | action_loss_epoch = 0 66 | dist_entropy_epoch = 0 67 | total_loss_epoch = 0 68 | 69 | 70 | for e in range(self.ppo_epoch): 71 | data_generator = rollouts.recurrent_generator( 72 | advantages, self.num_mini_batch 73 | ) 74 | 75 | for sample in data_generator: 76 | ( 77 | obs_batch, 78 | recurrent_hidden_states_batch, 79 | actions_batch, 80 | prev_actions_batch, 81 | value_preds_batch, 82 | return_batch, 83 | masks_batch, 84 | old_action_log_probs_batch, 85 | adv_targ, 86 | ) = sample 87 | 88 | # Reshape to do in a single forward pass for all steps 89 | ( 90 | values, 91 | action_log_probs, 92 | dist_entropy, 93 | _, 94 | ) = self.actor_critic.evaluate_actions( 95 | obs_batch, 96 | recurrent_hidden_states_batch, 97 | prev_actions_batch, 98 | masks_batch, 99 | actions_batch, 100 | ) 101 | 102 | ratio = torch.exp( 103 | action_log_probs - old_action_log_probs_batch 104 | ) 105 | surr1 = ratio * adv_targ 106 | surr2 = ( 107 | torch.clamp( 108 | ratio, 1.0 - self.clip_param, 1.0 + self.clip_param 109 | ) 110 | * adv_targ 111 | ) 112 | action_loss = -torch.min(surr1, surr2).mean() 113 | 114 | if self.use_clipped_value_loss: 115 | value_pred_clipped = value_preds_batch + ( 116 | values - value_preds_batch 117 | ).clamp(-self.clip_param, self.clip_param) 118 | value_losses = (values - return_batch).pow(2) 119 | value_losses_clipped = ( 120 | value_pred_clipped - return_batch 121 | ).pow(2) 122 | value_loss = ( 123 | 0.5 124 | * torch.max(value_losses, value_losses_clipped).mean() 125 | ) 126 | else: 127 | value_loss = 0.5 * (return_batch - values).pow(2).mean() 128 | 129 | self.optimizer.zero_grad() 130 | total_loss = ( 131 | value_loss * self.value_loss_coef 132 | + action_loss 133 | - dist_entropy * self.entropy_coef 134 | ) 135 | 136 | self.before_backward(total_loss) 137 | total_loss.backward() 138 | self.after_backward(total_loss) 139 | 140 | self.before_step() 141 | self.optimizer.step() 142 | self.after_step() 143 | 144 | value_loss_epoch += value_loss.item() 145 | action_loss_epoch += action_loss.item() 146 | dist_entropy_epoch += dist_entropy.item() 147 | total_loss_epoch += total_loss.item() 148 | 149 | num_updates = self.ppo_epoch * self.num_mini_batch 150 | 151 | value_loss_epoch /= num_updates 152 | action_loss_epoch /= num_updates 153 | dist_entropy_epoch /= num_updates 154 | total_loss_epoch /= num_updates 155 | 156 | return value_loss_epoch, action_loss_epoch, dist_entropy_epoch, total_loss_epoch 157 | 158 | def before_backward(self, loss): 159 | pass 160 | 161 | def after_backward(self, loss): 162 | pass 163 | 164 | def before_step(self): 165 | nn.utils.clip_grad_norm_( 166 | self.actor_critic.parameters(), self.max_grad_norm 167 | ) 168 | 169 | def after_step(self): 170 | pass 171 | -------------------------------------------------------------------------------- /soundspaces/datasets/audionav_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import gzip 8 | import json 9 | import os 10 | import logging 11 | from typing import List, Optional 12 | 13 | from habitat.config import Config 14 | from habitat.core.dataset import Dataset 15 | from habitat.core.registry import registry 16 | 17 | from habitat.tasks.nav.nav import ( 18 | NavigationEpisode, 19 | NavigationGoal, 20 | ShortestPathPoint, 21 | ) 22 | 23 | 24 | ALL_SCENES_MASK = "*" 25 | CONTENT_SCENES_PATH_FIELD = "content_scenes_path" 26 | DEFAULT_SCENE_PATH_PREFIX = "data/scene_dataset/" 27 | 28 | 29 | @registry.register_dataset(name="AudioNav") 30 | class AudioNavDataset(Dataset): 31 | r"""Class inherited from Dataset that loads Audio Navigation dataset. 32 | """ 33 | 34 | episodes: List[NavigationEpisode] 35 | content_scenes_path: str = "{data_path}/content/{scene}.json.gz" 36 | 37 | @staticmethod 38 | def check_config_paths_exist(config: Config) -> bool: 39 | return os.path.exists( 40 | config.DATA_PATH.format(version=config.VERSION, split=config.SPLIT) 41 | ) and os.path.exists(config.SCENES_DIR) 42 | 43 | @staticmethod 44 | def get_scenes_to_load(config: Config) -> List[str]: 45 | r"""Return list of scene ids for which dataset has separate files with 46 | episodes. 47 | """ 48 | assert AudioNavDataset.check_config_paths_exist(config), \ 49 | (config.DATA_PATH.format(version=config.VERSION, split=config.SPLIT), config.SCENES_DIR) 50 | dataset_dir = os.path.dirname( 51 | config.DATA_PATH.format(version=config.VERSION, split=config.SPLIT) 52 | ) 53 | 54 | cfg = config.clone() 55 | cfg.defrost() 56 | cfg.CONTENT_SCENES = [] 57 | dataset = AudioNavDataset(cfg) 58 | return AudioNavDataset._get_scenes_from_folder( 59 | content_scenes_path=dataset.content_scenes_path, 60 | dataset_dir=dataset_dir, 61 | ) 62 | 63 | @staticmethod 64 | def _get_scenes_from_folder(content_scenes_path, dataset_dir): 65 | scenes = [] 66 | content_dir = content_scenes_path.split("{scene}")[0] 67 | scene_dataset_ext = content_scenes_path.split("{scene}")[1] 68 | content_dir = content_dir.format(data_path=dataset_dir) 69 | if not os.path.exists(content_dir): 70 | return scenes 71 | 72 | for filename in os.listdir(content_dir): 73 | if filename.endswith(scene_dataset_ext): 74 | scene = filename[: -len(scene_dataset_ext)] 75 | scenes.append(scene) 76 | scenes.sort() 77 | return scenes 78 | 79 | def __init__(self, config: Optional[Config] = None) -> None: 80 | self.episodes = [] 81 | self._config = config 82 | 83 | if config is None: 84 | return 85 | 86 | datasetfile_path = config.DATA_PATH.format(version=config.VERSION, split=config.SPLIT) 87 | with gzip.open(datasetfile_path, "rt") as f: 88 | self.from_json(f.read(), scenes_dir=config.SCENES_DIR, scene_filename=datasetfile_path) 89 | 90 | # Read separate file for each scene 91 | dataset_dir = os.path.dirname(datasetfile_path) 92 | scenes = config.CONTENT_SCENES 93 | if ALL_SCENES_MASK in scenes: 94 | scenes = AudioNavDataset._get_scenes_from_folder( 95 | content_scenes_path=self.content_scenes_path, 96 | dataset_dir=dataset_dir, 97 | ) 98 | 99 | last_episode_cnt = 0 100 | for scene in scenes: 101 | scene_filename = self.content_scenes_path.format( 102 | data_path=dataset_dir, scene=scene 103 | ) 104 | with gzip.open(scene_filename, "rt") as f: 105 | self.from_json(f.read(), scenes_dir=config.SCENES_DIR, scene_filename=scene_filename) 106 | 107 | num_episode = len(self.episodes) - last_episode_cnt 108 | last_episode_cnt = len(self.episodes) 109 | logging.info('Sampled {} from {}'.format(num_episode, scene)) 110 | 111 | def filter_by_ids(self, scene_ids): 112 | episodes_to_keep = list() 113 | 114 | for episode in self.episodes: 115 | for scene_id in scene_ids: 116 | scene, ep_id = scene_id.split(',') 117 | if scene in episode.scene_id and ep_id == episode.episode_id: 118 | episodes_to_keep.append(episode) 119 | 120 | self.episodes = episodes_to_keep 121 | 122 | # filter by scenes for data collection 123 | def filter_by_scenes(self, scene): 124 | episodes_to_keep = list() 125 | 126 | for episode in self.episodes: 127 | episode_scene = episode.scene_id.split("/")[3] 128 | if scene == episode_scene: 129 | episodes_to_keep.append(episode) 130 | 131 | self.episodes = episodes_to_keep 132 | 133 | def from_json( 134 | self, json_str: str, scenes_dir: Optional[str] = None, scene_filename: Optional[str] = None 135 | ) -> None: 136 | deserialized = json.loads(json_str) 137 | if CONTENT_SCENES_PATH_FIELD in deserialized: 138 | self.content_scenes_path = deserialized[CONTENT_SCENES_PATH_FIELD] 139 | 140 | episode_cnt = 0 141 | for episode in deserialized["episodes"]: 142 | episode = NavigationEpisode(**episode) 143 | 144 | if scenes_dir is not None: 145 | if episode.scene_id.startswith(DEFAULT_SCENE_PATH_PREFIX): 146 | episode.scene_id = episode.scene_id[ 147 | len(DEFAULT_SCENE_PATH_PREFIX): 148 | ] 149 | 150 | episode.scene_id = os.path.join(scenes_dir, episode.scene_id) 151 | 152 | for g_index, goal in enumerate(episode.goals): 153 | episode.goals[g_index] = NavigationGoal(**goal) 154 | if episode.shortest_paths is not None: 155 | for path in episode.shortest_paths: 156 | for p_index, point in enumerate(path): 157 | path[p_index] = ShortestPathPoint(**point) 158 | self.episodes.append(episode) 159 | episode_cnt += 1 160 | -------------------------------------------------------------------------------- /trainer/ppo/policy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import git 4 | import sys 5 | repo = git.Repo(".", search_parent_directories=True) 6 | if f"{repo.working_tree_dir}" not in sys.path: 7 | sys.path.append(f"{repo.working_tree_dir}") 8 | print("add") 9 | import abc 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torchsummary import summary 14 | 15 | from ss_baselines.common.utils import CategoricalNet 16 | from ss_baselines.av_nav.models.rnn_state_encoder import RNNStateEncoder 17 | from ss_baselines.av_nav.models.visual_cnn import VisualCNN 18 | from ss_baselines.av_nav.models.audio_cnn import AudioCNN 19 | 20 | DUAL_GOAL_DELIMITER = ',' 21 | 22 | 23 | class Policy(nn.Module): 24 | def __init__(self, net, dim_actions): 25 | super().__init__() 26 | self.net = net 27 | self.dim_actions = dim_actions 28 | 29 | self.action_distribution = CategoricalNet( 30 | self.net.output_size, self.dim_actions 31 | ) 32 | self.critic = CriticHead(self.net.output_size) 33 | 34 | def forward(self, *x): 35 | raise NotImplementedError 36 | 37 | def act( 38 | self, 39 | observations, 40 | rnn_hidden_states, 41 | prev_actions, 42 | masks, 43 | deterministic=False, 44 | ): 45 | features, rnn_hidden_states = self.net( 46 | observations, rnn_hidden_states, prev_actions, masks 47 | ) 48 | # print('Features: ', features.cpu().numpy()) 49 | distribution = self.action_distribution(features) 50 | # print('Distribution: ', distribution.logits.cpu().numpy()) 51 | value = self.critic(features) 52 | # print('Value: ', value.item()) 53 | 54 | if deterministic: 55 | action = distribution.mode() 56 | # print('Deterministic action: ', action.item()) 57 | else: 58 | action = distribution.sample() 59 | # print('Sample action: ', action.item()) 60 | 61 | action_log_probs = distribution.log_probs(action) 62 | 63 | return value, action, action_log_probs, rnn_hidden_states 64 | 65 | def get_value(self, observations, rnn_hidden_states, prev_actions, masks): 66 | features, _ = self.net( 67 | observations, rnn_hidden_states, prev_actions, masks 68 | ) 69 | return self.critic(features) 70 | 71 | def evaluate_actions( 72 | self, observations, rnn_hidden_states, prev_actions, masks, action 73 | ): 74 | features, rnn_hidden_states = self.net( 75 | observations, rnn_hidden_states, prev_actions, masks 76 | ) 77 | distribution = self.action_distribution(features) 78 | value = self.critic(features) 79 | 80 | action_log_probs = distribution.log_probs(action) 81 | distribution_entropy = distribution.entropy().mean() 82 | 83 | return value, action_log_probs, distribution_entropy, rnn_hidden_states 84 | 85 | 86 | class CriticHead(nn.Module): 87 | def __init__(self, input_size): 88 | super().__init__() 89 | self.fc = nn.Linear(input_size, 1) 90 | nn.init.orthogonal_(self.fc.weight) 91 | nn.init.constant_(self.fc.bias, 0) 92 | 93 | def forward(self, x): 94 | return self.fc(x) 95 | 96 | 97 | class PointNavBaselinePolicy(Policy): 98 | def __init__( 99 | self, 100 | observation_space, 101 | action_space, 102 | goal_sensor_uuid, 103 | hidden_size=512, 104 | extra_rgb=False 105 | ): 106 | super().__init__( 107 | PointNavBaselineNet( 108 | observation_space=observation_space, 109 | hidden_size=hidden_size, 110 | goal_sensor_uuid=goal_sensor_uuid, 111 | extra_rgb=extra_rgb 112 | ), 113 | action_space.n, 114 | ) 115 | 116 | 117 | class Net(nn.Module, metaclass=abc.ABCMeta): 118 | @abc.abstractmethod 119 | def forward(self, observations, rnn_hidden_states, prev_actions, masks): 120 | pass 121 | 122 | @property 123 | @abc.abstractmethod 124 | def output_size(self): 125 | pass 126 | 127 | @property 128 | @abc.abstractmethod 129 | def num_recurrent_layers(self): 130 | pass 131 | 132 | @property 133 | @abc.abstractmethod 134 | def is_blind(self): 135 | pass 136 | 137 | 138 | class PointNavBaselineNet(Net): 139 | r"""Network which passes the input image through CNN and concatenates 140 | goal vector with CNN's output and passes that through RNN. 141 | """ 142 | 143 | def __init__(self, observation_space, hidden_size, goal_sensor_uuid, extra_rgb=False): 144 | super().__init__() 145 | self.goal_sensor_uuid = goal_sensor_uuid 146 | self._hidden_size = hidden_size 147 | self._audiogoal = False 148 | self._pointgoal = False 149 | self._n_pointgoal = 0 150 | 151 | if DUAL_GOAL_DELIMITER in self.goal_sensor_uuid: 152 | goal1_uuid, goal2_uuid = self.goal_sensor_uuid.split(DUAL_GOAL_DELIMITER) 153 | self._audiogoal = self._pointgoal = True 154 | self._n_pointgoal = observation_space.spaces[goal1_uuid].shape[0] 155 | else: 156 | if 'pointgoal_with_gps_compass' == self.goal_sensor_uuid: 157 | self._pointgoal = True 158 | self._n_pointgoal = observation_space.spaces[self.goal_sensor_uuid].shape[0] 159 | else: 160 | self._audiogoal = True 161 | 162 | self.visual_encoder = VisualCNN(observation_space, hidden_size, extra_rgb) 163 | if self._audiogoal: 164 | if 'audiogoal' in self.goal_sensor_uuid: 165 | audiogoal_sensor = 'audiogoal' 166 | elif 'spectrogram' in self.goal_sensor_uuid: 167 | audiogoal_sensor = 'spectrogram' 168 | self.audio_encoder = AudioCNN(observation_space, hidden_size, audiogoal_sensor) 169 | 170 | rnn_input_size = (0 if self.is_blind else self._hidden_size) + \ 171 | (self._n_pointgoal if self._pointgoal else 0) + (self._hidden_size if self._audiogoal else 0) 172 | self.state_encoder = RNNStateEncoder(rnn_input_size, self._hidden_size) 173 | 174 | if 'rgb' in observation_space.spaces and not extra_rgb: 175 | rgb_shape = observation_space.spaces['rgb'].shape 176 | summary(self.visual_encoder.cnn, (rgb_shape[2], rgb_shape[0], rgb_shape[1]), device='cpu') 177 | if 'depth' in observation_space.spaces: 178 | depth_shape = observation_space.spaces['depth'].shape 179 | summary(self.visual_encoder.cnn, (depth_shape[2], depth_shape[0], depth_shape[1]), device='cpu') 180 | if self._audiogoal: 181 | audio_shape = observation_space.spaces[audiogoal_sensor].shape 182 | summary(self.audio_encoder.cnn, (audio_shape[2], audio_shape[0], audio_shape[1]), device='cpu') 183 | 184 | self.train() 185 | 186 | @property 187 | def output_size(self): 188 | return self._hidden_size 189 | 190 | @property 191 | def is_blind(self): 192 | return self.visual_encoder.is_blind 193 | 194 | @property 195 | def num_recurrent_layers(self): 196 | return self.state_encoder.num_recurrent_layers 197 | 198 | def forward(self, observations, rnn_hidden_states, prev_actions, masks): 199 | x = [] 200 | 201 | if self._pointgoal: 202 | x.append(observations[self.goal_sensor_uuid.split(DUAL_GOAL_DELIMITER)[0]]) 203 | if self._audiogoal: 204 | x.append(self.audio_encoder(observations)) 205 | if not self.is_blind: 206 | x.append(self.visual_encoder(observations)) 207 | 208 | x1 = torch.cat(x, dim=1) 209 | x2, rnn_hidden_states1 = self.state_encoder(x1, rnn_hidden_states, masks) 210 | 211 | assert not torch.isnan(x2).any().item() 212 | 213 | return x2, rnn_hidden_states1 214 | -------------------------------------------------------------------------------- /ss_baselines/common/rollout_storage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from collections import defaultdict 4 | 5 | import torch 6 | 7 | 8 | class RolloutStorage: 9 | r"""Class for storing rollout information for RL trainers. 10 | 11 | """ 12 | 13 | def __init__( 14 | self, 15 | num_steps, 16 | num_envs, 17 | observation_space, 18 | action_space, 19 | recurrent_hidden_state_size, 20 | num_recurrent_layers=1, 21 | ): 22 | self.observations = {} 23 | 24 | for sensor in observation_space.spaces: 25 | self.observations[sensor] = torch.zeros( 26 | num_steps + 1, 27 | num_envs, 28 | *observation_space.spaces[sensor].shape 29 | ) 30 | 31 | self.recurrent_hidden_states = torch.zeros( 32 | num_steps + 1, 33 | num_recurrent_layers, 34 | num_envs, 35 | recurrent_hidden_state_size, 36 | ) 37 | 38 | self.rewards = torch.zeros(num_steps, num_envs, 1) 39 | self.value_preds = torch.zeros(num_steps + 1, num_envs, 1) 40 | self.returns = torch.zeros(num_steps + 1, num_envs, 1) 41 | 42 | self.action_log_probs = torch.zeros(num_steps, num_envs, 1) 43 | if action_space.__class__.__name__ == "ActionSpace": 44 | action_shape = 1 45 | else: 46 | action_shape = action_space.shape[0] 47 | 48 | self.actions = torch.zeros(num_steps, num_envs, action_shape) 49 | self.prev_actions = torch.zeros(num_steps + 1, num_envs, action_shape) 50 | if action_space.__class__.__name__ == "ActionSpace": 51 | self.actions = self.actions.long() 52 | self.prev_actions = self.prev_actions.long() 53 | 54 | self.masks = torch.ones(num_steps + 1, num_envs, 1) 55 | 56 | self.num_steps = num_steps 57 | self.step = 0 58 | 59 | def to(self, device): 60 | for sensor in self.observations: 61 | self.observations[sensor] = self.observations[sensor].to(device) 62 | 63 | self.recurrent_hidden_states = self.recurrent_hidden_states.to(device) 64 | self.rewards = self.rewards.to(device) 65 | self.value_preds = self.value_preds.to(device) 66 | self.returns = self.returns.to(device) 67 | self.action_log_probs = self.action_log_probs.to(device) 68 | self.actions = self.actions.to(device) 69 | self.prev_actions = self.prev_actions.to(device) 70 | self.masks = self.masks.to(device) 71 | 72 | def insert( 73 | self, 74 | observations, 75 | recurrent_hidden_states, 76 | actions, 77 | action_log_probs, 78 | value_preds, 79 | rewards, 80 | masks, 81 | ): 82 | for sensor in observations: 83 | self.observations[sensor][self.step + 1].copy_( 84 | observations[sensor] 85 | ) 86 | self.recurrent_hidden_states[self.step + 1].copy_( 87 | recurrent_hidden_states 88 | ) 89 | self.actions[self.step].copy_(actions) 90 | self.prev_actions[self.step + 1].copy_(actions) 91 | self.action_log_probs[self.step].copy_(action_log_probs) 92 | self.value_preds[self.step].copy_(value_preds) 93 | self.rewards[self.step].copy_(rewards) 94 | self.masks[self.step + 1].copy_(masks) 95 | 96 | self.step = (self.step + 1) % self.num_steps 97 | 98 | def after_update(self): 99 | for sensor in self.observations: 100 | self.observations[sensor][0].copy_(self.observations[sensor][-1]) 101 | 102 | self.recurrent_hidden_states[0].copy_(self.recurrent_hidden_states[-1]) 103 | self.masks[0].copy_(self.masks[-1]) 104 | self.prev_actions[0].copy_(self.prev_actions[-1]) 105 | 106 | def compute_returns(self, next_value, use_gae, gamma, tau): 107 | if use_gae: 108 | self.value_preds[-1] = next_value 109 | gae = 0 110 | for step in reversed(range(self.rewards.size(0))): 111 | delta = ( 112 | self.rewards[step] 113 | + gamma * self.value_preds[step + 1] * self.masks[step + 1] 114 | - self.value_preds[step] 115 | ) 116 | gae = delta + gamma * tau * self.masks[step + 1] * gae 117 | self.returns[step] = gae + self.value_preds[step] 118 | else: 119 | self.returns[-1] = next_value 120 | for step in reversed(range(self.rewards.size(0))): 121 | self.returns[step] = ( 122 | self.returns[step + 1] * gamma * self.masks[step + 1] 123 | + self.rewards[step] 124 | ) 125 | 126 | def recurrent_generator(self, advantages, num_mini_batch): 127 | num_processes = self.rewards.size(1) 128 | assert num_processes >= num_mini_batch, ( 129 | "Trainer requires the number of processes ({}) " 130 | "to be greater than or equal to the number of " 131 | "trainer mini batches ({}).".format(num_processes, num_mini_batch) 132 | ) 133 | num_envs_per_batch = num_processes // num_mini_batch 134 | perm = torch.randperm(num_processes) 135 | for start_ind in range(0, num_processes, num_envs_per_batch): 136 | observations_batch = defaultdict(list) 137 | 138 | recurrent_hidden_states_batch = [] 139 | actions_batch = [] 140 | prev_actions_batch = [] 141 | value_preds_batch = [] 142 | return_batch = [] 143 | masks_batch = [] 144 | old_action_log_probs_batch = [] 145 | adv_targ = [] 146 | 147 | for offset in range(num_envs_per_batch): 148 | ind = perm[start_ind + offset] 149 | 150 | for sensor in self.observations: 151 | observations_batch[sensor].append( 152 | self.observations[sensor][:-1, ind] 153 | ) 154 | 155 | recurrent_hidden_states_batch.append( 156 | self.recurrent_hidden_states[0, :, ind] 157 | ) 158 | 159 | actions_batch.append(self.actions[:, ind]) 160 | prev_actions_batch.append(self.prev_actions[:-1, ind]) 161 | value_preds_batch.append(self.value_preds[:-1, ind]) 162 | return_batch.append(self.returns[:-1, ind]) 163 | masks_batch.append(self.masks[:-1, ind]) 164 | old_action_log_probs_batch.append( 165 | self.action_log_probs[:, ind] 166 | ) 167 | 168 | adv_targ.append(advantages[:, ind]) 169 | 170 | T, N = self.num_steps, num_envs_per_batch 171 | 172 | # These are all tensors of size (T, N, -1) 173 | for sensor in observations_batch: 174 | observations_batch[sensor] = torch.stack( 175 | observations_batch[sensor], 1 176 | ) 177 | 178 | actions_batch = torch.stack(actions_batch, 1) 179 | prev_actions_batch = torch.stack(prev_actions_batch, 1) 180 | value_preds_batch = torch.stack(value_preds_batch, 1) 181 | return_batch = torch.stack(return_batch, 1) 182 | masks_batch = torch.stack(masks_batch, 1) 183 | old_action_log_probs_batch = torch.stack( 184 | old_action_log_probs_batch, 1 185 | ) 186 | adv_targ = torch.stack(adv_targ, 1) 187 | 188 | # States is just a (num_recurrent_layers, N, -1) tensor 189 | recurrent_hidden_states_batch = torch.stack( 190 | recurrent_hidden_states_batch, 1 191 | ) 192 | 193 | # Flatten the (T, N, ...) tensors to (T * N, ...) 194 | for sensor in observations_batch: 195 | observations_batch[sensor] = self._flatten_helper( 196 | T, N, observations_batch[sensor] 197 | ) 198 | 199 | actions_batch = self._flatten_helper(T, N, actions_batch) 200 | prev_actions_batch = self._flatten_helper(T, N, prev_actions_batch) 201 | value_preds_batch = self._flatten_helper(T, N, value_preds_batch) 202 | return_batch = self._flatten_helper(T, N, return_batch) 203 | masks_batch = self._flatten_helper(T, N, masks_batch) 204 | old_action_log_probs_batch = self._flatten_helper( 205 | T, N, old_action_log_probs_batch 206 | ) 207 | adv_targ = self._flatten_helper(T, N, adv_targ) 208 | 209 | yield ( 210 | observations_batch, 211 | recurrent_hidden_states_batch, 212 | actions_batch, 213 | prev_actions_batch, 214 | value_preds_batch, 215 | return_batch, 216 | masks_batch, 217 | old_action_log_probs_batch, 218 | adv_targ, 219 | ) 220 | 221 | @staticmethod 222 | def _flatten_helper(t: int, n: int, tensor: torch.Tensor) -> torch.Tensor: 223 | r"""Given a tensor of size (t, n, ..), flatten it to size (t*n, ...). 224 | 225 | Args: 226 | t: first dimension of tensor. 227 | n: second dimension of tensor. 228 | tensor: target tensor to be flattened. 229 | 230 | Returns: 231 | flattened tensor of size (t*n, ...) 232 | """ 233 | return tensor.view(t * n, *tensor.size()[2:]) 234 | -------------------------------------------------------------------------------- /soundspaces/tasks/audionav_task.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Any, Type, Union 8 | import logging 9 | 10 | import numpy as np 11 | import librosa 12 | from gym import spaces 13 | from skimage.measure import block_reduce 14 | 15 | from habitat.config import Config 16 | from habitat.core.dataset import Episode 17 | 18 | from habitat.tasks.nav.nav import NavigationTask, Measure, EmbodiedTask 19 | from habitat.core.registry import registry 20 | 21 | from habitat.core.simulator import ( 22 | Sensor, 23 | SensorTypes, 24 | Simulator, 25 | ) 26 | 27 | 28 | @registry.register_task(name="AudioNav") 29 | class AudioNavigationTask(NavigationTask): 30 | def overwrite_sim_config( 31 | self, sim_config: Any, episode: Type[Episode] 32 | ) -> Any: 33 | return merge_sim_episode_config(sim_config, episode) 34 | 35 | 36 | def merge_sim_episode_config( 37 | sim_config: Config, episode: Type[Episode] 38 | ) -> Any: 39 | sim_config.defrost() 40 | # here's where the scene update happens, extract the scene name out of the path 41 | sim_config.SCENE = episode.scene_id 42 | sim_config.freeze() 43 | if ( 44 | episode.start_position is not None 45 | and episode.start_rotation is not None 46 | ): 47 | agent_name = sim_config.AGENTS[sim_config.DEFAULT_AGENT_ID] 48 | agent_cfg = getattr(sim_config, agent_name) 49 | agent_cfg.defrost() 50 | agent_cfg.START_POSITION = episode.start_position 51 | agent_cfg.START_ROTATION = episode.start_rotation 52 | agent_cfg.GOAL_POSITION = episode.goals[0].position 53 | agent_cfg.SOUND = episode.info['sound'] 54 | agent_cfg.IS_SET_START_STATE = True 55 | agent_cfg.freeze() 56 | return sim_config 57 | 58 | 59 | @registry.register_sensor 60 | class AudioGoalSensor(Sensor): 61 | def __init__(self, *args: Any, sim: Simulator, config: Config, **kwargs: Any): 62 | self._sim = sim 63 | super().__init__(config=config) 64 | 65 | def _get_uuid(self, *args: Any, **kwargs: Any): 66 | return "audiogoal" 67 | 68 | def _get_sensor_type(self, *args: Any, **kwargs: Any): 69 | return SensorTypes.PATH 70 | 71 | def _get_observation_space(self, *args: Any, **kwargs: Any): 72 | sensor_shape = (2, self._sim.config.AUDIO.RIR_SAMPLING_RATE) 73 | 74 | return spaces.Box( 75 | low=np.finfo(np.float32).min, 76 | high=np.finfo(np.float32).max, 77 | shape=sensor_shape, 78 | dtype=np.float32, 79 | ) 80 | 81 | def get_observation(self, *args: Any, observations, episode: Episode, **kwargs: Any): 82 | return self._sim.get_current_audiogoal_observation() 83 | 84 | 85 | @registry.register_sensor 86 | class SpectrogramSensor(Sensor): 87 | def __init__(self, *args: Any, sim: Simulator, config: Config, **kwargs: Any): 88 | self._sim = sim 89 | super().__init__(config=config) 90 | 91 | def _get_uuid(self, *args: Any, **kwargs: Any): 92 | return "spectrogram" 93 | 94 | def _get_sensor_type(self, *args: Any, **kwargs: Any): 95 | return SensorTypes.PATH 96 | 97 | def _get_observation_space(self, *args: Any, **kwargs: Any): 98 | spectrogram = self.compute_spectrogram(np.ones((2, self._sim.config.AUDIO.RIR_SAMPLING_RATE))) 99 | 100 | return spaces.Box( 101 | low=np.finfo(np.float32).min, 102 | high=np.finfo(np.float32).max, 103 | shape=spectrogram.shape, 104 | dtype=np.float32, 105 | ) 106 | 107 | @staticmethod 108 | def compute_spectrogram(audio_data): 109 | def compute_stft(signal): 110 | n_fft = 512 111 | hop_length = 160 112 | win_length = 400 113 | stft = np.abs(librosa.stft(signal, n_fft=n_fft, hop_length=hop_length, win_length=win_length)) 114 | stft = block_reduce(stft, block_size=(4, 4), func=np.mean) 115 | return stft 116 | 117 | channel1_magnitude = np.log1p(compute_stft(audio_data[0])) 118 | channel2_magnitude = np.log1p(compute_stft(audio_data[1])) 119 | spectrogram = np.stack([channel1_magnitude, channel2_magnitude], axis=-1) 120 | 121 | return spectrogram 122 | 123 | def get_observation(self, *args: Any, observations, episode: Episode, **kwargs: Any): 124 | spectrogram = self._sim.get_current_spectrogram_observation(self.compute_spectrogram) 125 | 126 | return spectrogram 127 | 128 | 129 | @registry.register_measure 130 | class DistanceToGoal(Measure): 131 | r""" Distance to goal the episode ends 132 | """ 133 | 134 | def __init__( 135 | self, *args: Any, sim: Simulator, config: Config, **kwargs: Any 136 | ): 137 | self._start_end_episode_distance = None 138 | self._sim = sim 139 | self._config = config 140 | 141 | super().__init__() 142 | 143 | def _get_uuid(self, *args: Any, **kwargs: Any): 144 | return "distance_to_goal" 145 | 146 | def reset_metric(self, *args: Any, episode, **kwargs: Any): 147 | self._start_end_episode_distance = episode.info["geodesic_distance"] 148 | self._metric = None 149 | self.update_metric(episode=episode, *args, **kwargs) 150 | 151 | def update_metric( 152 | self, *args: Any, episode, **kwargs: Any 153 | ): 154 | current_position = self._sim.get_agent_state().position.tolist() 155 | 156 | distance_to_target = self._sim.geodesic_distance( 157 | current_position, episode.goals[0].position 158 | ) 159 | 160 | self._metric = distance_to_target 161 | 162 | 163 | @registry.register_measure 164 | class NormalizedDistanceToGoal(Measure): 165 | r""" Distance to goal the episode ends 166 | """ 167 | 168 | def __init__( 169 | self, *args: Any, sim: Simulator, config: Config, **kwargs: Any 170 | ): 171 | self._start_end_episode_distance = None 172 | self._sim = sim 173 | self._config = config 174 | 175 | super().__init__() 176 | 177 | def _get_uuid(self, *args: Any, **kwargs: Any): 178 | return "normalized_distance_to_goal" 179 | 180 | def reset_metric(self, *args: Any, episode, **kwargs: Any): 181 | self._start_end_episode_distance = episode.info["geodesic_distance"] 182 | self._metric = None 183 | 184 | def update_metric( 185 | self, *args: Any, episode, action, task: EmbodiedTask, **kwargs: Any 186 | ): 187 | current_position = self._sim.get_agent_state().position.tolist() 188 | 189 | distance_to_target = self._sim.geodesic_distance( 190 | current_position, episode.goals[0].position 191 | ) 192 | 193 | self._metric = distance_to_target / self._start_end_episode_distance 194 | 195 | 196 | @registry.register_sensor(name="Collision") 197 | class Collision(Sensor): 198 | def __init__( 199 | self, sim: Union[Simulator, Config], config: Config, *args: Any, **kwargs: Any 200 | ): 201 | super().__init__(config=config) 202 | self._sim = sim 203 | 204 | def _get_uuid(self, *args: Any, **kwargs: Any): 205 | return "collision" 206 | 207 | def _get_sensor_type(self, *args: Any, **kwargs: Any): 208 | return SensorTypes.COLOR 209 | 210 | def _get_observation_space(self, *args: Any, **kwargs: Any): 211 | return spaces.Box( 212 | low=0, 213 | high=1, 214 | shape=(1,), 215 | dtype=bool 216 | ) 217 | 218 | def get_observation( 219 | self, *args: Any, observations, episode: Episode, **kwargs: Any 220 | ) -> object: 221 | return [self._sim.previous_step_collided] 222 | 223 | 224 | @registry.register_measure 225 | class SNA(Measure): 226 | r"""SPL (Success weighted by Path Length) 227 | 228 | ref: On Evaluation of Embodied Agents - Anderson et. al 229 | https://arxiv.org/pdf/1807.06757.pdf 230 | """ 231 | 232 | def __init__( 233 | self, *args: Any, sim: Simulator, config: Config, **kwargs: Any 234 | ): 235 | self._start_end_num_action = None 236 | self._agent_num_action = None 237 | self._sim = sim 238 | self._config = config 239 | 240 | super().__init__() 241 | 242 | def _get_uuid(self, *args: Any, **kwargs: Any): 243 | return "sna" 244 | 245 | def reset_metric(self, *args: Any, episode, **kwargs: Any): 246 | self._start_end_num_action = episode.info["num_action"] 247 | self._agent_num_action = 0 248 | self._metric = None 249 | 250 | def update_metric( 251 | self, *args: Any, episode, action, task: EmbodiedTask, **kwargs: Any 252 | ): 253 | ep_success = 0 254 | current_position = self._sim.get_agent_state().position.tolist() 255 | 256 | distance_to_target = self._sim.geodesic_distance( 257 | current_position, episode.goals[0].position 258 | ) 259 | 260 | if ( 261 | hasattr(task, "is_stop_called") 262 | and task.is_stop_called 263 | and distance_to_target < 0.25 264 | ): 265 | ep_success = 1 266 | 267 | self._agent_num_action += 1 268 | 269 | self._metric = ep_success * ( 270 | self._start_end_num_action 271 | / max( 272 | self._start_end_num_action, self._agent_num_action 273 | ) 274 | ) 275 | 276 | 277 | @registry.register_measure 278 | class NA(Measure): 279 | r""" Number of actions 280 | 281 | ref: On Evaluation of Embodied Agents - Anderson et. al 282 | https://arxiv.org/pdf/1807.06757.pdf 283 | """ 284 | 285 | def __init__( 286 | self, *args: Any, sim: Simulator, config: Config, **kwargs: Any 287 | ): 288 | self._agent_num_action = None 289 | self._sim = sim 290 | self._config = config 291 | 292 | super().__init__() 293 | 294 | def _get_uuid(self, *args: Any, **kwargs: Any): 295 | return "na" 296 | 297 | def reset_metric(self, *args: Any, episode, **kwargs: Any): 298 | self._agent_num_action = 0 299 | self._metric = None 300 | 301 | def update_metric( 302 | self, *args: Any, episode, action, task: EmbodiedTask, **kwargs: Any 303 | ): 304 | self._agent_num_action += 1 305 | self._metric = self._agent_num_action 306 | -------------------------------------------------------------------------------- /ss_baselines/common/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import glob 4 | import os 5 | import datetime 6 | from collections import defaultdict 7 | from typing import Dict, List, Optional 8 | import random 9 | 10 | import numpy as np 11 | import cv2 12 | from scipy.io import wavfile 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as f 16 | import moviepy.editor as mpy 17 | from moviepy.audio.AudioClip import CompositeAudioClip 18 | 19 | from habitat.utils.visualizations.utils import images_to_video 20 | from ss_baselines.common.tensorboard_utils import TensorboardWriter 21 | from habitat.utils.visualizations import maps 22 | 23 | 24 | class Flatten(nn.Module): 25 | def forward(self, x): 26 | return x.reshape(x.size(0), -1) 27 | 28 | 29 | class CustomFixedCategorical(torch.distributions.Categorical): 30 | def sample(self, sample_shape=torch.Size()): 31 | return super().sample(sample_shape).unsqueeze(-1) 32 | 33 | def log_probs(self, actions): 34 | return ( 35 | super() 36 | .log_prob(actions.squeeze(-1)) 37 | .view(actions.size(0), -1) 38 | .sum(-1) 39 | .unsqueeze(-1) 40 | ) 41 | 42 | def mode(self): 43 | return self.probs.argmax(dim=-1, keepdim=True) 44 | 45 | 46 | class CategoricalNet(nn.Module): 47 | def __init__(self, num_inputs, num_outputs): 48 | super().__init__() 49 | 50 | self.linear = nn.Linear(num_inputs, num_outputs) 51 | 52 | nn.init.orthogonal_(self.linear.weight, gain=0.01) 53 | nn.init.constant_(self.linear.bias, 0) 54 | 55 | def forward(self, x): 56 | x = self.linear(x) 57 | return CustomFixedCategorical(logits=x) 58 | 59 | 60 | class CategoricalNetWithMask(nn.Module): 61 | def __init__(self, num_inputs, num_outputs, masking): 62 | super().__init__() 63 | self.masking = masking 64 | 65 | self.linear = nn.Linear(num_inputs, num_outputs) 66 | 67 | nn.init.orthogonal_(self.linear.weight, gain=0.01) 68 | nn.init.constant_(self.linear.bias, 0) 69 | 70 | def forward(self, features, action_maps): 71 | probs = f.softmax(self.linear(features)) 72 | if self.masking: 73 | probs = probs * torch.reshape(action_maps, (action_maps.shape[0], -1)).float() 74 | 75 | return CustomFixedCategorical(probs=probs) 76 | 77 | 78 | def linear_decay(epoch: int, total_num_updates: int) -> float: 79 | r"""Returns a multiplicative factor for linear value decay 80 | 81 | Args: 82 | epoch: current epoch number 83 | total_num_updates: total number of epochs 84 | 85 | Returns: 86 | multiplicative factor that decreases param value linearly 87 | """ 88 | return 1 - (epoch / float(total_num_updates)) 89 | 90 | 91 | def exponential_decay(epoch: int, total_num_updates: int, decay_lambda: float): 92 | r"""Returns a multiplicative factor for linear value decay 93 | 94 | Args: 95 | epoch: current epoch number 96 | total_num_updates: total number of epochs 97 | decay_lambda: decay lambda 98 | 99 | Returns: 100 | multiplicative factor that decreases param value linearly 101 | """ 102 | return np.exp(-decay_lambda * (epoch / float(total_num_updates))) 103 | 104 | 105 | def to_tensor(v): 106 | if torch.is_tensor(v): 107 | return v 108 | elif isinstance(v, np.ndarray): 109 | return torch.from_numpy(v) 110 | else: 111 | return torch.tensor(v, dtype=torch.float) 112 | 113 | 114 | def batch_obs( 115 | observations: List[Dict], device: Optional[torch.device] = None 116 | ) :#-> Dict[str, torch.Tensor]: 117 | r"""Transpose a batch of observation dicts to a dict of batched 118 | observations. 119 | 120 | Args: 121 | observations: list of dicts of observations. 122 | device: The torch.device to put the resulting tensors on. 123 | Will not move the tensors if None 124 | 125 | Returns: 126 | transposed dict of lists of observations. 127 | """ 128 | batch = defaultdict(list) 129 | 130 | for obs in observations: 131 | for sensor in obs: 132 | batch[sensor].append(to_tensor(obs[sensor])) 133 | 134 | for sensor in batch: 135 | batch[sensor] = torch.stack(batch[sensor], dim=0).to( 136 | device=device, dtype=torch.float 137 | ) 138 | 139 | return batch 140 | 141 | 142 | def poll_checkpoint_folder( 143 | checkpoint_folder: str, previous_ckpt_ind: int, eval_interval: int 144 | ) -> Optional[str]: 145 | r""" Return (previous_ckpt_ind + 1)th checkpoint in checkpoint folder 146 | (sorted by time of last modification). 147 | 148 | Args: 149 | checkpoint_folder: directory to look for checkpoints. 150 | previous_ckpt_ind: index of checkpoint last returned. 151 | eval_interval: number of checkpoints between two evaluation 152 | 153 | Returns: 154 | return checkpoint path if (previous_ckpt_ind + 1)th checkpoint is found 155 | else return None. 156 | """ 157 | assert os.path.isdir(checkpoint_folder), ( 158 | f"invalid checkpoint folder " f"path {checkpoint_folder}" 159 | ) 160 | models_paths = [] 161 | for i in range(800): 162 | model_path = f"{checkpoint_folder}/ckpt.{i}.pth" 163 | assert os.path.isfile(model_path), ( 164 | f"invalid checkpoint folder " f"path {model_path}" 165 | ) 166 | models_paths.append(model_path) 167 | #------------------------------------------ 168 | ind = previous_ckpt_ind + eval_interval 169 | if ind < len(models_paths): 170 | return models_paths[ind] 171 | return None 172 | 173 | 174 | def generate_video( 175 | video_option: List[str], 176 | video_dir: Optional[str], 177 | images: List[np.ndarray], 178 | scene_name: str, 179 | sound: str, 180 | sr: int, 181 | episode_id: int, 182 | checkpoint_idx: int, 183 | metric_name: str, 184 | metric_value: float, 185 | tb_writer: TensorboardWriter, 186 | fps: int = 10, 187 | audios: List[str] = None 188 | ) -> None: 189 | r"""Generate video according to specified information. 190 | 191 | Args: 192 | video_option: string list of "tensorboard" or "disk" or both. 193 | video_dir: path to target video directory. 194 | images: list of images to be converted to video. 195 | episode_id: episode id for video naming. 196 | checkpoint_idx: checkpoint index for video naming. 197 | metric_name: name of the performance metric, e.g. "spl". 198 | metric_value: value of metric. 199 | tb_writer: tensorboard writer object for uploading video. 200 | fps: fps for generated video. 201 | audios: raw audio files 202 | Returns: 203 | None 204 | """ 205 | if len(images) < 1: 206 | return 207 | 208 | ct =datetime.datetime.now() 209 | video_name = f"{checkpoint_idx}_{scene_name}_{episode_id}_{sound}_{metric_name}{metric_value:.2f}_{ct}" 210 | if "disk" in video_option: 211 | assert video_dir is not None 212 | if audios is None: 213 | images_to_video(images, video_dir, video_name) 214 | else: 215 | images_to_video_with_audio(images, video_dir, video_name, audios, sr, fps=fps) 216 | if "tensorboard" in video_option: 217 | tb_writer.add_video_from_np_images( 218 | f"episode{episode_id}", checkpoint_idx, images, fps=fps 219 | ) 220 | 221 | 222 | def plot_top_down_map(info, dataset='replica'): 223 | top_down_map = info["top_down_map"]["map"] 224 | top_down_map = maps.colorize_topdown_map( 225 | top_down_map, info["top_down_map"]["fog_of_war_mask"] 226 | ) 227 | map_agent_pos = info["top_down_map"]["agent_map_coord"] 228 | if dataset == 'replica': 229 | agent_radius_px = top_down_map.shape[0] // 16 230 | else: 231 | agent_radius_px = top_down_map.shape[0] // 50 232 | top_down_map = maps.draw_agent( 233 | image=top_down_map, 234 | agent_center_coord=map_agent_pos, 235 | agent_rotation=info["top_down_map"]["agent_angle"], 236 | agent_radius_px=agent_radius_px 237 | ) 238 | 239 | if top_down_map.shape[0] > top_down_map.shape[1]: 240 | top_down_map = np.rot90(top_down_map, 1) 241 | return top_down_map 242 | 243 | def images_to_video_with_audio( 244 | images: List[np.ndarray], 245 | output_dir: str, 246 | video_name: str, 247 | audios: List[str], 248 | sr: int, 249 | fps: int = 1, 250 | quality: Optional[float] = 5, 251 | **kwargs 252 | ): 253 | r"""Calls imageio to run FFMPEG on a list of images. For more info on 254 | parameters, see https://imageio.readthedocs.io/en/stable/format_ffmpeg.html 255 | Args: 256 | images: The list of images. Images should be HxWx3 in RGB order. 257 | output_dir: The folder to put the video in. 258 | video_name: The name for the video. 259 | audios: raw audio files 260 | fps: Frames per second for the video. Not all values work with FFMPEG, 261 | use at your own risk. 262 | quality: Default is 5. Uses variable bit rate. Highest quality is 10, 263 | lowest is 0. Set to None to prevent variable bitrate flags to 264 | FFMPEG so you can manually specify them using output_params 265 | instead. Specifying a fixed bitrate using ‘bitrate’ disables 266 | this parameter. 267 | """ 268 | assert 0 <= quality <= 10 269 | if not os.path.exists(output_dir): 270 | os.makedirs(output_dir) 271 | video_name = video_name.replace(" ", "_").replace("\n", "_") + ".mp4" 272 | 273 | 274 | assert len(images) == len(audios) * fps 275 | audio_clips = [] 276 | temp_file_name = '/tmp/{}.wav'.format(random.randint(0, 100000)) 277 | # use amplitude scaling factor to reduce the volume of sounds 278 | amplitude_scaling_factor = 100 279 | for i, audio in enumerate(audios): 280 | # def f(t): 281 | # return audio[0, t], audio[1: t] 282 | # 283 | # audio_clip = mpy.AudioClip(f, duration=1, fps=audio.shape[1]) 284 | wavfile.write(temp_file_name, sr, audio.T / amplitude_scaling_factor) 285 | audio_clip = mpy.AudioFileClip(temp_file_name) 286 | audio_clip = audio_clip.set_duration(1) 287 | audio_clip = audio_clip.set_start(i) 288 | audio_clips.append(audio_clip) 289 | composite_audio_clip = CompositeAudioClip(audio_clips) 290 | video_clip = mpy.ImageSequenceClip(images, fps=fps) 291 | video_with_new_audio = video_clip.set_audio(composite_audio_clip) 292 | video_with_new_audio.write_videofile(os.path.join(output_dir, video_name)) 293 | os.remove(temp_file_name) 294 | 295 | 296 | def resize_observation(observations, model_resolution): 297 | for observation in observations: 298 | observation['rgb'] = cv2.resize(observation['rgb'], (model_resolution, model_resolution)) 299 | observation['depth'] = np.expand_dims(cv2.resize(observation['depth'], (model_resolution, model_resolution)), 300 | axis=-1) 301 | -------------------------------------------------------------------------------- /soundspaces/visualizations/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | 5 | 6 | import os 7 | import textwrap 8 | from typing import Dict, List, Optional, Tuple 9 | 10 | import imageio 11 | import numpy as np 12 | import tqdm 13 | 14 | from habitat.core.logging import logger 15 | 16 | #from habitat.utils.visualizations import maps 17 | from soundspaces.visualizations import maps 18 | 19 | from habitat.core.utils import try_cv2_import 20 | 21 | cv2 = try_cv2_import() 22 | 23 | 24 | def paste_overlapping_image( 25 | background: np.ndarray, 26 | foreground: np.ndarray, 27 | location: Tuple[int, int], 28 | mask: Optional[np.ndarray] = None, 29 | ): 30 | r"""Composites the foreground onto the background dealing with edge 31 | boundaries. 32 | Args: 33 | background: the background image to paste on. 34 | foreground: the image to paste. Can be RGB or RGBA. If using alpha 35 | blending, values for foreground and background should both be 36 | between 0 and 255. Otherwise behavior is undefined. 37 | location: the image coordinates to paste the foreground. 38 | mask: If not None, a mask for deciding what part of the foreground to 39 | use. Must be the same size as the foreground if provided. 40 | Returns: 41 | The modified background image. This operation is in place. 42 | """ 43 | assert mask is None or mask.shape[:2] == foreground.shape[:2] 44 | foreground_size = foreground.shape[:2] 45 | min_pad = ( 46 | max(0, foreground_size[0] // 2 - location[0]), 47 | max(0, foreground_size[1] // 2 - location[1]), 48 | ) 49 | 50 | max_pad = ( 51 | max( 52 | 0, 53 | (location[0] + (foreground_size[0] - foreground_size[0] // 2)) 54 | - background.shape[0], 55 | ), 56 | max( 57 | 0, 58 | (location[1] + (foreground_size[1] - foreground_size[1] // 2)) 59 | - background.shape[1], 60 | ), 61 | ) 62 | 63 | background_patch = background[ 64 | (location[0] - foreground_size[0] // 2 + min_pad[0]) : ( 65 | location[0] 66 | + (foreground_size[0] - foreground_size[0] // 2) 67 | - max_pad[0] 68 | ), 69 | (location[1] - foreground_size[1] // 2 + min_pad[1]) : ( 70 | location[1] 71 | + (foreground_size[1] - foreground_size[1] // 2) 72 | - max_pad[1] 73 | ), 74 | ] 75 | foreground = foreground[ 76 | min_pad[0] : foreground.shape[0] - max_pad[0], 77 | min_pad[1] : foreground.shape[1] - max_pad[1], 78 | ] 79 | if foreground.size == 0 or background_patch.size == 0: 80 | # Nothing to do, no overlap. 81 | return background 82 | 83 | if mask is not None: 84 | mask = mask[ 85 | min_pad[0] : foreground.shape[0] - max_pad[0], 86 | min_pad[1] : foreground.shape[1] - max_pad[1], 87 | ] 88 | 89 | if foreground.shape[2] == 4: 90 | # Alpha blending 91 | foreground = ( 92 | background_patch.astype(np.int32) * (255 - foreground[:, :, [3]]) 93 | + foreground[:, :, :3].astype(np.int32) * foreground[:, :, [3]] 94 | ) // 255 95 | if mask is not None: 96 | background_patch[mask] = foreground[mask] 97 | else: 98 | background_patch[:] = foreground 99 | return background 100 | 101 | 102 | def images_to_video( 103 | images: List[np.ndarray], 104 | output_dir: str, 105 | video_name: str, 106 | fps: int = 10, 107 | quality: Optional[float] = 5, 108 | **kwargs, 109 | ): 110 | r"""Calls imageio to run FFMPEG on a list of images. For more info on 111 | parameters, see https://imageio.readthedocs.io/en/stable/format_ffmpeg.html 112 | Args: 113 | images: The list of images. Images should be HxWx3 in RGB order. 114 | output_dir: The folder to put the video in. 115 | video_name: The name for the video. 116 | fps: Frames per second for the video. Not all values work with FFMPEG, 117 | use at your own risk. 118 | quality: Default is 5. Uses variable bit rate. Highest quality is 10, 119 | lowest is 0. Set to None to prevent variable bitrate flags to 120 | FFMPEG so you can manually specify them using output_params 121 | instead. Specifying a fixed bitrate using ‘bitrate’ disables 122 | this parameter. 123 | """ 124 | assert 0 <= quality <= 10 125 | if not os.path.exists(output_dir): 126 | os.makedirs(output_dir) 127 | video_name = video_name.replace(" ", "_").replace("\n", "_") + ".mp4" 128 | writer = imageio.get_writer( 129 | os.path.join(output_dir, video_name), 130 | fps=fps, 131 | quality=quality, 132 | **kwargs, 133 | ) 134 | logger.info(f"Video created: {os.path.join(output_dir, video_name)}") 135 | for im in tqdm.tqdm(images): 136 | writer.append_data(im) 137 | writer.close() 138 | 139 | 140 | def draw_collision(view: np.ndarray, alpha: float = 0.4) -> np.ndarray: 141 | r"""Draw translucent red strips on the border of input view to indicate 142 | a collision has taken place. 143 | Args: 144 | view: input view of size HxWx3 in RGB order. 145 | alpha: Opacity of red collision strip. 1 is completely non-transparent. 146 | Returns: 147 | A view with collision effect drawn. 148 | """ 149 | strip_width = view.shape[0] // 20 150 | mask = np.ones(view.shape) 151 | mask[strip_width:-strip_width, strip_width:-strip_width] = 0 152 | mask = mask == 1 153 | view[mask] = (alpha * np.array([255, 0, 0]) + (1.0 - alpha) * view)[mask] 154 | return view 155 | 156 | 157 | def observations_to_image(observation: Dict, info: Dict) -> np.ndarray: 158 | r"""Generate image of single frame from observation and info 159 | returned from a single environment step(). 160 | 161 | Args: 162 | observation: observation returned from an environment step(). 163 | info: info returned from an environment step(). 164 | 165 | Returns: 166 | generated image of a single frame. 167 | """ 168 | #print(info) 169 | egocentric_view = [] 170 | if "rgb" in observation: 171 | observation_size = observation["rgb"].shape[0] 172 | rgb = observation["rgb"] 173 | if not isinstance(rgb, np.ndarray): 174 | rgb = rgb.cpu().numpy() 175 | 176 | egocentric_view.append(rgb) 177 | 178 | # draw depth map if observation has depth info 179 | if "depth" in observation: 180 | observation_size = observation["depth"].shape[0] 181 | depth_map = observation["depth"].squeeze() * 255.0 182 | if not isinstance(depth_map, np.ndarray): 183 | depth_map = depth_map.cpu().numpy() 184 | 185 | depth_map = depth_map.astype(np.uint8) 186 | depth_map = np.stack([depth_map for _ in range(3)], axis=2) 187 | egocentric_view.append(depth_map) 188 | 189 | assert ( 190 | len(egocentric_view) > 0 191 | ), "Expected at least one visual sensor enabled." 192 | egocentric_view = np.concatenate(egocentric_view, axis=1) 193 | 194 | # draw collision 195 | if "collisions" in info and info["collisions"]["is_collision"]: 196 | egocentric_view = draw_collision(egocentric_view) 197 | 198 | frame = egocentric_view 199 | 200 | if "top_down_map" in info: 201 | top_down_map = info["top_down_map"]["map"] 202 | top_down_map = maps.colorize_topdown_map( 203 | top_down_map, info["top_down_map"]["fog_of_war_mask"] 204 | ) 205 | map_agent_pos = info["top_down_map"]["agent_map_coord"] 206 | top_down_map = maps.draw_agent( 207 | image=top_down_map, 208 | agent_center_coord=map_agent_pos, 209 | agent_rotation=info["top_down_map"]["agent_angle"], 210 | agent_radius_px=top_down_map.shape[0] // 16, 211 | ) 212 | 213 | # ************************************ 214 | # # attack point 215 | #print("print attack point in utils************************:") 216 | # green = (0, 255, 0) 217 | map_attack_pos = info["top_down_map"]["attack_map_coord"] 218 | #print(f"attack_point in map in utils:{map_attack_pos}") 219 | # top_down_map = maps.draw_attack( 220 | # top_down_map, 221 | # map_attack_pos 222 | # ) 223 | # map_attack_pos_x, map_attack_pos_y = map_attack_pos 224 | # top_down_map = cv2.circle( 225 | # top_down_map, 226 | # map_attack_pos, 227 | # 20, 228 | # green, 229 | # thickness=-1, 230 | # ) 231 | # 232 | # cv2.imshow("after top_down_map", top_down_map) 233 | 234 | # add this for rotate the green circle is disappeared 235 | #utils.paste_overlapping_image(top_down_map, map_attack, map_attack_pos) 236 | #************************************ 237 | #print(f"top_down_map.shape 原始:{top_down_map.shape}") 238 | #print(top_down_map) 239 | if top_down_map.shape[0] > top_down_map.shape[1]: 240 | #print('grid xxxxxx旋转90') 241 | top_down_map = np.rot90(top_down_map, 1) 242 | 243 | # scale top down map to align with rgb view 244 | old_h, old_w, _ = top_down_map.shape 245 | 246 | top_down_height = observation_size 247 | top_down_width = int(float(top_down_height) / old_h * old_w) 248 | # cv2 resize (dsize is width first) 249 | top_down_map = cv2.resize( 250 | top_down_map, 251 | (top_down_width, top_down_height), 252 | interpolation=cv2.INTER_CUBIC, 253 | ) 254 | 255 | frame = np.concatenate((egocentric_view, top_down_map), axis=1) 256 | return frame 257 | 258 | 259 | def append_text_to_image(image: np.ndarray, text: str): 260 | r""" Appends text underneath an image of size (height, width, channels). 261 | The returned image has white text on a black background. Uses textwrap to 262 | split long text into multiple lines. 263 | Args: 264 | image: the image to put text underneath 265 | text: a string to display 266 | Returns: 267 | A new image with text inserted underneath the input image 268 | """ 269 | h, w, c = image.shape 270 | font_size = 0.5 271 | font_thickness = 1 272 | font = cv2.FONT_HERSHEY_SIMPLEX 273 | blank_image = np.zeros(image.shape, dtype=np.uint8) 274 | 275 | char_size = cv2.getTextSize(" ", font, font_size, font_thickness)[0] 276 | wrapped_text = textwrap.wrap(text, width=int(w / char_size[0])) 277 | 278 | y = 0 279 | for line in wrapped_text: 280 | textsize = cv2.getTextSize(line, font, font_size, font_thickness)[0] 281 | y += textsize[1] + 10 282 | x = 10 283 | cv2.putText( 284 | blank_image, 285 | line, 286 | (x, y), 287 | font, 288 | font_size, 289 | (255, 255, 255), 290 | font_thickness, 291 | lineType=cv2.LINE_AA, 292 | ) 293 | text_image = blank_image[0 : y + 10, 0:w] 294 | final = np.concatenate((image, text_image), axis=0) 295 | return final 296 | -------------------------------------------------------------------------------- /ss_baselines/common/base_trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import time 5 | from typing import ClassVar, Dict, List 6 | import glob 7 | 8 | import torch 9 | 10 | from habitat import Config, logger 11 | from ss_baselines.common.tensorboard_utils import TensorboardWriter 12 | from ss_baselines.common.utils import poll_checkpoint_folder 13 | 14 | 15 | class BaseTrainer: 16 | r"""Generic trainer class that serves as a base template for more 17 | specific trainer classes like RL trainer, SLAM or imitation learner. 18 | Includes only the most basic functionality. 19 | """ 20 | 21 | supported_tasks: ClassVar[List[str]] 22 | 23 | def train(self) -> None: 24 | raise NotImplementedError 25 | 26 | def eval(self) -> None: 27 | raise NotImplementedError 28 | 29 | def save_checkpoint(self, file_name) -> None: 30 | raise NotImplementedError 31 | 32 | def load_checkpoint(self, checkpoint_path, *args, **kwargs) -> Dict: 33 | raise NotImplementedError 34 | 35 | 36 | class BaseRLTrainer(BaseTrainer): 37 | r"""Base trainer class for RL trainers. Future RL-specific 38 | methods should be hosted here. 39 | """ 40 | device: torch.device 41 | config: Config 42 | video_option: List[str] 43 | _flush_secs: int 44 | 45 | def __init__(self, config: Config): 46 | super().__init__() 47 | assert config is not None, "needs config file to initialize trainer" 48 | self.config = config 49 | self._flush_secs = 30 50 | 51 | @property 52 | def flush_secs(self): 53 | return self._flush_secs 54 | 55 | @flush_secs.setter 56 | def flush_secs(self, value: int): 57 | self._flush_secs = value 58 | 59 | def train(self) -> None: 60 | raise NotImplementedError 61 | 62 | def eval(self, eval_interval=1, prev_ckpt_ind=-1, use_last_ckpt=False) -> None: 63 | r"""Main method of trainer evaluation. Calls _eval_checkpoint() that 64 | is specified in Trainer class that inherits from BaseRLTrainer 65 | 66 | Returns: 67 | None 68 | """ 69 | self.device = ( 70 | torch.device("cuda", self.config.TORCH_GPU_ID) 71 | if torch.cuda.is_available() 72 | else torch.device("cpu") 73 | ) 74 | 75 | if "tensorboard" in self.config.VIDEO_OPTION: 76 | assert ( 77 | len(self.config.TENSORBOARD_DIR) > 0 78 | ), "Must specify a tensorboard directory for video display" 79 | if "disk" in self.config.VIDEO_OPTION: 80 | assert ( 81 | len(self.config.VIDEO_DIR) > 0 82 | ), "Must specify a directory for storing videos on disk" 83 | 84 | with TensorboardWriter( 85 | self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs 86 | ) as writer: 87 | # eval last checkpoint in the folder 88 | if use_last_ckpt: 89 | models_paths = list( 90 | filter(os.path.isfile, glob.glob(self.config.EVAL_CKPT_PATH_DIR + "/*")) 91 | ) 92 | models_paths.sort(key=os.path.getmtime) 93 | self.config.defrost() 94 | self.config.EVAL_CKPT_PATH_DIR = models_paths[-1] 95 | self.config.freeze() 96 | 97 | # add for continue eval ckpt from break step 98 | if self.config.EVAL_prev_ckpt_index_enable: 99 | prev_ckpt_ind = int(self.config.EVAL_prev_ckpt_index) 100 | 101 | if os.path.isfile(self.config.EVAL_CKPT_PATH_DIR): 102 | # evaluate singe checkpoint 103 | result = self._eval_checkpoint(self.config.EVAL_CKPT_PATH_DIR, writer) 104 | return result 105 | else: 106 | # evaluate multiple checkpoints in order 107 | while True: 108 | current_ckpt = None 109 | while current_ckpt is None: 110 | current_ckpt = poll_checkpoint_folder( 111 | self.config.EVAL_CKPT_PATH_DIR, prev_ckpt_ind, eval_interval 112 | ) 113 | # time.sleep(2) # sleep for 2 secs before polling again 114 | logger.info(f"=======current_ckpt: {current_ckpt}=======") 115 | prev_ckpt_ind += eval_interval 116 | self._eval_checkpoint( 117 | checkpoint_path=current_ckpt, 118 | writer=writer, 119 | checkpoint_index=prev_ckpt_ind 120 | ) 121 | 122 | if prev_ckpt_ind == 799: 123 | break 124 | logger.info(f"=======congratulations,you have finished check all the ckpt,the last ckpt is: {current_ckpt}=======") 125 | 126 | def _setup_eval_config(self, checkpoint_config: Config) -> Config: 127 | r"""Sets up and returns a merged config for evaluation. Config 128 | object saved from checkpoint is merged into config file specified 129 | at evaluation time with the following overwrite priority: 130 | eval_opts > ckpt_opts > eval_cfg > ckpt_cfg 131 | If the saved config is outdated, only the eval config is returned. 132 | 133 | Args: 134 | checkpoint_config: saved config from checkpoint. 135 | 136 | Returns: 137 | Config: merged config for eval. 138 | """ 139 | 140 | config = self.config.clone() 141 | 142 | ckpt_cmd_opts = checkpoint_config.CMD_TRAILING_OPTS 143 | eval_cmd_opts = config.CMD_TRAILING_OPTS 144 | 145 | try: 146 | config.merge_from_other_cfg(checkpoint_config) 147 | config.merge_from_other_cfg(self.config) 148 | config.merge_from_list(ckpt_cmd_opts) 149 | config.merge_from_list(eval_cmd_opts) 150 | except KeyError: 151 | logger.info("Saved config is outdated, using solely eval config") 152 | config = self.config.clone() 153 | config.merge_from_list(eval_cmd_opts) 154 | 155 | config.TASK_CONFIG.SIMULATOR.AGENT_0.defrost() 156 | config.TASK_CONFIG.SIMULATOR.AGENT_0.SENSORS = self.config.SENSORS 157 | config.freeze() 158 | 159 | return config 160 | 161 | def _eval_checkpoint( 162 | self, 163 | checkpoint_path: str, 164 | writer: TensorboardWriter, 165 | checkpoint_index: int = 0, 166 | ) -> None: 167 | r"""Evaluates a single checkpoint. Trainer algorithms should 168 | implement this. 169 | 170 | Args: 171 | checkpoint_path: path of checkpoint 172 | writer: tensorboard writer object for logging to tensorboard 173 | checkpoint_index: index of cur checkpoint for logging 174 | 175 | Returns: 176 | None 177 | """ 178 | raise NotImplementedError 179 | 180 | def save_checkpoint(self, file_name) -> None: 181 | raise NotImplementedError 182 | 183 | def load_checkpoint(self, checkpoint_path, *args, **kwargs) -> Dict: 184 | raise NotImplementedError 185 | 186 | @staticmethod 187 | def _pause_envs( 188 | envs_to_pause, 189 | envs, 190 | test_recurrent_hidden_states, 191 | not_done_masks, 192 | current_episode_reward, 193 | prev_actions, 194 | batch, 195 | rgb_frames, 196 | ): 197 | # pausing self.envs with no new episode 198 | if len(envs_to_pause) > 0: 199 | state_index = list(range(envs.num_envs)) 200 | for idx in reversed(envs_to_pause): 201 | state_index.pop(idx) 202 | envs.pause_at(idx) 203 | 204 | # indexing along the batch dimensions 205 | test_recurrent_hidden_states = test_recurrent_hidden_states[ 206 | :, state_index 207 | ] 208 | not_done_masks = not_done_masks[state_index] 209 | current_episode_reward = current_episode_reward[state_index] 210 | prev_actions = prev_actions[state_index] 211 | 212 | for k, v in batch.items(): 213 | batch[k] = v[state_index] 214 | 215 | rgb_frames = [rgb_frames[i] for i in state_index] 216 | 217 | return ( 218 | envs, 219 | test_recurrent_hidden_states, 220 | not_done_masks, 221 | current_episode_reward, 222 | prev_actions, 223 | batch, 224 | rgb_frames, 225 | ) 226 | 227 | @staticmethod 228 | def _pause_envs_agent_attack( 229 | attacker_actions_desc, 230 | envs_to_pause, 231 | envs, 232 | test_recurrent_hidden_states_agent, 233 | test_recurrent_hidden_states_attack, 234 | not_done_masks_agent, 235 | not_done_masks_attack, 236 | current_episode_reward_agent, 237 | current_episode_reward_attack, 238 | prev_actions_agent, 239 | prev_actions_attack, 240 | batch, 241 | rgb_frames, 242 | ): 243 | # pausing self.envs with no new episode 244 | if len(envs_to_pause) > 0: 245 | state_index = list(range(envs.num_envs)) 246 | for idx in reversed(envs_to_pause): 247 | state_index.pop(idx) 248 | envs.pause_at(idx) 249 | 250 | # indexing along the batch dimensions 251 | test_recurrent_hidden_states_agent = test_recurrent_hidden_states_agent[ 252 | :, state_index 253 | ] 254 | test_recurrent_hidden_states_attack = test_recurrent_hidden_states_attack[ 255 | :, state_index 256 | ] 257 | 258 | not_done_masks_agent = not_done_masks_agent[state_index] 259 | not_done_masks_attack = not_done_masks_attack[state_index] 260 | 261 | current_episode_reward_agent = current_episode_reward_agent[state_index] 262 | current_episode_reward_attack = current_episode_reward_attack[state_index] 263 | 264 | prev_actions_agent = prev_actions_agent[state_index] 265 | prev_actions_attack_new = {} 266 | if attacker_actions_desc.position.is_action: 267 | prev_actions_attack_new["position"] = prev_actions_attack["position"][state_index] 268 | if attacker_actions_desc.alpha.is_action: 269 | prev_actions_attack_new["alpha"] = prev_actions_attack["alpha"][state_index] 270 | if attacker_actions_desc.category.is_action: 271 | prev_actions_attack_new["category"] = prev_actions_attack["category"][state_index] 272 | 273 | prev_actions_attack = None 274 | prev_actions_attack = prev_actions_attack_new 275 | 276 | for k, v in batch.items(): 277 | batch[k] = v[state_index] 278 | 279 | rgb_frames = [rgb_frames[i] for i in state_index] 280 | 281 | return ( 282 | envs, 283 | test_recurrent_hidden_states_agent, 284 | test_recurrent_hidden_states_attack, 285 | not_done_masks_agent, 286 | not_done_masks_attack, 287 | current_episode_reward_agent, 288 | current_episode_reward_attack, 289 | prev_actions_agent, 290 | prev_actions_attack, 291 | batch, 292 | rgb_frames, 293 | ) 294 | -------------------------------------------------------------------------------- /storage/rollout_storage.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from collections import defaultdict 4 | 5 | import torch 6 | import copy 7 | 8 | from habitat.config import Config as CN 9 | 10 | 11 | class RolloutStorage: 12 | r"""Class for storing rollout information for RL trainers. 13 | 14 | """ 15 | 16 | def __init__( 17 | self, 18 | num_steps, 19 | num_envs, 20 | observation_space, 21 | action_space, 22 | recurrent_hidden_state_size, 23 | device, 24 | num_recurrent_layers=1, 25 | 26 | ): 27 | self.observations = {} 28 | self.device = device 29 | 30 | for sensor in observation_space.spaces: 31 | self.observations[sensor] = torch.zeros( 32 | num_steps + 1, 33 | num_envs, 34 | *observation_space.spaces[sensor].shape, 35 | device = self.device, 36 | ) 37 | 38 | self.recurrent_hidden_states = torch.zeros( 39 | num_steps + 1, 40 | num_recurrent_layers, 41 | num_envs, 42 | recurrent_hidden_state_size, 43 | device = self.device, 44 | ) 45 | 46 | self.rewards = torch.zeros(num_steps, num_envs, 1,device = self.device) 47 | self.value_preds = torch.zeros(num_steps + 1, num_envs, 1,device = self.device) 48 | self.returns = torch.zeros(num_steps + 1, num_envs, 1,device = self.device) 49 | 50 | self.action_log_probs = torch.zeros(num_steps, num_envs, 1,device = self.device) 51 | if action_space.__class__.__name__ == "ActionSpace": 52 | action_shape = 1 53 | else: 54 | action_shape = action_space.shape[0] 55 | 56 | self.actions = torch.zeros(num_steps, num_envs, action_shape,device = self.device) 57 | self.prev_actions = torch.zeros(num_steps + 1, num_envs, action_shape,device = self.device) 58 | if action_space.__class__.__name__ == "ActionSpace": 59 | self.actions = self.actions.long() 60 | self.prev_actions = self.prev_actions.long() 61 | 62 | self.masks = torch.ones(num_steps + 1, num_envs, 1,device = self.device) 63 | 64 | self.num_steps = num_steps 65 | self.step = 0 66 | 67 | def to(self, device): 68 | for sensor in self.observations: 69 | self.observations[sensor] = self.observations[sensor].to(device) 70 | 71 | self.recurrent_hidden_states = self.recurrent_hidden_states.to(device) 72 | self.rewards = self.rewards.to(device) 73 | self.value_preds = self.value_preds.to(device) 74 | self.returns = self.returns.to(device) 75 | self.action_log_probs = self.action_log_probs.to(device) 76 | self.actions = self.actions.to(device) 77 | self.prev_actions = self.prev_actions.to(device) 78 | self.masks = self.masks.to(device) 79 | 80 | def insert( 81 | self, 82 | observations, 83 | recurrent_hidden_states, 84 | actions, 85 | action_log_probs, 86 | value_preds, 87 | rewards, 88 | masks, 89 | ): 90 | for sensor in observations: 91 | self.observations[sensor][self.step + 1].copy_( 92 | observations[sensor] 93 | ) 94 | self.recurrent_hidden_states[self.step + 1].copy_( 95 | recurrent_hidden_states 96 | ) 97 | self.actions[self.step].copy_(actions) 98 | self.prev_actions[self.step + 1].copy_(actions) 99 | self.action_log_probs[self.step].copy_(action_log_probs) 100 | self.value_preds[self.step].copy_(value_preds) 101 | self.rewards[self.step].copy_(rewards) 102 | self.masks[self.step + 1].copy_(masks) 103 | 104 | self.step = (self.step + 1) % self.num_steps 105 | 106 | def after_update(self): 107 | for sensor in self.observations: 108 | self.observations[sensor][0].copy_(self.observations[sensor][-1]) 109 | 110 | self.recurrent_hidden_states[0].copy_(self.recurrent_hidden_states[-1]) 111 | self.masks[0].copy_(self.masks[-1]) 112 | self.prev_actions[0].copy_(self.prev_actions[-1]) 113 | 114 | def compute_returns(self, next_value, use_gae, gamma, tau): 115 | if use_gae: 116 | self.value_preds[-1] = next_value 117 | gae = 0 118 | for step in reversed(range(self.rewards.size(0))): 119 | delta = ( 120 | self.rewards[step] 121 | + gamma * self.value_preds[step + 1] * self.masks[step + 1] 122 | - self.value_preds[step] 123 | ) 124 | gae = delta + gamma * tau * self.masks[step + 1] * gae 125 | self.returns[step] = gae + self.value_preds[step] 126 | else: 127 | self.returns[-1] = next_value 128 | for step in reversed(range(self.rewards.size(0))): 129 | self.returns[step] = ( 130 | self.returns[step + 1] * gamma * self.masks[step + 1] 131 | + self.rewards[step] 132 | ) 133 | 134 | def recurrent_generator(self, advantages, num_mini_batch): 135 | #rewards: (num_steps, num_envs, 1) 136 | 137 | num_processes = self.rewards.size(1) 138 | assert num_processes >= num_mini_batch, ( 139 | "Trainer requires the number of processes ({}) " 140 | "to be greater than or equal to the number of " 141 | "trainer mini batches ({}).".format(num_processes, num_mini_batch) 142 | ) 143 | num_envs_per_batch = num_processes // num_mini_batch 144 | perm = torch.randperm(num_processes) 145 | for start_ind in range(0, num_processes, num_envs_per_batch): 146 | observations_batch = defaultdict(list) 147 | 148 | recurrent_hidden_states_batch = [] 149 | actions_batch = [] 150 | prev_actions_batch = [] 151 | value_preds_batch = [] 152 | return_batch = [] 153 | masks_batch = [] 154 | old_action_log_probs_batch = [] 155 | adv_targ = [] 156 | 157 | for offset in range(num_envs_per_batch): 158 | ind = perm[start_ind + offset] 159 | 160 | for sensor in self.observations: 161 | observations_batch[sensor].append( 162 | self.observations[sensor][:-1, ind] 163 | ) 164 | 165 | recurrent_hidden_states_batch.append( 166 | self.recurrent_hidden_states[0, :, ind] 167 | ) 168 | 169 | actions_batch.append(self.actions[:, ind]) 170 | prev_actions_batch.append(self.prev_actions[:-1, ind]) 171 | value_preds_batch.append(self.value_preds[:-1, ind]) 172 | return_batch.append(self.returns[:-1, ind]) 173 | masks_batch.append(self.masks[:-1, ind]) 174 | old_action_log_probs_batch.append( 175 | self.action_log_probs[:, ind] 176 | ) 177 | 178 | adv_targ.append(advantages[:, ind]) 179 | 180 | T, N = self.num_steps, num_envs_per_batch 181 | 182 | # These are all tensors of size (T, N, -1) 183 | for sensor in observations_batch: 184 | observations_batch[sensor] = torch.stack( 185 | observations_batch[sensor], 1 186 | ) 187 | 188 | actions_batch = torch.stack(actions_batch, 1) 189 | prev_actions_batch = torch.stack(prev_actions_batch, 1) 190 | value_preds_batch = torch.stack(value_preds_batch, 1) 191 | return_batch = torch.stack(return_batch, 1) 192 | masks_batch = torch.stack(masks_batch, 1) 193 | old_action_log_probs_batch = torch.stack( 194 | old_action_log_probs_batch, 1 195 | ) 196 | adv_targ = torch.stack(adv_targ, 1) 197 | 198 | # States is just a (num_recurrent_layers, N, -1) tensor 199 | recurrent_hidden_states_batch = torch.stack( 200 | recurrent_hidden_states_batch, 1 201 | ) 202 | 203 | # Flatten the (T, N, ...) tensors to (T * N, ...) 204 | for sensor in observations_batch: 205 | observations_batch[sensor] = self._flatten_helper( 206 | T, N, observations_batch[sensor] 207 | ) 208 | 209 | actions_batch = self._flatten_helper(T, N, actions_batch) 210 | prev_actions_batch = self._flatten_helper(T, N, prev_actions_batch) 211 | value_preds_batch = self._flatten_helper(T, N, value_preds_batch) 212 | return_batch = self._flatten_helper(T, N, return_batch) 213 | masks_batch = self._flatten_helper(T, N, masks_batch) 214 | old_action_log_probs_batch = self._flatten_helper( 215 | T, N, old_action_log_probs_batch 216 | ) 217 | adv_targ = self._flatten_helper(T, N, adv_targ) 218 | 219 | yield ( 220 | observations_batch, 221 | recurrent_hidden_states_batch, 222 | actions_batch, 223 | prev_actions_batch, 224 | value_preds_batch, 225 | return_batch, 226 | masks_batch, 227 | old_action_log_probs_batch, 228 | adv_targ, 229 | ) 230 | 231 | @staticmethod 232 | def _flatten_helper(t: int, n: int, tensor: torch.Tensor) -> torch.Tensor: 233 | r"""Given a tensor of size (t, n, ..), flatten it to size (t*n, ...). 234 | 235 | Args: 236 | t: first dimension of tensor. 237 | n: second dimension of tensor. 238 | tensor: target tensor to be flattened. 239 | 240 | Returns: 241 | flattened tensor of size (t*n, ...) 242 | """ 243 | return tensor.view(t * n, *tensor.size()[2:]) 244 | 245 | class RolloutStorageTwoAgentHybrid: 246 | r"""Class for storing rollout information for Two RL trainers. 247 | 248 | """ 249 | 250 | def __init__( 251 | self, 252 | attack:RolloutStorageHybrid, 253 | agent:RolloutStorage, 254 | ): 255 | self.attack = attack 256 | self.agent = agent 257 | 258 | def to(self, device): 259 | self.attack = self.attack.to(device) 260 | self.agent = self.agent.to(device) 261 | 262 | 263 | class RolloutStorageTwoAgent: 264 | r"""Class for storing rollout information for Two RL trainers. 265 | 266 | """ 267 | 268 | def __init__( 269 | self, 270 | attack:RolloutStorageHybrid, 271 | agent:RolloutStorage, 272 | ): 273 | self.attack = attack 274 | self.agent = agent 275 | 276 | def to(self, device): 277 | self.attack = self.attack.to(device) 278 | self.agent = self.agent.to(device) 279 | 280 | class RolloutStorageMA: 281 | r"""Class for storing rollout information for Two RL trainers. 282 | 283 | """ 284 | 285 | def __init__( 286 | self, 287 | attack:RolloutStorageHybrid, 288 | agent:RolloutStorage, 289 | ): 290 | self.attack = attack 291 | self.agent = agent 292 | 293 | def to(self, device): 294 | self.attack = self.attack.to(device) 295 | self.agent = self.agent.to(device) 296 | 297 | def recurrent_generator(self, advantages, num_mini_batch): 298 | data_generator_agent = self.agent.recurrent_generator( 299 | advantages["agent"], num_mini_batch 300 | ) 301 | data_generator_attack = self.attack.recurrent_generator( 302 | advantages["attack"], num_mini_batch 303 | ) 304 | for sample_agent,sample_attack in zip(data_generator_agent,data_generator_attack): 305 | ( 306 | obs_batch_agent, 307 | recurrent_hidden_states_batch_agent, 308 | actions_batch_agent, 309 | prev_actions_batch_agent, 310 | value_preds_batch_agent, 311 | return_batch_agent, 312 | masks_batch_agent, 313 | old_action_log_probs_batch_agent, 314 | adv_targ_agent, 315 | ) = sample_agent 316 | 317 | ( 318 | obs_batch_attack, 319 | recurrent_hidden_states_batch_attack, 320 | actions_batch_attack, 321 | prev_actions_batch_attack, 322 | value_preds_batch_attack, 323 | return_batch_attack, 324 | masks_batch_attack, 325 | old_action_log_probs_batch_attack, 326 | adv_targ_attack, 327 | ) = sample_attack 328 | 329 | 330 | yield ( 331 | obs_batch_agent, obs_batch_attack, 332 | recurrent_hidden_states_batch_agent, recurrent_hidden_states_batch_attack, 333 | actions_batch_agent, actions_batch_attack, 334 | prev_actions_batch_agent, prev_actions_batch_attack, 335 | value_preds_batch_agent, value_preds_batch_attack, 336 | return_batch_agent, return_batch_attack, 337 | masks_batch_agent, masks_batch_attack, 338 | old_action_log_probs_batch_agent, old_action_log_probs_batch_attack, 339 | adv_targ_agent, adv_targ_attack, 340 | ) 341 | 342 | 343 | 344 | 345 | def get(dg1,dg2): 346 | for i,j in zip(dg1,dg2): 347 | a,b,c = i 348 | _,d,e = j 349 | yield ( 350 | a, 351 | b, 352 | c, 353 | d, 354 | e, 355 | ) 356 | 357 | if __name__ == "__main__": 358 | import itertools 359 | dg1 = [(1,2,3),(2,4,5),(3,6,7)] 360 | dg2 = [(1,12,13),(2,14,15)] 361 | i = 0 362 | for sample in get(dg1,dg2): 363 | print(sample) 364 | print(i) 365 | i+=1 366 | 367 | -------------------------------------------------------------------------------- /soundspaces/visualizations/maps.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | 5 | 6 | import os 7 | from typing import List, Optional, Tuple 8 | 9 | import imageio 10 | import numpy as np 11 | import scipy.ndimage 12 | 13 | from habitat.core.simulator import Simulator 14 | from habitat.core.utils import try_cv2_import 15 | from habitat.utils.visualizations import utils 16 | 17 | cv2 = try_cv2_import() 18 | 19 | AGENT_SPRITE = imageio.imread( 20 | os.path.join( 21 | os.path.dirname(__file__), 22 | "assets", 23 | "maps_topdown_agent_sprite", 24 | "100x100.png", 25 | ) 26 | ) 27 | AGENT_SPRITE = np.ascontiguousarray(np.flipud(AGENT_SPRITE)) 28 | COORDINATE_EPSILON = 1e-6 29 | COORDINATE_MIN = -62.3241 - COORDINATE_EPSILON 30 | COORDINATE_MAX = 90.0399 + COORDINATE_EPSILON 31 | 32 | MAP_INVALID_POINT = 0 33 | MAP_VALID_POINT = 1 34 | MAP_BORDER_INDICATOR = 2 35 | MAP_ATTACK_POINT_INDICATOR = 3 36 | MAP_SOURCE_POINT_INDICATOR = 4 37 | MAP_TARGET_POINT_INDICATOR = 6 38 | MAP_SHORTEST_PATH_COLOR = 7 39 | MAP_VIEW_POINT_INDICATOR = 8 40 | MAP_TARGET_BOUNDING_BOX = 9 41 | TOP_DOWN_MAP_COLORS = np.full((256, 3), 150, dtype=np.uint8) 42 | TOP_DOWN_MAP_COLORS[10:] = cv2.applyColorMap( 43 | np.arange(246, dtype=np.uint8), cv2.COLORMAP_JET 44 | ).squeeze(1)[:, ::-1] 45 | TOP_DOWN_MAP_COLORS[MAP_ATTACK_POINT_INDICATOR] = [255, 255, 0] # yellow 46 | 47 | TOP_DOWN_MAP_COLORS[MAP_INVALID_POINT] = [255, 255, 255] # White 48 | TOP_DOWN_MAP_COLORS[MAP_VALID_POINT] = [150, 150, 150] # Light Grey 49 | TOP_DOWN_MAP_COLORS[MAP_BORDER_INDICATOR] = [50, 50, 50] # Grey 50 | TOP_DOWN_MAP_COLORS[MAP_SOURCE_POINT_INDICATOR] = [0, 0, 200] # Blue 51 | TOP_DOWN_MAP_COLORS[MAP_TARGET_POINT_INDICATOR] = [200, 0, 0] # Red 52 | TOP_DOWN_MAP_COLORS[MAP_SHORTEST_PATH_COLOR] = [0, 200, 0] # Green 53 | TOP_DOWN_MAP_COLORS[MAP_VIEW_POINT_INDICATOR] = [245, 150, 150] # Light Red 54 | TOP_DOWN_MAP_COLORS[MAP_TARGET_BOUNDING_BOX] = [0, 175, 0] # Green 55 | 56 | 57 | def draw_agent( 58 | image: np.ndarray, 59 | agent_center_coord: Tuple[int, int], 60 | agent_rotation: float, 61 | agent_radius_px: int = 5, 62 | ) -> np.ndarray: 63 | r"""Return an image with the agent image composited onto it. 64 | Args: 65 | image: the image onto which to put the agent. 66 | agent_center_coord: the image coordinates where to paste the agent. 67 | agent_rotation: the agent's current rotation in radians. 68 | agent_radius_px: 1/2 number of pixels the agent will be resized to. 69 | Returns: 70 | The modified background image. This operation is in place. 71 | """ 72 | 73 | # Rotate before resize to keep good resolution. 74 | rotated_agent = scipy.ndimage.interpolation.rotate( 75 | AGENT_SPRITE, agent_rotation * 180 / np.pi 76 | ) 77 | # Rescale because rotation may result in larger image than original, but 78 | # the agent sprite size should stay the same. 79 | initial_agent_size = AGENT_SPRITE.shape[0] 80 | new_size = rotated_agent.shape[0] 81 | agent_size_px = max( 82 | 1, int(agent_radius_px * 2 * new_size / initial_agent_size) 83 | ) 84 | resized_agent = cv2.resize( 85 | rotated_agent, 86 | (agent_size_px, agent_size_px), 87 | interpolation=cv2.INTER_LINEAR, 88 | ) 89 | utils.paste_overlapping_image(image, resized_agent, agent_center_coord) 90 | return image 91 | 92 | 93 | def pointnav_draw_target_birdseye_view( 94 | agent_position: np.ndarray, 95 | agent_heading: float, 96 | goal_position: np.ndarray, 97 | resolution_px: int = 800, 98 | goal_radius: float = 0.2, 99 | agent_radius_px: int = 20, 100 | target_band_radii: Optional[List[float]] = None, 101 | target_band_colors: Optional[List[Tuple[int, int, int]]] = None, 102 | ) -> np.ndarray: 103 | r"""Return an image of agent w.r.t. centered target location for pointnav 104 | tasks. 105 | 106 | Args: 107 | agent_position: the agent's current position. 108 | agent_heading: the agent's current rotation in radians. This can be 109 | found using the HeadingSensor. 110 | goal_position: the pointnav task goal position. 111 | resolution_px: number of pixels for the output image width and height. 112 | goal_radius: how near the agent needs to be to be successful for the 113 | pointnav task. 114 | agent_radius_px: 1/2 number of pixels the agent will be resized to. 115 | target_band_radii: distance in meters to the outer-radius of each band 116 | in the target image. 117 | target_band_colors: colors in RGB 0-255 for the bands in the target. 118 | Returns: 119 | Image centered on the goal with the agent's current relative position 120 | and rotation represented by an arrow. To make the rotations align 121 | visually with habitat, positive-z is up, positive-x is left and a 122 | rotation of 0 points upwards in the output image and rotates clockwise. 123 | """ 124 | if target_band_radii is None: 125 | target_band_radii = [20, 10, 5, 2.5, 1] 126 | if target_band_colors is None: 127 | target_band_colors = [ 128 | (47, 19, 122), 129 | (22, 99, 170), 130 | (92, 177, 0), 131 | (226, 169, 0), 132 | (226, 12, 29), 133 | ] 134 | 135 | assert len(target_band_radii) == len( 136 | target_band_colors 137 | ), "There must be an equal number of scales and colors." 138 | print(f"goal_position:{goal_position}") 139 | goal_agent_dist = np.linalg.norm(agent_position - goal_position, 2) 140 | 141 | goal_distance_padding = max( 142 | 2, 2 ** np.ceil(np.log(max(1e-6, goal_agent_dist)) / np.log(2)) 143 | ) 144 | movement_scale = 1.0 / goal_distance_padding 145 | half_res = resolution_px // 2 146 | im_position = np.full( 147 | (resolution_px, resolution_px, 3), 255, dtype=np.uint8 148 | ) 149 | 150 | # Draw bands: 151 | for scale, color in zip(target_band_radii, target_band_colors): 152 | if goal_distance_padding * 4 > scale: 153 | cv2.circle( 154 | im_position, 155 | (half_res, half_res), 156 | max(2, int(half_res * scale * movement_scale)), 157 | color, 158 | thickness=-1, 159 | ) 160 | 161 | # Draw such that the agent being inside the radius is the circles 162 | # overlapping. 163 | cv2.circle( 164 | im_position, 165 | (half_res, half_res), 166 | max(2, int(half_res * goal_radius * movement_scale)), 167 | (127, 0, 0), 168 | thickness=-1, 169 | ) 170 | 171 | relative_position = agent_position - goal_position 172 | # swap x and z, remove y for (x,y,z) -> image coordinates. 173 | relative_position = relative_position[[2, 0]] 174 | relative_position *= half_res * movement_scale 175 | relative_position += half_res 176 | relative_position = np.round(relative_position).astype(np.int32) 177 | 178 | # Draw the agent 179 | draw_agent(im_position, relative_position, agent_heading, agent_radius_px) 180 | 181 | # Rotate twice to fix coordinate system to upwards being positive-z. 182 | # Rotate instead of flip to keep agent rotations in sync with egocentric 183 | # view. 184 | im_position = np.rot90(im_position, 2) 185 | return im_position 186 | 187 | 188 | def to_grid( 189 | realworld_x: float, 190 | realworld_y: float, 191 | coordinate_min: float, 192 | coordinate_max: float, 193 | grid_resolution: Tuple[int, int], 194 | ) -> Tuple[int, int]: 195 | r"""Return gridworld index of realworld coordinates assuming top-left corner 196 | is the origin. The real world coordinates of lower left corner are 197 | (coordinate_min, coordinate_min) and of top right corner are 198 | (coordinate_max, coordinate_max) 199 | """ 200 | grid_size = ( 201 | (coordinate_max - coordinate_min) / grid_resolution[0], 202 | (coordinate_max - coordinate_min) / grid_resolution[1], 203 | ) 204 | grid_x = int((coordinate_max - realworld_x) / grid_size[0]) 205 | grid_y = int((realworld_y - coordinate_min) / grid_size[1]) 206 | return grid_x, grid_y 207 | 208 | 209 | def from_grid( 210 | grid_x: int, 211 | grid_y: int, 212 | coordinate_min: float, 213 | coordinate_max: float, 214 | grid_resolution: Tuple[int, int], 215 | ) -> Tuple[float, float]: 216 | r"""Inverse of _to_grid function. Return real world coordinate from 217 | gridworld assuming top-left corner is the origin. The real world 218 | coordinates of lower left corner are (coordinate_min, coordinate_min) and 219 | of top right corner are (coordinate_max, coordinate_max) 220 | """ 221 | grid_size = ( 222 | (coordinate_max - coordinate_min) / grid_resolution[0], 223 | (coordinate_max - coordinate_min) / grid_resolution[1], 224 | ) 225 | realworld_x = coordinate_max - grid_x * grid_size[0] 226 | realworld_y = coordinate_min + grid_y * grid_size[1] 227 | return realworld_x, realworld_y 228 | 229 | 230 | def _outline_border(top_down_map): 231 | left_right_block_nav = (top_down_map[:, :-1] == 1) & ( 232 | top_down_map[:, :-1] != top_down_map[:, 1:] 233 | ) 234 | left_right_nav_block = (top_down_map[:, 1:] == 1) & ( 235 | top_down_map[:, :-1] != top_down_map[:, 1:] 236 | ) 237 | 238 | up_down_block_nav = (top_down_map[:-1] == 1) & ( 239 | top_down_map[:-1] != top_down_map[1:] 240 | ) 241 | up_down_nav_block = (top_down_map[1:] == 1) & ( 242 | top_down_map[:-1] != top_down_map[1:] 243 | ) 244 | 245 | top_down_map[:, :-1][left_right_block_nav] = MAP_BORDER_INDICATOR 246 | top_down_map[:, 1:][left_right_nav_block] = MAP_BORDER_INDICATOR 247 | 248 | top_down_map[:-1][up_down_block_nav] = MAP_BORDER_INDICATOR 249 | top_down_map[1:][up_down_nav_block] = MAP_BORDER_INDICATOR 250 | 251 | 252 | def get_topdown_map( 253 | sim: Simulator, 254 | map_resolution: Tuple[int, int] = (1250, 1250), 255 | num_samples: int = 20000, 256 | draw_border: bool = True, 257 | ) -> np.ndarray: 258 | r"""Return a top-down occupancy map for a sim. Note, this only returns valid 259 | values for whatever floor the agent is currently on. 260 | 261 | Args: 262 | sim: The simulator. 263 | map_resolution: The resolution of map which will be computed and 264 | returned. 265 | num_samples: The number of random navigable points which will be 266 | initially 267 | sampled. For large environments it may need to be increased. 268 | draw_border: Whether to outline the border of the occupied spaces. 269 | 270 | Returns: 271 | Image containing 0 if occupied, 1 if unoccupied, and 2 if border (if 272 | the flag is set). 273 | """ 274 | top_down_map = np.zeros(map_resolution, dtype=np.uint8) 275 | border_padding = 3 276 | 277 | start_height = sim.get_agent_state().position[1] 278 | 279 | # Use sampling to find the extrema points that might be navigable. 280 | range_x = (map_resolution[0], 0) 281 | range_y = (map_resolution[1], 0) 282 | for _ in range(num_samples): 283 | point = sim.sample_navigable_point() 284 | # Check if on same level as original 285 | if np.abs(start_height - point[1]) > 0.5: 286 | continue 287 | g_x, g_y = to_grid( 288 | point[0], point[2], COORDINATE_MIN, COORDINATE_MAX, map_resolution 289 | ) 290 | range_x = (min(range_x[0], g_x), max(range_x[1], g_x)) 291 | range_y = (min(range_y[0], g_y), max(range_y[1], g_y)) 292 | 293 | # Pad the range just in case not enough points were sampled to get the true 294 | # extrema. 295 | padding = int(np.ceil(map_resolution[0] / 125)) 296 | range_x = ( 297 | max(range_x[0] - padding, 0), 298 | min(range_x[-1] + padding + 1, top_down_map.shape[0]), 299 | ) 300 | range_y = ( 301 | max(range_y[0] - padding, 0), 302 | min(range_y[-1] + padding + 1, top_down_map.shape[1]), 303 | ) 304 | 305 | # Search over grid for valid points. 306 | for ii in range(range_x[0], range_x[1]): 307 | for jj in range(range_y[0], range_y[1]): 308 | realworld_x, realworld_y = from_grid( 309 | ii, jj, COORDINATE_MIN, COORDINATE_MAX, map_resolution 310 | ) 311 | valid_point = sim.is_navigable( 312 | [realworld_x, start_height, realworld_y] 313 | ) 314 | top_down_map[ii, jj] = ( 315 | MAP_VALID_POINT if valid_point else MAP_INVALID_POINT 316 | ) 317 | 318 | # Draw border if necessary 319 | if draw_border: 320 | # Recompute range in case padding added any more values. 321 | range_x = np.where(np.any(top_down_map, axis=1))[0] 322 | range_y = np.where(np.any(top_down_map, axis=0))[0] 323 | range_x = ( 324 | max(range_x[0] - border_padding, 0), 325 | min(range_x[-1] + border_padding + 1, top_down_map.shape[0]), 326 | ) 327 | range_y = ( 328 | max(range_y[0] - border_padding, 0), 329 | min(range_y[-1] + border_padding + 1, top_down_map.shape[1]), 330 | ) 331 | 332 | _outline_border( 333 | top_down_map[range_x[0]: range_x[1], range_y[0]: range_y[1]] 334 | ) 335 | return top_down_map 336 | 337 | 338 | def colorize_topdown_map( 339 | top_down_map: np.ndarray, 340 | fog_of_war_mask: Optional[np.ndarray] = None, 341 | fog_of_war_desat_amount: float = 0.5, 342 | ) -> np.ndarray: 343 | r"""Convert the top down map to RGB based on the indicator values. 344 | Args: 345 | top_down_map: A non-colored version of the map. 346 | fog_of_war_mask: A mask used to determine which parts of the 347 | top_down_map are visible 348 | Non-visible parts will be desaturated 349 | fog_of_war_desat_amount: Amount to desaturate the color of unexplored areas 350 | Decreasing this value will make unexplored areas darker 351 | Default: 0.5 352 | Returns: 353 | A colored version of the top-down map. 354 | """ 355 | _map = TOP_DOWN_MAP_COLORS[top_down_map] 356 | 357 | if fog_of_war_mask is not None: 358 | fog_of_war_desat_values = np.array([[fog_of_war_desat_amount], [1.0]]) 359 | # Only desaturate things that are valid points as only valid points get revealed 360 | desat_mask = top_down_map != MAP_INVALID_POINT 361 | 362 | _map[desat_mask] = ( 363 | _map * fog_of_war_desat_values[fog_of_war_mask] 364 | ).astype(np.uint8)[desat_mask] 365 | 366 | return _map 367 | 368 | 369 | def draw_path( 370 | top_down_map: np.ndarray, 371 | path_points: List[Tuple], 372 | color: int, 373 | thickness: int = 2, 374 | ) -> None: 375 | r"""Draw path on top_down_map (in place) with specified color. 376 | Args: 377 | top_down_map: A colored version of the map. 378 | color: color code of the path, from TOP_DOWN_MAP_COLORS. 379 | path_points: list of points that specify the path to be drawn 380 | thickness: thickness of the path. 381 | """ 382 | for prev_pt, next_pt in zip(path_points[:-1], path_points[1:]): 383 | # Swapping x y 384 | cv2.line( 385 | top_down_map, 386 | prev_pt[::-1], 387 | next_pt[::-1], 388 | color, 389 | thickness=thickness, 390 | ) 391 | 392 | 393 | # def draw_attack(top_down_map, map_attack_pos) -> np.ndarray: 394 | # yellow = (255, 255, 0) 395 | # top_down_map = cv2.circle( 396 | # top_down_map, 397 | # map_attack_pos, 398 | # 20, 399 | # yellow, 400 | # thickness=-1, 401 | # ) 402 | # return top_down_map 403 | -------------------------------------------------------------------------------- /soundspaces/simulator.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Any, List 3 | from collections import defaultdict, namedtuple 4 | import logging 5 | import time 6 | import pickle 7 | import os 8 | 9 | import librosa 10 | import scipy 11 | from scipy.io import wavfile 12 | from scipy.signal import fftconvolve 13 | import numpy as np 14 | import networkx as nx 15 | 16 | import habitat_sim 17 | from habitat_sim.utils.common import quat_from_angle_axis, quat_from_coeffs, quat_to_angle_axis 18 | from habitat.sims.habitat_simulator.habitat_simulator import HabitatSim 19 | from habitat.sims.habitat_simulator.actions import HabitatSimActions 20 | from habitat.core.simulator import Config, AgentState, ShortestPathPoint 21 | from soundspaces.utils import load_metadata, _to_tensor 22 | 23 | #from habitat.core.registry import registry 24 | # from ss_baselines.common.registry import registry 25 | 26 | class DummySimulator: 27 | """ 28 | Dummy simulator for avoiding loading the scene meshes when using cached observations. 29 | """ 30 | def __init__(self): 31 | self.position = None 32 | self.rotation = None 33 | self._sim_obs = None 34 | 35 | def seed(self, seed): 36 | pass 37 | 38 | def set_agent_state(self, position, rotation): 39 | self.position = np.array(position, dtype=np.float32) 40 | self.rotation = rotation 41 | 42 | def get_agent_state(self): 43 | class State: 44 | def __init__(self, position, rotation): 45 | self.position = position 46 | self.rotation = rotation 47 | 48 | return State(self.position, self.rotation) 49 | 50 | def set_sensor_observations(self, sim_obs): 51 | self._sim_obs = sim_obs 52 | 53 | def get_sensor_observations(self): 54 | return self._sim_obs 55 | 56 | def close(self): 57 | pass 58 | 59 | 60 | # @registry.register_simulator(name="SoundSpaces") 61 | class SoundSpaces(HabitatSim): 62 | r"""Changes made to simulator wrapper over habitat-sim 63 | 64 | This simulator first loads the graph of current environment and moves the agent among nodes. 65 | Any sounds can be specified in the episode and loaded in this simulator. 66 | Args: 67 | config: configuration for initializing the simulator. 68 | """ 69 | 70 | def action_space_shortest_path(self, source: AgentState, targets: List[AgentState], agent_id: int = 0) -> List[ 71 | ShortestPathPoint]: 72 | pass 73 | 74 | def __init__(self, config: Config) -> None: 75 | print("2:sim enter--->SoundSpaces") 76 | super().__init__(config) 77 | self._source_position_index = None 78 | self._receiver_position_index = None 79 | self._rotation_angle = None 80 | self._current_sound = None 81 | self._source_sound_dict = dict() 82 | self._sampling_rate = None 83 | self._node2index = None 84 | self._frame_cache = dict() 85 | self._audiogoal_cache = dict() 86 | self._spectrogram_cache = dict() 87 | self._scene_observations = None 88 | self._episode_step_count = None 89 | self._is_episode_active = None 90 | self._position_to_index_mapping = dict() 91 | self._previous_step_collided = False 92 | 93 | self.points, self.graph = load_metadata(self.metadata_dir) 94 | for node in self.graph.nodes(): 95 | self._position_to_index_mapping[self.position_encoding(self.graph.nodes()[node]['point'])] = node 96 | self._load_sound_sources() 97 | logging.info('Current scene: {} and sound: {}'.format(self.current_scene_name, self._current_sound)) 98 | 99 | if self.config.USE_RENDERED_OBSERVATIONS: 100 | self._sim.close() 101 | del self._sim 102 | self._sim = DummySimulator() 103 | with open(self.current_scene_observation_file, 'rb') as fo: 104 | self._frame_cache = pickle.load(fo) 105 | 106 | def get_agent_state(self, agent_id: int = 0) -> habitat_sim.AgentState: 107 | if not self.config.USE_RENDERED_OBSERVATIONS: 108 | agent_state = super().get_agent_state(agent_id) 109 | else: 110 | agent_state = self._sim.get_agent_state() 111 | 112 | return agent_state 113 | 114 | def set_agent_state( 115 | self, 116 | position: List[float], 117 | rotation: List[float], 118 | agent_id: int = 0, 119 | reset_sensors: bool = True, 120 | ) -> bool: 121 | if not self.config.USE_RENDERED_OBSERVATIONS: 122 | super().set_agent_state(position, rotation, agent_id=agent_id, reset_sensors=reset_sensors) 123 | else: 124 | pass 125 | 126 | @property 127 | def binaural_rir_dir(self): 128 | return os.path.join(self.config.AUDIO.BINAURAL_RIR_DIR, self.config.SCENE_DATASET, self.current_scene_name) 129 | 130 | @property 131 | def source_sound_dir(self): 132 | return self.config.AUDIO.SOURCE_SOUND_DIR 133 | 134 | @property 135 | def metadata_dir(self): 136 | return os.path.join(self.config.AUDIO.METADATA_DIR, self.config.SCENE_DATASET, self.current_scene_name) 137 | 138 | @property 139 | def current_scene_name(self): 140 | # config.SCENE (_current_scene) looks like 'data/scene_datasets/replica/office_1/habitat/mesh_semantic.ply' 141 | return self._current_scene.split('/')[3] 142 | 143 | @property 144 | def current_scene_observation_file(self): 145 | return os.path.join(self.config.SCENE_OBSERVATION_DIR, self.config.SCENE_DATASET, 146 | self.current_scene_name + '.pkl') 147 | 148 | @property 149 | def current_source_sound(self): 150 | return self._source_sound_dict[self._current_sound] 151 | 152 | def reconfigure(self, config: Config) -> None: 153 | self.config = config 154 | is_same_sound = config.AGENT_0.SOUND == self._current_sound 155 | if not is_same_sound: 156 | self._current_sound = self.config.AGENT_0.SOUND 157 | 158 | is_same_scene = config.SCENE == self._current_scene 159 | if not is_same_scene: 160 | self._current_scene = config.SCENE 161 | logging.debug('Current scene: {} and sound: {}'.format(self.current_scene_name, self._current_sound)) 162 | 163 | if not self.config.USE_RENDERED_OBSERVATIONS: 164 | self._sim.close() 165 | del self._sim 166 | self.sim_config = self.create_sim_config(self._sensor_suite) 167 | self._sim = habitat_sim.Simulator(self.sim_config) 168 | self._update_agents_state() 169 | self._frame_cache = dict() 170 | else: 171 | with open(self.current_scene_observation_file, 'rb') as fo: 172 | self._frame_cache = pickle.load(fo) 173 | logging.debug('Loaded scene {}'.format(self.current_scene_name)) 174 | 175 | self.points, self.graph = load_metadata(self.metadata_dir) 176 | for node in self.graph.nodes(): 177 | self._position_to_index_mapping[self.position_encoding(self.graph.nodes()[node]['point'])] = node 178 | 179 | if not is_same_scene or not is_same_sound: 180 | self._audiogoal_cache = dict() 181 | self._spectrogram_cache = dict() 182 | 183 | self._episode_step_count = 0 184 | 185 | # set agent positions 186 | self._receiver_position_index = self._position_to_index(self.config.AGENT_0.START_POSITION) 187 | self._source_position_index = self._position_to_index(self.config.AGENT_0.GOAL_POSITION) 188 | # the agent rotates about +Y starting from -Z counterclockwise, 189 | # so rotation angle 90 means the agent rotate about +Y 90 degrees 190 | self._rotation_angle = int(np.around(np.rad2deg(quat_to_angle_axis(quat_from_coeffs( 191 | self.config.AGENT_0.START_ROTATION))[0]))) % 360 192 | if not self.config.USE_RENDERED_OBSERVATIONS: 193 | self.set_agent_state(list(self.graph.nodes[self._receiver_position_index]['point']), 194 | self.config.AGENT_0.START_ROTATION) 195 | else: 196 | self._sim.set_agent_state(list(self.graph.nodes[self._receiver_position_index]['point']), 197 | quat_from_coeffs(self.config.AGENT_0.START_ROTATION)) 198 | 199 | logging.debug("Initial source, agent at: {}, {}, orientation: {}". 200 | format(self._source_position_index, self._receiver_position_index, self.get_orientation())) 201 | 202 | @staticmethod 203 | def position_encoding(position): 204 | return '{:.2f}_{:.2f}_{:.2f}'.format(*position) 205 | 206 | def _position_to_index(self, position): 207 | if self.position_encoding(position) in self._position_to_index_mapping: 208 | return self._position_to_index_mapping[self.position_encoding(position)] 209 | else: 210 | raise ValueError("Position misalignment.") 211 | 212 | def _get_sim_observation(self): 213 | joint_index = (self._receiver_position_index, self._rotation_angle) 214 | if joint_index in self._frame_cache: 215 | return self._frame_cache[joint_index] 216 | else: 217 | assert not self.config.USE_RENDERED_OBSERVATIONS 218 | sim_obs = self._sim.get_sensor_observations() 219 | for sensor in sim_obs: 220 | sim_obs[sensor] = sim_obs[sensor] 221 | self._frame_cache[joint_index] = sim_obs 222 | return sim_obs 223 | 224 | def reset(self): 225 | logging.debug('Reset simulation') 226 | if not self.config.USE_RENDERED_OBSERVATIONS: 227 | sim_obs = self._sim.reset() 228 | if self._update_agents_state(): 229 | sim_obs = self._get_sim_observation() 230 | else: 231 | sim_obs = self._get_sim_observation() 232 | self._sim.set_sensor_observations(sim_obs) 233 | 234 | self._is_episode_active = True 235 | self._prev_sim_obs = sim_obs 236 | self._previous_step_collided = False 237 | # Encapsule data under Observations class 238 | observations = self._sensor_suite.get_observations(sim_obs) 239 | 240 | return observations 241 | 242 | def step(self, action, only_allowed=True): 243 | """ 244 | All angle calculations in this function is w.r.t habitat coordinate frame, on X-Z plane 245 | where +Y is upward, -Z is forward and +X is rightward. 246 | Angle 0 corresponds to +X, angle 90 corresponds to +y and 290 corresponds to 270. 247 | 248 | :param action: action to be taken 249 | :param only_allowed: if true, then can't step anywhere except allowed locations 250 | :return: 251 | Dict of observations 252 | """ 253 | assert self._is_episode_active, ( 254 | "episode is not active, environment not RESET or " 255 | "STOP action called previously" 256 | ) 257 | 258 | self._previous_step_collided = False 259 | # STOP: 0, FORWARD: 1, LEFT: 2, RIGHT: 2 260 | if action == HabitatSimActions.STOP: 261 | self._is_episode_active = False 262 | else: 263 | prev_position_index = self._receiver_position_index 264 | prev_rotation_angle = self._rotation_angle 265 | if action == HabitatSimActions.MOVE_FORWARD: 266 | # the agent initially faces -Z by default 267 | self._previous_step_collided = True 268 | for neighbor in self.graph[self._receiver_position_index]: 269 | p1 = self.graph.nodes[self._receiver_position_index]['point'] 270 | p2 = self.graph.nodes[neighbor]['point'] 271 | direction = int(np.around(np.rad2deg(np.arctan2(p2[2] - p1[2], p2[0] - p1[0])))) % 360 272 | if direction == self.get_orientation(): 273 | self._receiver_position_index = neighbor 274 | self._previous_step_collided = False 275 | break 276 | elif action == HabitatSimActions.TURN_LEFT: 277 | # agent rotates counterclockwise, so turning left means increasing rotation angle by 90 278 | self._rotation_angle = (self._rotation_angle + 90) % 360 279 | elif action == HabitatSimActions.TURN_RIGHT: 280 | self._rotation_angle = (self._rotation_angle - 90) % 360 281 | 282 | if self.config.CONTINUOUS_VIEW_CHANGE: 283 | intermediate_observations = list() 284 | fps = self.config.VIEW_CHANGE_FPS 285 | if action == HabitatSimActions.MOVE_FORWARD: 286 | prev_position = np.array(self.graph.nodes[prev_position_index]['point']) 287 | current_position = np.array(self.graph.nodes[self._receiver_position_index]['point']) 288 | for i in range(1, fps): 289 | intermediate_position = prev_position + i / fps * (current_position - prev_position) 290 | self.set_agent_state(intermediate_position.tolist(), quat_from_angle_axis(np.deg2rad( 291 | self._rotation_angle), np.array([0, 1, 0]))) 292 | sim_obs = self._sim.get_sensor_observations() 293 | observations = self._sensor_suite.get_observations(sim_obs) 294 | intermediate_observations.append(observations) 295 | else: 296 | for i in range(1, fps): 297 | if action == HabitatSimActions.TURN_LEFT: 298 | intermediate_rotation = prev_rotation_angle + i / fps * 90 299 | elif action == HabitatSimActions.TURN_RIGHT: 300 | intermediate_rotation = prev_rotation_angle - i / fps * 90 301 | self.set_agent_state(list(self.graph.nodes[self._receiver_position_index]['point']), 302 | quat_from_angle_axis(np.deg2rad(intermediate_rotation), 303 | np.array([0, 1, 0]))) 304 | sim_obs = self._sim.get_sensor_observations() 305 | observations = self._sensor_suite.get_observations(sim_obs) 306 | intermediate_observations.append(observations) 307 | 308 | if not self.config.USE_RENDERED_OBSERVATIONS: 309 | self.set_agent_state(list(self.graph.nodes[self._receiver_position_index]['point']), 310 | quat_from_angle_axis(np.deg2rad(self._rotation_angle), np.array([0, 1, 0]))) 311 | else: 312 | self._sim.set_agent_state(list(self.graph.nodes[self._receiver_position_index]['point']), 313 | quat_from_angle_axis(np.deg2rad(self._rotation_angle), np.array([0, 1, 0]))) 314 | self._episode_step_count += 1 315 | 316 | # log debugging info 317 | logging.debug('After taking action {}, s,r: {}, {}, orientation: {}, location: {}'.format( 318 | action, self._source_position_index, self._receiver_position_index, 319 | self.get_orientation(), self.graph.nodes[self._receiver_position_index]['point'])) 320 | 321 | sim_obs = self._get_sim_observation() 322 | if self.config.USE_RENDERED_OBSERVATIONS: 323 | self._sim.set_sensor_observations(sim_obs) 324 | self._prev_sim_obs = sim_obs 325 | observations = self._sensor_suite.get_observations(sim_obs) 326 | if self.config.CONTINUOUS_VIEW_CHANGE: 327 | observations['intermediate'] = intermediate_observations 328 | 329 | return observations 330 | 331 | def get_orientation(self): 332 | _base_orientation = 270 333 | return (_base_orientation - self._rotation_angle) % 360 334 | 335 | @property 336 | def azimuth_angle(self): 337 | # this is the angle used to index the binaural audio files 338 | # in mesh coordinate systems, +Y forward, +X rightward, +Z upward 339 | # azimuth is calculated clockwise so +Y is 0 and +X is 90 340 | return -(self._rotation_angle + 0) % 360 341 | 342 | @property 343 | def reaching_goal(self): 344 | return self._source_position_index == self._receiver_position_index 345 | 346 | def _update_observations_with_audio(self, observations): 347 | audio = self.get_current_audio_observation() 348 | observations.update({"audio": audio}) 349 | 350 | def _load_sound_sources(self): 351 | # load all mono files at once 352 | sound_files = os.listdir(self.source_sound_dir) 353 | for sound_file in sound_files: 354 | sound = sound_file.split('.')[0] 355 | sr, audio_data = wavfile.read(os.path.join(self.source_sound_dir, sound_file)) 356 | assert sr == 44100 357 | if sr != self.config.AUDIO.RIR_SAMPLING_RATE: 358 | audio_data = scipy.signal.resample(audio_data, self.config.AUDIO.RIR_SAMPLING_RATE) 359 | self._source_sound_dict[sound] = audio_data 360 | 361 | def _compute_euclidean_distance_between_sr_locations(self): 362 | p1 = self.graph.nodes[self._receiver_position_index]['point'] 363 | p2 = self.graph.nodes[self._source_position_index]['point'] 364 | d = np.sqrt((p1[0] - p2[0])**2 + (p1[2] - p2[2])**2) 365 | return d 366 | 367 | def _compute_audiogoal(self): 368 | binaural_rir_file = os.path.join(self.binaural_rir_dir, str(self.azimuth_angle), '{}_{}.wav'.format( 369 | self._receiver_position_index, self._source_position_index)) 370 | try: 371 | sampling_freq, binaural_rir = wavfile.read(binaural_rir_file) # float32 372 | 373 | except ValueError: 374 | logging.warning("{} file is not readable".format(binaural_rir_file)) 375 | binaural_rir = np.zeros((self.config.AUDIO.RIR_SAMPLING_RATE, 2)).astype(np.float32) 376 | if len(binaural_rir) == 0: 377 | logging.debug("Empty RIR file at {}".format(binaural_rir_file)) 378 | binaural_rir = np.zeros((self.config.AUDIO.RIR_SAMPLING_RATE, 2)).astype(np.float32) 379 | 380 | # by default, convolve in full mode, which preserves the direct sound 381 | binaural_convolved = [fftconvolve(self.current_source_sound, binaural_rir[:, channel] 382 | ) for channel in range(binaural_rir.shape[-1])] 383 | audiogoal = np.array(binaural_convolved)[:, :self.current_source_sound.shape[0]] 384 | 385 | free_range = audiogoal.shape[1] - self.config.AUDIO.RIR_SAMPLING_RATE 386 | if free_range > 0: 387 | slide_k = np.random.randint(0,high=free_range) 388 | audiogoal = audiogoal[:,slide_k:slide_k+self.config.AUDIO.RIR_SAMPLING_RATE] 389 | 390 | return audiogoal 391 | 392 | def get_current_audiogoal_observation(self): 393 | join_index = (self._source_position_index, self._receiver_position_index, self.azimuth_angle) 394 | if join_index not in self._audiogoal_cache: 395 | self._audiogoal_cache[join_index] = self._compute_audiogoal() 396 | 397 | return self._audiogoal_cache[join_index] 398 | 399 | def get_current_spectrogram_observation(self, audiogoal2spectrogram): 400 | sr_index = (self._source_position_index, self._receiver_position_index) 401 | join_index = sr_index + (self.azimuth_angle,) 402 | if join_index not in self._spectrogram_cache: 403 | audiogoal = self._compute_audiogoal() 404 | spectrogram = audiogoal2spectrogram(audiogoal) 405 | self._spectrogram_cache[join_index] = spectrogram 406 | 407 | return self._spectrogram_cache[join_index] 408 | 409 | def geodesic_distance(self, position_a, position_b): 410 | index_a = self._position_to_index(position_a) 411 | index_b = self._position_to_index(position_b) 412 | assert index_a is not None and index_b is not None 413 | steps = nx.shortest_path_length(self.graph, index_a, index_b) * self.config.GRID_SIZE 414 | 415 | return steps 416 | 417 | def get_straight_shortest_path_points(self, position_a, position_b): 418 | index_a = self._position_to_index(position_a) 419 | index_b = self._position_to_index(position_b) 420 | assert index_a is not None and index_b is not None 421 | 422 | shortest_path = nx.shortest_path(self.graph, source=index_a, target=index_b) 423 | points = list() 424 | for node in shortest_path: 425 | points.append(self.graph.nodes()[node]['point']) 426 | return points 427 | 428 | @property 429 | def previous_step_collided(self): 430 | return self._previous_step_collided 431 | -------------------------------------------------------------------------------- /ss_baselines/common/sync_vector_env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from multiprocessing.connection import Connection 4 | from multiprocessing.context import BaseContext 5 | from queue import Queue 6 | from threading import Thread 7 | from typing import ( 8 | Any, 9 | Callable, 10 | Dict, 11 | List, 12 | Optional, 13 | Sequence, 14 | Set, 15 | Tuple, 16 | Union, 17 | ) 18 | 19 | import gym 20 | import numpy as np 21 | from gym.spaces.dict_space import Dict as SpaceDict 22 | 23 | import habitat 24 | from habitat.config import Config 25 | from habitat.core.env import Env, Observations, RLEnv 26 | from habitat.core.logging import logger 27 | from habitat.core.utils import tile_images 28 | 29 | try: 30 | # Use torch.multiprocessing if we can. 31 | # We have yet to find a reason to not use it and 32 | # you are required to use it when sending a torch.Tensor 33 | # between processes 34 | import torch.multiprocessing as mp 35 | except ImportError: 36 | import multiprocessing as mp 37 | 38 | STEP_COMMAND = "step" 39 | RESET_COMMAND = "reset" 40 | RENDER_COMMAND = "render" 41 | CLOSE_COMMAND = "close" 42 | OBSERVATION_SPACE_COMMAND = "observation_space" 43 | ACTION_SPACE_COMMAND = "action_space" 44 | CALL_COMMAND = "call" 45 | EPISODE_COMMAND = "current_episode" 46 | 47 | 48 | def _make_env_fn( 49 | config: Config, dataset: Optional[habitat.Dataset] = None, rank: int = 0 50 | ) -> Env: 51 | """Constructor for default habitat `env.Env`. 52 | 53 | :param config: configuration for environment. 54 | :param dataset: dataset for environment. 55 | :param rank: rank for setting seed of environment 56 | :return: `env.Env` / `env.RLEnv` object 57 | """ 58 | habitat_env = Env(config=config, dataset=dataset) 59 | habitat_env.seed(config.SEED + rank) 60 | return habitat_env 61 | 62 | 63 | class WorkerEnv: 64 | def __init__(self, env_fn, env_fn_arg, auto_reset_done): 65 | self._env = env_fn(*env_fn_arg) 66 | self._auto_reset_done = auto_reset_done 67 | 68 | def __call__(self, command, data): 69 | while command != CLOSE_COMMAND: 70 | if command == STEP_COMMAND: 71 | # different step methods for habitat.RLEnv and habitat.Env 72 | if isinstance(self._env, habitat.RLEnv) or isinstance( 73 | self._env, gym.Env 74 | ): 75 | # habitat.RLEnv 76 | observations, reward, done, info = self._env.step(**data) 77 | if self._auto_reset_done and done: 78 | observations = self._env.reset() 79 | return observations, reward, done, info 80 | elif isinstance(self._env, habitat.Env): 81 | # habitat.Env 82 | observations = self._env.step(**data) 83 | if self._auto_reset_done and self._env.episode_over: 84 | observations = self._env.reset() 85 | return observations 86 | else: 87 | raise NotImplementedError 88 | 89 | elif command == RESET_COMMAND: 90 | observations = self._env.reset() 91 | return observations 92 | 93 | elif command == RENDER_COMMAND: 94 | return self._env.render(*data[0], **data[1]) 95 | 96 | elif ( 97 | command == OBSERVATION_SPACE_COMMAND 98 | or command == ACTION_SPACE_COMMAND 99 | ): 100 | if isinstance(command, str): 101 | return getattr(self._env, command) 102 | 103 | elif command == CALL_COMMAND: 104 | function_name, function_args = data 105 | if function_args is None or len(function_args) == 0: 106 | result = getattr(self._env, function_name)() 107 | else: 108 | result = getattr(self._env, function_name)(**function_args) 109 | return result 110 | 111 | # TODO: update CALL_COMMAND for getting attribute like this 112 | elif command == EPISODE_COMMAND: 113 | return self._env.current_episode 114 | else: 115 | raise NotImplementedError 116 | 117 | 118 | class SyncVectorEnv: 119 | r"""Vectorized environment which creates multiple processes where each 120 | process runs its own environment. Main class for parallelization of 121 | training and evaluation. 122 | 123 | 124 | All the environments are synchronized on step and reset methods. 125 | """ 126 | 127 | observation_spaces: List[SpaceDict] 128 | action_spaces: List[SpaceDict] 129 | _workers: List[Union[mp.Process, Thread]] 130 | _is_waiting: bool 131 | _num_envs: int 132 | _auto_reset_done: bool 133 | _mp_ctx: BaseContext 134 | _connection_read_fns: List[Callable[[], Any]] 135 | _connection_write_fns: List[Callable[[Any], None]] 136 | 137 | def __init__( 138 | self, 139 | make_env_fn: Callable[..., Union[Env, RLEnv]] = _make_env_fn, 140 | env_fn_args: Sequence[Tuple] = None, 141 | auto_reset_done: bool = True, 142 | multiprocessing_start_method: str = "forkserver", 143 | ) -> None: 144 | """.. 145 | 146 | :param make_env_fn: function which creates a single environment. An 147 | environment can be of type `env.Env` or `env.RLEnv` 148 | :param env_fn_args: tuple of tuple of args to pass to the 149 | `_make_env_fn`. 150 | :param auto_reset_done: automatically reset the environment when 151 | done. This functionality is provided for seamless training 152 | of vectorized environments. 153 | :param multiprocessing_start_method: the multiprocessing method used to 154 | spawn worker processes. Valid methods are 155 | :py:`{'spawn', 'forkserver', 'fork'}`; :py:`'forkserver'` is the 156 | recommended method as it works well with CUDA. If :py:`'fork'` is 157 | used, the subproccess must be started before any other GPU useage. 158 | """ 159 | self._is_waiting = False 160 | self._is_closed = True 161 | 162 | # assert ( 163 | # env_fn_args is not None and len(env_fn_args) > 0 164 | # ), "number of environments to be created should be greater than 0" 165 | 166 | self._num_envs = len(env_fn_args) 167 | 168 | # assert multiprocessing_start_method in self._valid_start_methods, ( 169 | # "multiprocessing_start_method must be one of {}. Got '{}'" 170 | # ).format(self._valid_start_methods, multiprocessing_start_method) 171 | self._auto_reset_done = auto_reset_done 172 | # self._mp_ctx = mp.get_context(multiprocessing_start_method) 173 | # self._workers = [] 174 | # ( 175 | # self._connection_read_fns, 176 | # self._connection_write_fns, 177 | # ) = self._spawn_workers( # noqa 178 | # env_fn_args, make_env_fn 179 | # ) 180 | self.workers = [] 181 | for env_fn_arg in env_fn_args: 182 | worker = WorkerEnv(make_env_fn, env_fn_arg, auto_reset_done=auto_reset_done) 183 | self.workers.append(worker) 184 | 185 | # self._is_closed = False 186 | 187 | # for write_fn in self._connection_write_fns: 188 | # write_fn((OBSERVATION_SPACE_COMMAND, None)) 189 | # self.observation_spaces = [ 190 | # read_fn() for read_fn in self._connection_read_fns 191 | # ] 192 | self.observation_spaces = [worker(OBSERVATION_SPACE_COMMAND, None) for worker in self.workers] 193 | self.action_spaces = [worker(ACTION_SPACE_COMMAND, None) for worker in self.workers] 194 | # for write_fn in self._connection_write_fns: 195 | # write_fn((ACTION_SPACE_COMMAND, None)) 196 | # self.action_spaces = [ 197 | # read_fn() for read_fn in self._connection_read_fns 198 | # ] 199 | self._paused = [] 200 | 201 | @property 202 | def num_envs(self): 203 | r"""number of individual environments. 204 | """ 205 | return self._num_envs - len(self._paused) 206 | 207 | # @staticmethod 208 | # def _worker_env( 209 | # connection_read_fn: Callable, 210 | # connection_write_fn: Callable, 211 | # env_fn: Callable, 212 | # env_fn_args: Tuple[Any], 213 | # auto_reset_done: bool, 214 | # child_pipe: Optional[Connection] = None, 215 | # parent_pipe: Optional[Connection] = None, 216 | # ) -> None: 217 | # r"""process worker for creating and interacting with the environment. 218 | # """ 219 | # env = env_fn(*env_fn_args) 220 | # if parent_pipe is not None: 221 | # parent_pipe.close() 222 | # try: 223 | # command, data = connection_read_fn() 224 | # while command != CLOSE_COMMAND: 225 | # if command == STEP_COMMAND: 226 | # # different step methods for habitat.RLEnv and habitat.Env 227 | # if isinstance(env, habitat.RLEnv) or isinstance( 228 | # env, gym.Env 229 | # ): 230 | # # habitat.RLEnv 231 | # observations, reward, done, info = env.step(**data) 232 | # if auto_reset_done and done: 233 | # observations = env.reset() 234 | # connection_write_fn((observations, reward, done, info)) 235 | # elif isinstance(env, habitat.Env): 236 | # # habitat.Env 237 | # observations = env.step(**data) 238 | # if auto_reset_done and env.episode_over: 239 | # observations = env.reset() 240 | # connection_write_fn(observations) 241 | # else: 242 | # raise NotImplementedError 243 | # 244 | # elif command == RESET_COMMAND: 245 | # observations = env.reset() 246 | # connection_write_fn(observations) 247 | # 248 | # elif command == RENDER_COMMAND: 249 | # connection_write_fn(env.render(*data[0], **data[1])) 250 | # 251 | # elif ( 252 | # command == OBSERVATION_SPACE_COMMAND 253 | # or command == ACTION_SPACE_COMMAND 254 | # ): 255 | # if isinstance(command, str): 256 | # connection_write_fn(getattr(env, command)) 257 | # 258 | # elif command == CALL_COMMAND: 259 | # function_name, function_args = data 260 | # if function_args is None or len(function_args) == 0: 261 | # result = getattr(env, function_name)() 262 | # else: 263 | # result = getattr(env, function_name)(**function_args) 264 | # connection_write_fn(result) 265 | # 266 | # # TODO: update CALL_COMMAND for getting attribute like this 267 | # elif command == EPISODE_COMMAND: 268 | # connection_write_fn(env.current_episode) 269 | # else: 270 | # raise NotImplementedError 271 | # 272 | # command, data = connection_read_fn() 273 | # 274 | # if child_pipe is not None: 275 | # child_pipe.close() 276 | # except KeyboardInterrupt: 277 | # logger.info("Worker KeyboardInterrupt") 278 | # finally: 279 | # env.close() 280 | # 281 | # def _spawn_workers( 282 | # self, 283 | # env_fn_args: Sequence[Tuple], 284 | # make_env_fn: Callable[..., Union[Env, RLEnv]] = _make_env_fn, 285 | # ) -> Tuple[List[Callable[[], Any]], List[Callable[[Any], None]]]: 286 | # parent_connections, worker_connections = zip( 287 | # *[self._mp_ctx.Pipe(duplex=True) for _ in range(self._num_envs)] 288 | # ) 289 | # self._workers = [] 290 | # for worker_conn, parent_conn, env_args in zip( 291 | # worker_connections, parent_connections, env_fn_args 292 | # ): 293 | # ps = self._mp_ctx.Process( 294 | # target=self._worker_env, 295 | # args=( 296 | # worker_conn.recv, 297 | # worker_conn.send, 298 | # make_env_fn, 299 | # env_args, 300 | # self._auto_reset_done, 301 | # worker_conn, 302 | # parent_conn, 303 | # ), 304 | # ) 305 | # self._workers.append(ps) 306 | # ps.daemon = True 307 | # ps.start() 308 | # worker_conn.close() 309 | # return ( 310 | # [p.recv for p in parent_connections], 311 | # [p.send for p in parent_connections], 312 | # ) 313 | 314 | def current_episodes(self): 315 | # self._is_waiting = True 316 | # for write_fn in self._connection_write_fns: 317 | # write_fn((EPISODE_COMMAND, None)) 318 | # results = [] 319 | # for read_fn in self._connection_read_fns: 320 | # results.append(read_fn()) 321 | # self._is_waiting = False 322 | results = [worker(EPISODE_COMMAND, None) for worker in self.workers] 323 | return results 324 | 325 | def reset(self): 326 | r"""Reset all the vectorized environments 327 | 328 | :return: list of outputs from the reset method of envs. 329 | """ 330 | # self._is_waiting = True 331 | # for write_fn in self._connection_write_fns: 332 | # write_fn((RESET_COMMAND, None)) 333 | # results = [] 334 | # for read_fn in self._connection_read_fns: 335 | # results.append(read_fn()) 336 | # self._is_waiting = False 337 | results = [worker(RESET_COMMAND, None) for worker in self.workers] 338 | return results 339 | 340 | def reset_at(self, index_env: int): 341 | r"""Reset in the index_env environment in the vector. 342 | 343 | :param index_env: index of the environment to be reset 344 | :return: list containing the output of reset method of indexed env. 345 | """ 346 | # self._is_waiting = True 347 | # self._connection_write_fns[index_env]((RESET_COMMAND, None)) 348 | # results = [self._connection_read_fns[index_env]()] 349 | # self._is_waiting = False 350 | results = [self.workers[index_env](RESET_COMMAND, None)] 351 | return results 352 | 353 | def step_at(self, index_env: int, action: Dict[str, Any]): 354 | r"""Step in the index_env environment in the vector. 355 | 356 | :param index_env: index of the environment to be stepped into 357 | :param action: action to be taken 358 | :return: list containing the output of step method of indexed env. 359 | """ 360 | # self._is_waiting = True 361 | # self._connection_write_fns[index_env]((STEP_COMMAND, action)) 362 | # results = [self._connection_read_fns[index_env]()] 363 | # self._is_waiting = False 364 | results = [self.workers[index_env](STEP_COMMAND, action)] 365 | return results 366 | 367 | def async_step(self, data: List[Union[int, str, Dict[str, Any]]]) -> None: 368 | r"""Asynchronously step in the environments. 369 | 370 | :param data: list of size _num_envs containing keyword arguments to 371 | pass to `step` method for each Environment. For example, 372 | :py:`[{"action": "TURN_LEFT", "action_args": {...}}, ...]`. 373 | """ 374 | # Backward compatibility 375 | if isinstance(data[0], (int, np.integer, str)): 376 | data = [{"action": {"action": action}} for action in data] 377 | 378 | self._is_waiting = True 379 | for write_fn, args in zip(self._connection_write_fns, data): 380 | write_fn((STEP_COMMAND, args)) 381 | 382 | def wait_step(self) -> List[Observations]: 383 | r"""Wait until all the asynchronized environments have synchronized. 384 | """ 385 | observations = [] 386 | for read_fn in self._connection_read_fns: 387 | observations.append(read_fn()) 388 | self._is_waiting = False 389 | return observations 390 | 391 | def step(self, data: List[Union[int, str, Dict[str, Any]]]) -> List[Any]: 392 | r"""Perform actions in the vectorized environments. 393 | 394 | :param data: list of size _num_envs containing keyword arguments to 395 | pass to `step` method for each Environment. For example, 396 | :py:`[{"action": "TURN_LEFT", "action_args": {...}}, ...]`. 397 | :return: list of outputs from the step method of envs. 398 | """ 399 | if isinstance(data[0], (int, np.integer, str)): 400 | data = [{"action": {"action": action}} for action in data] 401 | # self.async_step(data) 402 | # return self.wait_step() 403 | results = [worker(STEP_COMMAND, args) for worker, args in zip(self.workers, data)] 404 | return results 405 | 406 | def close(self) -> None: 407 | if self._is_closed: 408 | return 409 | 410 | # if self._is_waiting: 411 | # for read_fn in self._connection_read_fns: 412 | # read_fn() 413 | # 414 | # for write_fn in self._connection_write_fns: 415 | # write_fn((CLOSE_COMMAND, None)) 416 | # 417 | # for _, _, write_fn, _ in self._paused: 418 | # write_fn((CLOSE_COMMAND, None)) 419 | # 420 | # for process in self._workers: 421 | # process.join() 422 | # 423 | # for _, _, _, process in self._paused: 424 | # process.join() 425 | for worker in self.workers: 426 | worker(CLOSE_COMMAND, None) 427 | 428 | self._is_closed = True 429 | 430 | def pause_at(self, index: int) -> None: 431 | r"""Pauses computation on this env without destroying the env. 432 | 433 | :param index: which env to pause. All indexes after this one will be 434 | shifted down by one. 435 | 436 | This is useful for not needing to call steps on all environments when 437 | only some are active (for example during the last episodes of running 438 | eval episodes). 439 | """ 440 | # if self._is_waiting: 441 | # for read_fn in self._connection_read_fns: 442 | # read_fn() 443 | # read_fn = self._connection_read_fns.pop(index) 444 | # write_fn = self._connection_write_fns.pop(index) 445 | # worker = self._workers.pop(index) 446 | # self._paused.append((index, read_fn, write_fn, worker)) 447 | worker = self.workers.pop(index) 448 | self._paused.append((index, worker)) 449 | 450 | def resume_all(self) -> None: 451 | r"""Resumes any paused envs. 452 | """ 453 | # for index, read_fn, write_fn, worker in reversed(self._paused): 454 | # self._connection_read_fns.insert(index, read_fn) 455 | # self._connection_write_fns.insert(index, write_fn) 456 | # self._workers.insert(index, worker) 457 | # self._paused = [] 458 | for index, worker in reversed(self._paused): 459 | self.workers.insert(index, worker) 460 | self._paused = [] 461 | 462 | def call_at( 463 | self, 464 | index: int, 465 | function_name: str, 466 | function_args: Optional[Dict[str, Any]] = None, 467 | ) -> Any: 468 | r"""Calls a function (which is passed by name) on the selected env and 469 | returns the result. 470 | 471 | :param index: which env to call the function on. 472 | :param function_name: the name of the function to call on the env. 473 | :param function_args: optional function args. 474 | :return: result of calling the function. 475 | """ 476 | # getattr(foo, 'bar')() 477 | result = getattr(self.workers[index]._env, function_name)(**function_args) 478 | # self._is_waiting = True 479 | # self._connection_write_fns[index]( 480 | # (CALL_COMMAND, (function_name, function_args)) 481 | # ) 482 | # result = self._connection_read_fns[index]() 483 | # self._is_waiting = False 484 | return result 485 | 486 | def call( 487 | self, 488 | function_names: List[str], 489 | function_args_list: Optional[List[Any]] = None, 490 | ) -> List[Any]: 491 | r"""Calls a list of functions (which are passed by name) on the 492 | corresponding env (by index). 493 | 494 | :param function_names: the name of the functions to call on the envs. 495 | :param function_args_list: list of function args for each function. If 496 | provided, :py:`len(function_args_list)` should be as long as 497 | :py:`len(function_names)`. 498 | :return: result of calling the function. 499 | """ 500 | self._is_waiting = True 501 | if function_args_list is None: 502 | function_args_list = [None] * len(function_names) 503 | assert len(function_names) == len(function_args_list) 504 | func_args = zip(function_names, function_args_list) 505 | for write_fn, func_args_on in zip( 506 | self._connection_write_fns, func_args 507 | ): 508 | write_fn((CALL_COMMAND, func_args_on)) 509 | results = [] 510 | for read_fn in self._connection_read_fns: 511 | results.append(read_fn()) 512 | self._is_waiting = False 513 | return results 514 | 515 | def render( 516 | self, mode: str = "human", *args, **kwargs 517 | ) -> Union[np.ndarray, None]: 518 | r"""Render observations from all environments in a tiled image. 519 | """ 520 | for write_fn in self._connection_write_fns: 521 | write_fn((RENDER_COMMAND, (args, {"mode": "rgb", **kwargs}))) 522 | images = [read_fn() for read_fn in self._connection_read_fns] 523 | tile = tile_images(images) 524 | if mode == "human": 525 | from habitat.core.utils import try_cv2_import 526 | 527 | cv2 = try_cv2_import() 528 | 529 | cv2.imshow("vecenv", tile[:, :, ::-1]) 530 | cv2.waitKey(1) 531 | return None 532 | elif mode == "rgb_array": 533 | return tile 534 | else: 535 | raise NotImplementedError 536 | 537 | @property 538 | def _valid_start_methods(self) -> Set[str]: 539 | return {"forkserver", "spawn", "fork"} 540 | 541 | def __del__(self): 542 | self.close() 543 | 544 | def __enter__(self): 545 | return self 546 | 547 | def __exit__(self, exc_type, exc_val, exc_tb): 548 | self.close() 549 | --------------------------------------------------------------------------------