├── audio_separation ├── common │ ├── __init__.py │ ├── baseline_registry.py │ ├── tensorboard_utils.py │ ├── benchmark.py │ ├── environments.py │ ├── utils.py │ ├── base_trainer.py │ └── sync_vector_env.py ├── config │ ├── __init__.py │ ├── pretrain_passive.yaml │ ├── val │ │ ├── farTarget.yaml │ │ ├── nearTarget.yaml │ │ ├── farTarget_unheard.yaml │ │ └── nearTarget_unheard.yaml │ ├── test │ │ ├── nearTarget.yaml │ │ ├── nearTarget_unheard.yaml │ │ ├── farTarget.yaml │ │ └── farTarget_unheard.yaml │ ├── train │ │ ├── farTarget.yaml │ │ └── nearTarget.yaml │ └── default.py ├── rl │ ├── models │ │ ├── __init__.py │ │ ├── memory_nets.py │ │ ├── rnn_state_encoder.py │ │ ├── audio_cnn.py │ │ ├── visual_cnn.py │ │ └── separator_cnn.py │ ├── ppo │ │ ├── __init__.py │ │ ├── ddppo_utils.py │ │ ├── policy.py │ │ └── ppo.py │ └── __init__.py ├── pretrain │ ├── datasets │ │ ├── __init__.py │ │ └── dataset.py │ ├── passive │ │ ├── __init__.py │ │ ├── passive.py │ │ ├── policy.py │ │ └── passive_trainer.py │ └── __init__.py └── __init__.py ├── gfx └── concept.png ├── .gitignore ├── habitat_audio ├── __init__.py ├── action_space_separation.py ├── utils.py ├── dataset.py └── task.py ├── LICENSE ├── configs └── tasks │ ├── farTarget │ ├── train_farTarget.yaml │ ├── test_farTarget.yaml │ ├── val_farTarget.yaml │ ├── testUnheard_farTarget.yaml │ └── valUnheard_farTarget.yaml │ ├── nearTarget │ ├── train_nearTarget.yaml │ ├── test_nearTarget.yaml │ ├── val_nearTarget.yaml │ ├── testUnheard_nearTarget.yaml │ └── valUnheard_nearTarget.yaml │ └── pretrain_passive.yaml ├── scripts ├── search_for_checkpoint_thru_validation │ ├── link_ckpts_for_val.ipynb │ └── find_bestCkpt_lowestValSTFTLoss.ipynb ├── farTarget_eval │ └── copy_individualCkptsNCfgs_switchPolicyEval.ipynb └── separated_audio_quality │ └── compute_separation_qualtiy.ipynb ├── main.py ├── requirements.txt └── README.md /audio_separation/common/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audio_separation/config/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /audio_separation/rl/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /gfx/concept.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAGNIKMJR/move2hear-active-AV-separation/HEAD/gfx/concept.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | *.egg-info 3 | .ipynb_checkpoints/ 4 | __pycache__ 5 | runs 6 | pretrained_models/ 7 | my_config* 8 | my_scripts 9 | -------------------------------------------------------------------------------- /audio_separation/pretrain/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from audio_separation.pretrain.datasets.dataset import PassiveDataset 2 | 3 | __all__ = ["PassiveDataset"] 4 | -------------------------------------------------------------------------------- /audio_separation/pretrain/passive/__init__.py: -------------------------------------------------------------------------------- 1 | from audio_separation.pretrain.passive.policy import Policy, Move2HearPassiveWoMemoryPolicy 2 | 3 | __all__ = ["Policy", "Move2HearPassiveWoMemoryPolicy"] 4 | -------------------------------------------------------------------------------- /audio_separation/rl/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from audio_separation.rl.ppo.policy import Net, Policy, Move2HearPolicy 2 | from audio_separation.rl.ppo.ppo import PPO, DDPPO 3 | 4 | __all__ = ["PPO", "DDPPO", "Policy", "Net", "Move2HearPolicy",] 5 | -------------------------------------------------------------------------------- /habitat_audio/__init__.py: -------------------------------------------------------------------------------- 1 | from habitat_audio.action_space_separation import * 2 | from habitat_audio.simulator_train import * 3 | from habitat_audio.simulator_eval import * 4 | from habitat_audio.dataset import * 5 | from habitat_audio.task import * -------------------------------------------------------------------------------- /audio_separation/pretrain/__init__.py: -------------------------------------------------------------------------------- 1 | from audio_separation.common.base_trainer import BaseRLTrainer, BaseTrainer 2 | from audio_separation.pretrain.passive.passive_trainer import PassiveTrainer 3 | 4 | 5 | __all__ = ["BaseTrainer", "BaseRLTrainer", "PassiveTrainer"] 6 | -------------------------------------------------------------------------------- /audio_separation/config/pretrain_passive.yaml: -------------------------------------------------------------------------------- 1 | BASE_TASK_CONFIG_PATH: "configs/tasks/pretrain_passive.yaml" 2 | SENSORS: [] 3 | 4 | TRAINER_NAME: "passive" 5 | NUM_EPOCHS: 1000 6 | 7 | 8 | Pretrain: 9 | Passive: 10 | ### HYPERPARAMS 11 | lr: 5.0e-4 12 | eps: 1.0e-5 13 | max_grad_norm: 0.8 14 | -------------------------------------------------------------------------------- /audio_separation/rl/__init__.py: -------------------------------------------------------------------------------- 1 | from audio_separation.common.base_trainer import BaseRLTrainer, BaseTrainer 2 | from audio_separation.rl.ppo.ppo_trainer import PPOTrainer, RolloutStoragePol, RolloutStorageSep 3 | 4 | __all__ = ["BaseTrainer", "BaseRLTrainer", "PPOTrainer", "RolloutStoragePol", "RolloutStorageSep"] 5 | -------------------------------------------------------------------------------- /audio_separation/__init__.py: -------------------------------------------------------------------------------- 1 | from audio_separation.rl.ppo.ppo_trainer import PPOTrainer, RolloutStoragePol, RolloutStorageSep 2 | from audio_separation.pretrain.passive.passive_trainer import PassiveTrainer 3 | 4 | 5 | __all__ = ["BaseTrainer", "BaseRLTrainer", "PPOTrainer", "RolloutStoragePol", "RolloutStorageSep", "PassiveTrainer"] 6 | -------------------------------------------------------------------------------- /audio_separation/config/val/farTarget.yaml: -------------------------------------------------------------------------------- 1 | BASE_TASK_CONFIG_PATH: "configs/tasks/farTarget/val_farTarget.yaml" 2 | NUM_PROCESSES: 1 3 | SENSORS: ["RGB_SENSOR", "DEPTH_SENSOR"] 4 | EXTRA_DEPTH: True 5 | 6 | EVAL_EPISODE_COUNT: 100 7 | 8 | EVAL: 9 | SPLIT: "val_farTarget_8scenes_100episodes" 10 | USE_CKPT_CONFIG: True 11 | 12 | RL: 13 | PPO: 14 | deterministic_eval: False 15 | hidden_size: 512 16 | 17 | # needed to turn off bn 18 | use_ddppo: True 19 | -------------------------------------------------------------------------------- /audio_separation/config/val/nearTarget.yaml: -------------------------------------------------------------------------------- 1 | BASE_TASK_CONFIG_PATH: "configs/tasks/nearTarget/val_nearTarget.yaml" 2 | NUM_PROCESSES: 1 3 | SENSORS: ["RGB_SENSOR", "DEPTH_SENSOR"] 4 | EXTRA_DEPTH: True 5 | 6 | EVAL_EPISODE_COUNT: 100 7 | 8 | EVAL: 9 | SPLIT: "val_nearTarget_8scenes_100episodes" 10 | USE_CKPT_CONFIG: True 11 | 12 | RL: 13 | PPO: 14 | deterministic_eval: False 15 | hidden_size: 512 16 | 17 | # needed to turn off bn 18 | use_ddppo: True 19 | -------------------------------------------------------------------------------- /audio_separation/config/val/farTarget_unheard.yaml: -------------------------------------------------------------------------------- 1 | BASE_TASK_CONFIG_PATH: "configs/tasks/farTarget/valUnheard_farTarget.yaml" 2 | NUM_PROCESSES: 1 3 | SENSORS: ["RGB_SENSOR", "DEPTH_SENSOR"] 4 | EXTRA_DEPTH: True 5 | 6 | EVAL_EPISODE_COUNT: 100 7 | 8 | EVAL: 9 | SPLIT: "valUnheard_farTarget_8scenes_100episodes" 10 | USE_CKPT_CONFIG: True 11 | 12 | RL: 13 | PPO: 14 | deterministic_eval: False 15 | hidden_size: 512 16 | 17 | # needed to turn off bn 18 | use_ddppo: True 19 | -------------------------------------------------------------------------------- /audio_separation/config/val/nearTarget_unheard.yaml: -------------------------------------------------------------------------------- 1 | BASE_TASK_CONFIG_PATH: "configs/tasks/nearTarget/valUnheard_nearTarget.yaml" 2 | NUM_PROCESSES: 1 3 | SENSORS: ["RGB_SENSOR", "DEPTH_SENSOR"] 4 | EXTRA_DEPTH: True 5 | 6 | EVAL_EPISODE_COUNT: 100 7 | 8 | EVAL: 9 | SPLIT: "valUnheard_nearTarget_8scenes_100episodes" 10 | USE_CKPT_CONFIG: True 11 | 12 | RL: 13 | PPO: 14 | deterministic_eval: False 15 | hidden_size: 512 16 | 17 | # needed to turn off bn 18 | use_ddppo: True 19 | -------------------------------------------------------------------------------- /audio_separation/config/test/nearTarget.yaml: -------------------------------------------------------------------------------- 1 | BASE_TASK_CONFIG_PATH: "configs/tasks/nearTarget/test_nearTarget.yaml" 2 | NUM_PROCESSES: 1 3 | SENSORS: ["RGB_SENSOR", "DEPTH_SENSOR"] 4 | EXTRA_DEPTH: True 5 | 6 | EVAL_EPISODE_COUNT: 1000 7 | 8 | COMPUTE_EVAL_METRICS: True 9 | EVAL_METRICS_TO_COMPUTE: ["si_sdr"] 10 | 11 | EVAL: 12 | SPLIT: "test_nearTarget_15scenes_1000episodes" 13 | USE_CKPT_CONFIG: True 14 | 15 | RL: 16 | PPO: 17 | deterministic_eval: False 18 | hidden_size: 512 19 | 20 | # needed to turn off bn 21 | use_ddppo: True 22 | -------------------------------------------------------------------------------- /audio_separation/config/test/nearTarget_unheard.yaml: -------------------------------------------------------------------------------- 1 | BASE_TASK_CONFIG_PATH: "configs/tasks/nearTarget/testUnheard_nearTarget.yaml" 2 | NUM_PROCESSES: 1 3 | SENSORS: ["RGB_SENSOR", "DEPTH_SENSOR"] 4 | EXTRA_DEPTH: True 5 | 6 | EVAL_EPISODE_COUNT: 1000 7 | 8 | COMPUTE_EVAL_METRICS: True 9 | EVAL_METRICS_TO_COMPUTE: ["si_sdr"] 10 | 11 | EVAL: 12 | SPLIT: "testUnheard_nearTarget_15scenes_1000episodes" 13 | USE_CKPT_CONFIG: True 14 | 15 | RL: 16 | PPO: 17 | deterministic_eval: False 18 | hidden_size: 512 19 | 20 | # needed to turn off bn 21 | use_ddppo: True 22 | -------------------------------------------------------------------------------- /audio_separation/config/test/farTarget.yaml: -------------------------------------------------------------------------------- 1 | BASE_TASK_CONFIG_PATH: "configs/tasks/farTarget/test_farTarget.yaml" 2 | NUM_PROCESSES: 1 3 | SENSORS: ["RGB_SENSOR", "DEPTH_SENSOR"] 4 | EXTRA_DEPTH: True 5 | 6 | EVAL_EPISODE_COUNT: 1000 7 | 8 | COMPUTE_EVAL_METRICS: True 9 | EVAL_METRICS_TO_COMPUTE: ["si_sdr"] 10 | 11 | EVAL: 12 | SPLIT: "test_farTarget_15scenes_1000episodes" 13 | 14 | RL: 15 | PPO: 16 | deterministic_eval: False 17 | hidden_size: 512 18 | 19 | switch_policy: True 20 | time_thres_for_pol_switch: 80 21 | 22 | # needed to turn off bn 23 | use_ddppo: True 24 | -------------------------------------------------------------------------------- /audio_separation/config/test/farTarget_unheard.yaml: -------------------------------------------------------------------------------- 1 | BASE_TASK_CONFIG_PATH: "configs/tasks/farTarget/testUnheard_farTarget.yaml" 2 | NUM_PROCESSES: 1 3 | SENSORS: ["RGB_SENSOR", "DEPTH_SENSOR"] 4 | EXTRA_DEPTH: True 5 | 6 | EVAL_EPISODE_COUNT: 1000 7 | 8 | COMPUTE_EVAL_METRICS: True 9 | EVAL_METRICS_TO_COMPUTE: ["si_sdr"] 10 | 11 | EVAL: 12 | SPLIT: "testUnheard_farTarget_15scenes_1000episodes" 13 | 14 | RL: 15 | PPO: 16 | deterministic_eval: False 17 | hidden_size: 512 18 | 19 | switch_policy: True 20 | time_thres_for_pol_switch: 80 21 | 22 | # needed to turn off bn 23 | use_ddppo: True 24 | -------------------------------------------------------------------------------- /audio_separation/pretrain/passive/passive.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Passive(nn.Module): 5 | def __init__( 6 | self, 7 | actor_critic, 8 | ): 9 | super().__init__() 10 | self.actor_critic = actor_critic 11 | 12 | def forward(self, *x): 13 | raise NotImplementedError 14 | 15 | def update(self, rollouts): 16 | raise NotImplementedError 17 | 18 | def before_backward(self, loss): 19 | pass 20 | 21 | def after_backward(self, loss): 22 | pass 23 | 24 | def before_step(self): 25 | pass 26 | 27 | def after_step(self): 28 | pass 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Sagnik Majumder 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/tasks/farTarget/train_farTarget.yaml: -------------------------------------------------------------------------------- 1 | ENVIRONMENT: 2 | MAX_EPISODE_STEPS: 80 3 | SIMULATOR: 4 | SCENE_DATASET: "mp3d" 5 | GRID_SIZE: 1.0 6 | HABITAT_SIM_V0: 7 | GPU_DEVICE_ID: 0 8 | RGB_SENSOR: 9 | WIDTH: 128 10 | HEIGHT: 128 11 | DEPTH_SENSOR: 12 | WIDTH: 128 13 | HEIGHT: 128 14 | AUDIO: 15 | RIR_SAMPLING_RATE: 16000 16 | NORM_TYPE: "l2" 17 | GT_MONO_MAG_NORM: 1.2 18 | 19 | MONO_DIR: "data/audio_data/VoxCelebV1TenClasses_MITMusic_ESC50/train_preprocessed" 20 | 21 | TYPE: "HabitatSimAudioEnabledTrain" 22 | ACTION_SPACE_CONFIG: "audio-separation" 23 | 24 | TASK: 25 | TYPE: AAViSS 26 | SENSORS: ["MIXED_BIN_AUDIO_MAG_SENSOR", "GT_BIN_COMPONENTS_SENSOR", "GT_MONO_COMPONENTS_SENSOR", 27 | "TARGET_CLASS_SENSOR"] 28 | GOAL_SENSOR_UUID: mixed_bin_audio_mag 29 | MEASUREMENTS: ["GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE", "NORMALIZED_GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE"] 30 | POSSIBLE_ACTIONS: ["MOVE_FORWARD", "TURN_LEFT", "TURN_RIGHT"] 31 | 32 | 33 | DATASET: 34 | TYPE: "AAViSS" 35 | SPLIT: "train_farTarget_24scenes_112009episodes" 36 | VERSION: 'v1' 37 | CONTENT_SCENES: ["*"] 38 | SCENES_DIR: "../sound_spaces/scene_datasets/mp3d" 39 | DATA_PATH: "data/active_datasets/{version}/{split}/{split}.json.gz" 40 | -------------------------------------------------------------------------------- /configs/tasks/nearTarget/train_nearTarget.yaml: -------------------------------------------------------------------------------- 1 | ENVIRONMENT: 2 | MAX_EPISODE_STEPS: 20 3 | SIMULATOR: 4 | SCENE_DATASET: "mp3d" 5 | GRID_SIZE: 1.0 6 | HABITAT_SIM_V0: 7 | GPU_DEVICE_ID: 0 8 | RGB_SENSOR: 9 | WIDTH: 128 10 | HEIGHT: 128 11 | DEPTH_SENSOR: 12 | WIDTH: 128 13 | HEIGHT: 128 14 | AUDIO: 15 | RIR_SAMPLING_RATE: 16000 16 | NORM_TYPE: "l2" 17 | GT_MONO_MAG_NORM: 1.2 18 | 19 | MONO_DIR: "data/audio_data/VoxCelebV1TenClasses_MITMusic_ESC50/train_preprocessed" 20 | 21 | TYPE: "HabitatSimAudioEnabledTrain" 22 | ACTION_SPACE_CONFIG: "audio-separation" 23 | 24 | TASK: 25 | TYPE: AAViSS 26 | SENSORS: ["MIXED_BIN_AUDIO_MAG_SENSOR", "GT_BIN_COMPONENTS_SENSOR", "GT_MONO_COMPONENTS_SENSOR", 27 | "TARGET_CLASS_SENSOR"] 28 | GOAL_SENSOR_UUID: mixed_bin_audio_mag 29 | MEASUREMENTS: ["GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE", "NORMALIZED_GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE"] 30 | POSSIBLE_ACTIONS: ["MOVE_FORWARD", "TURN_LEFT", "TURN_RIGHT"] 31 | 32 | 33 | DATASET: 34 | TYPE: "AAViSS" 35 | SPLIT: "train_nearTarget_24scenes_112009episodes" 36 | VERSION: 'v1' 37 | CONTENT_SCENES: ["*"] 38 | SCENES_DIR: "../sound_spaces/scene_datasets/mp3d" 39 | DATA_PATH: "data/active_datasets/{version}/{split}/{split}.json.gz" 40 | -------------------------------------------------------------------------------- /configs/tasks/farTarget/test_farTarget.yaml: -------------------------------------------------------------------------------- 1 | ENVIRONMENT: 2 | MAX_EPISODE_STEPS: 100 3 | SIMULATOR: 4 | SCENE_DATASET: "mp3d" 5 | GRID_SIZE: 1.0 6 | HABITAT_SIM_V0: 7 | GPU_DEVICE_ID: 0 8 | RGB_SENSOR: 9 | WIDTH: 128 10 | HEIGHT: 128 11 | DEPTH_SENSOR: 12 | WIDTH: 128 13 | HEIGHT: 128 14 | AUDIO: 15 | RIR_SAMPLING_RATE: 16000 16 | NORM_TYPE: "l2" 17 | GT_MONO_MAG_NORM: 1.2 18 | 19 | MONO_DIR: "data/audio_data/VoxCelebV1TenClasses_MITMusic_ESC50/train_preprocessed" 20 | 21 | TYPE: "HabitatSimAudioEnabledEval" 22 | ACTION_SPACE_CONFIG: "audio-separation" 23 | 24 | TASK: 25 | TYPE: AAViSS 26 | SENSORS: ["MIXED_BIN_AUDIO_MAG_SENSOR", "MIXED_BIN_AUDIO_PHASE_SENSOR", "GT_BIN_COMPONENTS_SENSOR", 27 | "GT_MONO_COMPONENTS_SENSOR", "TARGET_CLASS_SENSOR"] 28 | GOAL_SENSOR_UUID: mixed_bin_audio_mag 29 | MEASUREMENTS: ["GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE", "NORMALIZED_GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE"] 30 | POSSIBLE_ACTIONS: ["MOVE_FORWARD", "TURN_LEFT", "TURN_RIGHT"] 31 | 32 | 33 | DATASET: 34 | TYPE: "AAViSS" 35 | SPLIT: "train_farTarget_24scenes_112009episodes" 36 | VERSION: 'v1' 37 | CONTENT_SCENES: ["*"] 38 | SCENES_DIR: "../sound_spaces/scene_datasets/mp3d" 39 | DATA_PATH: "data/active_datasets/{version}/{split}/{split}.json.gz" 40 | -------------------------------------------------------------------------------- /configs/tasks/farTarget/val_farTarget.yaml: -------------------------------------------------------------------------------- 1 | ENVIRONMENT: 2 | MAX_EPISODE_STEPS: 80 3 | SIMULATOR: 4 | SCENE_DATASET: "mp3d" 5 | GRID_SIZE: 1.0 6 | HABITAT_SIM_V0: 7 | GPU_DEVICE_ID: 0 8 | RGB_SENSOR: 9 | WIDTH: 128 10 | HEIGHT: 128 11 | DEPTH_SENSOR: 12 | WIDTH: 128 13 | HEIGHT: 128 14 | AUDIO: 15 | RIR_SAMPLING_RATE: 16000 16 | NORM_TYPE: "l2" 17 | GT_MONO_MAG_NORM: 1.2 18 | 19 | MONO_DIR: "data/audio_data/VoxCelebV1TenClasses_MITMusic_ESC50/train_preprocessed" 20 | 21 | TYPE: "HabitatSimAudioEnabledTrain" 22 | ACTION_SPACE_CONFIG: "audio-separation" 23 | 24 | TASK: 25 | TYPE: AAViSS 26 | SENSORS: ["MIXED_BIN_AUDIO_MAG_SENSOR", "MIXED_BIN_AUDIO_PHASE_SENSOR", "GT_BIN_COMPONENTS_SENSOR", 27 | "GT_MONO_COMPONENTS_SENSOR", "TARGET_CLASS_SENSOR"] 28 | GOAL_SENSOR_UUID: mixed_bin_audio_mag 29 | MEASUREMENTS: ["GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE", "NORMALIZED_GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE"] 30 | POSSIBLE_ACTIONS: ["MOVE_FORWARD", "TURN_LEFT", "TURN_RIGHT"] 31 | 32 | 33 | DATASET: 34 | TYPE: "AAViSS" 35 | SPLIT: "train_farTarget_24scenes_112009episodes" 36 | VERSION: 'v1' 37 | CONTENT_SCENES: ["*"] 38 | SCENES_DIR: "../sound_spaces/scene_datasets/mp3d" 39 | DATA_PATH: "data/active_datasets/{version}/{split}/{split}.json.gz" 40 | -------------------------------------------------------------------------------- /configs/tasks/farTarget/testUnheard_farTarget.yaml: -------------------------------------------------------------------------------- 1 | ENVIRONMENT: 2 | MAX_EPISODE_STEPS: 100 3 | SIMULATOR: 4 | SCENE_DATASET: "mp3d" 5 | GRID_SIZE: 1.0 6 | HABITAT_SIM_V0: 7 | GPU_DEVICE_ID: 0 8 | RGB_SENSOR: 9 | WIDTH: 128 10 | HEIGHT: 128 11 | DEPTH_SENSOR: 12 | WIDTH: 128 13 | HEIGHT: 128 14 | AUDIO: 15 | RIR_SAMPLING_RATE: 16000 16 | NORM_TYPE: "l2" 17 | GT_MONO_MAG_NORM: 1.2 18 | 19 | MONO_DIR: "data/audio_data/VoxCelebV1TenClasses_MITMusic_ESC50/test_preprocessed" 20 | 21 | TYPE: "HabitatSimAudioEnabledEval" 22 | ACTION_SPACE_CONFIG: "audio-separation" 23 | 24 | TASK: 25 | TYPE: AAViSS 26 | SENSORS: ["MIXED_BIN_AUDIO_MAG_SENSOR", "MIXED_BIN_AUDIO_PHASE_SENSOR", "GT_BIN_COMPONENTS_SENSOR", 27 | "GT_MONO_COMPONENTS_SENSOR", "TARGET_CLASS_SENSOR"] 28 | GOAL_SENSOR_UUID: mixed_bin_audio_mag 29 | MEASUREMENTS: ["GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE", "NORMALIZED_GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE"] 30 | POSSIBLE_ACTIONS: ["MOVE_FORWARD", "TURN_LEFT", "TURN_RIGHT"] 31 | 32 | 33 | DATASET: 34 | TYPE: "AAViSS" 35 | SPLIT: "train_farTarget_24scenes_112009episodes" 36 | VERSION: 'v1' 37 | CONTENT_SCENES: ["*"] 38 | SCENES_DIR: "../sound_spaces/scene_datasets/mp3d" 39 | DATA_PATH: "data/active_datasets/{version}/{split}/{split}.json.gz" 40 | -------------------------------------------------------------------------------- /configs/tasks/farTarget/valUnheard_farTarget.yaml: -------------------------------------------------------------------------------- 1 | ENVIRONMENT: 2 | MAX_EPISODE_STEPS: 80 3 | SIMULATOR: 4 | SCENE_DATASET: "mp3d" 5 | GRID_SIZE: 1.0 6 | HABITAT_SIM_V0: 7 | GPU_DEVICE_ID: 0 8 | RGB_SENSOR: 9 | WIDTH: 128 10 | HEIGHT: 128 11 | DEPTH_SENSOR: 12 | WIDTH: 128 13 | HEIGHT: 128 14 | AUDIO: 15 | RIR_SAMPLING_RATE: 16000 16 | NORM_TYPE: "l2" 17 | GT_MONO_MAG_NORM: 1.2 18 | 19 | MONO_DIR: "data/audio_data/VoxCelebV1TenClasses_MITMusic_ESC50/val_preprocessed" 20 | 21 | TYPE: "HabitatSimAudioEnabledTrain" 22 | ACTION_SPACE_CONFIG: "audio-separation" 23 | 24 | TASK: 25 | TYPE: AAViSS 26 | SENSORS: ["MIXED_BIN_AUDIO_MAG_SENSOR", "MIXED_BIN_AUDIO_PHASE_SENSOR", "GT_BIN_COMPONENTS_SENSOR", 27 | "GT_MONO_COMPONENTS_SENSOR", "TARGET_CLASS_SENSOR"] 28 | GOAL_SENSOR_UUID: mixed_bin_audio_mag 29 | MEASUREMENTS: ["GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE", "NORMALIZED_GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE"] 30 | POSSIBLE_ACTIONS: ["MOVE_FORWARD", "TURN_LEFT", "TURN_RIGHT"] 31 | 32 | 33 | DATASET: 34 | TYPE: "AAViSS" 35 | SPLIT: "train_farTarget_24scenes_112009episodes" 36 | VERSION: 'v1' 37 | CONTENT_SCENES: ["*"] 38 | SCENES_DIR: "../sound_spaces/scene_datasets/mp3d" 39 | DATA_PATH: "data/active_datasets/{version}/{split}/{split}.json.gz" 40 | -------------------------------------------------------------------------------- /configs/tasks/nearTarget/test_nearTarget.yaml: -------------------------------------------------------------------------------- 1 | ENVIRONMENT: 2 | MAX_EPISODE_STEPS: 20 3 | SIMULATOR: 4 | SCENE_DATASET: "mp3d" 5 | GRID_SIZE: 1.0 6 | HABITAT_SIM_V0: 7 | GPU_DEVICE_ID: 0 8 | RGB_SENSOR: 9 | WIDTH: 128 10 | HEIGHT: 128 11 | DEPTH_SENSOR: 12 | WIDTH: 128 13 | HEIGHT: 128 14 | AUDIO: 15 | RIR_SAMPLING_RATE: 16000 16 | NORM_TYPE: "l2" 17 | GT_MONO_MAG_NORM: 1.2 18 | 19 | MONO_DIR: "data/audio_data/VoxCelebV1TenClasses_MITMusic_ESC50/train_preprocessed" 20 | 21 | TYPE: "HabitatSimAudioEnabledEval" 22 | ACTION_SPACE_CONFIG: "audio-separation" 23 | 24 | TASK: 25 | TYPE: AAViSS 26 | SENSORS: ["MIXED_BIN_AUDIO_MAG_SENSOR", "MIXED_BIN_AUDIO_PHASE_SENSOR", "GT_BIN_COMPONENTS_SENSOR", 27 | "GT_MONO_COMPONENTS_SENSOR", "TARGET_CLASS_SENSOR"] 28 | GOAL_SENSOR_UUID: mixed_bin_audio_mag 29 | MEASUREMENTS: ["GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE", "NORMALIZED_GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE"] 30 | POSSIBLE_ACTIONS: ["MOVE_FORWARD", "TURN_LEFT", "TURN_RIGHT"] 31 | 32 | 33 | DATASET: 34 | TYPE: "AAViSS" 35 | SPLIT: "train_nearTarget_24scenes_112009episodes" 36 | VERSION: 'v1' 37 | CONTENT_SCENES: ["*"] 38 | SCENES_DIR: "../sound_spaces/scene_datasets/mp3d" 39 | DATA_PATH: "data/active_datasets/{version}/{split}/{split}.json.gz" 40 | -------------------------------------------------------------------------------- /configs/tasks/nearTarget/val_nearTarget.yaml: -------------------------------------------------------------------------------- 1 | ENVIRONMENT: 2 | MAX_EPISODE_STEPS: 20 3 | SIMULATOR: 4 | SCENE_DATASET: "mp3d" 5 | GRID_SIZE: 1.0 6 | HABITAT_SIM_V0: 7 | GPU_DEVICE_ID: 0 8 | RGB_SENSOR: 9 | WIDTH: 128 10 | HEIGHT: 128 11 | DEPTH_SENSOR: 12 | WIDTH: 128 13 | HEIGHT: 128 14 | AUDIO: 15 | RIR_SAMPLING_RATE: 16000 16 | NORM_TYPE: "l2" 17 | GT_MONO_MAG_NORM: 1.2 18 | 19 | MONO_DIR: "data/audio_data/VoxCelebV1TenClasses_MITMusic_ESC50/train_preprocessed" 20 | 21 | TYPE: "HabitatSimAudioEnabledTrain" 22 | ACTION_SPACE_CONFIG: "audio-separation" 23 | 24 | TASK: 25 | TYPE: AAViSS 26 | SENSORS: ["MIXED_BIN_AUDIO_MAG_SENSOR", "MIXED_BIN_AUDIO_PHASE_SENSOR", "GT_BIN_COMPONENTS_SENSOR", 27 | "GT_MONO_COMPONENTS_SENSOR", "TARGET_CLASS_SENSOR"] 28 | GOAL_SENSOR_UUID: mixed_bin_audio_mag 29 | MEASUREMENTS: ["GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE", "NORMALIZED_GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE"] 30 | POSSIBLE_ACTIONS: ["MOVE_FORWARD", "TURN_LEFT", "TURN_RIGHT"] 31 | 32 | 33 | DATASET: 34 | TYPE: "AAViSS" 35 | SPLIT: "train_nearTarget_24scenes_112009episodes" 36 | VERSION: 'v1' 37 | CONTENT_SCENES: ["*"] 38 | SCENES_DIR: "../sound_spaces/scene_datasets/mp3d" 39 | DATA_PATH: "data/active_datasets/{version}/{split}/{split}.json.gz" 40 | -------------------------------------------------------------------------------- /configs/tasks/nearTarget/testUnheard_nearTarget.yaml: -------------------------------------------------------------------------------- 1 | ENVIRONMENT: 2 | MAX_EPISODE_STEPS: 20 3 | SIMULATOR: 4 | SCENE_DATASET: "mp3d" 5 | GRID_SIZE: 1.0 6 | HABITAT_SIM_V0: 7 | GPU_DEVICE_ID: 0 8 | RGB_SENSOR: 9 | WIDTH: 128 10 | HEIGHT: 128 11 | DEPTH_SENSOR: 12 | WIDTH: 128 13 | HEIGHT: 128 14 | AUDIO: 15 | RIR_SAMPLING_RATE: 16000 16 | NORM_TYPE: "l2" 17 | GT_MONO_MAG_NORM: 1.2 18 | 19 | MONO_DIR: "data/audio_data/VoxCelebV1TenClasses_MITMusic_ESC50/test_preprocessed" 20 | 21 | TYPE: "HabitatSimAudioEnabledEval" 22 | ACTION_SPACE_CONFIG: "audio-separation" 23 | 24 | TASK: 25 | TYPE: AAViSS 26 | SENSORS: ["MIXED_BIN_AUDIO_MAG_SENSOR", "MIXED_BIN_AUDIO_PHASE_SENSOR", "GT_BIN_COMPONENTS_SENSOR", 27 | "GT_MONO_COMPONENTS_SENSOR", "TARGET_CLASS_SENSOR"] 28 | GOAL_SENSOR_UUID: mixed_bin_audio_mag 29 | MEASUREMENTS: ["GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE", "NORMALIZED_GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE"] 30 | POSSIBLE_ACTIONS: ["MOVE_FORWARD", "TURN_LEFT", "TURN_RIGHT"] 31 | 32 | 33 | DATASET: 34 | TYPE: "AAViSS" 35 | SPLIT: "train_nearTarget_24scenes_112009episodes" 36 | VERSION: 'v1' 37 | CONTENT_SCENES: ["*"] 38 | SCENES_DIR: "../sound_spaces/scene_datasets/mp3d" 39 | DATA_PATH: "data/active_datasets/{version}/{split}/{split}.json.gz" 40 | -------------------------------------------------------------------------------- /configs/tasks/nearTarget/valUnheard_nearTarget.yaml: -------------------------------------------------------------------------------- 1 | ENVIRONMENT: 2 | MAX_EPISODE_STEPS: 20 3 | SIMULATOR: 4 | SCENE_DATASET: "mp3d" 5 | GRID_SIZE: 1.0 6 | HABITAT_SIM_V0: 7 | GPU_DEVICE_ID: 0 8 | RGB_SENSOR: 9 | WIDTH: 128 10 | HEIGHT: 128 11 | DEPTH_SENSOR: 12 | WIDTH: 128 13 | HEIGHT: 128 14 | AUDIO: 15 | RIR_SAMPLING_RATE: 16000 16 | NORM_TYPE: "l2" 17 | GT_MONO_MAG_NORM: 1.2 18 | 19 | MONO_DIR: "data/audio_data/VoxCelebV1TenClasses_MITMusic_ESC50/val_preprocessed" 20 | 21 | TYPE: "HabitatSimAudioEnabledTrain" 22 | ACTION_SPACE_CONFIG: "audio-separation" 23 | 24 | TASK: 25 | TYPE: AAViSS 26 | SENSORS: ["MIXED_BIN_AUDIO_MAG_SENSOR", "MIXED_BIN_AUDIO_PHASE_SENSOR", "GT_BIN_COMPONENTS_SENSOR", 27 | "GT_MONO_COMPONENTS_SENSOR", "TARGET_CLASS_SENSOR"] 28 | GOAL_SENSOR_UUID: mixed_bin_audio_mag 29 | MEASUREMENTS: ["GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE", "NORMALIZED_GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE"] 30 | POSSIBLE_ACTIONS: ["MOVE_FORWARD", "TURN_LEFT", "TURN_RIGHT"] 31 | 32 | 33 | DATASET: 34 | TYPE: "AAViSS" 35 | SPLIT: "train_nearTarget_24scenes_112009episodes" 36 | VERSION: 'v1' 37 | CONTENT_SCENES: ["*"] 38 | SCENES_DIR: "../sound_spaces/scene_datasets/mp3d" 39 | DATA_PATH: "data/active_datasets/{version}/{split}/{split}.json.gz" 40 | -------------------------------------------------------------------------------- /habitat_audio/action_space_separation.py: -------------------------------------------------------------------------------- 1 | import habitat_sim 2 | from habitat.core.registry import registry 3 | from habitat.core.simulator import ActionSpaceConfiguration 4 | from habitat.sims.habitat_simulator.actions import HabitatSimActions 5 | 6 | 7 | # swapping PAUSE for STOP 8 | HabitatSimActions.extend_action_space("PAUSE") 9 | temp = HabitatSimActions.STOP 10 | HabitatSimActions._known_actions["STOP"] = HabitatSimActions.PAUSE 11 | HabitatSimActions._known_actions["PAUSE"] = temp 12 | 13 | 14 | @registry.register_action_space_configuration(name="audio-separation") 15 | class AudioSeparationSpaceConfiguration(ActionSpaceConfiguration): 16 | def get(self): 17 | return { 18 | HabitatSimActions.PAUSE: habitat_sim.ActionSpec("pause"), 19 | HabitatSimActions.MOVE_FORWARD: habitat_sim.ActionSpec( 20 | "move_forward", 21 | habitat_sim.ActuationSpec( 22 | amount=self.config.FORWARD_STEP_SIZE 23 | ), 24 | ), 25 | HabitatSimActions.TURN_LEFT: habitat_sim.ActionSpec( 26 | "turn_left", 27 | habitat_sim.ActuationSpec(amount=self.config.TURN_ANGLE), 28 | ), 29 | HabitatSimActions.TURN_RIGHT: habitat_sim.ActionSpec( 30 | "turn_right", 31 | habitat_sim.ActuationSpec(amount=self.config.TURN_ANGLE), 32 | ), 33 | } 34 | -------------------------------------------------------------------------------- /configs/tasks/pretrain_passive.yaml: -------------------------------------------------------------------------------- 1 | ENVIRONMENT: 2 | MAX_EPISODE_STEPS: 20 3 | SIMULATOR: 4 | SCENE_DATASET: "mp3d" 5 | GRID_SIZE: 1.0 6 | HABITAT_SIM_V0: 7 | GPU_DEVICE_ID: 0 8 | RGB_SENSOR: 9 | WIDTH: 128 10 | HEIGHT: 128 11 | DEPTH_SENSOR: 12 | WIDTH: 128 13 | HEIGHT: 128 14 | AUDIO: 15 | RIR_SAMPLING_RATE: 16000 16 | NORM_TYPE: "l2" 17 | GT_MONO_MAG_NORM: 1.2 18 | 19 | PASSIVE_DATASET_VERSION: "v1" 20 | SOURCE_AGENT_LOCATION_DATAPOINTS_DIR: "data/passive_datasets/" 21 | PASSIVE_TRAIN_AUDIO_DIR: "data/audio_data/VoxCelebV1TenClasses_MITMusic_ESC50/train_preprocessed" 22 | PASSIVE_NONOVERLAPPING_VAL_AUDIO_DIR: "data/audio_data/VoxCelebV1TenClasses_MITMusic_ESC50/val_preprocessed" 23 | 24 | NUM_WORKER: 60 25 | BATCH_SIZE: 64 # 64, 128 26 | NUM_PASSIVE_DATAPOINTS_PER_SCENE: 30000 27 | NUM_PASSIVE_DATAPOINTS_PER_SCENE_EVAL: 30000 28 | 29 | 30 | TYPE: "HabitatSimAudioEnabledTrain" 31 | ACTION_SPACE_CONFIG: "audio-separation" 32 | 33 | TASK: 34 | TYPE: AAViSS 35 | SENSORS: ["MIXED_BIN_AUDIO_MAG_SENSOR"] 36 | GOAL_SENSOR_UUID: mixed_bin_audio_mag 37 | MEASUREMENTS: [] 38 | POSSIBLE_ACTIONS: [] 39 | 40 | 41 | DATASET: 42 | TYPE: "AAViSS" 43 | SPLIT: "train_nearTarget_20scenes" 44 | VERSION: 'v1' 45 | CONTENT_SCENES: ["*"] 46 | SCENES_DIR: "../sound_spaces/scene_datasets/mp3d" 47 | DATA_PATH: "data/active_datasets/{version}/{split}/{split}.json.gz" 48 | -------------------------------------------------------------------------------- /audio_separation/config/train/farTarget.yaml: -------------------------------------------------------------------------------- 1 | BASE_TASK_CONFIG_PATH: "configs/tasks/farTarget/train_farTarget.yaml" 2 | NUM_PROCESSES: 14 3 | SENSORS: ["RGB_SENSOR", "DEPTH_SENSOR"] 4 | EXTRA_DEPTH: True 5 | 6 | ### 37.6M steps 7 | ## 1 GPU 8 | # NUM_UPDATES: 134285 9 | # CHECKPOINT_INTERVAL: 715 10 | 11 | ## 8 GPU 12 | NUM_UPDATES: 16786 13 | CHECKPOINT_INTERVAL: 89 14 | 15 | LOG_INTERVAL: 50 16 | 17 | 18 | RL: 19 | PPO: 20 | num_updates_per_cycle: 6 # 1, 6 21 | 22 | # replace PRETRAIN_DIRNAME with name of directory having checkpoints from passive pretraining 23 | pretrained_passive_separators_ckpt: "runs/passive_pretrain/PRETRAIN_DIRNAME/data/best_ckpt_nonoverlapping_val.pth" 24 | train_passive_separators: False 25 | 26 | hidden_size: 512 27 | 28 | value_loss_coef: 0.5 29 | bin_separation_loss_coef: 1.0 30 | mono_conversion_loss_coef: 1.0 31 | entropy_coef: 0.20 32 | lr_pol: 1.0e-4 33 | lr_sep: 5.0e-4 34 | 35 | clip_param: 0.1 36 | ppo_epoch: 4 37 | num_mini_batch: 1 38 | eps: 1.0e-5 39 | max_grad_norm: 0.5 40 | num_steps: 20 41 | 42 | use_gae: True 43 | gamma: 0.99 44 | tau: 0.95 45 | use_linear_clip_decay: True 46 | use_linear_lr_decay: True 47 | 48 | sep_reward_weight: 0.0 49 | nav_reward_weight: 1.0 50 | reward_window_size: 50 51 | 52 | use_ddppo: True 53 | ddppo_distrib_backend: "NCCL" 54 | short_rollout_threshold: 1.0 # 0.25 55 | sync_frac: 0.6 56 | # master_port: 7738 57 | # master_addr: "127.0.0.9" 58 | -------------------------------------------------------------------------------- /audio_separation/config/train/nearTarget.yaml: -------------------------------------------------------------------------------- 1 | BASE_TASK_CONFIG_PATH: "configs/tasks/nearTarget/train_nearTarget.yaml" 2 | NUM_PROCESSES: 14 3 | SENSORS: ["RGB_SENSOR", "DEPTH_SENSOR"] 4 | EXTRA_DEPTH: True 5 | 6 | ### 37.6M steps 7 | ## 1 GPU 8 | # NUM_UPDATES: 134285 9 | # CHECKPOINT_INTERVAL: 715 10 | 11 | ## 8 GPU 12 | NUM_UPDATES: 16786 13 | CHECKPOINT_INTERVAL: 89 14 | 15 | LOG_INTERVAL: 50 16 | 17 | 18 | RL: 19 | PPO: 20 | num_updates_per_cycle: 6 # 1, 6 21 | 22 | # replace PRETRAIN_DIRNAME with name of directory having checkpoints from passive pretraining 23 | pretrained_passive_separators_ckpt: "runs/passive_pretrain/PRETRAIN_DIRNAME/data/best_ckpt_nonoverlapping_val.pth" 24 | train_passive_separators: False 25 | 26 | hidden_size: 512 27 | 28 | value_loss_coef: 0.5 29 | bin_separation_loss_coef: 1.0 30 | mono_conversion_loss_coef: 1.0 31 | entropy_coef: 0.20 32 | lr_pol: 1.0e-4 33 | lr_sep: 5.0e-4 34 | 35 | clip_param: 0.1 36 | ppo_epoch: 4 37 | num_mini_batch: 1 38 | eps: 1.0e-5 39 | max_grad_norm: 0.5 40 | num_steps: 20 41 | 42 | use_gae: True 43 | gamma: 0.99 44 | tau: 0.95 45 | use_linear_clip_decay: True 46 | use_linear_lr_decay: True 47 | 48 | sep_reward_weight: 1.0 49 | nav_reward_weight: 0.0 50 | extra_reward_multiplier: 10.0 51 | reward_window_size: 50 52 | 53 | use_ddppo: True 54 | ddppo_distrib_backend: "NCCL" 55 | short_rollout_threshold: 1.0 # 0.25 56 | sync_frac: 0.6 57 | # master_port: 7738 58 | # master_addr: "127.0.0.9" 59 | -------------------------------------------------------------------------------- /audio_separation/common/baseline_registry.py: -------------------------------------------------------------------------------- 1 | r"""BaselineRegistry is extended from habitat.Registry to provide 2 | registration for trainer and environments, while keeping Registry 3 | in habitat core intact. 4 | 5 | Import the baseline registry object using 6 | 7 | ``from baselines.common.baseline_registry import baseline_registry`` 8 | 9 | Various decorators for registry different kind of classes with unique keys 10 | 11 | - Register a environment: ``@registry.register_env`` 12 | - Register a trainer: ``@registry.register_trainer`` 13 | """ 14 | 15 | from typing import Optional 16 | 17 | from habitat.core.registry import Registry 18 | 19 | 20 | class BaselineRegistry(Registry): 21 | @classmethod 22 | def register_trainer(cls, to_register=None, *, name: Optional[str] = None): 23 | r"""Register a RL training algorithm to registry with key 'name'. 24 | 25 | Args: 26 | name: Key with which the trainer will be registered. 27 | If None will use the name of the class. 28 | 29 | """ 30 | from audio_separation.common.base_trainer import BaseTrainer 31 | 32 | return cls._register_impl( 33 | "trainer", to_register, name, assert_type=BaseTrainer 34 | ) 35 | 36 | @classmethod 37 | def get_trainer(cls, name): 38 | return cls._get_impl("trainer", name) 39 | 40 | @classmethod 41 | def register_env(cls, to_register=None, *, name: Optional[str] = None): 42 | r"""Register a environment to registry with key 'name' 43 | currently only support subclass of RLEnv. 44 | 45 | Args: 46 | name: Key with which the env will be registered. 47 | If None will use the name of the class. 48 | 49 | """ 50 | 51 | return cls._register_impl("env", to_register, name) 52 | 53 | @classmethod 54 | def get_env(cls, name): 55 | return cls._get_impl("env", name) 56 | 57 | 58 | baseline_registry = BaselineRegistry() 59 | -------------------------------------------------------------------------------- /scripts/search_for_checkpoint_thru_validation/link_ckpts_for_val.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "demographic-tracy", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 27, 16 | "id": "advisory-jaguar", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "### replace with necessary ABSOLUTE paths to train and val dirs to avoid potential issues with symlinks\n", 21 | "SOURCE_DIR = \"/projects/move2hear-active-AV-separation/runs/train/farTarget/farTarget_38MSteps_8G14P/data\"\n", 22 | "DUMP_DIR = \"/projects/move2hear-active-AV-separation/runs/val/farTarget/farTarget_38MSteps_8G14P_unheard/data\" \n", 23 | "\n", 24 | "assert os.path.isdir(SOURCE_DIR)\n", 25 | "assert os.path.isdir(DUMP_DIR)\n", 26 | "\n", 27 | "START_CKPT_IDX = 100" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 28, 33 | "id": "spoken-composite", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "for _, __, ckpt_files in os.walk(SOURCE_DIR):\n", 38 | " break\n", 39 | " \n", 40 | "for ckpt_idx in range(START_CKPT_IDX, int(ckpt_files[-1].split(\".\")[1]) + 1):\n", 41 | " assert f\"ckpt.{ckpt_idx}.pth\" in ckpt_files\n", 42 | " \n", 43 | " source_ckpt_path = os.path.join(SOURCE_DIR, f\"ckpt.{ckpt_idx}.pth\")\n", 44 | " assert os.path.exists(source_ckpt_path)\n", 45 | " \n", 46 | " dump_ckpt_path = os.path.join(DUMP_DIR, f\"ckpt.{ckpt_idx}.pth\")\n", 47 | "\n", 48 | " os.system(f\"ln -s {source_ckpt_path} {dump_ckpt_path}\") " 49 | ] 50 | } 51 | ], 52 | "metadata": { 53 | "kernelspec": { 54 | "display_name": "Python 3", 55 | "language": "python", 56 | "name": "python3" 57 | }, 58 | "language_info": { 59 | "codemirror_mode": { 60 | "name": "ipython", 61 | "version": 3 62 | }, 63 | "file_extension": ".py", 64 | "mimetype": "text/x-python", 65 | "name": "python", 66 | "nbconvert_exporter": "python", 67 | "pygments_lexer": "ipython3", 68 | "version": "3.6.12" 69 | } 70 | }, 71 | "nbformat": 4, 72 | "nbformat_minor": 5 73 | } 74 | -------------------------------------------------------------------------------- /audio_separation/common/tensorboard_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | import torch 5 | from torch.utils.tensorboard import SummaryWriter 6 | 7 | 8 | class TensorboardWriter: 9 | def __init__(self, log_dir: str, *args: Any, **kwargs: Any): 10 | r"""A Wrapper for tensorboard SummaryWriter. It creates a dummy writer 11 | when log_dir is empty string or None. It also has functionality that 12 | generates tb video directly from numpy images. 13 | 14 | Args: 15 | log_dir: Save directory location. Will not write to disk if 16 | log_dir is an empty string. 17 | *args: Additional positional args for SummaryWriter 18 | **kwargs: Additional keyword args for SummaryWriter 19 | """ 20 | self.writer = None 21 | if log_dir is not None and len(log_dir) > 0: 22 | self.writer = SummaryWriter(log_dir, *args, **kwargs) 23 | 24 | def __getattr__(self, item): 25 | if self.writer: 26 | return self.writer.__getattribute__(item) 27 | else: 28 | return lambda *args, **kwargs: None 29 | 30 | def __enter__(self): 31 | return self 32 | 33 | def __exit__(self, exc_type, exc_val, exc_tb): 34 | if self.writer: 35 | self.writer.close() 36 | 37 | def add_video_from_np_images( 38 | self, video_name: str, step_idx: int, images: np.ndarray, fps: int = 10 39 | ) -> None: 40 | r"""Write video into tensorboard from images frames. 41 | 42 | Args: 43 | video_name: name of video string. 44 | step_idx: int of checkpoint index to be displayed. 45 | images: list of n frames. Each frame is a np.ndarray of shape. 46 | fps: frame per second for output video. 47 | 48 | Returns: 49 | None. 50 | """ 51 | if not self.writer: 52 | return 53 | # initial shape of np.ndarray list: N * (H, W, 3) 54 | frame_tensors = [ 55 | torch.from_numpy(np_arr).unsqueeze(0) for np_arr in images 56 | ] 57 | video_tensor = torch.cat(tuple(frame_tensors)) 58 | video_tensor = video_tensor.permute(0, 3, 1, 2).unsqueeze(0) 59 | # final shape of video tensor: (1, n, 3, H, W) 60 | self.writer.add_video( 61 | video_name, video_tensor, fps=fps, global_step=step_idx 62 | ) 63 | -------------------------------------------------------------------------------- /scripts/farTarget_eval/copy_individualCkptsNCfgs_switchPolicyEval.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "excellent-variable", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import os" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 4, 17 | "id": "wrapped-edwards", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "### replace with necessary paths to checkpoints and dump_dir\n", 22 | "NEAR_TARGET_CKPT_SOURCE_PATH = \"../../runs/train/nearTarget/nearTarget_38MSteps_8G14P/data/ckpt.152.pth\"\n", 23 | "assert os.path.exists(NEAR_TARGET_CKPT_SOURCE_PATH)\n", 24 | "\n", 25 | "FAR_TARGET_CKPT_SOURCE_PATH = \"../../runs/train/farTarget/farTarget_38MSteps_8G14P/data/ckpt.188.pth\"\n", 26 | "assert os.path.exists(FAR_TARGET_CKPT_SOURCE_PATH)\n", 27 | "\n", 28 | "DUMP_DIR = \"../../runs/test/farTarget/farTarget_38MSteps_8G14P/data\"\n", 29 | "if not os.path.isdir(DUMP_DIR):\n", 30 | " os.makedirs(DUMP_DIR)\n", 31 | "DUMP_FILENAME = \"ckpt_polSwitch.pth\"\n", 32 | "DUMP_FILE_PATH = os.path.join(DUMP_DIR, DUMP_FILENAME)" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "id": "sorted-nightmare", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "nearTarget_dct = torch.load(NEAR_TARGET_CKPT_SOURCE_PATH, map_location=\"cpu\")\n", 43 | "farTarget_dct = torch.load(FAR_TARGET_CKPT_SOURCE_PATH, map_location=\"cpu\")" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 5, 49 | "id": "proved-closure", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "dump_dct = {\"state_dict_nav\": farTarget_dct[\"state_dict\"],\n", 54 | " \"config_nav\": farTarget_dct[\"config\"],\n", 55 | " \"state_dict_qualImprov\": nearTarget_dct[\"state_dict\"],\n", 56 | " \"config_qualImprov\": nearTarget_dct[\"config\"]}\n", 57 | "\n", 58 | "torch.save(\n", 59 | " dump_dct, DUMP_FILE_PATH\n", 60 | ")" 61 | ] 62 | } 63 | ], 64 | "metadata": { 65 | "kernelspec": { 66 | "display_name": "Python 3", 67 | "language": "python", 68 | "name": "python3" 69 | }, 70 | "language_info": { 71 | "codemirror_mode": { 72 | "name": "ipython", 73 | "version": 3 74 | }, 75 | "file_extension": ".py", 76 | "mimetype": "text/x-python", 77 | "name": "python", 78 | "nbconvert_exporter": "python", 79 | "pygments_lexer": "ipython3", 80 | "version": "3.6.12" 81 | } 82 | }, 83 | "nbformat": 4, 84 | "nbformat_minor": 5 85 | } 86 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import logging 9 | 10 | import warnings 11 | warnings.filterwarnings('ignore', category=FutureWarning) 12 | warnings.filterwarnings('ignore', category=UserWarning) 13 | import tensorflow as tf 14 | import torch 15 | 16 | from audio_separation.common.baseline_registry import baseline_registry 17 | from audio_separation.config.default import get_config 18 | 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument( 23 | "--run-type", 24 | choices=["train", "eval"], 25 | # required=True, 26 | default='train', 27 | help="run type of the experiment (train or eval)", 28 | ) 29 | parser.add_argument( 30 | "--exp-config", 31 | type=str, 32 | # required=True, 33 | default='baselines/config/pointnav_rgb.yaml', 34 | help="path to config yaml containing info about experiment", 35 | ) 36 | parser.add_argument( 37 | "opts", 38 | default=None, 39 | nargs=argparse.REMAINDER, 40 | help="Modify config options from command line", 41 | ) 42 | parser.add_argument( 43 | "--model-dir", 44 | default=None, 45 | help="Modify config options from command line", 46 | ) 47 | parser.add_argument( 48 | "--eval-interval", 49 | type=int, 50 | default=1, 51 | help="Evaluation interval of checkpoints", 52 | ) 53 | parser.add_argument( 54 | "--prev-ckpt-ind", 55 | type=int, 56 | default=-1, 57 | help="Evaluation interval of checkpoints", 58 | ) 59 | args = parser.parse_args() 60 | 61 | # repo = git.Repo(search_parent_directories=True) 62 | # logging.info('Current git head hash code: {}'.format(repo.head.object.hexsha)) 63 | 64 | # run exp 65 | config = get_config(args.exp_config, args.opts, args.model_dir, args.run_type) 66 | trainer_init = baseline_registry.get_trainer(config.TRAINER_NAME) 67 | assert trainer_init is not None, f"{config.TRAINER_NAME} is not supported" 68 | trainer = trainer_init(config) 69 | # torch.set_num_threads(1) 70 | 71 | level = logging.DEBUG if config.DEBUG else logging.INFO 72 | logging.basicConfig(level=level, format='%(asctime)s, %(levelname)s: %(message)s', 73 | datefmt="%Y-%m-%d %H:%M:%S") 74 | 75 | if args.run_type == "train": 76 | trainer.train() 77 | elif args.run_type == "eval": 78 | trainer.eval(args.eval_interval, args.prev_ckpt_ind) 79 | 80 | 81 | if __name__ == "__main__": 82 | main() 83 | -------------------------------------------------------------------------------- /habitat_audio/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def load_points(points_file: str, transform=True, scene_dataset="replica"): 8 | r""" 9 | Helper method to load points data from files stored on disk and transform if necessary 10 | :param points_file: path to files containing points data 11 | :param transform: transform coordinate systems of loaded points for use in Habitat or not 12 | :param scene_dataset: name of scenes dataset ("replica", "mp3d", etc.) 13 | :return: points in transformed coordinate system for use with Habitat 14 | """ 15 | points_data = np.loadtxt(points_file, delimiter="\t") 16 | if transform: 17 | if scene_dataset == "replica": 18 | points = list(zip( 19 | points_data[:, 1], 20 | points_data[:, 3] - 1.5528907, 21 | -points_data[:, 2]) 22 | ) 23 | elif scene_dataset == "mp3d": 24 | points = list(zip( 25 | points_data[:, 1], 26 | points_data[:, 3] - 1.5, 27 | -points_data[:, 2]) 28 | ) 29 | else: 30 | raise NotImplementedError 31 | else: 32 | points = list(zip( 33 | points_data[:, 1], 34 | points_data[:, 2], 35 | points_data[:, 3]) 36 | ) 37 | points_index = points_data[:, 0].astype(int) 38 | points_dict = dict(zip(points_index, points)) 39 | assert list(points_index) == list(range(len(points))) 40 | return points_dict, points 41 | 42 | 43 | def load_points_data(parent_folder, graph_file, transform=True, scene_dataset="replica"): 44 | r""" 45 | Main method to load points data from files stored on disk and transform if necessary 46 | :param parent_folder: parent folder containing files with points data 47 | :param graph_file: files containing connectivity of points per scene 48 | :param transform: transform coordinate systems of loaded points for use in Habitat or not 49 | :param scene_dataset: name of scenes dataset ("replica", "mp3d", etc.) 50 | :return: 1. points in transformed coordinate system for use with Habitat 51 | 2. graph object containing information about the connectivity of points in a scene 52 | """ 53 | points_file = os.path.join(parent_folder, 'points.txt') 54 | graph_file = os.path.join(parent_folder, graph_file) 55 | 56 | _, points = load_points(points_file, transform=transform, scene_dataset=scene_dataset) 57 | if not os.path.exists(graph_file): 58 | raise FileExistsError(graph_file + ' does not exist!') 59 | else: 60 | with open(graph_file, 'rb') as fo: 61 | graph = pickle.load(fo) 62 | 63 | return points, graph 64 | 65 | 66 | def _to_tensor(v): 67 | if torch.is_tensor(v): 68 | return v 69 | elif isinstance(v, np.ndarray): 70 | return torch.from_numpy(v) 71 | else: 72 | return torch.tensor(v, dtype=torch.float) 73 | -------------------------------------------------------------------------------- /scripts/separated_audio_quality/compute_separation_qualtiy.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "rural-gambling", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "import os\n", 12 | "import pickle" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": null, 18 | "id": "secondary-nicholas", 19 | "metadata": {}, 20 | "outputs": [], 21 | "source": [ 22 | "### replace with necessary path to test dir\n", 23 | "SOURCE_DIR = \"../../runs/test/nearTarget/nearTarget_38MSteps_8G14P\"\n", 24 | "assert os.path.isdir(SOURCE_DIR)\n", 25 | "\n", 26 | "EVAL_METRICS_FILENAME = \"eval_metrics.pkl\"\n", 27 | "\n", 28 | "EVAL_METRICS_FILE_FULL_PATH = os.path.join(SOURCE_DIR, EVAL_METRICS_FILENAME)\n", 29 | "assert os.path.exists(EVAL_METRICS_FILE_FULL_PATH)\n", 30 | "\n", 31 | "TARGET_METRIC = \"si_sdr\" # \"STFT_L2_loss\", \"si_sdr\"\n", 32 | "PRED_TYPE = \"monoFromMem\" # \"mono\", \"monoFromMem\"" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "ordered-curve", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "with open(EVAL_METRICS_FILE_FULL_PATH, \"rb\") as fi:\n", 43 | " eval_metrics_dct = pickle.load(fi)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": null, 49 | "id": "criminal-scoop", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "assert PRED_TYPE in eval_metrics_dct\n", 54 | "eval_metrics_dct_thisPredType = eval_metrics_dct[PRED_TYPE]\n", 55 | "\n", 56 | "assert TARGET_METRIC in eval_metrics_dct_thisPredType\n", 57 | "\n", 58 | "last_metricValue_perEpisode = []\n", 59 | "last_stepIdx = None\n", 60 | "for ep_idx in eval_metrics_dct_thisPredType[TARGET_METRIC]:\n", 61 | " last_stepIdx_thisEpisode = sorted(list(eval_metrics_dct_thisPredType[TARGET_METRIC][ep_idx].keys()))[-1]\n", 62 | " if last_stepIdx is None:\n", 63 | " last_stepIdx = last_stepIdx_thisEpisode\n", 64 | " assert last_stepIdx == last_stepIdx_thisEpisode\n", 65 | " last_metricValue_perEpisode.append(eval_metrics_dct_thisPredType[TARGET_METRIC][ep_idx][last_stepIdx_thisEpisode])" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "id": "terminal-mainstream", 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "print(f\"{PRED_TYPE} {TARGET_METRIC} mean: {np.mean(last_metricValue_perEpisode)}, std: {np.std(last_metricValue_perEpisode)}\")" 76 | ] 77 | } 78 | ], 79 | "metadata": { 80 | "kernelspec": { 81 | "display_name": "Python 3", 82 | "language": "python", 83 | "name": "python3" 84 | }, 85 | "language_info": { 86 | "codemirror_mode": { 87 | "name": "ipython", 88 | "version": 3 89 | }, 90 | "file_extension": ".py", 91 | "mimetype": "text/x-python", 92 | "name": "python", 93 | "nbconvert_exporter": "python", 94 | "pygments_lexer": "ipython3", 95 | "version": "3.6.12" 96 | } 97 | }, 98 | "nbformat": 4, 99 | "nbformat_minor": 5 100 | } 101 | -------------------------------------------------------------------------------- /audio_separation/rl/models/memory_nets.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class AcousticMem(nn.Module): 6 | def __init__(self, use_ddppo=False,): 7 | super().__init__() 8 | self._slice_factor = 16 9 | _n_out_audio = self._slice_factor 10 | 11 | if use_ddppo: 12 | self.cnn = nn.Sequential( 13 | nn.Conv2d(_n_out_audio * 2, 32, kernel_size=3, padding=1, bias=False), 14 | nn.ReLU(inplace=True), 15 | nn.Conv2d(32, _n_out_audio, kernel_size=3, padding=1, bias=False), 16 | ) 17 | else: 18 | self.cnn = nn.Sequential( 19 | nn.Conv2d(_n_out_audio * 2, 32, kernel_size=3, padding=1, bias=False), 20 | nn.BatchNorm2d(32), 21 | nn.ReLU(inplace=True), 22 | nn.Conv2d(32, _n_out_audio, kernel_size=3, padding=1, bias=False), 23 | ) 24 | 25 | self.layer_init() 26 | 27 | def layer_init(self): 28 | for layer in self.cnn: 29 | if isinstance(layer, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): 30 | nn.init.kaiming_normal_( 31 | layer.weight, nn.init.calculate_gain("relu") 32 | ) 33 | if layer.bias is not None: 34 | nn.init.constant_(layer.bias, val=0) 35 | elif isinstance(layer, (nn.BatchNorm1d, nn.BatchNorm2d)): 36 | if layer.affine: 37 | layer.weight.data.fill_(1) 38 | layer.bias.data.zero_() 39 | 40 | def forward(self, pred_mono, prev_pred_monoFromMem_masked): 41 | # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] 42 | pred_mono = pred_mono.permute(0, 3, 1, 2) 43 | prev_pred_monoFromMem_masked = prev_pred_monoFromMem_masked.permute(0, 3, 1, 2) 44 | 45 | # slice along freq dimension into 16 chunks 46 | pred_mono = pred_mono.view(pred_mono.size(0), pred_mono.size(1), self._slice_factor, -1, pred_mono.size(3)) 47 | pred_mono = pred_mono.reshape(pred_mono.size(0), -1, pred_mono.size(3), pred_mono.size(4)) 48 | 49 | prev_pred_monoFromMem_masked = prev_pred_monoFromMem_masked.view(prev_pred_monoFromMem_masked.size(0), 50 | prev_pred_monoFromMem_masked.size(1), 51 | self._slice_factor, 52 | -1, 53 | prev_pred_monoFromMem_masked.size(3)) 54 | prev_pred_monoFromMem_masked = prev_pred_monoFromMem_masked.reshape(prev_pred_monoFromMem_masked.size(0), 55 | -1, 56 | prev_pred_monoFromMem_masked.size(3), 57 | prev_pred_monoFromMem_masked.size(4)) 58 | 59 | out = torch.cat((pred_mono, prev_pred_monoFromMem_masked), dim=1) 60 | out = self.cnn(out) 61 | 62 | # deslice 63 | out = out.view(out.size(0), -1, self._slice_factor, out.size(2), out.size(3)) 64 | out = out.reshape(out.size(0), out.size(1), -1, out.size(4)) 65 | 66 | # permute tensor to dimension [BATCH x HEIGHT X WIDTH x CHANNEL] 67 | out = out.permute(0, 2, 3, 1) 68 | 69 | return out 70 | -------------------------------------------------------------------------------- /audio_separation/pretrain/passive/policy.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchsummary import summary 3 | 4 | from audio_separation.rl.models.separator_cnn import PassiveSepEncCNN, PassiveSepDecCNN 5 | 6 | 7 | class PassiveSepEnc(nn.Module): 8 | r"""Network which encodes separated bin or mono outputs 9 | """ 10 | def __init__(self, observation_space, world_rank=0, convert_bin2mono=False,): 11 | super().__init__() 12 | print(observation_space.spaces) 13 | assert 'mixed_bin_audio_mag' in observation_space.spaces 14 | 15 | print(convert_bin2mono) 16 | self.passive_sep_encoder = PassiveSepEncCNN(convert_bin2mono=convert_bin2mono,) 17 | 18 | if world_rank == 0: 19 | audio_shape = observation_space.spaces['mixed_bin_audio_mag'].shape 20 | 21 | if not convert_bin2mono: 22 | summary(self.passive_sep_encoder.cnn, 23 | (audio_shape[2] * 16 + 1, audio_shape[0] // 16, audio_shape[1]), 24 | device='cpu') 25 | else: 26 | summary(self.passive_sep_encoder.cnn, 27 | (audio_shape[2] * 16, audio_shape[0] // 16, audio_shape[1]), 28 | device='cpu') 29 | 30 | def forward(self, observations, mixed_audio=None): 31 | bottleneck_feats, lst_skip_feats = self.passive_sep_encoder(observations, mixed_audio=mixed_audio,) 32 | 33 | return bottleneck_feats, lst_skip_feats 34 | 35 | 36 | class PassiveSepDec(nn.Module): 37 | r"""Network which decodes separated bin or mono outputs feature embeddings 38 | """ 39 | def __init__(self, convert_bin2mono=False,): 40 | super().__init__() 41 | self.passive_sep_decoder = PassiveSepDecCNN(convert_bin2mono=convert_bin2mono,) 42 | 43 | def forward(self, bottleneck_feats, lst_skip_feats): 44 | return self.passive_sep_decoder(bottleneck_feats, lst_skip_feats) 45 | 46 | 47 | class Policy(nn.Module): 48 | r""" 49 | Network for the passive separation in Move2Hear pretraining 50 | """ 51 | def __init__(self, binSep_enc, binSep_dec, bin2mono_enc, bin2mono_dec,): 52 | super().__init__() 53 | self.binSep_enc = binSep_enc 54 | self.binSep_dec = binSep_dec 55 | self.bin2mono_enc = bin2mono_enc 56 | self.bin2mono_dec = bin2mono_dec 57 | 58 | def forward(self): 59 | raise NotImplementedError 60 | 61 | def get_binSepMasks(self, observations): 62 | bottleneck_feats, lst_skip_feats = self.binSep_enc( 63 | observations, 64 | ) 65 | return self.binSep_dec(bottleneck_feats, lst_skip_feats) 66 | 67 | def convert_bin2mono(self, pred_binSepMasks, mixed_audio=None): 68 | bottleneck_feats, lst_skip_feats = self.bin2mono_enc( 69 | pred_binSepMasks, mixed_audio=mixed_audio 70 | ) 71 | return self.bin2mono_dec(bottleneck_feats, lst_skip_feats) 72 | 73 | 74 | class Move2HearPassiveWoMemoryPolicy(Policy): 75 | def __init__( 76 | self, 77 | observation_space, 78 | ): 79 | binSep_enc = PassiveSepEnc( 80 | observation_space=observation_space, 81 | ) 82 | binSep_dec = PassiveSepDec() 83 | 84 | bin2mono_enc = PassiveSepEnc( 85 | observation_space=observation_space, 86 | convert_bin2mono=True, 87 | ) 88 | bin2mono_dec = PassiveSepDec( 89 | convert_bin2mono=True, 90 | ) 91 | 92 | super().__init__( 93 | binSep_enc, 94 | binSep_dec, 95 | bin2mono_enc, 96 | bin2mono_dec, 97 | ) 98 | -------------------------------------------------------------------------------- /audio_separation/common/benchmark.py: -------------------------------------------------------------------------------- 1 | r"""Implements evaluation of ``habitat.Agent`` inside ``habitat.Env``. 2 | ``habitat.Benchmark`` creates a ``habitat.Env`` which is specified through 3 | the ``config_env`` parameter in constructor. The evaluation is task agnostic 4 | and is implemented through metrics defined for ``habitat.EmbodiedTask``. 5 | """ 6 | 7 | from collections import defaultdict 8 | from typing import Dict, Optional 9 | import logging 10 | 11 | from tqdm import tqdm 12 | 13 | from habitat import Config 14 | from habitat.core.agent import Agent 15 | # from habitat.core.env import Env 16 | from audio_separation.common.environments import NavRLEnv 17 | from habitat.datasets import make_dataset 18 | 19 | 20 | class Benchmark: 21 | r"""Benchmark for evaluating agents in environments. 22 | """ 23 | 24 | def __init__(self, task_config: Optional[Config] = None) -> None: 25 | r""".. 26 | 27 | :param task_config: config to be used for creating the environment 28 | """ 29 | dummy_config = Config() 30 | dummy_config.RL = Config() 31 | dummy_config.RL.SLACK_REWARD = -0.01 32 | dummy_config.RL.SUCCESS_REWARD = 10 33 | dummy_config.RL.WITH_TIME_PENALTY = True 34 | dummy_config.RL.DISTANCE_REWARD_SCALE = 1 35 | dummy_config.RL.WITH_DISTANCE_REWARD = True 36 | dummy_config.RL.defrost() 37 | dummy_config.TASK_CONFIG = task_config 38 | dummy_config.freeze() 39 | 40 | dataset = make_dataset(id_dataset=task_config.DATASET.TYPE, config=task_config.DATASET) 41 | self._env = NavRLEnv(config=dummy_config, dataset=dataset) 42 | 43 | def evaluate( 44 | self, agent: Agent, num_episodes: Optional[int] = None 45 | ) -> Dict[str, float]: 46 | r""".. 47 | 48 | :param agent: agent to be evaluated in environment. 49 | :param num_episodes: count of number of episodes for which the 50 | evaluation should be run. 51 | :return: dict containing metrics tracked by environment. 52 | """ 53 | 54 | if num_episodes is None: 55 | num_episodes = len(self._env.episodes) 56 | else: 57 | assert num_episodes <= len(self._env.episodes), ( 58 | "num_episodes({}) is larger than number of episodes " 59 | "in environment ({})".format( 60 | num_episodes, len(self._env.episodes) 61 | ) 62 | ) 63 | 64 | assert num_episodes > 0, "num_episodes should be greater than 0" 65 | 66 | agg_metrics: Dict = defaultdict(float) 67 | 68 | count_episodes = 0 69 | reward_episodes = 0 70 | step_episodes = 0 71 | for count_episodes in tqdm(range(num_episodes)): 72 | agent.reset() 73 | observations = self._env.reset() 74 | episode_reward = 0 75 | 76 | while not self._env.habitat_env.episode_over: 77 | action = agent.act(observations) 78 | observations, reward, done, info = self._env.step(**action) 79 | logging.debug("Reward: {}".format(reward)) 80 | if done: 81 | logging.debug('Episode reward: {}'.format(episode_reward)) 82 | episode_reward += reward 83 | step_episodes += 1 84 | 85 | metrics = self._env.habitat_env.get_metrics() 86 | for m, v in metrics.items(): 87 | agg_metrics[m] += v 88 | reward_episodes += episode_reward 89 | 90 | avg_metrics = {k: v / count_episodes for k, v in agg_metrics.items()} 91 | logging.info("Average reward: {} in {} episodes".format(reward_episodes / count_episodes, count_episodes)) 92 | logging.info("Average episode steps: {}".format(step_episodes / count_episodes)) 93 | 94 | return avg_metrics 95 | -------------------------------------------------------------------------------- /audio_separation/common/environments.py: -------------------------------------------------------------------------------- 1 | r""" 2 | This file hosts task-specific or trainer-specific environments for trainers. 3 | All environments here should be a (direct or indirect ) subclass of Env class 4 | in habitat. Customized environments should be registered using 5 | ``@baseline_registry.register_env(name="myEnv")` for reusability 6 | """ 7 | 8 | from typing import Optional, Type 9 | import logging 10 | 11 | import habitat 12 | from habitat import Config, Dataset 13 | from audio_separation.common.baseline_registry import baseline_registry 14 | from habitat.sims.habitat_simulator.actions import HabitatSimActions 15 | 16 | 17 | def get_env_class(env_name: str) -> Type[habitat.RLEnv]: 18 | r"""Return environment class based on name. 19 | 20 | Args: 21 | env_name: name of the environment. 22 | 23 | Returns: 24 | Type[habitat.RLEnv]: env class. 25 | """ 26 | return baseline_registry.get_env(env_name) 27 | 28 | 29 | @baseline_registry.register_env(name="AAViSSEnv") 30 | class AAViSSEnv(habitat.RLEnv): 31 | def __init__(self, config: Config, dataset: Optional[Dataset] = None): 32 | self._rl_config = config.RL 33 | self._config = config 34 | self._core_env_config = config.TASK_CONFIG 35 | self._goal_reached_once = False 36 | 37 | self._previous_target_distance = None 38 | self._episode_distance_covered = None 39 | self._success_distance = self._core_env_config.TASK.SUCCESS_DISTANCE 40 | super().__init__(self._core_env_config, dataset) 41 | 42 | def reset(self): 43 | self._env_step = 0 44 | self._goal_reached_once = False 45 | observation = super().reset() 46 | logging.debug(super().current_episode) 47 | self._previous_target_distance = self.habitat_env.current_episode.info[0]["geodesic_distance"] 48 | return observation 49 | 50 | def step(self, *args, **kwargs): 51 | observation, reward, done, info = super().step(*args, **kwargs) 52 | self._env_step += 1 53 | return observation, reward, done, info 54 | 55 | def get_reward_range(self): 56 | return ( 57 | self._rl_config.SLACK_REWARD - 1.0, 58 | self._rl_config.SUCCESS_REWARD + 1.0, 59 | ) 60 | 61 | def get_reward(self, observations): 62 | reward = 0 63 | # for FarTgt 64 | if self._rl_config.WITH_DISTANCE_REWARD: 65 | current_target_distance = self._distance_target() 66 | reward += (self._previous_target_distance - current_target_distance) * self._rl_config.DISTANCE_REWARD_SCALE 67 | self._previous_target_distance = current_target_distance 68 | 69 | return reward 70 | 71 | def _distance_target(self): 72 | current_position = self._env.sim.get_agent_state().position.tolist() 73 | target_position = self._env.current_episode.goals[0].position 74 | distance = self._env.sim.geodesic_distance( 75 | current_position, target_position 76 | ) 77 | return distance 78 | 79 | def _episode_success(self): 80 | if ( 81 | (not self._env.sim._is_episode_active) 82 | and self._env.sim.reaching_goal 83 | ): 84 | return True 85 | return False 86 | 87 | def _goal_reached(self): 88 | if ( 89 | self._env.sim.reaching_goal 90 | ): 91 | return True 92 | return False 93 | 94 | def get_done(self, observations): 95 | done = False 96 | if self._env.episode_over: 97 | done = True 98 | return done 99 | 100 | def get_info(self, observations): 101 | return self.habitat_env.get_metrics() 102 | 103 | # for data collection 104 | def get_current_episode_id(self): 105 | return self.habitat_env.current_episode.episode_id 106 | -------------------------------------------------------------------------------- /audio_separation/common/utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | from collections import defaultdict 4 | from typing import Dict, List, Optional 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | class Flatten(nn.Module): 12 | def forward(self, x): 13 | return x.view(x.size(0), -1) 14 | 15 | 16 | class CustomFixedCategorical(torch.distributions.Categorical): 17 | def sample(self, sample_shape=torch.Size()): 18 | return super().sample(sample_shape).unsqueeze(-1) 19 | 20 | def log_probs(self, actions): 21 | return ( 22 | super() 23 | .log_prob(actions.squeeze(-1)) 24 | .view(actions.size(0), -1) 25 | .sum(-1) 26 | .unsqueeze(-1) 27 | ) 28 | 29 | def mode(self): 30 | return self.probs.argmax(dim=-1, keepdim=True) 31 | 32 | def get_probs(self): 33 | return self.probs 34 | 35 | def get_log_probs(self): 36 | return torch.log(self.probs + 1e-7) 37 | 38 | 39 | class CategoricalNet(nn.Module): 40 | def __init__(self, num_inputs, num_outputs): 41 | super().__init__() 42 | 43 | self.linear = nn.Linear(num_inputs, num_outputs) 44 | 45 | nn.init.orthogonal_(self.linear.weight, gain=0.01) 46 | nn.init.constant_(self.linear.bias, 0) 47 | 48 | def forward(self, x): 49 | x = self.linear(x) 50 | return CustomFixedCategorical(logits=x) 51 | 52 | 53 | def linear_decay(epoch: int, total_num_updates: int) -> float: 54 | r"""Returns a multiplicative factor for linear value decay 55 | 56 | Args: 57 | epoch: current epoch number 58 | total_num_updates: total number of epochs 59 | 60 | Returns: 61 | multiplicative factor that decreases param value linearly 62 | """ 63 | return 1 - (epoch / float(total_num_updates)) 64 | 65 | 66 | def to_tensor(v, device=None): 67 | if torch.is_tensor(v): 68 | return v.to(device=device, dtype=torch.float) 69 | elif isinstance(v, np.ndarray): 70 | return (torch.from_numpy(v)).to(device=device, dtype=torch.float) 71 | else: 72 | return torch.tensor(v, dtype=torch.float, device=device) 73 | 74 | 75 | def batch_obs( 76 | observations: List[Dict], device: Optional[torch.device] = None, 77 | ) -> Dict[str, torch.Tensor]: 78 | r"""Transpose a batch of observation dicts to a dict of batched 79 | observations. 80 | 81 | Args: 82 | observations: list of dicts of observations. 83 | device: The torch.device to put the resulting tensors on. 84 | Will not move the tensors if None 85 | 86 | Returns: 87 | transposed dict of lists of observations. 88 | """ 89 | batch: DefaultDict[str, List] = defaultdict(list) 90 | for obs in observations: 91 | for sensor in obs: 92 | batch[sensor].append(to_tensor(obs[sensor], device=device)) 93 | 94 | for sensor in batch: 95 | batch[sensor] = torch.stack(batch[sensor], dim=0) 96 | 97 | return batch 98 | 99 | 100 | def poll_checkpoint_folder( 101 | checkpoint_folder: str, previous_ckpt_ind: int, eval_interval: int 102 | ) -> Optional[str]: 103 | r""" Return (previous_ckpt_ind + 1)th checkpoint in checkpoint folder 104 | (sorted by time of last modification). 105 | 106 | Args: 107 | checkpoint_folder: directory to look for checkpoints. 108 | previous_ckpt_ind: index of checkpoint last returned. 109 | eval_interval: number of checkpoints between two evaluation 110 | 111 | Returns: 112 | return checkpoint path if (previous_ckpt_ind + 1)th checkpoint is found 113 | else return None. 114 | """ 115 | assert os.path.isdir(checkpoint_folder), ( 116 | f"invalid checkpoint folder " f"path {checkpoint_folder}" 117 | ) 118 | models_paths = list( 119 | filter(os.path.isfile, glob.glob(checkpoint_folder + "/*")) 120 | ) 121 | models_paths.sort(key=os.path.getmtime) 122 | ind = previous_ckpt_ind + eval_interval 123 | if ind < len(models_paths): 124 | return models_paths[ind] 125 | return None 126 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | aiohttp==3.7.4 3 | alabaster==0.7.12 4 | allennlp==0.9.0 5 | appdirs==1.4.4 6 | argon2-cffi==20.1.0 7 | astor==0.8.1 8 | async-generator==1.10 9 | async-timeout==3.0.1 10 | atomicwrites==1.3.0 11 | attrs==20.3.0 12 | audioread==2.1.8 13 | Babel==2.7.0 14 | backcall==0.2.0 15 | beautifulsoup4==4.9.3 16 | bleach==3.2.3 17 | blinker==1.4 18 | blis==0.2.4 19 | boto3==1.10.26 20 | botocore==1.13.26 21 | brotlipy==0.7.0 22 | bs4==0.0.1 23 | cachetools==4.0.0 24 | certifi==2020.12.5 25 | cffi==1.14.0 26 | chardet==4.0.0 27 | Click==7.0 28 | colorama==0.4.1 29 | conllu==1.3.1 30 | cryptography==3.3.1 31 | cycler==0.10.0 32 | cymem==2.0.3 33 | Cython==0.29.21 34 | dataclasses==0.8 35 | decorator==4.4.2 36 | defusedxml==0.6.0 37 | distro==1.5.0 38 | docutils==0.15.2 39 | editdistance==0.5.3 40 | en-core-web-sm==2.1.0 41 | entrypoints==0.3 42 | ffmpeg-python==0.2.0 43 | filelock==3.0.12 44 | flaky==3.6.1 45 | Flask==1.1.1 46 | Flask-Cors==3.0.8 47 | ftfy==5.6 48 | funcsigs==1.0.2 49 | future==0.18.2 50 | gast==0.3.2 51 | gevent==1.4.0 52 | google-auth==1.10.1 53 | google-auth-oauthlib==0.4.1 54 | google-pasta==0.2.0 55 | greenlet==0.4.15 56 | grpcio==1.26.0 57 | gym==0.10.9 58 | h5py==2.10.0 59 | habitat-sim==0.1.4 60 | habitat==0.1.4 61 | idna==2.10 62 | idna-ssl==1.1.0 63 | ifcfg==0.21 64 | imageio==2.9.0 65 | imageio-ffmpeg==0.4.3 66 | imagesize==1.1.0 67 | importlib-metadata==0.23 68 | ipykernel==5.4.3 69 | ipython==7.16.1 70 | ipython-genutils==0.2.0 71 | itsdangerous==1.1.0 72 | jedi==0.18.0 73 | Jinja2==2.10.3 74 | jmespath==0.9.4 75 | joblib==0.14.1 76 | jsonnet==0.14.0 77 | jsonpickle==1.2 78 | jsonschema==3.1.1 79 | jupyter-client==6.1.11 80 | jupyter-core==4.7.0 81 | jupyterlab-pygments==0.1.2 82 | Keras-Applications==1.0.8 83 | Keras-Preprocessing==1.1.0 84 | kiwisolver==1.1.0 85 | librosa==0.8.0 86 | llvmlite==0.31.0 87 | Markdown==3.1.1 88 | MarkupSafe==1.1.1 89 | matplotlib==3.1.2 90 | mir-eval==0.6 91 | mistune==0.8.4 92 | mkl-fft==1.2.0 93 | mkl-random==1.1.1 94 | mkl-service==2.3.0 95 | mock==3.0.5 96 | more-itertools==7.2.0 97 | moviepy==1.0.3 98 | multidict==5.1.0 99 | murmurhash==1.0.2 100 | musdb==0.4.0 101 | museval==0.4.0 102 | nbclient==0.5.1 103 | nbconvert==6.0.7 104 | nbformat==5.1.2 105 | nest-asyncio==1.5.1 106 | networkx==2.1 107 | nltk==3.4.5 108 | notebook==6.2.0 109 | numba==0.48.0 110 | numpy==1.19.5 111 | numpy-quaternion==2020.11.2.17.0.49 112 | numpydoc==0.9.1 113 | oauthlib==3.1.0 114 | olefile==0.46 115 | opencv-python==4.5.1.48 116 | opt-einsum==3.1.0 117 | overrides==2.5 118 | packaging==20.8 119 | pandas==1.1.5 120 | pandocfilters==1.4.3 121 | parsimonious==0.8.1 122 | parso==0.8.1 123 | pexpect==4.8.0 124 | pickleshare==0.7.5 125 | Pillow==8.1.0 126 | plac==0.9.6 127 | pluggy==0.13.1 128 | pooch==1.3.0 129 | preshed==2.0.1 130 | proglog==0.1.9 131 | prometheus-client==0.9.0 132 | prompt-toolkit==3.0.14 133 | protobuf==3.11.2 134 | psutil==5.6.5 135 | ptyprocess==0.7.0 136 | py==1.8.0 137 | pyaml==20.4.0 138 | pyasn1==0.4.8 139 | pyasn1-modules==0.2.8 140 | pybind11==2.6.2 141 | pycparser==2.20 142 | pyglet==1.5.14 143 | Pygments==2.4.2 144 | PyJWT==1.7.1 145 | pyOpenSSL==20.0.1 146 | pyparsing==2.4.7 147 | pyroomacoustics==0.4.2 148 | pyrsistent==0.15.5 149 | PySocks==1.7.1 150 | pytest==5.3.0 151 | python-dateutil==2.8.1 152 | pytorch-pretrained-bert==0.6.2 153 | pytorch-transformers==1.1.0 154 | pytz==2019.3 155 | PyWavelets==1.1.1 156 | PyYAML==5.4.1 157 | pyzmq==22.0.0 158 | ray==0.7.6 159 | redis==3.3.11 160 | regex==2019.11.1 161 | requests==2.25.1 162 | requests-oauthlib==1.3.0 163 | resampy==0.2.2 164 | responses==0.10.6 165 | rsa==4.0 166 | s3transfer==0.2.1 167 | scikit-build==0.11.1 168 | scikit-image==0.16.2 169 | scikit-learn==0.22.2.post1 170 | scipy==1.5.4 171 | seaborn==0.11.2 172 | Send2Trash==1.5.0 173 | sentencepiece==0.1.85 174 | simplejson==3.17.2 175 | six==1.15.0 176 | sklearn==0.0 177 | snowballstemmer==2.0.0 178 | SoundFile==0.10.3.post1 179 | soupsieve==2.2.1 180 | spacy==2.1.9 181 | Sphinx==2.2.1 182 | sphinxcontrib-applehelp==1.0.1 183 | sphinxcontrib-devhelp==1.0.1 184 | sphinxcontrib-htmlhelp==1.0.2 185 | sphinxcontrib-jsmath==1.0.1 186 | sphinxcontrib-qthelp==1.0.2 187 | sphinxcontrib-serializinghtml==1.1.3 188 | sqlparse==0.3.0 189 | srsly==0.2.0 190 | stempeg==0.2.3 191 | tensorboard==2.4.0 192 | tensorboard-plugin-wit==1.6.0 193 | tensorflow==2.1.0 194 | tensorflow-estimator==2.1.0 195 | termcolor==1.1.0 196 | terminado==0.9.2 197 | testpath==0.4.4 198 | thinc==7.0.8 199 | torch==1.4.0 200 | torchaudio==0.4.0 201 | torchsummary==1.5.1 202 | torchvision==0.5.0 203 | tornado==6.1 204 | tqdm==4.56.0 205 | traitlets==4.3.3 206 | Unidecode==1.1.1 207 | urllib3==1.26.3 208 | wasabi==0.4.0 209 | wcwidth==0.1.7 210 | webencodings==0.5.1 211 | Werkzeug==0.16.0 212 | word2number==1.1 213 | wrapt==1.12.1 214 | yacs==0.1.8 215 | yarl==1.6.3 216 | zipp==0.6.0 217 | -------------------------------------------------------------------------------- /audio_separation/rl/models/rnn_state_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class RNNStateEncoder(nn.Module): 6 | def __init__( 7 | self, 8 | input_size: int, 9 | hidden_size: int, 10 | num_layers: int = 1, 11 | rnn_type: str = "GRU", 12 | ): 13 | r"""An RNN for encoding the state in RL. 14 | 15 | Supports masking the hidden state during various timesteps in the forward lass 16 | 17 | Args: 18 | input_size: The input size of the RNN 19 | hidden_size: The hidden size 20 | num_layers: The number of recurrent layers 21 | rnn_type: The RNN cell type. Must be GRU or LSTM 22 | """ 23 | 24 | super().__init__() 25 | self._num_recurrent_layers = num_layers 26 | self._rnn_type = rnn_type 27 | 28 | self.rnn = getattr(nn, rnn_type)( 29 | input_size=input_size, 30 | hidden_size=hidden_size, 31 | num_layers=num_layers, 32 | ) 33 | 34 | self.layer_init() 35 | 36 | def layer_init(self): 37 | for name, param in self.rnn.named_parameters(): 38 | if "weight" in name: 39 | nn.init.orthogonal_(param) 40 | elif "bias" in name: 41 | nn.init.constant_(param, 0) 42 | 43 | @property 44 | def num_recurrent_layers(self): 45 | return self._num_recurrent_layers * ( 46 | 2 if "LSTM" in self._rnn_type else 1 47 | ) 48 | 49 | def _pack_hidden(self, hidden_states): 50 | if "LSTM" in self._rnn_type: 51 | hidden_states = torch.cat( 52 | [hidden_states[0], hidden_states[1]], dim=0 53 | ) 54 | 55 | return hidden_states 56 | 57 | def _unpack_hidden(self, hidden_states): 58 | if "LSTM" in self._rnn_type: 59 | hidden_states = ( 60 | hidden_states[0 : self._num_recurrent_layers], 61 | hidden_states[self._num_recurrent_layers :], 62 | ) 63 | 64 | return hidden_states 65 | 66 | def _mask_hidden(self, hidden_states, masks): 67 | if isinstance(hidden_states, tuple): 68 | hidden_states = tuple(v * masks for v in hidden_states) 69 | else: 70 | hidden_states = masks * hidden_states 71 | 72 | return hidden_states 73 | 74 | def single_forward(self, x, hidden_states, masks): 75 | r"""Forward for a non-sequence input 76 | """ 77 | hidden_states = self._unpack_hidden(hidden_states) 78 | x, hidden_states = self.rnn( 79 | x.unsqueeze(0), 80 | self._mask_hidden(hidden_states, masks.unsqueeze(0)), 81 | ) 82 | x = x.squeeze(0) 83 | hidden_states = self._pack_hidden(hidden_states) 84 | return x, hidden_states 85 | 86 | def seq_forward(self, x, hidden_states, masks): 87 | r"""Forward for a sequence of length T 88 | 89 | Args: 90 | x: (T, N, -1) Tensor that has been flattened to (T * N, -1) 91 | hidden_states: The starting hidden state. 92 | masks: The masks to be applied to hidden state at every timestep. 93 | A (T, N) tensor flatten to (T * N) 94 | """ 95 | # x is a (T, N, -1) tensor flattened to (T * N, -1) 96 | n = hidden_states.size(1) 97 | t = int(x.size(0) / n) 98 | 99 | # unflatten 100 | x = x.view(t, n, x.size(1)) 101 | masks = masks.view(t, n) 102 | 103 | # steps in sequence which have zero for any agent. Assume t=0 has 104 | # a zero in it. 105 | has_zeros = (masks[1:] == 0.0).any(dim=-1).nonzero().squeeze().cpu() 106 | 107 | # +1 to correct the masks[1:] 108 | if has_zeros.dim() == 0: 109 | has_zeros = [has_zeros.item() + 1] # handle scalar 110 | else: 111 | has_zeros = (has_zeros + 1).numpy().tolist() 112 | 113 | # add t=0 and t=T to the list 114 | has_zeros = [0] + has_zeros + [t] 115 | 116 | hidden_states = self._unpack_hidden(hidden_states) 117 | outputs = [] 118 | for i in range(len(has_zeros) - 1): 119 | # process steps that don't have any zeros in masks together 120 | start_idx = has_zeros[i] 121 | end_idx = has_zeros[i + 1] 122 | 123 | rnn_scores, hidden_states = self.rnn( 124 | x[start_idx:end_idx], 125 | self._mask_hidden( 126 | hidden_states, masks[start_idx].view(1, -1, 1) 127 | ), 128 | ) 129 | 130 | outputs.append(rnn_scores) 131 | 132 | # x is a (T, N, -1) tensor 133 | x = torch.cat(outputs, dim=0) 134 | x = x.view(t * n, -1) # flatten 135 | 136 | hidden_states = self._pack_hidden(hidden_states) 137 | return x, hidden_states 138 | 139 | def forward(self, x, hidden_states, masks): 140 | if x.size(0) == hidden_states.size(1): 141 | return self.single_forward(x, hidden_states, masks) 142 | else: 143 | return self.seq_forward(x, hidden_states, masks) 144 | -------------------------------------------------------------------------------- /scripts/search_for_checkpoint_thru_validation/find_bestCkpt_lowestValSTFTLoss.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "white-cylinder", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import os" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "entitled-lighter", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "### replace with necessary path to val dir\n", 21 | "SOURCE_DIR = \"../../runs/val/nearTarget/nearTarget_38MSteps_8G14P/\"\n", 22 | "assert os.path.isdir(SOURCE_DIR)\n", 23 | "\n", 24 | "CKPT_DIR = os.path.join(SOURCE_DIR, \"data\")\n", 25 | "assert os.path.isdir(CKPT_DIR)\n", 26 | "\n", 27 | "TRAIN_LOG_PATH = os.path.join(SOURCE_DIR, \"train.log\")\n", 28 | "assert os.path.exists(TRAIN_LOG_PATH)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": null, 34 | "id": "operating-burden", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "\"\"\"mono and monoFromMem losses\"\"\"\n", 39 | "ckpt2mono_lastStep_dct = {}\n", 40 | "ckpt2monoFromMem_lastStep_dct = {}\n", 41 | "ckpt_number = -1\n", 42 | "look_for_newCkpt = True\n", 43 | "with open(TRAIN_LOG_PATH, \"r\") as fo:\n", 44 | " for line in fo:\n", 45 | " if look_for_newCkpt and (len(line.split(\" \")) >= 3) and (line.split(\" \")[2] == \"Mono\") and (line.split(\" \")[7] == \"last\") and (line.split(\" \")[9] == \"---\"):\n", 46 | " ckpt2mono_lastStep_dct[ckpt_number] = [float(line.split(\" \")[11][:-1]), float(line.split(\" \")[13][:-1])]\n", 47 | " \n", 48 | " if look_for_newCkpt and (len(line.split(\" \")) >= 3) and (line.split(\" \")[2] == \"MonoFromMem\") and (line.split(\" \")[7] == \"last\") and (line.split(\" \")[9] == \"---\"):\n", 49 | " look_for_newCkpt = False\n", 50 | " ckpt2monoFromMem_lastStep_dct[ckpt_number] = [float(line.split(\" \")[11][:-1]), float(line.split(\" \")[13][:-1])]\n", 51 | " \n", 52 | " if (not look_for_newCkpt) and (len(line.split(\" \")) >= 4) and (line.split(\" \")[2] == \"=======current_ckpt:\"):\n", 53 | " look_for_newCkpt = True\n", 54 | " ckpt_number = int(line.split(\" \")[-1].split(\"=======\")[0].split(\"/\")[-1].split(\".\")[1]) " 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "passing-louis", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "ckpt_numbers = []\n", 65 | "missing_key = -1\n", 66 | "for (dirpath, dirnames, filenames) in os.walk(CKPT_DIR):\n", 67 | " ckpt_numbers.extend(filenames)\n", 68 | " break\n", 69 | " \n", 70 | "for i in range(len(ckpt_numbers)):\n", 71 | " ckpt_numbers[i] = int(ckpt_numbers[i].split(\".\")[1])\n", 72 | " \n", 73 | "for key in ckpt_numbers:\n", 74 | " if key not in list(ckpt2mono_lastStep_dct.keys()):\n", 75 | " missing_key = key\n", 76 | " \n", 77 | "ckpt2mono_lastStep_dct_final = {}\n", 78 | "ckpt2monoFromMem_lastStep_dct_final = {}\n", 79 | "for key in ckpt2mono_lastStep_dct:\n", 80 | " if key == -1:\n", 81 | " ckpt2mono_lastStep_dct_final[missing_key] = ckpt2mono_lastStep_dct[key]\n", 82 | " ckpt2monoFromMem_lastStep_dct_final[missing_key] = ckpt2monoFromMem_lastStep_dct[key]\n", 83 | " else:\n", 84 | " ckpt2mono_lastStep_dct_final[key] = ckpt2mono_lastStep_dct[key]\n", 85 | " ckpt2monoFromMem_lastStep_dct_final[key] = ckpt2monoFromMem_lastStep_dct[key]" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "parallel-zambia", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "min_mono_lastStep = [float(\"inf\"), 0]\n", 96 | "min_monoFromMem_lastStep = [float(\"inf\"), 0]\n", 97 | "\n", 98 | "for key, value in ckpt2mono_lastStep_dct_final.items():\n", 99 | " if value[0] < min_mono_lastStep[0]:\n", 100 | " min_mono_lastStep = value\n", 101 | " bestCkpt_mono_lastStep = \"ckpt.\" + str(key) + \".pth\"\n", 102 | " \n", 103 | " if ckpt2monoFromMem_lastStep_dct_final[key][0] < min_monoFromMem_lastStep[0]:\n", 104 | " min_monoFromMem_lastStep = ckpt2monoFromMem_lastStep_dct_final[key]\n", 105 | " bestCkpt_monoFromMem_lastStep = \"ckpt.\" + str(key) + \".pth\" \n", 106 | " \n", 107 | "print(\"best validation checkpoint: \", bestCkpt_mono_lastStep,\n", 108 | " \", mono_lastStep: mean -- {}, std -- {}\".format(min_mono_lastStep[0],\n", 109 | " min_mono_lastStep[1]))\n", 110 | "print(\"best validation checkpoint: \", bestCkpt_monoFromMem_lastStep,\n", 111 | " \", monoFromMem_lastStep: mean -- {}, std -- {}\".format(min_monoFromMem_lastStep[0],\n", 112 | " min_monoFromMem_lastStep[1]))" 113 | ] 114 | } 115 | ], 116 | "metadata": { 117 | "kernelspec": { 118 | "display_name": "Python 3", 119 | "language": "python", 120 | "name": "python3" 121 | }, 122 | "language_info": { 123 | "codemirror_mode": { 124 | "name": "ipython", 125 | "version": 3 126 | }, 127 | "file_extension": ".py", 128 | "mimetype": "text/x-python", 129 | "name": "python", 130 | "nbconvert_exporter": "python", 131 | "pygments_lexer": "ipython3", 132 | "version": "3.6.12" 133 | } 134 | }, 135 | "nbformat": 4, 136 | "nbformat_minor": 5 137 | } 138 | -------------------------------------------------------------------------------- /audio_separation/rl/models/audio_cnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from audio_separation.common.utils import Flatten 6 | 7 | 8 | class AudioCNN(nn.Module): 9 | r"""A Simple 3-Conv CNN followed by a fully connected layer for high res spec. 10 | 11 | Takes in separated audio outputs (bin/monos) and produces an embedding 12 | 13 | Args: 14 | observation_space: The observation_space of the agent 15 | output_size: The size of the embedding vector 16 | encode_monoNmonoFromMem: creates CNN for encoding predicted monaurals (concatenation of passive and acoustic 17 | memory outputs) if set to True 18 | """ 19 | def __init__(self, observation_space, output_size, encode_monoNmonoFromMem=False,): 20 | super().__init__() 21 | self.encode_monoNmonoFromMem = encode_monoNmonoFromMem 22 | # originally 2 channels for binaural or concatenation of monos but spec. sliced up into 16 chunks along the frequency 23 | # dimension (this makes the high-res. specs. easier to deal with) 24 | self._slice_factor = 16 25 | self._n_input_audio = 2 * self._slice_factor 26 | 27 | # kernel size for different CNN layers 28 | self._cnn_layers_kernel_size = [(8, 8), (4, 4), (2, 2)] 29 | 30 | # strides for different CNN layers 31 | self._cnn_layers_stride = [(4, 4), (2, 2), (1, 1)] 32 | 33 | cnn_dims = np.array( 34 | [observation_space.spaces["mixed_bin_audio_mag"].shape[0] // 16, 35 | observation_space.spaces["mixed_bin_audio_mag"].shape[1]], 36 | dtype=np.float32 37 | ) 38 | 39 | for kernel_size, stride in zip( 40 | self._cnn_layers_kernel_size, self._cnn_layers_stride 41 | ): 42 | cnn_dims = self._conv_output_dim( 43 | dimension=cnn_dims, 44 | padding=np.array([0, 0], dtype=np.float32), 45 | dilation=np.array([1, 1], dtype=np.float32), 46 | kernel_size=np.array(kernel_size, dtype=np.float32), 47 | stride=np.array(stride, dtype=np.float32), 48 | ) 49 | 50 | self.cnn = nn.Sequential( 51 | nn.Conv2d( 52 | in_channels=self._n_input_audio, 53 | out_channels=32, 54 | kernel_size=self._cnn_layers_kernel_size[0], 55 | stride=self._cnn_layers_stride[0], 56 | ), 57 | nn.ReLU(True), 58 | nn.Conv2d( 59 | in_channels=32, 60 | out_channels=64, 61 | kernel_size=self._cnn_layers_kernel_size[1], 62 | stride=self._cnn_layers_stride[1], 63 | ), 64 | nn.ReLU(True), 65 | nn.Conv2d( 66 | in_channels=64, 67 | out_channels=32, 68 | kernel_size=self._cnn_layers_kernel_size[2], 69 | stride=self._cnn_layers_stride[2], 70 | ), 71 | nn.ReLU(True), 72 | Flatten(), 73 | nn.Linear(32 * cnn_dims[0] * cnn_dims[1], output_size), 74 | nn.ReLU(True), 75 | ) 76 | 77 | self.layer_init() 78 | 79 | def _conv_output_dim( 80 | self, dimension, padding, dilation, kernel_size, stride 81 | ): 82 | r"""Calculates the output height and width based on the input 83 | height and width to the convolution layer. 84 | 85 | ref: https://pytorch.org/docs/master/nn.html#torch.nn.Conv2d 86 | """ 87 | assert len(dimension) == 2 88 | out_dimension = [] 89 | for i in range(len(dimension)): 90 | out_dimension.append( 91 | int( 92 | np.floor( 93 | ( 94 | ( 95 | dimension[i] 96 | + 2 * padding[i] 97 | - dilation[i] * (kernel_size[i] - 1) 98 | - 1 99 | ) 100 | / stride[i] 101 | ) 102 | + 1 103 | ) 104 | ) 105 | ) 106 | return tuple(out_dimension) 107 | 108 | def layer_init(self): 109 | for layer in self.cnn: 110 | if isinstance(layer, (nn.Conv2d, nn.Linear)): 111 | nn.init.kaiming_normal_( 112 | layer.weight, nn.init.calculate_gain("relu") 113 | ) 114 | if layer.bias is not None: 115 | nn.init.constant_(layer.bias, val=0) 116 | 117 | def forward(self, observations, pred_binSepMasks=None, pred_monoNmonoFromMem=None,): 118 | cnn_input = [] 119 | 120 | if self.encode_monoNmonoFromMem: 121 | assert pred_monoNmonoFromMem is not None 122 | x = torch.log1p(torch.clamp(pred_monoNmonoFromMem, min=0)) 123 | else: 124 | assert pred_binSepMasks is not None 125 | x = observations["mixed_bin_audio_mag"] 126 | x = torch.exp(x) - 1 127 | x = x * pred_binSepMasks 128 | x = torch.log1p(torch.clamp(x, min=0)) 129 | 130 | # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] 131 | x = x.permute(0, 3, 1, 2) 132 | 133 | # slice along freq dimension into 16 chunks 134 | x = x.view(x.size(0), x.size(1), self._slice_factor, -1, x.size(3)) 135 | x = x.reshape(x.size(0), -1, x.size(3), x.size(4)) 136 | 137 | cnn_input.append(x) 138 | cnn_input = torch.cat(cnn_input, dim=1) 139 | 140 | return self.cnn(cnn_input) 141 | -------------------------------------------------------------------------------- /audio_separation/rl/models/visual_cnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from audio_separation.common.utils import Flatten 6 | 7 | 8 | class VisualCNN(nn.Module): 9 | r"""A Simple 3-Conv CNN followed by a fully connected layer for low res spec. 10 | 11 | Takes in observations and produces an embedding of the rgb and/or depth components 12 | 13 | Args: 14 | observation_space: The observation_space of the agent 15 | output_size: The size of the embedding vector 16 | extra_rgb: RGB isn't used for encoding when True 17 | extra_depth: depth isn't used for encoding when True 18 | """ 19 | 20 | def __init__(self, observation_space, output_size, extra_rgb, extra_depth): 21 | super().__init__() 22 | if "rgb" in observation_space.spaces and (not extra_rgb): 23 | self._n_input_rgb = observation_space.spaces["rgb"].shape[2] 24 | else: 25 | self._n_input_rgb = 0 26 | 27 | if "depth" in observation_space.spaces and (not extra_depth): 28 | self._n_input_depth = observation_space.spaces["depth"].shape[2] 29 | else: 30 | self._n_input_depth = 0 31 | 32 | # kernel size for different CNN layers 33 | self._cnn_layers_kernel_size = [(8, 8), (4, 4), (3, 3)] 34 | 35 | # strides for different CNN layers 36 | self._cnn_layers_stride = [(4, 4), (2, 2), (1, 1)] 37 | 38 | if self._n_input_rgb > 0: 39 | # cnn_dims = np.array( 40 | # observation_space.spaces["rgb"].shape[:2], dtype=np.float32 41 | # ) 42 | # hardcoding needed since rgb is of size (720, 720, 3) when rendering videos and this call throws an error 43 | cnn_dims = np.array( 44 | [128, 128], dtype=np.float32 45 | ) 46 | elif self._n_input_depth > 0: 47 | cnn_dims = np.array( 48 | observation_space.spaces["depth"].shape[:2], dtype=np.float32 49 | ) 50 | 51 | if self.is_blind: 52 | self.cnn = nn.Sequential() 53 | else: 54 | for kernel_size, stride in zip( 55 | self._cnn_layers_kernel_size, self._cnn_layers_stride 56 | ): 57 | cnn_dims = self._conv_output_dim( 58 | dimension=cnn_dims, 59 | padding=np.array([0, 0], dtype=np.float32), 60 | dilation=np.array([1, 1], dtype=np.float32), 61 | kernel_size=np.array(kernel_size, dtype=np.float32), 62 | stride=np.array(stride, dtype=np.float32), 63 | ) 64 | 65 | self.cnn = nn.Sequential( 66 | nn.Conv2d( 67 | in_channels=self._n_input_rgb + self._n_input_depth, 68 | out_channels=32, 69 | kernel_size=self._cnn_layers_kernel_size[0], 70 | stride=self._cnn_layers_stride[0], 71 | ), 72 | nn.ReLU(True), 73 | nn.Conv2d( 74 | in_channels=32, 75 | out_channels=64, 76 | kernel_size=self._cnn_layers_kernel_size[1], 77 | stride=self._cnn_layers_stride[1], 78 | ), 79 | nn.ReLU(True), 80 | nn.Conv2d( 81 | in_channels=64, 82 | out_channels=32, 83 | kernel_size=self._cnn_layers_kernel_size[2], 84 | stride=self._cnn_layers_stride[2], 85 | ), 86 | Flatten(), 87 | nn.Linear(32 * cnn_dims[0] * cnn_dims[1], output_size), 88 | nn.ReLU(True), 89 | ) 90 | 91 | self.layer_init() 92 | 93 | def _conv_output_dim( 94 | self, dimension, padding, dilation, kernel_size, stride 95 | ): 96 | r"""Calculates the output height and width based on the input 97 | height and width to the convolution layer. 98 | 99 | ref: https://pytorch.org/docs/master/nn.html#torch.nn.Conv2d 100 | """ 101 | assert len(dimension) == 2 102 | out_dimension = [] 103 | for i in range(len(dimension)): 104 | out_dimension.append( 105 | int( 106 | np.floor( 107 | ( 108 | ( 109 | dimension[i] 110 | + 2 * padding[i] 111 | - dilation[i] * (kernel_size[i] - 1) 112 | - 1 113 | ) 114 | / stride[i] 115 | ) 116 | + 1 117 | ) 118 | ) 119 | ) 120 | return tuple(out_dimension) 121 | 122 | def layer_init(self): 123 | for layer in self.cnn: 124 | if isinstance(layer, (nn.Conv2d, nn.Linear)): 125 | nn.init.kaiming_normal_( 126 | layer.weight, nn.init.calculate_gain("relu") 127 | ) 128 | if layer.bias is not None: 129 | nn.init.constant_(layer.bias, val=0) 130 | 131 | @property 132 | def is_blind(self): 133 | return self._n_input_rgb + self._n_input_depth == 0 134 | 135 | def forward(self, observations): 136 | cnn_input = [] 137 | if self._n_input_rgb > 0: 138 | rgb_observations = observations["rgb"] 139 | # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] 140 | rgb_observations = rgb_observations.permute(0, 3, 1, 2) 141 | rgb_observations = rgb_observations / 255.0 # normalize RGB 142 | cnn_input.append(rgb_observations) 143 | 144 | if self._n_input_depth > 0: 145 | depth_observations = observations["depth"] 146 | # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] 147 | depth_observations = depth_observations.permute(0, 3, 1, 2) 148 | cnn_input.append(depth_observations) 149 | 150 | cnn_input = torch.cat(cnn_input, dim=1) 151 | 152 | return self.cnn(cnn_input) 153 | 154 | -------------------------------------------------------------------------------- /audio_separation/rl/ppo/ddppo_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shlex 3 | import signal 4 | import subprocess 5 | import threading 6 | from os import path as osp 7 | from typing import Any, Optional, Tuple 8 | 9 | import ifcfg 10 | import torch 11 | from torch import distributed as distrib 12 | 13 | from habitat import logger 14 | 15 | EXIT = threading.Event() 16 | EXIT.clear() 17 | REQUEUE = threading.Event() 18 | REQUEUE.clear() 19 | 20 | 21 | # Default port to initialized the TCP store on 22 | DEFAULT_PORT = 7738 23 | # Default address of world rank 0 24 | DEFAULT_MASTER_ADDR = "127.0.0.2" 25 | 26 | SLURM_JOBID = os.environ.get("SLURM_JOB_ID", None) 27 | INTERRUPTED_STATE_FILE = osp.join( 28 | os.environ["HOME"], ".interrupted_states", f"{SLURM_JOBID}.pth" 29 | ) 30 | 31 | 32 | def _clean_exit_handler(signum, frame): 33 | EXIT.set() 34 | print("Exiting cleanly", flush=True) 35 | 36 | 37 | def _requeue_handler(signal, frame): 38 | print("Got signal to requeue", flush=True) 39 | EXIT.set() 40 | REQUEUE.set() 41 | 42 | 43 | def add_signal_handlers() -> None: 44 | signal.signal(signal.SIGINT, _clean_exit_handler) 45 | signal.signal(signal.SIGTERM, _clean_exit_handler) 46 | 47 | # SIGUSR2 can be sent to all processes to have them cleanup 48 | # and exit nicely. This is nice to use with SLURM as scancel 49 | # sets a 30 second timer for the job to exit, and it can take more than 50 | # 30 seconds for the job to cleanup and exit nicely. When using NCCL, 51 | # forcing the job to exit without cleaning up can be bad. 52 | # scancel --signal SIGUSR2 will set no such timer and will give 53 | # the job ample time to cleanup and exit. 54 | signal.signal(signal.SIGUSR2, _clean_exit_handler) 55 | 56 | signal.signal(signal.SIGUSR1, _requeue_handler) 57 | 58 | 59 | def save_interrupted_state(state: Any, filename: str = None): 60 | r"""Saves the interrupted job state to the specified filename. 61 | This is useful when working with preemptable job partitions. 62 | 63 | This method will do nothing if SLURM is not currently being used and the filename is the default 64 | 65 | :param state: The state to save 66 | :param filename: The filename. Defaults to "${HOME}/.interrupted_states/${SLURM_JOBID}.pth" 67 | """ 68 | if SLURM_JOBID is None and filename is None: 69 | logger.warn("SLURM_JOBID is none, not saving interrupted state") 70 | return 71 | 72 | if filename is None: 73 | filename = INTERRUPTED_STATE_FILE 74 | 75 | torch.save(state, filename) 76 | 77 | 78 | def load_interrupted_state(filename: str = None) -> Optional[Any]: 79 | r"""Loads the saved interrupted state 80 | 81 | :param filename: The filename of the saved state. 82 | Defaults to "${HOME}/.interrupted_states/${SLURM_JOBID}.pth" 83 | 84 | :return: The saved state if the file exists, else none 85 | """ 86 | if SLURM_JOBID is None and filename is None: 87 | return None 88 | 89 | if filename is None: 90 | filename = INTERRUPTED_STATE_FILE 91 | 92 | if not osp.exists(filename): 93 | return None 94 | 95 | return torch.load(filename, map_location="cpu") 96 | 97 | 98 | def requeue_job(): 99 | r"""Requeues the job by calling ``scontrol requeue ${SLURM_JOBID}``""" 100 | if SLURM_JOBID is None: 101 | return 102 | 103 | if not REQUEUE.is_set(): 104 | return 105 | 106 | distrib.barrier() 107 | 108 | if distrib.get_rank() == 0: 109 | logger.info(f"Requeueing job {SLURM_JOBID}") 110 | subprocess.check_call(shlex.split(f"scontrol requeue {SLURM_JOBID}")) 111 | 112 | 113 | def get_ifname() -> str: 114 | return ifcfg.default_interface()["device"] 115 | 116 | 117 | def init_distrib_slurm( 118 | backend: str = "nccl", master_port=8738, master_addr="127.0.0.1", 119 | ) -> Tuple[int, torch.distributed.TCPStore]: # type: ignore 120 | r"""Initializes torch.distributed by parsing environment variables set 121 | by SLURM when ``srun`` is used or by parsing environment variables set 122 | by torch.distributed.launch 123 | 124 | :param backend: Which torch.distributed backend to use 125 | 126 | :returns: Tuple of the local_rank (aka which GPU to use for this process) 127 | and the TCPStore used for the rendezvous 128 | """ 129 | assert ( 130 | torch.distributed.is_available() 131 | ), "torch.distributed must be available" 132 | 133 | if "GLOO_SOCKET_IFNAME" not in os.environ: 134 | os.environ["GLOO_SOCKET_IFNAME"] = get_ifname() 135 | 136 | if "NCCL_SOCKET_IFNAME" not in os.environ: 137 | os.environ["NCCL_SOCKET_IFNAME"] = get_ifname() 138 | 139 | master_port = int(master_port) 140 | 141 | # Check to see if we should parse from torch.distributed.launch 142 | if os.environ.get("LOCAL_RANK", None) is not None: 143 | local_rank = int(os.environ["LOCAL_RANK"]) 144 | world_rank = int(os.environ["RANK"]) 145 | world_size = int(os.environ["WORLD_SIZE"]) 146 | # Else parse from SLURM is using SLURM 147 | elif os.environ.get("SLURM_JOBID", None) is not None: 148 | local_rank = int(os.environ["SLURM_LOCALID"]) 149 | world_rank = int(os.environ["SLURM_PROCID"]) 150 | world_size = int(os.environ["SLURM_NTASKS"]) 151 | # Otherwise setup for just 1 process, this is nice for testing 152 | else: 153 | local_rank = 0 154 | world_rank = 0 155 | world_size = 1 156 | 157 | tcp_store = distrib.TCPStore( # type: ignore 158 | master_addr, master_port, world_size, world_rank == 0 159 | ) 160 | 161 | distrib.init_process_group( 162 | backend, store=tcp_store, rank=world_rank, world_size=world_size 163 | ) 164 | 165 | return local_rank, tcp_store 166 | 167 | 168 | def distributed_mean_and_var( 169 | values: torch.Tensor, 170 | ): 171 | r"""Computes the mean and variances of a tensor over multiple workers. 172 | This method is equivalent to first collecting all versions of values and 173 | then computing the mean and variance locally over that 174 | :param values: (*,) shaped tensors to compute mean and variance over. Assumed 175 | to be solely the workers local copy of this tensor, 176 | the resultant mean and variance will be computed 177 | over _all_ workers version of this tensor. 178 | """ 179 | assert distrib.is_initialized(), "Distributed must be initialized" 180 | 181 | world_size = distrib.get_world_size() 182 | mean = values.mean() 183 | distrib.all_reduce(mean) 184 | mean /= world_size 185 | 186 | sq_diff = (values - mean).pow(2).mean() 187 | distrib.all_reduce(sq_diff) 188 | var = sq_diff / world_size 189 | 190 | return mean, var 191 | -------------------------------------------------------------------------------- /audio_separation/rl/models/separator_cnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | def unet_conv(input_nc, output_nc, kernel_size=(4, 4), norm_layer=nn.BatchNorm2d, padding=(1, 1), stride=(2, 2), bias=False): 6 | downconv = nn.Conv2d(input_nc, output_nc, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias,) 7 | downrelu = nn.LeakyReLU(0.2, True) 8 | if norm_layer is not None: 9 | downnorm = norm_layer(output_nc) 10 | return nn.Sequential(*[downconv, downnorm, downrelu]) 11 | else: 12 | return nn.Sequential(*[downconv, downrelu]) 13 | 14 | 15 | def unet_upconv(input_nc, output_nc, kernel_size=(4, 4), outermost=False, norm_layer=nn.BatchNorm2d, stride=(2, 2), 16 | padding=(1, 1), output_padding=(0, 0), bias=False,): 17 | upconv = nn.ConvTranspose2d(input_nc, output_nc, kernel_size=kernel_size, stride=stride, padding=padding, 18 | output_padding=output_padding, bias=bias) 19 | uprelu = nn.ReLU(True) 20 | upnorm = norm_layer(output_nc) 21 | if not outermost: 22 | return nn.Sequential(*[upconv, upnorm, uprelu]) 23 | else: 24 | return nn.Sequential(*[upconv, nn.Sigmoid()]) 25 | 26 | 27 | class PassiveSepEncCNN(nn.Module): 28 | r"""A U-net encoder for passive separation. 29 | 30 | Takes in mixed binaural audio or predicted clean binaural and produces an clean binaural or clean monaural embeddings 31 | and skip-connection feature list, respectively. 32 | 33 | Args: 34 | convert_bin2mono: creates encoder for converting binaural to monaural if set to True 35 | """ 36 | def __init__(self, convert_bin2mono=False): 37 | super().__init__() 38 | self._convert_bin2mono = convert_bin2mono 39 | # originally 2 channels for binaural or concatenation of monos but spec. sliced up into 16 chunks along the frequency 40 | # dimension (this makes the high-res. specs. easier to deal with) 41 | self._slice_factor = 16 42 | self._n_input_audio = 2 * self._slice_factor 43 | if not convert_bin2mono: 44 | self._n_input_audio += 1 45 | 46 | self.cnn = nn.Sequential( 47 | unet_conv(self._n_input_audio, 64,), 48 | unet_conv(64, 64 * 2,), 49 | unet_conv(64 * 2, 64 * 4,), 50 | unet_conv(64 * 4, 64 * 8,), 51 | unet_conv(64 * 8, 64 * 8,) 52 | ) 53 | 54 | self.layer_init() 55 | 56 | def layer_init(self): 57 | for module in self.cnn: 58 | for layer in module: 59 | if isinstance(layer, (nn.Conv2d, nn.Linear)): 60 | nn.init.kaiming_normal_( 61 | layer.weight, nn.init.calculate_gain("leaky_relu", 0.2) 62 | ) 63 | if layer.bias is not None: 64 | nn.init.constant_(layer.bias, val=0) 65 | elif isinstance(layer, (nn.BatchNorm1d, nn.BatchNorm2d)): 66 | if layer.affine: 67 | layer.weight.data.fill_(1) 68 | layer.bias.data.zero_() 69 | 70 | def forward(self, observations, mixed_audio=None,): 71 | cnn_input = [] 72 | 73 | if self._convert_bin2mono: 74 | assert mixed_audio is not None 75 | # observations has pred_binSepMasks 76 | x = observations 77 | mixed_audio = torch.exp(mixed_audio) - 1 78 | x = x * mixed_audio 79 | x = torch.log1p(torch.clamp(x, min=0)) 80 | else: 81 | # observations has all sensor readings 82 | x = observations["mixed_bin_audio_mag"] 83 | 84 | # permute tensor to dimension [BATCH x CHANNEL x HEIGHT X WIDTH] 85 | x = x.permute(0, 3, 1, 2) 86 | 87 | # slice along freq dimension into 16 chunks 88 | x = x.view(x.size(0), x.size(1), self._slice_factor, -1, x.size(3)) 89 | x = x.reshape(x.size(0), -1, x.size(3), x.size(4)) 90 | 91 | # append target class for passive bin. extraction 92 | if not self._convert_bin2mono: 93 | target_class = observations["target_class"] 94 | # adding 1 to the target_class sensor value (probably not necessary) 95 | target_class = target_class.unsqueeze(1).unsqueeze(1).repeat(1, 1, x.size(2), x.size(3)).float() + 1 96 | x = torch.cat((x, target_class), dim=1) 97 | 98 | cnn_input.append(x) 99 | cnn_input = torch.cat(cnn_input, dim=1) 100 | 101 | lst_skip_feats = [] 102 | out = cnn_input 103 | for module in self.cnn: 104 | out = module(out) 105 | lst_skip_feats.append(out) 106 | # return the first N - 1 features (last feature is the bottleneck feature) and invert for convenience during 107 | # upsampling forward pass 108 | return out.reshape(cnn_input.size(0), -1), lst_skip_feats[:-1][::-1] 109 | 110 | 111 | class PassiveSepDecCNN(nn.Module): 112 | r"""A U-net decoder for passive separation. 113 | 114 | Takes in feature embeddings and skip-connection feature list and produces an clean binaural or clean monaural, respectively. 115 | 116 | Args: 117 | convert_bin2mono: creates encoder for converting binaural to monaural if set to True 118 | """ 119 | def __init__(self, convert_bin2mono=False,): 120 | super().__init__() 121 | # originally 2 channels for binaural or concatenation of monos but spec. sliced up into 16 chunks along the frequency 122 | # dimension (this makes the high-res. specs. easier to deal with) 123 | self._slice_factor = 16 124 | self._n_out_audio = self._slice_factor 125 | if not convert_bin2mono: 126 | self._n_out_audio *= 2 127 | 128 | self.cnn = nn.Sequential( 129 | unet_upconv(64 * 8, 64 * 8), 130 | unet_upconv(64 * 16, 64 * 4,), 131 | unet_upconv(64 * 8, 64 * 2,), 132 | unet_upconv(64 * 4, 64 * 1), 133 | unet_upconv(64 * 2, self._n_out_audio, padding=(1, 1)), 134 | nn.Sequential(nn.Conv2d(self._n_out_audio, self._n_out_audio, kernel_size=(1, 1),)), 135 | ) 136 | 137 | self.layer_init() 138 | 139 | def layer_init(self): 140 | for module in self.cnn: 141 | for layer in module: 142 | if isinstance(layer, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): 143 | nn.init.kaiming_normal_( 144 | layer.weight, nn.init.calculate_gain("relu") 145 | ) 146 | if layer.bias is not None: 147 | nn.init.constant_(layer.bias, val=0) 148 | elif isinstance(layer, (nn.BatchNorm1d, nn.BatchNorm2d)): 149 | if layer.affine: 150 | layer.weight.data.fill_(1) 151 | layer.bias.data.zero_() 152 | 153 | def forward(self, bottleneck_feats, lst_skip_feats): 154 | out = bottleneck_feats.view(bottleneck_feats.size(0), -1, 1, 1) 155 | 156 | for idx, module in enumerate(self.cnn): 157 | if (idx == 0) or (idx == len(self.cnn) - 1): 158 | out = module(out) 159 | else: 160 | skip_feats = lst_skip_feats[idx - 1] 161 | out = module(torch.cat((out, skip_feats), dim=1)) 162 | 163 | # deslice 164 | out = out.view(out.size(0), -1, self._slice_factor, out.size(2), out.size(3)) 165 | out = out.reshape(out.size(0), out.size(1), -1, out.size(4)) 166 | 167 | # permute tensor to dimension [BATCH x HEIGHT X WIDTH x CHANNEL] 168 | out = out.permute(0, 2, 3, 1) 169 | 170 | return out 171 | -------------------------------------------------------------------------------- /audio_separation/common/base_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from typing import ClassVar, Dict, List 4 | 5 | import torch 6 | 7 | from habitat import Config, logger 8 | from audio_separation.common.tensorboard_utils import TensorboardWriter 9 | from audio_separation.common.utils import poll_checkpoint_folder 10 | 11 | 12 | class BaseTrainer: 13 | r"""Generic trainer class that serves as a base template for more 14 | specific trainer classes like RL trainer, SLAM or imitation learner. 15 | Includes only the most basic functionality. 16 | """ 17 | 18 | supported_tasks: ClassVar[List[str]] 19 | 20 | def train(self) -> None: 21 | raise NotImplementedError 22 | 23 | def eval(self) -> None: 24 | raise NotImplementedError 25 | 26 | def save_checkpoint(self, file_name) -> None: 27 | raise NotImplementedError 28 | 29 | def load_checkpoint(self, checkpoint_path, *args, **kwargs) -> Dict: 30 | raise NotImplementedError 31 | 32 | 33 | class BaseRLTrainer(BaseTrainer): 34 | r"""Base trainer class for RL trainers. Future RL-specific 35 | methods should be hosted here. 36 | """ 37 | device: torch.device 38 | config: Config 39 | video_option: List[str] 40 | _flush_secs: int 41 | 42 | def __init__(self, config: Config): 43 | super().__init__() 44 | assert config is not None, "needs config file to initialize trainer" 45 | self.config = config 46 | self._flush_secs = 30 47 | 48 | @property 49 | def flush_secs(self): 50 | return self._flush_secs 51 | 52 | @flush_secs.setter 53 | def flush_secs(self, value: int): 54 | self._flush_secs = value 55 | 56 | def train(self) -> None: 57 | raise NotImplementedError 58 | 59 | def eval(self, eval_interval=1, prev_ckpt_ind=-1) -> None: 60 | r"""Main method of trainer evaluation. Calls _eval_checkpoint() that 61 | is specified in Trainer class that inherits from BaseRLTrainer 62 | 63 | Returns: 64 | None 65 | """ 66 | self.device = ( 67 | torch.device("cuda", self.config.TORCH_GPU_ID) 68 | if torch.cuda.is_available() 69 | else torch.device("cpu") 70 | ) 71 | 72 | if "tensorboard" in self.config.VIDEO_OPTION: 73 | assert ( 74 | len(self.config.TENSORBOARD_DIR) > 0 75 | ), "Must specify a tensorboard directory for video display" 76 | if "disk" in self.config.VIDEO_OPTION: 77 | assert ( 78 | len(self.config.VIDEO_DIR) > 0 79 | ), "Must specify a directory for storing videos on disk" 80 | 81 | with TensorboardWriter( 82 | self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs 83 | ) as writer: 84 | if os.path.isfile(self.config.EVAL_CKPT_PATH_DIR): 85 | # evaluate singe checkpoint 86 | self._eval_checkpoint(self.config.EVAL_CKPT_PATH_DIR, writer) 87 | else: 88 | # evaluate multiple checkpoints in order 89 | while True: 90 | current_ckpt = None 91 | while current_ckpt is None: 92 | current_ckpt = poll_checkpoint_folder( 93 | self.config.EVAL_CKPT_PATH_DIR, prev_ckpt_ind, eval_interval 94 | ) 95 | time.sleep(2) # sleep for 2 secs before polling again 96 | logger.info(f"=======current_ckpt: {current_ckpt}=======") 97 | prev_ckpt_ind += eval_interval 98 | self._eval_checkpoint( 99 | checkpoint_path=current_ckpt, 100 | writer=writer, 101 | checkpoint_index=prev_ckpt_ind 102 | ) 103 | 104 | def _setup_eval_config(self, checkpoint_config: Config) -> Config: 105 | r"""Sets up and returns a merged config for evaluation. Config 106 | object saved from checkpoint is merged into config file specified 107 | at evaluation time with the following overwrite priority: 108 | eval_opts > ckpt_opts > eval_cfg > ckpt_cfg 109 | If the saved config is outdated, only the eval config is returned. 110 | 111 | Args: 112 | checkpoint_config: saved config from checkpoint. 113 | 114 | Returns: 115 | Config: merged config for eval. 116 | """ 117 | 118 | config = self.config.clone() 119 | 120 | ckpt_cmd_opts = checkpoint_config.CMD_TRAILING_OPTS 121 | eval_cmd_opts = config.CMD_TRAILING_OPTS 122 | 123 | try: 124 | config.merge_from_other_cfg(checkpoint_config) 125 | config.merge_from_other_cfg(self.config) 126 | config.merge_from_list(ckpt_cmd_opts) 127 | config.merge_from_list(eval_cmd_opts) 128 | except KeyError: 129 | logger.info("Saved config is outdated, using solely eval config") 130 | config = self.config.clone() 131 | config.merge_from_list(eval_cmd_opts) 132 | if config.TASK_CONFIG.DATASET.SPLIT == "train": 133 | config.TASK_CONFIG.defrost() 134 | config.TASK_CONFIG.DATASET.SPLIT = "val" 135 | 136 | config.TASK_CONFIG.SIMULATOR.AGENT_0.defrost() 137 | config.TASK_CONFIG.SIMULATOR.AGENT_0.SENSORS = self.config.SENSORS 138 | config.freeze() 139 | 140 | return config 141 | 142 | def _eval_checkpoint( 143 | self, 144 | checkpoint_path: str, 145 | writer: TensorboardWriter, 146 | checkpoint_index: int = 0, 147 | ) -> None: 148 | r"""Evaluates a single checkpoint. Trainer algorithms should 149 | implement this. 150 | 151 | Args: 152 | checkpoint_path: path of checkpoint 153 | writer: tensorboard writer object for logging to tensorboard 154 | checkpoint_index: index of cur checkpoint for logging 155 | 156 | Returns: 157 | None 158 | """ 159 | raise NotImplementedError 160 | 161 | def save_checkpoint(self, file_name) -> None: 162 | raise NotImplementedError 163 | 164 | def load_checkpoint(self, checkpoint_path, *args, **kwargs) -> Dict: 165 | raise NotImplementedError 166 | 167 | @staticmethod 168 | def _pause_envs( 169 | envs_to_pause, 170 | envs, 171 | test_recurrent_hidden_states, 172 | not_done_masks, 173 | current_episode_reward, 174 | prev_actions, 175 | batch, 176 | rgb_frames, 177 | ): 178 | # pausing self.envs with no new episode 179 | if len(envs_to_pause) > 0: 180 | state_index = list(range(envs.num_envs)) 181 | for idx in reversed(envs_to_pause): 182 | state_index.pop(idx) 183 | envs.pause_at(idx) 184 | 185 | # indexing along the batch dimensions 186 | test_recurrent_hidden_states = test_recurrent_hidden_states[ 187 | :, state_index 188 | ] 189 | not_done_masks = not_done_masks[state_index] 190 | current_episode_reward = current_episode_reward[state_index] 191 | prev_actions = prev_actions[state_index] 192 | 193 | for k, v in batch.items(): 194 | batch[k] = v[state_index] 195 | 196 | rgb_frames = [rgb_frames[i] for i in state_index] 197 | 198 | return ( 199 | envs, 200 | test_recurrent_hidden_states, 201 | not_done_masks, 202 | current_episode_reward, 203 | prev_actions, 204 | batch, 205 | rgb_frames, 206 | ) 207 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Move2Hear: Active Audio-Visual Source Separation 2 | This repository contains the PyTorch implementation of our **ICCV-21 paper** and the associated datasets: 3 | 4 | [Move2Hear: Active Audio-Visual Source Separation](http://vision.cs.utexas.edu/projects/move2hear)
5 | Sagnik Majumder, Ziad Al-Halah, Kristen Grauman
6 | The University of Texas at Austin, Facebook AI Research 7 | 8 | Project website: [http://vision.cs.utexas.edu/projects/move2hear](http://vision.cs.utexas.edu/projects/move2hear) 9 | 10 |

11 | 12 |

13 | 14 | ## Related repos 15 | [Active Audio-Visual Separation of Dynamic Sound Sources](https://github.com/SAGNIKMJR/active-AV-dynamic-separation) 16 | 17 | ## Abstract 18 | We introduce the active audio-visual source separation problem, where an agent must move intelligently in order to better isolate the sounds coming from an object of interest in its environment. The agent hears multiple audio sources simultaneously (e.g., a person speaking down the hall in a noisy household) and it must use its eyes and ears to automatically separate out the sounds originating from a target object within a limited time budget. Towards this goal, we introduce a reinforcement learning approach that trains movement policies controlling the agent's camera and microphone placement over time, guided by the improvement in predicted audio separation quality. We demonstrate our approach in scenarios motivated by both augmented reality (system is already co-located with the target object) and mobile robotics (agent begins arbitrarily far from the target object). Using state-of-the-art realistic audio-visual simulations in 3D environments, we demonstrate our model's ability to find minimal movement sequences with maximal payoff for audio source separation. 19 | 20 | ## Dependencies 21 | This code has been tested with ```python 3.6.12```, ```habitat-api 0.1.4```, ```habitat-sim 0.1.4``` and ```torch 1.4.0```. Additional python package requirements are available in ```requirements.txt```. 22 | 23 | First, install the required versions of [habitat-api](https://github.com/facebookresearch/habitat-lab), [habitat-sim](https://github.com/facebookresearch/habitat-sim) and [torch](https://pytorch.org/) inside a [conda](https://www.anaconda.com/) environment. 24 | 25 | Next, install the remaining dependencies either by 26 | ``` 27 | pip3 install -r requirements.txt 28 | ``` 29 | or by parsing ```requirements.txt``` to get the names and versions of individual dependencies and install them individually. 30 | 31 | ## Datasets 32 | Download the AAViSS-specific datasets from [this link](https://bit.ly/3zGpo4A), extract the zip and put it under the project root. The extracted ```data``` directory should have 3 types of data 33 | 1. **audio_data**: the pre-processed and pre-normalized raw monaural audio waveforms for training and evaluation 34 | 2. **passive_datasets**: the dataset (audio source and receiver pair spatial attributes) for pre-training of passive separators 35 | 3. **active_datasets**: the dataset (episode specification) for training of Move2Hear policies 36 | 37 | Make a directory named ```sound_spaces``` and place it in the same directory as the one where the project root resides. Download the [SoundSpaces](https://github.com/facebookresearch/sound-spaces/blob/main/soundspaces/README.md) Matterport3D **binaural RIRs** and **metadata**, and extract them into directories named ```sound_spaces/binaural_rirs/mp3d``` and ```sound_spaces/metadata/mp3d```, respectively. 38 | 39 | Download the [Matterport3D](https://niessner.github.io/Matterport/) dataset, and cache the observations relevant for the SoundSpaces simulator using [this script](https://github.com/facebookresearch/sound-spaces/blob/main/scripts/cache_observations.py) from the [SoundSpaces repository](https://github.com/facebookresearch/sound-spaces). Use resolutions of ```128 x 128``` for both RGB and depth sensors. Place the cached observations for all scenes (.pkl files) in ```sound_spaces/scene_observations_new```. 40 | 41 | For further info about the structuring of the associated datasets, refer to ```audio_separation/config/default.py``` or the task configs. 42 | 43 | ## Code 44 | ###### Pretraining 45 | ``` 46 | CUDA_VISIBLE_DEVICES=0 python3 main.py --exp-config audio_separation/config/pretrain_passive.yaml --model-dir runs/passive_pretrain/PRETRAIN_DIRNAME --run-type train NUM_PROCESSES 1 47 | ``` 48 | ###### Policy Training 49 | 8 GPU DDPPO training: 50 | 1. **Near-Target** 51 | ``` 52 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -u -m torch.distributed.launch --use_env --nproc_per_node 8 main.py --exp-config audio_separation/config/train/nearTarget.yaml --model-dir runs/train/nearTarget/NEAR_TARGET_TRAIN_DIRNAME --run-type train NUM_PROCESSES 14 53 | ``` 54 | 2. **Far-Target** 55 | ``` 56 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -u -m torch.distributed.launch --use_env --nproc_per_node 8 main.py --exp-config audio_separation/config/train/farTarget.yaml --model-dir runs/train/farTarget/FAR_TARGET_TRAIN_DIRNAME --run-type train NUM_PROCESSES 14 57 | ``` 58 | 59 | ###### Validation 60 | Link checkpoints using ```scripts/search_for_checkpoint_thru_validation/link_ckpts_for_val.ipynb``` to search for best checkpoint on the basis of validation. 61 | 1. **Near-Target** 62 | ``` 63 | CUDA_VISIBLE_DEVICES=0 python3 main.py --exp-config audio_separation/config/val/nearTarget.yaml --model-dir runs/val/nearTarget/NEAR_TARGET_VAL_DIRNAME --run-type eval NUM_PROCESSES 1 64 | ``` 65 | 2. **Far-Target** 66 | ``` 67 | CUDA_VISIBLE_DEVICES=0 python3 main.py --exp-config audio_separation/config/val/farTarget.yaml --model-dir runs/val/farTarget/FAR_TARGET_VAL_DIRNAME --run-type eval NUM_PROCESSES 1 68 | ``` 69 | 70 | Search for best checkpoint using ```scripts/search_for_checkpoint_thru_validation/find_bestCkpt_lowestValSTFTLoss.ipynb```. 71 | 72 | For unheard sounds, use ```config/val/nearTarget_unheard.yaml``` or ```config/val/farTarget_unheard.yaml```, and use the corresponding validation directory. 73 | 74 | 75 | ###### Testing 76 | 1. **Near-Target** 77 | ``` 78 | CUDA_VISIBLE_DEVICES=0 python3 main.py --exp-config audio_separation/config/test/nearTarget.yaml --model-dir runs/test/nearTarget/NEAR_TARGET_TEST_DIRNAME --run-type eval NUM_PROCESSES 1 79 | ``` 80 | 81 | 2. **Far-Target** 82 | Copy configs and best checkpoints from Near-Target and Far-Target into one single checkpoint using ```scripts/farTarget_eval/copy_individualCkptsNCfgs_switchPolicyEval.ipynb``` for switching policies during eval. 83 | ``` 84 | CUDA_VISIBLE_DEVICES=0 python3 main.py --exp-config audio_separation/config/test/farTarget.yaml --model-dir runs/test/farTarget/FAR_TARGET_TEST_DIRNAME --run-type eval NUM_PROCESSES 1 85 | ``` 86 | 87 | Compute test metric (STFT l2 loss or SI-SDR) values using ```scripts/separated_audio_quality/compute_separation_qualtiy.ipynb```. 88 | 89 | For unheard sounds, use ```config/test/nearTarget_unheard.yaml``` or ```config/test/farTarget_unheard.yaml```, and use the corresponding test directory. 90 | 91 | 92 | 93 | ## Pretrained models 94 | Download pretrained model checkpoints from this [link](https://utexas.box.com/s/0pdi6goecfvbh45n045r2lbb8jsgmonw). 95 | 96 | 97 | ## Citation 98 | ``` 99 | @InProceedings{Majumder_2021_ICCV, 100 | author = {Majumder, Sagnik and Al-Halah, Ziad and Grauman, Kristen}, 101 | title = {Move2Hear: Active Audio-Visual Source Separation}, 102 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 103 | month = {October}, 104 | year = {2021}, 105 | pages = {275-285} 106 | } 107 | ``` 108 | 109 | # License 110 | This project is released under the MIT license, as found in the LICENSE file. 111 | -------------------------------------------------------------------------------- /habitat_audio/dataset.py: -------------------------------------------------------------------------------- 1 | import attr 2 | import gzip 3 | import json 4 | import os 5 | import logging 6 | from typing import List, Optional, Dict 7 | 8 | from habitat.config import Config 9 | from habitat.core.dataset import Dataset 10 | from habitat.core.registry import registry 11 | from habitat.tasks.nav.nav import ( 12 | NavigationEpisode, 13 | NavigationGoal, 14 | ShortestPathPoint, 15 | ) 16 | 17 | 18 | ALL_SCENES_MASK = "*" 19 | CONTENT_SCENES_PATH_FIELD = "content_scenes_path" 20 | DEFAULT_SCENE_PATH_PREFIX = "data/scene_dataset/" 21 | 22 | 23 | @attr.s(auto_attribs=True, kw_only=True) 24 | class NavigationEpisodeCustom(NavigationEpisode): 25 | r"""Class for episode specification that includes all geodesic distances. 26 | Args: 27 | all_geodesic_distances: geodesic distances of agent to source 28 | and in between two sources 29 | """ 30 | 31 | all_geodesic_distances: Optional[Dict[str, str]] = None 32 | gt_actions: Optional[Dict[str, str]] = None 33 | 34 | 35 | @registry.register_dataset(name="AAViSS") 36 | class AAViSSDataset(Dataset): 37 | episodes: List[NavigationEpisodeCustom] 38 | content_scenes_path: str = "{data_path}/content/{scene}.json.gz" 39 | 40 | @staticmethod 41 | def check_config_paths_exist(config: Config) -> bool: 42 | r""" 43 | check if paths to episode datasets exist 44 | :param config: 45 | :return: 46 | """ 47 | return os.path.exists( 48 | config.DATA_PATH.format(version=config.VERSION, split=config.SPLIT) if len(config.SPLIT.split("/")) == 1\ 49 | else config.DATA_PATH.format(version=config.VERSION, split1=config.SPLIT, split2=config.SPLIT.split("/")[-1]) 50 | ) and os.path.exists(config.SCENES_DIR) 51 | 52 | @staticmethod 53 | def get_sounds_in_split(config: Config) -> str: 54 | return [] 55 | 56 | @staticmethod 57 | def get_scenes_to_load(config: Config) -> List[str]: 58 | r"""Return list of scene ids for which dataset has separate files with 59 | episodes. 60 | """ 61 | assert AAViSSDataset.check_config_paths_exist(config), \ 62 | (config.DATA_PATH.format(version=config.VERSION, split=config.SPLIT)\ 63 | if len(config.SPLIT.split("/")) == 1 else\ 64 | config.DATA_PATH.format(version=config.VERSION, split1=config.SPLIT, split2=config.SPLIT.split("/")[-1]), 65 | config.SCENES_DIR) 66 | dataset_dir = os.path.dirname( 67 | config.DATA_PATH.format(version=config.VERSION, split=config.SPLIT)\ 68 | if len(config.SPLIT.split("/")) == 1 else\ 69 | config.DATA_PATH.format(version=config.VERSION, split1=config.SPLIT, split2=config.SPLIT.split("/")[-1]) 70 | ) 71 | 72 | cfg = config.clone() 73 | cfg.defrost() 74 | cfg.CONTENT_SCENES = [] 75 | dataset = AAViSSDataset(cfg) 76 | return AAViSSDataset._get_scenes_from_folder( 77 | content_scenes_path=dataset.content_scenes_path, 78 | dataset_dir=dataset_dir, 79 | ) 80 | 81 | @staticmethod 82 | def _get_scenes_from_folder(content_scenes_path, dataset_dir): 83 | scenes = [] 84 | content_dir = content_scenes_path.split("{scene}")[0] 85 | scene_dataset_ext = content_scenes_path.split("{scene}")[1] 86 | content_dir = content_dir.format(data_path=dataset_dir) 87 | if not os.path.exists(content_dir): 88 | return scenes 89 | 90 | for filename in os.listdir(content_dir): 91 | if filename.endswith(scene_dataset_ext): 92 | scene = filename[: -len(scene_dataset_ext)] 93 | scenes.append(scene) 94 | scenes.sort() 95 | return scenes 96 | 97 | def __init__(self, config: Optional[Config] = None) -> None: 98 | r"""Class inherited from Dataset that loads Point Navigation dataset. 99 | """ 100 | self.episodes = [] 101 | self._config = config 102 | 103 | if config is None: 104 | return 105 | 106 | datasetfile_path = config.DATA_PATH.format(version=config.VERSION, split=config.SPLIT)\ 107 | if len(config.SPLIT.split("/")) == 1\ 108 | else config.DATA_PATH.format(version=config.VERSION, split1=config.SPLIT, split2=config.SPLIT.split("/")[-1]) 109 | with gzip.open(datasetfile_path, "rt") as f: 110 | self.from_json(f.read(), scenes_dir=config.SCENES_DIR, scene_filename=datasetfile_path) 111 | 112 | # Read separate file for each scene 113 | dataset_dir = os.path.dirname(datasetfile_path) 114 | scenes = config.CONTENT_SCENES 115 | if ALL_SCENES_MASK in scenes: 116 | scenes = AAViSSDataset._get_scenes_from_folder( 117 | content_scenes_path=self.content_scenes_path, 118 | dataset_dir=dataset_dir, 119 | ) 120 | 121 | last_episode_cnt = 0 122 | for scene in scenes: 123 | scene_filename = self.content_scenes_path.format( 124 | data_path=dataset_dir, scene=scene 125 | ) 126 | with gzip.open(scene_filename, "rt") as f: 127 | self.from_json(f.read(), scenes_dir=config.SCENES_DIR, scene_filename=scene_filename) 128 | 129 | num_episode = len(self.episodes) - last_episode_cnt 130 | last_episode_cnt = len(self.episodes) 131 | logging.info('Sampled {} from {}'.format(num_episode, scene)) 132 | 133 | # filter by scenes for data collection 134 | def filter_by_scenes(self, scenes): 135 | r""" 136 | filter all episodes on the basis of scene names 137 | :param scenes: scenes to filter episodes with 138 | :return: filtered episodes 139 | """ 140 | episodes_to_keep = list() 141 | for episode in self.episodes: 142 | episode_scene = episode.scene_id.split("/")[-1].split(".")[0] 143 | if episode_scene in scenes: 144 | episodes_to_keep.append(episode) 145 | self.episodes = episodes_to_keep 146 | 147 | # filter by scenes for data collection 148 | def filter_by_scenes_n_ids(self, scenes_n_ids): 149 | r""" 150 | filter all episodes on the basis of scene names and episode IDs 151 | :param scenes_n_ids: scene names and episode IDs to filter all episodes with 152 | :return: filtered episodes 153 | """ 154 | episodes_to_keep = list() 155 | for episode in self.episodes: 156 | episode_scene = episode.scene_id.split("/")[-1].split(".")[0] 157 | episode_id = int(episode.episode_id) 158 | if episode_scene + "_" + str(episode_id) in scenes_n_ids: 159 | episodes_to_keep.append(episode) 160 | self.episodes = episodes_to_keep 161 | 162 | def from_json( 163 | self, json_str: str, scenes_dir: Optional[str] = None, scene_filename: Optional[str] = None 164 | ) -> None: 165 | r""" 166 | loads and reads episodes from per-scene json files 167 | :param json_str: json file name 168 | :param scenes_dir: directory containing json files 169 | :return: None 170 | """ 171 | deserialized = json.loads(json_str) 172 | if CONTENT_SCENES_PATH_FIELD in deserialized: 173 | self.content_scenes_path = deserialized[CONTENT_SCENES_PATH_FIELD] 174 | 175 | episode_cnt = 0 176 | for episode in deserialized["episodes"]: 177 | episode = NavigationEpisodeCustom(**episode) 178 | 179 | if scenes_dir is not None: 180 | if episode.scene_id.startswith(DEFAULT_SCENE_PATH_PREFIX): 181 | episode.scene_id = episode.scene_id[ 182 | len(DEFAULT_SCENE_PATH_PREFIX): 183 | ] 184 | 185 | episode.scene_id = os.path.join(scenes_dir, episode.scene_id) 186 | 187 | for g_index, goal in enumerate(episode.goals): 188 | episode.goals[g_index] = NavigationGoal(**goal) 189 | if episode.shortest_paths is not None: 190 | for path in episode.shortest_paths: 191 | for p_index, point in enumerate(path): 192 | path[p_index] = ShortestPathPoint(**point) 193 | self.episodes.append(episode) 194 | episode_cnt += 1 195 | -------------------------------------------------------------------------------- /habitat_audio/task.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Type, Union 2 | 3 | import numpy as np 4 | from gym import spaces 5 | 6 | from habitat.config import Config 7 | from habitat.core.dataset import Episode 8 | 9 | from habitat.tasks.nav.nav import NavigationTask, Measure, EmbodiedTask, SimulatorTaskAction 10 | from habitat.core.registry import registry 11 | from habitat.core.simulator import ( 12 | Sensor, 13 | SensorTypes, 14 | Simulator, 15 | ) 16 | 17 | from habitat.sims.habitat_simulator.actions import HabitatSimActions 18 | 19 | 20 | def merge_sim_episode_config( 21 | sim_config: Config, episode: Type[Episode] 22 | ) -> Any: 23 | sim_config.defrost() 24 | # here's where the scene update happens, extract the scene name out of the path 25 | sim_config.SCENE = episode.scene_id 26 | sim_config.freeze() 27 | if ( 28 | episode.start_position is not None 29 | and episode.start_rotation is not None 30 | ): 31 | agent_name = sim_config.AGENTS[sim_config.DEFAULT_AGENT_ID] 32 | agent_cfg = getattr(sim_config, agent_name) 33 | agent_cfg.defrost() 34 | agent_cfg.START_POSITION = episode.start_position 35 | agent_cfg.START_ROTATION = episode.start_rotation 36 | agent_cfg.TARGET_CLASS = episode.info[0]["target_label"] 37 | agent_cfg.AUDIO_SOURCE_POSITIONS = [] 38 | for source in episode.goals: 39 | agent_cfg.AUDIO_SOURCE_POSITIONS.append(source.position) 40 | agent_cfg.SOUND_NAMES = [] 41 | for source_info in episode.info: 42 | agent_cfg.SOUND_NAMES.append(source_info["sound"]) 43 | agent_cfg.IS_SET_START_STATE = True 44 | agent_cfg.freeze() 45 | return sim_config 46 | 47 | 48 | @registry.register_task(name="AAViSS") 49 | class AAViSSTask(NavigationTask): 50 | def overwrite_sim_config( 51 | self, sim_config: Any, episode: Type[Episode] 52 | ) -> Any: 53 | return merge_sim_episode_config(sim_config, episode) 54 | 55 | def _check_episode_is_active(self, *args: Any, **kwargs: Any) -> bool: 56 | return self._sim._is_episode_active 57 | 58 | 59 | @registry.register_sensor 60 | class MixedBinAudioMagSensor(Sensor): 61 | r"""Mixed binaural spectrogram magnitude at the current step 62 | """ 63 | 64 | def __init__(self, *args: Any, sim: Simulator, config: Config, **kwargs: Any): 65 | self._sim = sim 66 | super().__init__(config=config) 67 | 68 | def _get_uuid(self, *args: Any, **kwargs: Any): 69 | return "mixed_bin_audio_mag" 70 | 71 | def _get_sensor_type(self, *args: Any, **kwargs: Any): 72 | return SensorTypes.PATH 73 | 74 | def _get_observation_space(self, *args: Any, **kwargs: Any): 75 | assert hasattr(self.config, 'FEATURE_SHAPE') 76 | sensor_shape = self.config.FEATURE_SHAPE 77 | 78 | return spaces.Box( 79 | low=np.finfo(np.float32).min, 80 | high=np.finfo(np.float32).max, 81 | shape=sensor_shape, 82 | dtype=np.float32, 83 | ) 84 | 85 | def get_observation(self, *args: Any, observations, episode: Episode, **kwargs: Any): 86 | return self._sim.get_current_mixed_bin_audio_mag_spec() 87 | 88 | 89 | @registry.register_sensor 90 | class MixedBinAudioPhaseSensor(Sensor): 91 | r"""Mixed binaural spectrogram phase at the current step 92 | """ 93 | 94 | def __init__(self, *args: Any, sim: Simulator, config: Config, **kwargs: Any): 95 | self._sim = sim 96 | super().__init__(config=config) 97 | 98 | def _get_uuid(self, *args: Any, **kwargs: Any): 99 | return "mixed_bin_audio_phase" 100 | 101 | def _get_sensor_type(self, *args: Any, **kwargs: Any): 102 | return SensorTypes.PATH 103 | 104 | def _get_observation_space(self, *args: Any, **kwargs: Any): 105 | assert hasattr(self.config, 'FEATURE_SHAPE') 106 | sensor_shape = self.config.FEATURE_SHAPE 107 | 108 | return spaces.Box( 109 | low=np.finfo(np.float32).min, 110 | high=np.finfo(np.float32).max, 111 | shape=sensor_shape, 112 | dtype=np.float32, 113 | ) 114 | 115 | def get_observation(self, *args: Any, observations, episode: Episode, **kwargs: Any): 116 | return self._sim.get_current_mixed_bin_audio_phase_spec() 117 | 118 | 119 | @registry.register_sensor 120 | class GtMonoComponentsSensor(Sensor): 121 | r"""Ground truth monaural components at the current step 122 | """ 123 | 124 | def __init__(self, *args: Any, sim: Simulator, config: Config, **kwargs: Any): 125 | self._sim = sim 126 | super().__init__(config=config) 127 | 128 | def _get_uuid(self, *args: Any, **kwargs: Any): 129 | return "gt_mono_comps" 130 | 131 | def _get_sensor_type(self, *args: Any, **kwargs: Any): 132 | return SensorTypes.PATH 133 | 134 | def _get_observation_space(self, *args: Any, **kwargs: Any): 135 | assert hasattr(self.config, 'FEATURE_SHAPE') 136 | sensor_shape = self.config.FEATURE_SHAPE 137 | 138 | return spaces.Box( 139 | low=np.finfo(np.float32).min, 140 | high=np.finfo(np.float32).max, 141 | shape=sensor_shape, 142 | dtype=np.float32, 143 | ) 144 | 145 | def get_observation(self, *args: Any, observations, episode: Episode, **kwargs: Any): 146 | return self._sim.get_current_gt_mono_audio_components() 147 | 148 | 149 | @registry.register_sensor 150 | class GtBinComponentsSensor(Sensor): 151 | r"""Ground truth binaural components at the current step 152 | """ 153 | 154 | def __init__(self, *args: Any, sim: Simulator, config: Config, **kwargs: Any): 155 | self._sim = sim 156 | super().__init__(config=config) 157 | 158 | def _get_uuid(self, *args: Any, **kwargs: Any): 159 | return "gt_bin_comps" 160 | 161 | def _get_sensor_type(self, *args: Any, **kwargs: Any): 162 | return SensorTypes.PATH 163 | 164 | def _get_observation_space(self, *args: Any, **kwargs: Any): 165 | assert hasattr(self.config, 'FEATURE_SHAPE') 166 | sensor_shape = self.config.FEATURE_SHAPE 167 | 168 | return spaces.Box( 169 | low=np.finfo(np.float32).min, 170 | high=np.finfo(np.float32).max, 171 | shape=sensor_shape, 172 | dtype=np.float32, 173 | ) 174 | 175 | def get_observation(self, *args: Any, observations, episode: Episode, **kwargs: Any): 176 | return self._sim.get_current_gt_bin_audio_components() 177 | 178 | 179 | @registry.register_sensor(name="TargetClassSensor") 180 | class TargetClassSensor(Sensor): 181 | r"""Target class for the current episode 182 | """ 183 | 184 | def __init__( 185 | self, sim: Union[Simulator, Config], config: Config, *args: Any, **kwargs: Any 186 | ): 187 | super().__init__(config=config) 188 | self._sim = sim 189 | 190 | def _get_uuid(self, *args: Any, **kwargs: Any): 191 | return "target_class" 192 | 193 | def _get_sensor_type(self, *args: Any, **kwargs: Any): 194 | return SensorTypes.COLOR 195 | 196 | def _get_observation_space(self, *args: Any, **kwargs: Any): 197 | return spaces.Box( 198 | low=0, 199 | high=1, 200 | shape=(1,), 201 | dtype=bool 202 | ) 203 | 204 | def get_observation( 205 | self, *args: Any, observations, episode: Episode, **kwargs: Any 206 | ) -> object: 207 | return [self._sim.target_class] 208 | 209 | 210 | @registry.register_measure 211 | class GeoDistanceToTargetAudioSource(Measure): 212 | r"""Geodesic distance to target audio source for every time step 213 | """ 214 | 215 | def __init__( 216 | self, *args: Any, sim: Simulator, config: Config, **kwargs: Any 217 | ): 218 | self._start_end_episode_distance = None 219 | self._sim = sim 220 | self._config = config 221 | 222 | super().__init__() 223 | 224 | def _get_uuid(self, *args: Any, **kwargs: Any): 225 | return "geo_distance_to_target_audio_source" 226 | 227 | def reset_metric(self, *args: Any, episode, **kwargs: Any): 228 | self._start_end_episode_distance = episode.info[0]["geodesic_distance"] 229 | self._metric = None 230 | self.update_metric(episode=episode, *args, **kwargs) 231 | 232 | def update_metric( 233 | self, *args: Any, episode, **kwargs: Any 234 | ): 235 | current_position = self._sim.get_agent_state().position.tolist() 236 | 237 | distance_to_target = self._sim.geodesic_distance( 238 | current_position, episode.goals[0].position 239 | ) 240 | 241 | self._metric = distance_to_target 242 | 243 | 244 | @registry.register_measure 245 | class NormalizedGeoDistanceToTargetAudioSource(Measure): 246 | r"""Normalized geodesic distance to target audio source for every time step 247 | """ 248 | 249 | def __init__( 250 | self, *args: Any, sim: Simulator, config: Config, **kwargs: Any 251 | ): 252 | self._start_end_episode_distance = None 253 | self._sim = sim 254 | self._config = config 255 | 256 | super().__init__() 257 | 258 | def _get_uuid(self, *args: Any, **kwargs: Any): 259 | return "normalized_geo_distance_to_target_audio_source" 260 | 261 | def reset_metric(self, *args: Any, episode, **kwargs: Any): 262 | self._start_end_episode_distance = episode.info[0]["geodesic_distance"] 263 | self._metric = None 264 | 265 | def update_metric( 266 | self, *args: Any, episode, action, task: EmbodiedTask, **kwargs: Any 267 | ): 268 | current_position = self._sim.get_agent_state().position.tolist() 269 | 270 | distance_to_target = self._sim.geodesic_distance( 271 | current_position, episode.goals[0].position 272 | ) 273 | 274 | if self._start_end_episode_distance == 0: 275 | self._metric = -1 276 | else: 277 | self._metric = distance_to_target / self._start_end_episode_distance 278 | 279 | 280 | @registry.register_task_action 281 | class PauseAction(SimulatorTaskAction): 282 | name: str = "PAUSE" 283 | 284 | def step(self, *args: Any, **kwargs: Any): 285 | r"""Update ``_metric``, this method is called from ``Env`` on each 286 | ``step``. 287 | """ 288 | return self._sim.step(HabitatSimActions.PAUSE) 289 | -------------------------------------------------------------------------------- /audio_separation/pretrain/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from tqdm import tqdm 4 | import numpy as np 5 | from scipy.io import wavfile 6 | from scipy.signal import fftconvolve 7 | import librosa 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | 12 | # 10 VoxCeleb1 classes, 1 combined music class from MIT MUSIC, 1 combined ESC-50 class (only used as background distractor) 13 | CLASS_NAMES_TO_LABELS = {"id10393": 0, "id10583": 1, "id10061": 2, "id10954": 3, "id10355": 4, "id10799": 5, 14 | "id10203": 6, "id10371": 7, "id10361": 8, "id10254": 9, "music": 10, "esc": 11,} 15 | LABELS_TO_CLASS_NAMES = {} 16 | for key, val in CLASS_NAMES_TO_LABELS.items(): 17 | LABELS_TO_CLASS_NAMES[val] = key 18 | 19 | # STFT params 20 | HOP_LENGTH = 512 21 | N_FFT = 1023 22 | 23 | 24 | class PassiveDataset(Dataset): 25 | def __init__(self, split="train", scene_graphs=None, sim_cfg=None,): 26 | np.random.seed(42) 27 | torch.manual_seed(42) 28 | 29 | self.split = split 30 | self.audio_cfg = sim_cfg.AUDIO 31 | self.passive_dataset_version = self.audio_cfg.PASSIVE_DATASET_VERSION 32 | self.binaural_rir_dir = self.audio_cfg.RIR_DIR 33 | self.use_cache = False if split.split("_")[0] == "train" else True 34 | self.rir_sampling_rate = 16000 35 | self.num_sources = 2 36 | 37 | assert split in ["train", "val", "nonoverlapping_val"] 38 | 39 | # datapoints: locations to place sources and pose of agent 40 | if split == "nonoverlapping_val": 41 | self.sourceAgentLocn_datapoints_dir = os.path.join(self.audio_cfg.SOURCE_AGENT_LOCATION_DATAPOINTS_DIR, 42 | self.passive_dataset_version, "val") 43 | else: 44 | self.sourceAgentLocn_datapoints_dir = os.path.join(self.audio_cfg.SOURCE_AGENT_LOCATION_DATAPOINTS_DIR, 45 | self.passive_dataset_version, split) 46 | 47 | # datapoints: raw mono sounds playing at the sources 48 | if split in ["train", "val"]: 49 | self.audio_dir = self.audio_cfg.PASSIVE_TRAIN_AUDIO_DIR 50 | elif split in "nonoverlapping_val": 51 | self.audio_dir = self.audio_cfg.PASSIVE_NONOVERLAPPING_VAL_AUDIO_DIR 52 | assert os.path.exists(self.audio_dir) 53 | 54 | self.audio_files = [] 55 | for _, __, self.audio_files in os.walk(self.audio_dir): 56 | break 57 | 58 | self.audio_files_per_class = {} 59 | for audio_file in self.audio_files: 60 | assert audio_file.split("_")[0] in CLASS_NAMES_TO_LABELS 61 | if audio_file.split("_")[0] not in self.audio_files_per_class: 62 | self.audio_files_per_class[audio_file.split("_")[0]] = [audio_file] 63 | else: 64 | self.audio_files_per_class[audio_file.split("_")[0]].append(audio_file) 65 | 66 | self.file2audio_dict = dict() 67 | self.load_source_audio() 68 | 69 | self.complete_datapoint_files = list() 70 | self.target_classes = list() 71 | tqdm_over = scene_graphs 72 | for scene in tqdm(tqdm_over): 73 | with open(os.path.join(self.sourceAgentLocn_datapoints_dir, scene + ".pkl"), "rb") as fi: 74 | sourceAgentLocn_datapoints = pickle.load(fi)[scene] 75 | 76 | if split.split("_")[0] == "train": 77 | sourceAgentLocn_datapoints = sourceAgentLocn_datapoints[:self.audio_cfg.NUM_PASSIVE_DATAPOINTS_PER_SCENE] 78 | elif (split.split("_")[0] == "val") or (split.split("_")[1] == "val"): 79 | sourceAgentLocn_datapoints = sourceAgentLocn_datapoints[:self.audio_cfg.NUM_PASSIVE_DATAPOINTS_PER_SCENE_EVAL] 80 | 81 | for datapoint in sourceAgentLocn_datapoints: 82 | audio_files, target_class = self.get_target_class_audio_files_for_sources() 83 | self.target_classes.append(target_class) 84 | 85 | receiver_locn = datapoint['r'] 86 | receiver_azimuth = datapoint['azimuth'] 87 | all_source_locns = datapoint['all_s'] 88 | complete_datapoint = [] 89 | for idx, source_locn in enumerate(all_source_locns): 90 | binaural_rir_file = os.path.join(scene, str(receiver_azimuth), f"{receiver_locn}_{source_locn}.wav") 91 | audio_file = audio_files[idx] 92 | complete_datapoint.append((binaural_rir_file, audio_file)) 93 | self.complete_datapoint_files.append(complete_datapoint) 94 | 95 | # this is used to cache mono stfts to prevent redundant computation (depending on the amount of audio data this 96 | # could be removed to save memory footprint) 97 | self._gt_mono_mag_cache = dict() 98 | 99 | if self.use_cache: 100 | self.complete_datapoints = [None] * len(self.complete_datapoint_files) 101 | for item in tqdm(range(len(self.complete_datapoints))): 102 | rirs_audio = self.complete_datapoint_files[item] 103 | mixed_audio, gt_bin_mag, gt_mono_mag, target_class =\ 104 | self.compute_audiospects(rirs_audio, target_class=self.target_classes[item]) 105 | self.complete_datapoints[item] = (mixed_audio, gt_bin_mag, gt_mono_mag, target_class) 106 | 107 | def __len__(self): 108 | return len(self.complete_datapoint_files) 109 | 110 | def __getitem__(self, item): 111 | if not self.use_cache: 112 | audio_files, target_class = self.get_target_class_audio_files_for_sources() 113 | 114 | rirs_audio = self.complete_datapoint_files[item] 115 | # keep the same src locns, and agent locn and pose but resample the source audio 116 | rirs_audio_new = [] 117 | for src_idx, (binaural_rir_file, audio_file) in enumerate(rirs_audio): 118 | rirs_audio_new.append((binaural_rir_file, audio_files[src_idx])) 119 | rirs_audio = rirs_audio_new 120 | 121 | mixed_audio, gt_bin_mag, gt_mono_mag, target_class =\ 122 | self.compute_audiospects(rirs_audio, target_class=target_class) 123 | 124 | mixed_audio = torch.from_numpy(mixed_audio) 125 | gt_bin_mag = torch.from_numpy(np.concatenate(gt_bin_mag, axis=2)) 126 | gt_mono_mag = torch.from_numpy(np.concatenate(gt_mono_mag, axis=2)) 127 | target_class = torch.from_numpy(target_class) 128 | else: 129 | mixed_audio = torch.from_numpy(self.complete_datapoints[item][0]) 130 | gt_bin_mag = torch.from_numpy(np.concatenate(self.complete_datapoints[item][1], axis=2)) 131 | gt_mono_mag = torch.from_numpy(np.concatenate(self.complete_datapoints[item][2], axis=2)) 132 | target_class = torch.from_numpy(self.complete_datapoints[item][3]) 133 | 134 | return mixed_audio, gt_bin_mag, gt_mono_mag, target_class 135 | 136 | def get_target_class_audio_files_for_sources(self): 137 | sampled_classes = (torch.randperm(len(CLASS_NAMES_TO_LABELS)).tolist())[:self.num_sources] 138 | target_class = sampled_classes[0] 139 | assert target_class < len(CLASS_NAMES_TO_LABELS) 140 | while target_class == 11: 141 | sampled_classes = (torch.randperm(len(CLASS_NAMES_TO_LABELS)).tolist())[:self.num_sources] 142 | target_class = sampled_classes[0] 143 | assert target_class < len(CLASS_NAMES_TO_LABELS) 144 | 145 | sampled_class_names = [] 146 | audio_files = [] 147 | for sampled_class in sampled_classes: 148 | sampled_class_names.append(LABELS_TO_CLASS_NAMES[sampled_class]) 149 | audio_files_for_sampled_class = self.audio_files_per_class[LABELS_TO_CLASS_NAMES[sampled_class]] 150 | audio_files.append(audio_files_for_sampled_class[torch.randint(len(audio_files_for_sampled_class), 151 | size=(1,)).item()]) 152 | 153 | return audio_files, target_class 154 | 155 | def load_source_audio(self): 156 | for audio_file in tqdm(self.audio_files): 157 | sr, audio_data = wavfile.read(os.path.join(self.audio_dir, audio_file)) 158 | if sr != self.rir_sampling_rate: 159 | audio_data = resample(audio_data, self.rir_sampling_rate) 160 | self.file2audio_dict[audio_file] = audio_data 161 | 162 | def compute_audiospects(self, rirs_audio, target_class): 163 | gt_mono_mag = [] 164 | gt_bin_mag = [] 165 | mixed_binaural_wave = 0 166 | target_class_idx = -1 167 | for idx, rir_audio in enumerate(rirs_audio): 168 | binaural_rir_file = os.path.join(self.binaural_rir_dir, rir_audio[0]) 169 | mono_audio = self.file2audio_dict[rir_audio[1]] 170 | try: 171 | sr, binaural_rir = wavfile.read(binaural_rir_file) 172 | assert sr == self.rir_sampling_rate, "RIR doesn't have sampling frequency of rir_sampling_rate kHz" 173 | except ValueError: 174 | binaural_rir = np.zeros((self.rir_sampling_rate, 2)).astype("float32") 175 | if len(binaural_rir) == 0: 176 | binaural_rir = np.zeros((self.rir_sampling_rate, 2)).astype("float32") 177 | 178 | binaural_convolved = [] 179 | for channel in range(binaural_rir.shape[-1]): 180 | binaural_convolved.append(fftconvolve(mono_audio, binaural_rir[:, channel], mode="same")) 181 | 182 | binaural_convolved = np.array(binaural_convolved) 183 | # this makes sure that the audio is in the range [-32768, 32767] 184 | binaural_convolved = np.round(binaural_convolved).astype("int16").astype("float32") 185 | binaural_convolved *= (1 / 32768) 186 | 187 | # compute target specs 188 | if idx == 0: 189 | # compute gt bin. magnitude 190 | fft_windows_l = librosa.stft(np.asfortranarray(binaural_convolved[0]), hop_length=HOP_LENGTH, 191 | n_fft=N_FFT) 192 | magnitude_l, _ = librosa.magphase(fft_windows_l) 193 | 194 | fft_windows_r = librosa.stft(np.asfortranarray(binaural_convolved[1]), hop_length=HOP_LENGTH, 195 | n_fft=N_FFT) 196 | magnitude_r, _ = librosa.magphase(fft_windows_r) 197 | 198 | gt_bin_mag.append(np.stack([magnitude_l, magnitude_r], axis=-1).astype("float32")) 199 | 200 | # compute gt mono magnitude 201 | if rir_audio[1] not in self._gt_mono_mag_cache: 202 | # this makes sure that the audio is in the range [-32768, 32767] 203 | mono_audio = mono_audio.astype("float32") / 32768 204 | 205 | fft_windows = librosa.stft(np.asfortranarray(mono_audio), hop_length=HOP_LENGTH, 206 | n_fft=N_FFT) 207 | magnitude, _ = librosa.magphase(fft_windows) 208 | if np.power(np.mean(np.power(magnitude, 2)), 0.5) != 0.: 209 | magnitude = magnitude * self.audio_cfg.GT_MONO_MAG_NORM / np.power(np.mean(np.power(magnitude, 2)), 0.5) 210 | 211 | self._gt_mono_mag_cache[rir_audio[1]] = magnitude 212 | 213 | mono_magnitude = self._gt_mono_mag_cache[rir_audio[1]] 214 | gt_mono_mag.append(np.expand_dims(mono_magnitude, axis=2).astype("float32")) 215 | 216 | mixed_binaural_wave += binaural_convolved 217 | 218 | mixed_binaural_wave /= len(rirs_audio) 219 | 220 | fft_windows_l = librosa.stft(np.asfortranarray(mixed_binaural_wave[0]), hop_length=HOP_LENGTH, n_fft=N_FFT) 221 | magnitude_l, _ = librosa.magphase(fft_windows_l) 222 | 223 | fft_windows_r = librosa.stft(np.asfortranarray(mixed_binaural_wave[1]), hop_length=HOP_LENGTH, n_fft=N_FFT) 224 | magnitude_r, _ = librosa.magphase(fft_windows_r) 225 | 226 | mixed_mag = np.stack([magnitude_l, magnitude_r], axis=-1).astype("float32") 227 | 228 | return np.log1p(mixed_mag), gt_bin_mag, gt_mono_mag, np.array([target_class]).astype("int64") 229 | -------------------------------------------------------------------------------- /audio_separation/rl/ppo/policy.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchsummary import summary 6 | 7 | from audio_separation.common.utils import CategoricalNet 8 | from audio_separation.rl.models.rnn_state_encoder import RNNStateEncoder 9 | from audio_separation.rl.models.visual_cnn import VisualCNN 10 | from audio_separation.rl.models.audio_cnn import AudioCNN 11 | from audio_separation.rl.models.separator_cnn import PassiveSepEncCNN, PassiveSepDecCNN 12 | from audio_separation.rl.models.memory_nets import AcousticMem 13 | 14 | 15 | class CriticHead(nn.Module): 16 | def __init__(self, input_size): 17 | super().__init__() 18 | self.fc = nn.Linear(input_size, 1) 19 | nn.init.orthogonal_(self.fc.weight) 20 | nn.init.constant_(self.fc.bias, 0) 21 | 22 | def forward(self, x): 23 | return self.fc(x) 24 | 25 | 26 | class Net(nn.Module, metaclass=abc.ABCMeta): 27 | @abc.abstractmethod 28 | def forward(self, observations, rnn_hidden_states, prev_actions, masks): 29 | pass 30 | 31 | @property 32 | @abc.abstractmethod 33 | def output_size(self): 34 | pass 35 | 36 | @property 37 | @abc.abstractmethod 38 | def num_recurrent_layers(self): 39 | pass 40 | 41 | @property 42 | @abc.abstractmethod 43 | def is_blind(self): 44 | pass 45 | 46 | 47 | class PolicyNet(Net): 48 | r"""Network which passes the observations and separated audio outputs through CNNs and concatenates 49 | them into a single vector before passing that through RNN. 50 | """ 51 | def __init__(self, observation_space, hidden_size, goal_sensor_uuid, extra_rgb=False, extra_depth=False, 52 | world_rank=0,): 53 | super().__init__() 54 | assert 'mixed_bin_audio_mag' in observation_space.spaces 55 | self.goal_sensor_uuid = goal_sensor_uuid 56 | self._hidden_size = hidden_size 57 | 58 | self.visual_encoder = VisualCNN(observation_space, hidden_size, extra_rgb, extra_depth) 59 | self.bin_encoder = AudioCNN(observation_space, hidden_size,) 60 | self.monoNmonoFromMem_encoder = AudioCNN(observation_space, hidden_size, encode_monoNmonoFromMem=True,) 61 | 62 | rnn_input_size = 3 * self._hidden_size 63 | self.state_encoder = RNNStateEncoder(rnn_input_size, self._hidden_size) 64 | 65 | # printing out the network layers 66 | if world_rank == 0: 67 | if (('rgb' in observation_space.spaces) and (not extra_rgb)) and\ 68 | (('depth' in observation_space.spaces) and (not extra_depth)): 69 | assert observation_space.spaces['rgb'].shape[:-1] == observation_space.spaces['depth'].shape[:-1] 70 | # hardcoding needed since rgb is of size (720, 720, 3) when rendering videos and this call throws an error 71 | rgb_shape = (128, 128, 3) # observation_space.spaces['rgb'].shape 72 | depth_shape = observation_space.spaces['depth'].shape 73 | summary(self.visual_encoder.cnn, (rgb_shape[2] + depth_shape[2], rgb_shape[0], rgb_shape[1]), device='cpu') 74 | elif ('rgb' in observation_space.spaces) and (not extra_rgb): 75 | """hardcoding needed since rgb is of size (720, 720, 3) when rendering videos and this call throws an error""" 76 | rgb_shape = (128, 128, 3) # observation_space.spaces['rgb'].shape 77 | summary(self.visual_encoder.cnn, (rgb_shape[2], rgb_shape[0], rgb_shape[1]), device='cpu') 78 | elif ('depth' in observation_space.spaces) and (not extra_depth): 79 | depth_shape = observation_space.spaces['depth'].shape 80 | summary(self.visual_encoder.cnn, (depth_shape[2], depth_shape[0], depth_shape[1]), device='cpu') 81 | 82 | audio_shape = observation_space.spaces['mixed_bin_audio_mag'].shape 83 | summary(self.bin_encoder.cnn, (2 * 16, audio_shape[0] // 16, audio_shape[1]), device='cpu') 84 | summary(self.monoNmonoFromMem_encoder.cnn, (2 * 16, audio_shape[0] // 16, audio_shape[1]), device='cpu') 85 | 86 | @property 87 | def is_blind(self): 88 | return False 89 | 90 | @property 91 | def output_size(self): 92 | return self._hidden_size 93 | 94 | @property 95 | def num_recurrent_layers(self): 96 | return self.state_encoder.num_recurrent_layers 97 | 98 | def forward(self, observations, rnn_hidden_states, masks, pred_binSepMasks=None, pred_mono=None, 99 | pred_monoFromMem=None,): 100 | x = [] 101 | x.append(self.visual_encoder(observations)) 102 | x.append(self.bin_encoder(observations, pred_binSepMasks=pred_binSepMasks)) 103 | x.append(self.monoNmonoFromMem_encoder(observations, pred_monoNmonoFromMem=torch.cat((pred_mono, pred_monoFromMem), dim=3))) 104 | 105 | try: 106 | x1 = torch.cat(x, dim=1) 107 | except AssertionError as error: 108 | for data in x: 109 | print(data.size()) 110 | 111 | try: 112 | x2, rnn_hidden_states_new = self.state_encoder(x1, rnn_hidden_states, masks) 113 | except AssertionError as error: 114 | print(x1.size(), rnn_hidden_states.size(), masks.size(), x2.size(), rnn_hidden_states_new.size()) 115 | 116 | assert not torch.isnan(x2).any().item() 117 | 118 | return x2, rnn_hidden_states_new 119 | 120 | 121 | class PassiveSepEnc(nn.Module): 122 | r"""Network which encodes separated bin or mono outputs 123 | """ 124 | def __init__(self, observation_space, world_rank=0, convert_bin2mono=False,): 125 | super().__init__() 126 | assert 'mixed_bin_audio_mag' in observation_space.spaces 127 | 128 | self.passive_sep_encoder = PassiveSepEncCNN(convert_bin2mono=convert_bin2mono,) 129 | 130 | if world_rank == 0: 131 | audio_shape = observation_space.spaces['mixed_bin_audio_mag'].shape 132 | 133 | if not convert_bin2mono: 134 | summary(self.passive_sep_encoder.cnn, 135 | (audio_shape[2] * 16 + 1, audio_shape[0] // 16, audio_shape[1]), 136 | device='cpu') 137 | else: 138 | summary(self.passive_sep_encoder.cnn, 139 | (audio_shape[2] * 16, audio_shape[0] // 16, audio_shape[1]), 140 | device='cpu') 141 | 142 | def forward(self, observations, mixed_audio=None): 143 | bottleneck_feats, lst_skip_feats = self.passive_sep_encoder(observations, mixed_audio=mixed_audio,) 144 | 145 | return bottleneck_feats, lst_skip_feats 146 | 147 | 148 | class PassiveSepDec(nn.Module): 149 | r"""Network which decodes separated bin or mono outputs feature embeddings 150 | """ 151 | def __init__(self, convert_bin2mono=False,): 152 | super().__init__() 153 | self.passive_sep_decoder = PassiveSepDecCNN(convert_bin2mono=convert_bin2mono,) 154 | 155 | def forward(self, bottleneck_feats, lst_skip_feats): 156 | return self.passive_sep_decoder(bottleneck_feats, lst_skip_feats) 157 | 158 | 159 | class Policy(nn.Module): 160 | r""" 161 | Network for the full Move2Hear policy, including separation and action-making 162 | """ 163 | def __init__(self, pol_net, dim_actions, binSep_enc, binSep_dec, bin2mono_enc, bin2mono_dec, acoustic_mem,): 164 | super().__init__() 165 | self.dim_actions = dim_actions 166 | 167 | # full policy with actor and critic 168 | self.pol_net = pol_net 169 | self.action_dist = CategoricalNet( 170 | self.pol_net.output_size, self.dim_actions 171 | ) 172 | self.critic = CriticHead(self.pol_net.output_size) 173 | 174 | self.binSep_enc = binSep_enc 175 | self.binSep_dec = binSep_dec 176 | self.bin2mono_enc = bin2mono_enc 177 | self.bin2mono_dec = bin2mono_dec 178 | self.acoustic_mem = acoustic_mem 179 | 180 | def forward(self): 181 | raise NotImplementedError 182 | 183 | def get_binSepMasks(self, observations): 184 | bottleneck_feats, lst_skip_feats = self.binSep_enc( 185 | observations, 186 | ) 187 | return self.binSep_dec(bottleneck_feats, lst_skip_feats) 188 | 189 | def convert_bin2mono(self, pred_binSepMasks, mixed_audio=None): 190 | bottleneck_feats, lst_skip_feats = self.bin2mono_enc( 191 | pred_binSepMasks, mixed_audio=mixed_audio 192 | ) 193 | return self.bin2mono_dec(bottleneck_feats, lst_skip_feats) 194 | 195 | def get_monoFromMem(self, pred_mono, prev_pred_monoFromMem_masked): 196 | return self.acoustic_mem(pred_mono, prev_pred_monoFromMem_masked) 197 | 198 | def act( 199 | self, 200 | observations, 201 | rnn_hidden_states_pol, 202 | masks, 203 | deterministic=False, 204 | pred_binSepMasks=None, 205 | pred_mono=None, 206 | pred_monoFromMem=None, 207 | ): 208 | feats_pol, rnn_hidden_states_pol = self.pol_net( 209 | observations, 210 | rnn_hidden_states_pol, 211 | masks, 212 | pred_binSepMasks=pred_binSepMasks.detach(), 213 | pred_mono=pred_mono.detach(), 214 | pred_monoFromMem=pred_monoFromMem.detach(), 215 | ) 216 | 217 | dist = self.action_dist(feats_pol) 218 | value = self.critic(feats_pol) 219 | if deterministic: 220 | action = dist.mode() 221 | else: 222 | action = dist.sample() 223 | action_log_probs = dist.log_probs(action) 224 | 225 | return value, action, action_log_probs, rnn_hidden_states_pol, dist.get_probs() 226 | 227 | def get_value( 228 | self, 229 | observations, 230 | rnn_hidden_states_pol, 231 | masks, 232 | pred_binSepMasks=None, 233 | pred_mono=None, 234 | pred_monoFromMem=None, 235 | ): 236 | 237 | feats_pol, _ = self.pol_net( 238 | observations, 239 | rnn_hidden_states_pol, 240 | masks, 241 | pred_binSepMasks=pred_binSepMasks.detach(), 242 | pred_mono=pred_mono.detach(), 243 | pred_monoFromMem=pred_monoFromMem.detach(), 244 | ) 245 | 246 | return self.critic(feats_pol) 247 | 248 | def evaluate_actions( 249 | self, 250 | observations, 251 | rnn_hidden_states_pol, 252 | masks, 253 | action, 254 | pred_binSepMasks=None, 255 | pred_mono=None, 256 | pred_monoFromMem=None, 257 | ): 258 | feats_pol, rnn_hidden_states_pol = self.pol_net( 259 | observations, 260 | rnn_hidden_states_pol, 261 | masks, 262 | pred_binSepMasks=pred_binSepMasks, 263 | pred_mono=pred_mono, 264 | pred_monoFromMem=pred_monoFromMem, 265 | ) 266 | 267 | dist = self.action_dist(feats_pol) 268 | value = self.critic(feats_pol) 269 | 270 | action_log_probs = dist.log_probs(action) 271 | dist_entropy = dist.entropy().mean() 272 | 273 | return value, action_log_probs, dist_entropy, rnn_hidden_states_pol 274 | 275 | 276 | class Move2HearPolicy(Policy): 277 | def __init__( 278 | self, 279 | observation_space, 280 | action_space, 281 | goal_sensor_uuid, 282 | hidden_size=512, 283 | extra_rgb=False, 284 | extra_depth=False, 285 | use_ddppo=False, 286 | world_rank=0, 287 | use_smartnav_for_eval_pol_mix=False, 288 | ): 289 | pol_net = PolicyNet( 290 | observation_space=observation_space, 291 | hidden_size=hidden_size, 292 | goal_sensor_uuid=goal_sensor_uuid, 293 | extra_rgb=extra_rgb, 294 | extra_depth=extra_depth, 295 | world_rank=world_rank, 296 | ) 297 | 298 | binSep_enc = PassiveSepEnc( 299 | observation_space=observation_space, 300 | world_rank=world_rank, 301 | ) 302 | binSep_dec = PassiveSepDec() 303 | 304 | bin2mono_enc = PassiveSepEnc( 305 | observation_space=observation_space, 306 | world_rank=world_rank, 307 | convert_bin2mono=True, 308 | ) 309 | bin2mono_dec = PassiveSepDec( 310 | convert_bin2mono=True, 311 | ) 312 | 313 | acoustic_mem = AcousticMem( 314 | use_ddppo=use_ddppo, 315 | ) 316 | 317 | super().__init__( 318 | pol_net, 319 | # action_space.n - 1 if use_smartnav_for_eval_pol_mix else action_space.n, 320 | action_space.n, 321 | binSep_enc, 322 | binSep_dec, 323 | bin2mono_enc, 324 | bin2mono_dec, 325 | acoustic_mem, 326 | ) 327 | 328 | -------------------------------------------------------------------------------- /audio_separation/pretrain/passive/passive_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import random 4 | from typing import Dict 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from tqdm import tqdm 11 | from torch.utils.data import DataLoader 12 | 13 | from habitat import logger 14 | 15 | from audio_separation.common.base_trainer import BaseRLTrainer 16 | from audio_separation.common.baseline_registry import baseline_registry 17 | from audio_separation.common.env_utils import construct_envs 18 | from audio_separation.common.environments import get_env_class 19 | from audio_separation.common.tensorboard_utils import TensorboardWriter 20 | from audio_separation.pretrain.passive.policy import Move2HearPassiveWoMemoryPolicy 21 | from audio_separation.pretrain.passive.passive import Passive 22 | from audio_separation.pretrain.datasets.dataset import PassiveDataset 23 | from habitat_audio.utils import load_points_data 24 | 25 | 26 | SCENE_SPLITS = { 27 | "mp3d": 28 | { 29 | 'train': 30 | ['sT4fr6TAbpF', 'E9uDoFAP3SH', 'VzqfbhrpDEA', 'kEZ7cmS4wCh', '29hnd4uzFmX', 'ac26ZMwG7aT', 31 | 's8pcmisQ38h', 'rPc6DW4iMge', 'EDJbREhghzL', 'mJXqzFtmKg4', 'B6ByNegPMKs', 32 | 'JeFG25nYj2p', '82sE5b5pLXE', 'D7N2EKCX4Sj', '7y3sRwLe3Va', '5LpN3gDmAk7', 33 | 'gTV8FGcVJC9', 'ur6pFq6Qu1A', 'qoiz87JEwZ2', 'PuKPg4mmafe', 'VLzqgDo317F', 'aayBHfsNo7d', 34 | 'JmbYfDe2QKZ', 'XcA2TqTSSAj', '8WUmhLawc2A', 'sKLMLpTHeUy', 'r47D5H71a5s', 'Uxmj2M2itWa', 35 | 'Pm6F8kyY3z2', 'p5wJjkQkbXX', '759xd9YjKW5', 'JF19kD82Mey', 'V2XKFyX4ASd', '1LXtFkjw3qL', 36 | '17DRP5sb8fy', '5q7pvUzZiYa', 'VVfe2KiqLaN', 'Vvot9Ly1tCj', 'ULsKaCPVFJR', 'D7G3Y4RVNrH', 37 | 'uNb9QFRL6hY', 'ZMojNkEp431', '2n8kARJN3HM', 'vyrNrziPKCB', 'e9zR4mvMWw7', 'r1Q1Z4BcV1o', 38 | 'PX4nDJXEHrG', 'YmJkqBEsHnH', 'b8cTxDM8gDG', 'GdvgFV5R1Z5', 'pRbA3pwrgk9', 'jh4fc5c5qoQ', 39 | '1pXnuDYAj8r', 'S9hNv5qa7GM', 'VFuaQ6m2Qom', 'cV4RVeZvu5T', 'SN83YJsR3w2'], 40 | 'val': 41 | ['x8F5xyUWy9e', 'QUCTc6BB5sX', 'EU6Fwq7SyZv', '2azQ1b91cZZ', 'Z6MFQCViBuw', 'pLe4wQe7qrG', 42 | 'oLBMNvg9in8', 'X7HyMhZNoso', 'zsNo4HB9uLZ', 'TbHJrupSAjP', '8194nk5LbLH'], 43 | }, 44 | } 45 | 46 | 47 | EPS = 1e-7 48 | 49 | 50 | @baseline_registry.register_trainer(name="passive") 51 | class PassiveTrainer(BaseRLTrainer): 52 | r"""Trainer class for pretraining passive separators in a supervised fashion 53 | """ 54 | # supported_tasks = ["Nav-v0"] 55 | 56 | def __init__(self, config=None): 57 | super().__init__(config) 58 | self.actor_critic = None 59 | self.agent = None 60 | self.envs = None 61 | 62 | def _setup_passive_agent(self,) -> None: 63 | r"""Sets up agent for passive pretraining. 64 | 65 | Args: 66 | None 67 | Returns: 68 | None 69 | """ 70 | logger.add_filehandler(self.config.LOG_FILE) 71 | passive_cfg = self.config.Pretrain.Passive 72 | 73 | self.actor_critic = Move2HearPassiveWoMemoryPolicy( 74 | observation_space=self.envs.observation_spaces[0], 75 | ) 76 | 77 | self.actor_critic.to(self.device) 78 | self.actor_critic.train() 79 | 80 | self.agent = Passive( 81 | actor_critic=self.actor_critic, 82 | ) 83 | 84 | def save_checkpoint(self, file_name: str,) -> None: 85 | r"""Save checkpoint with specified name. 86 | 87 | Args: 88 | file_name: file name for checkpoint 89 | 90 | Returns: 91 | None 92 | """ 93 | checkpoint = { 94 | "state_dict": self.agent.state_dict(), 95 | "config": self.config, 96 | } 97 | torch.save( 98 | checkpoint, os.path.join(self.config.CHECKPOINT_FOLDER, file_name) 99 | ) 100 | 101 | def load_checkpoint(self, checkpoint_path: str, *args, **kwargs) -> Dict: 102 | r"""Load checkpoint of specified path as a dict. 103 | 104 | Args: 105 | checkpoint_path: path of target checkpoint 106 | *args: additional positional args 107 | **kwargs: additional keyword args 108 | 109 | Returns: 110 | dict containing checkpoint info 111 | """ 112 | return torch.load(checkpoint_path, *args, **kwargs) 113 | 114 | def get_dataloaders(self): 115 | r""" 116 | build datasets and dataloaders 117 | :return: 118 | dataloaders: PyTorch dataloaders for training and validation 119 | dataset_sizes: sizes of train and val datasets 120 | """ 121 | sim_cfg = self.config.TASK_CONFIG.SIMULATOR 122 | audio_cfg = sim_cfg.AUDIO 123 | 124 | scene_splits = {"train": SCENE_SPLITS[sim_cfg.SCENE_DATASET]["train"], 125 | "val": SCENE_SPLITS[sim_cfg.SCENE_DATASET]["val"], 126 | "nonoverlapping_val": SCENE_SPLITS[sim_cfg.SCENE_DATASET]["val"]} 127 | datasets = dict() 128 | dataloaders = dict() 129 | dataset_sizes = dict() 130 | for split in scene_splits: 131 | scenes = scene_splits[split] 132 | scene_graphs = dict() 133 | for scene in scenes: 134 | _, graph = load_points_data( 135 | os.path.join(audio_cfg.META_DIR, scene), 136 | audio_cfg.GRAPH_FILE, 137 | transform=True, 138 | scene_dataset=sim_cfg.SCENE_DATASET) 139 | scene_graphs[scene] = graph 140 | 141 | datasets[split] = PassiveDataset( 142 | split=split, 143 | scene_graphs=scene_graphs, 144 | sim_cfg=sim_cfg, 145 | ) 146 | 147 | dataloaders[split] = DataLoader(dataset=datasets[split], 148 | batch_size=audio_cfg.BATCH_SIZE, 149 | shuffle=(split == 'train'), 150 | pin_memory=True, 151 | num_workers=audio_cfg.NUM_WORKER, 152 | ) 153 | dataset_sizes[split] = len(datasets[split]) 154 | print('{} has {} samples'.format(split.upper(), dataset_sizes[split])) 155 | return dataloaders, dataset_sizes 156 | 157 | def train(self) -> None: 158 | r"""Main method for training passive separators using supervised learning. 159 | 160 | Returns: 161 | None 162 | """ 163 | passive_cfg = self.config.Pretrain.Passive 164 | sim_cfg = self.config.TASK_CONFIG.SIMULATOR 165 | audio_cfg = sim_cfg.AUDIO 166 | 167 | logger.info(f"config: {self.config}") 168 | random.seed(self.config.SEED) 169 | np.random.seed(self.config.SEED) 170 | torch.manual_seed(self.config.SEED) 171 | 172 | # just needed to get observation_spaces 173 | self.envs = construct_envs( 174 | self.config, get_env_class(self.config.ENV_NAME) 175 | ) 176 | 177 | self.device = ( 178 | torch.device("cuda", self.config.TORCH_GPU_ID) 179 | if torch.cuda.is_available() 180 | else torch.device("cpu") 181 | ) 182 | 183 | if not os.path.isdir(self.config.CHECKPOINT_FOLDER): 184 | os.makedirs(self.config.CHECKPOINT_FOLDER) 185 | 186 | self._setup_passive_agent() 187 | 188 | logger.info( 189 | "agent number of parameters: {}".format( 190 | sum(param.numel() for param in self.agent.parameters()) 191 | ) 192 | ) 193 | 194 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.actor_critic.parameters()), 195 | lr=passive_cfg.lr, eps=passive_cfg.eps) 196 | 197 | # build datasets and dataloaders 198 | dataloaders, dataset_sizes = self.get_dataloaders() 199 | 200 | best_mono_loss = float('inf') 201 | best_nonoverlapping_mono_loss = float('inf') 202 | with TensorboardWriter( 203 | self.config.TENSORBOARD_DIR, flush_secs=self.flush_secs 204 | ) as writer: 205 | for epoch in range(passive_cfg.NUM_EPOCHS): 206 | logging.info('-' * 10) 207 | logging.info('Epoch {}/{}'.format(epoch, passive_cfg.NUM_EPOCHS - 1)) 208 | 209 | for split in dataloaders.keys(): 210 | # set forward pass mode 211 | if split == "train": 212 | self.actor_critic.train() 213 | else: 214 | self.actor_critic.eval() 215 | 216 | bin_loss_epoch = 0. 217 | mono_loss_epoch = 0. 218 | for i, data in enumerate(tqdm(dataloaders[split])): 219 | mixed_audio = data[0].to(self.device) 220 | gt_bin_mag = data[1].to(self.device)[..., 0:2] 221 | gt_mono_mag = data[2].to(self.device)[..., :1] 222 | target_class = data[3].to(self.device) 223 | bs = target_class.size(0) 224 | 225 | obs_batch = {"mixed_bin_audio_mag": mixed_audio, "target_class": target_class} 226 | 227 | if split == "train": 228 | pred_binSepMasks = self.actor_critic.get_binSepMasks(obs_batch) 229 | pred_mono =\ 230 | self.actor_critic.convert_bin2mono(pred_binSepMasks.detach(), mixed_audio=mixed_audio) 231 | else: 232 | with torch.no_grad(): 233 | pred_binSepMasks = self.actor_critic.get_binSepMasks(obs_batch) 234 | pred_mono =\ 235 | self.actor_critic.convert_bin2mono(pred_binSepMasks.detach(), 236 | mixed_audio=mixed_audio) 237 | 238 | bin_loss, mono_loss, optimizer =\ 239 | self.optimize_supervised_loss(optimizer=optimizer, 240 | mixed_audio=mixed_audio, 241 | pred_binSepMasks=pred_binSepMasks, 242 | gt_bin_mag=gt_bin_mag, 243 | pred_mono=pred_mono, 244 | gt_mono_mag=gt_mono_mag, 245 | split=split, 246 | ) 247 | 248 | bin_loss_epoch += bin_loss.item() * bs 249 | mono_loss_epoch += mono_loss.item() * bs 250 | 251 | bin_loss_epoch /= dataset_sizes[split] 252 | mono_loss_epoch /= dataset_sizes[split] 253 | 254 | writer.add_scalar('bin_loss/{}'.format(split), bin_loss_epoch, epoch) 255 | writer.add_scalar('mono_loss/{}'.format(split), mono_loss_epoch, epoch) 256 | logging.info('{} -- bin loss: {:.4f}, mono loss: {:.4f}'.format(split.upper(), 257 | bin_loss_epoch, 258 | mono_loss_epoch)) 259 | if split == "val": 260 | if mono_loss_epoch < best_mono_loss: 261 | best_mono_loss = mono_loss_epoch 262 | self.save_checkpoint(f"best_ckpt_val.pth") 263 | elif split == "nonoverlapping_val": 264 | if mono_loss_epoch < best_nonoverlapping_mono_loss: 265 | best_nonoverlapping_mono_loss = mono_loss_epoch 266 | self.save_checkpoint(f"best_ckpt_nonoverlapping_val.pth") 267 | self.envs.close() 268 | 269 | def optimize_supervised_loss(self, optimizer, mixed_audio, pred_binSepMasks, gt_bin_mag, pred_mono, gt_mono_mag, 270 | split='train',): 271 | mixed_audio = torch.exp(mixed_audio) - 1 272 | pred_bin = pred_binSepMasks * mixed_audio 273 | bin_loss = F.l1_loss(pred_bin, gt_bin_mag) 274 | 275 | mono_loss = F.l1_loss(pred_mono, gt_mono_mag) 276 | 277 | if split == "train": 278 | optimizer.zero_grad() 279 | loss = bin_loss + mono_loss 280 | nn.utils.clip_grad_norm_( 281 | self.actor_critic.parameters(), self.config.Pretrain.Passive.max_grad_norm 282 | ) 283 | loss.backward() 284 | optimizer.step() 285 | 286 | return bin_loss, mono_loss, optimizer 287 | -------------------------------------------------------------------------------- /audio_separation/config/default.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | import os 3 | import shutil 4 | 5 | from habitat import get_config as get_task_config 6 | from habitat.config import Config as CN 7 | from habitat.config.default import SIMULATOR_SENSOR 8 | import habitat 9 | 10 | DEFAULT_CONFIG_DIR = "configs/" 11 | CONFIG_FILE_SEPARATOR = "," 12 | # ----------------------------------------------------------------------------- 13 | # EXPERIMENT CONFIG 14 | # ----------------------------------------------------------------------------- 15 | _C = CN() 16 | _C.SEED = 0 17 | _C.BASE_TASK_CONFIG_PATH = "configs/tasks/pointnav.yaml" 18 | _C.TASK_CONFIG = CN() # task_config will be stored as a config node 19 | _C.CMD_TRAILING_OPTS = [] # store command line options as list of strings 20 | _C.TRAINER_NAME = "ppo" 21 | _C.ENV_NAME = "AAViSSEnv" 22 | _C.SIMULATOR_GPU_ID = 0 23 | _C.TORCH_GPU_ID = 0 24 | _C.PARALLEL_GPU_IDS = [] 25 | _C.MODEL_DIR = '' 26 | _C.TENSORBOARD_DIR = "tb" 27 | _C.VIDEO_OPTION = [] 28 | _C.EVAL_CKPT_PATH_DIR = "data/checkpoints" # path to ckpt or path to ckpts dir 29 | _C.NUM_PROCESSES = 16 30 | _C.SENSORS = ["RGB_SENSOR", "DEPTH_SENSOR"] 31 | _C.CHECKPOINT_FOLDER = "data/checkpoints" 32 | _C.NUM_UPDATES = 10000 33 | _C.LOG_INTERVAL = 10 34 | _C.LOG_FILE = "train.log" 35 | _C.CHECKPOINT_INTERVAL = 50 36 | _C.USE_VECENV = True 37 | _C.USE_SYNC_VECENV = False 38 | _C.EXTRA_RGB = False 39 | _C.EXTRA_DEPTH = False 40 | _C.DEBUG = False 41 | _C.NUM_SOUNDS_IN_MIX = 2 42 | _C.COMPUTE_EVAL_METRICS = False 43 | _C.EVAL_METRICS_TO_COMPUTE = ['si_sdr',] 44 | _C.EPS_SCENES = [] 45 | _C.EPS_SCENES_N_IDS = [] 46 | _C.JOB_ID = 1 47 | 48 | # ----------------------------------------------------------------------------- 49 | # EVAL CONFIG 50 | # ----------------------------------------------------------------------------- 51 | _C.EVAL = CN() 52 | # The split to evaluate on 53 | _C.EVAL.SPLIT = "val" 54 | _C.EVAL.USE_CKPT_CONFIG = True 55 | 56 | # ----------------------------------------------------------------------------- 57 | # REINFORCEMENT LEARNING (RL) ENVIRONMENT CONFIG 58 | # ----------------------------------------------------------------------------- 59 | _C.RL = CN() 60 | _C.RL.SUCCESS_REWARD = 10.0 61 | _C.RL.SLACK_REWARD = -0.01 62 | _C.RL.WITH_DISTANCE_REWARD = True 63 | _C.RL.DISTANCE_REWARD_SCALE = 1.0 64 | # ----------------------------------------------------------------------------- 65 | # PROXIMAL POLICY OPTIMIZATION (PPO) 66 | # ----------------------------------------------------------------------------- 67 | _C.RL.PPO = CN() 68 | _C.RL.PPO.num_updates_per_cycle = 1 69 | _C.RL.PPO.pretrained_passive_separators_ckpt = "" 70 | _C.RL.PPO.train_passive_separators = False 71 | _C.RL.PPO.clip_param = 0.2 72 | _C.RL.PPO.ppo_epoch = 4 73 | _C.RL.PPO.num_mini_batch = 16 74 | _C.RL.PPO.value_loss_coef = 0.5 75 | _C.RL.PPO.bin_separation_loss_coef = 1.0 76 | _C.RL.PPO.mono_conversion_loss_coef = 1.0 77 | _C.RL.PPO.entropy_coef = 0.01 78 | _C.RL.PPO.lr_pol = 1e-3 79 | _C.RL.PPO.lr_sep = 1e-3 80 | _C.RL.PPO.eps = 1e-5 81 | _C.RL.PPO.max_grad_norm = 0.5 82 | _C.RL.PPO.num_steps = 5 83 | _C.RL.PPO.hidden_size = 512 84 | _C.RL.PPO.use_gae = True 85 | _C.RL.PPO.use_linear_lr_decay = False 86 | _C.RL.PPO.use_linear_clip_decay = False 87 | _C.RL.PPO.gamma = 0.99 88 | _C.RL.PPO.tau = 0.95 89 | _C.RL.PPO.reward_window_size = 50 90 | _C.RL.PPO.nav_reward_weight = 0.0 91 | _C.RL.PPO.sep_reward_weight = 1.0 92 | _C.RL.PPO.extra_reward_multiplier = 10.0 93 | _C.RL.PPO.deterministic_eval = False 94 | _C.RL.PPO.use_ddppo = False 95 | _C.RL.PPO.ddppo_distrib_backend = "NCCL" 96 | _C.RL.PPO.short_rollout_threshold = 0.25 97 | _C.RL.PPO.sync_frac = 0.6 98 | _C.RL.PPO.master_port = 8738 99 | _C.RL.PPO.master_addr = "127.0.0.1" 100 | _C.RL.PPO.switch_policy = False 101 | _C.RL.PPO.time_thres_for_pol_switch = 80 102 | 103 | # ----------------------------------------------------------------------------- 104 | # Pretraining passive separator 105 | # ----------------------------------------------------------------------------- 106 | _C.Pretrain = CN() 107 | _C.Pretrain.Passive = CN() 108 | _C.Pretrain.Passive.lr = 5.0e-4 109 | _C.Pretrain.Passive.eps = 1.0e-5 110 | _C.Pretrain.Passive.max_grad_norm = 0.8 111 | _C.Pretrain.Passive.NUM_EPOCHS = 1000 112 | 113 | # ----------------------------------------------------------------------------- 114 | # TASK CONFIG 115 | # ----------------------------------------------------------------------------- 116 | _TC = habitat.get_config() 117 | _TC.defrost() 118 | 119 | ########## ACTIONS ########### 120 | # ----------------------------------------------------------------------------- 121 | # PAUSE ACTION 122 | # ----------------------------------------------------------------------------- 123 | _TC.TASK.ACTIONS.PAUSE = CN() 124 | _TC.TASK.ACTIONS.PAUSE.TYPE = "PauseAction" 125 | 126 | ########## SENSORS ########### 127 | # ----------------------------------------------------------------------------- 128 | # MIXED BINAURAL AUDIO MAGNITUDE SENSOR 129 | # ----------------------------------------------------------------------------- 130 | _TC.TASK.MIXED_BIN_AUDIO_MAG_SENSOR = CN() 131 | _TC.TASK.MIXED_BIN_AUDIO_MAG_SENSOR.TYPE = "MixedBinAudioMagSensor" 132 | _TC.TASK.MIXED_BIN_AUDIO_MAG_SENSOR.FEATURE_SHAPE = [512, 32, 2] 133 | # ----------------------------------------------------------------------------- 134 | # MIXED BINAURAL AUDIO PHASE SENSOR 135 | # ----------------------------------------------------------------------------- 136 | _TC.TASK.MIXED_BIN_AUDIO_PHASE_SENSOR = CN() 137 | _TC.TASK.MIXED_BIN_AUDIO_PHASE_SENSOR.TYPE = "MixedBinAudioPhaseSensor" 138 | _TC.TASK.MIXED_BIN_AUDIO_PHASE_SENSOR.FEATURE_SHAPE = [512, 32, 2] 139 | # ----------------------------------------------------------------------------- 140 | # GROUND-TRUTH MONO COMPONENTS SENSOR 141 | # ----------------------------------------------------------------------------- 142 | _TC.TASK.GT_MONO_COMPONENTS_SENSOR = CN() 143 | _TC.TASK.GT_MONO_COMPONENTS_SENSOR.TYPE = "GtMonoComponentsSensor" 144 | # default for 1 sound in the mixture ([mag, phase]) 145 | _TC.TASK.GT_MONO_COMPONENTS_SENSOR.FEATURE_SHAPE = [512, 32, 2] 146 | # ----------------------------------------------------------------------------- 147 | # GROUND-TRUTH BINAURAL COMPONENTS SENSOR 148 | # ----------------------------------------------------------------------------- 149 | _TC.TASK.GT_BIN_COMPONENTS_SENSOR = CN() 150 | _TC.TASK.GT_BIN_COMPONENTS_SENSOR.TYPE = "GtBinComponentsSensor" 151 | # default for 1 sound in the mixture ([mag_l, phase_l, mag_r, phase_r]) 152 | _TC.TASK.GT_BIN_COMPONENTS_SENSOR.FEATURE_SHAPE = [512, 32, 4] 153 | # ----------------------------------------------------------------------------- 154 | # TARGET CLASS SENSOR 155 | # ----------------------------------------------------------------------------- 156 | _TC.TASK.TARGET_CLASS_SENSOR = SIMULATOR_SENSOR.clone() 157 | _TC.TASK.TARGET_CLASS_SENSOR.TYPE = "TargetClassSensor" 158 | 159 | ########## MEASURES IN INFO ########### 160 | # ----------------------------------------------------------------------------- 161 | # Geodesic Distance to Target Audio Source Measure 162 | # ----------------------------------------------------------------------------- 163 | _TC.TASK.GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE = CN() 164 | _TC.TASK.GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE.TYPE = "GeoDistanceToTargetAudioSource" 165 | # ----------------------------------------------------------------------------- 166 | # Normalized Geodesic Distance to Target Audio Source Measure 167 | # ----------------------------------------------------------------------------- 168 | _TC.TASK.NORMALIZED_GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE = CN() 169 | _TC.TASK.NORMALIZED_GEODESIC_DISTANCE_TO_TARGET_AUDIO_SOURCE.TYPE = "NormalizedGeoDistanceToTargetAudioSource" 170 | 171 | # ----------------------------------------------------------------------------- 172 | # simulator config 173 | # ----------------------------------------------------------------------------- 174 | _TC.SIMULATOR.SEED = -1 175 | _TC.SIMULATOR.SCENE_DATASET = "mp3d" 176 | _TC.SIMULATOR.MAX_EPISODE_STEPS = 20 177 | _TC.SIMULATOR.GRID_SIZE = 1.0 178 | _TC.SIMULATOR.USE_RENDERED_OBSERVATIONS = True 179 | _TC.SIMULATOR.RENDERED_OBSERVATIONS = "../sound_spaces/scene_observations_new/" 180 | 181 | # ----------------------------------------------------------------------------- 182 | # audio config 183 | # ----------------------------------------------------------------------------- 184 | _TC.SIMULATOR.AUDIO = CN() 185 | _TC.SIMULATOR.AUDIO.MONO_DIR = "data/audio_data/VoxCelebV1TenClasses_MITMusic_ESC50/train_preprocessed" 186 | _TC.SIMULATOR.AUDIO.RIR_DIR = "../sound_spaces/binaural_rirs/mp3d" 187 | _TC.SIMULATOR.AUDIO.META_DIR = "../sound_spaces/metadata/mp3d" 188 | _TC.SIMULATOR.AUDIO.PASSIVE_DATASET_VERSION = "v1" 189 | _TC.SIMULATOR.AUDIO.SOURCE_AGENT_LOCATION_DATAPOINTS_DIR = "data/passive_datasets/" 190 | _TC.SIMULATOR.AUDIO.PASSIVE_TRAIN_AUDIO_DIR = "data/audio_data/VoxCelebV1TenClasses_MITMusic_ESC50/train_preprocessed" 191 | _TC.SIMULATOR.AUDIO.PASSIVE_NONOVERLAPPING_VAL_AUDIO_DIR = "data/audio_data/VoxCelebV1TenClasses_MITMusic_ESC50/val_preprocessed" 192 | _TC.SIMULATOR.AUDIO.NUM_PASSIVE_DATAPOINTS_PER_SCENE = 30000 193 | _TC.SIMULATOR.AUDIO.NUM_PASSIVE_DATAPOINTS_PER_SCENE_EVAL = 1000 194 | _TC.SIMULATOR.AUDIO.GRAPH_FILE = 'graph.pkl' 195 | _TC.SIMULATOR.AUDIO.POINTS_FILE = 'points.txt' 196 | _TC.SIMULATOR.AUDIO.NUM_WORKER = 4 197 | _TC.SIMULATOR.AUDIO.BATCH_SIZE = 128 198 | _TC.SIMULATOR.AUDIO.GT_MONO_MAG_NORM = 0.0 199 | _TC.SIMULATOR.AUDIO.NORM_TYPE = "l2" 200 | _TC.SIMULATOR.AUDIO.RIR_SAMPLING_RATE = 16000 201 | 202 | # ----------------------------------------------------------------------------- 203 | # Dataset extension 204 | # ----------------------------------------------------------------------------- 205 | _TC.DATASET.VERSION = 'v1' 206 | 207 | 208 | def merge_from_path(config, config_paths): 209 | """ 210 | merge config with configs from config paths 211 | :param config: original unmerged config 212 | :param config_paths: config paths to merge configs from 213 | :return: merged config 214 | """ 215 | if config_paths: 216 | if isinstance(config_paths, str): 217 | if CONFIG_FILE_SEPARATOR in config_paths: 218 | config_paths = config_paths.split(CONFIG_FILE_SEPARATOR) 219 | else: 220 | config_paths = [config_paths] 221 | 222 | for config_path in config_paths: 223 | config.merge_from_file(config_path) 224 | 225 | return config 226 | 227 | 228 | def get_config( 229 | config_paths: Optional[Union[List[str], str]] = None, 230 | opts: Optional[list] = None, 231 | model_dir: Optional[str] = None, 232 | run_type: Optional[str] = None 233 | ) -> CN: 234 | r"""Create a unified config with default values overwritten by values from 235 | `config_paths` and overwritten by options from `opts`. 236 | Args: 237 | config_paths: List of config paths or string that contains comma 238 | separated list of config paths. 239 | opts: Config options (keys, values) in a list (e.g., passed from 240 | command line into the config. For example, `opts = ['FOO.BAR', 241 | 0.5]`. Argument can be used for parameter sweeping or quick tests. 242 | model_dir: suffix for output dirs 243 | run_type: either train or eval 244 | """ 245 | config = merge_from_path(_C.clone(), config_paths) 246 | config.TASK_CONFIG = get_task_config(config_paths=config.BASE_TASK_CONFIG_PATH) 247 | 248 | if opts: 249 | config.CMD_TRAILING_OPTS = opts 250 | config.merge_from_list(opts) 251 | 252 | assert model_dir is not None, "set --model-dir" 253 | config.MODEL_DIR = model_dir 254 | config.TENSORBOARD_DIR = os.path.join(config.MODEL_DIR, config.TENSORBOARD_DIR) 255 | config.CHECKPOINT_FOLDER = os.path.join(config.MODEL_DIR, 'data') 256 | config.LOG_FILE = os.path.join(config.MODEL_DIR, config.LOG_FILE) 257 | config.EVAL_CKPT_PATH_DIR = os.path.join(config.MODEL_DIR, 'data') 258 | 259 | dirs = [config.TENSORBOARD_DIR, config.CHECKPOINT_FOLDER] 260 | if run_type == 'train': 261 | # check dirs 262 | if any([os.path.exists(d) for d in dirs]): 263 | for d in dirs: 264 | if os.path.exists(d): 265 | print('{} exists'.format(d)) 266 | key = input('Output directory already exists! Overwrite the folder? (y/n)') 267 | if key == 'y': 268 | for d in dirs: 269 | if os.path.exists(d): 270 | shutil.rmtree(d) 271 | 272 | config.TASK_CONFIG.defrost() 273 | config.TASK_CONFIG.SIMULATOR.USE_SYNC_VECENV = config.USE_SYNC_VECENV 274 | 275 | config.TASK_CONFIG.TASK.GT_MONO_COMPONENTS_SENSOR.FEATURE_SHAPE[2] *= config.NUM_SOUNDS_IN_MIX 276 | config.TASK_CONFIG.TASK.GT_BIN_COMPONENTS_SENSOR.FEATURE_SHAPE[2] *= config.NUM_SOUNDS_IN_MIX 277 | 278 | config.TASK_CONFIG.SIMULATOR.MAX_EPISODE_STEPS = config.TASK_CONFIG.ENVIRONMENT.MAX_EPISODE_STEPS 279 | 280 | if config.RL.PPO.switch_policy: 281 | config.EVAL.USE_CKPT_CONFIG = False 282 | config.NUM_PROCESSES = 1 283 | 284 | config.TASK_CONFIG.freeze() 285 | 286 | config.freeze() 287 | 288 | return config 289 | 290 | 291 | def get_task_config( 292 | config_paths: Optional[Union[List[str], str]] = None, 293 | opts: Optional[list] = None 294 | ) -> habitat.Config: 295 | r""" 296 | get config after merging configs stored in yaml files and command line arguments 297 | :param config_paths: paths to configs 298 | :param opts: optional command line arguments 299 | :return: merged config 300 | """ 301 | config = _TC.clone() 302 | if config_paths: 303 | if isinstance(config_paths, str): 304 | if CONFIG_FILE_SEPARATOR in config_paths: 305 | config_paths = config_paths.split(CONFIG_FILE_SEPARATOR) 306 | else: 307 | config_paths = [config_paths] 308 | 309 | for config_path in config_paths: 310 | config.merge_from_file(config_path) 311 | 312 | if opts: 313 | config.merge_from_list(opts) 314 | 315 | config.freeze() 316 | return config 317 | -------------------------------------------------------------------------------- /audio_separation/rl/ppo/ppo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | 6 | from audio_separation.rl.ppo.ddppo_utils import distributed_mean_and_var 7 | 8 | EPS_PPO = 1e-5 9 | 10 | 11 | class PPO(nn.Module): 12 | def __init__( 13 | self, 14 | actor_critic, 15 | clip_param, 16 | ppo_epoch, 17 | num_mini_batch, 18 | value_loss_coef, 19 | bin_separation_loss_coef, 20 | mono_conversion_loss_coef, 21 | entropy_coef, 22 | lr_pol=None, 23 | lr_sep=None, 24 | eps=None, 25 | max_grad_norm=None, 26 | freeze_passive_separators=False, 27 | use_clipped_value_loss=True, 28 | use_normalized_advantage=True, 29 | ): 30 | super().__init__() 31 | self.actor_critic = actor_critic 32 | 33 | self.clip_param = clip_param 34 | self.ppo_epoch = ppo_epoch 35 | self.num_mini_batch = num_mini_batch 36 | 37 | self.value_loss_coef = value_loss_coef 38 | self.bin_separation_loss_coef = bin_separation_loss_coef 39 | self.mono_conversion_loss_coef = mono_conversion_loss_coef 40 | self.entropy_coef = entropy_coef 41 | 42 | self.max_grad_norm = max_grad_norm 43 | self.use_clipped_value_loss = use_clipped_value_loss 44 | self.use_normalized_advantage = use_normalized_advantage 45 | 46 | self.freeze_passive_separators=freeze_passive_separators 47 | 48 | pol_params = list(actor_critic.pol_net.parameters()) + list(actor_critic.action_dist.parameters()) +\ 49 | list(actor_critic.critic.parameters()) 50 | self.optimizer_pol = optim.Adam(pol_params, lr=lr_pol, eps=eps) 51 | 52 | sep_params = list(actor_critic.binSep_enc.parameters()) + list(actor_critic.binSep_dec.parameters()) +\ 53 | list(actor_critic.bin2mono_enc.parameters()) + list(actor_critic.bin2mono_dec.parameters()) +\ 54 | list(actor_critic.acoustic_mem.parameters()) 55 | self.optimizer_sep = optim.Adam(sep_params, lr=lr_sep, eps=eps) 56 | 57 | self.device = next(actor_critic.parameters()).device 58 | 59 | def load_pretrained_passive_separators(self, state_dict): 60 | # loading pretrained weights from passive binaural separator 61 | for name in self.actor_critic.binSep_enc.state_dict(): 62 | self.actor_critic.binSep_enc.state_dict()[name].copy_(state_dict["actor_critic.binSep_enc." + name]) 63 | for name in self.actor_critic.binSep_dec.state_dict(): 64 | self.actor_critic.binSep_dec.state_dict()[name].copy_(state_dict["actor_critic.binSep_dec." + name]) 65 | 66 | # loading pretrained weights from passive bin2mono separator 67 | for name in self.actor_critic.bin2mono_enc.state_dict(): 68 | self.actor_critic.bin2mono_enc.state_dict()[name].copy_(state_dict["actor_critic.bin2mono_enc." + name]) 69 | for name in self.actor_critic.bin2mono_dec.state_dict(): 70 | self.actor_critic.bin2mono_dec.state_dict()[name].copy_(state_dict["actor_critic.bin2mono_dec." + name]) 71 | 72 | def forward(self, *x): 73 | raise NotImplementedError 74 | 75 | def get_advantages(self, rollouts_pol): 76 | advantages = rollouts_pol.returns[:-1] - rollouts_pol.value_preds[:-1] 77 | if not self.use_normalized_advantage: 78 | return advantages 79 | 80 | return (advantages - advantages.mean()) / (advantages.std() + EPS_PPO) 81 | 82 | def update_pol(self, rollouts_pol): 83 | advantages = self.get_advantages(rollouts_pol) 84 | 85 | value_loss_epoch = 0 86 | action_loss_epoch = 0 87 | dist_entropy_epoch = 0 88 | 89 | for e in range(self.ppo_epoch): 90 | data_generator = rollouts_pol.recurrent_generator( 91 | advantages, self.num_mini_batch 92 | ) 93 | 94 | for sample in data_generator: 95 | ( 96 | obs_batch, 97 | recurrent_hidden_states_pol_batch, 98 | pred_binSepMasks_batch, 99 | pred_mono_batch, 100 | pred_monoFromMem_batch, 101 | value_preds_batch, 102 | return_batch, 103 | adv_targ, 104 | actions_batch, 105 | old_action_log_probs_batch, 106 | masks_batch, 107 | ) = sample 108 | 109 | 110 | ( 111 | values, 112 | action_log_probs, 113 | dist_entropy, 114 | _, 115 | ) = self.actor_critic.evaluate_actions( 116 | obs_batch, 117 | recurrent_hidden_states_pol_batch, 118 | masks_batch, 119 | actions_batch, 120 | pred_binSepMasks=pred_binSepMasks_batch, 121 | pred_mono=pred_mono_batch, 122 | pred_monoFromMem=pred_monoFromMem_batch, 123 | ) 124 | 125 | ratio = torch.exp( 126 | action_log_probs - old_action_log_probs_batch 127 | ) 128 | surr1 = ratio * adv_targ 129 | surr2 = ( 130 | torch.clamp( 131 | ratio, 1.0 - self.clip_param, 1.0 + self.clip_param 132 | ) 133 | * adv_targ 134 | ) 135 | action_loss = -torch.min(surr1, surr2).mean() 136 | 137 | if self.use_clipped_value_loss: 138 | value_pred_clipped = value_preds_batch + ( 139 | values - value_preds_batch 140 | ).clamp(-self.clip_param, self.clip_param) 141 | value_losses = (values - return_batch).pow(2) 142 | value_losses_clipped = ( 143 | value_pred_clipped - return_batch 144 | ).pow(2) 145 | value_loss = ( 146 | 0.5 147 | * torch.max(value_losses, value_losses_clipped).mean() 148 | ) 149 | else: 150 | value_loss = 0.5 * (return_batch - values).pow(2).mean() 151 | 152 | self.optimizer_pol.zero_grad() 153 | total_loss = ( 154 | value_loss * self.value_loss_coef 155 | + action_loss 156 | - dist_entropy * self.entropy_coef 157 | ) 158 | 159 | self.before_backward(total_loss) 160 | total_loss.backward() 161 | self.after_backward(total_loss) 162 | 163 | self.before_step_pol() 164 | self.optimizer_pol.step() 165 | self.after_step() 166 | 167 | action_loss_epoch += action_loss.item() 168 | value_loss_epoch += value_loss.item() 169 | dist_entropy_epoch += dist_entropy.item() 170 | 171 | num_updates = self.ppo_epoch * self.num_mini_batch 172 | 173 | action_loss_epoch /= num_updates 174 | value_loss_epoch /= num_updates 175 | dist_entropy_epoch /= num_updates 176 | 177 | return value_loss_epoch, action_loss_epoch, dist_entropy_epoch 178 | 179 | def update_sep(self, rollouts_sep): 180 | bin_loss_epoch = 0. 181 | mono_loss_epoch = 0. 182 | monoFromMem_loss_epoch = 0. 183 | 184 | for e in range(self.ppo_epoch): 185 | data_generator = rollouts_sep.recurrent_generator(self.num_mini_batch) 186 | 187 | for sample in data_generator: 188 | ( 189 | obs_batch, 190 | pred_monoFromMem_batch, 191 | prev_pred_monoFromMem_batch, 192 | masks_batch 193 | ) = sample 194 | 195 | # use torch.no_grad since passive separators are frozen 196 | with torch.no_grad(): 197 | pred_binSepMasks =\ 198 | self.actor_critic.get_binSepMasks( 199 | obs_batch, 200 | ) 201 | pred_mono =\ 202 | self.actor_critic.convert_bin2mono(pred_binSepMasks.detach(), 203 | mixed_audio=obs_batch["mixed_bin_audio_mag"], 204 | ) 205 | 206 | prev_pred_monoFromMem_masked = prev_pred_monoFromMem_batch *\ 207 | masks_batch.unsqueeze(1).unsqueeze(2).repeat(1, 208 | *pred_mono.size()[1:] 209 | ) 210 | pred_monoFromMem =\ 211 | self.actor_critic.get_monoFromMem(pred_mono, prev_pred_monoFromMem_masked) 212 | 213 | # get monoFromMem loss 214 | gt_mono_mag = obs_batch["gt_mono_comps"][..., 0::2].clone()[..., :1] 215 | monoFromMem_loss = F.l1_loss(pred_monoFromMem, gt_mono_mag) 216 | 217 | # get bin2mono loss 218 | mono_loss = F.l1_loss(pred_mono, gt_mono_mag) 219 | 220 | # get bin-sep loss 221 | gt_bin_mag = obs_batch["gt_bin_comps"][..., 0::2].clone()[..., :2] 222 | pred_bin = (torch.exp(obs_batch["mixed_bin_audio_mag"]) - 1) * pred_binSepMasks 223 | bin_loss = F.l1_loss(pred_bin, gt_bin_mag) 224 | 225 | self.optimizer_sep.zero_grad() 226 | total_loss = monoFromMem_loss 227 | 228 | self.before_backward(total_loss) 229 | total_loss.backward() 230 | self.after_backward(total_loss) 231 | 232 | self.before_step_sep() 233 | self.optimizer_sep.step() 234 | self.after_step() 235 | 236 | bin_loss_epoch += bin_loss.item() 237 | mono_loss_epoch += mono_loss.item() 238 | monoFromMem_loss_epoch += monoFromMem_loss.item() 239 | 240 | num_updates = self.ppo_epoch * self.num_mini_batch 241 | 242 | bin_loss_epoch /= num_updates 243 | mono_loss_epoch /= num_updates 244 | monoFromMem_loss_epoch /= num_updates 245 | 246 | return bin_loss_epoch, mono_loss_epoch, monoFromMem_loss_epoch 247 | 248 | def before_backward(self, loss): 249 | pass 250 | 251 | def after_backward(self, loss): 252 | pass 253 | 254 | def before_step_pol(self): 255 | pol_params = list(self.actor_critic.pol_net.parameters()) +\ 256 | list(self.actor_critic.action_dist.parameters()) +\ 257 | list(self.actor_critic.critic.parameters()) 258 | nn.utils.clip_grad_norm_( 259 | pol_params, self.max_grad_norm 260 | ) 261 | 262 | def before_step_sep(self): 263 | sep_params = list(self.actor_critic.binSep_enc.parameters()) + list(self.actor_critic.binSep_dec.parameters()) +\ 264 | list(self.actor_critic.bin2mono_enc.parameters()) + list(self.actor_critic.bin2mono_dec.parameters()) +\ 265 | list(self.actor_critic.acoustic_mem.parameters()) 266 | nn.utils.clip_grad_norm_( 267 | sep_params, self.max_grad_norm 268 | ) 269 | 270 | def after_step(self): 271 | pass 272 | 273 | 274 | class DecentralizedDistributedMixin: 275 | def _get_advantages_distributed( 276 | self, rollouts_nav 277 | ) -> torch.Tensor: 278 | advantages = rollouts_nav.returns[:-1] - rollouts_nav.value_preds[:-1] 279 | if not self.use_normalized_advantage: 280 | return advantages 281 | 282 | mean, var = distributed_mean_and_var(advantages) 283 | 284 | return (advantages - mean) / (var.sqrt() + EPS_PPO) 285 | 286 | def init_distributed(self, find_unused_params: bool = True) -> None: 287 | r"""Initializes distributed training for the model 288 | 1. Broadcasts the model weights from world_rank 0 to all other workers 289 | 2. Adds gradient hooks to the model 290 | :param find_unused_params: Whether or not to filter out unused parameters 291 | before gradient reduction. This *must* be True if 292 | there are any parameters in the model that where unused in the 293 | forward pass, otherwise the gradient reduction 294 | will not work correctly. 295 | """ 296 | # NB: Used to hide the hooks from the nn.Module, 297 | # so they don't show up in the state_dict 298 | class Guard: 299 | def __init__(self, model, device): 300 | if torch.cuda.is_available(): 301 | self.ddp = torch.nn.parallel.DistributedDataParallel( 302 | model, device_ids=[device], output_device=device 303 | ) 304 | else: 305 | self.ddp = torch.nn.parallel.DistributedDataParallel(model) 306 | 307 | self._ddp_hooks = Guard(self.actor_critic, self.device) 308 | self.get_advantages = self._get_advantages_distributed 309 | 310 | self.reducer = self._ddp_hooks.ddp.reducer 311 | self.find_unused_params = find_unused_params 312 | 313 | def before_backward(self, loss): 314 | super().before_backward(loss) 315 | 316 | if self.find_unused_params: 317 | self.reducer.prepare_for_backward([loss]) 318 | else: 319 | self.reducer.prepare_for_backward([]) 320 | 321 | 322 | class DDPPO(DecentralizedDistributedMixin, PPO): 323 | pass 324 | -------------------------------------------------------------------------------- /audio_separation/common/sync_vector_env.py: -------------------------------------------------------------------------------- 1 | from multiprocessing.context import BaseContext 2 | from threading import Thread 3 | from typing import ( 4 | Any, 5 | Callable, 6 | Dict, 7 | List, 8 | Optional, 9 | Sequence, 10 | Set, 11 | Tuple, 12 | Union, 13 | ) 14 | 15 | import gym 16 | import numpy as np 17 | from gym.spaces.dict_space import Dict as SpaceDict 18 | 19 | import habitat 20 | from habitat.config import Config 21 | from habitat.core.env import Env, Observations, RLEnv 22 | from habitat.core.utils import tile_images 23 | 24 | try: 25 | # Use torch.multiprocessing if we can. 26 | # We have yet to find a reason to not use it and 27 | # you are required to use it when sending a torch.Tensor 28 | # between processes 29 | import torch.multiprocessing as mp 30 | except ImportError: 31 | import multiprocessing as mp 32 | 33 | STEP_COMMAND = "step" 34 | RESET_COMMAND = "reset" 35 | RENDER_COMMAND = "render" 36 | CLOSE_COMMAND = "close" 37 | OBSERVATION_SPACE_COMMAND = "observation_space" 38 | ACTION_SPACE_COMMAND = "action_space" 39 | CALL_COMMAND = "call" 40 | EPISODE_COMMAND = "current_episode" 41 | 42 | 43 | def _make_env_fn( 44 | config: Config, dataset: Optional[habitat.Dataset] = None, rank: int = 0 45 | ) -> Env: 46 | """Constructor for default habitat `env.Env`. 47 | 48 | :param config: configuration for environment. 49 | :param dataset: dataset for environment. 50 | :param rank: rank for setting seed of environment 51 | :return: `env.Env` / `env.RLEnv` object 52 | """ 53 | habitat_env = Env(config=config, dataset=dataset) 54 | habitat_env.seed(config.SEED + rank) 55 | return habitat_env 56 | 57 | 58 | class WorkerEnv: 59 | def __init__(self, env_fn, env_fn_arg, auto_reset_done): 60 | self._env = env_fn(*env_fn_arg) 61 | self._auto_reset_done = auto_reset_done 62 | 63 | def __call__(self, command, data): 64 | while command != CLOSE_COMMAND: 65 | if command == STEP_COMMAND: 66 | # different step methods for habitat.RLEnv and habitat.Env 67 | if isinstance(self._env, habitat.RLEnv) or isinstance( 68 | self._env, gym.Env 69 | ): 70 | # habitat.RLEnv 71 | observations, reward, done, info = self._env.step(**data) 72 | if self._auto_reset_done and done: 73 | observations = self._env.reset() 74 | return observations, reward, done, info 75 | elif isinstance(self._env, habitat.Env): 76 | # habitat.Env 77 | observations = self._env.step(**data) 78 | if self._auto_reset_done and self._env.episode_over: 79 | observations = self._env.reset() 80 | return observations 81 | else: 82 | raise NotImplementedError 83 | 84 | elif command == RESET_COMMAND: 85 | observations = self._env.reset() 86 | return observations 87 | 88 | elif command == RENDER_COMMAND: 89 | return self._env.render(*data[0], **data[1]) 90 | 91 | elif ( 92 | command == OBSERVATION_SPACE_COMMAND 93 | or command == ACTION_SPACE_COMMAND 94 | ): 95 | if isinstance(command, str): 96 | return getattr(self._env, command) 97 | 98 | elif command == CALL_COMMAND: 99 | function_name, function_args = data 100 | if function_args is None or len(function_args) == 0: 101 | result = getattr(self._env, function_name)() 102 | else: 103 | result = getattr(self._env, function_name)(**function_args) 104 | return result 105 | 106 | # TODO: update CALL_COMMAND for getting attribute like this 107 | elif command == EPISODE_COMMAND: 108 | return self._env.current_episode 109 | else: 110 | raise NotImplementedError 111 | 112 | 113 | class SyncVectorEnv: 114 | r"""Vectorized environment which creates multiple processes where each 115 | process runs its own environment. Main class for parallelization of 116 | training and evaluation. 117 | 118 | 119 | All the environments are synchronized on step and reset methods. 120 | """ 121 | 122 | observation_spaces: List[SpaceDict] 123 | action_spaces: List[SpaceDict] 124 | _workers: List[Union[mp.Process, Thread]] 125 | _is_waiting: bool 126 | _num_envs: int 127 | _auto_reset_done: bool 128 | _mp_ctx: BaseContext 129 | _connection_read_fns: List[Callable[[], Any]] 130 | _connection_write_fns: List[Callable[[Any], None]] 131 | 132 | def __init__( 133 | self, 134 | make_env_fn: Callable[..., Union[Env, RLEnv]] = _make_env_fn, 135 | env_fn_args: Sequence[Tuple] = None, 136 | auto_reset_done: bool = True, 137 | multiprocessing_start_method: str = "forkserver", 138 | ) -> None: 139 | """.. 140 | 141 | :param make_env_fn: function which creates a single environment. An 142 | environment can be of type `env.Env` or `env.RLEnv` 143 | :param env_fn_args: tuple of tuple of args to pass to the 144 | `_make_env_fn`. 145 | :param auto_reset_done: automatically reset the environment when 146 | done. This functionality is provided for seamless training 147 | of vectorized environments. 148 | :param multiprocessing_start_method: the multiprocessing method used to 149 | spawn worker processes. Valid methods are 150 | :py:`{'spawn', 'forkserver', 'fork'}`; :py:`'forkserver'` is the 151 | recommended method as it works well with CUDA. If :py:`'fork'` is 152 | used, the subproccess must be started before any other GPU useage. 153 | """ 154 | self._is_waiting = False 155 | self._is_closed = True 156 | self._num_envs = len(env_fn_args) 157 | 158 | self._auto_reset_done = auto_reset_done 159 | self.workers = [] 160 | for env_fn_arg in env_fn_args: 161 | worker = WorkerEnv(make_env_fn, env_fn_arg, auto_reset_done=True) 162 | self.workers.append(worker) 163 | 164 | self.observation_spaces = [worker(OBSERVATION_SPACE_COMMAND, None) for worker in self.workers] 165 | self.action_spaces = [worker(ACTION_SPACE_COMMAND, None) for worker in self.workers] 166 | self._paused = [] 167 | 168 | @property 169 | def num_envs(self): 170 | r"""number of individual environments. 171 | """ 172 | return self._num_envs - len(self._paused) 173 | 174 | def current_episodes(self): 175 | results = [worker(EPISODE_COMMAND, None) for worker in self.workers] 176 | return results 177 | 178 | def reset(self): 179 | r"""Reset all the vectorized environments 180 | 181 | :return: list of outputs from the reset method of envs. 182 | """ 183 | results = [worker(RESET_COMMAND, None) for worker in self.workers] 184 | return results 185 | 186 | def reset_at(self, index_env: int): 187 | r"""Reset in the index_env environment in the vector. 188 | 189 | :param index_env: index of the environment to be reset 190 | :return: list containing the output of reset method of indexed env. 191 | """ 192 | results = [self.workers[index_env](RESET_COMMAND, None)] 193 | return results 194 | 195 | def step_at(self, index_env: int, action: Dict[str, Any]): 196 | r"""Step in the index_env environment in the vector. 197 | 198 | :param index_env: index of the environment to be stepped into 199 | :param action: action to be taken 200 | :return: list containing the output of step method of indexed env. 201 | """ 202 | results = [self.workers[index_env](STEP_COMMAND, action)] 203 | return results 204 | 205 | def async_step(self, data: List[Union[int, str, Dict[str, Any]]]) -> None: 206 | r"""Asynchronously step in the environments. 207 | 208 | :param data: list of size _num_envs containing keyword arguments to 209 | pass to `step` method for each Environment. For example, 210 | :py:`[{"action": "TURN_LEFT", "action_args": {...}}, ...]`. 211 | """ 212 | # Backward compatibility 213 | if isinstance(data[0], (int, np.integer, str)): 214 | data = [{"action": {"action": action}} for action in data] 215 | 216 | self._is_waiting = True 217 | for write_fn, args in zip(self._connection_write_fns, data): 218 | write_fn((STEP_COMMAND, args)) 219 | 220 | def wait_step(self) -> List[Observations]: 221 | r"""Wait until all the asynchronized environments have synchronized. 222 | """ 223 | observations = [] 224 | for read_fn in self._connection_read_fns: 225 | observations.append(read_fn()) 226 | self._is_waiting = False 227 | return observations 228 | 229 | def step(self, data: List[Union[int, str, Dict[str, Any]]]) -> List[Any]: 230 | r"""Perform actions in the vectorized environments. 231 | 232 | :param data: list of size _num_envs containing keyword arguments to 233 | pass to `step` method for each Environment. For example, 234 | :py:`[{"action": "TURN_LEFT", "action_args": {...}}, ...]`. 235 | :return: list of outputs from the step method of envs. 236 | """ 237 | if isinstance(data[0], (int, np.integer, str)): 238 | data = [{"action": {"action": action}} for action in data] 239 | results = [worker(STEP_COMMAND, args) for worker, args in zip(self.workers, data)] 240 | return results 241 | 242 | def close(self) -> None: 243 | if self._is_closed: 244 | return 245 | 246 | for worker in self.workers: 247 | worker(CLOSE_COMMAND, None) 248 | 249 | self._is_closed = True 250 | 251 | def pause_at(self, index: int) -> None: 252 | r"""Pauses computation on this env without destroying the env. 253 | 254 | :param index: which env to pause. All indexes after this one will be 255 | shifted down by one. 256 | 257 | This is useful for not needing to call steps on all environments when 258 | only some are active (for example during the last episodes of running 259 | eval episodes). 260 | """ 261 | worker = self.workers.pop(index) 262 | self._paused.append((index, worker)) 263 | 264 | def resume_all(self) -> None: 265 | r"""Resumes any paused envs. 266 | """ 267 | for index, worker in reversed(self._paused): 268 | self.workers.insert(index, worker) 269 | self._paused = [] 270 | 271 | def call_at( 272 | self, 273 | index: int, 274 | function_name: str, 275 | function_args: Optional[Dict[str, Any]] = None, 276 | ) -> Any: 277 | r"""Calls a function (which is passed by name) on the selected env and 278 | returns the result. 279 | 280 | :param index: which env to call the function on. 281 | :param function_name: the name of the function to call on the env. 282 | :param function_args: optional function args. 283 | :return: result of calling the function. 284 | """ 285 | self._is_waiting = True 286 | self._connection_write_fns[index]( 287 | (CALL_COMMAND, (function_name, function_args)) 288 | ) 289 | result = self._connection_read_fns[index]() 290 | self._is_waiting = False 291 | return result 292 | 293 | def call( 294 | self, 295 | function_names: List[str], 296 | function_args_list: Optional[List[Any]] = None, 297 | ) -> List[Any]: 298 | r"""Calls a list of functions (which are passed by name) on the 299 | corresponding env (by index). 300 | 301 | :param function_names: the name of the functions to call on the envs. 302 | :param function_args_list: list of function args for each function. If 303 | provided, :py:`len(function_args_list)` should be as long as 304 | :py:`len(function_names)`. 305 | :return: result of calling the function. 306 | """ 307 | self._is_waiting = True 308 | if function_args_list is None: 309 | function_args_list = [None] * len(function_names) 310 | assert len(function_names) == len(function_args_list) 311 | func_args = zip(function_names, function_args_list) 312 | for write_fn, func_args_on in zip( 313 | self._connection_write_fns, func_args 314 | ): 315 | write_fn((CALL_COMMAND, func_args_on)) 316 | results = [] 317 | for read_fn in self._connection_read_fns: 318 | results.append(read_fn()) 319 | self._is_waiting = False 320 | return results 321 | 322 | def render( 323 | self, mode: str = "human", *args, **kwargs 324 | ) -> Union[np.ndarray, None]: 325 | r"""Render observations from all environments in a tiled image. 326 | """ 327 | for write_fn in self._connection_write_fns: 328 | write_fn((RENDER_COMMAND, (args, {"mode": "rgb", **kwargs}))) 329 | images = [read_fn() for read_fn in self._connection_read_fns] 330 | tile = tile_images(images) 331 | if mode == "human": 332 | from habitat.core.utils import try_cv2_import 333 | 334 | cv2 = try_cv2_import() 335 | 336 | cv2.imshow("vecenv", tile[:, :, ::-1]) 337 | cv2.waitKey(1) 338 | return None 339 | elif mode == "rgb_array": 340 | return tile 341 | else: 342 | raise NotImplementedError 343 | 344 | @property 345 | def _valid_start_methods(self) -> Set[str]: 346 | return {"forkserver", "spawn", "fork"} 347 | 348 | def __del__(self): 349 | self.close() 350 | 351 | def __enter__(self): 352 | return self 353 | 354 | def __exit__(self, exc_type, exc_val, exc_tb): 355 | self.close() 356 | --------------------------------------------------------------------------------