├── nvas3d ├── model │ ├── __init__.py │ └── model.py ├── train │ ├── __init__.py │ └── trainer.py ├── utils │ ├── __init__.py │ ├── training_data_generation │ │ ├── upsample_librispeech.py │ │ ├── README.md │ │ ├── script_slakh.py │ │ ├── generate_metadata_square.py │ │ ├── generate_training_data.py │ │ └── generate_test_data.py │ ├── dynamic_utils.py │ ├── utils.py │ ├── plot_utils.py │ └── generate_dataset_utils.py ├── data_loader │ └── __init__.py ├── config │ ├── default_config.yaml │ └── visual_config.yaml ├── assets │ └── saved_models │ │ └── default │ │ └── config.yaml ├── main.py └── baseline │ └── baseline_dsp.py ├── soundspaces_nvas3d ├── __init__.py ├── utils │ ├── __init__.py │ ├── render_scene_script.py │ ├── audio_utils.py │ └── ss_utils.py ├── rir_generation │ ├── __init__.py │ ├── minimal_example_render_rir.py │ ├── generate_rir.py │ └── generate_grid.py ├── image_rendering │ ├── run_generate_envmap.py │ ├── run_generate_target_image.py │ ├── generate_envmap.py │ └── generate_target_image.py └── README.md ├── data ├── source │ ├── drum.flac │ ├── female.flac │ ├── guitar1.flac │ └── guitar2.flac └── objects │ ├── bluetooth_speaker.object_config.json │ └── classic_microphone.object_config.json ├── assets ├── videos │ └── teaser.mp4 └── images │ └── thumbnail.png ├── demo ├── config_demo │ ├── scene1_17DRP5sb8fy.json │ └── path1_17DRP5sb8fy.json ├── run_demo.sh ├── README.md └── generate_demo_data.py ├── CONTRIBUTING.md ├── setup.sh ├── LICENSE ├── CODE_OF_CONDUCT.md ├── README.md └── ACKNOWLEDGEMENTS /nvas3d/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nvas3d/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nvas3d/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /nvas3d/data_loader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /soundspaces_nvas3d/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /soundspaces_nvas3d/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /soundspaces_nvas3d/rir_generation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data/source/drum.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-nvas3d/HEAD/data/source/drum.flac -------------------------------------------------------------------------------- /data/source/female.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-nvas3d/HEAD/data/source/female.flac -------------------------------------------------------------------------------- /assets/videos/teaser.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-nvas3d/HEAD/assets/videos/teaser.mp4 -------------------------------------------------------------------------------- /data/source/guitar1.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-nvas3d/HEAD/data/source/guitar1.flac -------------------------------------------------------------------------------- /data/source/guitar2.flac: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-nvas3d/HEAD/data/source/guitar2.flac -------------------------------------------------------------------------------- /assets/images/thumbnail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-nvas3d/HEAD/assets/images/thumbnail.png -------------------------------------------------------------------------------- /demo/config_demo/scene1_17DRP5sb8fy.json: -------------------------------------------------------------------------------- 1 | { 2 | "source_idx_list": [ 3 | 14, 4 | 43 5 | ], 6 | "receiver_idx_list": [ 7 | 23, 8 | 28, 9 | 29, 10 | 24 11 | ] 12 | } -------------------------------------------------------------------------------- /data/objects/bluetooth_speaker.object_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "render_asset": "bluetooth_speaker.glb", 3 | "mass": 0.41, 4 | "friction_coefficient": 0.5, 5 | "restitution_coefficient": 0.4, 6 | "use_bounding_box_for_collision": true, 7 | "margin": 0.01 8 | } -------------------------------------------------------------------------------- /data/objects/classic_microphone.object_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "render_asset": "classic_microphone.glb", 3 | "mass": 0.41, 4 | "friction_coefficient": 0.5, 5 | "restitution_coefficient": 0.4, 6 | "use_bounding_box_for_collision": true, 7 | "margin": 0.01 8 | } -------------------------------------------------------------------------------- /soundspaces_nvas3d/image_rendering/run_generate_envmap.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import subprocess 7 | 8 | # RENDER ENVMAP 9 | room_list = [ 10 | '17DRP5sb8fy' 11 | ] 12 | 13 | for i, room in enumerate(room_list): 14 | print(f'mp3d envmap rendering: {room}, {i+1}/{len(room_list)}') 15 | subprocess.run(['python', 'soundspaces_nvas3d/image_rendering/generate_envmap.py', '--room', room]) 16 | -------------------------------------------------------------------------------- /soundspaces_nvas3d/image_rendering/run_generate_target_image.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import subprocess 7 | 8 | # RENDER TARGET IMAGE 9 | room_list = [ 10 | '17DRP5sb8fy' 11 | ] 12 | 13 | for i, room in enumerate(room_list): 14 | print(f'mp3d target image rendering: {room}, {i+1}/{len(room_list)}') 15 | subprocess.run(['python', 'soundspaces_nvas3d/image_rendering/generate_target_image.py', '--room', room]) 16 | -------------------------------------------------------------------------------- /nvas3d/config/default_config.yaml: -------------------------------------------------------------------------------- 1 | save_dir: './nvas3d/assets/saved_models' 2 | use_visual: False 3 | use_deconv: True 4 | 5 | data_loader: 6 | batch_size: 96 7 | num_workers: 8 8 | num_receivers: 4 9 | hop_length: 480 10 | win_length: 1200 11 | n_fft: 2048 12 | nvas3d_dataset: 'nvas3d_square_all_all' 13 | audio_format: 'flac' 14 | 15 | training: 16 | epochs: 200 17 | lr: 0.001 18 | weight_decay: 0.0 19 | detect_loss_weight: 0.1 20 | save_checkpoint_interval: 10 21 | resume: False 22 | checkpoint_dir: None 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /nvas3d/assets/saved_models/default/config.yaml: -------------------------------------------------------------------------------- 1 | save_dir: './nvas3d/assets/saved_models' 2 | use_visual: False 3 | use_deconv: True 4 | 5 | data_loader: 6 | batch_size: 96 7 | num_workers: 8 8 | num_receivers: 4 9 | hop_length: 480 10 | win_length: 1200 11 | n_fft: 2048 12 | nvas3d_dataset: 'nvas3d_demo' 13 | audio_format: 'flac' 14 | 15 | training: 16 | epochs: 200 17 | lr: 0.001 18 | weight_decay: 0.0 19 | detect_loss_weight: 0.1 20 | save_checkpoint_interval: 10 21 | resume: False 22 | checkpoint_dir: None 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /demo/config_demo/path1_17DRP5sb8fy.json: -------------------------------------------------------------------------------- 1 | { 2 | "receiver_idx_list": [ 3 | 146, 4 | 141, 5 | 135, 6 | 132, 7 | 127, 8 | 120, 9 | 110, 10 | 100, 11 | 91, 12 | 83, 13 | 77, 14 | 70, 15 | 60 16 | ], 17 | "receiver_rotation_list": [ 18 | 270, 19 | 270, 20 | 180, 21 | 180, 22 | 180, 23 | 90, 24 | 0, 25 | 0, 26 | 0, 27 | 90, 28 | 180, 29 | 180, 30 | 180 31 | ] 32 | } -------------------------------------------------------------------------------- /nvas3d/config/visual_config.yaml: -------------------------------------------------------------------------------- 1 | save_dir: './nvas3d/saved_models' 2 | use_norm: True 3 | 4 | use_visual: True 5 | use_deconv: True 6 | use_real_imag: True 7 | use_mask: False 8 | use_visual_bg: False 9 | use_beamforming: False 10 | use_legacy: True 11 | 12 | data_loader: 13 | batch_size: 96 14 | num_workers: 8 15 | num_receivers: 4 16 | hop_length: 480 17 | win_length: 1200 18 | n_fft: 2048 19 | nvas3d_dataset: 'nvas3d_square_all_all' 20 | audio_format: 'flac' 21 | 22 | training: 23 | epochs: 200 24 | lr: 0.001 25 | weight_decay: 0.0 26 | detect_loss_weight: 0.1 27 | save_checkpoint_interval: 10 28 | resume: False 29 | checkpoint_dir: None 30 | 31 | 32 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). -------------------------------------------------------------------------------- /demo/run_demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ========================================= 4 | # For licensing see accompanying LICENSE file. 5 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 6 | # ========================================= 7 | 8 | # ============= 9 | # Configuration 10 | # ============= 11 | # For data generation 12 | room='17DRP5sb8fy' 13 | dataset_dir='nvas3d_demo' 14 | scene_config="scene1_17DRP5sb8fy" 15 | 16 | # For dry-sound estimation 17 | model='default' 18 | results_dir="results/$dataset_dir/$model/demo/$room/0" 19 | 20 | # For novel-view acoustic rendering 21 | novel_path_config="path1_17DRP5sb8fy" 22 | 23 | 24 | # ============= 25 | # Execution 26 | # ============= 27 | # Data generation 28 | python demo/generate_demo_data.py --room $room --dataset_dir $dataset_dir --scene_config $scene_config 29 | 30 | # Dry-sound estimation using our model 31 | python demo/test_demo.py --dataset_dir $dataset_dir --model $model 32 | 33 | # Novel-view acoustic rendering 34 | python demo/generate_demo_video.py --results_dir $results_dir --novel_path_config $novel_path_config 35 | 36 | -------------------------------------------------------------------------------- /soundspaces_nvas3d/rir_generation/minimal_example_render_rir.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import itertools 7 | 8 | from soundspaces_nvas3d.utils.ss_utils import render_rir_parallel 9 | from soundspaces_nvas3d.utils.aihabitat_utils import load_room_grid 10 | 11 | # Configuration 12 | grid_distance = 2.0 13 | room = '17DRP5sb8fy' 14 | 15 | # Load grid points for the specified room 16 | grid_data = load_room_grid(room, grid_distance=grid_distance) 17 | grid_points = grid_data['grid_points'] 18 | 19 | # Generate all possible combinations of source and receiver indices 20 | num_points = len(grid_points) 21 | pairs = list(itertools.product(range(num_points), repeat=2)) 22 | idx_source, idx_receiver = zip(*pairs) 23 | 24 | # Extract corresponding source and receiver points 25 | source_points = grid_points[list(idx_source)] 26 | receiver_points = grid_points[list(idx_receiver)] 27 | 28 | # Create a room list for the IR generator 29 | room_list = [room] * len(source_points) 30 | 31 | # Generate Room Impulse Responses (IRs) without saving as WAV 32 | ir_list = render_rir_parallel(room_list, source_points, receiver_points) 33 | -------------------------------------------------------------------------------- /nvas3d/utils/training_data_generation/upsample_librispeech.py: -------------------------------------------------------------------------------- 1 | 2 | import torchaudio 3 | from torchaudio.transforms import Resample 4 | import os 5 | 6 | 7 | def process_directory(input_dir, output_dir, sample_rate=48000): 8 | for dirpath, dirnames, filenames in os.walk(input_dir): 9 | for filename in filenames: 10 | if filename.endswith(".flac"): 11 | input_file_path = os.path.join(dirpath, filename) 12 | waveform_, sr = torchaudio.load(input_file_path) 13 | 14 | # Resample 15 | upsample_transform = Resample(sr, sample_rate) 16 | waveform = upsample_transform(waveform_) 17 | 18 | relative_path = os.path.relpath(dirpath, input_dir) 19 | clip_output_dir = os.path.join(output_dir, relative_path) 20 | 21 | os.makedirs(clip_output_dir, exist_ok=True) 22 | 23 | output_file_path = os.path.join(clip_output_dir, filename) 24 | torchaudio.save(output_file_path, waveform, sample_rate) 25 | 26 | 27 | # Modify the paths and parameters as per your needs 28 | input_dir = "data/source/LibriSpeech/" 29 | output_dir = "data/source/LibriSpeech48k/" 30 | 31 | process_directory(input_dir, output_dir) 32 | -------------------------------------------------------------------------------- /soundspaces_nvas3d/utils/render_scene_script.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import sys 7 | from soundspaces_nvas3d.utils.ss_utils import render_scene_config, render_receiver_image, create_scene 8 | from soundspaces_nvas3d.utils.aihabitat_utils import save_grid_heatmap 9 | 10 | 11 | def execute_scene(args): 12 | filename, room, source_idx_list, all_receiver_idx_list, grid_distance = args 13 | render_scene_config(filename, room, eval(source_idx_list), eval(all_receiver_idx_list), eval(grid_distance), no_query=False) 14 | 15 | 16 | def execute_receiver(args): 17 | dirname, room, source_idx_list, source_class_list, all_receiver_idx_list = args 18 | render_receiver_image(dirname, room, eval(source_idx_list), eval(source_class_list), eval(all_receiver_idx_list)) 19 | 20 | 21 | def execute_heatmap(args): 22 | filename, room, grid_points, prediction_list, receiver_idx_list, grid_distance = args 23 | scene = create_scene(room) 24 | save_grid_heatmap(filename, scene.sim.pathfinder, eval(grid_points), eval(prediction_list), receiver_idx_list=eval(receiver_idx_list)) 25 | 26 | 27 | # Dictionary to map function names to actual function calls 28 | functions = { 29 | "scene": execute_scene, 30 | "receiver": execute_receiver, 31 | "heatmap": execute_heatmap 32 | } 33 | 34 | function_name = sys.argv[1] 35 | 36 | # Execute the corresponding function 37 | if function_name in functions: 38 | functions[function_name](sys.argv[2:]) 39 | else: 40 | print(f"Error: Function {function_name} not found!") 41 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # ========================================= 4 | # For licensing see accompanying LICENSE file. 5 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 6 | # ========================================= 7 | # 8 | # This script automates the installation process described in the soundspaces_nvas3d/README. 9 | # Please refer to soundspaces_nvas3d/README.md for detailed instructions and explanations. 10 | 11 | # Update PYTHONPATH 12 | echo "export PYTHONPATH=\$PYTHONPATH:$(pwd)" >> ~/.bashrc && source ~/.bashrc 13 | 14 | 15 | # Install dependencies for Soundspaces 16 | apt-get update && apt-get upgrade -y 17 | 18 | apt-get install -y --no-install-recommends \ 19 | libjpeg-dev libglm-dev libgl1-mesa-glx libegl1-mesa-dev mesa-utils xorg-dev freeglut3-dev 20 | 21 | # Update conda 22 | conda update -n base -c defaults conda 23 | 24 | # Create conda environment 25 | conda create -n ml-nvas3d python=3.7 cmake=3.14.0 -y 26 | conda activate ml-nvas3d 27 | 28 | # Install torch 29 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia 30 | 31 | # Install habitat-sim 32 | cd .. 33 | git clone https://github.com/facebookresearch/habitat-sim.git 34 | cd habitat-sim 35 | pip install -r requirements.txt 36 | git checkout RLRAudioPropagationUpdate 37 | python setup.py install --headless --audio --with-cuda 38 | 39 | # Install habitat-lab 40 | cd .. 41 | git clone https://github.com/facebookresearch/habitat-lab.git 42 | cd habitat-lab 43 | git checkout v0.2.2 44 | pip install -e . 45 | sed -i '36 s/^/#/' habitat/tasks/rearrange/rearrange_sim.py # remove FetchRobot 46 | 47 | # Install soundspaces 48 | cd .. 49 | git clone https://github.com/facebookresearch/sound-spaces.git 50 | cd sound-spaces 51 | pip install -e . 52 | 53 | # Change directory 54 | cd ml-nvas3d -------------------------------------------------------------------------------- /nvas3d/utils/training_data_generation/README.md: -------------------------------------------------------------------------------- 1 | # NVAS3D Training Data Generation 2 | This guid provides the process of generating training data for NVAS3D. 3 | 4 | ## Step-by-Step Instructions 5 | 6 | ### 1. Download Matterport3D Data 7 | Download all rooms from [Matterport3D](https://niessner.github.io/Matterport/). 8 | 9 | ### 2. Download Source Audios 10 | Download the following dataset for source audios and locate them at `data/source` 11 | * [Slakh2100](https://zenodo.org/record/4599666) 12 | * [LibriSpeech](https://www.openslr.org/12) 13 | 14 | ### 3. Preprocess Source Audios 15 | #### 3.1. Clip, Split, and Upsample Slakh2100 16 | To process Slakh2100 dataset (clipping, splitting, and upsampling to 48kHz), execute the following command: 17 | ```bash 18 | python nvas3d/training_data_generation/script_slakh.py 19 | ``` 20 | The output will be located at `data/MIDI/clip/`. 21 | 22 | #### 3.2. Upsample LibriSpeech 23 | To upsample LibriSpeech dataset to 48kHz, execute the following command: 24 | ```bash 25 | python nvas3d/training_data_generation/upsample_librispeech.py 26 | ``` 27 | The output will be located at `data/source/LibriSpeech48k`, and move it to `data/MIDI/clip/speech/LibriSpeech48k`. 28 | 29 | ### 4. Generate Metadata for Microphone Configuration 30 | To generate square-shaped microphone configuration metadata, execute the following command: 31 | ```bash 32 | python nvas3d/training_data_generation/generate_metadata_square.py 33 | ``` 34 | The output metadata will be located at `data/nvas3d_square/` 35 | 36 | ### 5. Generate Training Data 37 | Finally, to generate the training data for NVAS3D, execute the following command: 38 | ```bash 39 | python nvas3d/training_data_generation/generate_training_data.py 40 | ``` 41 | The generated data will be located at `data/nvas3d_square_all_all`. 42 | 43 | ## Acknowledgements 44 | * [LibriSpeech](https://www.openslr.org/12) is licensed under [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/). 45 | 46 | * [Slakh2100](http://www.slakh.com) is licensed under [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/). 47 | 48 | * [Matterport3D](https://niessner.github.io/Matterport/): Matterport3D-based task datasets and trained models are distributed with the [Matterport3D Terms of Use](https://kaldir.vc.in.tum.de/matterport/MP_TOS.pdf) and are licensed under [CC BY-NC-SA 3.0 US](https://creativecommons.org/licenses/by-nc-sa/3.0/us/). 49 | 50 | 51 | -------------------------------------------------------------------------------- /soundspaces_nvas3d/rir_generation/generate_rir.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import argparse 8 | import itertools 9 | 10 | from soundspaces_nvas3d.utils.ss_utils import render_rir_parallel 11 | from soundspaces_nvas3d.utils.aihabitat_utils import load_room_grid 12 | 13 | 14 | def generate_rir(args: argparse.Namespace) -> None: 15 | """ 16 | Generate Room Impulse Response (RIR) based on given room and grid distance. 17 | """ 18 | 19 | grid_distance_str = str(args.grid_distance).replace(".", "_") 20 | dirname = os.path.join(args.dirname, f'rir_mp3d/grid_{grid_distance_str}', args.room) 21 | os.makedirs(dirname, exist_ok=True) 22 | 23 | grid_data = load_room_grid(args.room, grid_distance=args.grid_distance) 24 | grid_points = grid_data['grid_points'] 25 | num_points = len(grid_points) 26 | 27 | # Generate combinations of source and receiver indices 28 | pairs = list(itertools.product(range(num_points), repeat=2)) 29 | source_indices, receiver_indices = zip(*pairs) 30 | 31 | room_list = [args.room] * len(source_indices) 32 | source_points = grid_points[list(source_indices)] 33 | receiver_points = grid_points[list(receiver_indices)] 34 | filename_list = [ 35 | os.path.join(dirname, f'ir_{args.room}_{source_idx}_{receiver_idx}.wav') 36 | for source_idx, receiver_idx in zip(source_indices, receiver_indices) 37 | ] 38 | 39 | render_rir_parallel(room_list, source_points, receiver_points, filename_list) 40 | 41 | 42 | if __name__ == '__main__': 43 | parser = argparse.ArgumentParser(description="Generate Room Impulse Response for given parameters.") 44 | 45 | parser.add_argument('--room', 46 | default='17DRP5sb8fy', 47 | type=str, 48 | help='MP3D room identifier') 49 | 50 | parser.add_argument('--grid_distance', 51 | default=2.0, 52 | type=float, 53 | help='Distance between grid points in meters') 54 | 55 | parser.add_argument('--dirname', 56 | default='data/examples', 57 | type=str, 58 | help='Directory to save generated RIRs') 59 | 60 | args = parser.parse_args() 61 | generate_rir(args) 62 | -------------------------------------------------------------------------------- /soundspaces_nvas3d/image_rendering/generate_envmap.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import argparse 8 | 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | from soundspaces_nvas3d.utils.ss_utils import create_scene 13 | 14 | """ 15 | Notes: 16 | - MP3D room should be located at: data/scene_datasets/mp3d/{room} 17 | - Grid data should be present at: data/scene_datasets/metadata/mp3d/grid_1_0/grid_{room}.npy 18 | (Refer to: rir_generation/generate_grid.py for grid generation) 19 | """ 20 | 21 | room_list = [ 22 | '17DRP5sb8fy' 23 | ] 24 | 25 | 26 | def main(args): 27 | room = args.room 28 | grid_distance = args.grid_distance 29 | 30 | print(room) 31 | image_size = (192, 144) 32 | 33 | # Load grid 34 | grid_distance_str = str(grid_distance).replace(".", "_") 35 | dirname_grid = f'data/scene_datasets/metadata/mp3d/grid_{grid_distance_str}' 36 | filename_grid = f'{dirname_grid}/grid_{room}.npy' 37 | grid_info = np.load(filename_grid, allow_pickle=True).item() 38 | 39 | grid_points = grid_info['grid_points'] 40 | dirname = f'data/examples/envmap_mp3d/grid_{grid_distance_str}/{room}' 41 | os.makedirs(dirname, exist_ok=True) 42 | 43 | scene = create_scene(room, image_size=image_size) 44 | 45 | for receiver_idx in range(grid_points.shape[0]): 46 | receiver_position = grid_points[receiver_idx] 47 | scene.update_receiver_position(receiver_position) 48 | rgb, depth = scene.render_envmap() 49 | 50 | filename_rgb = f'{dirname}/envmap_rgb_{room}_{receiver_idx}.png' 51 | filename_depth = f'{dirname}/envmap_depth_{room}_{receiver_idx}.png' 52 | 53 | plt.imsave(filename_rgb, rgb) 54 | plt.imsave(filename_depth, depth) 55 | 56 | filename_rgb = f'{dirname}/envmap_rgb_{room}_{receiver_idx}.npy' 57 | filename_depth = f'{dirname}/envmap_depth_{room}_{receiver_idx}.npy' 58 | 59 | np.save(filename_rgb, rgb) 60 | np.save(filename_depth, depth) 61 | 62 | 63 | if __name__ == '__main__': 64 | parser = argparse.ArgumentParser() 65 | 66 | parser.add_argument('--room', 67 | default='17DRP5sb8fy', 68 | type=str, 69 | help='mp3d room') 70 | 71 | parser.add_argument('--grid_distance', 72 | default=1.0, 73 | type=float, 74 | help='distance between grid points') 75 | 76 | args = parser.parse_args() 77 | 78 | main(args) 79 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2023 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | 41 | ------------------------------------------------------------------------------- 42 | SOFTWARE DISTRIBUTED IN THIS REPOSITORY: 43 | 44 | The software includes a number of subcomponents with separate 45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS. 46 | ------------------------------------------------------------------------------- -------------------------------------------------------------------------------- /nvas3d/utils/dynamic_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import numpy as np 7 | import typing as T 8 | 9 | from scipy.signal import oaconvolve 10 | 11 | 12 | def setup_dynamic_interp( 13 | receiver_position: np.ndarray, 14 | total_samples: int, 15 | ) -> T.Tuple[np.ndarray, np.ndarray]: 16 | """ 17 | Setup moving path with a constant speed for a receiver, given its positions in 3D space. 18 | 19 | Args: 20 | - receiver_position: Receiver positions in 3D space of shape (num_positions, 3). 21 | - total_samples: Total number of samples in the audio. 22 | 23 | Returns: 24 | - interp_index: Indices representing the start positions for interpolation. 25 | - interp_weight: Weight values for linear interpolation. 26 | """ 27 | 28 | # Calculate the number of samples per interval 29 | distance = np.linalg.norm(np.diff(receiver_position, axis=0), axis=1) 30 | speed_per_sample = distance.sum() / total_samples 31 | samples_per_interval = np.round(distance / speed_per_sample).astype(int) 32 | 33 | # Distribute rounding errors 34 | error = total_samples - samples_per_interval.sum() 35 | for i in np.random.choice(len(samples_per_interval), abs(error)): 36 | samples_per_interval[i] += np.sign(error) 37 | 38 | # Calculate indices and weights for linear interpolation 39 | interp_index = np.repeat(np.arange(len(distance)), samples_per_interval) 40 | interp_weight = np.concatenate([np.linspace(0, 1, num, endpoint=False) for num in samples_per_interval]) 41 | 42 | return interp_index, interp_weight.astype(np.float32) 43 | 44 | 45 | def convolve_moving_receiver( 46 | source_audio: np.ndarray, 47 | rirs: np.ndarray, 48 | interp_index: T.List[int], 49 | interp_weight: T.List[float] 50 | ) -> np.ndarray: 51 | """ 52 | Apply convolution between an audio signal and moving impulse responses (IRs). 53 | 54 | Args: 55 | - source_audio: Source audio of shape (audio_len,) 56 | - rirs: RIRs of shape (num_positions, num_channels, ir_length) 57 | - interp_index: Indices representing the start positions for interpolation of shape (audio_len,). 58 | - interp_weight: Weight values for linear interpolation of shape (audio_len,). 59 | 60 | Returns: 61 | - Convolved audio signal of shape (num_channels, audio_len) 62 | """ 63 | 64 | num_channels = rirs.shape[1] 65 | audio_len = source_audio.shape[0] 66 | 67 | # Perform convolution for each position and channel 68 | convolved_audios = oaconvolve(source_audio[None, None, :], rirs, axes=-1)[..., :audio_len] 69 | 70 | # NumPy fancy indexing and broadcasting for interpolation 71 | start_audio = convolved_audios[interp_index, np.arange(num_channels)[:, None], np.arange(audio_len)] 72 | end_audio = convolved_audios[interp_index + 1, np.arange(num_channels)[:, None], np.arange(audio_len)] 73 | interp_weight = interp_weight[None, :] 74 | 75 | # Apply linear interpolation 76 | moving_audio = (1 - interp_weight) * start_audio + interp_weight * end_audio 77 | 78 | return moving_audio 79 | -------------------------------------------------------------------------------- /nvas3d/main.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import yaml 8 | import shutil 9 | import logging 10 | import argparse 11 | 12 | import torch 13 | import torch.multiprocessing as mp 14 | import torch.distributed as dist 15 | from torch.nn.parallel import DistributedDataParallel 16 | 17 | from nvas3d.train.trainer import Trainer 18 | from nvas3d.data_loader.data_loader import SSAVDataLoader 19 | from nvas3d.model.model import NVASNet 20 | 21 | 22 | def setup_distributed_training(local_rank): 23 | world_size = torch.cuda.device_count() 24 | dist.init_process_group(backend='nccl', 25 | init_method='tcp://localhost:12355', 26 | rank=local_rank, 27 | world_size=world_size) 28 | device = torch.device(f'cuda:{local_rank}') 29 | return device 30 | 31 | 32 | def main(local_rank, args): 33 | is_ddp = args.gpu is None 34 | device = setup_distributed_training(local_rank) if is_ddp else torch.device(f'cuda:{args.gpu}') 35 | 36 | if not is_ddp: 37 | torch.cuda.set_device(args.gpu) 38 | 39 | # Load and parse config file 40 | with open(args.config, 'r') as f: 41 | config = yaml.safe_load(f) 42 | 43 | # Create a directory and copy config 44 | save_dir = os.path.join(config['save_dir'], f'{args.exp}') 45 | os.makedirs(save_dir, exist_ok=True) 46 | shutil.copy(args.config, f'{save_dir}/config.yaml') 47 | 48 | # Initialize DataLoader 49 | data_loader = SSAVDataLoader(config['use_visual'], config['use_deconv'], is_ddp, **config['data_loader']) 50 | 51 | # Initialize and deploy model to device 52 | model = NVASNet(config['data_loader']['num_receivers'], config['use_visual']) 53 | model = model.to(device) 54 | if is_ddp: 55 | model = DistributedDataParallel(model, device_ids=[local_rank]) 56 | 57 | # Train the model 58 | trainer = Trainer(model, data_loader, device, save_dir, config['use_deconv'], config['training']) 59 | trainer.train() 60 | 61 | if is_ddp: 62 | dist.destroy_process_group() 63 | 64 | 65 | if __name__ == '__main__': 66 | torch.manual_seed(42) 67 | 68 | logging.basicConfig(level=logging.INFO, format='%(asctime)s, %(levelname)s: %(message)s', datefmt="%Y-%m-%d %H:%M:%S") 69 | 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument('--config', 72 | type=str, 73 | default='./nvas3d/config/default_config.yaml', 74 | help='Path to the configuration file.') 75 | 76 | parser.add_argument('--exp', 77 | type=str, 78 | default='default_exp', 79 | help='Experiment name') 80 | 81 | parser.add_argument('--gpu', 82 | type=int, 83 | default=None, 84 | help='GPU ID to use') 85 | 86 | args = parser.parse_args() 87 | 88 | if args.gpu is not None: 89 | main(0, args) # Single GPU mode 90 | else: 91 | mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count(), join=True) 92 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Novel-View Acoustic Synthesis from 3D Reconstructed Rooms 2 | 3 | [[Paper](http://arxiv.org/abs/2310.15130)] 4 | [[Docs](/soundspaces_nvas3d/README.md)] 5 | [[Demo docs](/demo/README.md)] 6 | [[Video1](https://docs-assets.developer.apple.com/ml-research/models/nvas/nvas3d_turn.mp4)] 7 | [[Video2](https://docs-assets.developer.apple.com/ml-research/models/nvas/teaser.mp4)] 8 | 9 | 10 | > [Click on the thumbnail image](https://docs-assets.developer.apple.com/ml-research/models/nvas/teaser.mp4) below to watch a video showcasing our Novel-View Acoustic Synthesis. 11 | > 12 | > 🎧 _For the optimal experience, using a headset is recommended._ 13 | 14 |

15 | 16 | Demo Video 17 | 18 |

19 | 20 | Welcome to the official code repository for "Novel-View Acoustic Synthesis from 3D Reconstructed Rooms". 21 | This project estimates the sound anywhere in a scene containing multiple unknown sound sources, hence resulting in novel-view acoustic synthesis, given audio recordings from multiple microphones and the 3D geometry and material of a scene. 22 | 23 | 24 | ["Novel-View Acoustic Synthesis from 3D Reconstructed Rooms"](http://arxiv.org/abs/2310.15130)\ 25 | [Byeongjoo Ahn](https://byeongjooahn.github.io), 26 | [Karren Yang](https://karreny.github.io), 27 | [Brian Hamilton](https://www.brianhamilton.co), 28 | [Jonathan Sheaffer](https://www.linkedin.com/in/jsheaffer/), 29 | [Anurag Ranjan](https://anuragranj.github.io), 30 | [Miguel Sarabia](https://scholar.google.co.uk/citations?user=U2mA-EAAAAAJ&hl=en), 31 | [Oncel Tuzel](https://www.onceltuzel.net), 32 | [Jen-Hao Rick Chang](https://rick-chang.github.io) 33 | 34 | ## Directory Structure 35 | ```yaml 36 | . 37 | ├── demo/ # Quickstart and demo 38 | │ ├── ... 39 | ├── nvas3d/ # Implementation of our model 40 | │ ├── ... 41 | └── soundspaces_nvas3d/ # SoundSpaces integration for NVAS3D 42 | ├── ... 43 | ``` 44 | 45 | ## Installation: SoundSpaces 46 | Follow our [Step-by-Step Installation Guide](soundspaces_nvas3d/README.md) for rendering room impulse responses (RIRs) and images in Matterport3D rooms using SoundSpaces. 47 | 48 | ## Quickstart: Demo 49 | Refer to the [Demo Guide](demo/README.md) for instructions on data generation, dry sound estimation using our model, and novel-view acoustic rendering. 50 | 51 | ### Download the Pretrained Model 52 | Download [our pretrained model](https://docs-assets.developer.apple.com/ml-research/models/nvas/checkpoint_200.pt) and place it in the `nvas3d/assets/saved_models/default/checkpoints/` directory. 53 | 54 | ### Launch the Demo 55 | To get started with the full pipeline quickly: 56 | ```bash 57 | bash demo/run_demo.sh 58 | ``` 59 | 60 | ## Training 61 | After [Training Data Generation](nvas3d/utils/training_data_generation/README.md), start the training process with: 62 | ```bash 63 | python main.py --config ./nvas3d/config/default_config.yaml --exp default_exp 64 | ``` 65 | 66 | ## Acknowledgements 67 | We thank Dirk Schroeder and David Romblom for insightful discussions and feedback, Changan Chen for the assistance with SoundSpaces. 68 | 69 | 78 | -------------------------------------------------------------------------------- /soundspaces_nvas3d/rir_generation/generate_grid.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import argparse 8 | 9 | import numpy as np 10 | import torch 11 | import typing as T 12 | 13 | from soundspaces_nvas3d.soundspaces_nvas3d import Scene 14 | 15 | # Set rooms for which grids need to be generated 16 | ROOM_LIST = [ 17 | '17DRP5sb8fy' 18 | ] 19 | 20 | 21 | def save_xy_grid_points(room: str, 22 | grid_distance: float, 23 | dirname: str 24 | ) -> T.Dict[str, T.Any]: 25 | """ 26 | Save xy grid points given a mp3d room 27 | """ 28 | 29 | filename_npy = f'{dirname}/grid_{room}.npy' 30 | filename_png = f'{dirname}/grid_{room}.png' 31 | 32 | scene = Scene( 33 | room, 34 | [None], # placeholder for source class 35 | include_visual_sensor=False, 36 | add_source_mesh=False, 37 | device=torch.device('cpu') 38 | ) 39 | grid_points = scene.generate_xy_grid_points(grid_distance, filename_png=filename_png) 40 | room_size = scene.sim.pathfinder.navigable_area 41 | 42 | grid_info = dict( 43 | grid_points=grid_points, 44 | room_size=room_size, 45 | grid_distance=grid_distance, 46 | ) 47 | np.save(filename_npy, grid_info) 48 | 49 | return grid_info 50 | 51 | 52 | def calculate_statistics(data_dict): 53 | values = list(data_dict.values()) 54 | average = np.mean(values) 55 | minimum = np.min(values) 56 | maximum = np.max(values) 57 | return average, minimum, maximum 58 | 59 | 60 | def main(args): 61 | # Set grid distance 62 | grid_distance = args.grid_distance 63 | grid_distance_str = str(grid_distance).replace(".", "_") 64 | 65 | # Generate grid points 66 | dirname = f'data/scene_datasets/metadata/mp3d/grid_{grid_distance_str}' 67 | os.makedirs(dirname, exist_ok=True) 68 | 69 | # Define lists to store data 70 | num_points_dict = {} 71 | room_size_dict = {} 72 | 73 | for room in ROOM_LIST: 74 | # Note: mp3d room should be under data/scene_datasets/mp3d/{room} 75 | grid_info = save_xy_grid_points(room, grid_distance, dirname) 76 | 77 | # Append data to lists 78 | num_points_dict[room] = grid_info['grid_points'].shape[0] 79 | room_size_dict[room] = grid_info['room_size'] 80 | 81 | # Calculate statistics 82 | num_points_average, num_points_min, num_points_max = calculate_statistics(num_points_dict) 83 | room_size_average, room_size_min, room_size_max = calculate_statistics(room_size_dict) 84 | 85 | # Save statistics to txt file 86 | filename_satistics = f'data/scene_datasets/metadata/mp3d/grid_{grid_distance_str}/statistics.txt' 87 | with open(filename_satistics, 'w') as file: 88 | file.write('[Number of points statistics]:\n') 89 | file.write(f'Average: {num_points_average:.2f}\n') 90 | file.write(f'Minimum: {num_points_min}\n') 91 | file.write(f'Maximum: {num_points_max}\n') 92 | file.write('\n') 93 | file.write('[Room size statistics]:\n') 94 | file.write(f'Average: {room_size_average:.2f}\n') 95 | file.write(f'Minimum: {room_size_min:.2f}\n') 96 | file.write(f'Maximum: {room_size_max:.2f}\n') 97 | 98 | file.write('\n') 99 | file.write('[Room-wise statistics]:\n') 100 | for room in ROOM_LIST: 101 | file.write(f'{room} - Num Points: {num_points_dict[room]}, \t Room Size: {room_size_dict[room]:.2f}\n') 102 | 103 | 104 | if __name__ == '__main__': 105 | parser = argparse.ArgumentParser() 106 | 107 | parser.add_argument('--grid_distance', 108 | default=1.0, 109 | type=float, 110 | help='distance between grid points') 111 | 112 | args = parser.parse_args() 113 | 114 | main(args) 115 | -------------------------------------------------------------------------------- /nvas3d/utils/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | # Portions of this code are derived from VIDA (CC-BY-NC). 6 | # Original work available at: https://github.com/facebookresearch/learning-audio-visual-dereverberation/tree/main 7 | 8 | import typing as T 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | 14 | source_class_map = { 15 | 'female': 0, # speech: 22961 16 | 'male': 1, 17 | 'Bass': 2, # 51678 18 | 'Brass': 3, # 7216 19 | 'Chromatic Percussion': 4, # 4154 20 | 'Drums': 5, # 47291 21 | 'Guitar': 6, # 75848 22 | 'Organ': 7, # 10438 23 | 'Piano': 8, # 59328 24 | 'Pipe': 9, # 8200 25 | 'Reed': 10, # 10039 26 | 'Strings': 11, # 3169 27 | 'Strings (continued)': 12, # 48429 28 | 'Synth Lead': 13, # 7053 29 | 'Synth Pad': 14 # 11210 30 | } 31 | 32 | MP3D_SCENE_SPLITS = { 33 | 'demo': ['17DRP5sb8fy'], 34 | } 35 | 36 | 37 | def get_key_from_value(my_dict, target_value): 38 | for key, value in my_dict.items(): 39 | if value == target_value: 40 | return key 41 | return None 42 | 43 | 44 | def parse_librispeech_metadata(filename: str) -> T.Dict: 45 | """ 46 | Reads LibriSpeech metadata from a csv file and returns a dictionary. 47 | Each entry in the dictionary maps a reader_id (as integer) to its corresponding gender. 48 | """ 49 | 50 | import csv 51 | 52 | # Dictionary to store reader_id and corresponding gender 53 | librispeech_metadata = {} 54 | 55 | with open(filename, 'r') as file: 56 | reader = csv.reader(file, delimiter='|') 57 | for row in reader: 58 | # Skip comment lines and header 59 | if row[0].startswith(';') or row[0].strip() == 'ID': 60 | continue 61 | reader_id = int(row[0]) # Convert string to integer 62 | sex = row[1].strip() # Remove extra spaces 63 | librispeech_metadata[reader_id] = sex 64 | 65 | return librispeech_metadata 66 | 67 | 68 | def parse_ir(filename_ir: str): 69 | """ 70 | Extracts the room, source index and receiver index information from an IR filename. 71 | The function assumes that these elements are present at the end of the filename, separated by underscores. 72 | """ 73 | 74 | parts = filename_ir.split('_') 75 | room = parts[-3] 76 | source_idx = int(parts[-2]) 77 | receiver_idx = int(parts[-1].split('.')[0]) 78 | 79 | return room, source_idx, receiver_idx 80 | 81 | 82 | def count_parameters(model): 83 | """ 84 | Returns the number of trainable parameters in a given PyTorch model. 85 | """ 86 | 87 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 88 | 89 | 90 | #################### 91 | # From Changan's code 92 | #################### 93 | 94 | def complex_norm( 95 | complex_tensor: torch.Tensor, 96 | power: float = 1.0 97 | ) -> torch.Tensor: 98 | """ 99 | Compute the norm of complex tensor input. 100 | """ 101 | 102 | # Replace by torch.norm once issue is fixed 103 | # https://github.com/pytorch/pytorch/issues/34279 104 | return complex_tensor.pow(2.).sum(-1).pow(0.5 * power) 105 | 106 | 107 | def normalize(audio, norm='peak'): 108 | if norm == 'peak': 109 | peak = abs(audio).max() 110 | if peak != 0: 111 | return audio / peak 112 | else: 113 | return audio 114 | elif norm == 'rms': 115 | if torch.is_tensor(audio): 116 | audio = audio.numpy() 117 | audio_without_padding = np.trim_zeros(audio, trim='b') 118 | rms = np.sqrt(np.mean(np.square(audio_without_padding))) * 100 119 | if rms != 0: 120 | return audio / rms 121 | else: 122 | return audio 123 | else: 124 | raise NotImplementedError 125 | 126 | 127 | def overlap_chunk(input, dimension, size, step, left_padding): 128 | """ 129 | Input shape is [Frequency bins, Frame numbers] 130 | """ 131 | input = F.pad(input, (left_padding, size), 'constant', 0) 132 | return input.unfold(dimension, size, step) 133 | -------------------------------------------------------------------------------- /nvas3d/utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | import torchaudio 11 | from torchmetrics.audio import SignalDistortionRatio 12 | 13 | from nvas3d.utils.audio_utils import psnr, clip_two 14 | 15 | 16 | def save_specgram(filename, stft): 17 | mag = torch.sqrt(stft[..., 0]**2 + stft[..., 1]**2) 18 | fig, ax = plt.subplots(1, 1, figsize=(14, 4)) 19 | cax1 = ax.imshow(torch.log10(mag), aspect='auto', origin='lower') 20 | fig.colorbar(cax1, ax=ax, orientation='horizontal') 21 | ax.set_title('Magnitude spectrogram') 22 | 23 | plt.tight_layout() 24 | plt.savefig(filename) 25 | plt.close(fig) 26 | 27 | 28 | def plot_specgram(filename, waveform, sample_rate, plot_phase=False): 29 | 30 | mag = compute_spectrogram(waveform, use_mag=True).squeeze() 31 | phase = compute_spectrogram(waveform, use_phase=True).squeeze() 32 | 33 | phase = phase * (180 / np.pi) 34 | 35 | if plot_phase: 36 | fig, ax = plt.subplots(2, 1, figsize=(14, 8)) 37 | cax1 = ax[0].imshow(torch.log10(mag), aspect='auto', origin='lower') 38 | fig.colorbar(cax1, ax=ax[0], orientation='horizontal') 39 | ax[0].set_title('Magnitude spectrogram') 40 | 41 | cax2 = ax[1].imshow(phase, aspect='auto', origin='lower', cmap='twilight') 42 | fig.colorbar(cax2, ax=ax[1], orientation='horizontal') 43 | ax[1].set_title('Phase spectrogram') 44 | else: 45 | fig, ax = plt.subplots(1, 1, figsize=(14, 4)) 46 | cax1 = ax.imshow(torch.log10(mag), aspect='auto', origin='lower') 47 | fig.colorbar(cax1, ax=ax, orientation='horizontal') 48 | ax.set_title('Magnitude spectrogram') 49 | 50 | plt.tight_layout() 51 | plt.savefig(filename) 52 | plt.close(fig) 53 | 54 | 55 | def plot_waveform(filename, waveform, sample_rate): 56 | plt.figure(figsize=(14, 4)) 57 | plt.plot(np.linspace(0, len(waveform[0]) / sample_rate, len(waveform[0])), waveform[0]) 58 | plt.title('Waveform') 59 | plt.xlabel('Time [s]') 60 | plt.savefig(filename) 61 | plt.close() 62 | 63 | 64 | def plot_debug(filename, waveform, sample_rate, save_png=False, save_normalized=False, reference=None): 65 | waveform = waveform.reshape(1, -1) 66 | 67 | if reference is not None: 68 | reference = reference.reshape(1, -1) 69 | waveform, reference = clip_two(waveform, reference) 70 | # assert waveform.shape[-1] == reference.shape[-1] 71 | psnr_score = psnr(reference.detach().cpu(), waveform.detach().cpu()) 72 | compute_sdr = SignalDistortionRatio() 73 | sdr = compute_sdr(reference.detach().cpu(), waveform.detach().cpu()) 74 | 75 | if save_normalized: 76 | waveform /= waveform.abs().max() 77 | 78 | if save_png: 79 | plot_specgram(f'{filename}_specgram.png', waveform, sample_rate) 80 | plot_waveform(f'{filename}_waveform.png', waveform, sample_rate) 81 | 82 | if reference is not None: 83 | filename = f'{filename}({psnr_score:.1f})({sdr:.1f})' 84 | torchaudio.save(f'{filename}.wav', waveform, sample_rate) 85 | 86 | 87 | def compute_spectrogram(audio_data, n_fft=2048, hop_length=480, win_length=1200, use_mag=False, use_phase=False, use_complex=False): 88 | 89 | audio_data = to_tensor(audio_data) 90 | stft = torch.stft(audio_data, n_fft=n_fft, hop_length=hop_length, win_length=win_length, 91 | window=torch.hamming_window(win_length, device=audio_data.device), pad_mode='constant', 92 | return_complex=not use_complex) 93 | 94 | if use_mag: 95 | spectrogram = stft.abs().unsqueeze(-1) 96 | elif use_phase: 97 | spectrogram = stft.angle().unsqueeze(-1) 98 | elif use_complex: 99 | # one channel for real and one channel for imaginary 100 | spectrogram = stft 101 | else: 102 | raise ValueError 103 | 104 | return spectrogram 105 | 106 | 107 | def to_tensor(v): 108 | if torch.is_tensor(v): 109 | return v 110 | elif isinstance(v, np.ndarray): 111 | return torch.from_numpy(v) 112 | else: 113 | return torch.tensor(v, dtype=torch.float) 114 | -------------------------------------------------------------------------------- /demo/README.md: -------------------------------------------------------------------------------- 1 | # Novel-View Acoustic Synthesis from 3D Reconstructed Rooms: Demo 2 | 3 | This guide is for the demonstration for NVAS from 3D reconstructed rooms. It provides instructions for data generation, dry sound estimation using our model, and novel-view acoustic rendering. 4 | 5 | ## Quick Start 6 | 7 | ### Download the Pretrained Model 8 | Download [our pretrained model](https://docs-assets.developer.apple.com/ml-research/models/nvas/checkpoint_200.pt) and place it in the `nvas3d/assets/saved_models/default/checkpoints/` directory. 9 | To get started with the full pipeline quickly: 10 | ```bash 11 | bash demo/run_demo.sh 12 | ``` 13 | 14 | ### Launch the Demo 15 | To get started with the full pipeline quickly: 16 | ```bash 17 | bash demo/run_demo.sh 18 | ``` 19 | 20 | ## Detailed Workflow 21 | For a more detailed approach, you can explore each segment of the workflow individually: 22 | 23 | ### 1. Data Generation 24 | Generate and save the demo data specific to a room. Please ensure you've installed SoundSpaces and downloaded the necessary data (sample Matterport3D room of 17DRP5sb8fy and material config) by following the instructions in [soundspaces_nvas3d/README.md](../soundspaces_nvas3d/README.md). 25 | 26 | ```bash 27 | python demo/generate_demo_data.py --room 17DRP5sb8fy 28 | ``` 29 | 30 | The generated data will be structured as: 31 | ```yaml 32 | data/nvas3d_demo/demo/{room}/0 33 | ├── receiver/ # Receiver audio. 34 | ├── wiener/ # Deconvolved audio to accelerate tests. 35 | ├── source/ # Ground truth dry audio. 36 | ├── reverb1/ # Ground truth reverberant audio for source 1. 37 | ├── reverb2/ # Ground truth reverberant audio for source 2. 38 | ├── ir_receiver/ # Ground truth RIRs from source to receiver. 39 | ├── config.png # Visualization of room configuration. 40 | └── metadata.pt # Metadata (source indices, classes, grid points, and room info). 41 | ``` 42 | 43 | Additionally, visualizations of room indices will be located at `data/nvas3d_demo/{room}/index_{grid_distance}.png`. If not, please ensure that [headless rendering in Habitat-Sim](https://github.com/facebookresearch/habitat-sim/issues?q=headless) is configured correctly. 44 | 45 | > [!Note] 46 | > If you want to modify the scene configuration (e.g.,, source locations, receiver locations), edit `demo/nvas/config_demo/scene_config.json`. 47 | 48 | ### 2. Dry Sound Estimation Using Our Model 49 | Run the NVAS3D model on the your dataset: 50 | ```bash 51 | python demo/test_demo.py 52 | ``` 53 | 54 | The results will be structured as: 55 | ```yaml 56 | {results_demo} = results/nvas3d_demo/default/demo/{room}/0 57 | │ 58 | ├── results_drysound/ 59 | │ ├── dry1_estimated.wav # Estimated dry sound for source 1 location. 60 | │ ├── dry2_estimated.wav # Estimated dry sound for source 2 location. 61 | │ ├── dry1_gt.wav # Ground-truth dry sound for source 1. 62 | │ ├── dry2_gt.wav # Ground-truth dry sound for source 2. 63 | │ └── detected/ 64 | │ └── dry_{query_idx}.wav # Estimated dry sound for query idx if positive. 65 | │ 66 | ├── baseline_dsp/ 67 | │ ├── deconv_and_sum1.wav # Baseline result of dry sound for source 1. 68 | │ └── deconv_and_sum2.wav # Baseline result of dry sound for source 2. 69 | │ 70 | ├── results_detection/ 71 | │ ├── detection_heatmap.png # Heatmap visualizing source detection results. 72 | │ └── metadata.pt # Detection results and metadata. 73 | │ 74 | └── metrics.json # Quantitative metrics. 75 | 76 | ``` 77 | 78 | ### 3. Novel-view Acoustic Rendering 79 | Render the novel-view video integrated with the sound: 80 | 81 | ```bash 82 | python demo/generate_demo_video.py 83 | ``` 84 | 85 | The video results will be structured as: 86 | ```yaml 87 | results/nvas3d_demo/default/demo/{room}/0 88 | └── video/ 89 | ├── moving_audio.wav # Audio only: Interpolated sound for the moving receiver. 90 | ├── moving_audio_1.wav # Audio only: Interpolated sound from separated Source 1 for the moving receiver. 91 | ├── moving_audio_2.wav # Audio only: Interpolated sound from separated Source 2 for the moving receiver. 92 | ├── moving_video.mp4 # Video only: Interpolated video for the moving receiver. 93 | ├── nvas.mp4 # Video with Audio: NVAS video results with combined audio from all sources. 94 | ├── nvas_source1.mp4 # Video with Audio: NVAS video results for separated Source 1 audio. 95 | ├── nvas_source2.mp4 # Video with Audio: NVAS video results for separated Source 2 audio. 96 | └── rgb_receiver.png # Image: A static view rendered from the receiver's perspective for reference. 97 | ``` 98 | 99 | > [!Note] 100 | > If you want to modify the novel receiver's path, edit `demo/nvas/config_demo/path1_novel_receiver.json`. 101 | -------------------------------------------------------------------------------- /ACKNOWLEDGEMENTS: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this software may utilize the following copyrighted material, the use of which is hereby acknowledged. 3 | 4 | _____________________ 5 | 6 | Facebook, Inc (PyTorch) 7 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 8 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 9 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 10 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 11 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 12 | Copyright (c) 2011-2013 NYU (Clement Farabet) 13 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) 14 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 15 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) 16 | 17 | All rights reserved. 18 | 19 | Redistribution and use in source and binary forms, with or without 20 | modification, are permitted provided that the following conditions are met: 21 | 22 | 1. Redistributions of source code must retain the above copyright 23 | notice, this list of conditions and the following disclaimer. 24 | 25 | 2. Redistributions in binary form must reproduce the above copyright 26 | notice, this list of conditions and the following disclaimer in the 27 | documentation and/or other materials provided with the distribution. 28 | 29 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America 30 | and IDIAP Research Institute nor the names of its contributors may be 31 | used to endorse or promote products derived from this software without 32 | specific prior written permission. 33 | 34 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 35 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 36 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 37 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 38 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 39 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 40 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 41 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 42 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 43 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 44 | POSSIBILITY OF SUCH DAMAGE. 45 | 46 | Meta Platforms, Inc (Habitat-Sim) 47 | MIT License 48 | 49 | Copyright (c) Meta Platforms, Inc. and its affiliates. 50 | 51 | Permission is hereby granted, free of charge, to any person obtaining a copy 52 | of this software and associated documentation files (the "Software"), to deal 53 | in the Software without restriction, including without limitation the rights 54 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 55 | copies of the Software, and to permit persons to whom the Software is 56 | furnished to do so, subject to the following conditions: 57 | 58 | The above copyright notice and this permission notice shall be included in all 59 | copies or substantial portions of the Software. 60 | 61 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 62 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 63 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 64 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 65 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 66 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 67 | SOFTWARE. 68 | 69 | Meta Platforms, Inc (Habitat-Lab) 70 | MIT License 71 | 72 | Copyright (c) Meta Platforms, Inc. and its affiliates. 73 | 74 | Permission is hereby granted, free of charge, to any person obtaining a copy 75 | of this software and associated documentation files (the "Software"), to deal 76 | in the Software without restriction, including without limitation the rights 77 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 78 | copies of the Software, and to permit persons to whom the Software is 79 | furnished to do so, subject to the following conditions: 80 | 81 | The above copyright notice and this permission notice shall be included in all 82 | copies or substantial portions of the Software. 83 | 84 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 85 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 86 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 87 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 88 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 89 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 90 | SOFTWARE. 91 | 92 | Meta Platforms, Inc (SoundSpaces) 93 | Attribution 4.0 International 94 | 95 | Meta Platforms, Inc (VisualVoice) 96 | Attribution-NonCommercial 4.0 International 97 | -------------------------------------------------------------------------------- /nvas3d/utils/training_data_generation/script_slakh.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | # Script to clip and split Slakh2100 dataset: https://zenodo.org/record/4599666 6 | # 7 | 8 | import os 9 | import yaml 10 | import shutil 11 | import numpy as np 12 | 13 | import torchaudio 14 | from torchaudio.transforms import Resample 15 | 16 | 17 | # Set the silence threshold 18 | THRESHOLD = 1e-2 19 | 20 | 21 | def read_instruments(source_dir): 22 | """ 23 | Read instrument metadata from source directory 24 | """ 25 | 26 | inst_dict = {} # dictionary to store instrument data 27 | # Loop through every file in the source directory 28 | for root, dirs, files in os.walk(source_dir): 29 | for file in files: 30 | if file.endswith("metadata.yaml"): 31 | with open(os.path.join(root, file), 'r') as yaml_file: 32 | metadata = yaml.safe_load(yaml_file) 33 | 34 | # Add instrument name and class to dictionary 35 | for stem, stem_data in metadata['stems'].items(): 36 | if stem_data['midi_program_name'] not in inst_dict.keys(): 37 | inst_dict[stem_data['midi_program_name']] = stem_data['inst_class'] 38 | return inst_dict 39 | 40 | 41 | def copy_files_based_on_metadata(source_dir, target_dir, query): 42 | """ 43 | Copy files based on metadata 44 | """ 45 | 46 | # Walk through each file in the source directory 47 | for root, dirs, files in os.walk(source_dir): 48 | for file in files: 49 | if file.endswith("metadata.yaml"): 50 | with open(os.path.join(root, file), 'r') as yaml_file: 51 | try: 52 | metadata = yaml.safe_load(yaml_file) 53 | 54 | # If instrument matches query, copy the associated flac file to target directory 55 | for stem, stem_data in metadata['stems'].items(): 56 | if stem_data['midi_program_name'] == query: 57 | source_flac_file = os.path.join(root.replace('metadata.yaml', ''), "stems", f"{stem}.flac") 58 | target_flac_file = source_flac_file.replace(source_dir, target_dir) 59 | 60 | os.makedirs(os.path.dirname(target_flac_file), exist_ok=True) 61 | 62 | # Copy the .flac file 63 | if os.path.exists(source_flac_file): 64 | shutil.copyfile(source_flac_file, target_flac_file) 65 | 66 | except yaml.YAMLError as exc: 67 | print(exc) 68 | 69 | 70 | def save_clips(waveform, output_path, sample_rate, min_length, max_length): 71 | """ 72 | Save clips from a given waveform 73 | """ 74 | 75 | # Calculate clip lengths in samples 76 | min_length_samples = int(min_length * sample_rate) 77 | max_length_samples = int(max_length * sample_rate) 78 | 79 | # Get total number of samples in the waveform 80 | total_samples = waveform.shape[1] 81 | 82 | start = 0 83 | end = np.random.randint(min_length_samples, max_length_samples + 1) 84 | clip_number = 1 85 | 86 | # Keep creating clips until we've covered the whole waveform 87 | while end <= total_samples: 88 | # Slice the waveform to get the clip 89 | clip = waveform[:, start:end] 90 | 91 | # Check if the clip contains all zeros (is silent) 92 | if abs(clip).mean() > THRESHOLD: 93 | # Save the clip 94 | output_clip_path = f"{output_path.rsplit('.', 1)[0]}_clip_{clip_number}.flac" 95 | torchaudio.save(output_clip_path, clip, sample_rate) 96 | 97 | # Increment the clip number 98 | clip_number += 1 99 | 100 | # Update the start and end for the next clip 101 | start = end 102 | end = start + np.random.randint(min_length_samples, max_length_samples + 1) 103 | 104 | 105 | def process_directory(input_dir, output_dir, sample_rate=48000, min_length=6, max_length=10): 106 | """ 107 | Process a directory of audio files 108 | """ 109 | 110 | # Walk through each file in the input directory 111 | for dirpath, dirnames, filenames in os.walk(input_dir): 112 | for filename in filenames: 113 | if filename.endswith(".flac"): 114 | input_file_path = os.path.join(dirpath, filename) 115 | waveform_, sr = torchaudio.load(input_file_path) 116 | 117 | # Resample the audio 118 | upsample_transform = Resample(sr, sample_rate) 119 | waveform = upsample_transform(waveform_) 120 | 121 | relative_path = os.path.relpath(dirpath, input_dir) 122 | clip_output_dir = os.path.join(output_dir, relative_path) 123 | 124 | os.makedirs(clip_output_dir, exist_ok=True) 125 | 126 | output_file_path = os.path.join(clip_output_dir, filename) 127 | 128 | # set audio length 129 | split = dirpath.split('/')[-3] 130 | if split == 'test': 131 | min_length = 20 132 | max_length = 21 133 | else: 134 | min_length = 6 135 | max_length = 7 136 | # Save clips from the resampled audio 137 | save_clips(waveform, output_file_path, sample_rate, min_length, max_length) 138 | 139 | 140 | # Read instruments 141 | source_dir = 'data/source/slakh2100_flac_redux' 142 | inst_dict = read_instruments(source_dir) 143 | print(inst_dict) 144 | 145 | # Copy query instruments 146 | os.makedirs('data/MIDI', exist_ok=True) 147 | os.makedirs('data/MIDI/full', exist_ok=True) 148 | 149 | # Loop through each instrument in the dictionary 150 | for query, inst_class in inst_dict.items(): 151 | os.makedirs(f'data/MIDI/full/{inst_class}', exist_ok=True) 152 | target_dir = f'data/MIDI/full/{inst_class}/{query}' 153 | print(target_dir) 154 | 155 | # Copy each instrument file to the target directory 156 | for subdir in ['train', 'test', 'validation']: 157 | copy_files_based_on_metadata(os.path.join(source_dir, subdir), os.path.join(target_dir, subdir), query) 158 | print('Copy done!') 159 | 160 | # Clip the copied audio files 161 | os.makedirs('data/MIDI/clip', exist_ok=True) 162 | for query, inst_class in inst_dict.items(): 163 | os.makedirs(f'data/MIDI/clip/{inst_class}', exist_ok=True) 164 | target_dir = f'data/MIDI/full/{inst_class}/{query}' 165 | clip_dir = f'data/MIDI/clip/{inst_class}/{query}' 166 | print(clip_dir) 167 | 168 | process_directory(target_dir, clip_dir) 169 | print('Clip done!') 170 | -------------------------------------------------------------------------------- /soundspaces_nvas3d/image_rendering/generate_target_image.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import numpy as np 8 | import glob 9 | import torch 10 | import matplotlib.pyplot as plt 11 | from soundspaces_nvas3d.soundspaces_nvas3d import Receiver 12 | from soundspaces_nvas3d.utils.aihabitat_utils import load_room_grid 13 | from soundspaces_nvas3d.utils.ss_utils import create_scene 14 | import argparse 15 | 16 | 17 | def contains_png_files(dirname): 18 | return len(glob.glob(os.path.join(dirname, "*.png"))) > 0 19 | 20 | 21 | def render_target_images(args): 22 | room = args.room 23 | grid_distance = args.grid_distance 24 | image_size = (192, 144) 25 | 26 | # Download soundspaces dataset if not exists 27 | grid_points = load_room_grid(room, grid_distance=grid_distance)['grid_points'] 28 | grid_distance_str = str(grid_distance).replace(".", "_") 29 | if args.dirname is None: 30 | dirname = f'data/examples/target_image_mp3d/grid_{grid_distance_str}/{room}' 31 | else: 32 | dirname = args.dirname 33 | os.makedirs(dirname, exist_ok=True) 34 | 35 | # # Return if there is a png file already 36 | # if contains_png_files(dirname): 37 | # return 38 | 39 | scene = create_scene(room, image_size=image_size) 40 | sample_rate = 48000 41 | position = [0.0, 0, 0] 42 | rotation = 0.0 43 | receiver = Receiver(position, rotation, sample_rate) 44 | 45 | # Render image from 1m distance 46 | dist = 1.0 47 | # south east north west 48 | position_offset_list = [[0.0, 0.0, dist], [dist, 0.0, 0.0], [0.0, 0.0, -dist], [-dist, 0.0, 0.0]] 49 | rotation_offset_list = [0.0, 90.0, 180.0, 270.0] 50 | 51 | source_class_list = ['female', 'male', 'guitar'] 52 | # source_class_list = ['guitar'] 53 | 54 | if args.source_idx_list is None: 55 | source_idx_list = range(grid_points.shape[0]) 56 | else: 57 | source_idx_list = list(map(int, args.source_idx_list.split())) 58 | 59 | for source_idx in source_idx_list: 60 | # print(f'{room}: {source_idx}/{len(source_idx_list)}') 61 | source_position = grid_points[source_idx] 62 | 63 | rgb = [] 64 | depth = [] 65 | 66 | for position_offset, rotation_offset in zip(position_offset_list, rotation_offset_list): 67 | receiver.position = source_position + torch.tensor(position_offset) 68 | receiver.rotation = rotation_offset 69 | scene.update_receiver(receiver) 70 | rgb_, depth_ = scene.render_image() 71 | rgb.append(rgb_) 72 | depth.append(depth_) 73 | 74 | rgb = np.concatenate(rgb, axis=1) 75 | depth = np.concatenate(depth, axis=1) 76 | 77 | filename_rgb_png = f'{dirname}/rgb_{room}_{source_idx}.png' 78 | filename_depth_png = f'{dirname}/depth_{room}_{source_idx}.png' 79 | 80 | plt.imsave(filename_rgb_png, rgb) 81 | plt.imsave(filename_depth_png, depth) 82 | 83 | filename_rgb_npy = f'{dirname}/rgb_{room}_{source_idx}.npy' 84 | filename_depth_npy = f'{dirname}/depth_{room}_{source_idx}.npy' 85 | 86 | np.save(filename_rgb_npy, rgb) 87 | np.save(filename_depth_npy, depth) 88 | 89 | # Optional: Render with mesh. Requires mesh data for habitat under data/objects/{source.mesh} 90 | # source_female = Source( 91 | # position=position, 92 | # rotation=random.uniform(0, 360), 93 | # dry_sound='', 94 | # mesh='female', 95 | # device=torch.device('cpu') 96 | # ) 97 | 98 | # source_male = Source( 99 | # position=position, 100 | # rotation=random.uniform(0, 360), 101 | # dry_sound='', 102 | # mesh='male', 103 | # device=torch.device('cpu') 104 | # ) 105 | 106 | # source_guitar = Source( 107 | # position=position, 108 | # rotation=random.uniform(0, 360), 109 | # dry_sound='', 110 | # mesh='guitar', 111 | # device=torch.device('cpu') 112 | # ) 113 | 114 | # source_list = [source_female, source_male, source_guitar] 115 | # # source_list = [source_guitar] 116 | 117 | # for source_class, source in zip(source_class_list, source_list): 118 | # scene.add_source_mesh = True 119 | # scene.source_list = [None] 120 | # source.position = source_position 121 | 122 | # scene.update_source(source, 0) 123 | # rgb = [] 124 | # depth = [] 125 | 126 | # for position_offset, rotation_offset in zip(position_offset_list, rotation_offset_list): 127 | # receiver.position = source_position + torch.tensor(position_offset) 128 | # receiver.rotation = rotation_offset 129 | # scene.update_receiver(receiver) 130 | # rgb_, depth_ = scene.render_image() 131 | # rgb.append(rgb_) 132 | # depth.append(depth_) 133 | 134 | # scene.sim.get_rigid_object_manager().remove_all_objects() 135 | 136 | # rgb = np.concatenate(rgb, axis=1) 137 | # depth = np.concatenate(depth, axis=1) 138 | 139 | # filename_rgb_png = f'{dirname}/rgb_{room}_{source_class}_{source_idx}.png' 140 | # filename_depth_png = f'{dirname}/depth_{room}_{source_class}_{source_idx}.png' 141 | 142 | # plt.imsave(filename_rgb_png, rgb) 143 | # plt.imsave(filename_depth_png, depth) 144 | 145 | # filename_rgb_npy = f'{dirname}/rgb_{room}_{source_class}_{source_idx}.npy' 146 | # filename_depth_npy = f'{dirname}/depth_{room}_{source_class}_{source_idx}.npy' 147 | 148 | # np.save(filename_rgb_npy, rgb) 149 | # np.save(filename_depth_npy, depth) 150 | 151 | 152 | if __name__ == '__main__': 153 | parser = argparse.ArgumentParser() 154 | 155 | parser.add_argument('--room', 156 | default='17DRP5sb8fy', 157 | type=str, 158 | help='mp3d room') 159 | 160 | parser.add_argument('--source_idx_list', 161 | default=None, 162 | type=str, 163 | help='source_idx_list') 164 | 165 | parser.add_argument('--dirname', 166 | default=None, 167 | type=str, 168 | help='dirname') 169 | 170 | parser.add_argument('--grid_distance', 171 | default=1.0, 172 | type=float, 173 | help='distance between grid points') 174 | 175 | args = parser.parse_args() 176 | 177 | render_target_images(args) 178 | -------------------------------------------------------------------------------- /soundspaces_nvas3d/README.md: -------------------------------------------------------------------------------- 1 | # SoundSpaces for NVAS3D 2 | This guide provides a step-by-step process to set up and generate data using SoundSpaces for NVAS3D. You can also quickly start the installation process by running the `setup.sh` script included in this guide. 3 | 4 | ## Prerequisites for [SoundSpaces](https://github.com/facebookresearch/sound-spaces) 5 | - Ubuntu 20.04 or a similar Linux distribution 6 | - CUDA 7 | - [Conda](https://docs.conda.io/en/latest/miniconda.html) 8 | 9 | 10 | ## Installation 11 | Here we repeat the installation steps from [SoundSpaces Installation Guide](https://github.com/facebookresearch/sound-spaces/blob/main/INSTALLATION.md). 12 | 13 | ### 0. Update `PYTHONPATH` 14 | Ensure your `PYTHONPATH` is updated: 15 | ```bash 16 | cd ml-nvas3d # the root of the repository 17 | export PYTHONPATH=$PYTHONPATH:$(pwd) 18 | ``` 19 | 20 | ### 1. Install SoundSpaces Dependencies 21 | Install required dependencies for SoundSpaces: 22 | ```bash 23 | apt-get update && apt-get upgrade -y && \ 24 | apt-get install -y --no-install-recommends libjpeg-dev libglm-dev libgl1-mesa-glx libegl1-mesa-dev mesa-utils xorg-dev freeglut3-dev 25 | ``` 26 | 27 | ### 2. Update Conda 28 | Ensure that your Conda installation is up to date: 29 | ```bash 30 | conda update -n base -c defaults conda 31 | ``` 32 | 33 | ### 3. Create and Activate Conda Environment 34 | Create a new Conda environment named nvas3d with Python 3.7 and cmake 3.14.0: 35 | ```bash 36 | conda create -n nvas3d python=3.7 cmake=3.14.0 -y && \ 37 | conda activate nvas3d 38 | ``` 39 | 40 | ### 4. Install PyTorch 41 | Install PyTorch, torchvision, torchaudio, and the CUDA toolkit. Replace the version accordingly with your CUDA version: 42 | ```bash 43 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia 44 | ``` 45 | 46 | ### 5. Install [Habitat-Sim](https://github.com/facebookresearch/habitat-sim/tree/main) 47 | Install Habitat-Sim: 48 | ```bash 49 | # step outside the repo root and git clone Habitat-Sim 50 | cd .. && \ 51 | git clone https://github.com/facebookresearch/habitat-sim.git && \ 52 | cd habitat-sim && \ 53 | pip install -r requirements.txt && \ 54 | git checkout RLRAudioPropagationUpdate && \ 55 | python setup.py install --headless --audio --with-cuda 56 | ``` 57 | This guide is based on commit `30f4cc7`. 58 | 59 | ### 6. Install [Habitat-Lab](https://github.com/facebookresearch/habitat-lab) 60 | Install Habitat-Lab: 61 | ```bash 62 | # step outside the repo root and git clone Habitat-Lab 63 | cd .. &&\ 64 | git clone https://github.com/facebookresearch/habitat-lab.git && \ 65 | cd habitat-lab && \ 66 | git checkout v0.2.2 && \ 67 | pip install -e . && \ 68 | sed -i '36 s/^/#/' habitat/tasks/rearrange/rearrange_sim.py # remove FetchRobot 69 | ``` 70 | This guide is based on commit `49f7c15`. 71 | 72 | 73 | ### 7. Install [SoundSpaces](https://github.com/facebookresearch/sound-spaces/tree/main) 74 | Install SoundSpaces: 75 | ```bash 76 | # step outside the repo root and git clone SoundSpaces 77 | cd .. &&\ 78 | git clone https://github.com/facebookresearch/sound-spaces.git && \ 79 | cd sound-spaces && \ 80 | pip install -e . 81 | ``` 82 | This guide is based on commit `3768a50`. 83 | 84 | ### 8. Install Additional Packages 85 | Install additional Python packages needed: 86 | ```bash 87 | pip install scipy torchmetrics pyroomacoustics 88 | ``` 89 | 90 | ## Quick Installation 91 | To streamline the installation process, run the setup.sh script, which encapsulates all the steps listed above: 92 | ```bash 93 | bash setup.sh 94 | ``` 95 | 96 | ## Preparation of Matterport3D Room and Material 97 | 98 | ### 1. Download Example MP3D Room 99 | Follow these steps to download the MP3D data in the correct directory: 100 | 101 | 102 | (1) **Switch to habitat-sim directory**: 103 | ```bash 104 | cd /path/to/habitat-sim 105 | ``` 106 | 107 | (2) **Run the dataset download script**: 108 | ```bash 109 | python src_python/habitat_sim/utils/datasets_download.py --uids mp3d_example_scene 110 | ``` 111 | 112 | (3) **Copy the downloaded data to this repository**: 113 | ```bash 114 | mkdir -p /path/to/ml-nvas3d/data/scene_datasets/ 115 | cp -r data/scene_datasets/mp3d_example /path/to/ml-nvas3d/data/scene_datasets/mp3d 116 | ``` 117 | 118 | After executing the above steps, ensure the existence of the `data/scene_datasets/mp3d/17DRP5sb8fy` directory. For additional rooms, you might want to consider downloading from [Matterport3D](https://niessner.github.io/Matterport/). 119 | 120 | ### 2. Download Material Configuration File 121 | Download the material configuration file in the correct directory: 122 | 123 | ```bash 124 | cd /path/to/ml-nvas3d && \ 125 | mkdir data/material && \ 126 | cd data/material && wget https://raw.githubusercontent.com/facebookresearch/rlr-audio-propagation/main/RLRAudioPropagationPkg/data/mp3d_material_config.json 127 | ``` 128 | 129 | > [!Note] 130 | > Now, you are ready to run [Demo](../demo/README.md) using our pretrained model. 131 | > 132 | > To explore with more data, you can consider running [Training Data Generation](../nvas3d/utils/training_data_generation/README.md). 133 | 134 | 135 | ## Extended Usage: SoundSpaces for NVAS3D 136 | For users interested in exploring with more data, follow the steps outlined below. 137 | 138 | ### Generate grid points in room: 139 | To create grid points within the room, execute the following command: 140 | ```bash 141 | python soundspaces_nvas3d/rir_generation/generate_grid.py --grid_distance 1.0 142 | ``` 143 | * Input directory: `data/scene_datasets/mp3d/` 144 | * Output directory: `data/scene_datasets/metadata/mp3d/grid_{grid_distance}` 145 | 146 | ### Generate RIRs: 147 | To generate room impulse responses for all grid point pairs, execute the following command: 148 | ```bash 149 | python soundspaces_nvas3d/rir_generation/generate_rir.py --room 17DRP5sb8fy 150 | ``` 151 | * Input directory: `data/scene_datasets/mp3d/{room}` 152 | * Output directory:`data/examples/rir_mp3d/grid_{grid_distance}/{room}` 153 | 154 | ### Minimal Example of RIR Generation Using SoundSpaces 155 | We provide an mimimal example code for generating RIRs using our codebase. To generate sample RIRs, execute the following command: 156 | ```bash 157 | python demo/sound_spaces/nvas3d/example_render_ir.py 158 | ``` 159 | 160 | --- 161 | Optionally, for those interested in training an audio-visual network, images can be rendered using the following methods: 162 | ### Render target images: 163 | To render images centered at individual grid points, execute the following command: 164 | ```bash 165 | python soundspaces_nvas3d/image_rendering/run_generate_target_image.py 166 | ``` 167 | * Input directory: `data/scene_datasets/mp3d/{room}` 168 | * Output directory: `data/examples/target_image_mp3d/grid_{grid_distance}/{room}` 169 | 170 | ### Render environment maps: 171 | To render environment maps from all grid points, execute the following command: 172 | ```bash 173 | python soundspaces_nvas3d/image_rendering/run_generate_envmap.py 174 | ``` 175 | * Input directory: `data/scene_datasets/mp3d/{room}` 176 | * Output directory: `data/examples/envmap_3d/grid_{grid_distance}/{room}` 177 | 178 | -------------------------------------------------------------------------------- /nvas3d/utils/generate_dataset_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import glob 7 | import random 8 | import numpy as np 9 | import typing as T 10 | from scipy.signal import fftconvolve 11 | 12 | import torch 13 | import torchaudio 14 | import torch.nn.functional as F 15 | 16 | MIN_LENGTH_AUDIO = 5 * 48000 17 | 18 | 19 | def load_ir_source_receiver( 20 | ir_dir: str, 21 | room: str, 22 | source_idx: int, 23 | receiver_idx_list: T.List[int], 24 | ir_length: int 25 | ) -> T.List[torch.Tensor]: 26 | """ 27 | Load impulse responses for specific source and receivers in a room. 28 | 29 | Args: 30 | - ir_dir: Directory containing impulse response files. 31 | - room: Name of the room. 32 | - source_idx: Index of the source. 33 | - receiver_idx_list: List of receiver indices. 34 | - ir_length: Length of the impulse response to be loaded. 35 | 36 | Returns: 37 | - List of loaded impulse responses (first channel only). 38 | """ 39 | 40 | ir_list = [] 41 | for receiver_idx in receiver_idx_list: 42 | filename_ir = f'{ir_dir}/{room}/ir_{room}_{source_idx}_{receiver_idx}.wav' 43 | ir, _ = torchaudio.load(filename_ir) 44 | if ir[0].shape[0] > ir_length: 45 | ir0 = ir[0][:ir_length] 46 | else: 47 | ir0 = F.pad(ir[0], (0, ir_length - ir[0].shape[0])) 48 | ir_list.append(ir0) 49 | 50 | return ir_list 51 | 52 | 53 | def load_ir_source_receiver_allchannel( 54 | ir_dir: str, 55 | room: str, 56 | source_idx: int, 57 | receiver_idx_list: T.List[int], 58 | ir_length: int 59 | ) -> T.List[torch.Tensor]: 60 | """ 61 | Load impulse responses for all channels for specific source and receivers in a room. 62 | 63 | Args: 64 | - ir_dir: Directory containing impulse response files. 65 | - room: Name of the room. 66 | - source_idx: Index of the source. 67 | - receiver_idx_list: List of receiver indices. 68 | - ir_length: Length of the impulse response to be loaded. 69 | 70 | Returns: 71 | - List of loaded impulse responses for all channels. 72 | """ 73 | 74 | ir_list = [] 75 | for receiver_idx in receiver_idx_list: 76 | filename_ir = f'{ir_dir}/{room}/ir_{room}_{source_idx}_{receiver_idx}.wav' 77 | ir, _ = torchaudio.load(filename_ir) 78 | if ir.shape[-1] > ir_length: 79 | ir = ir[..., :ir_length] 80 | else: 81 | ir = F.pad(ir, (0, ir_length - ir[0].shape[0])) 82 | ir_list.append(ir) 83 | 84 | return ir_list 85 | 86 | 87 | def save_audio_list( 88 | filename: str, 89 | audio_list: T.List[torch.Tensor], 90 | sample_rate: int, 91 | audio_format: str 92 | ): 93 | """ 94 | Save a list of audio tensors to files. 95 | 96 | Args: 97 | - filename: Filename to save audio. 98 | - audio_list: List of audio tensors to save. 99 | - sample_rate: Sample rate of audio. 100 | - audio_format: File format to save audio. 101 | """ 102 | 103 | for idx_audio, audio in enumerate(audio_list): 104 | torchaudio.save(f'{filename}_{idx_audio+1}.{audio_format}', audio.unsqueeze(0), sample_rate) 105 | 106 | 107 | def clip_source( 108 | source1_audio: torch.Tensor, 109 | source2_audio: torch.Tensor, 110 | len_clip: int 111 | ) -> T.Tuple[torch.Tensor, torch.Tensor]: 112 | """ 113 | Clip source audio tensors for faster convolution. 114 | 115 | Args: 116 | - source1_audio: First source audio tensor. 117 | - source2_audio: Second source audio tensor. 118 | - len_clip: Desired length of the output audio tensors. 119 | 120 | Returns: 121 | - Clipped source1_audio and source2_audio. 122 | """ 123 | 124 | # pad audio 125 | if len_clip > source1_audio.shape[0]: 126 | source1_audio = F.pad(source1_audio, (0, len_clip - source1_audio.shape[0])) 127 | source1_audio = F.pad(source1_audio, (0, max(0, len_clip - source1_audio.shape[0]))) 128 | source2_audio = F.pad(source2_audio, (0, max(0, len_clip - source2_audio.shape[0]))) 129 | 130 | # clip 131 | start_index = np.random.randint(0, source1_audio.shape[0] - len_clip) \ 132 | if source1_audio.shape[0] != len_clip else 0 133 | source1_audio_clipped = source1_audio[start_index: start_index + len_clip] 134 | source2_audio_clipped = source2_audio[start_index: start_index + len_clip] 135 | 136 | return source1_audio_clipped, source2_audio_clipped 137 | 138 | 139 | def compute_reverb( 140 | source_audio: torch.Tensor, 141 | ir_list: T.List[torch.Tensor], 142 | padding: str = 'valid' 143 | ) -> T.List[torch.Tensor]: 144 | """ 145 | Compute reverberated audio signals by convolving source audio with impulse responses. 146 | 147 | Args: 148 | - source_audio: Source audio signal (dry) to be reverberated. 149 | - ir_list: List of impulse responses for reverberation. 150 | - padding: Padding mode for convolution ('valid' or 'full'). 151 | 152 | Returns: 153 | - A list of reverberated audio signals. 154 | """ 155 | 156 | reverb_list = [] 157 | for ir in ir_list: 158 | reverb = fftconvolve(source_audio, ir, padding) 159 | reverb_list.append(torch.from_numpy(reverb)) 160 | 161 | return reverb_list 162 | 163 | 164 | #################### 165 | # Sampling source audio to generate data 166 | #################### 167 | 168 | def sample_speech(files_librispeech, librispeech_metadata): 169 | source_speech = torch.zeros(1) # Initialize with a tensor of zeros 170 | while torch.all(source_speech == 0) or source_speech.shape[-1] < MIN_LENGTH_AUDIO: # Continue until a non-zero tensor is found 171 | filename_source = random.choice(files_librispeech) 172 | speaker_id = int(filename_source.split('/')[6]) 173 | speaker_gender = librispeech_metadata[speaker_id] 174 | if speaker_gender == 'M': 175 | source_class = 'male' 176 | else: 177 | source_class = 'female' 178 | source_speech, _ = torchaudio.load(filename_source) 179 | source_speech = source_speech.reshape(-1) 180 | 181 | return source_speech, source_class 182 | 183 | 184 | def sample_nonspeech(all_instruments_dir): 185 | class_dir = random.choice(all_instruments_dir) 186 | 187 | # Ensure that the class is not 'Speech' 188 | while 'Speech' in class_dir: 189 | class_dir = random.choice(all_instruments_dir) 190 | 191 | files_source = glob.glob(class_dir + '/**/*.flac', recursive=True) 192 | 193 | source_audio = torch.zeros(1) # Initialize with a tensor of zeros 194 | while torch.all(source_audio == 0) or source_audio.shape[-1] < MIN_LENGTH_AUDIO: # Continue until a non-zero tensor is found 195 | filename_source = random.choice(files_source) 196 | source_class = class_dir.split('/')[3] 197 | source_audio, _ = torchaudio.load(filename_source) 198 | source_audio = source_audio.reshape(-1) 199 | 200 | return source_audio, source_class 201 | 202 | 203 | def sample_acoustic_guitar(all_instruments_dir): 204 | guitar_dir = [dirname for dirname in all_instruments_dir if dirname.split('/')[4] == 'Acoustic Guitar (steel)'] 205 | 206 | class_dir = random.choice(guitar_dir) 207 | 208 | files_source = glob.glob(class_dir + '/**/*.flac', recursive=True) 209 | 210 | source_audio = torch.zeros(1) # Initialize with a tensor of zeros 211 | while torch.all(source_audio == 0) or source_audio.shape[-1] < MIN_LENGTH_AUDIO: # Continue until a non-zero tensor is found 212 | filename_source = random.choice(files_source) 213 | source_class = 'guitar' 214 | source_audio, _ = torchaudio.load(filename_source) 215 | source_audio = source_audio.reshape(-1) 216 | 217 | return source_audio, source_class 218 | 219 | 220 | def sample_instrument(all_instruments_dir, librispeech_metadata, classname): 221 | guitar_dir = [dirname for dirname in all_instruments_dir if dirname.split('/')[3] == classname] # e.g., Guitar 222 | 223 | class_dir = random.choice(guitar_dir) 224 | 225 | files_source = glob.glob(class_dir + '/**/*.flac', recursive=True) 226 | 227 | source_audio = torch.zeros(1) # Initialize with a tensor of zeros 228 | while torch.all(source_audio == 0) or source_audio.shape[-1] < MIN_LENGTH_AUDIO: # Continue until a non-zero tensor is found 229 | filename_source = random.choice(files_source) 230 | source_class = 'guitar' 231 | source_audio, _ = torchaudio.load(filename_source) 232 | source_audio = source_audio.reshape(-1) 233 | 234 | return source_audio, source_class 235 | 236 | 237 | def sample_all(all_instruments_dir, librispeech_metadata): 238 | class_dir = random.choice(all_instruments_dir) 239 | files_source = glob.glob(class_dir + '/**/*.flac', recursive=True) 240 | 241 | source_audio = torch.zeros(1) # Initialize with a tensor of zeros 242 | while torch.all(source_audio == 0) or source_audio.shape[-1] < MIN_LENGTH_AUDIO: # Continue until a non-zero tensor is found 243 | filename_source = random.choice(files_source) 244 | source_class = class_dir.split('/')[3] 245 | source_audio, _ = torchaudio.load(filename_source) 246 | source_audio = source_audio.reshape(-1) 247 | 248 | if source_class == 'Speech': 249 | speaker_id = int(filename_source.split('/')[6]) 250 | speaker_gender = librispeech_metadata[speaker_id] 251 | if speaker_gender == 'M': 252 | source_class = 'male' 253 | else: 254 | source_class = 'female' 255 | 256 | return source_audio, source_class 257 | -------------------------------------------------------------------------------- /soundspaces_nvas3d/utils/audio_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch 7 | import torchaudio 8 | import matplotlib.pyplot as plt 9 | import torch.nn.functional as F 10 | import typing as T 11 | 12 | 13 | def fft_conv( 14 | signal: torch.Tensor, 15 | kernel: torch.Tensor, 16 | is_cpu: bool = False 17 | ) -> torch.Tensor: 18 | """ 19 | Perform convolution of a signal and a kernel using Fast Fourier Transform (FFT). 20 | 21 | Args: 22 | - signal (torch.Tensor): Input signal tensor. 23 | - kernel (torch.Tensor): Kernel tensor. 24 | - is_cpu (bool, optional): Flag to determine if the operation should be on the CPU. 25 | 26 | Returns: 27 | - torch.Tensor: Convolved signal. 28 | """ 29 | 30 | if is_cpu: 31 | signal = signal.detach().cpu() 32 | kernel = kernel.detach().cpu() 33 | 34 | padded_signal = F.pad(signal.reshape(-1), (0, kernel.size(-1) - 1)) 35 | padded_kernel = F.pad(kernel.reshape(-1), (0, signal.size(-1) - 1)) 36 | 37 | signal_fr = torch.fft.rfftn(padded_signal, dim=-1) 38 | kernel_fr = torch.fft.rfftn(padded_kernel, dim=-1) 39 | 40 | output_fr = signal_fr * kernel_fr 41 | output = torch.fft.irfftn(output_fr, dim=-1) 42 | 43 | return output 44 | 45 | 46 | def wiener_deconv( 47 | signal: torch.Tensor, 48 | kernel: torch.Tensor, 49 | snr: float, 50 | is_cpu: bool = False 51 | ) -> torch.Tensor: 52 | """ 53 | Perform Wiener deconvolution on a given signal using a specified kernel. 54 | 55 | Args: 56 | - signal (torch.Tensor): Input signal tensor. 57 | - kernel (torch.Tensor): Kernel tensor. 58 | - snr (float): Signal-to-noise ratio. 59 | - is_cpu (bool, optional): Flag to determine if the operation should be on the CPU. 60 | 61 | Returns: 62 | - torch.Tensor: Deconvolved signal. 63 | """ 64 | 65 | if is_cpu: 66 | signal = signal.detach().cpu() 67 | kernel = kernel.detach().cpu() 68 | 69 | n_fft = signal.shape[-1] + kernel.shape[-1] - 1 70 | signal_fr = torch.fft.rfft(signal.reshape(-1), n=n_fft) 71 | kernel_fr = torch.fft.rfft(kernel.reshape(-1), n=n_fft) 72 | 73 | wiener_filter_fr = torch.conj(kernel_fr) / (torch.abs(kernel_fr)**2 + 1 / snr) 74 | 75 | filtered_signal_fr = wiener_filter_fr * signal_fr 76 | 77 | filtered_signal = torch.fft.irfft(filtered_signal_fr) 78 | 79 | # Crop the filtered signal to the original size 80 | filtered_signal = filtered_signal[:signal.shape[-1]] 81 | 82 | return filtered_signal 83 | 84 | 85 | def wiener_deconv_list( 86 | signal: T.List[torch.Tensor], 87 | kernel: T.List[torch.Tensor], 88 | snr: float, 89 | is_cpu: bool = False 90 | ) -> torch.Tensor: 91 | """ 92 | wiener_deconv for list input. 93 | 94 | Args: 95 | - signal (torch.Tensor): List of signals. 96 | - kernel (torch.Tensor): List of kernels. 97 | - snr (float): Signal-to-noise ratio. 98 | - is_cpu (bool, optional): Flag to determine if the operation should be on the CPU. 99 | 100 | Returns: 101 | - torch.Tensor: Deconvolved signal. 102 | """ 103 | 104 | M = len(signal) 105 | if isinstance(signal, list): 106 | signal = torch.stack(signal).reshape(M, -1) 107 | assert signal.shape[0] == M 108 | kernel = torch.stack(kernel).reshape(M, -1) 109 | snr /= abs(kernel).max() 110 | 111 | if is_cpu: 112 | signal = signal.detach().cpu() 113 | kernel = kernel.detach().cpu() 114 | 115 | n_batch, n_samples = signal.shape 116 | 117 | # Pad the signals and kernels to avoid circular convolution 118 | padded_signal = F.pad(signal, (0, kernel.shape[-1] - 1)) 119 | padded_kernel = F.pad(kernel, (0, signal.shape[-1] - 1)) 120 | 121 | # Compute the Fourier transforms 122 | signal_fr = torch.fft.rfft(padded_signal, dim=-1) 123 | kernel_fr = torch.fft.rfft(padded_kernel, dim=-1) 124 | 125 | # Compute the Wiener filter in the frequency domain 126 | wiener_filter_fr = torch.conj(kernel_fr) / (torch.abs(kernel_fr)**2 + 1 / snr) 127 | 128 | # Apply the Wiener filter 129 | filtered_signal_fr = wiener_filter_fr * signal_fr 130 | 131 | # Compute the inverse Fourier transform 132 | filtered_signal = torch.fft.irfft(filtered_signal_fr, dim=-1) 133 | 134 | # Crop the filtered signals to the original size 135 | filtered_signal = filtered_signal[:, :n_samples] 136 | 137 | filtered_signal_list = [filtered_signal[i] for i in range(filtered_signal.size(0))] 138 | 139 | return filtered_signal_list 140 | 141 | 142 | def save_audio( 143 | filename: str, 144 | waveform: torch.Tensor, # (ch, time) in cpu 145 | sample_rate: float 146 | ): 147 | """ 148 | Save an audio waveform to a file. 149 | 150 | Args: 151 | - filename (str): Output filename. 152 | - waveform (torch.Tensor): Audio waveform tensor. 153 | - sample_rate (float): Sample rate of the audio. 154 | """ 155 | 156 | torchaudio.save(filename, waveform, sample_rate=sample_rate) 157 | 158 | 159 | def plot_waveform(filename: str, 160 | waveform: torch.Tensor, # (ch, time) in cpu 161 | sample_rate: float = 48000, 162 | title: str = None, 163 | sharex: bool = True, 164 | xlim: T.Tuple[float, float] = None, 165 | ylim: T.Tuple[float, float] = None, 166 | color: str = 'orange', 167 | waveform_ref: torch.Tensor = None 168 | ): 169 | """ 170 | Plot an audio waveform. 171 | 172 | Args: 173 | - filename (str): Output filename. 174 | - waveform (torch.Tensor): Audio waveform tensor. 175 | - sample_rate (float, optional): Sample rate of the audio. Defaults to 48000. 176 | - title (str, optional): Title for the plot. 177 | - sharex (bool, optional): Whether to share the x-axis across subplots. Defaults to True. 178 | - xlim (T.Tuple[float, float], optional): Limits for x-axis. 179 | - ylim (T.Tuple[float, float], optional): Limits for y-axis. 180 | - color (str, optional): Color for the waveform plot. Defaults to 'orange'. 181 | - waveform_ref (torch.Tensor, optional): Reference waveform for comparison. 182 | """ 183 | 184 | num_channels, num_frames = waveform.shape 185 | time_axis = torch.arange(0, num_frames) / sample_rate 186 | if waveform_ref is not None: 187 | num_frames_ref = waveform_ref.shape[0] # should be 1D 188 | time_axis_ref = torch.arange(0, num_frames_ref) / sample_rate 189 | 190 | if ylim is None: 191 | margin = 1.1 192 | ylim = (margin * waveform.min(), margin * waveform.max()) 193 | 194 | figure, axes = plt.subplots(num_channels, 1, sharex=sharex) 195 | if num_channels == 1: 196 | axes = [axes] 197 | for c in range(num_channels): 198 | if waveform_ref is None: 199 | axes[c].plot(time_axis, waveform[c], color=color, linewidth=1) 200 | else: 201 | axes[c].plot(time_axis, waveform[c], color=color, alpha=0.5, linewidth=1, label='signal') 202 | axes[c].plot(time_axis_ref, waveform_ref, color='r', alpha=0.5, linewidth=1, label='reference') 203 | axes[c].legend() 204 | axes[c].grid(True) 205 | if num_channels > 1: 206 | axes[c].set_ylabel(f'Channel {c+1}') 207 | if xlim: 208 | axes[c].set_xlim(xlim) 209 | if ylim: 210 | axes[c].set_ylim(ylim) 211 | if title is not None: 212 | figure.suptitle(title) 213 | 214 | if filename is not None: 215 | plt.savefig(filename) 216 | else: 217 | plt.show(block=False) 218 | 219 | 220 | def print_stats( 221 | waveform: torch.Tensor, 222 | sample_rate: T.Optional[float] = None, 223 | src: T.Optional[str] = None 224 | ): 225 | """ 226 | Print the statistics of a given waveform. 227 | 228 | Args: 229 | - waveform (torch.Tensor): Input audio waveform tensor. 230 | - sample_rate (float, optional): Sample rate of the audio. 231 | - src (str, optional): Source of the audio, for display purposes. 232 | """ 233 | 234 | if src: 235 | print("-" * 10) 236 | print("Source:", src) 237 | print("-" * 10) 238 | if sample_rate: 239 | print("Sample Rate:", sample_rate) 240 | print("Shape:", tuple(waveform.shape)) 241 | print("Dtype:", waveform.dtype) 242 | print(f" - Max: {waveform.max().item():6.3f}") 243 | print(f" - Min: {waveform.min().item():6.3f}") 244 | print(f" - Mean: {waveform.mean().item():6.3f}") 245 | print(f" - Std Dev: {waveform.std().item():6.3f}") 246 | print() 247 | print(waveform) 248 | print() 249 | 250 | 251 | def plot_specgram( 252 | filename: str, 253 | waveform: torch.Tensor, 254 | sample_rate: float, 255 | title: T.Optional[str] = None, 256 | xlim: T.Optional[T.Tuple[float, float]] = None 257 | ): 258 | """ 259 | Plot the spectrogram of a given audio waveform. 260 | 261 | Args: 262 | - filename (str): Output filename. 263 | - waveform (torch.Tensor): Audio waveform tensor. 264 | - sample_rate (float): Sample rate of the audio. 265 | - title (str, optional): Title for the plot. 266 | - xlim (T.Tuple[float, float], optional): Limits for x-axis. 267 | """ 268 | 269 | waveform = waveform.numpy() if isinstance(waveform, torch.Tensor) else waveform 270 | 271 | num_channels, num_frames = waveform.shape 272 | time_axis = torch.arange(0, num_frames) / sample_rate 273 | 274 | figure, axes = plt.subplots(num_channels, 1) 275 | if num_channels == 1: 276 | axes = [axes] 277 | for c in range(num_channels): 278 | axes[c].specgram(waveform[c], Fs=sample_rate) 279 | if num_channels > 1: 280 | axes[c].set_ylabel(f'Channel {c+1}') 281 | if xlim: 282 | axes[c].set_xlim(xlim) 283 | if title is not None: 284 | figure.suptitle(title) 285 | 286 | if filename is not None: 287 | plt.savefig(filename) 288 | else: 289 | plt.show(block=False) 290 | 291 | 292 | def plot_debug( 293 | filename: str, 294 | waveform: torch.Tensor, 295 | sample_rate: float, 296 | save_png: bool = False 297 | ): 298 | """ 299 | Generate and save waveform and spectrogram plots, and save audio waveform to a file. 300 | 301 | Args: 302 | - filename (str): Base filename for outputs. 303 | - waveform (torch.Tensor): Audio waveform tensor. 304 | - sample_rate (float): Sample rate of the audio. 305 | - save_png (bool, optional): Whether to save plots as PNG files. Defaults to False. 306 | """ 307 | 308 | waveform = waveform.reshape(1, -1) 309 | if save_png: 310 | plot_specgram(f'{filename}_specgram.png', waveform, sample_rate) 311 | plot_waveform(f'{filename}_waveform.png', waveform, sample_rate) 312 | torchaudio.save(f'{filename}.wav', waveform, sample_rate) 313 | -------------------------------------------------------------------------------- /nvas3d/utils/training_data_generation/generate_metadata_square.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import json 8 | import random 9 | import argparse 10 | import matplotlib.pyplot as plt 11 | from collections import defaultdict 12 | from itertools import combinations 13 | 14 | from soundspaces_nvas3d.utils.aihabitat_utils import load_room_grid 15 | from nvas3d.utils.utils import MP3D_SCENE_SPLITS 16 | 17 | 18 | def find_squares(points, grid_distance, tolerance=1e-5): 19 | squares = [] 20 | for i, point1 in enumerate(points): 21 | x1, y1, z1 = point1 22 | for j, point2 in enumerate(points): 23 | if i != j: 24 | x2, y2, z2 = point2 25 | if abs(x2 - (x1 + grid_distance)) < tolerance and abs(z2 - z1) < tolerance: 26 | for k, point3 in enumerate(points): 27 | if k != i and k != j: 28 | x3, y3, z3 = point3 29 | if abs(x3 - x1) < tolerance and abs(z3 - (z1 + grid_distance)) < tolerance: 30 | for l, point4 in enumerate(points): 31 | if l != i and l != j and l != k: 32 | x4, y4, z4 = point4 33 | if abs(x4 - (x1 + grid_distance)) < tolerance and abs(z4 - (z1 + grid_distance)) < tolerance: 34 | squares.append((i, j, l, k)) 35 | return squares 36 | 37 | 38 | def plot_points_and_squares(filename, points, squares): 39 | fig, ax = plt.subplots() 40 | 41 | # plot all points 42 | ax.scatter(*zip(*[(x, z) for x, _, z in points]), c='blue') 43 | 44 | # plot squares 45 | for square_indices in squares: 46 | square_points = [points[i] for i in square_indices] 47 | square_points.append(square_points[0]) # Add the first point again to close the square 48 | ax.plot(*zip(*[(x, z) for x, _, z in square_points]), c='red') 49 | 50 | ax.set_aspect('equal', 'box') # Ensure the aspect ratio is equal 51 | plt.savefig(filename) 52 | plt.show() 53 | 54 | 55 | def main(args): 56 | dataset_name = args.dataset_name 57 | num_pairs_per_room = args.num_pairs_per_room 58 | random.seed(0) 59 | 60 | grid_distance = args.grid_distance 61 | grid_distance_str = str(grid_distance).replace(".", "_") 62 | 63 | os.makedirs(f'data/{dataset_name}/metadata/grid_{grid_distance_str}', exist_ok=True) 64 | 65 | filesize_valid_total = 0.0 66 | filesize_invalid_total = 0.0 67 | num_valid_total = 0.0 68 | num_invalid_total = 0.0 69 | mean_firstnz_total = 0.0 70 | mean_rt60_total = 0.0 71 | mean_maximum_total = 0.0 72 | mean_mean_total = 0.0 73 | mean_num_points_total = 0.0 74 | mean_room_size_total = 0.0 75 | 76 | count_total = 0 77 | 78 | is_debug = False 79 | # Read jsons 80 | for split in ['demo']: # ['train', 'val', 'test', 'demo']: 81 | 82 | filesize_valid_total_ = 0.0 83 | filesize_invalid_total_ = 0.0 84 | num_valid_total_ = 0.0 85 | num_invalid_total_ = 0.0 86 | mean_firstnz_total_ = 0.0 87 | mean_rt60_total_ = 0.0 88 | mean_maximum_total_ = 0.0 89 | mean_mean_total_ = 0.0 90 | mean_num_points_total_ = 0.0 91 | mean_room_size_total_ = 0.0 92 | count_split = 0 93 | 94 | for i_room, room in enumerate(MP3D_SCENE_SPLITS[split]): 95 | if is_debug: 96 | print(f'room: {room} in {split} ({i_room}/{len(MP3D_SCENE_SPLITS[split])})') 97 | filename = f'data/metadata_grid/grid_{grid_distance_str}/{room}.json' 98 | with open(filename, 'r') as file: 99 | metadata = json.load(file) 100 | 101 | filesize_valid_total += metadata['filesize_valid_in_mb'] 102 | filesize_invalid_total += metadata['filesize_invalid_in_mb'] 103 | num_valid_total += metadata['num_valid'] 104 | num_invalid_total += metadata['num_invalid'] 105 | mean_firstnz_total = metadata['mean_firstnz'] 106 | mean_rt60_total = metadata['mean_rt60'] 107 | mean_maximum_total = metadata['mean_maximum'] 108 | mean_mean_total = metadata['mean_mean'] 109 | mean_num_points_total = metadata['num_points'] 110 | mean_room_size_total = metadata['room_size'] 111 | 112 | filesize_valid_total_ += metadata['filesize_valid_in_mb'] 113 | filesize_invalid_total_ += metadata['filesize_invalid_in_mb'] 114 | num_valid_total_ += metadata['num_valid'] 115 | num_invalid_total_ += metadata['num_invalid'] 116 | mean_firstnz_total_ = metadata['mean_firstnz'] 117 | mean_rt60_total_ = metadata['mean_rt60'] 118 | mean_maximum_total_ = metadata['mean_maximum'] 119 | mean_mean_total_ = metadata['mean_mean'] 120 | mean_num_points_total_ = metadata['num_points'] 121 | mean_room_size_total_ = metadata['room_size'] 122 | 123 | count_split += 1 124 | count_total += 1 125 | 126 | # find valid 127 | points = load_room_grid(room, grid_distance)['grid_points'] 128 | squares = find_squares(points, grid_distance) 129 | # filename = f'data/metadata_grid/grid_{grid_distance_str}/{room}_square.png' 130 | # plot_points_and_squares(filename, points, squares) 131 | 132 | # Initialize defaultdicts to count source_idx and receiver_idx occurrences 133 | source_idx_counts = defaultdict(set) 134 | receiver_idx_counts = defaultdict(set) 135 | 136 | # Iterate over keys in valid metadata 137 | for key in metadata.keys(): 138 | if key.startswith(room): 139 | # Extract source_idx and receiver_idx from key 140 | room, source_idx, receiver_idx = key.split("_") 141 | 142 | if metadata[key]['Is Valid'] == 'True': 143 | # Add receiver_idx to the set associated with source_idx 144 | # and vice versa 145 | source_idx_counts[source_idx].add(receiver_idx) 146 | receiver_idx_counts[receiver_idx].add(source_idx) 147 | 148 | # Initialize the empty dictionary where each value is a set 149 | square_to_source_idxs = defaultdict(set) 150 | 151 | # Iterate over each source_idx and its corresponding valid indices 152 | for source_idx, valid_indices in source_idx_counts.items(): 153 | # Iterate over each square 154 | for square in squares: 155 | # Check if all indices of the current square are valid for the current source_idx 156 | if all(str(idx) in valid_indices for idx in square): 157 | # Add the source_idx to the square's set of valid source_idxs 158 | square_to_source_idxs[square].add(int(source_idx)) 159 | 160 | # Filter out squares that have less than 3 valid source_idxs (2 for positive, 1 for negative) 161 | three_source_idx_squares = {square: source_idxs for square, source_idxs in square_to_source_idxs.items() if len(source_idxs) >= 3} 162 | 163 | # Initialize the list of pairs 164 | pairs = [] 165 | 166 | # Iterate over the dictionary to build pairs 167 | for square, source_idxs in three_source_idx_squares.items(): 168 | # Generate all combinations of 3 source indices 169 | for source_idx_pair in combinations(source_idxs, 3): 170 | # Avoid adding pairs where any source_idx is in square (receiver indices) 171 | if not any(source_idx in square for source_idx in source_idx_pair): 172 | # Find novel receiver that is valid to source 1 and source 2 173 | common_receiver = source_idx_counts[str(source_idx_pair[0])].intersection(source_idx_counts[str(source_idx_pair[1])]) 174 | common_receiver = common_receiver - set(square) - set([source_idx_pair[0]]) - set([source_idx_pair[1]]) - set(square) 175 | if common_receiver: 176 | novel = int(random.choice(list(common_receiver))) 177 | 178 | # Add the pair to the list 179 | pairs.append((source_idx_pair, square, novel)) 180 | 181 | # If there are less than num_pairs_per_room 182 | if len(pairs) < num_pairs_per_room: 183 | selected_pairs = pairs 184 | else: 185 | # Randomly select num_pairs_per_room 186 | selected_pairs = random.sample(pairs, num_pairs_per_room) 187 | 188 | # Save the lists to JSON files 189 | total_dict = {} 190 | total_dict['squares'] = squares 191 | total_dict['selected_pairs'] = selected_pairs 192 | # total_dict['square_to_source_idxs_keys'] = list(square_to_source_idxs.keys()) 193 | # total_dict['square_to_source_idxs_values'] = list(square_to_source_idxs.values()) 194 | 195 | filename = f'data/{dataset_name}/metadata/grid_{grid_distance_str}/{room}_square.json' 196 | with open(filename, 'w') as file: 197 | json.dump(total_dict, file) 198 | 199 | if is_debug: 200 | print(split) 201 | print(filesize_valid_total_ / 1024) # GB 202 | print(filesize_invalid_total_ / 1024) # GB 203 | print(num_valid_total_) 204 | print(num_invalid_total_) 205 | print(mean_firstnz_total_ / count_split) 206 | print(mean_rt60_total_ / count_split) 207 | print(mean_maximum_total_ / count_split) 208 | print(mean_mean_total_ / count_split) 209 | print(mean_num_points_total_ / count_split) 210 | print(mean_room_size_total_ / count_split) 211 | 212 | if is_debug: 213 | print(filesize_valid_total / 1024) # GB 214 | print(filesize_invalid_total / 1024) # GB 215 | print(num_valid_total) 216 | print(num_invalid_total) 217 | print(mean_firstnz_total / count_total) 218 | print(mean_rt60_total / count_total) 219 | print(mean_maximum_total / count_total) 220 | print(mean_mean_total / count_total) 221 | print(mean_num_points_total / count_total) 222 | print(mean_room_size_total / count_total) 223 | 224 | 225 | if __name__ == '__main__': 226 | parser = argparse.ArgumentParser(description="Generate metadata for square-shape microphone array.") 227 | 228 | parser.add_argument('--grid_distance', 229 | default=1.0, 230 | type=float, 231 | help='Distance between grid points in meters') 232 | 233 | parser.add_argument('--num_pairs_per_room', 234 | default=1000, 235 | type=int, 236 | help='Number of pairs per room to generate') 237 | 238 | parser.add_argument('--dataset_name', 239 | default='nvas3d_square', 240 | type=str, 241 | help='Name of the dataset') 242 | 243 | args = parser.parse_args() 244 | main(args) 245 | -------------------------------------------------------------------------------- /nvas3d/model/model.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | # Portions of this code are derived from VisualVoice (CC-BY-NC). 6 | # Original work available at: https://github.com/facebookresearch/VisualVoice 7 | # 8 | 9 | import torch 10 | import torchvision 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | class NVASNet(nn.Module): 16 | def __init__(self, num_receivers, use_visual=False): 17 | super(NVASNet, self).__init__() 18 | 19 | self.num_receivers = num_receivers 20 | self.use_visual = use_visual 21 | 22 | if self.use_visual: 23 | self.rgb_net = VisualNet(torchvision.models.resnet18(pretrained=True), 3) 24 | self.depth_net = VisualNet(torchvision.models.resnet18(pretrained=True), 1) 25 | 26 | concat_size = 512 * 2 27 | self.pooling = nn.AdaptiveAvgPool2d((1, 1)) 28 | self.conv1x1 = create_conv(concat_size, 512, 1, 0) 29 | 30 | input_channel = 2 * num_receivers 31 | output_channel = 2 32 | 33 | self.audio_net = AudioNet(64, input_channel, output_channel, self.use_visual) 34 | self.audio_net.apply(weights_init) 35 | 36 | def forward(self, inputs, disable_detection=False): 37 | visual_features = [] 38 | if self.use_visual: 39 | visual_features.append(self.rgb_net(inputs['rgb'])) 40 | visual_features.append(self.depth_net(inputs['depth'])) 41 | 42 | if len(visual_features) != 0: 43 | # concatenate channel-wise 44 | concat_visual_features = torch.cat(visual_features, dim=1) 45 | concat_visual_features = self.conv1x1(concat_visual_features) 46 | concat_visual_features = self.pooling(concat_visual_features) 47 | else: 48 | concat_visual_features = None 49 | 50 | # Dereverber 51 | pred_stft, audio_feat, source_detection = self.audio_net(inputs['input_stft'], concat_visual_features, disable_detection) 52 | output = {'pred_stft': pred_stft} 53 | 54 | # Source identifier 55 | output['source_detection'] = source_detection 56 | 57 | if len(visual_features) != 0: 58 | audio_embed = self.pooling(audio_feat).squeeze(-1).squeeze(-1) 59 | visual_embed = concat_visual_features.squeeze(-1).squeeze(-1) 60 | output['audio_feat'] = F.normalize(audio_embed, p=2, dim=1) 61 | output['visual_feat'] = F.normalize(visual_embed, p=2, dim=1) 62 | 63 | return output 64 | 65 | 66 | def unet_conv(input_nc, output_nc, use_norm=False, norm_layer=nn.BatchNorm2d): 67 | downconv = nn.Conv2d(input_nc, output_nc, kernel_size=4, stride=2, padding=1) 68 | downrelu = nn.LeakyReLU(0.2, True) 69 | if use_norm: 70 | downnorm = norm_layer(output_nc) 71 | return nn.Sequential(*[downconv, downnorm, downrelu]) 72 | else: 73 | return nn.Sequential(*[downconv, downrelu]) 74 | 75 | 76 | def unet_upconv(input_nc, output_nc, outermost=False, use_sigmoid=False, use_tanh=False, use_norm=False, norm_layer=nn.BatchNorm2d): 77 | upconv = nn.ConvTranspose2d(input_nc, output_nc, kernel_size=4, stride=2, padding=1) 78 | uprelu = nn.ReLU(True) 79 | if use_norm and not outermost: 80 | upnorm = norm_layer(output_nc) 81 | return nn.Sequential(*[upconv, upnorm, uprelu]) 82 | else: 83 | if outermost: 84 | if use_sigmoid: 85 | return nn.Sequential(*[upconv, nn.Sigmoid()]) 86 | elif use_tanh: 87 | return nn.Sequential(*[upconv, nn.Tanh()]) 88 | else: 89 | return nn.Sequential(*[upconv]) 90 | else: 91 | return nn.Sequential(*[upconv, uprelu]) 92 | 93 | 94 | class conv_block(nn.Module): 95 | def __init__(self, ch_in, ch_out, use_norm=False): 96 | super(conv_block, self).__init__() 97 | if use_norm: 98 | self.conv = nn.Sequential( 99 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 100 | nn.BatchNorm2d(ch_out), 101 | nn.LeakyReLU(0.2, True), 102 | nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 103 | nn.BatchNorm2d(ch_out), 104 | nn.LeakyReLU(0.2, True) 105 | ) 106 | else: 107 | self.conv = nn.Sequential( 108 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 109 | nn.LeakyReLU(0.2, True), 110 | nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 111 | nn.LeakyReLU(0.2, True) 112 | ) 113 | 114 | def forward(self, x): 115 | x = self.conv(x) 116 | return x 117 | 118 | 119 | class up_conv(nn.Module): 120 | def __init__(self, ch_in, ch_out, outermost=False, use_norm=False, scale_factor=(2., 1.)): 121 | super(up_conv, self).__init__() 122 | if not outermost: 123 | if use_norm: 124 | self.up = nn.Sequential( 125 | nn.Upsample(scale_factor=scale_factor), 126 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 127 | nn.BatchNorm2d(ch_out), 128 | nn.ReLU(inplace=True) 129 | ) 130 | else: 131 | self.up = nn.Sequential( 132 | nn.Upsample(scale_factor=scale_factor), 133 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 134 | nn.ReLU(inplace=True) 135 | ) 136 | else: 137 | self.up = nn.Sequential( 138 | nn.Upsample(scale_factor=scale_factor), 139 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True), 140 | nn.Sigmoid() 141 | ) 142 | 143 | def forward(self, x): 144 | x = self.up(x) 145 | return x 146 | 147 | 148 | def create_conv(input_channels, output_channels, kernel, paddings, batch_norm=False, use_relu=True, stride=1): 149 | model = [nn.Conv2d(input_channels, output_channels, kernel, stride=stride, padding=paddings)] 150 | if batch_norm: 151 | model.append(nn.BatchNorm2d(output_channels)) 152 | if use_relu: 153 | model.append(nn.ReLU()) 154 | return nn.Sequential(*model) 155 | 156 | 157 | def weights_init(m): 158 | classname = m.__class__.__name__ 159 | if classname.find('Conv') != -1: 160 | m.weight.data.normal_(0.0, 0.02) 161 | elif classname.find('BatchNorm2d') != -1: 162 | m.weight.data.normal_(1.0, 0.02) 163 | m.bias.data.fill_(0) 164 | elif classname.find('Linear') != -1: 165 | m.weight.data.normal_(0.0, 0.02) 166 | 167 | 168 | class VisualNet(nn.Module): 169 | def __init__(self, original_resnet, num_channel=3): 170 | super(VisualNet, self).__init__() 171 | original_resnet.conv1 = nn.Conv2d(num_channel, 64, kernel_size=7, stride=2, padding=3, bias=False) 172 | layers = list(original_resnet.children())[0:-2] 173 | self.feature_extraction = nn.Sequential(*layers) # features before conv1x1 174 | 175 | def forward(self, x): 176 | x = self.feature_extraction(x) 177 | return x 178 | 179 | 180 | class AudioNet(nn.Module): 181 | def __init__(self, ngf=64, input_nc=2, output_nc=2, use_visual=False, use_norm=True, audioVisual_feature_dim=512): 182 | super(AudioNet, self).__init__() 183 | 184 | self.use_visual = use_visual 185 | if use_visual: 186 | audioVisual_feature_dim += 512 187 | 188 | # initialize layers 189 | self.audionet_convlayer1 = unet_conv(input_nc, ngf, use_norm=False) 190 | self.audionet_convlayer2 = unet_conv(ngf, ngf * 2, use_norm=use_norm) 191 | self.audionet_convlayer3 = conv_block(ngf * 2, ngf * 4, use_norm=use_norm) 192 | self.audionet_convlayer4 = conv_block(ngf * 4, ngf * 8, use_norm=use_norm) 193 | self.audionet_convlayer5 = conv_block(ngf * 8, ngf * 8, use_norm=use_norm) 194 | self.audionet_convlayer6 = conv_block(ngf * 8, ngf * 8, use_norm=use_norm) 195 | self.audionet_convlayer7 = conv_block(ngf * 8, ngf * 8, use_norm=use_norm) 196 | self.audionet_convlayer8 = conv_block(ngf * 8, ngf * 8, use_norm=use_norm) 197 | self.frequency_time_pool = nn.MaxPool2d([2, 2]) 198 | self.frequency_pool = nn.MaxPool2d([2, 1]) 199 | self.audionet_upconvlayer1 = up_conv(audioVisual_feature_dim, ngf * 8, use_norm=use_norm) 200 | self.audionet_upconvlayer2 = up_conv(ngf * 16, ngf * 8, use_norm=use_norm) 201 | self.audionet_upconvlayer3 = up_conv(ngf * 16, ngf * 8, use_norm=use_norm, scale_factor=(2., 2.)) 202 | self.audionet_upconvlayer4 = up_conv(ngf * 16, ngf * 8, use_norm=use_norm, scale_factor=(2., 2.)) 203 | self.audionet_upconvlayer5 = up_conv(ngf * 16, ngf * 4, use_norm=use_norm, scale_factor=(2., 2.)) 204 | self.audionet_upconvlayer6 = up_conv(ngf * 8, ngf * 2, use_norm=use_norm, scale_factor=(2., 2.)) 205 | self.audionet_upconvlayer7 = unet_upconv(ngf * 4, ngf, use_norm=use_norm) 206 | self.audionet_upconvlayer8 = unet_upconv(ngf * 2, output_nc, True, use_norm=use_norm) 207 | self.Sigmoid = nn.Sigmoid() 208 | self.Tanh = nn.Tanh() 209 | 210 | # Source identifier 211 | self.source_detector = nn.Sequential( 212 | nn.Conv2d(audioVisual_feature_dim, 64, kernel_size=3, stride=1, padding=1), # maintains spatial dimensions as 4x4 213 | nn.ReLU(), 214 | nn.MaxPool2d(kernel_size=2, stride=2), # reduces spatial dimensions to 2x2 215 | nn.Conv2d(64, 32, kernel_size=2, stride=1), # reduces spatial dimensions to 1x1 216 | nn.ReLU(), 217 | nn.Flatten(), 218 | nn.Linear(32, 1), 219 | ) 220 | 221 | def forward(self, audio_mix_stft, visual_feat, disable_detection=False): 222 | audio_conv1feature = self.audionet_convlayer1(audio_mix_stft) 223 | audio_conv2feature = self.audionet_convlayer2(audio_conv1feature) 224 | audio_conv3feature = self.audionet_convlayer3(audio_conv2feature) 225 | audio_conv3feature = self.frequency_time_pool(audio_conv3feature) 226 | audio_conv4feature = self.audionet_convlayer4(audio_conv3feature) 227 | audio_conv4feature = self.frequency_time_pool(audio_conv4feature) 228 | audio_conv5feature = self.audionet_convlayer5(audio_conv4feature) 229 | audio_conv5feature = self.frequency_time_pool(audio_conv5feature) 230 | audio_conv6feature = self.audionet_convlayer6(audio_conv5feature) 231 | audio_conv6feature = self.frequency_time_pool(audio_conv6feature) 232 | audio_conv7feature = self.audionet_convlayer7(audio_conv6feature) 233 | audio_conv7feature = self.frequency_pool(audio_conv7feature) 234 | audio_conv8feature = self.audionet_convlayer8(audio_conv7feature) 235 | audio_conv8feature = self.frequency_pool(audio_conv8feature) 236 | 237 | audioVisual_feature = audio_conv8feature 238 | if self.use_visual: 239 | visual_feat = visual_feat.view(visual_feat.shape[0], -1, 1, 1) # flatten visual feature 240 | visual_feat = visual_feat.repeat(1, 1, audio_conv8feature.shape[-2], 241 | audio_conv8feature.shape[-1]) # tile visual feature 242 | 243 | audioVisual_feature = torch.cat((visual_feat, audioVisual_feature), dim=1) 244 | 245 | if not disable_detection: 246 | source_detection = self.source_detector(audioVisual_feature) 247 | else: 248 | source_detection = torch.tensor([[0.0]]) 249 | 250 | audio_upconv1feature = self.audionet_upconvlayer1(audioVisual_feature) 251 | audio_upconv2feature = self.audionet_upconvlayer2(torch.cat((audio_upconv1feature, audio_conv7feature), dim=1)) 252 | audio_upconv3feature = self.audionet_upconvlayer3(torch.cat((audio_upconv2feature, audio_conv6feature), dim=1)) 253 | audio_upconv4feature = self.audionet_upconvlayer4(torch.cat((audio_upconv3feature, audio_conv5feature), dim=1)) 254 | audio_upconv5feature = self.audionet_upconvlayer5(torch.cat((audio_upconv4feature, audio_conv4feature), dim=1)) 255 | audio_upconv6feature = self.audionet_upconvlayer6(torch.cat((audio_upconv5feature, audio_conv3feature), dim=1)) 256 | audio_upconv7feature = self.audionet_upconvlayer7(torch.cat((audio_upconv6feature, audio_conv2feature), dim=1)) 257 | prediction = self.audionet_upconvlayer8(torch.cat((audio_upconv7feature, audio_conv1feature), dim=1)) 258 | 259 | pred_stft = prediction 260 | 261 | return pred_stft, audio_conv8feature, source_detection 262 | -------------------------------------------------------------------------------- /nvas3d/utils/training_data_generation/generate_training_data.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import glob 8 | import json 9 | import random 10 | import shutil 11 | from tqdm import tqdm 12 | 13 | import torch 14 | import torchaudio 15 | 16 | from soundspaces_nvas3d.utils.aihabitat_utils import load_room_grid 17 | from soundspaces_nvas3d.utils.audio_utils import wiener_deconv_list 18 | from nvas3d.utils.audio_utils import clip_two 19 | from nvas3d.utils.utils import normalize, parse_librispeech_metadata, MP3D_SCENE_SPLITS 20 | from nvas3d.utils.generate_dataset_utils import sample_speech, sample_nonspeech, sample_acoustic_guitar, sample_all, clip_source, load_ir_source_receiver, save_audio_list, compute_reverb 21 | 22 | SOURCE1_DATA = 'all' # speech, nonspeech, guitar 23 | SOURCE2_DATA = 'all' 24 | 25 | random.seed(42) 26 | 27 | DATASET_NAME = f'nvas3d_square_{SOURCE1_DATA}_{SOURCE2_DATA}' 28 | os.makedirs(f'data/{DATASET_NAME}', exist_ok=True) 29 | 30 | grid_distance = 1.0 31 | grid_distance_str = str(grid_distance).replace(".", "_") 32 | target_shape_t = 256 33 | ir_length = 72000 34 | ir_clip_idx = ir_length - 1 35 | hop_length = 480 36 | len_clip = hop_length * (target_shape_t - 1) + ir_length - 1 37 | sample_rate = 48000 38 | snr = 100 39 | audio_format = 'flac' 40 | 41 | for split in ['train', 'val', 'test', 'demo']: 42 | # LibriSpeech 43 | if split == 'train': 44 | librispeech_dir = f'data/MIDI/clip/Speech/LibriSpeech48k/train' 45 | elif split == 'val': 46 | librispeech_dir = f'data/MIDI/clip/Speech/LibriSpeech48k/validation' 47 | elif split == 'test': 48 | librispeech_dir = f'data/MIDI/clip/Speech/LibriSpeech48k/test' 49 | else: 50 | librispeech_dir = f'data/MIDI/clip/Speech/LibriSpeech48k/validation' 51 | files_librispeech = glob.glob(librispeech_dir + '/**/*.flac', recursive=True) 52 | librispeech_metadata = parse_librispeech_metadata(f'data/MIDI/clip/Speech/LibriSpeech48k/SPEAKERS.TXT') 53 | 54 | # MIDI 55 | if split == 'train': 56 | all_instruments_dir = [path for path in glob.glob(os.path.join('data/MIDI/clip', '*/*', 'train')) if os.path.isdir(path)] 57 | elif split == 'val': 58 | all_instruments_dir = [path for path in glob.glob(os.path.join('data/MIDI/clip', '*/*', 'validation')) if os.path.isdir(path)] 59 | elif split == 'test': 60 | all_instruments_dir = [path for path in glob.glob(os.path.join('data/MIDI/clip', '*/*', 'test')) if os.path.isdir(path)] 61 | else: 62 | all_instruments_dir = [path for path in glob.glob(os.path.join('data/MIDI/clip', '*/*', 'validation')) if os.path.isdir(path)] 63 | 64 | # RIR 65 | ir_dir = f'data/nvas3d_square/ir/{split}/grid_{grid_distance_str}' 66 | 67 | # Image 68 | dirname_sourceimage = f'data/nvas3d_square/image/{split}/grid_{grid_distance_str}' 69 | 70 | # Iterate over rooms 71 | for i_room, room in enumerate(tqdm(MP3D_SCENE_SPLITS[split])): 72 | grid_points = load_room_grid(room, grid_distance)['grid_points'] 73 | total_pairs = [] 74 | filename = f'data/nvas3d_square/metadata/grid_{grid_distance_str}/{room}_square.json' 75 | with open(filename, 'r') as file: 76 | square_data = json.load(file) 77 | pairs = square_data['selected_pairs'] 78 | # Add each pair with room id to the total list 79 | for pair in pairs: 80 | total_pairs.append((room,) + tuple(pair)) 81 | 82 | random.shuffle(total_pairs) 83 | 84 | for i_pair, pair in enumerate(tqdm(total_pairs)): 85 | dirname = f'data/{DATASET_NAME}/{split}/{room}/{i_pair}' 86 | # os.makedirs(dirname, exist_ok=True) 87 | 88 | room, source_idx_list, receiver_idx_list, novel_receiver_idx = pair 89 | 90 | # Compute source 91 | if SOURCE1_DATA == 'speech': 92 | source1_audio, source1_class = sample_speech(files_librispeech, librispeech_metadata) 93 | elif SOURCE1_DATA == 'nonspeech': 94 | source1_audio, source1_class = sample_nonspeech(all_instruments_dir) 95 | elif SOURCE1_DATA == 'guitar': 96 | source1_audio, source1_class = sample_acoustic_guitar(all_instruments_dir) 97 | else: 98 | source1_audio, source1_class = sample_all(all_instruments_dir, librispeech_metadata) 99 | 100 | if SOURCE2_DATA == 'speech': 101 | source2_audio, source2_class = sample_speech(files_librispeech, librispeech_metadata) 102 | elif SOURCE2_DATA == 'nonspeech': 103 | source2_audio, source2_class = sample_nonspeech(all_instruments_dir) 104 | elif SOURCE2_DATA == 'guitar': 105 | source2_audio, source2_class = sample_acoustic_guitar(all_instruments_dir) 106 | else: 107 | source2_audio, source2_class = sample_all(all_instruments_dir, librispeech_metadata) 108 | 109 | source1_audio, source2_audio = clip_two(source1_audio, source2_audio) 110 | if not split == 'test': 111 | source1_audio, source2_audio = clip_source(source1_audio, source2_audio, len_clip) 112 | 113 | source1_audio = normalize(source1_audio) 114 | source2_audio = normalize(source2_audio) 115 | 116 | if torch.isnan(source1_audio).any() or torch.isnan(source2_audio).any(): 117 | continue 118 | 119 | if split == 'test': 120 | if source1_audio.shape[0] < sample_rate * 10: # skip for short (<10s) 121 | continue 122 | 123 | os.makedirs(f'{dirname}/source', exist_ok=True) 124 | torchaudio.save(f'{dirname}/source/source1.{audio_format}', source1_audio[ir_clip_idx:].unsqueeze(0), sample_rate) 125 | torchaudio.save(f'{dirname}/source/source2.{audio_format}', source2_audio[ir_clip_idx:].unsqueeze(0), sample_rate) 126 | 127 | # Save IR 128 | os.makedirs(f'{dirname}/ir_receiver', exist_ok=True) 129 | ir1_list = load_ir_source_receiver(ir_dir, room, source_idx_list[0], receiver_idx_list, ir_length) 130 | ir2_list = load_ir_source_receiver(ir_dir, room, source_idx_list[1], receiver_idx_list, ir_length) 131 | ir3_list = load_ir_source_receiver(ir_dir, room, source_idx_list[2], receiver_idx_list, ir_length) 132 | save_audio_list(f'{dirname}/ir_receiver/ir1', ir1_list, sample_rate, audio_format) 133 | save_audio_list(f'{dirname}/ir_receiver/ir2', ir2_list, sample_rate, audio_format) 134 | save_audio_list(f'{dirname}/ir_receiver/ir3', ir3_list, sample_rate, audio_format) 135 | 136 | os.makedirs(f'{dirname}/ir_novel', exist_ok=True) 137 | ir1_novel = load_ir_source_receiver(ir_dir, room, source_idx_list[0], [novel_receiver_idx], ir_length)[0] 138 | ir2_novel = load_ir_source_receiver(ir_dir, room, source_idx_list[1], [novel_receiver_idx], ir_length)[0] 139 | torchaudio.save(f'{dirname}/ir_novel/ir1_novel.{audio_format}', ir1_novel.unsqueeze(0), sample_rate) 140 | torchaudio.save(f'{dirname}/ir_novel/ir2_novel.{audio_format}', ir2_novel.unsqueeze(0), sample_rate) 141 | 142 | # Save reverb 143 | os.makedirs(f'{dirname}/reverb', exist_ok=True) 144 | reverb1_list = compute_reverb(source1_audio, ir1_list) 145 | reverb2_list = compute_reverb(source2_audio, ir2_list) 146 | save_audio_list(f'{dirname}/reverb/reverb1', reverb1_list, sample_rate, audio_format) 147 | save_audio_list(f'{dirname}/reverb/reverb2', reverb2_list, sample_rate, audio_format) 148 | 149 | # Compute receiver 150 | os.makedirs(f'{dirname}/receiver', exist_ok=True) 151 | receiver_list = [reverb1 + reverb2 for reverb1, reverb2 in zip(reverb1_list, reverb2_list)] 152 | save_audio_list(f'{dirname}/receiver/receiver', receiver_list, sample_rate, audio_format) 153 | 154 | # Save Weiner 155 | os.makedirs(f'{dirname}/wiener', exist_ok=True) 156 | wiener1_list = wiener_deconv_list(receiver_list, ir1_list, snr) 157 | wiener2_list = wiener_deconv_list(receiver_list, ir2_list, snr) 158 | wiener3_list = wiener_deconv_list(receiver_list, ir3_list, snr) 159 | save_audio_list(f'{dirname}/wiener/wiener1', wiener1_list, sample_rate, audio_format) 160 | save_audio_list(f'{dirname}/wiener/wiener2', wiener2_list, sample_rate, audio_format) 161 | save_audio_list(f'{dirname}/wiener/wiener3', wiener3_list, sample_rate, audio_format) 162 | 163 | # Copy image 164 | os.makedirs(f'{dirname}/image', exist_ok=True) 165 | 166 | if source1_class.lower() == 'male' or source1_class.lower() == 'female': 167 | source1_class_render = source1_class.lower() 168 | else: 169 | source1_class_render = 'guitar' 170 | 171 | if source2_class.lower() == 'male' or source2_class.lower() == 'female': 172 | source2_class_render = source2_class.lower() 173 | else: 174 | source2_class_render = 'guitar' 175 | rgb1 = f'{dirname_sourceimage}/{room}/rgb_{room}_{source1_class_render}_{source_idx_list[0]}.npy' 176 | rgb2 = f'{dirname_sourceimage}/{room}/rgb_{room}_{source2_class_render}_{source_idx_list[1]}.npy' 177 | depth1 = f'{dirname_sourceimage}/{room}/depth_{room}_{source1_class_render}_{source_idx_list[0]}.npy' 178 | depth2 = f'{dirname_sourceimage}/{room}/depth_{room}_{source2_class_render}_{source_idx_list[1]}.npy' 179 | 180 | rgb1_bg = f'{dirname_sourceimage}/{room}/rgb_{room}_{source_idx_list[0]}.npy' 181 | rgb2_bg = f'{dirname_sourceimage}/{room}/rgb_{room}_{source_idx_list[1]}.npy' 182 | rgb3_bg = f'{dirname_sourceimage}/{room}/rgb_{room}_{source_idx_list[2]}.npy' 183 | depth1_bg = f'{dirname_sourceimage}/{room}/depth_{room}_{source_idx_list[0]}.npy' 184 | depth2_bg = f'{dirname_sourceimage}/{room}/depth_{room}_{source_idx_list[1]}.npy' 185 | depth3_bg = f'{dirname_sourceimage}/{room}/depth_{room}_{source_idx_list[2]}.npy' 186 | 187 | shutil.copy(rgb1, f'{dirname}/image/rgb1.npy') 188 | shutil.copy(rgb2, f'{dirname}/image/rgb2.npy') 189 | shutil.copy(depth1, f'{dirname}/image/depth1.npy') 190 | shutil.copy(depth2, f'{dirname}/image/depth2.npy') 191 | 192 | shutil.copy(rgb1_bg, f'{dirname}/image/rgb1_bg.npy') 193 | shutil.copy(rgb2_bg, f'{dirname}/image/rgb2_bg.npy') 194 | shutil.copy(rgb3_bg, f'{dirname}/image/rgb3_bg.npy') 195 | shutil.copy(depth1_bg, f'{dirname}/image/depth1_bg.npy') 196 | shutil.copy(depth2_bg, f'{dirname}/image/depth2_bg.npy') 197 | shutil.copy(depth3_bg, f'{dirname}/image/depth3_bg.npy') 198 | 199 | # # png (optional) 200 | # rgb1 = f'{dirname_sourceimage}/{room}/rgb_{room}_{source1_class_render}_{source_idx_list[0]}.png' 201 | # rgb2 = f'{dirname_sourceimage}/{room}/rgb_{room}_{source2_class_render}_{source_idx_list[1]}.png' 202 | # depth1 = f'{dirname_sourceimage}/{room}/depth_{room}_{source1_class_render}_{source_idx_list[0]}.png' 203 | # depth2 = f'{dirname_sourceimage}/{room}/depth_{room}_{source2_class_render}_{source_idx_list[1]}.png' 204 | 205 | # rgb1_bg = f'{dirname_sourceimage}/{room}/rgb_{room}_{source_idx_list[0]}.png' 206 | # rgb2_bg = f'{dirname_sourceimage}/{room}/rgb_{room}_{source_idx_list[1]}.png' 207 | # rgb3_bg = f'{dirname_sourceimage}/{room}/rgb_{room}_{source_idx_list[2]}.png' 208 | # depth1_bg = f'{dirname_sourceimage}/{room}/depth_{room}_{source_idx_list[0]}.png' 209 | # depth2_bg = f'{dirname_sourceimage}/{room}/depth_{room}_{source_idx_list[1]}.png' 210 | # depth3_bg = f'{dirname_sourceimage}/{room}/depth_{room}_{source_idx_list[2]}.png' 211 | 212 | # shutil.copy(rgb1, f'{dirname}/image/rgb1.png') 213 | # shutil.copy(rgb2, f'{dirname}/image/rgb2.png') 214 | # shutil.copy(depth1, f'{dirname}/image/depth1.png') 215 | # shutil.copy(depth2, f'{dirname}/image/depth2.png') 216 | 217 | # shutil.copy(rgb1_bg, f'{dirname}/image/rgb1_bg.png') 218 | # shutil.copy(rgb2_bg, f'{dirname}/image/rgb2_bg.png') 219 | # shutil.copy(rgb3_bg, f'{dirname}/image/rgb3_bg.png') 220 | # shutil.copy(depth1_bg, f'{dirname}/image/depth1_bg.png') 221 | # shutil.copy(depth2_bg, f'{dirname}/image/depth2_bg.png') 222 | # shutil.copy(depth3_bg, f'{dirname}/image/depth3_bg.png') 223 | 224 | # Save metadata 225 | metadata = { 226 | 'source1_idx': source_idx_list[0], 227 | 'source2_idx': source_idx_list[1], 228 | 'source3_idx': source_idx_list[2], 229 | 'receiver_idx_list': receiver_idx_list, 230 | 'novel_receiver_idx': novel_receiver_idx, 231 | 'source1_class': source1_class, 232 | 'source2_class': source2_class, 233 | 'grid_points': grid_points, 234 | 'room': room, 235 | } 236 | torch.save(metadata, f'{dirname}/metadata.pt') 237 | 238 | pass 239 | -------------------------------------------------------------------------------- /demo/generate_demo_data.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import json 8 | import random 9 | import argparse 10 | import subprocess 11 | import typing as T 12 | 13 | import torch 14 | import torchaudio 15 | 16 | from soundspaces_nvas3d.utils.ss_utils import render_ir_parallel_room_idx, create_scene 17 | from soundspaces_nvas3d.utils.aihabitat_utils import load_room_grid 18 | from soundspaces_nvas3d.utils.audio_utils import wiener_deconv_list 19 | from nvas3d.utils.audio_utils import clip_two 20 | from nvas3d.utils.utils import normalize 21 | from nvas3d.utils.generate_dataset_utils import load_ir_source_receiver, save_audio_list, compute_reverb 22 | 23 | 24 | def generate_rir( 25 | args: argparse.Namespace, 26 | room: str, 27 | source_idx_list: T.List[int], 28 | receiver_idx_list: T.List[int] 29 | ): 30 | """ 31 | Generates and saves Room Impulse Response (RIR) data for pairs of source_idx_list and receiver_idx_list. 32 | 33 | Args: 34 | - args: Parsed command line arguments for dirname and grid distance. 35 | - room: Name of the room. 36 | - source_idx_list: List of source indices. 37 | - receiver_idx_list: List of receiver indices. 38 | """ 39 | 40 | ir_dir = f'data/{args.dataset_dir}/temp/ir/grid_{str(args.grid_distance).replace(".", "_")}' 41 | filename_ir = f'{ir_dir}/{room}/ir' 42 | os.makedirs(filename_ir, exist_ok=True) 43 | render_ir_parallel_room_idx(room, source_idx_list, receiver_idx_list, filename_ir, args.grid_distance) 44 | 45 | 46 | def visualize_grid( 47 | args: argparse.Namespace, 48 | room: str, 49 | image_size: T.Tuple[int, int] 50 | ): 51 | """ 52 | Visualizes grid points for a given room. 53 | 54 | Args: 55 | - args: Parsed command line arguments for dirname and grid distance. 56 | - room: Name of the room. 57 | - image_size: Dimensions of the output image. 58 | """ 59 | 60 | scene = create_scene(room, image_size=image_size) 61 | os.makedirs(f'data/{args.dataset_dir}/demo/{room}', exist_ok=True) 62 | scene.generate_xy_grid_points(args.grid_distance, filename_png=f'data/{args.dataset_dir}/demo/{room}/index_{str(args.grid_distance).replace(".", "_")}.png') 63 | scene.generate_xy_grid_points(0.5, filename_png=f'data/{args.dataset_dir}/demo/{room}/index_0_5.png') 64 | 65 | 66 | def process_audio_sources( 67 | dirname: str, 68 | source1_path: str, 69 | source2_path: str, 70 | rir_clip_idx: int, 71 | sample_rate: int, 72 | audio_format: str 73 | ) -> T.Tuple[torch.Tensor, torch.Tensor]: 74 | """ 75 | Preprocess and saves audio for dry sound. 76 | 77 | Args: 78 | - dirname: Directory name for saving the processed audio. 79 | - source1_path: Path to the first audio source. 80 | - source2_path: Path to the second audio source. 81 | - rir_clip_idx: Index for clipping the RIR. 82 | - sample_rate: Sampling rate of the audio. 83 | - audio_format: Format to save the audio (e.g., 'flac', 'wav'). 84 | 85 | Returns: 86 | - Processed audio tensors for source1 and source2. 87 | """ 88 | 89 | source1_audio, _ = torchaudio.load(source1_path) 90 | source2_audio, _ = torchaudio.load(source2_path) 91 | source1_audio = source1_audio.reshape(-1) 92 | source2_audio = source2_audio.reshape(-1) 93 | source1_audio, source2_audio = clip_two(source1_audio, source2_audio) 94 | source1_audio = normalize(source1_audio) 95 | source2_audio = normalize(source2_audio) 96 | os.makedirs(f'{dirname}/source', exist_ok=True) 97 | torchaudio.save(f'{dirname}/source/source1.{audio_format}', source1_audio[rir_clip_idx:].unsqueeze(0), sample_rate) 98 | torchaudio.save(f'{dirname}/source/source2.{audio_format}', source2_audio[rir_clip_idx:].unsqueeze(0), sample_rate) 99 | 100 | return source1_audio, source2_audio 101 | 102 | 103 | def save_rir_data( 104 | dirname: str, 105 | rir_dir: str, 106 | room: str, 107 | source_idx_list: T.List[int], 108 | receiver_idx_list: T.List[int], 109 | rir_length: int, 110 | sample_rate: int, 111 | audio_format: str 112 | ) -> T.Tuple[T.List[torch.Tensor], T.List[torch.Tensor]]: 113 | """ 114 | Saves RIR data for the pair of sources and receivers. 115 | 116 | Args: 117 | - dirname: Directory name for saving the RIR data. 118 | - rir_dir: Directory where RIR data is located. 119 | - room: Name of the room. 120 | - source_idx_list: List of source indices. 121 | - receiver_idx_list: List of receiver indices. 122 | - rir_length: Length of the RIR. 123 | - sample_rate: Sampling rate of the audio. 124 | - audio_format: Format to save the audio (e.g., 'flac', 'wav'). 125 | 126 | Returns: 127 | - Lists of IRs from source1 and source2. 128 | """ 129 | 130 | os.makedirs(f'{dirname}/ir_receiver', exist_ok=True) 131 | rir1_list = load_ir_source_receiver(rir_dir, room, source_idx_list[0], receiver_idx_list, rir_length) 132 | rir2_list = load_ir_source_receiver(rir_dir, room, source_idx_list[1], receiver_idx_list, rir_length) 133 | save_audio_list(f'{dirname}/ir_receiver/ir1', rir1_list, sample_rate, audio_format) 134 | save_audio_list(f'{dirname}/ir_receiver/ir2', rir2_list, sample_rate, audio_format) 135 | return rir1_list, rir2_list 136 | 137 | 138 | def reverb_audio( 139 | filename: str, 140 | source_audio: torch.Tensor, 141 | rir_list: T.List[torch.Tensor], 142 | sample_rate: int, 143 | audio_format: str 144 | ) -> T.List[torch.Tensor]: 145 | """ 146 | Applies reverberation to audio using provided IR. 147 | 148 | Args: 149 | - filename: Directory name for saving the reverberant audio. 150 | - source_audio: Source audio tensor. 151 | - rir_list: List of RIR tensors. 152 | - sample_rate: Sampling rate of the audio. 153 | - audio_format: Format to save the audio (e.g., 'flac', 'wav'). 154 | 155 | Returns: 156 | - List of reverberated audio. 157 | """ 158 | 159 | reverb_list = compute_reverb(source_audio, rir_list) 160 | save_audio_list(filename, reverb_list, sample_rate, audio_format) 161 | return reverb_list 162 | 163 | 164 | def mix_audio( 165 | dirname: str, 166 | reverb1_list: T.List[torch.Tensor], 167 | reverb2_list: T.List[torch.Tensor], 168 | sample_rate: int, 169 | audio_format: str 170 | ) -> T.List[torch.Tensor]: 171 | """ 172 | Mixes two reverberant audio. 173 | 174 | Args: 175 | - dirname: Directory name for saving the mixed audio. 176 | - reverb1_list: List of the first reverberant audio. 177 | - reverb2_list: List of the second reverberant audio. 178 | - sample_rate: Sampling rate of the audio. 179 | - audio_format: Format to save the audio (e.g., 'flac'). 180 | 181 | Returns: 182 | - List containing mixed audio. 183 | """ 184 | 185 | os.makedirs(f'{dirname}/receiver', exist_ok=True) 186 | receiver_list = [reverb1 + reverb2 for reverb1, reverb2 in zip(reverb1_list, reverb2_list)] 187 | save_audio_list(f'{dirname}/receiver/receiver', receiver_list, sample_rate, audio_format) 188 | return receiver_list 189 | 190 | 191 | def save_topdown_view( 192 | function_name: str, 193 | args_list: T.List[T.Union[str, int, float]] 194 | ): 195 | """ 196 | Saves a topdown view of a given scene. 197 | 198 | Args: 199 | - function_name: Name of the function to be used for visualization. 200 | - args_list: List of arguments for the visualization function. 201 | (room, source_idx_list, receiver_idx_list, grid_distance) 202 | """ 203 | 204 | subprocess.run(["python", "soundspaces_nvas3d/utils/render_scene_script.py", function_name, *map(str, args_list)]) 205 | 206 | 207 | def main(args): 208 | """ 209 | Generate and save demo data for a specific room. 210 | 211 | Directory Structure and Contents: 212 | ├── {data_demo} = data/nvas3d_demo/demo/{room}/0 213 | │ ├── receiver/ : Receiver audio. 214 | │ ├── wiener/ : Deconvolved audio (auxiliary data to accelerate tests) (wiener_{query_idx}_{receiver_id}). 215 | │ ├── source/ : Ground truth dry audio. 216 | │ ├── reverb1/ : Ground truth reverberant audio for source 1. 217 | │ ├── reverb2/ : Ground truth reverberant audio for source 2. 218 | │ ├── ir_receiver/ : Ground truth RIRs from source to receiver. 219 | │ ├── config.png : Visualization of room configuration. 220 | │ └── metadata.pt : Metadata containing source indices, classes, grid points, and room information. 221 | 222 | Additional Visualizations: 223 | ├── data/nvas3d_demo/{room} : Room index visualizations. 224 | """ 225 | 226 | # Seed for reproducibility 227 | random.seed(42) 228 | 229 | # Directory setup for data 230 | os.makedirs(f'data/{args.dataset_dir}', exist_ok=True) 231 | 232 | # Extract and load room and grid related data 233 | room = args.room 234 | grid_distance = args.grid_distance 235 | grid_points = load_room_grid(room, grid_distance)['grid_points'] 236 | rir_length = args.rir_length 237 | sample_rate = args.sample_rate 238 | snr = args.snr 239 | audio_format = args.audio_format 240 | scene_config = args.scene_config 241 | 242 | source1_path = args.source1_path 243 | source2_path = args.source2_path 244 | 245 | # Set source and receiver locations 246 | with open(f'demo/config_demo/{scene_config}.json', 'r') as file: 247 | json_scene = json.load(file) 248 | source_idx_list = json_scene['source_idx_list'] 249 | receiver_idx_list = json_scene['receiver_idx_list'] 250 | 251 | # Generate and save RIR data 252 | generate_rir(args, room, list(range(len(grid_points))), receiver_idx_list) 253 | 254 | # Prepare directory for demo data 255 | dirname = f'data/{args.dataset_dir}/demo/{room}/0' 256 | os.makedirs(dirname, exist_ok=True) 257 | 258 | # Visualize the grid points within the room 259 | visualize_grid(args, room, (400, 300)) 260 | 261 | # Preprocess audio 262 | rir_clip_idx = rir_length - 1 263 | source1_class, source2_class = 'female', 'drum' 264 | source1_audio, source2_audio = process_audio_sources(dirname, source1_path, source2_path, rir_clip_idx, sample_rate, audio_format) 265 | 266 | # Reverb audio for source 1 and source 2 267 | rir_dir = f'data/{args.dataset_dir}/temp/ir/grid_{str(grid_distance).replace(".", "_")}' 268 | rir1_list, rir2_list = save_rir_data(dirname, rir_dir, room, source_idx_list, receiver_idx_list, rir_length, sample_rate, audio_format) 269 | os.makedirs(f'{dirname}/reverb', exist_ok=True) 270 | reverb1_list = reverb_audio(f'{dirname}/reverb/reverb1', source1_audio, rir1_list, sample_rate, audio_format) 271 | reverb2_list = reverb_audio(f'{dirname}/reverb/reverb2', source2_audio, rir2_list, sample_rate, audio_format) 272 | 273 | # Mix both reverberant audios 274 | receiver_list = mix_audio(dirname, reverb1_list, reverb2_list, sample_rate, audio_format) 275 | 276 | # Visualize topdown view of the scene 277 | save_topdown_view("scene", [f'{dirname}/config.png', room, source_idx_list[:2], receiver_idx_list, args.grid_distance]) 278 | 279 | # Save metadata 280 | metadata = { 281 | 'source1_idx': source_idx_list[0], 282 | 'source2_idx': source_idx_list[1], 283 | 'receiver_idx_list': receiver_idx_list, 284 | 'source1_class': source1_class, 285 | 'source2_class': source2_class, 286 | 'grid_points': grid_points, 287 | 'grid_distance': grid_distance, 288 | 'room': room, 289 | } 290 | torch.save(metadata, f'{dirname}/metadata.pt') 291 | 292 | # Save deconvolved audio using Wiener deconvolution 293 | for query_idx in range(len(grid_points)): 294 | if (query_idx in receiver_idx_list): 295 | continue 296 | 297 | # load RIR 298 | rir_query_list = load_ir_source_receiver(rir_dir, room, query_idx, receiver_idx_list, rir_length) 299 | 300 | # save Weiner 301 | os.makedirs(f'{dirname}/wiener', exist_ok=True) 302 | wiener_list = wiener_deconv_list(receiver_list, rir_query_list, snr) 303 | save_audio_list(f'{dirname}/wiener/wiener{query_idx}', wiener_list, sample_rate, audio_format) 304 | 305 | 306 | if __name__ == '__main__': 307 | parser = argparse.ArgumentParser() 308 | 309 | parser.add_argument('--room', default='17DRP5sb8fy', type=str, help='mp3d room') 310 | parser.add_argument('--dataset_dir', default='nvas3d_demo', type=str, help='dirname') 311 | parser.add_argument('--grid_distance', default=1.0, type=float, help='Distance between grid points') 312 | parser.add_argument('--rir_length', default=72000, type=int, help='IR length') 313 | parser.add_argument('--sample_rate', default=48000, type=int, help='Sample rate') 314 | parser.add_argument('--snr', default=100, type=int, help='SNR for Wiener deconvolution') 315 | parser.add_argument('--audio_format', default='flac', type=str, help='Audio format to save') 316 | parser.add_argument('--scene_config', default='scene1_17DRP5sb8fy', type=str, help='Scene configuration json') 317 | parser.add_argument('--source1_path', default='data/source/female.flac', type=str, help='Filename for source 1') 318 | parser.add_argument('--source2_path', default='data/source/drum.flac', type=str, help='Filename for source 2') 319 | 320 | args = parser.parse_args() 321 | main(args) 322 | -------------------------------------------------------------------------------- /soundspaces_nvas3d/utils/ss_utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import sys 8 | import random 9 | import itertools 10 | import typing as T 11 | import multiprocessing 12 | import matplotlib.pyplot as plt 13 | from contextlib import contextmanager 14 | from tqdm import tqdm 15 | 16 | import torch 17 | import torchaudio 18 | 19 | from soundspaces_nvas3d.soundspaces_nvas3d import Receiver, Source, Scene 20 | from soundspaces_nvas3d.utils.aihabitat_utils import save_grid_config, load_room_grid 21 | 22 | 23 | @contextmanager 24 | def suppress_stdout_and_stderr(): 25 | """ 26 | Suppress the logs from SoundSpaces. 27 | """ 28 | 29 | original_stdout_fd = os.dup(sys.stdout.fileno()) 30 | original_stderr_fd = os.dup(sys.stderr.fileno()) 31 | devnull = os.open(os.devnull, os.O_WRONLY) 32 | 33 | try: 34 | os.dup2(devnull, sys.stdout.fileno()) 35 | os.dup2(devnull, sys.stderr.fileno()) 36 | yield 37 | finally: 38 | os.dup2(original_stdout_fd, sys.stdout.fileno()) 39 | os.dup2(original_stderr_fd, sys.stderr.fileno()) 40 | os.close(devnull) 41 | 42 | 43 | def create_scene(room: str, 44 | receiver_position: T.Tuple[float, float, float] = [0.0, 0.0, 0.0], 45 | sample_rate: float = 48000, 46 | image_size: T.Tuple[int, int] = (512, 256), 47 | include_visual_sensor: bool = True, 48 | hfov: float = 90.0 49 | ) -> Scene: 50 | """ 51 | Create a soundspaces scene to render IR. 52 | """ 53 | 54 | # Note: Make sure mp3d room is downloaded 55 | with suppress_stdout_and_stderr(): 56 | # Create a receiver 57 | receiver = Receiver( 58 | position=receiver_position, 59 | rotation=0, 60 | sample_rate=sample_rate 61 | ) 62 | 63 | scene = Scene( 64 | room, 65 | [None], # placeholder for source class 66 | receiver=receiver, 67 | include_visual_sensor=include_visual_sensor, 68 | add_source_mesh=False, 69 | device=torch.device('cpu'), 70 | add_source=False, 71 | image_size=image_size, 72 | hfov=hfov 73 | ) 74 | 75 | return scene 76 | 77 | 78 | def render_ir(room: str, 79 | source_position: T.Tuple[float, float, float], 80 | receiver_position: T.Tuple[float, float, float], 81 | filename: str = None, 82 | receiver_rotation: float = None, 83 | sample_rate: float = 48000, 84 | use_default_material: bool = False, 85 | channel_type: str = 'Ambisonics', 86 | channel_order: int = 1 87 | ) -> torch.Tensor: 88 | """ 89 | Render impulse response for a source and receiver pair in the mp3d room. 90 | """ 91 | 92 | if receiver_rotation is None: 93 | receiver_rotation = 90 94 | 95 | # Create a receiver 96 | receiver = Receiver( 97 | position=receiver_position, 98 | rotation=receiver_rotation, 99 | sample_rate=sample_rate 100 | ) 101 | 102 | # Create a source 103 | source = Source( 104 | position=source_position, 105 | rotation=0, 106 | dry_sound='', 107 | mesh='', 108 | device=torch.device('cpu') 109 | ) 110 | 111 | scene = Scene( 112 | room, 113 | [None], # placeholder for source class 114 | receiver=receiver, 115 | source_list=[source], 116 | include_visual_sensor=False, 117 | add_source_mesh=False, 118 | device=torch.device('cpu'), 119 | use_default_material=use_default_material, 120 | channel_type=channel_type, 121 | channel_order=channel_order 122 | ) 123 | 124 | # Render IR 125 | scene.add_audio_sensor() 126 | # with suppress_stdout_and_stderr(): 127 | ir = scene.render_ir(0) 128 | 129 | # Save file if dirname is given 130 | if filename is not None: 131 | torchaudio.save(filename, ir, sample_rate=sample_rate) 132 | else: 133 | return ir 134 | 135 | 136 | def render_rir_parallel(room_list: T.List[str], 137 | source_position_list: T.List[T.Tuple[float, float, float]], 138 | receiver_position_list: T.List[T.Tuple[float, float, float]], 139 | filename_list: T.List[str] = None, 140 | receiver_rotation_list: T.List[float] = None, 141 | batch_size: int = 64, 142 | sample_rate: float = 48000, 143 | use_default_material: bool = False, 144 | channel_type: str = 'Ambisonics', 145 | channel_order: int = 1 146 | ) -> T.List[torch.Tensor]: 147 | """ 148 | Run render_ir parallely for all elements of zip(source_position_list, receiver_position_list). 149 | """ 150 | 151 | assert len(room_list) == len(source_position_list) 152 | assert len(source_position_list) == len(receiver_position_list) 153 | 154 | if filename_list is None: 155 | is_return = True 156 | else: 157 | is_return = False 158 | 159 | if receiver_rotation_list is None: 160 | receiver_rotation_list = [0] * len(receiver_position_list) 161 | 162 | # Note: Make sure all rooms are downloaded 163 | 164 | # Calculate the number of batches 165 | num_points = len(source_position_list) 166 | num_batches = (num_points + batch_size - 1) // batch_size 167 | 168 | # Use tqdm to display the progress bar 169 | progress_bar = tqdm(total=num_points) 170 | 171 | def update_progress(*_): 172 | progress_bar.update() 173 | 174 | ir_list = [] 175 | # Process the tasks in batches 176 | for batch_idx in range(num_batches): 177 | # Calculate the start and end indices of the current batch 178 | start_idx = batch_idx * batch_size 179 | end_idx = min(start_idx + batch_size, num_points) 180 | if is_return: 181 | batch = [(room_list[i], source_position_list[i], receiver_position_list[i], None, receiver_rotation_list[i]) for i in range(start_idx, end_idx)] 182 | else: 183 | batch = [(room_list[i], source_position_list[i], receiver_position_list[i], filename_list[i], receiver_rotation_list[i]) for i in range(start_idx, end_idx)] 184 | 185 | # Create a multiprocessing Pool for the current batch 186 | with multiprocessing.Pool() as pool: 187 | tasks = [] 188 | for room, source_position, receiver_position, filename, receiver_rotation in batch: 189 | # Apply async mapping of process_ir function 190 | task = pool.apply_async(render_ir, args=(room, source_position, receiver_position, filename, receiver_rotation, sample_rate, use_default_material, channel_type, channel_order), callback=update_progress) 191 | tasks.append(task) 192 | 193 | # Wait for all tasks in the batch to complete and collect results 194 | for task in tasks: 195 | if is_return: 196 | ir = task.get() # Block until the result is ready 197 | ir_list.append(ir) # Append the result to the list 198 | else: 199 | task.get() 200 | if is_return: 201 | return ir_list 202 | 203 | 204 | def render_ir_parallel_room_idx(room: str, 205 | source_idx_list: T.List[int], 206 | receiver_idx_list: T.List[int], 207 | filename: str = None, 208 | grid_distance=1.0, 209 | batch_size: int = 64, 210 | sample_rate: float = 48000, 211 | use_default_material: bool = False, 212 | channel_type='Ambisonics' # Binaural 213 | ) -> T.List[torch.Tensor]: 214 | """ 215 | Run render_ir parallely for all elements of all_pair(source_idx_list, receiver_idx_list) 216 | """ 217 | 218 | grid_points = load_room_grid(room, grid_distance=grid_distance)['grid_points'] 219 | 220 | source_idx_pair_list, receiver_idx_pair_list = all_pairs(source_idx_list, receiver_idx_list) # only for filename 221 | receiver_points = grid_points[receiver_idx_list] 222 | source_points = grid_points[source_idx_list] 223 | 224 | source_points_pair, receiver_points_pair = all_pairs(source_points, receiver_points) 225 | 226 | room_list = [room] * len(source_points_pair) 227 | if filename is not None: 228 | filename_list = [f'{filename}_{room}_{source_idx}_{receiver_idx}.wav' 229 | for source_idx, receiver_idx in zip(source_idx_pair_list, receiver_idx_pair_list)] 230 | else: 231 | filename_list = None 232 | 233 | # Render IR for grid points 234 | ir_list = render_rir_parallel(room_list, 235 | source_points_pair, 236 | receiver_points_pair, 237 | filename_list, 238 | batch_size=batch_size, 239 | sample_rate=sample_rate, 240 | use_default_material=use_default_material, 241 | channel_type=channel_type) 242 | 243 | return ir_list, source_idx_pair_list, receiver_idx_pair_list 244 | 245 | 246 | def render_receiver_image(dirname: str, 247 | room: str, 248 | source_idx_list: T.List[int], 249 | source_class_list: T.List[str], 250 | receiver_idx_list: T.List[int], 251 | filename: str = None, 252 | grid_distance=1.0, 253 | hfov=120, 254 | image_size=(1024, 1024) 255 | ): 256 | 257 | # load grid points 258 | grid_points = load_room_grid(room, grid_distance=grid_distance)['grid_points'] 259 | 260 | receiver_points = grid_points[receiver_idx_list] 261 | source_points = grid_points[source_idx_list] 262 | # source_points_pair, receiver_points_pair = all_pairs(source_points, receiver_points) 263 | 264 | # initialize scene 265 | scene = create_scene(room, image_size=image_size, hfov=hfov) 266 | scene.add_source_mesh = True 267 | 268 | # initialize receiver 269 | sample_rate = 48000 270 | position = [0.0, 0, 0] 271 | rotation = 0.0 272 | receiver = Receiver(position, rotation, sample_rate) 273 | 274 | # set source 275 | source_list = [] 276 | for source_idx, source_class in zip(source_idx_list, source_class_list): 277 | position = grid_points[source_idx] 278 | if source_class == 'male' or source_class == 'female': 279 | source = Source( 280 | position=position, 281 | rotation=random.uniform(0, 360), 282 | dry_sound='', 283 | mesh=source_class, 284 | device=torch.device('cpu') 285 | ) 286 | source_list.append(source) 287 | else: 288 | source = Source( 289 | position=position, 290 | rotation=random.uniform(0, 360), 291 | dry_sound='', 292 | mesh='guitar', # All instruments use guitar mesh 293 | device=torch.device('cpu') 294 | ) 295 | source_list.append(source) 296 | 297 | # add mesh 298 | scene.source_list = [None] * len(source_idx_list) 299 | for id, source in enumerate(source_list): 300 | scene.update_source(source, id) 301 | 302 | # see source1 direction 303 | source_x = source_points[0][0] 304 | source_z = source_points[0][2] 305 | 306 | # render images 307 | rgb_list = [] 308 | depth_list = [] 309 | for receiver_idx in receiver_idx_list: 310 | receiver.position = grid_points[receiver_idx] 311 | # all receiver sees source 1 direction 312 | rotation_source1 = calculate_degree(source_x - receiver.position[0], source_z - receiver.position[2]) 313 | receiver.rotation = rotation_source1 314 | scene.update_receiver(receiver) 315 | rgb, depth = scene.render_image() 316 | rgb_list.append(rgb) 317 | depth_list.append(depth) 318 | 319 | # for index debug 320 | # source_idx_pair_list, receiver_idx_pair_list = all_pairs(source_idx_list, receiver_idx_list) 321 | 322 | for i in range(len(receiver_idx_list)): 323 | plt.imsave(f'{dirname}/{i+1}.png', rgb_list[i]) 324 | plt.imsave(f'{dirname}/{i+1}_depth.png', depth_list[i], cmap='gray') 325 | 326 | # return rgb_list, depth_list 327 | 328 | 329 | def render_scene_config(filename: str, 330 | room: str, 331 | source_idx_list: T.List[int], 332 | receiver_idx_list: T.List[int], 333 | grid_distance=1.0, 334 | no_query=False): 335 | 336 | # load grid points 337 | grid_points = load_room_grid(room, grid_distance=grid_distance)['grid_points'] 338 | 339 | # config 340 | scene = create_scene(room) 341 | save_grid_config(filename, scene.sim.pathfinder, grid_points, receiver_idx_list=receiver_idx_list, source_idx_list=source_idx_list, no_query=no_query) 342 | 343 | 344 | # Additional utility functions 345 | def all_pairs(list1, list2): 346 | list_pair = list(itertools.product(list1, list2)) 347 | 348 | list1_pair, list2_pair = zip(*list_pair) 349 | list1_pair = list(list1_pair) 350 | list2_pair = list(list2_pair) 351 | 352 | return list1_pair, list2_pair 353 | 354 | 355 | def calculate_degree(x, y): 356 | radian = torch.atan2(y, x) 357 | degree = torch.rad2deg(radian) 358 | # Adjusting for the described mapping 359 | degree = (-degree - 90) % 360 360 | return degree 361 | -------------------------------------------------------------------------------- /nvas3d/utils/training_data_generation/generate_test_data.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import glob 8 | import json 9 | import random 10 | import subprocess 11 | import concurrent.futures 12 | from itertools import product 13 | 14 | from tqdm import tqdm 15 | 16 | import torch 17 | import torchaudio 18 | 19 | from soundspaces_nvas3d.utils.aihabitat_utils import load_room_grid 20 | from soundspaces_nvas3d.utils.audio_utils import wiener_deconv_list 21 | from nvas3d.utils.audio_utils import clip_two 22 | from nvas3d.utils.utils import normalize, parse_librispeech_metadata, MP3D_SCENE_SPLITS 23 | from nvas3d.utils.generate_dataset_utils import sample_speech, sample_nonspeech, sample_acoustic_guitar, sample_all, sample_instrument, clip_source, load_ir_source_receiver, save_audio_list, compute_reverb 24 | 25 | os.makedirs('data/temp', exist_ok=True) 26 | 27 | 28 | SOURCE1_DATA = 'Guitar' 29 | SOURCE2_DATA = 'Guitar' 30 | 31 | num_id_per_room = 1 32 | 33 | random.seed(42) 34 | 35 | DATASET_NAME = f'nvas3d_square_{SOURCE1_DATA}_{SOURCE2_DATA}_queryall_{num_id_per_room}_v3' 36 | os.makedirs(f'data/{DATASET_NAME}', exist_ok=True) 37 | 38 | grid_distance = 1.0 39 | grid_distance_str = str(grid_distance).replace(".", "_") 40 | target_shape_t = 256 41 | ir_length = 72000 42 | ir_clip_idx = ir_length - 1 43 | hop_length = 480 44 | len_clip = hop_length * (target_shape_t - 1) + ir_length - 1 45 | sample_rate = 48000 46 | snr = 100 47 | audio_format = 'flac' 48 | 49 | for split in ['val']: 50 | # LibriSpeech 51 | if split == 'train': 52 | librispeech_dir = f'data/MIDI/clip/Speech/LibriSpeech48k/train' 53 | elif split == 'val': 54 | librispeech_dir = f'data/MIDI/clip/Speech/LibriSpeech48k/validation' 55 | elif split == 'test': 56 | librispeech_dir = f'data/MIDI/clip/Speech/LibriSpeech48k/test' 57 | else: 58 | librispeech_dir = f'data/MIDI/clip/Speech/LibriSpeech48k/validation' 59 | files_librispeech = glob.glob(librispeech_dir + '/**/*.flac', recursive=True) 60 | librispeech_metadata = parse_librispeech_metadata(f'data/MIDI/clip/Speech/LibriSpeech48k/SPEAKERS.TXT') 61 | 62 | # MIDI 63 | if split == 'train': 64 | all_instruments_dir = [path for path in glob.glob(os.path.join('data/MIDI/clip', '*/*', 'train')) if os.path.isdir(path)] 65 | elif split == 'val': 66 | all_instruments_dir = [path for path in glob.glob(os.path.join('data/MIDI/clip', '*/*', 'validation')) if os.path.isdir(path)] 67 | elif split == 'test': 68 | all_instruments_dir = [path for path in glob.glob(os.path.join('data/MIDI/clip', '*/*', 'test')) if os.path.isdir(path)] 69 | else: 70 | all_instruments_dir = [path for path in glob.glob(os.path.join('data/MIDI/clip', '*/*', 'validation')) if os.path.isdir(path)] 71 | 72 | # RIR 73 | if split == 'val_trainscene': 74 | split_scene = 'train' 75 | else: 76 | split_scene = split 77 | ir_dir = f'data/nvas3d_square/ir/{split_scene}/grid_{grid_distance_str}' 78 | 79 | # Image 80 | dirname_sourceimage = f'data/nvas3d_square/image/{split_scene}/grid_{grid_distance_str}' 81 | 82 | # Iterate over rooms 83 | for i_room, room in enumerate(tqdm(MP3D_SCENE_SPLITS[split_scene])): 84 | grid_points = load_room_grid(room, grid_distance)['grid_points'] 85 | num_points = grid_points.shape[0] 86 | total_pairs = [] 87 | filename = f'data/nvas3d_square/metadata/grid_{grid_distance_str}/{room}_square.json' # from generate_metadata_square.json 88 | with open(filename, 'r') as file: 89 | square_data = json.load(file) 90 | pairs_all = square_data['selected_pairs'] 91 | # Add each pair with room id to the total list 92 | 93 | random.shuffle(pairs_all) 94 | 95 | # pairs = pairs[:num_id_per_room] 96 | 97 | pairs = [] 98 | for pair in pairs_all: 99 | source_idx_list, receiver_idx_list, novel_receiver_idx = pair 100 | if (novel_receiver_idx not in source_idx_list) and (novel_receiver_idx not in receiver_idx_list): 101 | pairs.append(pair) 102 | # else: 103 | # print(f'invalid idx: {source_idx_list}, {receiver_idx_list}, {novel_receiver_idx}') 104 | if len(pairs) >= num_id_per_room: 105 | break 106 | 107 | # All IRs 108 | # Initialize a list to store all combinations 109 | all_combinations = [] 110 | 111 | # Iterate over selected pairs 112 | for pair in pairs: 113 | # Unpack the pair 114 | _, receiver_idxs, _ = pair 115 | # Get all combinations of source and receiver indices 116 | comb = product(list(range(num_points)), receiver_idxs) 117 | # Add these combinations to the list 118 | all_combinations.extend(comb) 119 | all_combinations = list(set(all_combinations)) # remove redundancy 120 | 121 | # download wav files # Replace to render IR 122 | # temp_list = set() 123 | # with concurrent.futures.ThreadPoolExecutor() as executor: 124 | # for source_idx in executor.map(download_wav, all_combinations): 125 | # temp_list.add(source_idx) 126 | # temp_list = list(temp_list) 127 | 128 | # Render image 129 | dirname_target_image = f'data/{DATASET_NAME}/{split}/{room}/image' 130 | os.makedirs(dirname_target_image, exist_ok=True) 131 | query_idx_list = list(range(num_points)) 132 | subprocess.run(['python', 'soundspaces_nvas3d/image_rendering/generate_target_image.py', '--room', room, '--dirname', dirname_target_image, '--source_idx_list', ' '.join(map(str, query_idx_list))]) 133 | 134 | # For each pair, make data 135 | for i_pair, pair in enumerate(tqdm(pairs)): 136 | dirname = f'data/{DATASET_NAME}/{split}/{room}/{i_pair}' 137 | source_idx_list, receiver_idx_list, novel_receiver_idx = pair 138 | 139 | os.makedirs(dirname, exist_ok=True) 140 | 141 | # Compute source 142 | os.makedirs(f'{dirname}/source', exist_ok=True) 143 | if SOURCE1_DATA == 'speech': 144 | source1_audio, source1_class = sample_speech(files_librispeech, librispeech_metadata) 145 | elif SOURCE1_DATA == 'nonspeech': 146 | source1_audio, source1_class = sample_nonspeech(all_instruments_dir) 147 | elif SOURCE1_DATA == 'guitar': 148 | source1_audio, source1_class = sample_acoustic_guitar(all_instruments_dir) 149 | elif SOURCE1_DATA == 'all': 150 | source1_audio, source1_class = sample_all(all_instruments_dir, librispeech_metadata) 151 | else: 152 | source1_audio, source1_class = sample_instrument(all_instruments_dir, librispeech_metadata, SOURCE1_DATA) 153 | 154 | if SOURCE2_DATA == 'speech': 155 | source2_audio, source2_class = sample_speech(files_librispeech, librispeech_metadata) 156 | elif SOURCE2_DATA == 'nonspeech': 157 | source2_audio, source2_class = sample_nonspeech(all_instruments_dir) 158 | elif SOURCE2_DATA == 'guitar': 159 | source2_audio, source2_class = sample_acoustic_guitar(all_instruments_dir) 160 | elif SOURCE2_DATA == 'all': 161 | source2_audio, source2_audio = sample_all(all_instruments_dir, librispeech_metadata) 162 | else: 163 | source2_audio, source2_class = sample_instrument(all_instruments_dir, librispeech_metadata, SOURCE2_DATA) 164 | 165 | source1_audio, source2_audio = clip_two(source1_audio, source2_audio) 166 | if not (split == 'test' or split == 'val'): # check 167 | source1_audio, source2_audio = clip_source(source1_audio, source2_audio, len_clip) 168 | source1_audio = normalize(source1_audio) 169 | source2_audio = normalize(source2_audio) 170 | 171 | if torch.isnan(source1_audio).any() or torch.isnan(source2_audio).any(): 172 | continue 173 | 174 | if split == 'test' or split == 'val': 175 | if source1_audio.shape[0] < sample_rate * 10: # skip for short (<10s) 176 | continue 177 | 178 | os.makedirs(f'{dirname}/source', exist_ok=True) 179 | torchaudio.save(f'{dirname}/source/source1.{audio_format}', source1_audio[ir_clip_idx:].unsqueeze(0), sample_rate) 180 | torchaudio.save(f'{dirname}/source/source2.{audio_format}', source2_audio[ir_clip_idx:].unsqueeze(0), sample_rate) 181 | 182 | # Save IR 183 | os.makedirs(f'{dirname}/ir_receiver', exist_ok=True) 184 | ir1_list = load_ir_source_receiver(ir_dir, room, source_idx_list[0], receiver_idx_list, ir_length) 185 | ir2_list = load_ir_source_receiver(ir_dir, room, source_idx_list[1], receiver_idx_list, ir_length) 186 | ir3_list = load_ir_source_receiver(ir_dir, room, source_idx_list[2], receiver_idx_list, ir_length) 187 | save_audio_list(f'{dirname}/ir_receiver/ir1', ir1_list, sample_rate, audio_format) 188 | save_audio_list(f'{dirname}/ir_receiver/ir2', ir2_list, sample_rate, audio_format) 189 | save_audio_list(f'{dirname}/ir_receiver/ir3', ir3_list, sample_rate, audio_format) 190 | 191 | os.makedirs(f'{dirname}/ir_novel', exist_ok=True) 192 | ir1_novel = load_ir_source_receiver(ir_dir, room, source_idx_list[0], [novel_receiver_idx], ir_length)[0] 193 | ir2_novel = load_ir_source_receiver(ir_dir, room, source_idx_list[1], [novel_receiver_idx], ir_length)[0] 194 | torchaudio.save(f'{dirname}/ir_novel/ir1_novel.{audio_format}', ir1_novel.unsqueeze(0), sample_rate) 195 | torchaudio.save(f'{dirname}/ir_novel/ir2_novel.{audio_format}', ir2_novel.unsqueeze(0), sample_rate) 196 | 197 | # Save reverb 198 | os.makedirs(f'{dirname}/reverb', exist_ok=True) 199 | reverb1_list = compute_reverb(source1_audio, ir1_list) 200 | reverb2_list = compute_reverb(source2_audio, ir2_list) 201 | save_audio_list(f'{dirname}/reverb/reverb1', reverb1_list, sample_rate, audio_format) 202 | save_audio_list(f'{dirname}/reverb/reverb2', reverb2_list, sample_rate, audio_format) 203 | 204 | # Compute receiver 205 | os.makedirs(f'{dirname}/receiver', exist_ok=True) 206 | receiver_list = [reverb1 + reverb2 for reverb1, reverb2 in zip(reverb1_list, reverb2_list)] 207 | save_audio_list(f'{dirname}/receiver/receiver', receiver_list, sample_rate, audio_format) 208 | 209 | # Compute novel 210 | os.makedirs(f'{dirname}/receiver_novel', exist_ok=True) 211 | reverb1_novel = compute_reverb(source1_audio, [ir1_novel])[0] 212 | reverb2_novel = compute_reverb(source2_audio, [ir2_novel])[0] 213 | receiver_novel = reverb1_novel + reverb2_novel 214 | torchaudio.save(f'{dirname}/receiver_novel/receiver_novel.{audio_format}', receiver_novel.unsqueeze(0), sample_rate) 215 | 216 | # # Render binaural IR and audio 217 | # if not os.path.exists(f'{dirname}/ir_novel_binaural'): 218 | # os.makedirs(f'{dirname}/ir_novel_binaural', exist_ok=True) 219 | # ir_novel_list, source_idx_pair_list, receiver_idx_pair_list = render_ir_parallel_room_idx(room, source_idx_list, [novel_receiver_idx], filename=None, use_default_material=False, channel_type='Binaural') 220 | # for idx_novel, ir_novel in enumerate(ir_novel_list[:2]): # first two sources 221 | # if ir_novel[0].shape[0] > ir_length: 222 | # ir_binaural = ir_novel[:][:ir_length] 223 | # else: 224 | # ir_binaural = F.pad(ir_novel[:], (0, ir_length - ir_novel.shape[1])) 225 | # torchaudio.save(f'{dirname}/ir_novel_binaural/ir{idx_novel+1}_novel_binaural.{audio_format}', ir_binaural, sample_rate) 226 | 227 | # # set image rendering 228 | # all_receiver_idx_list = receiver_idx_list.copy() 229 | # all_receiver_idx_list.append(novel_receiver_idx) 230 | 231 | # def run_function(function_name, *args): # use subprocess because of memory leak in soundspaces 232 | # subprocess.run(["python", "script/render_scene_script.py", function_name, *map(str, args)]) 233 | 234 | # # Render novel-view RGBD image 235 | # os.makedirs(f'{dirname}/image_receiver', exist_ok=True) 236 | # source_class_list = [source1_class, source2_class] 237 | # run_function("receiver", f'{dirname}/image_receiver', room, source_idx_list[:2], source_class_list, all_receiver_idx_list) 238 | 239 | # # Save topdown view 240 | # run_function("scene", f'{dirname}/config.png', room, source_idx_list[:2], all_receiver_idx_list) 241 | 242 | # Save metadata 243 | metadata = { 244 | 'source1_idx': source_idx_list[0], 245 | 'source2_idx': source_idx_list[1], 246 | 'receiver_idx_list': receiver_idx_list, 247 | 'novel_receiver_idx': novel_receiver_idx, 248 | 'source1_class': source1_class, 249 | 'source2_class': source2_class, 250 | 'grid_points': grid_points, 251 | 'room': room, 252 | } 253 | 254 | torch.save(metadata, f'{dirname}/metadata.pt') 255 | 256 | # Query 257 | for query_idx in range(num_points): 258 | if (query_idx in receiver_idx_list): 259 | continue 260 | 261 | # Load IR 262 | ir_query_list = load_ir_source_receiver(f'data/temp/ir', room, query_idx, receiver_idx_list, ir_length) 263 | 264 | # Save Weiner 265 | os.makedirs(f'{dirname}/wiener', exist_ok=True) 266 | wiener_list = wiener_deconv_list(receiver_list, ir_query_list, snr) 267 | 268 | save_audio_list(f'{dirname}/wiener/wiener{query_idx}', wiener_list, sample_rate, audio_format) 269 | 270 | # # debug delay 271 | # delay_ir = torch.nonzero(ir_query_list[0], as_tuple=True)[0][0] 272 | # receiver_point = grid_points[receiver_idx_list[0]] 273 | # query_point = grid_points[query_idx] 274 | 275 | # c = 343 276 | # dist = (query_point - receiver_point).norm() - 0.1 # soundspaces offset 277 | # delay_dist = (int(dist / c * sample_rate)) 278 | # print(delay_ir, delay_dist) 279 | # pass 280 | -------------------------------------------------------------------------------- /nvas3d/train/trainer.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import glob 8 | import time 9 | import logging 10 | import typing as T 11 | import matplotlib.pyplot as plt 12 | 13 | from tqdm import tqdm 14 | from collections import defaultdict 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.distributed as dist 19 | from torch.utils.tensorboard import SummaryWriter 20 | 21 | from nvas3d.utils.audio_utils import spec2wav, compute_metrics 22 | from nvas3d.utils.utils import source_class_map, get_key_from_value 23 | from nvas3d.utils.plot_utils import plot_debug 24 | 25 | 26 | class Trainer: 27 | def __init__( 28 | self, 29 | model: torch.nn.Module, 30 | data_loader, 31 | device: torch.device, 32 | save_dir: str, 33 | use_deconv: bool, 34 | config: T.Dict[str, T.Any], 35 | save_audio: bool = False 36 | ) -> None: 37 | """ 38 | Initializes the Trainer class. 39 | 40 | Args: 41 | - model: Model to be trained. 42 | - data_loader: Data loader supplying training data. 43 | - device: Device (CPU or GPU) to be used for training. 44 | - save_dir: Directory to save the model. 45 | - use_deconv: Flag to use deconvolution. 46 | - config: Configuration with training parameters. 47 | """ 48 | 49 | # Args 50 | self.model = model.to(device, dtype=torch.float) 51 | self.data_loader = data_loader 52 | self.device = device 53 | self.save_dir = save_dir 54 | self.use_deconv = use_deconv 55 | self.save_audio = save_audio 56 | 57 | # Optimzer 58 | self.optimizer = torch.optim.Adam( 59 | filter(lambda p: p.requires_grad, self.model.parameters()), 60 | lr=config['lr'], 61 | weight_decay=config['weight_decay'] 62 | ) 63 | 64 | # Loss 65 | self.regressor_criterion = nn.MSELoss().to(device=self.device) 66 | self.bce_criterion = nn.BCEWithLogitsLoss().to(device=self.device) 67 | 68 | # Training configuration 69 | self.epochs = config['epochs'] 70 | self.resume = config['resume'] 71 | self.detect_loss_weight = config['detect_loss_weight'] 72 | self.save_checkpoint_interval = config['save_checkpoint_interval'] 73 | 74 | # Tensorboard 75 | tb_dir = f'{self.save_dir}/tb' 76 | os.makedirs(tb_dir, exist_ok=True) 77 | self.writer = SummaryWriter(log_dir=tb_dir) 78 | 79 | # Resume 80 | if self.resume: 81 | print(f'Resume from {config["checkpoint_dir"]}...') 82 | latest_checkpoint = self.find_latest_checkpoint(config['checkpoint_dir']) 83 | if latest_checkpoint is not None: 84 | start_epoch = self.load_checkpoint(latest_checkpoint) + 1 85 | else: 86 | start_epoch = 0 87 | else: 88 | start_epoch = 0 89 | self.start_epoch = start_epoch 90 | 91 | def train(self): 92 | scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, pow(0.1, 1 / self.epochs)) 93 | 94 | since = time.time() 95 | 96 | for epoch in range(self.start_epoch, self.epochs + 1): 97 | logging.info('-' * 10) 98 | logging.info('Epoch {}/{}'.format(epoch, self.epochs)) 99 | 100 | # Set DistributedSampler for each epoch if model is DDP 101 | if isinstance(self.model, torch.nn.parallel.DistributedDataParallel) and self.model.training: 102 | self.data_loader.train_sampler.set_epoch(epoch) 103 | 104 | # Initialize variables to store metrics 105 | running_loss = defaultdict(float) 106 | running_metrics = defaultdict(float) 107 | num_data_point = defaultdict(int) 108 | count_metrics = defaultdict(int) 109 | 110 | for train_mode in ['Train', 'Val']: 111 | # Skip validation if not saving checkpoint this epoch 112 | if train_mode == 'Val' and epoch % self.save_checkpoint_interval != 0: 113 | continue 114 | 115 | data_loader = (self.data_loader.get_train_data() if train_mode == 'Train' 116 | else self.data_loader.get_val_data()) 117 | 118 | # Set mode 119 | if train_mode == 'Train': 120 | self.model.train() 121 | else: 122 | self.model.eval() 123 | 124 | with torch.set_grad_enabled(train_mode == 'Train'): 125 | pbar = tqdm(data_loader, desc=f"Epoch {epoch}") 126 | for data, room, id in pbar: 127 | # Preprocess 128 | for key, value in data.items(): 129 | data[key] = value.to(device=self.device, dtype=torch.float) 130 | 131 | # Set mask 132 | positive_mask = (data['is_source'] > 0.5).reshape(-1) 133 | nonzero_mask = torch.any(data['source1_audio'], dim=1) 134 | nan_mask = torch.any(torch.isnan(data['source1_audio']), dim=1) 135 | valid_mask = nonzero_mask & ~nan_mask 136 | valid_positive_mask = positive_mask & valid_mask 137 | if nan_mask.any(): 138 | nan_room = [x for x, m in zip(room, nan_mask) if m] 139 | nan_id = [x for x, m in zip(id, nan_mask) if m] 140 | print(f'{nan_room}, {nan_id}: invalid audio') 141 | data['input_stft'][nan_mask] = 0 # will ignore in the loss 142 | 143 | # Parse data 144 | source1_stft = data['source1_stft'] # [n_batch, 2, F, N] 145 | input_stft = data['input_stft'] # [n_batch, n_receivers*2, F, N] 146 | 147 | source1_stft_nondc = source1_stft[:, :, 1:, :] 148 | 149 | # Zero the parameter gradients 150 | self.optimizer.zero_grad() 151 | 152 | # Forward 153 | output = self.model(data) 154 | pred_stft = output['pred_stft'] # [n_batch, n_receivers*2, F, N] 155 | 156 | # Compute detection loss on valid mask 157 | detect_loss = self.bce_criterion(output['source_detection'][valid_mask], data['is_source'][valid_mask]) 158 | 159 | # Compute magnitude loss on valid positive mask, if any 160 | mag_loss = (self.regressor_criterion(pred_stft[valid_positive_mask], source1_stft_nondc[valid_positive_mask]) 161 | if valid_positive_mask.any() else torch.tensor(0.0, device=self.device)) 162 | 163 | # Combine losses 164 | loss = mag_loss + detect_loss * self.detect_loss_weight 165 | assert not torch.isnan(loss), 'Loss is NaN' 166 | 167 | # Backward 168 | if train_mode == 'Train': 169 | loss.backward() 170 | self.optimizer.step() 171 | 172 | # Save metrics 173 | if epoch % self.save_checkpoint_interval == 0 and valid_positive_mask.any(): 174 | save_idx = valid_positive_mask.nonzero(as_tuple=True)[0][0] 175 | room = room[save_idx] 176 | id_value = int(id[save_idx]) 177 | 178 | # Prepare audio data 179 | full_pred_stft_save = input_stft[save_idx].clone().permute(1, 2, 0)[..., :2] 180 | pred_stft_save = pred_stft[save_idx].permute(1, 2, 0) 181 | full_pred_stft_save[1:, :, :] = pred_stft_save # keep dc from input 182 | pred_audio = spec2wav(full_pred_stft_save) 183 | source1_audio_save = data['source1_audio'][save_idx] 184 | 185 | # Save debug data 186 | if self.save_audio: 187 | self.save_audio_data(save_idx, pred_audio, source1_audio_save, data, train_mode, room, id_value, epoch) 188 | 189 | # Compute metrics for dry-sound estimation 190 | dry_psnr, dry_sdr = compute_metrics(pred_audio, source1_audio_save) 191 | 192 | # Compute metrics for detection 193 | predicted = (output['source_detection'] > 0.5).int() 194 | accuracy_true = (predicted[data['is_source'] == 1] == 1).float().mean().item() 195 | accuracy_false = (predicted[data['is_source'] == 0] == 0).float().mean().item() 196 | detection_accuracy = (predicted == data['is_source']).float().mean().item() 197 | 198 | metrics_dict = { 199 | 'dry_psnr': dry_psnr, 200 | 'dry_sdr': dry_sdr, 201 | 'detection_accuracy': detection_accuracy, 202 | 'accuracy_true': accuracy_true, 203 | 'accuracy_false': accuracy_false 204 | } 205 | 206 | # Metric reduction for DDP 207 | if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): 208 | for metric, value in metrics_dict.items(): 209 | value_tensor = torch.tensor(value).to(self.device) 210 | dist.all_reduce(value_tensor, op=dist.ReduceOp.SUM) 211 | running_metrics[f'{train_mode}/{metric}'] += value_tensor.item() 212 | count_metrics[f'{train_mode}/{metric}'] += dist.get_world_size() 213 | else: 214 | for metric, value in metrics_dict.items(): 215 | running_metrics[f'{train_mode}/{metric}'] += value 216 | count_metrics[f'{train_mode}/{metric}'] += 1 217 | 218 | # Loss reduction for DDP 219 | loss_dict = { 220 | 'total_loss': loss, 221 | 'mag_loss': mag_loss, 222 | 'detect_loss': detect_loss 223 | } 224 | B = input_stft.shape[0] 225 | 226 | if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): 227 | for loss_type, loss_value in loss_dict.items(): 228 | loss_tensor = torch.tensor(loss_value).to(self.device) 229 | dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM) 230 | normalized_loss = loss_tensor.item() / dist.get_world_size() 231 | running_loss[f'{train_mode}/{loss_type}'] += normalized_loss 232 | else: 233 | for loss_type, loss_value in loss_dict.items(): 234 | current_loss = loss_dict[loss_type].item() * B 235 | running_loss[f'{train_mode}/{loss_type}'] += current_loss 236 | num_data_point[train_mode] += B 237 | 238 | # Log, save checkpoints, write tensorboard 239 | should_log = not isinstance(self.model, torch.nn.parallel.DistributedDataParallel) or (dist.get_rank() == 0) 240 | if should_log: 241 | # Log and save checkpoints 242 | self.log_and_save(running_loss, running_metrics, epoch) 243 | 244 | # Write tensorboard 245 | for key, value in running_loss.items(): 246 | self.writer.add_scalar(key, value, epoch) 247 | for key, value in running_metrics.items(): 248 | self.writer.add_scalar(key, value, epoch) 249 | 250 | scheduler.step() 251 | 252 | time_elapsed = time.time() - since 253 | logging.info(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s') 254 | 255 | print('Training finished.') 256 | 257 | def save_checkpoint(self, epoch): 258 | """ 259 | Save checkpoint. 260 | """ 261 | 262 | if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): 263 | checkpoint = { 264 | 'epoch': epoch, 265 | 'model_state_dict': self.model.module.state_dict(), 266 | 'optimizer_state_dict': self.optimizer.state_dict() 267 | } 268 | 269 | else: 270 | checkpoint = { 271 | 'epoch': epoch, 272 | 'model_state_dict': self.model.state_dict(), 273 | 'optimizer_state_dict': self.optimizer.state_dict() 274 | } 275 | 276 | filename = f'checkpoints/checkpoint_{epoch}.pt' 277 | os.makedirs(os.path.join(self.save_dir, 'checkpoints'), exist_ok=True) 278 | torch.save(checkpoint, os.path.join(self.save_dir, filename)) 279 | 280 | def load_checkpoint(self, checkpoint_path): 281 | """ 282 | Load checkpoint. 283 | """ 284 | 285 | checkpoint = torch.load(checkpoint_path) 286 | self.model.load_state_dict(checkpoint['model_state_dict']) 287 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 288 | return checkpoint['epoch'] 289 | 290 | def find_latest_checkpoint(self, checkpoint_dir): 291 | """ 292 | Find latest checkpoint. 293 | """ 294 | 295 | checkpoint_files = glob.glob(os.path.join(checkpoint_dir, '**', 'checkpoint_*.pt'), recursive=True) 296 | if checkpoint_files: 297 | latest_checkpoint = max(checkpoint_files, key=os.path.getctime) 298 | return latest_checkpoint 299 | else: 300 | return None 301 | 302 | def log_and_save(self, running_loss, running_metrics, epoch): 303 | """ 304 | Log values to Tensorboard and save model checkpoint. 305 | """ 306 | 307 | # Logging to Tensorboard 308 | for key, value in running_loss.items(): 309 | self.writer.add_scalar(key, value, epoch) 310 | for key, value in running_metrics.items(): 311 | self.writer.add_scalar(key, value, epoch) 312 | 313 | # Save checkpoint 314 | if epoch % self.save_checkpoint_interval == 0: 315 | self.save_checkpoint(epoch) 316 | 317 | def save_audio_data(self, save_idx, pred_audio, source1_audio_save, data, train_mode, room, id_value, epoch, sample_rate=48000, save_normalized=True): 318 | """ 319 | Function to save audio data. 320 | """ 321 | 322 | # Set directory 323 | source1_class_save = get_key_from_value(source_class_map, data['source1_class'][save_idx]) 324 | source2_class_save = get_key_from_value(source_class_map, data['source2_class'][save_idx]) 325 | debug_dir = f'{self.save_dir}/{train_mode}_audio/{source1_class_save}-{source2_class_save}' 326 | 327 | # Prepare audio 328 | input_audio_mean = data['input_audio'][save_idx].mean(dim=0) 329 | audio_name_list = ['pred', 'input_mean', 'pred_gt', 'input_0', 'receiver_0'] 330 | audio_data_list = [pred_audio, input_audio_mean, source1_audio_save, data['input_audio'][save_idx][0], data['receiver_audio'][save_idx][0]] 331 | 332 | # Save 333 | os.makedirs(f'{debug_dir}/{room}/{id_value}', exist_ok=True) 334 | for audio_name, audio_data in zip(audio_name_list, audio_data_list): 335 | filename_val = f'{debug_dir}/{room}/{id_value}/{epoch}_{audio_name}' 336 | plot_debug(filename_val, audio_data.reshape(1, -1).detach().cpu(), 337 | sample_rate=sample_rate, save_normalized=save_normalized, reference=source1_audio_save) 338 | 339 | # Save RGB image if available 340 | if 'rgb' in data.keys(): 341 | filename_val = f'{debug_dir}/{room}/{id_value}/{epoch}_rgb.png' 342 | plt.imsave(filename_val, data['rgb'][save_idx].permute(1, 2, 0).detach().cpu().numpy()) 343 | -------------------------------------------------------------------------------- /nvas3d/baseline/baseline_dsp.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE file. 3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import os 7 | import json 8 | import random 9 | import argparse 10 | 11 | import torch 12 | import torchaudio 13 | from tqdm import tqdm 14 | from scipy.signal import fftconvolve 15 | from torchmetrics.classification import BinaryAUROC 16 | 17 | from nvas3d.utils.audio_utils import cs_list, compute_metrics_si 18 | 19 | 20 | def normalize(audio): 21 | """ 22 | Normalize the signal making the maximum of absolute value to be 1 23 | """ 24 | return audio / audio.abs().max() 25 | 26 | 27 | def compute_delayed(audio_list, delay_list): 28 | """ 29 | Remove delay of audios 30 | """ 31 | 32 | delayed_audio_list = [] 33 | for audio, delay in zip(audio_list, delay_list): 34 | delayed_audio = torch.zeros_like(audio) 35 | delayed_audio[:, :-delay] = audio[:, delay:] 36 | 37 | delayed_audio_list.append(delayed_audio) 38 | return delayed_audio_list 39 | 40 | 41 | class BaselineDSP: 42 | def __init__(self, split, num_receivers, audio_format, dataset_name): 43 | self.num_receivers = num_receivers 44 | self.split = split 45 | self.audio_format = audio_format 46 | self.dataset_name = dataset_name 47 | self.dataset_dir = f'data/{self.dataset_name}/{split}' 48 | self.room_id_pairs = self._get_room_id_pairs() 49 | # random.shuffle(self.room_id_pairs) 50 | 51 | def _get_room_id_pairs(self): 52 | pairs = [] 53 | for room in os.listdir(self.dataset_dir): 54 | room_path = os.path.join(self.dataset_dir, room) 55 | if os.path.isdir(room_path): # Ensure room_path is a directory 56 | for id in os.listdir(room_path): 57 | id_path = os.path.join(room_path, id) 58 | if os.path.isdir(id_path) and id != 'image': # Ensure id_path is a directory and its name isn't 'image' 59 | pairs.append((room, id)) 60 | return pairs 61 | 62 | def load_audio_list(self, audio_list_name): 63 | audio_list = [] 64 | for m in range(self.num_receivers): 65 | audio, _ = torchaudio.load(f'{audio_list_name}_{m+1}.{self.audio_format}') 66 | audio_list.append(audio) 67 | return audio_list 68 | 69 | def separation(self, num_eval=None): 70 | """ 71 | Baseline DSP for separation using delay-and-sum and deconv-and-sum 72 | """ 73 | 74 | # metrics_names = ['psnr', 'sdr', 'sipsnr', 'sisdr'] 75 | metrics_names = ['psnr', 'sdr'] 76 | methods = ['dry1_deconv', 'dry2_deconv', 'reverb1_deconv', 'reverb2_deconv', 77 | 'dry1_delay', 'dry2_delay', 'reverb1_delay', 'reverb2_delay'] 78 | metrics = {method: {metric: [] for metric in metrics_names} for method in methods} 79 | 80 | loop_length = num_eval if num_eval is not None else len(self.room_id_pairs) 81 | for room_id_pair in tqdm(self.room_id_pairs[:loop_length]): 82 | room, id = room_id_pair 83 | data_dir = f'{self.dataset_dir}/{room}/{id}' 84 | # Set query: positive or negative 85 | 86 | metadata = torch.load(f'{data_dir}/metadata.pt') 87 | grid_points = metadata['grid_points'] 88 | num_points = grid_points.shape[0] 89 | source1_idx = metadata['source1_idx'] 90 | source2_idx = metadata['source2_idx'] 91 | receiver_idx_list = metadata['receiver_idx_list'] 92 | 93 | source1_audio, _ = torchaudio.load(f'{data_dir}/source/source1.{self.audio_format}') 94 | source1_audio = source1_audio[0] 95 | source2_audio, _ = torchaudio.load(f'{data_dir}/source/source2.{self.audio_format}') 96 | source2_audio = source2_audio[0] 97 | reverb1_audio = self.load_audio_list(f'{data_dir}/reverb/reverb1') 98 | reverb1_audio = reverb1_audio[0][0] 99 | reverb2_audio = self.load_audio_list(f'{data_dir}/reverb/reverb2') 100 | reverb2_audio = reverb2_audio[0][0] 101 | receiver_audio = self.load_audio_list(f'{data_dir}/receiver/receiver') 102 | 103 | ir1 = self.load_audio_list(f'{data_dir}/ir_receiver/ir1') 104 | ir2 = self.load_audio_list(f'{data_dir}/ir_receiver/ir2') 105 | 106 | for query_idx in range(num_points): 107 | if (query_idx not in receiver_idx_list): 108 | deconv_audio = self.load_audio_list(f'{data_dir}/wiener/wiener{query_idx}') 109 | 110 | deconv_and_sum = torch.stack(deconv_audio, dim=0).reshape(self.num_receivers, -1).mean(dim=0) 111 | 112 | # compute delay from distance 113 | c = 343 114 | sample_rate = 48000 115 | query_point = grid_points[query_idx] 116 | delay_list = [] 117 | for receiver_idx in receiver_idx_list: 118 | receiver_point = grid_points[receiver_idx] 119 | dist = (query_point - receiver_point).norm() - 0.1 # soundspaces offset 120 | delay_list.append(int(dist / c * sample_rate)) 121 | 122 | delayed_audio = compute_delayed(receiver_audio, delay_list) 123 | delay_and_sum_0 = torch.stack(delayed_audio, dim=0).reshape(self.num_receivers, -1).mean(dim=0) 124 | 125 | # # apply delay again to align with reverb1 for fair evaluation 126 | # delay = delay_list[0] 127 | # delay_and_sum = torch.zeros_like(delay_and_sum_0) 128 | # delay_and_sum[delay:] = delay_and_sum_0[:-delay] 129 | 130 | if query_idx == source1_idx: 131 | dry1 = deconv_and_sum 132 | dry1_delay = delay_and_sum_0 133 | 134 | receiver_point = grid_points[receiver_idx_list[0]] 135 | dist = (query_point - receiver_point).norm() - 0.1 # soundspaces offset 136 | delay = int(dist / c * sample_rate) 137 | reverb1_delay = torch.zeros_like(delay_and_sum_0) 138 | reverb1_delay[delay:] = delay_and_sum_0[:-delay] 139 | 140 | # reverb1 141 | elif query_idx == source2_idx: 142 | dry2 = deconv_and_sum 143 | dry2_delay = delay_and_sum_0 144 | 145 | receiver_point = grid_points[receiver_idx_list[0]] 146 | dist = (query_point - receiver_point).norm() - 0.1 # soundspaces offset 147 | delay = int(dist / c * sample_rate) 148 | reverb2_delay = torch.zeros_like(delay_and_sum_0) 149 | reverb2_delay[delay:] = delay_and_sum_0[:-delay] 150 | 151 | m = 0 152 | ir1_m = ir1[m][0] 153 | pred_reverb1 = torch.from_numpy(fftconvolve(dry1, ir1_m)[:len(reverb1_audio)]) 154 | 155 | ir2_m = ir2[m][0] 156 | pred_reverb2 = torch.from_numpy(fftconvolve(dry2, ir2_m)[:len(reverb2_audio)]) 157 | 158 | def compute_and_save_metrics(processed_audio, reference_audio, metric_lists, filename, save_audio=True): 159 | psnr_score, sdr_score, _, _ = compute_metrics_si(processed_audio, reference_audio) 160 | 161 | metric_lists['psnr'].append(psnr_score) 162 | metric_lists['sdr'].append(sdr_score) 163 | 164 | # Saving the processed audio 165 | if save_audio: 166 | os.makedirs(os.path.dirname(filename), exist_ok=True) # Ensuring the directory exists 167 | torchaudio.save(filename, normalize(processed_audio.unsqueeze(0)), sample_rate=48000) 168 | 169 | current_dir = f'test_results/{self.dataset_name}/baseline/{self.split}/{room}/{id}' 170 | os.makedirs(current_dir, exist_ok=True) 171 | 172 | # dry (deconv-and-sum) 173 | filename = f'{current_dir}/dry1.wav' 174 | compute_and_save_metrics(dry1, source1_audio, metrics['dry1_deconv'], filename) 175 | filename = f'{current_dir}/dry2.wav' 176 | compute_and_save_metrics(dry2, source2_audio, metrics['dry2_deconv'], filename) 177 | 178 | # dry (delay-and-sum) 179 | filename = f'{current_dir}/dry1_delay.wav' 180 | compute_and_save_metrics(dry1_delay, source1_audio, metrics['dry1_delay'], filename) 181 | filename = f'{current_dir}/dry2_delay.wav' 182 | compute_and_save_metrics(dry2_delay, source2_audio, metrics['dry2_delay'], filename) 183 | 184 | # reverb (deconv-and-sum) 185 | filename = f'{current_dir}/reverb1.wav' 186 | compute_and_save_metrics(pred_reverb1, reverb1_audio, metrics['reverb1_deconv'], filename) 187 | filename = f'{current_dir}/reverb2.wav' 188 | compute_and_save_metrics(pred_reverb2, reverb2_audio, metrics['reverb2_deconv'], filename) 189 | 190 | # reverb (delay-and-sum) 191 | filename = f'{current_dir}/reverb1_delay.wav' 192 | compute_and_save_metrics(reverb1_delay, reverb1_audio, metrics['reverb1_delay'], filename) 193 | filename = f'{current_dir}/reverb2_delay.wav' 194 | compute_and_save_metrics(reverb2_delay, reverb2_audio, metrics['reverb2_delay'], filename) 195 | 196 | # Print metrics 197 | for method, method_metrics in metrics.items(): 198 | print(f'\n[{method} audio]') 199 | for metric_name, metric_values in method_metrics.items(): 200 | tensor_values = torch.tensor(metric_values) 201 | is_valid = torch.isfinite(tensor_values) & ~torch.isnan(tensor_values) 202 | valid_values = tensor_values[is_valid] # Exclude both 'inf' and 'nan' values 203 | print(f'{method} {metric_name}: {torch.tensor(valid_values).mean()}') 204 | 205 | # Build the dictionary with averages 206 | averaged_metrics = {} 207 | for method, method_metrics in metrics.items(): 208 | averaged_metrics[method] = {} 209 | for metric_name, metric_values in method_metrics.items(): 210 | tensor_values = torch.tensor(metric_values) 211 | is_valid = torch.isfinite(tensor_values) & ~torch.isnan(tensor_values) 212 | valid_values = tensor_values[is_valid] # Exclude both 'inf' and 'nan' values 213 | print(f'{method} {metric_name}: {torch.tensor(valid_values).mean()}') 214 | averaged_value = float(valid_values.mean()) if valid_values.nelement() > 0 else float('nan') # Avoid potential empty tensor 215 | averaged_metrics[method][metric_name] = averaged_value 216 | 217 | # Save to a JSON file 218 | filename_json = f'test_results/{self.dataset_name}/baseline/{self.split}/metrics.json' 219 | with open(filename_json, 'w') as json_file: 220 | json.dump(averaged_metrics, json_file, indent=4) 221 | 222 | def detection(self, num_eval=None): 223 | """ 224 | Baseline DSP for detection using cosine similarity of delay signals or deconvolved signals 225 | """ 226 | 227 | cs_deconv_list = [] 228 | cs_delay_list = [] 229 | cs_receiver_list = [] 230 | positive_list = [] 231 | 232 | loop_length = num_eval if num_eval is not None else len(self.room_id_pairs) 233 | for room_id_pair in tqdm(self.room_id_pairs[:loop_length]): 234 | room, id = room_id_pair 235 | data_dir = f'{self.dataset_dir}/{room}/{id}' 236 | 237 | metadata = torch.load(f'{data_dir}/metadata.pt') 238 | grid_points = metadata['grid_points'] 239 | num_points = grid_points.shape[0] 240 | source1_idx = metadata['source1_idx'] 241 | source2_idx = metadata['source2_idx'] 242 | source_idx_list = [source1_idx, source2_idx] 243 | receiver_idx_list = metadata['receiver_idx_list'] 244 | receiver_audio = self.load_audio_list(f'{data_dir}/receiver/receiver') 245 | 246 | for query_idx in range(num_points): 247 | if (query_idx not in receiver_idx_list): 248 | deconv_audio = self.load_audio_list(f'{data_dir}/wiener/wiener{query_idx}') 249 | 250 | # compute delay from distance 251 | c = 343 252 | sample_rate = 48000 253 | query_point = grid_points[query_idx] 254 | delay_list = [] 255 | for receiver_idx in receiver_idx_list: 256 | receiver_point = grid_points[receiver_idx] 257 | dist = (query_point - receiver_point).norm() - 0.1 # soundspaces offset 258 | delay_list.append(int(dist / c * sample_rate)) 259 | 260 | delayed_audio = compute_delayed(receiver_audio, delay_list) 261 | 262 | cs = cs_list(deconv_audio) 263 | cs_delay = cs_list(delayed_audio) 264 | cs_receiver = cs_list(receiver_audio) 265 | if query_idx in source_idx_list: 266 | positive_list.append(True) 267 | else: 268 | positive_list.append(False) 269 | 270 | cs_deconv_list.append(cs) 271 | cs_delay_list.append(cs_delay) 272 | cs_receiver_list.append(cs_receiver) 273 | 274 | cs_deconv_list = torch.tensor(cs_deconv_list) 275 | cs_delay_list = torch.tensor(cs_delay_list) 276 | cs_receiver_list = torch.tensor(cs_receiver_list) 277 | positive_list = torch.tensor(positive_list) 278 | print(f'deconv_p: {cs_deconv_list[positive_list].mean()}') 279 | print(f'deconv_n: {cs_deconv_list[~positive_list].mean()}') 280 | print(f'delay_p: {cs_delay_list[positive_list].mean()}') 281 | print(f'delay_n: {cs_delay_list[~positive_list].mean()}') 282 | print(f'receiver_p: {cs_receiver_list[positive_list].mean()}') 283 | print(f'receiver_n: {cs_receiver_list[~positive_list].mean()}') 284 | 285 | # AUC 286 | metric = BinaryAUROC(thresholds=None) 287 | auc_deconv = metric(cs_deconv_list, positive_list) 288 | auc_delay = metric(cs_delay_list, positive_list) 289 | auc_receiver = metric(cs_receiver_list, positive_list) 290 | print(f'AUC (receiver, delay, deconv): {auc_receiver}, {auc_delay}, {auc_deconv}') 291 | 292 | # # search best threshold 293 | # cs_deconv_p_tensor = torch.tensor(cs_deconv_p_list) 294 | # cs_deconv_n_tensor = torch.tensor(cs_deconv_n_list) 295 | # th_list = torch.arange(-0.1, 0.3, 0.01) 296 | # accuracy_list = [] 297 | # for th in th_list: 298 | # accuracy_p = (cs_deconv_p_tensor > th).sum() 299 | # accuracy_n = (cs_deconv_n_tensor <= th).sum() 300 | # accuracy_ = accuracy_p + accuracy_n 301 | # accuracy = float(accuracy_) / (len(cs_deconv_p_tensor) + len(cs_deconv_n_tensor)) 302 | # accuracy_list.append(accuracy) 303 | 304 | # max_accuracy = torch.tensor(accuracy_list).max() 305 | # max_idx = accuracy_list.index(max_accuracy) 306 | # max_th = th_list[max_idx] 307 | 308 | # print(f'max_accuracy: {max_accuracy}') 309 | # print(f'max_th: {max_th}') 310 | # # print(accuracy_list) 311 | 312 | # accuracy_p = (cs_deconv_p_tensor > max_th).float().mean() 313 | # accuracy_n = (cs_deconv_n_tensor <= max_th).float().mean() 314 | # print(accuracy_p) 315 | # print(accuracy_n) 316 | 317 | # max_th 318 | 319 | 320 | if __name__ == "__main__": 321 | random.seed(42) 322 | 323 | # Config parsing 324 | parser = argparse.ArgumentParser() 325 | parser.add_argument('--task', type=str, default='separation') # separation, detection 326 | parser.add_argument('--split', type=str, default='demo') 327 | parser.add_argument('--num_receivers', type=int, default=4) 328 | parser.add_argument('--audio_format', type=str, default='flac') 329 | parser.add_argument('--dataset_name', type=str, default='nvas3d_demo') 330 | parser.add_argument('--num_eval', type=int, default=None) # None to test all data 331 | args = parser.parse_args() 332 | 333 | # Run baseline 334 | baseline = BaselineDSP(args.split, args.num_receivers, args.audio_format, args.dataset_name) 335 | if args.task == 'separation': 336 | baseline.separation(args.num_eval) 337 | elif args.task == 'detection': 338 | baseline.detection(args.num_eval) 339 | --------------------------------------------------------------------------------