├── nvas3d
├── model
│ ├── __init__.py
│ └── model.py
├── train
│ ├── __init__.py
│ └── trainer.py
├── utils
│ ├── __init__.py
│ ├── training_data_generation
│ │ ├── upsample_librispeech.py
│ │ ├── README.md
│ │ ├── script_slakh.py
│ │ ├── generate_metadata_square.py
│ │ ├── generate_training_data.py
│ │ └── generate_test_data.py
│ ├── dynamic_utils.py
│ ├── utils.py
│ ├── plot_utils.py
│ └── generate_dataset_utils.py
├── data_loader
│ └── __init__.py
├── config
│ ├── default_config.yaml
│ └── visual_config.yaml
├── assets
│ └── saved_models
│ │ └── default
│ │ └── config.yaml
├── main.py
└── baseline
│ └── baseline_dsp.py
├── soundspaces_nvas3d
├── __init__.py
├── utils
│ ├── __init__.py
│ ├── render_scene_script.py
│ ├── audio_utils.py
│ └── ss_utils.py
├── rir_generation
│ ├── __init__.py
│ ├── minimal_example_render_rir.py
│ ├── generate_rir.py
│ └── generate_grid.py
├── image_rendering
│ ├── run_generate_envmap.py
│ ├── run_generate_target_image.py
│ ├── generate_envmap.py
│ └── generate_target_image.py
└── README.md
├── data
├── source
│ ├── drum.flac
│ ├── female.flac
│ ├── guitar1.flac
│ └── guitar2.flac
└── objects
│ ├── bluetooth_speaker.object_config.json
│ └── classic_microphone.object_config.json
├── assets
├── videos
│ └── teaser.mp4
└── images
│ └── thumbnail.png
├── demo
├── config_demo
│ ├── scene1_17DRP5sb8fy.json
│ └── path1_17DRP5sb8fy.json
├── run_demo.sh
├── README.md
└── generate_demo_data.py
├── CONTRIBUTING.md
├── setup.sh
├── LICENSE
├── CODE_OF_CONDUCT.md
├── README.md
└── ACKNOWLEDGEMENTS
/nvas3d/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/nvas3d/train/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/nvas3d/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/nvas3d/data_loader/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/soundspaces_nvas3d/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/soundspaces_nvas3d/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/soundspaces_nvas3d/rir_generation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/data/source/drum.flac:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-nvas3d/HEAD/data/source/drum.flac
--------------------------------------------------------------------------------
/data/source/female.flac:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-nvas3d/HEAD/data/source/female.flac
--------------------------------------------------------------------------------
/assets/videos/teaser.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-nvas3d/HEAD/assets/videos/teaser.mp4
--------------------------------------------------------------------------------
/data/source/guitar1.flac:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-nvas3d/HEAD/data/source/guitar1.flac
--------------------------------------------------------------------------------
/data/source/guitar2.flac:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-nvas3d/HEAD/data/source/guitar2.flac
--------------------------------------------------------------------------------
/assets/images/thumbnail.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-nvas3d/HEAD/assets/images/thumbnail.png
--------------------------------------------------------------------------------
/demo/config_demo/scene1_17DRP5sb8fy.json:
--------------------------------------------------------------------------------
1 | {
2 | "source_idx_list": [
3 | 14,
4 | 43
5 | ],
6 | "receiver_idx_list": [
7 | 23,
8 | 28,
9 | 29,
10 | 24
11 | ]
12 | }
--------------------------------------------------------------------------------
/data/objects/bluetooth_speaker.object_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "render_asset": "bluetooth_speaker.glb",
3 | "mass": 0.41,
4 | "friction_coefficient": 0.5,
5 | "restitution_coefficient": 0.4,
6 | "use_bounding_box_for_collision": true,
7 | "margin": 0.01
8 | }
--------------------------------------------------------------------------------
/data/objects/classic_microphone.object_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "render_asset": "classic_microphone.glb",
3 | "mass": 0.41,
4 | "friction_coefficient": 0.5,
5 | "restitution_coefficient": 0.4,
6 | "use_bounding_box_for_collision": true,
7 | "margin": 0.01
8 | }
--------------------------------------------------------------------------------
/soundspaces_nvas3d/image_rendering/run_generate_envmap.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import subprocess
7 |
8 | # RENDER ENVMAP
9 | room_list = [
10 | '17DRP5sb8fy'
11 | ]
12 |
13 | for i, room in enumerate(room_list):
14 | print(f'mp3d envmap rendering: {room}, {i+1}/{len(room_list)}')
15 | subprocess.run(['python', 'soundspaces_nvas3d/image_rendering/generate_envmap.py', '--room', room])
16 |
--------------------------------------------------------------------------------
/soundspaces_nvas3d/image_rendering/run_generate_target_image.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import subprocess
7 |
8 | # RENDER TARGET IMAGE
9 | room_list = [
10 | '17DRP5sb8fy'
11 | ]
12 |
13 | for i, room in enumerate(room_list):
14 | print(f'mp3d target image rendering: {room}, {i+1}/{len(room_list)}')
15 | subprocess.run(['python', 'soundspaces_nvas3d/image_rendering/generate_target_image.py', '--room', room])
16 |
--------------------------------------------------------------------------------
/nvas3d/config/default_config.yaml:
--------------------------------------------------------------------------------
1 | save_dir: './nvas3d/assets/saved_models'
2 | use_visual: False
3 | use_deconv: True
4 |
5 | data_loader:
6 | batch_size: 96
7 | num_workers: 8
8 | num_receivers: 4
9 | hop_length: 480
10 | win_length: 1200
11 | n_fft: 2048
12 | nvas3d_dataset: 'nvas3d_square_all_all'
13 | audio_format: 'flac'
14 |
15 | training:
16 | epochs: 200
17 | lr: 0.001
18 | weight_decay: 0.0
19 | detect_loss_weight: 0.1
20 | save_checkpoint_interval: 10
21 | resume: False
22 | checkpoint_dir: None
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/nvas3d/assets/saved_models/default/config.yaml:
--------------------------------------------------------------------------------
1 | save_dir: './nvas3d/assets/saved_models'
2 | use_visual: False
3 | use_deconv: True
4 |
5 | data_loader:
6 | batch_size: 96
7 | num_workers: 8
8 | num_receivers: 4
9 | hop_length: 480
10 | win_length: 1200
11 | n_fft: 2048
12 | nvas3d_dataset: 'nvas3d_demo'
13 | audio_format: 'flac'
14 |
15 | training:
16 | epochs: 200
17 | lr: 0.001
18 | weight_decay: 0.0
19 | detect_loss_weight: 0.1
20 | save_checkpoint_interval: 10
21 | resume: False
22 | checkpoint_dir: None
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/demo/config_demo/path1_17DRP5sb8fy.json:
--------------------------------------------------------------------------------
1 | {
2 | "receiver_idx_list": [
3 | 146,
4 | 141,
5 | 135,
6 | 132,
7 | 127,
8 | 120,
9 | 110,
10 | 100,
11 | 91,
12 | 83,
13 | 77,
14 | 70,
15 | 60
16 | ],
17 | "receiver_rotation_list": [
18 | 270,
19 | 270,
20 | 180,
21 | 180,
22 | 180,
23 | 90,
24 | 0,
25 | 0,
26 | 0,
27 | 90,
28 | 180,
29 | 180,
30 | 180
31 | ]
32 | }
--------------------------------------------------------------------------------
/nvas3d/config/visual_config.yaml:
--------------------------------------------------------------------------------
1 | save_dir: './nvas3d/saved_models'
2 | use_norm: True
3 |
4 | use_visual: True
5 | use_deconv: True
6 | use_real_imag: True
7 | use_mask: False
8 | use_visual_bg: False
9 | use_beamforming: False
10 | use_legacy: True
11 |
12 | data_loader:
13 | batch_size: 96
14 | num_workers: 8
15 | num_receivers: 4
16 | hop_length: 480
17 | win_length: 1200
18 | n_fft: 2048
19 | nvas3d_dataset: 'nvas3d_square_all_all'
20 | audio_format: 'flac'
21 |
22 | training:
23 | epochs: 200
24 | lr: 0.001
25 | weight_decay: 0.0
26 | detect_loss_weight: 0.1
27 | save_checkpoint_interval: 10
28 | resume: False
29 | checkpoint_dir: None
30 |
31 |
32 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducability, and beyond its publication there are limited plans for future development of the repository.
4 |
5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
6 |
7 | ## Before you get started
8 |
9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
10 |
11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
--------------------------------------------------------------------------------
/demo/run_demo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # =========================================
4 | # For licensing see accompanying LICENSE file.
5 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
6 | # =========================================
7 |
8 | # =============
9 | # Configuration
10 | # =============
11 | # For data generation
12 | room='17DRP5sb8fy'
13 | dataset_dir='nvas3d_demo'
14 | scene_config="scene1_17DRP5sb8fy"
15 |
16 | # For dry-sound estimation
17 | model='default'
18 | results_dir="results/$dataset_dir/$model/demo/$room/0"
19 |
20 | # For novel-view acoustic rendering
21 | novel_path_config="path1_17DRP5sb8fy"
22 |
23 |
24 | # =============
25 | # Execution
26 | # =============
27 | # Data generation
28 | python demo/generate_demo_data.py --room $room --dataset_dir $dataset_dir --scene_config $scene_config
29 |
30 | # Dry-sound estimation using our model
31 | python demo/test_demo.py --dataset_dir $dataset_dir --model $model
32 |
33 | # Novel-view acoustic rendering
34 | python demo/generate_demo_video.py --results_dir $results_dir --novel_path_config $novel_path_config
35 |
36 |
--------------------------------------------------------------------------------
/soundspaces_nvas3d/rir_generation/minimal_example_render_rir.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import itertools
7 |
8 | from soundspaces_nvas3d.utils.ss_utils import render_rir_parallel
9 | from soundspaces_nvas3d.utils.aihabitat_utils import load_room_grid
10 |
11 | # Configuration
12 | grid_distance = 2.0
13 | room = '17DRP5sb8fy'
14 |
15 | # Load grid points for the specified room
16 | grid_data = load_room_grid(room, grid_distance=grid_distance)
17 | grid_points = grid_data['grid_points']
18 |
19 | # Generate all possible combinations of source and receiver indices
20 | num_points = len(grid_points)
21 | pairs = list(itertools.product(range(num_points), repeat=2))
22 | idx_source, idx_receiver = zip(*pairs)
23 |
24 | # Extract corresponding source and receiver points
25 | source_points = grid_points[list(idx_source)]
26 | receiver_points = grid_points[list(idx_receiver)]
27 |
28 | # Create a room list for the IR generator
29 | room_list = [room] * len(source_points)
30 |
31 | # Generate Room Impulse Responses (IRs) without saving as WAV
32 | ir_list = render_rir_parallel(room_list, source_points, receiver_points)
33 |
--------------------------------------------------------------------------------
/nvas3d/utils/training_data_generation/upsample_librispeech.py:
--------------------------------------------------------------------------------
1 |
2 | import torchaudio
3 | from torchaudio.transforms import Resample
4 | import os
5 |
6 |
7 | def process_directory(input_dir, output_dir, sample_rate=48000):
8 | for dirpath, dirnames, filenames in os.walk(input_dir):
9 | for filename in filenames:
10 | if filename.endswith(".flac"):
11 | input_file_path = os.path.join(dirpath, filename)
12 | waveform_, sr = torchaudio.load(input_file_path)
13 |
14 | # Resample
15 | upsample_transform = Resample(sr, sample_rate)
16 | waveform = upsample_transform(waveform_)
17 |
18 | relative_path = os.path.relpath(dirpath, input_dir)
19 | clip_output_dir = os.path.join(output_dir, relative_path)
20 |
21 | os.makedirs(clip_output_dir, exist_ok=True)
22 |
23 | output_file_path = os.path.join(clip_output_dir, filename)
24 | torchaudio.save(output_file_path, waveform, sample_rate)
25 |
26 |
27 | # Modify the paths and parameters as per your needs
28 | input_dir = "data/source/LibriSpeech/"
29 | output_dir = "data/source/LibriSpeech48k/"
30 |
31 | process_directory(input_dir, output_dir)
32 |
--------------------------------------------------------------------------------
/soundspaces_nvas3d/utils/render_scene_script.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import sys
7 | from soundspaces_nvas3d.utils.ss_utils import render_scene_config, render_receiver_image, create_scene
8 | from soundspaces_nvas3d.utils.aihabitat_utils import save_grid_heatmap
9 |
10 |
11 | def execute_scene(args):
12 | filename, room, source_idx_list, all_receiver_idx_list, grid_distance = args
13 | render_scene_config(filename, room, eval(source_idx_list), eval(all_receiver_idx_list), eval(grid_distance), no_query=False)
14 |
15 |
16 | def execute_receiver(args):
17 | dirname, room, source_idx_list, source_class_list, all_receiver_idx_list = args
18 | render_receiver_image(dirname, room, eval(source_idx_list), eval(source_class_list), eval(all_receiver_idx_list))
19 |
20 |
21 | def execute_heatmap(args):
22 | filename, room, grid_points, prediction_list, receiver_idx_list, grid_distance = args
23 | scene = create_scene(room)
24 | save_grid_heatmap(filename, scene.sim.pathfinder, eval(grid_points), eval(prediction_list), receiver_idx_list=eval(receiver_idx_list))
25 |
26 |
27 | # Dictionary to map function names to actual function calls
28 | functions = {
29 | "scene": execute_scene,
30 | "receiver": execute_receiver,
31 | "heatmap": execute_heatmap
32 | }
33 |
34 | function_name = sys.argv[1]
35 |
36 | # Execute the corresponding function
37 | if function_name in functions:
38 | functions[function_name](sys.argv[2:])
39 | else:
40 | print(f"Error: Function {function_name} not found!")
41 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # =========================================
4 | # For licensing see accompanying LICENSE file.
5 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
6 | # =========================================
7 | #
8 | # This script automates the installation process described in the soundspaces_nvas3d/README.
9 | # Please refer to soundspaces_nvas3d/README.md for detailed instructions and explanations.
10 |
11 | # Update PYTHONPATH
12 | echo "export PYTHONPATH=\$PYTHONPATH:$(pwd)" >> ~/.bashrc && source ~/.bashrc
13 |
14 |
15 | # Install dependencies for Soundspaces
16 | apt-get update && apt-get upgrade -y
17 |
18 | apt-get install -y --no-install-recommends \
19 | libjpeg-dev libglm-dev libgl1-mesa-glx libegl1-mesa-dev mesa-utils xorg-dev freeglut3-dev
20 |
21 | # Update conda
22 | conda update -n base -c defaults conda
23 |
24 | # Create conda environment
25 | conda create -n ml-nvas3d python=3.7 cmake=3.14.0 -y
26 | conda activate ml-nvas3d
27 |
28 | # Install torch
29 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
30 |
31 | # Install habitat-sim
32 | cd ..
33 | git clone https://github.com/facebookresearch/habitat-sim.git
34 | cd habitat-sim
35 | pip install -r requirements.txt
36 | git checkout RLRAudioPropagationUpdate
37 | python setup.py install --headless --audio --with-cuda
38 |
39 | # Install habitat-lab
40 | cd ..
41 | git clone https://github.com/facebookresearch/habitat-lab.git
42 | cd habitat-lab
43 | git checkout v0.2.2
44 | pip install -e .
45 | sed -i '36 s/^/#/' habitat/tasks/rearrange/rearrange_sim.py # remove FetchRobot
46 |
47 | # Install soundspaces
48 | cd ..
49 | git clone https://github.com/facebookresearch/sound-spaces.git
50 | cd sound-spaces
51 | pip install -e .
52 |
53 | # Change directory
54 | cd ml-nvas3d
--------------------------------------------------------------------------------
/nvas3d/utils/training_data_generation/README.md:
--------------------------------------------------------------------------------
1 | # NVAS3D Training Data Generation
2 | This guid provides the process of generating training data for NVAS3D.
3 |
4 | ## Step-by-Step Instructions
5 |
6 | ### 1. Download Matterport3D Data
7 | Download all rooms from [Matterport3D](https://niessner.github.io/Matterport/).
8 |
9 | ### 2. Download Source Audios
10 | Download the following dataset for source audios and locate them at `data/source`
11 | * [Slakh2100](https://zenodo.org/record/4599666)
12 | * [LibriSpeech](https://www.openslr.org/12)
13 |
14 | ### 3. Preprocess Source Audios
15 | #### 3.1. Clip, Split, and Upsample Slakh2100
16 | To process Slakh2100 dataset (clipping, splitting, and upsampling to 48kHz), execute the following command:
17 | ```bash
18 | python nvas3d/training_data_generation/script_slakh.py
19 | ```
20 | The output will be located at `data/MIDI/clip/`.
21 |
22 | #### 3.2. Upsample LibriSpeech
23 | To upsample LibriSpeech dataset to 48kHz, execute the following command:
24 | ```bash
25 | python nvas3d/training_data_generation/upsample_librispeech.py
26 | ```
27 | The output will be located at `data/source/LibriSpeech48k`, and move it to `data/MIDI/clip/speech/LibriSpeech48k`.
28 |
29 | ### 4. Generate Metadata for Microphone Configuration
30 | To generate square-shaped microphone configuration metadata, execute the following command:
31 | ```bash
32 | python nvas3d/training_data_generation/generate_metadata_square.py
33 | ```
34 | The output metadata will be located at `data/nvas3d_square/`
35 |
36 | ### 5. Generate Training Data
37 | Finally, to generate the training data for NVAS3D, execute the following command:
38 | ```bash
39 | python nvas3d/training_data_generation/generate_training_data.py
40 | ```
41 | The generated data will be located at `data/nvas3d_square_all_all`.
42 |
43 | ## Acknowledgements
44 | * [LibriSpeech](https://www.openslr.org/12) is licensed under [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/).
45 |
46 | * [Slakh2100](http://www.slakh.com) is licensed under [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/).
47 |
48 | * [Matterport3D](https://niessner.github.io/Matterport/): Matterport3D-based task datasets and trained models are distributed with the [Matterport3D Terms of Use](https://kaldir.vc.in.tum.de/matterport/MP_TOS.pdf) and are licensed under [CC BY-NC-SA 3.0 US](https://creativecommons.org/licenses/by-nc-sa/3.0/us/).
49 |
50 |
51 |
--------------------------------------------------------------------------------
/soundspaces_nvas3d/rir_generation/generate_rir.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import os
7 | import argparse
8 | import itertools
9 |
10 | from soundspaces_nvas3d.utils.ss_utils import render_rir_parallel
11 | from soundspaces_nvas3d.utils.aihabitat_utils import load_room_grid
12 |
13 |
14 | def generate_rir(args: argparse.Namespace) -> None:
15 | """
16 | Generate Room Impulse Response (RIR) based on given room and grid distance.
17 | """
18 |
19 | grid_distance_str = str(args.grid_distance).replace(".", "_")
20 | dirname = os.path.join(args.dirname, f'rir_mp3d/grid_{grid_distance_str}', args.room)
21 | os.makedirs(dirname, exist_ok=True)
22 |
23 | grid_data = load_room_grid(args.room, grid_distance=args.grid_distance)
24 | grid_points = grid_data['grid_points']
25 | num_points = len(grid_points)
26 |
27 | # Generate combinations of source and receiver indices
28 | pairs = list(itertools.product(range(num_points), repeat=2))
29 | source_indices, receiver_indices = zip(*pairs)
30 |
31 | room_list = [args.room] * len(source_indices)
32 | source_points = grid_points[list(source_indices)]
33 | receiver_points = grid_points[list(receiver_indices)]
34 | filename_list = [
35 | os.path.join(dirname, f'ir_{args.room}_{source_idx}_{receiver_idx}.wav')
36 | for source_idx, receiver_idx in zip(source_indices, receiver_indices)
37 | ]
38 |
39 | render_rir_parallel(room_list, source_points, receiver_points, filename_list)
40 |
41 |
42 | if __name__ == '__main__':
43 | parser = argparse.ArgumentParser(description="Generate Room Impulse Response for given parameters.")
44 |
45 | parser.add_argument('--room',
46 | default='17DRP5sb8fy',
47 | type=str,
48 | help='MP3D room identifier')
49 |
50 | parser.add_argument('--grid_distance',
51 | default=2.0,
52 | type=float,
53 | help='Distance between grid points in meters')
54 |
55 | parser.add_argument('--dirname',
56 | default='data/examples',
57 | type=str,
58 | help='Directory to save generated RIRs')
59 |
60 | args = parser.parse_args()
61 | generate_rir(args)
62 |
--------------------------------------------------------------------------------
/soundspaces_nvas3d/image_rendering/generate_envmap.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import os
7 | import argparse
8 |
9 | import numpy as np
10 | import matplotlib.pyplot as plt
11 |
12 | from soundspaces_nvas3d.utils.ss_utils import create_scene
13 |
14 | """
15 | Notes:
16 | - MP3D room should be located at: data/scene_datasets/mp3d/{room}
17 | - Grid data should be present at: data/scene_datasets/metadata/mp3d/grid_1_0/grid_{room}.npy
18 | (Refer to: rir_generation/generate_grid.py for grid generation)
19 | """
20 |
21 | room_list = [
22 | '17DRP5sb8fy'
23 | ]
24 |
25 |
26 | def main(args):
27 | room = args.room
28 | grid_distance = args.grid_distance
29 |
30 | print(room)
31 | image_size = (192, 144)
32 |
33 | # Load grid
34 | grid_distance_str = str(grid_distance).replace(".", "_")
35 | dirname_grid = f'data/scene_datasets/metadata/mp3d/grid_{grid_distance_str}'
36 | filename_grid = f'{dirname_grid}/grid_{room}.npy'
37 | grid_info = np.load(filename_grid, allow_pickle=True).item()
38 |
39 | grid_points = grid_info['grid_points']
40 | dirname = f'data/examples/envmap_mp3d/grid_{grid_distance_str}/{room}'
41 | os.makedirs(dirname, exist_ok=True)
42 |
43 | scene = create_scene(room, image_size=image_size)
44 |
45 | for receiver_idx in range(grid_points.shape[0]):
46 | receiver_position = grid_points[receiver_idx]
47 | scene.update_receiver_position(receiver_position)
48 | rgb, depth = scene.render_envmap()
49 |
50 | filename_rgb = f'{dirname}/envmap_rgb_{room}_{receiver_idx}.png'
51 | filename_depth = f'{dirname}/envmap_depth_{room}_{receiver_idx}.png'
52 |
53 | plt.imsave(filename_rgb, rgb)
54 | plt.imsave(filename_depth, depth)
55 |
56 | filename_rgb = f'{dirname}/envmap_rgb_{room}_{receiver_idx}.npy'
57 | filename_depth = f'{dirname}/envmap_depth_{room}_{receiver_idx}.npy'
58 |
59 | np.save(filename_rgb, rgb)
60 | np.save(filename_depth, depth)
61 |
62 |
63 | if __name__ == '__main__':
64 | parser = argparse.ArgumentParser()
65 |
66 | parser.add_argument('--room',
67 | default='17DRP5sb8fy',
68 | type=str,
69 | help='mp3d room')
70 |
71 | parser.add_argument('--grid_distance',
72 | default=1.0,
73 | type=float,
74 | help='distance between grid points')
75 |
76 | args = parser.parse_args()
77 |
78 | main(args)
79 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (C) 2023 Apple Inc. All Rights Reserved.
2 |
3 | IMPORTANT: This Apple software is supplied to you by Apple
4 | Inc. ("Apple") in consideration of your agreement to the following
5 | terms, and your use, installation, modification or redistribution of
6 | this Apple software constitutes acceptance of these terms. If you do
7 | not agree with these terms, please do not use, install, modify or
8 | redistribute this Apple software.
9 |
10 | In consideration of your agreement to abide by the following terms, and
11 | subject to these terms, Apple grants you a personal, non-exclusive
12 | license, under Apple's copyrights in this original Apple software (the
13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple
14 | Software, with or without modifications, in source and/or binary forms;
15 | provided that if you redistribute the Apple Software in its entirety and
16 | without modifications, you must retain this notice and the following
17 | text and disclaimers in all such redistributions of the Apple Software.
18 | Neither the name, trademarks, service marks or logos of Apple Inc. may
19 | be used to endorse or promote products derived from the Apple Software
20 | without specific prior written permission from Apple. Except as
21 | expressly stated in this notice, no other rights or licenses, express or
22 | implied, are granted by Apple herein, including but not limited to any
23 | patent rights that may be infringed by your derivative works or by other
24 | works in which the Apple Software may be incorporated.
25 |
26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31 |
32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39 | POSSIBILITY OF SUCH DAMAGE.
40 |
41 | -------------------------------------------------------------------------------
42 | SOFTWARE DISTRIBUTED IN THIS REPOSITORY:
43 |
44 | The software includes a number of subcomponents with separate
45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
46 | -------------------------------------------------------------------------------
--------------------------------------------------------------------------------
/nvas3d/utils/dynamic_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import numpy as np
7 | import typing as T
8 |
9 | from scipy.signal import oaconvolve
10 |
11 |
12 | def setup_dynamic_interp(
13 | receiver_position: np.ndarray,
14 | total_samples: int,
15 | ) -> T.Tuple[np.ndarray, np.ndarray]:
16 | """
17 | Setup moving path with a constant speed for a receiver, given its positions in 3D space.
18 |
19 | Args:
20 | - receiver_position: Receiver positions in 3D space of shape (num_positions, 3).
21 | - total_samples: Total number of samples in the audio.
22 |
23 | Returns:
24 | - interp_index: Indices representing the start positions for interpolation.
25 | - interp_weight: Weight values for linear interpolation.
26 | """
27 |
28 | # Calculate the number of samples per interval
29 | distance = np.linalg.norm(np.diff(receiver_position, axis=0), axis=1)
30 | speed_per_sample = distance.sum() / total_samples
31 | samples_per_interval = np.round(distance / speed_per_sample).astype(int)
32 |
33 | # Distribute rounding errors
34 | error = total_samples - samples_per_interval.sum()
35 | for i in np.random.choice(len(samples_per_interval), abs(error)):
36 | samples_per_interval[i] += np.sign(error)
37 |
38 | # Calculate indices and weights for linear interpolation
39 | interp_index = np.repeat(np.arange(len(distance)), samples_per_interval)
40 | interp_weight = np.concatenate([np.linspace(0, 1, num, endpoint=False) for num in samples_per_interval])
41 |
42 | return interp_index, interp_weight.astype(np.float32)
43 |
44 |
45 | def convolve_moving_receiver(
46 | source_audio: np.ndarray,
47 | rirs: np.ndarray,
48 | interp_index: T.List[int],
49 | interp_weight: T.List[float]
50 | ) -> np.ndarray:
51 | """
52 | Apply convolution between an audio signal and moving impulse responses (IRs).
53 |
54 | Args:
55 | - source_audio: Source audio of shape (audio_len,)
56 | - rirs: RIRs of shape (num_positions, num_channels, ir_length)
57 | - interp_index: Indices representing the start positions for interpolation of shape (audio_len,).
58 | - interp_weight: Weight values for linear interpolation of shape (audio_len,).
59 |
60 | Returns:
61 | - Convolved audio signal of shape (num_channels, audio_len)
62 | """
63 |
64 | num_channels = rirs.shape[1]
65 | audio_len = source_audio.shape[0]
66 |
67 | # Perform convolution for each position and channel
68 | convolved_audios = oaconvolve(source_audio[None, None, :], rirs, axes=-1)[..., :audio_len]
69 |
70 | # NumPy fancy indexing and broadcasting for interpolation
71 | start_audio = convolved_audios[interp_index, np.arange(num_channels)[:, None], np.arange(audio_len)]
72 | end_audio = convolved_audios[interp_index + 1, np.arange(num_channels)[:, None], np.arange(audio_len)]
73 | interp_weight = interp_weight[None, :]
74 |
75 | # Apply linear interpolation
76 | moving_audio = (1 - interp_weight) * start_audio + interp_weight * end_audio
77 |
78 | return moving_audio
79 |
--------------------------------------------------------------------------------
/nvas3d/main.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import os
7 | import yaml
8 | import shutil
9 | import logging
10 | import argparse
11 |
12 | import torch
13 | import torch.multiprocessing as mp
14 | import torch.distributed as dist
15 | from torch.nn.parallel import DistributedDataParallel
16 |
17 | from nvas3d.train.trainer import Trainer
18 | from nvas3d.data_loader.data_loader import SSAVDataLoader
19 | from nvas3d.model.model import NVASNet
20 |
21 |
22 | def setup_distributed_training(local_rank):
23 | world_size = torch.cuda.device_count()
24 | dist.init_process_group(backend='nccl',
25 | init_method='tcp://localhost:12355',
26 | rank=local_rank,
27 | world_size=world_size)
28 | device = torch.device(f'cuda:{local_rank}')
29 | return device
30 |
31 |
32 | def main(local_rank, args):
33 | is_ddp = args.gpu is None
34 | device = setup_distributed_training(local_rank) if is_ddp else torch.device(f'cuda:{args.gpu}')
35 |
36 | if not is_ddp:
37 | torch.cuda.set_device(args.gpu)
38 |
39 | # Load and parse config file
40 | with open(args.config, 'r') as f:
41 | config = yaml.safe_load(f)
42 |
43 | # Create a directory and copy config
44 | save_dir = os.path.join(config['save_dir'], f'{args.exp}')
45 | os.makedirs(save_dir, exist_ok=True)
46 | shutil.copy(args.config, f'{save_dir}/config.yaml')
47 |
48 | # Initialize DataLoader
49 | data_loader = SSAVDataLoader(config['use_visual'], config['use_deconv'], is_ddp, **config['data_loader'])
50 |
51 | # Initialize and deploy model to device
52 | model = NVASNet(config['data_loader']['num_receivers'], config['use_visual'])
53 | model = model.to(device)
54 | if is_ddp:
55 | model = DistributedDataParallel(model, device_ids=[local_rank])
56 |
57 | # Train the model
58 | trainer = Trainer(model, data_loader, device, save_dir, config['use_deconv'], config['training'])
59 | trainer.train()
60 |
61 | if is_ddp:
62 | dist.destroy_process_group()
63 |
64 |
65 | if __name__ == '__main__':
66 | torch.manual_seed(42)
67 |
68 | logging.basicConfig(level=logging.INFO, format='%(asctime)s, %(levelname)s: %(message)s', datefmt="%Y-%m-%d %H:%M:%S")
69 |
70 | parser = argparse.ArgumentParser()
71 | parser.add_argument('--config',
72 | type=str,
73 | default='./nvas3d/config/default_config.yaml',
74 | help='Path to the configuration file.')
75 |
76 | parser.add_argument('--exp',
77 | type=str,
78 | default='default_exp',
79 | help='Experiment name')
80 |
81 | parser.add_argument('--gpu',
82 | type=int,
83 | default=None,
84 | help='GPU ID to use')
85 |
86 | args = parser.parse_args()
87 |
88 | if args.gpu is not None:
89 | main(0, args) # Single GPU mode
90 | else:
91 | mp.spawn(main, args=(args,), nprocs=torch.cuda.device_count(), join=True)
92 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Novel-View Acoustic Synthesis from 3D Reconstructed Rooms
2 |
3 | [[Paper](http://arxiv.org/abs/2310.15130)]
4 | [[Docs](/soundspaces_nvas3d/README.md)]
5 | [[Demo docs](/demo/README.md)]
6 | [[Video1](https://docs-assets.developer.apple.com/ml-research/models/nvas/nvas3d_turn.mp4)]
7 | [[Video2](https://docs-assets.developer.apple.com/ml-research/models/nvas/teaser.mp4)]
8 |
9 |
10 | > [Click on the thumbnail image](https://docs-assets.developer.apple.com/ml-research/models/nvas/teaser.mp4) below to watch a video showcasing our Novel-View Acoustic Synthesis.
11 | >
12 | > 🎧 _For the optimal experience, using a headset is recommended._
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | Welcome to the official code repository for "Novel-View Acoustic Synthesis from 3D Reconstructed Rooms".
21 | This project estimates the sound anywhere in a scene containing multiple unknown sound sources, hence resulting in novel-view acoustic synthesis, given audio recordings from multiple microphones and the 3D geometry and material of a scene.
22 |
23 |
24 | ["Novel-View Acoustic Synthesis from 3D Reconstructed Rooms"](http://arxiv.org/abs/2310.15130)\
25 | [Byeongjoo Ahn](https://byeongjooahn.github.io),
26 | [Karren Yang](https://karreny.github.io),
27 | [Brian Hamilton](https://www.brianhamilton.co),
28 | [Jonathan Sheaffer](https://www.linkedin.com/in/jsheaffer/),
29 | [Anurag Ranjan](https://anuragranj.github.io),
30 | [Miguel Sarabia](https://scholar.google.co.uk/citations?user=U2mA-EAAAAAJ&hl=en),
31 | [Oncel Tuzel](https://www.onceltuzel.net),
32 | [Jen-Hao Rick Chang](https://rick-chang.github.io)
33 |
34 | ## Directory Structure
35 | ```yaml
36 | .
37 | ├── demo/ # Quickstart and demo
38 | │ ├── ...
39 | ├── nvas3d/ # Implementation of our model
40 | │ ├── ...
41 | └── soundspaces_nvas3d/ # SoundSpaces integration for NVAS3D
42 | ├── ...
43 | ```
44 |
45 | ## Installation: SoundSpaces
46 | Follow our [Step-by-Step Installation Guide](soundspaces_nvas3d/README.md) for rendering room impulse responses (RIRs) and images in Matterport3D rooms using SoundSpaces.
47 |
48 | ## Quickstart: Demo
49 | Refer to the [Demo Guide](demo/README.md) for instructions on data generation, dry sound estimation using our model, and novel-view acoustic rendering.
50 |
51 | ### Download the Pretrained Model
52 | Download [our pretrained model](https://docs-assets.developer.apple.com/ml-research/models/nvas/checkpoint_200.pt) and place it in the `nvas3d/assets/saved_models/default/checkpoints/` directory.
53 |
54 | ### Launch the Demo
55 | To get started with the full pipeline quickly:
56 | ```bash
57 | bash demo/run_demo.sh
58 | ```
59 |
60 | ## Training
61 | After [Training Data Generation](nvas3d/utils/training_data_generation/README.md), start the training process with:
62 | ```bash
63 | python main.py --config ./nvas3d/config/default_config.yaml --exp default_exp
64 | ```
65 |
66 | ## Acknowledgements
67 | We thank Dirk Schroeder and David Romblom for insightful discussions and feedback, Changan Chen for the assistance with SoundSpaces.
68 |
69 |
78 |
--------------------------------------------------------------------------------
/soundspaces_nvas3d/rir_generation/generate_grid.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import os
7 | import argparse
8 |
9 | import numpy as np
10 | import torch
11 | import typing as T
12 |
13 | from soundspaces_nvas3d.soundspaces_nvas3d import Scene
14 |
15 | # Set rooms for which grids need to be generated
16 | ROOM_LIST = [
17 | '17DRP5sb8fy'
18 | ]
19 |
20 |
21 | def save_xy_grid_points(room: str,
22 | grid_distance: float,
23 | dirname: str
24 | ) -> T.Dict[str, T.Any]:
25 | """
26 | Save xy grid points given a mp3d room
27 | """
28 |
29 | filename_npy = f'{dirname}/grid_{room}.npy'
30 | filename_png = f'{dirname}/grid_{room}.png'
31 |
32 | scene = Scene(
33 | room,
34 | [None], # placeholder for source class
35 | include_visual_sensor=False,
36 | add_source_mesh=False,
37 | device=torch.device('cpu')
38 | )
39 | grid_points = scene.generate_xy_grid_points(grid_distance, filename_png=filename_png)
40 | room_size = scene.sim.pathfinder.navigable_area
41 |
42 | grid_info = dict(
43 | grid_points=grid_points,
44 | room_size=room_size,
45 | grid_distance=grid_distance,
46 | )
47 | np.save(filename_npy, grid_info)
48 |
49 | return grid_info
50 |
51 |
52 | def calculate_statistics(data_dict):
53 | values = list(data_dict.values())
54 | average = np.mean(values)
55 | minimum = np.min(values)
56 | maximum = np.max(values)
57 | return average, minimum, maximum
58 |
59 |
60 | def main(args):
61 | # Set grid distance
62 | grid_distance = args.grid_distance
63 | grid_distance_str = str(grid_distance).replace(".", "_")
64 |
65 | # Generate grid points
66 | dirname = f'data/scene_datasets/metadata/mp3d/grid_{grid_distance_str}'
67 | os.makedirs(dirname, exist_ok=True)
68 |
69 | # Define lists to store data
70 | num_points_dict = {}
71 | room_size_dict = {}
72 |
73 | for room in ROOM_LIST:
74 | # Note: mp3d room should be under data/scene_datasets/mp3d/{room}
75 | grid_info = save_xy_grid_points(room, grid_distance, dirname)
76 |
77 | # Append data to lists
78 | num_points_dict[room] = grid_info['grid_points'].shape[0]
79 | room_size_dict[room] = grid_info['room_size']
80 |
81 | # Calculate statistics
82 | num_points_average, num_points_min, num_points_max = calculate_statistics(num_points_dict)
83 | room_size_average, room_size_min, room_size_max = calculate_statistics(room_size_dict)
84 |
85 | # Save statistics to txt file
86 | filename_satistics = f'data/scene_datasets/metadata/mp3d/grid_{grid_distance_str}/statistics.txt'
87 | with open(filename_satistics, 'w') as file:
88 | file.write('[Number of points statistics]:\n')
89 | file.write(f'Average: {num_points_average:.2f}\n')
90 | file.write(f'Minimum: {num_points_min}\n')
91 | file.write(f'Maximum: {num_points_max}\n')
92 | file.write('\n')
93 | file.write('[Room size statistics]:\n')
94 | file.write(f'Average: {room_size_average:.2f}\n')
95 | file.write(f'Minimum: {room_size_min:.2f}\n')
96 | file.write(f'Maximum: {room_size_max:.2f}\n')
97 |
98 | file.write('\n')
99 | file.write('[Room-wise statistics]:\n')
100 | for room in ROOM_LIST:
101 | file.write(f'{room} - Num Points: {num_points_dict[room]}, \t Room Size: {room_size_dict[room]:.2f}\n')
102 |
103 |
104 | if __name__ == '__main__':
105 | parser = argparse.ArgumentParser()
106 |
107 | parser.add_argument('--grid_distance',
108 | default=1.0,
109 | type=float,
110 | help='distance between grid points')
111 |
112 | args = parser.parse_args()
113 |
114 | main(args)
115 |
--------------------------------------------------------------------------------
/nvas3d/utils/utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 | # Portions of this code are derived from VIDA (CC-BY-NC).
6 | # Original work available at: https://github.com/facebookresearch/learning-audio-visual-dereverberation/tree/main
7 |
8 | import typing as T
9 | import numpy as np
10 |
11 | import torch
12 | import torch.nn.functional as F
13 |
14 | source_class_map = {
15 | 'female': 0, # speech: 22961
16 | 'male': 1,
17 | 'Bass': 2, # 51678
18 | 'Brass': 3, # 7216
19 | 'Chromatic Percussion': 4, # 4154
20 | 'Drums': 5, # 47291
21 | 'Guitar': 6, # 75848
22 | 'Organ': 7, # 10438
23 | 'Piano': 8, # 59328
24 | 'Pipe': 9, # 8200
25 | 'Reed': 10, # 10039
26 | 'Strings': 11, # 3169
27 | 'Strings (continued)': 12, # 48429
28 | 'Synth Lead': 13, # 7053
29 | 'Synth Pad': 14 # 11210
30 | }
31 |
32 | MP3D_SCENE_SPLITS = {
33 | 'demo': ['17DRP5sb8fy'],
34 | }
35 |
36 |
37 | def get_key_from_value(my_dict, target_value):
38 | for key, value in my_dict.items():
39 | if value == target_value:
40 | return key
41 | return None
42 |
43 |
44 | def parse_librispeech_metadata(filename: str) -> T.Dict:
45 | """
46 | Reads LibriSpeech metadata from a csv file and returns a dictionary.
47 | Each entry in the dictionary maps a reader_id (as integer) to its corresponding gender.
48 | """
49 |
50 | import csv
51 |
52 | # Dictionary to store reader_id and corresponding gender
53 | librispeech_metadata = {}
54 |
55 | with open(filename, 'r') as file:
56 | reader = csv.reader(file, delimiter='|')
57 | for row in reader:
58 | # Skip comment lines and header
59 | if row[0].startswith(';') or row[0].strip() == 'ID':
60 | continue
61 | reader_id = int(row[0]) # Convert string to integer
62 | sex = row[1].strip() # Remove extra spaces
63 | librispeech_metadata[reader_id] = sex
64 |
65 | return librispeech_metadata
66 |
67 |
68 | def parse_ir(filename_ir: str):
69 | """
70 | Extracts the room, source index and receiver index information from an IR filename.
71 | The function assumes that these elements are present at the end of the filename, separated by underscores.
72 | """
73 |
74 | parts = filename_ir.split('_')
75 | room = parts[-3]
76 | source_idx = int(parts[-2])
77 | receiver_idx = int(parts[-1].split('.')[0])
78 |
79 | return room, source_idx, receiver_idx
80 |
81 |
82 | def count_parameters(model):
83 | """
84 | Returns the number of trainable parameters in a given PyTorch model.
85 | """
86 |
87 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
88 |
89 |
90 | ####################
91 | # From Changan's code
92 | ####################
93 |
94 | def complex_norm(
95 | complex_tensor: torch.Tensor,
96 | power: float = 1.0
97 | ) -> torch.Tensor:
98 | """
99 | Compute the norm of complex tensor input.
100 | """
101 |
102 | # Replace by torch.norm once issue is fixed
103 | # https://github.com/pytorch/pytorch/issues/34279
104 | return complex_tensor.pow(2.).sum(-1).pow(0.5 * power)
105 |
106 |
107 | def normalize(audio, norm='peak'):
108 | if norm == 'peak':
109 | peak = abs(audio).max()
110 | if peak != 0:
111 | return audio / peak
112 | else:
113 | return audio
114 | elif norm == 'rms':
115 | if torch.is_tensor(audio):
116 | audio = audio.numpy()
117 | audio_without_padding = np.trim_zeros(audio, trim='b')
118 | rms = np.sqrt(np.mean(np.square(audio_without_padding))) * 100
119 | if rms != 0:
120 | return audio / rms
121 | else:
122 | return audio
123 | else:
124 | raise NotImplementedError
125 |
126 |
127 | def overlap_chunk(input, dimension, size, step, left_padding):
128 | """
129 | Input shape is [Frequency bins, Frame numbers]
130 | """
131 | input = F.pad(input, (left_padding, size), 'constant', 0)
132 | return input.unfold(dimension, size, step)
133 |
--------------------------------------------------------------------------------
/nvas3d/utils/plot_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 |
9 | import torch
10 | import torchaudio
11 | from torchmetrics.audio import SignalDistortionRatio
12 |
13 | from nvas3d.utils.audio_utils import psnr, clip_two
14 |
15 |
16 | def save_specgram(filename, stft):
17 | mag = torch.sqrt(stft[..., 0]**2 + stft[..., 1]**2)
18 | fig, ax = plt.subplots(1, 1, figsize=(14, 4))
19 | cax1 = ax.imshow(torch.log10(mag), aspect='auto', origin='lower')
20 | fig.colorbar(cax1, ax=ax, orientation='horizontal')
21 | ax.set_title('Magnitude spectrogram')
22 |
23 | plt.tight_layout()
24 | plt.savefig(filename)
25 | plt.close(fig)
26 |
27 |
28 | def plot_specgram(filename, waveform, sample_rate, plot_phase=False):
29 |
30 | mag = compute_spectrogram(waveform, use_mag=True).squeeze()
31 | phase = compute_spectrogram(waveform, use_phase=True).squeeze()
32 |
33 | phase = phase * (180 / np.pi)
34 |
35 | if plot_phase:
36 | fig, ax = plt.subplots(2, 1, figsize=(14, 8))
37 | cax1 = ax[0].imshow(torch.log10(mag), aspect='auto', origin='lower')
38 | fig.colorbar(cax1, ax=ax[0], orientation='horizontal')
39 | ax[0].set_title('Magnitude spectrogram')
40 |
41 | cax2 = ax[1].imshow(phase, aspect='auto', origin='lower', cmap='twilight')
42 | fig.colorbar(cax2, ax=ax[1], orientation='horizontal')
43 | ax[1].set_title('Phase spectrogram')
44 | else:
45 | fig, ax = plt.subplots(1, 1, figsize=(14, 4))
46 | cax1 = ax.imshow(torch.log10(mag), aspect='auto', origin='lower')
47 | fig.colorbar(cax1, ax=ax, orientation='horizontal')
48 | ax.set_title('Magnitude spectrogram')
49 |
50 | plt.tight_layout()
51 | plt.savefig(filename)
52 | plt.close(fig)
53 |
54 |
55 | def plot_waveform(filename, waveform, sample_rate):
56 | plt.figure(figsize=(14, 4))
57 | plt.plot(np.linspace(0, len(waveform[0]) / sample_rate, len(waveform[0])), waveform[0])
58 | plt.title('Waveform')
59 | plt.xlabel('Time [s]')
60 | plt.savefig(filename)
61 | plt.close()
62 |
63 |
64 | def plot_debug(filename, waveform, sample_rate, save_png=False, save_normalized=False, reference=None):
65 | waveform = waveform.reshape(1, -1)
66 |
67 | if reference is not None:
68 | reference = reference.reshape(1, -1)
69 | waveform, reference = clip_two(waveform, reference)
70 | # assert waveform.shape[-1] == reference.shape[-1]
71 | psnr_score = psnr(reference.detach().cpu(), waveform.detach().cpu())
72 | compute_sdr = SignalDistortionRatio()
73 | sdr = compute_sdr(reference.detach().cpu(), waveform.detach().cpu())
74 |
75 | if save_normalized:
76 | waveform /= waveform.abs().max()
77 |
78 | if save_png:
79 | plot_specgram(f'{filename}_specgram.png', waveform, sample_rate)
80 | plot_waveform(f'{filename}_waveform.png', waveform, sample_rate)
81 |
82 | if reference is not None:
83 | filename = f'{filename}({psnr_score:.1f})({sdr:.1f})'
84 | torchaudio.save(f'{filename}.wav', waveform, sample_rate)
85 |
86 |
87 | def compute_spectrogram(audio_data, n_fft=2048, hop_length=480, win_length=1200, use_mag=False, use_phase=False, use_complex=False):
88 |
89 | audio_data = to_tensor(audio_data)
90 | stft = torch.stft(audio_data, n_fft=n_fft, hop_length=hop_length, win_length=win_length,
91 | window=torch.hamming_window(win_length, device=audio_data.device), pad_mode='constant',
92 | return_complex=not use_complex)
93 |
94 | if use_mag:
95 | spectrogram = stft.abs().unsqueeze(-1)
96 | elif use_phase:
97 | spectrogram = stft.angle().unsqueeze(-1)
98 | elif use_complex:
99 | # one channel for real and one channel for imaginary
100 | spectrogram = stft
101 | else:
102 | raise ValueError
103 |
104 | return spectrogram
105 |
106 |
107 | def to_tensor(v):
108 | if torch.is_tensor(v):
109 | return v
110 | elif isinstance(v, np.ndarray):
111 | return torch.from_numpy(v)
112 | else:
113 | return torch.tensor(v, dtype=torch.float)
114 |
--------------------------------------------------------------------------------
/demo/README.md:
--------------------------------------------------------------------------------
1 | # Novel-View Acoustic Synthesis from 3D Reconstructed Rooms: Demo
2 |
3 | This guide is for the demonstration for NVAS from 3D reconstructed rooms. It provides instructions for data generation, dry sound estimation using our model, and novel-view acoustic rendering.
4 |
5 | ## Quick Start
6 |
7 | ### Download the Pretrained Model
8 | Download [our pretrained model](https://docs-assets.developer.apple.com/ml-research/models/nvas/checkpoint_200.pt) and place it in the `nvas3d/assets/saved_models/default/checkpoints/` directory.
9 | To get started with the full pipeline quickly:
10 | ```bash
11 | bash demo/run_demo.sh
12 | ```
13 |
14 | ### Launch the Demo
15 | To get started with the full pipeline quickly:
16 | ```bash
17 | bash demo/run_demo.sh
18 | ```
19 |
20 | ## Detailed Workflow
21 | For a more detailed approach, you can explore each segment of the workflow individually:
22 |
23 | ### 1. Data Generation
24 | Generate and save the demo data specific to a room. Please ensure you've installed SoundSpaces and downloaded the necessary data (sample Matterport3D room of 17DRP5sb8fy and material config) by following the instructions in [soundspaces_nvas3d/README.md](../soundspaces_nvas3d/README.md).
25 |
26 | ```bash
27 | python demo/generate_demo_data.py --room 17DRP5sb8fy
28 | ```
29 |
30 | The generated data will be structured as:
31 | ```yaml
32 | data/nvas3d_demo/demo/{room}/0
33 | ├── receiver/ # Receiver audio.
34 | ├── wiener/ # Deconvolved audio to accelerate tests.
35 | ├── source/ # Ground truth dry audio.
36 | ├── reverb1/ # Ground truth reverberant audio for source 1.
37 | ├── reverb2/ # Ground truth reverberant audio for source 2.
38 | ├── ir_receiver/ # Ground truth RIRs from source to receiver.
39 | ├── config.png # Visualization of room configuration.
40 | └── metadata.pt # Metadata (source indices, classes, grid points, and room info).
41 | ```
42 |
43 | Additionally, visualizations of room indices will be located at `data/nvas3d_demo/{room}/index_{grid_distance}.png`. If not, please ensure that [headless rendering in Habitat-Sim](https://github.com/facebookresearch/habitat-sim/issues?q=headless) is configured correctly.
44 |
45 | > [!Note]
46 | > If you want to modify the scene configuration (e.g.,, source locations, receiver locations), edit `demo/nvas/config_demo/scene_config.json`.
47 |
48 | ### 2. Dry Sound Estimation Using Our Model
49 | Run the NVAS3D model on the your dataset:
50 | ```bash
51 | python demo/test_demo.py
52 | ```
53 |
54 | The results will be structured as:
55 | ```yaml
56 | {results_demo} = results/nvas3d_demo/default/demo/{room}/0
57 | │
58 | ├── results_drysound/
59 | │ ├── dry1_estimated.wav # Estimated dry sound for source 1 location.
60 | │ ├── dry2_estimated.wav # Estimated dry sound for source 2 location.
61 | │ ├── dry1_gt.wav # Ground-truth dry sound for source 1.
62 | │ ├── dry2_gt.wav # Ground-truth dry sound for source 2.
63 | │ └── detected/
64 | │ └── dry_{query_idx}.wav # Estimated dry sound for query idx if positive.
65 | │
66 | ├── baseline_dsp/
67 | │ ├── deconv_and_sum1.wav # Baseline result of dry sound for source 1.
68 | │ └── deconv_and_sum2.wav # Baseline result of dry sound for source 2.
69 | │
70 | ├── results_detection/
71 | │ ├── detection_heatmap.png # Heatmap visualizing source detection results.
72 | │ └── metadata.pt # Detection results and metadata.
73 | │
74 | └── metrics.json # Quantitative metrics.
75 |
76 | ```
77 |
78 | ### 3. Novel-view Acoustic Rendering
79 | Render the novel-view video integrated with the sound:
80 |
81 | ```bash
82 | python demo/generate_demo_video.py
83 | ```
84 |
85 | The video results will be structured as:
86 | ```yaml
87 | results/nvas3d_demo/default/demo/{room}/0
88 | └── video/
89 | ├── moving_audio.wav # Audio only: Interpolated sound for the moving receiver.
90 | ├── moving_audio_1.wav # Audio only: Interpolated sound from separated Source 1 for the moving receiver.
91 | ├── moving_audio_2.wav # Audio only: Interpolated sound from separated Source 2 for the moving receiver.
92 | ├── moving_video.mp4 # Video only: Interpolated video for the moving receiver.
93 | ├── nvas.mp4 # Video with Audio: NVAS video results with combined audio from all sources.
94 | ├── nvas_source1.mp4 # Video with Audio: NVAS video results for separated Source 1 audio.
95 | ├── nvas_source2.mp4 # Video with Audio: NVAS video results for separated Source 2 audio.
96 | └── rgb_receiver.png # Image: A static view rendered from the receiver's perspective for reference.
97 | ```
98 |
99 | > [!Note]
100 | > If you want to modify the novel receiver's path, edit `demo/nvas/config_demo/path1_novel_receiver.json`.
101 |
--------------------------------------------------------------------------------
/ACKNOWLEDGEMENTS:
--------------------------------------------------------------------------------
1 | Acknowledgements
2 | Portions of this software may utilize the following copyrighted material, the use of which is hereby acknowledged.
3 |
4 | _____________________
5 |
6 | Facebook, Inc (PyTorch)
7 | Copyright (c) 2016- Facebook, Inc (Adam Paszke)
8 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
9 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
10 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
11 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
12 | Copyright (c) 2011-2013 NYU (Clement Farabet)
13 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
14 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
15 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
16 |
17 | All rights reserved.
18 |
19 | Redistribution and use in source and binary forms, with or without
20 | modification, are permitted provided that the following conditions are met:
21 |
22 | 1. Redistributions of source code must retain the above copyright
23 | notice, this list of conditions and the following disclaimer.
24 |
25 | 2. Redistributions in binary form must reproduce the above copyright
26 | notice, this list of conditions and the following disclaimer in the
27 | documentation and/or other materials provided with the distribution.
28 |
29 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
30 | and IDIAP Research Institute nor the names of its contributors may be
31 | used to endorse or promote products derived from this software without
32 | specific prior written permission.
33 |
34 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
35 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
36 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
37 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
38 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
39 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
40 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
41 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
42 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
43 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
44 | POSSIBILITY OF SUCH DAMAGE.
45 |
46 | Meta Platforms, Inc (Habitat-Sim)
47 | MIT License
48 |
49 | Copyright (c) Meta Platforms, Inc. and its affiliates.
50 |
51 | Permission is hereby granted, free of charge, to any person obtaining a copy
52 | of this software and associated documentation files (the "Software"), to deal
53 | in the Software without restriction, including without limitation the rights
54 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
55 | copies of the Software, and to permit persons to whom the Software is
56 | furnished to do so, subject to the following conditions:
57 |
58 | The above copyright notice and this permission notice shall be included in all
59 | copies or substantial portions of the Software.
60 |
61 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
62 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
63 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
64 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
65 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
66 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
67 | SOFTWARE.
68 |
69 | Meta Platforms, Inc (Habitat-Lab)
70 | MIT License
71 |
72 | Copyright (c) Meta Platforms, Inc. and its affiliates.
73 |
74 | Permission is hereby granted, free of charge, to any person obtaining a copy
75 | of this software and associated documentation files (the "Software"), to deal
76 | in the Software without restriction, including without limitation the rights
77 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
78 | copies of the Software, and to permit persons to whom the Software is
79 | furnished to do so, subject to the following conditions:
80 |
81 | The above copyright notice and this permission notice shall be included in all
82 | copies or substantial portions of the Software.
83 |
84 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
85 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
86 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
87 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
88 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
89 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
90 | SOFTWARE.
91 |
92 | Meta Platforms, Inc (SoundSpaces)
93 | Attribution 4.0 International
94 |
95 | Meta Platforms, Inc (VisualVoice)
96 | Attribution-NonCommercial 4.0 International
97 |
--------------------------------------------------------------------------------
/nvas3d/utils/training_data_generation/script_slakh.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 | # Script to clip and split Slakh2100 dataset: https://zenodo.org/record/4599666
6 | #
7 |
8 | import os
9 | import yaml
10 | import shutil
11 | import numpy as np
12 |
13 | import torchaudio
14 | from torchaudio.transforms import Resample
15 |
16 |
17 | # Set the silence threshold
18 | THRESHOLD = 1e-2
19 |
20 |
21 | def read_instruments(source_dir):
22 | """
23 | Read instrument metadata from source directory
24 | """
25 |
26 | inst_dict = {} # dictionary to store instrument data
27 | # Loop through every file in the source directory
28 | for root, dirs, files in os.walk(source_dir):
29 | for file in files:
30 | if file.endswith("metadata.yaml"):
31 | with open(os.path.join(root, file), 'r') as yaml_file:
32 | metadata = yaml.safe_load(yaml_file)
33 |
34 | # Add instrument name and class to dictionary
35 | for stem, stem_data in metadata['stems'].items():
36 | if stem_data['midi_program_name'] not in inst_dict.keys():
37 | inst_dict[stem_data['midi_program_name']] = stem_data['inst_class']
38 | return inst_dict
39 |
40 |
41 | def copy_files_based_on_metadata(source_dir, target_dir, query):
42 | """
43 | Copy files based on metadata
44 | """
45 |
46 | # Walk through each file in the source directory
47 | for root, dirs, files in os.walk(source_dir):
48 | for file in files:
49 | if file.endswith("metadata.yaml"):
50 | with open(os.path.join(root, file), 'r') as yaml_file:
51 | try:
52 | metadata = yaml.safe_load(yaml_file)
53 |
54 | # If instrument matches query, copy the associated flac file to target directory
55 | for stem, stem_data in metadata['stems'].items():
56 | if stem_data['midi_program_name'] == query:
57 | source_flac_file = os.path.join(root.replace('metadata.yaml', ''), "stems", f"{stem}.flac")
58 | target_flac_file = source_flac_file.replace(source_dir, target_dir)
59 |
60 | os.makedirs(os.path.dirname(target_flac_file), exist_ok=True)
61 |
62 | # Copy the .flac file
63 | if os.path.exists(source_flac_file):
64 | shutil.copyfile(source_flac_file, target_flac_file)
65 |
66 | except yaml.YAMLError as exc:
67 | print(exc)
68 |
69 |
70 | def save_clips(waveform, output_path, sample_rate, min_length, max_length):
71 | """
72 | Save clips from a given waveform
73 | """
74 |
75 | # Calculate clip lengths in samples
76 | min_length_samples = int(min_length * sample_rate)
77 | max_length_samples = int(max_length * sample_rate)
78 |
79 | # Get total number of samples in the waveform
80 | total_samples = waveform.shape[1]
81 |
82 | start = 0
83 | end = np.random.randint(min_length_samples, max_length_samples + 1)
84 | clip_number = 1
85 |
86 | # Keep creating clips until we've covered the whole waveform
87 | while end <= total_samples:
88 | # Slice the waveform to get the clip
89 | clip = waveform[:, start:end]
90 |
91 | # Check if the clip contains all zeros (is silent)
92 | if abs(clip).mean() > THRESHOLD:
93 | # Save the clip
94 | output_clip_path = f"{output_path.rsplit('.', 1)[0]}_clip_{clip_number}.flac"
95 | torchaudio.save(output_clip_path, clip, sample_rate)
96 |
97 | # Increment the clip number
98 | clip_number += 1
99 |
100 | # Update the start and end for the next clip
101 | start = end
102 | end = start + np.random.randint(min_length_samples, max_length_samples + 1)
103 |
104 |
105 | def process_directory(input_dir, output_dir, sample_rate=48000, min_length=6, max_length=10):
106 | """
107 | Process a directory of audio files
108 | """
109 |
110 | # Walk through each file in the input directory
111 | for dirpath, dirnames, filenames in os.walk(input_dir):
112 | for filename in filenames:
113 | if filename.endswith(".flac"):
114 | input_file_path = os.path.join(dirpath, filename)
115 | waveform_, sr = torchaudio.load(input_file_path)
116 |
117 | # Resample the audio
118 | upsample_transform = Resample(sr, sample_rate)
119 | waveform = upsample_transform(waveform_)
120 |
121 | relative_path = os.path.relpath(dirpath, input_dir)
122 | clip_output_dir = os.path.join(output_dir, relative_path)
123 |
124 | os.makedirs(clip_output_dir, exist_ok=True)
125 |
126 | output_file_path = os.path.join(clip_output_dir, filename)
127 |
128 | # set audio length
129 | split = dirpath.split('/')[-3]
130 | if split == 'test':
131 | min_length = 20
132 | max_length = 21
133 | else:
134 | min_length = 6
135 | max_length = 7
136 | # Save clips from the resampled audio
137 | save_clips(waveform, output_file_path, sample_rate, min_length, max_length)
138 |
139 |
140 | # Read instruments
141 | source_dir = 'data/source/slakh2100_flac_redux'
142 | inst_dict = read_instruments(source_dir)
143 | print(inst_dict)
144 |
145 | # Copy query instruments
146 | os.makedirs('data/MIDI', exist_ok=True)
147 | os.makedirs('data/MIDI/full', exist_ok=True)
148 |
149 | # Loop through each instrument in the dictionary
150 | for query, inst_class in inst_dict.items():
151 | os.makedirs(f'data/MIDI/full/{inst_class}', exist_ok=True)
152 | target_dir = f'data/MIDI/full/{inst_class}/{query}'
153 | print(target_dir)
154 |
155 | # Copy each instrument file to the target directory
156 | for subdir in ['train', 'test', 'validation']:
157 | copy_files_based_on_metadata(os.path.join(source_dir, subdir), os.path.join(target_dir, subdir), query)
158 | print('Copy done!')
159 |
160 | # Clip the copied audio files
161 | os.makedirs('data/MIDI/clip', exist_ok=True)
162 | for query, inst_class in inst_dict.items():
163 | os.makedirs(f'data/MIDI/clip/{inst_class}', exist_ok=True)
164 | target_dir = f'data/MIDI/full/{inst_class}/{query}'
165 | clip_dir = f'data/MIDI/clip/{inst_class}/{query}'
166 | print(clip_dir)
167 |
168 | process_directory(target_dir, clip_dir)
169 | print('Clip done!')
170 |
--------------------------------------------------------------------------------
/soundspaces_nvas3d/image_rendering/generate_target_image.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import os
7 | import numpy as np
8 | import glob
9 | import torch
10 | import matplotlib.pyplot as plt
11 | from soundspaces_nvas3d.soundspaces_nvas3d import Receiver
12 | from soundspaces_nvas3d.utils.aihabitat_utils import load_room_grid
13 | from soundspaces_nvas3d.utils.ss_utils import create_scene
14 | import argparse
15 |
16 |
17 | def contains_png_files(dirname):
18 | return len(glob.glob(os.path.join(dirname, "*.png"))) > 0
19 |
20 |
21 | def render_target_images(args):
22 | room = args.room
23 | grid_distance = args.grid_distance
24 | image_size = (192, 144)
25 |
26 | # Download soundspaces dataset if not exists
27 | grid_points = load_room_grid(room, grid_distance=grid_distance)['grid_points']
28 | grid_distance_str = str(grid_distance).replace(".", "_")
29 | if args.dirname is None:
30 | dirname = f'data/examples/target_image_mp3d/grid_{grid_distance_str}/{room}'
31 | else:
32 | dirname = args.dirname
33 | os.makedirs(dirname, exist_ok=True)
34 |
35 | # # Return if there is a png file already
36 | # if contains_png_files(dirname):
37 | # return
38 |
39 | scene = create_scene(room, image_size=image_size)
40 | sample_rate = 48000
41 | position = [0.0, 0, 0]
42 | rotation = 0.0
43 | receiver = Receiver(position, rotation, sample_rate)
44 |
45 | # Render image from 1m distance
46 | dist = 1.0
47 | # south east north west
48 | position_offset_list = [[0.0, 0.0, dist], [dist, 0.0, 0.0], [0.0, 0.0, -dist], [-dist, 0.0, 0.0]]
49 | rotation_offset_list = [0.0, 90.0, 180.0, 270.0]
50 |
51 | source_class_list = ['female', 'male', 'guitar']
52 | # source_class_list = ['guitar']
53 |
54 | if args.source_idx_list is None:
55 | source_idx_list = range(grid_points.shape[0])
56 | else:
57 | source_idx_list = list(map(int, args.source_idx_list.split()))
58 |
59 | for source_idx in source_idx_list:
60 | # print(f'{room}: {source_idx}/{len(source_idx_list)}')
61 | source_position = grid_points[source_idx]
62 |
63 | rgb = []
64 | depth = []
65 |
66 | for position_offset, rotation_offset in zip(position_offset_list, rotation_offset_list):
67 | receiver.position = source_position + torch.tensor(position_offset)
68 | receiver.rotation = rotation_offset
69 | scene.update_receiver(receiver)
70 | rgb_, depth_ = scene.render_image()
71 | rgb.append(rgb_)
72 | depth.append(depth_)
73 |
74 | rgb = np.concatenate(rgb, axis=1)
75 | depth = np.concatenate(depth, axis=1)
76 |
77 | filename_rgb_png = f'{dirname}/rgb_{room}_{source_idx}.png'
78 | filename_depth_png = f'{dirname}/depth_{room}_{source_idx}.png'
79 |
80 | plt.imsave(filename_rgb_png, rgb)
81 | plt.imsave(filename_depth_png, depth)
82 |
83 | filename_rgb_npy = f'{dirname}/rgb_{room}_{source_idx}.npy'
84 | filename_depth_npy = f'{dirname}/depth_{room}_{source_idx}.npy'
85 |
86 | np.save(filename_rgb_npy, rgb)
87 | np.save(filename_depth_npy, depth)
88 |
89 | # Optional: Render with mesh. Requires mesh data for habitat under data/objects/{source.mesh}
90 | # source_female = Source(
91 | # position=position,
92 | # rotation=random.uniform(0, 360),
93 | # dry_sound='',
94 | # mesh='female',
95 | # device=torch.device('cpu')
96 | # )
97 |
98 | # source_male = Source(
99 | # position=position,
100 | # rotation=random.uniform(0, 360),
101 | # dry_sound='',
102 | # mesh='male',
103 | # device=torch.device('cpu')
104 | # )
105 |
106 | # source_guitar = Source(
107 | # position=position,
108 | # rotation=random.uniform(0, 360),
109 | # dry_sound='',
110 | # mesh='guitar',
111 | # device=torch.device('cpu')
112 | # )
113 |
114 | # source_list = [source_female, source_male, source_guitar]
115 | # # source_list = [source_guitar]
116 |
117 | # for source_class, source in zip(source_class_list, source_list):
118 | # scene.add_source_mesh = True
119 | # scene.source_list = [None]
120 | # source.position = source_position
121 |
122 | # scene.update_source(source, 0)
123 | # rgb = []
124 | # depth = []
125 |
126 | # for position_offset, rotation_offset in zip(position_offset_list, rotation_offset_list):
127 | # receiver.position = source_position + torch.tensor(position_offset)
128 | # receiver.rotation = rotation_offset
129 | # scene.update_receiver(receiver)
130 | # rgb_, depth_ = scene.render_image()
131 | # rgb.append(rgb_)
132 | # depth.append(depth_)
133 |
134 | # scene.sim.get_rigid_object_manager().remove_all_objects()
135 |
136 | # rgb = np.concatenate(rgb, axis=1)
137 | # depth = np.concatenate(depth, axis=1)
138 |
139 | # filename_rgb_png = f'{dirname}/rgb_{room}_{source_class}_{source_idx}.png'
140 | # filename_depth_png = f'{dirname}/depth_{room}_{source_class}_{source_idx}.png'
141 |
142 | # plt.imsave(filename_rgb_png, rgb)
143 | # plt.imsave(filename_depth_png, depth)
144 |
145 | # filename_rgb_npy = f'{dirname}/rgb_{room}_{source_class}_{source_idx}.npy'
146 | # filename_depth_npy = f'{dirname}/depth_{room}_{source_class}_{source_idx}.npy'
147 |
148 | # np.save(filename_rgb_npy, rgb)
149 | # np.save(filename_depth_npy, depth)
150 |
151 |
152 | if __name__ == '__main__':
153 | parser = argparse.ArgumentParser()
154 |
155 | parser.add_argument('--room',
156 | default='17DRP5sb8fy',
157 | type=str,
158 | help='mp3d room')
159 |
160 | parser.add_argument('--source_idx_list',
161 | default=None,
162 | type=str,
163 | help='source_idx_list')
164 |
165 | parser.add_argument('--dirname',
166 | default=None,
167 | type=str,
168 | help='dirname')
169 |
170 | parser.add_argument('--grid_distance',
171 | default=1.0,
172 | type=float,
173 | help='distance between grid points')
174 |
175 | args = parser.parse_args()
176 |
177 | render_target_images(args)
178 |
--------------------------------------------------------------------------------
/soundspaces_nvas3d/README.md:
--------------------------------------------------------------------------------
1 | # SoundSpaces for NVAS3D
2 | This guide provides a step-by-step process to set up and generate data using SoundSpaces for NVAS3D. You can also quickly start the installation process by running the `setup.sh` script included in this guide.
3 |
4 | ## Prerequisites for [SoundSpaces](https://github.com/facebookresearch/sound-spaces)
5 | - Ubuntu 20.04 or a similar Linux distribution
6 | - CUDA
7 | - [Conda](https://docs.conda.io/en/latest/miniconda.html)
8 |
9 |
10 | ## Installation
11 | Here we repeat the installation steps from [SoundSpaces Installation Guide](https://github.com/facebookresearch/sound-spaces/blob/main/INSTALLATION.md).
12 |
13 | ### 0. Update `PYTHONPATH`
14 | Ensure your `PYTHONPATH` is updated:
15 | ```bash
16 | cd ml-nvas3d # the root of the repository
17 | export PYTHONPATH=$PYTHONPATH:$(pwd)
18 | ```
19 |
20 | ### 1. Install SoundSpaces Dependencies
21 | Install required dependencies for SoundSpaces:
22 | ```bash
23 | apt-get update && apt-get upgrade -y && \
24 | apt-get install -y --no-install-recommends libjpeg-dev libglm-dev libgl1-mesa-glx libegl1-mesa-dev mesa-utils xorg-dev freeglut3-dev
25 | ```
26 |
27 | ### 2. Update Conda
28 | Ensure that your Conda installation is up to date:
29 | ```bash
30 | conda update -n base -c defaults conda
31 | ```
32 |
33 | ### 3. Create and Activate Conda Environment
34 | Create a new Conda environment named nvas3d with Python 3.7 and cmake 3.14.0:
35 | ```bash
36 | conda create -n nvas3d python=3.7 cmake=3.14.0 -y && \
37 | conda activate nvas3d
38 | ```
39 |
40 | ### 4. Install PyTorch
41 | Install PyTorch, torchvision, torchaudio, and the CUDA toolkit. Replace the version accordingly with your CUDA version:
42 | ```bash
43 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
44 | ```
45 |
46 | ### 5. Install [Habitat-Sim](https://github.com/facebookresearch/habitat-sim/tree/main)
47 | Install Habitat-Sim:
48 | ```bash
49 | # step outside the repo root and git clone Habitat-Sim
50 | cd .. && \
51 | git clone https://github.com/facebookresearch/habitat-sim.git && \
52 | cd habitat-sim && \
53 | pip install -r requirements.txt && \
54 | git checkout RLRAudioPropagationUpdate && \
55 | python setup.py install --headless --audio --with-cuda
56 | ```
57 | This guide is based on commit `30f4cc7`.
58 |
59 | ### 6. Install [Habitat-Lab](https://github.com/facebookresearch/habitat-lab)
60 | Install Habitat-Lab:
61 | ```bash
62 | # step outside the repo root and git clone Habitat-Lab
63 | cd .. &&\
64 | git clone https://github.com/facebookresearch/habitat-lab.git && \
65 | cd habitat-lab && \
66 | git checkout v0.2.2 && \
67 | pip install -e . && \
68 | sed -i '36 s/^/#/' habitat/tasks/rearrange/rearrange_sim.py # remove FetchRobot
69 | ```
70 | This guide is based on commit `49f7c15`.
71 |
72 |
73 | ### 7. Install [SoundSpaces](https://github.com/facebookresearch/sound-spaces/tree/main)
74 | Install SoundSpaces:
75 | ```bash
76 | # step outside the repo root and git clone SoundSpaces
77 | cd .. &&\
78 | git clone https://github.com/facebookresearch/sound-spaces.git && \
79 | cd sound-spaces && \
80 | pip install -e .
81 | ```
82 | This guide is based on commit `3768a50`.
83 |
84 | ### 8. Install Additional Packages
85 | Install additional Python packages needed:
86 | ```bash
87 | pip install scipy torchmetrics pyroomacoustics
88 | ```
89 |
90 | ## Quick Installation
91 | To streamline the installation process, run the setup.sh script, which encapsulates all the steps listed above:
92 | ```bash
93 | bash setup.sh
94 | ```
95 |
96 | ## Preparation of Matterport3D Room and Material
97 |
98 | ### 1. Download Example MP3D Room
99 | Follow these steps to download the MP3D data in the correct directory:
100 |
101 |
102 | (1) **Switch to habitat-sim directory**:
103 | ```bash
104 | cd /path/to/habitat-sim
105 | ```
106 |
107 | (2) **Run the dataset download script**:
108 | ```bash
109 | python src_python/habitat_sim/utils/datasets_download.py --uids mp3d_example_scene
110 | ```
111 |
112 | (3) **Copy the downloaded data to this repository**:
113 | ```bash
114 | mkdir -p /path/to/ml-nvas3d/data/scene_datasets/
115 | cp -r data/scene_datasets/mp3d_example /path/to/ml-nvas3d/data/scene_datasets/mp3d
116 | ```
117 |
118 | After executing the above steps, ensure the existence of the `data/scene_datasets/mp3d/17DRP5sb8fy` directory. For additional rooms, you might want to consider downloading from [Matterport3D](https://niessner.github.io/Matterport/).
119 |
120 | ### 2. Download Material Configuration File
121 | Download the material configuration file in the correct directory:
122 |
123 | ```bash
124 | cd /path/to/ml-nvas3d && \
125 | mkdir data/material && \
126 | cd data/material && wget https://raw.githubusercontent.com/facebookresearch/rlr-audio-propagation/main/RLRAudioPropagationPkg/data/mp3d_material_config.json
127 | ```
128 |
129 | > [!Note]
130 | > Now, you are ready to run [Demo](../demo/README.md) using our pretrained model.
131 | >
132 | > To explore with more data, you can consider running [Training Data Generation](../nvas3d/utils/training_data_generation/README.md).
133 |
134 |
135 | ## Extended Usage: SoundSpaces for NVAS3D
136 | For users interested in exploring with more data, follow the steps outlined below.
137 |
138 | ### Generate grid points in room:
139 | To create grid points within the room, execute the following command:
140 | ```bash
141 | python soundspaces_nvas3d/rir_generation/generate_grid.py --grid_distance 1.0
142 | ```
143 | * Input directory: `data/scene_datasets/mp3d/`
144 | * Output directory: `data/scene_datasets/metadata/mp3d/grid_{grid_distance}`
145 |
146 | ### Generate RIRs:
147 | To generate room impulse responses for all grid point pairs, execute the following command:
148 | ```bash
149 | python soundspaces_nvas3d/rir_generation/generate_rir.py --room 17DRP5sb8fy
150 | ```
151 | * Input directory: `data/scene_datasets/mp3d/{room}`
152 | * Output directory:`data/examples/rir_mp3d/grid_{grid_distance}/{room}`
153 |
154 | ### Minimal Example of RIR Generation Using SoundSpaces
155 | We provide an mimimal example code for generating RIRs using our codebase. To generate sample RIRs, execute the following command:
156 | ```bash
157 | python demo/sound_spaces/nvas3d/example_render_ir.py
158 | ```
159 |
160 | ---
161 | Optionally, for those interested in training an audio-visual network, images can be rendered using the following methods:
162 | ### Render target images:
163 | To render images centered at individual grid points, execute the following command:
164 | ```bash
165 | python soundspaces_nvas3d/image_rendering/run_generate_target_image.py
166 | ```
167 | * Input directory: `data/scene_datasets/mp3d/{room}`
168 | * Output directory: `data/examples/target_image_mp3d/grid_{grid_distance}/{room}`
169 |
170 | ### Render environment maps:
171 | To render environment maps from all grid points, execute the following command:
172 | ```bash
173 | python soundspaces_nvas3d/image_rendering/run_generate_envmap.py
174 | ```
175 | * Input directory: `data/scene_datasets/mp3d/{room}`
176 | * Output directory: `data/examples/envmap_3d/grid_{grid_distance}/{room}`
177 |
178 |
--------------------------------------------------------------------------------
/nvas3d/utils/generate_dataset_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import glob
7 | import random
8 | import numpy as np
9 | import typing as T
10 | from scipy.signal import fftconvolve
11 |
12 | import torch
13 | import torchaudio
14 | import torch.nn.functional as F
15 |
16 | MIN_LENGTH_AUDIO = 5 * 48000
17 |
18 |
19 | def load_ir_source_receiver(
20 | ir_dir: str,
21 | room: str,
22 | source_idx: int,
23 | receiver_idx_list: T.List[int],
24 | ir_length: int
25 | ) -> T.List[torch.Tensor]:
26 | """
27 | Load impulse responses for specific source and receivers in a room.
28 |
29 | Args:
30 | - ir_dir: Directory containing impulse response files.
31 | - room: Name of the room.
32 | - source_idx: Index of the source.
33 | - receiver_idx_list: List of receiver indices.
34 | - ir_length: Length of the impulse response to be loaded.
35 |
36 | Returns:
37 | - List of loaded impulse responses (first channel only).
38 | """
39 |
40 | ir_list = []
41 | for receiver_idx in receiver_idx_list:
42 | filename_ir = f'{ir_dir}/{room}/ir_{room}_{source_idx}_{receiver_idx}.wav'
43 | ir, _ = torchaudio.load(filename_ir)
44 | if ir[0].shape[0] > ir_length:
45 | ir0 = ir[0][:ir_length]
46 | else:
47 | ir0 = F.pad(ir[0], (0, ir_length - ir[0].shape[0]))
48 | ir_list.append(ir0)
49 |
50 | return ir_list
51 |
52 |
53 | def load_ir_source_receiver_allchannel(
54 | ir_dir: str,
55 | room: str,
56 | source_idx: int,
57 | receiver_idx_list: T.List[int],
58 | ir_length: int
59 | ) -> T.List[torch.Tensor]:
60 | """
61 | Load impulse responses for all channels for specific source and receivers in a room.
62 |
63 | Args:
64 | - ir_dir: Directory containing impulse response files.
65 | - room: Name of the room.
66 | - source_idx: Index of the source.
67 | - receiver_idx_list: List of receiver indices.
68 | - ir_length: Length of the impulse response to be loaded.
69 |
70 | Returns:
71 | - List of loaded impulse responses for all channels.
72 | """
73 |
74 | ir_list = []
75 | for receiver_idx in receiver_idx_list:
76 | filename_ir = f'{ir_dir}/{room}/ir_{room}_{source_idx}_{receiver_idx}.wav'
77 | ir, _ = torchaudio.load(filename_ir)
78 | if ir.shape[-1] > ir_length:
79 | ir = ir[..., :ir_length]
80 | else:
81 | ir = F.pad(ir, (0, ir_length - ir[0].shape[0]))
82 | ir_list.append(ir)
83 |
84 | return ir_list
85 |
86 |
87 | def save_audio_list(
88 | filename: str,
89 | audio_list: T.List[torch.Tensor],
90 | sample_rate: int,
91 | audio_format: str
92 | ):
93 | """
94 | Save a list of audio tensors to files.
95 |
96 | Args:
97 | - filename: Filename to save audio.
98 | - audio_list: List of audio tensors to save.
99 | - sample_rate: Sample rate of audio.
100 | - audio_format: File format to save audio.
101 | """
102 |
103 | for idx_audio, audio in enumerate(audio_list):
104 | torchaudio.save(f'{filename}_{idx_audio+1}.{audio_format}', audio.unsqueeze(0), sample_rate)
105 |
106 |
107 | def clip_source(
108 | source1_audio: torch.Tensor,
109 | source2_audio: torch.Tensor,
110 | len_clip: int
111 | ) -> T.Tuple[torch.Tensor, torch.Tensor]:
112 | """
113 | Clip source audio tensors for faster convolution.
114 |
115 | Args:
116 | - source1_audio: First source audio tensor.
117 | - source2_audio: Second source audio tensor.
118 | - len_clip: Desired length of the output audio tensors.
119 |
120 | Returns:
121 | - Clipped source1_audio and source2_audio.
122 | """
123 |
124 | # pad audio
125 | if len_clip > source1_audio.shape[0]:
126 | source1_audio = F.pad(source1_audio, (0, len_clip - source1_audio.shape[0]))
127 | source1_audio = F.pad(source1_audio, (0, max(0, len_clip - source1_audio.shape[0])))
128 | source2_audio = F.pad(source2_audio, (0, max(0, len_clip - source2_audio.shape[0])))
129 |
130 | # clip
131 | start_index = np.random.randint(0, source1_audio.shape[0] - len_clip) \
132 | if source1_audio.shape[0] != len_clip else 0
133 | source1_audio_clipped = source1_audio[start_index: start_index + len_clip]
134 | source2_audio_clipped = source2_audio[start_index: start_index + len_clip]
135 |
136 | return source1_audio_clipped, source2_audio_clipped
137 |
138 |
139 | def compute_reverb(
140 | source_audio: torch.Tensor,
141 | ir_list: T.List[torch.Tensor],
142 | padding: str = 'valid'
143 | ) -> T.List[torch.Tensor]:
144 | """
145 | Compute reverberated audio signals by convolving source audio with impulse responses.
146 |
147 | Args:
148 | - source_audio: Source audio signal (dry) to be reverberated.
149 | - ir_list: List of impulse responses for reverberation.
150 | - padding: Padding mode for convolution ('valid' or 'full').
151 |
152 | Returns:
153 | - A list of reverberated audio signals.
154 | """
155 |
156 | reverb_list = []
157 | for ir in ir_list:
158 | reverb = fftconvolve(source_audio, ir, padding)
159 | reverb_list.append(torch.from_numpy(reverb))
160 |
161 | return reverb_list
162 |
163 |
164 | ####################
165 | # Sampling source audio to generate data
166 | ####################
167 |
168 | def sample_speech(files_librispeech, librispeech_metadata):
169 | source_speech = torch.zeros(1) # Initialize with a tensor of zeros
170 | while torch.all(source_speech == 0) or source_speech.shape[-1] < MIN_LENGTH_AUDIO: # Continue until a non-zero tensor is found
171 | filename_source = random.choice(files_librispeech)
172 | speaker_id = int(filename_source.split('/')[6])
173 | speaker_gender = librispeech_metadata[speaker_id]
174 | if speaker_gender == 'M':
175 | source_class = 'male'
176 | else:
177 | source_class = 'female'
178 | source_speech, _ = torchaudio.load(filename_source)
179 | source_speech = source_speech.reshape(-1)
180 |
181 | return source_speech, source_class
182 |
183 |
184 | def sample_nonspeech(all_instruments_dir):
185 | class_dir = random.choice(all_instruments_dir)
186 |
187 | # Ensure that the class is not 'Speech'
188 | while 'Speech' in class_dir:
189 | class_dir = random.choice(all_instruments_dir)
190 |
191 | files_source = glob.glob(class_dir + '/**/*.flac', recursive=True)
192 |
193 | source_audio = torch.zeros(1) # Initialize with a tensor of zeros
194 | while torch.all(source_audio == 0) or source_audio.shape[-1] < MIN_LENGTH_AUDIO: # Continue until a non-zero tensor is found
195 | filename_source = random.choice(files_source)
196 | source_class = class_dir.split('/')[3]
197 | source_audio, _ = torchaudio.load(filename_source)
198 | source_audio = source_audio.reshape(-1)
199 |
200 | return source_audio, source_class
201 |
202 |
203 | def sample_acoustic_guitar(all_instruments_dir):
204 | guitar_dir = [dirname for dirname in all_instruments_dir if dirname.split('/')[4] == 'Acoustic Guitar (steel)']
205 |
206 | class_dir = random.choice(guitar_dir)
207 |
208 | files_source = glob.glob(class_dir + '/**/*.flac', recursive=True)
209 |
210 | source_audio = torch.zeros(1) # Initialize with a tensor of zeros
211 | while torch.all(source_audio == 0) or source_audio.shape[-1] < MIN_LENGTH_AUDIO: # Continue until a non-zero tensor is found
212 | filename_source = random.choice(files_source)
213 | source_class = 'guitar'
214 | source_audio, _ = torchaudio.load(filename_source)
215 | source_audio = source_audio.reshape(-1)
216 |
217 | return source_audio, source_class
218 |
219 |
220 | def sample_instrument(all_instruments_dir, librispeech_metadata, classname):
221 | guitar_dir = [dirname for dirname in all_instruments_dir if dirname.split('/')[3] == classname] # e.g., Guitar
222 |
223 | class_dir = random.choice(guitar_dir)
224 |
225 | files_source = glob.glob(class_dir + '/**/*.flac', recursive=True)
226 |
227 | source_audio = torch.zeros(1) # Initialize with a tensor of zeros
228 | while torch.all(source_audio == 0) or source_audio.shape[-1] < MIN_LENGTH_AUDIO: # Continue until a non-zero tensor is found
229 | filename_source = random.choice(files_source)
230 | source_class = 'guitar'
231 | source_audio, _ = torchaudio.load(filename_source)
232 | source_audio = source_audio.reshape(-1)
233 |
234 | return source_audio, source_class
235 |
236 |
237 | def sample_all(all_instruments_dir, librispeech_metadata):
238 | class_dir = random.choice(all_instruments_dir)
239 | files_source = glob.glob(class_dir + '/**/*.flac', recursive=True)
240 |
241 | source_audio = torch.zeros(1) # Initialize with a tensor of zeros
242 | while torch.all(source_audio == 0) or source_audio.shape[-1] < MIN_LENGTH_AUDIO: # Continue until a non-zero tensor is found
243 | filename_source = random.choice(files_source)
244 | source_class = class_dir.split('/')[3]
245 | source_audio, _ = torchaudio.load(filename_source)
246 | source_audio = source_audio.reshape(-1)
247 |
248 | if source_class == 'Speech':
249 | speaker_id = int(filename_source.split('/')[6])
250 | speaker_gender = librispeech_metadata[speaker_id]
251 | if speaker_gender == 'M':
252 | source_class = 'male'
253 | else:
254 | source_class = 'female'
255 |
256 | return source_audio, source_class
257 |
--------------------------------------------------------------------------------
/soundspaces_nvas3d/utils/audio_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import torch
7 | import torchaudio
8 | import matplotlib.pyplot as plt
9 | import torch.nn.functional as F
10 | import typing as T
11 |
12 |
13 | def fft_conv(
14 | signal: torch.Tensor,
15 | kernel: torch.Tensor,
16 | is_cpu: bool = False
17 | ) -> torch.Tensor:
18 | """
19 | Perform convolution of a signal and a kernel using Fast Fourier Transform (FFT).
20 |
21 | Args:
22 | - signal (torch.Tensor): Input signal tensor.
23 | - kernel (torch.Tensor): Kernel tensor.
24 | - is_cpu (bool, optional): Flag to determine if the operation should be on the CPU.
25 |
26 | Returns:
27 | - torch.Tensor: Convolved signal.
28 | """
29 |
30 | if is_cpu:
31 | signal = signal.detach().cpu()
32 | kernel = kernel.detach().cpu()
33 |
34 | padded_signal = F.pad(signal.reshape(-1), (0, kernel.size(-1) - 1))
35 | padded_kernel = F.pad(kernel.reshape(-1), (0, signal.size(-1) - 1))
36 |
37 | signal_fr = torch.fft.rfftn(padded_signal, dim=-1)
38 | kernel_fr = torch.fft.rfftn(padded_kernel, dim=-1)
39 |
40 | output_fr = signal_fr * kernel_fr
41 | output = torch.fft.irfftn(output_fr, dim=-1)
42 |
43 | return output
44 |
45 |
46 | def wiener_deconv(
47 | signal: torch.Tensor,
48 | kernel: torch.Tensor,
49 | snr: float,
50 | is_cpu: bool = False
51 | ) -> torch.Tensor:
52 | """
53 | Perform Wiener deconvolution on a given signal using a specified kernel.
54 |
55 | Args:
56 | - signal (torch.Tensor): Input signal tensor.
57 | - kernel (torch.Tensor): Kernel tensor.
58 | - snr (float): Signal-to-noise ratio.
59 | - is_cpu (bool, optional): Flag to determine if the operation should be on the CPU.
60 |
61 | Returns:
62 | - torch.Tensor: Deconvolved signal.
63 | """
64 |
65 | if is_cpu:
66 | signal = signal.detach().cpu()
67 | kernel = kernel.detach().cpu()
68 |
69 | n_fft = signal.shape[-1] + kernel.shape[-1] - 1
70 | signal_fr = torch.fft.rfft(signal.reshape(-1), n=n_fft)
71 | kernel_fr = torch.fft.rfft(kernel.reshape(-1), n=n_fft)
72 |
73 | wiener_filter_fr = torch.conj(kernel_fr) / (torch.abs(kernel_fr)**2 + 1 / snr)
74 |
75 | filtered_signal_fr = wiener_filter_fr * signal_fr
76 |
77 | filtered_signal = torch.fft.irfft(filtered_signal_fr)
78 |
79 | # Crop the filtered signal to the original size
80 | filtered_signal = filtered_signal[:signal.shape[-1]]
81 |
82 | return filtered_signal
83 |
84 |
85 | def wiener_deconv_list(
86 | signal: T.List[torch.Tensor],
87 | kernel: T.List[torch.Tensor],
88 | snr: float,
89 | is_cpu: bool = False
90 | ) -> torch.Tensor:
91 | """
92 | wiener_deconv for list input.
93 |
94 | Args:
95 | - signal (torch.Tensor): List of signals.
96 | - kernel (torch.Tensor): List of kernels.
97 | - snr (float): Signal-to-noise ratio.
98 | - is_cpu (bool, optional): Flag to determine if the operation should be on the CPU.
99 |
100 | Returns:
101 | - torch.Tensor: Deconvolved signal.
102 | """
103 |
104 | M = len(signal)
105 | if isinstance(signal, list):
106 | signal = torch.stack(signal).reshape(M, -1)
107 | assert signal.shape[0] == M
108 | kernel = torch.stack(kernel).reshape(M, -1)
109 | snr /= abs(kernel).max()
110 |
111 | if is_cpu:
112 | signal = signal.detach().cpu()
113 | kernel = kernel.detach().cpu()
114 |
115 | n_batch, n_samples = signal.shape
116 |
117 | # Pad the signals and kernels to avoid circular convolution
118 | padded_signal = F.pad(signal, (0, kernel.shape[-1] - 1))
119 | padded_kernel = F.pad(kernel, (0, signal.shape[-1] - 1))
120 |
121 | # Compute the Fourier transforms
122 | signal_fr = torch.fft.rfft(padded_signal, dim=-1)
123 | kernel_fr = torch.fft.rfft(padded_kernel, dim=-1)
124 |
125 | # Compute the Wiener filter in the frequency domain
126 | wiener_filter_fr = torch.conj(kernel_fr) / (torch.abs(kernel_fr)**2 + 1 / snr)
127 |
128 | # Apply the Wiener filter
129 | filtered_signal_fr = wiener_filter_fr * signal_fr
130 |
131 | # Compute the inverse Fourier transform
132 | filtered_signal = torch.fft.irfft(filtered_signal_fr, dim=-1)
133 |
134 | # Crop the filtered signals to the original size
135 | filtered_signal = filtered_signal[:, :n_samples]
136 |
137 | filtered_signal_list = [filtered_signal[i] for i in range(filtered_signal.size(0))]
138 |
139 | return filtered_signal_list
140 |
141 |
142 | def save_audio(
143 | filename: str,
144 | waveform: torch.Tensor, # (ch, time) in cpu
145 | sample_rate: float
146 | ):
147 | """
148 | Save an audio waveform to a file.
149 |
150 | Args:
151 | - filename (str): Output filename.
152 | - waveform (torch.Tensor): Audio waveform tensor.
153 | - sample_rate (float): Sample rate of the audio.
154 | """
155 |
156 | torchaudio.save(filename, waveform, sample_rate=sample_rate)
157 |
158 |
159 | def plot_waveform(filename: str,
160 | waveform: torch.Tensor, # (ch, time) in cpu
161 | sample_rate: float = 48000,
162 | title: str = None,
163 | sharex: bool = True,
164 | xlim: T.Tuple[float, float] = None,
165 | ylim: T.Tuple[float, float] = None,
166 | color: str = 'orange',
167 | waveform_ref: torch.Tensor = None
168 | ):
169 | """
170 | Plot an audio waveform.
171 |
172 | Args:
173 | - filename (str): Output filename.
174 | - waveform (torch.Tensor): Audio waveform tensor.
175 | - sample_rate (float, optional): Sample rate of the audio. Defaults to 48000.
176 | - title (str, optional): Title for the plot.
177 | - sharex (bool, optional): Whether to share the x-axis across subplots. Defaults to True.
178 | - xlim (T.Tuple[float, float], optional): Limits for x-axis.
179 | - ylim (T.Tuple[float, float], optional): Limits for y-axis.
180 | - color (str, optional): Color for the waveform plot. Defaults to 'orange'.
181 | - waveform_ref (torch.Tensor, optional): Reference waveform for comparison.
182 | """
183 |
184 | num_channels, num_frames = waveform.shape
185 | time_axis = torch.arange(0, num_frames) / sample_rate
186 | if waveform_ref is not None:
187 | num_frames_ref = waveform_ref.shape[0] # should be 1D
188 | time_axis_ref = torch.arange(0, num_frames_ref) / sample_rate
189 |
190 | if ylim is None:
191 | margin = 1.1
192 | ylim = (margin * waveform.min(), margin * waveform.max())
193 |
194 | figure, axes = plt.subplots(num_channels, 1, sharex=sharex)
195 | if num_channels == 1:
196 | axes = [axes]
197 | for c in range(num_channels):
198 | if waveform_ref is None:
199 | axes[c].plot(time_axis, waveform[c], color=color, linewidth=1)
200 | else:
201 | axes[c].plot(time_axis, waveform[c], color=color, alpha=0.5, linewidth=1, label='signal')
202 | axes[c].plot(time_axis_ref, waveform_ref, color='r', alpha=0.5, linewidth=1, label='reference')
203 | axes[c].legend()
204 | axes[c].grid(True)
205 | if num_channels > 1:
206 | axes[c].set_ylabel(f'Channel {c+1}')
207 | if xlim:
208 | axes[c].set_xlim(xlim)
209 | if ylim:
210 | axes[c].set_ylim(ylim)
211 | if title is not None:
212 | figure.suptitle(title)
213 |
214 | if filename is not None:
215 | plt.savefig(filename)
216 | else:
217 | plt.show(block=False)
218 |
219 |
220 | def print_stats(
221 | waveform: torch.Tensor,
222 | sample_rate: T.Optional[float] = None,
223 | src: T.Optional[str] = None
224 | ):
225 | """
226 | Print the statistics of a given waveform.
227 |
228 | Args:
229 | - waveform (torch.Tensor): Input audio waveform tensor.
230 | - sample_rate (float, optional): Sample rate of the audio.
231 | - src (str, optional): Source of the audio, for display purposes.
232 | """
233 |
234 | if src:
235 | print("-" * 10)
236 | print("Source:", src)
237 | print("-" * 10)
238 | if sample_rate:
239 | print("Sample Rate:", sample_rate)
240 | print("Shape:", tuple(waveform.shape))
241 | print("Dtype:", waveform.dtype)
242 | print(f" - Max: {waveform.max().item():6.3f}")
243 | print(f" - Min: {waveform.min().item():6.3f}")
244 | print(f" - Mean: {waveform.mean().item():6.3f}")
245 | print(f" - Std Dev: {waveform.std().item():6.3f}")
246 | print()
247 | print(waveform)
248 | print()
249 |
250 |
251 | def plot_specgram(
252 | filename: str,
253 | waveform: torch.Tensor,
254 | sample_rate: float,
255 | title: T.Optional[str] = None,
256 | xlim: T.Optional[T.Tuple[float, float]] = None
257 | ):
258 | """
259 | Plot the spectrogram of a given audio waveform.
260 |
261 | Args:
262 | - filename (str): Output filename.
263 | - waveform (torch.Tensor): Audio waveform tensor.
264 | - sample_rate (float): Sample rate of the audio.
265 | - title (str, optional): Title for the plot.
266 | - xlim (T.Tuple[float, float], optional): Limits for x-axis.
267 | """
268 |
269 | waveform = waveform.numpy() if isinstance(waveform, torch.Tensor) else waveform
270 |
271 | num_channels, num_frames = waveform.shape
272 | time_axis = torch.arange(0, num_frames) / sample_rate
273 |
274 | figure, axes = plt.subplots(num_channels, 1)
275 | if num_channels == 1:
276 | axes = [axes]
277 | for c in range(num_channels):
278 | axes[c].specgram(waveform[c], Fs=sample_rate)
279 | if num_channels > 1:
280 | axes[c].set_ylabel(f'Channel {c+1}')
281 | if xlim:
282 | axes[c].set_xlim(xlim)
283 | if title is not None:
284 | figure.suptitle(title)
285 |
286 | if filename is not None:
287 | plt.savefig(filename)
288 | else:
289 | plt.show(block=False)
290 |
291 |
292 | def plot_debug(
293 | filename: str,
294 | waveform: torch.Tensor,
295 | sample_rate: float,
296 | save_png: bool = False
297 | ):
298 | """
299 | Generate and save waveform and spectrogram plots, and save audio waveform to a file.
300 |
301 | Args:
302 | - filename (str): Base filename for outputs.
303 | - waveform (torch.Tensor): Audio waveform tensor.
304 | - sample_rate (float): Sample rate of the audio.
305 | - save_png (bool, optional): Whether to save plots as PNG files. Defaults to False.
306 | """
307 |
308 | waveform = waveform.reshape(1, -1)
309 | if save_png:
310 | plot_specgram(f'{filename}_specgram.png', waveform, sample_rate)
311 | plot_waveform(f'{filename}_waveform.png', waveform, sample_rate)
312 | torchaudio.save(f'{filename}.wav', waveform, sample_rate)
313 |
--------------------------------------------------------------------------------
/nvas3d/utils/training_data_generation/generate_metadata_square.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import os
7 | import json
8 | import random
9 | import argparse
10 | import matplotlib.pyplot as plt
11 | from collections import defaultdict
12 | from itertools import combinations
13 |
14 | from soundspaces_nvas3d.utils.aihabitat_utils import load_room_grid
15 | from nvas3d.utils.utils import MP3D_SCENE_SPLITS
16 |
17 |
18 | def find_squares(points, grid_distance, tolerance=1e-5):
19 | squares = []
20 | for i, point1 in enumerate(points):
21 | x1, y1, z1 = point1
22 | for j, point2 in enumerate(points):
23 | if i != j:
24 | x2, y2, z2 = point2
25 | if abs(x2 - (x1 + grid_distance)) < tolerance and abs(z2 - z1) < tolerance:
26 | for k, point3 in enumerate(points):
27 | if k != i and k != j:
28 | x3, y3, z3 = point3
29 | if abs(x3 - x1) < tolerance and abs(z3 - (z1 + grid_distance)) < tolerance:
30 | for l, point4 in enumerate(points):
31 | if l != i and l != j and l != k:
32 | x4, y4, z4 = point4
33 | if abs(x4 - (x1 + grid_distance)) < tolerance and abs(z4 - (z1 + grid_distance)) < tolerance:
34 | squares.append((i, j, l, k))
35 | return squares
36 |
37 |
38 | def plot_points_and_squares(filename, points, squares):
39 | fig, ax = plt.subplots()
40 |
41 | # plot all points
42 | ax.scatter(*zip(*[(x, z) for x, _, z in points]), c='blue')
43 |
44 | # plot squares
45 | for square_indices in squares:
46 | square_points = [points[i] for i in square_indices]
47 | square_points.append(square_points[0]) # Add the first point again to close the square
48 | ax.plot(*zip(*[(x, z) for x, _, z in square_points]), c='red')
49 |
50 | ax.set_aspect('equal', 'box') # Ensure the aspect ratio is equal
51 | plt.savefig(filename)
52 | plt.show()
53 |
54 |
55 | def main(args):
56 | dataset_name = args.dataset_name
57 | num_pairs_per_room = args.num_pairs_per_room
58 | random.seed(0)
59 |
60 | grid_distance = args.grid_distance
61 | grid_distance_str = str(grid_distance).replace(".", "_")
62 |
63 | os.makedirs(f'data/{dataset_name}/metadata/grid_{grid_distance_str}', exist_ok=True)
64 |
65 | filesize_valid_total = 0.0
66 | filesize_invalid_total = 0.0
67 | num_valid_total = 0.0
68 | num_invalid_total = 0.0
69 | mean_firstnz_total = 0.0
70 | mean_rt60_total = 0.0
71 | mean_maximum_total = 0.0
72 | mean_mean_total = 0.0
73 | mean_num_points_total = 0.0
74 | mean_room_size_total = 0.0
75 |
76 | count_total = 0
77 |
78 | is_debug = False
79 | # Read jsons
80 | for split in ['demo']: # ['train', 'val', 'test', 'demo']:
81 |
82 | filesize_valid_total_ = 0.0
83 | filesize_invalid_total_ = 0.0
84 | num_valid_total_ = 0.0
85 | num_invalid_total_ = 0.0
86 | mean_firstnz_total_ = 0.0
87 | mean_rt60_total_ = 0.0
88 | mean_maximum_total_ = 0.0
89 | mean_mean_total_ = 0.0
90 | mean_num_points_total_ = 0.0
91 | mean_room_size_total_ = 0.0
92 | count_split = 0
93 |
94 | for i_room, room in enumerate(MP3D_SCENE_SPLITS[split]):
95 | if is_debug:
96 | print(f'room: {room} in {split} ({i_room}/{len(MP3D_SCENE_SPLITS[split])})')
97 | filename = f'data/metadata_grid/grid_{grid_distance_str}/{room}.json'
98 | with open(filename, 'r') as file:
99 | metadata = json.load(file)
100 |
101 | filesize_valid_total += metadata['filesize_valid_in_mb']
102 | filesize_invalid_total += metadata['filesize_invalid_in_mb']
103 | num_valid_total += metadata['num_valid']
104 | num_invalid_total += metadata['num_invalid']
105 | mean_firstnz_total = metadata['mean_firstnz']
106 | mean_rt60_total = metadata['mean_rt60']
107 | mean_maximum_total = metadata['mean_maximum']
108 | mean_mean_total = metadata['mean_mean']
109 | mean_num_points_total = metadata['num_points']
110 | mean_room_size_total = metadata['room_size']
111 |
112 | filesize_valid_total_ += metadata['filesize_valid_in_mb']
113 | filesize_invalid_total_ += metadata['filesize_invalid_in_mb']
114 | num_valid_total_ += metadata['num_valid']
115 | num_invalid_total_ += metadata['num_invalid']
116 | mean_firstnz_total_ = metadata['mean_firstnz']
117 | mean_rt60_total_ = metadata['mean_rt60']
118 | mean_maximum_total_ = metadata['mean_maximum']
119 | mean_mean_total_ = metadata['mean_mean']
120 | mean_num_points_total_ = metadata['num_points']
121 | mean_room_size_total_ = metadata['room_size']
122 |
123 | count_split += 1
124 | count_total += 1
125 |
126 | # find valid
127 | points = load_room_grid(room, grid_distance)['grid_points']
128 | squares = find_squares(points, grid_distance)
129 | # filename = f'data/metadata_grid/grid_{grid_distance_str}/{room}_square.png'
130 | # plot_points_and_squares(filename, points, squares)
131 |
132 | # Initialize defaultdicts to count source_idx and receiver_idx occurrences
133 | source_idx_counts = defaultdict(set)
134 | receiver_idx_counts = defaultdict(set)
135 |
136 | # Iterate over keys in valid metadata
137 | for key in metadata.keys():
138 | if key.startswith(room):
139 | # Extract source_idx and receiver_idx from key
140 | room, source_idx, receiver_idx = key.split("_")
141 |
142 | if metadata[key]['Is Valid'] == 'True':
143 | # Add receiver_idx to the set associated with source_idx
144 | # and vice versa
145 | source_idx_counts[source_idx].add(receiver_idx)
146 | receiver_idx_counts[receiver_idx].add(source_idx)
147 |
148 | # Initialize the empty dictionary where each value is a set
149 | square_to_source_idxs = defaultdict(set)
150 |
151 | # Iterate over each source_idx and its corresponding valid indices
152 | for source_idx, valid_indices in source_idx_counts.items():
153 | # Iterate over each square
154 | for square in squares:
155 | # Check if all indices of the current square are valid for the current source_idx
156 | if all(str(idx) in valid_indices for idx in square):
157 | # Add the source_idx to the square's set of valid source_idxs
158 | square_to_source_idxs[square].add(int(source_idx))
159 |
160 | # Filter out squares that have less than 3 valid source_idxs (2 for positive, 1 for negative)
161 | three_source_idx_squares = {square: source_idxs for square, source_idxs in square_to_source_idxs.items() if len(source_idxs) >= 3}
162 |
163 | # Initialize the list of pairs
164 | pairs = []
165 |
166 | # Iterate over the dictionary to build pairs
167 | for square, source_idxs in three_source_idx_squares.items():
168 | # Generate all combinations of 3 source indices
169 | for source_idx_pair in combinations(source_idxs, 3):
170 | # Avoid adding pairs where any source_idx is in square (receiver indices)
171 | if not any(source_idx in square for source_idx in source_idx_pair):
172 | # Find novel receiver that is valid to source 1 and source 2
173 | common_receiver = source_idx_counts[str(source_idx_pair[0])].intersection(source_idx_counts[str(source_idx_pair[1])])
174 | common_receiver = common_receiver - set(square) - set([source_idx_pair[0]]) - set([source_idx_pair[1]]) - set(square)
175 | if common_receiver:
176 | novel = int(random.choice(list(common_receiver)))
177 |
178 | # Add the pair to the list
179 | pairs.append((source_idx_pair, square, novel))
180 |
181 | # If there are less than num_pairs_per_room
182 | if len(pairs) < num_pairs_per_room:
183 | selected_pairs = pairs
184 | else:
185 | # Randomly select num_pairs_per_room
186 | selected_pairs = random.sample(pairs, num_pairs_per_room)
187 |
188 | # Save the lists to JSON files
189 | total_dict = {}
190 | total_dict['squares'] = squares
191 | total_dict['selected_pairs'] = selected_pairs
192 | # total_dict['square_to_source_idxs_keys'] = list(square_to_source_idxs.keys())
193 | # total_dict['square_to_source_idxs_values'] = list(square_to_source_idxs.values())
194 |
195 | filename = f'data/{dataset_name}/metadata/grid_{grid_distance_str}/{room}_square.json'
196 | with open(filename, 'w') as file:
197 | json.dump(total_dict, file)
198 |
199 | if is_debug:
200 | print(split)
201 | print(filesize_valid_total_ / 1024) # GB
202 | print(filesize_invalid_total_ / 1024) # GB
203 | print(num_valid_total_)
204 | print(num_invalid_total_)
205 | print(mean_firstnz_total_ / count_split)
206 | print(mean_rt60_total_ / count_split)
207 | print(mean_maximum_total_ / count_split)
208 | print(mean_mean_total_ / count_split)
209 | print(mean_num_points_total_ / count_split)
210 | print(mean_room_size_total_ / count_split)
211 |
212 | if is_debug:
213 | print(filesize_valid_total / 1024) # GB
214 | print(filesize_invalid_total / 1024) # GB
215 | print(num_valid_total)
216 | print(num_invalid_total)
217 | print(mean_firstnz_total / count_total)
218 | print(mean_rt60_total / count_total)
219 | print(mean_maximum_total / count_total)
220 | print(mean_mean_total / count_total)
221 | print(mean_num_points_total / count_total)
222 | print(mean_room_size_total / count_total)
223 |
224 |
225 | if __name__ == '__main__':
226 | parser = argparse.ArgumentParser(description="Generate metadata for square-shape microphone array.")
227 |
228 | parser.add_argument('--grid_distance',
229 | default=1.0,
230 | type=float,
231 | help='Distance between grid points in meters')
232 |
233 | parser.add_argument('--num_pairs_per_room',
234 | default=1000,
235 | type=int,
236 | help='Number of pairs per room to generate')
237 |
238 | parser.add_argument('--dataset_name',
239 | default='nvas3d_square',
240 | type=str,
241 | help='Name of the dataset')
242 |
243 | args = parser.parse_args()
244 | main(args)
245 |
--------------------------------------------------------------------------------
/nvas3d/model/model.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 | # Portions of this code are derived from VisualVoice (CC-BY-NC).
6 | # Original work available at: https://github.com/facebookresearch/VisualVoice
7 | #
8 |
9 | import torch
10 | import torchvision
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 |
15 | class NVASNet(nn.Module):
16 | def __init__(self, num_receivers, use_visual=False):
17 | super(NVASNet, self).__init__()
18 |
19 | self.num_receivers = num_receivers
20 | self.use_visual = use_visual
21 |
22 | if self.use_visual:
23 | self.rgb_net = VisualNet(torchvision.models.resnet18(pretrained=True), 3)
24 | self.depth_net = VisualNet(torchvision.models.resnet18(pretrained=True), 1)
25 |
26 | concat_size = 512 * 2
27 | self.pooling = nn.AdaptiveAvgPool2d((1, 1))
28 | self.conv1x1 = create_conv(concat_size, 512, 1, 0)
29 |
30 | input_channel = 2 * num_receivers
31 | output_channel = 2
32 |
33 | self.audio_net = AudioNet(64, input_channel, output_channel, self.use_visual)
34 | self.audio_net.apply(weights_init)
35 |
36 | def forward(self, inputs, disable_detection=False):
37 | visual_features = []
38 | if self.use_visual:
39 | visual_features.append(self.rgb_net(inputs['rgb']))
40 | visual_features.append(self.depth_net(inputs['depth']))
41 |
42 | if len(visual_features) != 0:
43 | # concatenate channel-wise
44 | concat_visual_features = torch.cat(visual_features, dim=1)
45 | concat_visual_features = self.conv1x1(concat_visual_features)
46 | concat_visual_features = self.pooling(concat_visual_features)
47 | else:
48 | concat_visual_features = None
49 |
50 | # Dereverber
51 | pred_stft, audio_feat, source_detection = self.audio_net(inputs['input_stft'], concat_visual_features, disable_detection)
52 | output = {'pred_stft': pred_stft}
53 |
54 | # Source identifier
55 | output['source_detection'] = source_detection
56 |
57 | if len(visual_features) != 0:
58 | audio_embed = self.pooling(audio_feat).squeeze(-1).squeeze(-1)
59 | visual_embed = concat_visual_features.squeeze(-1).squeeze(-1)
60 | output['audio_feat'] = F.normalize(audio_embed, p=2, dim=1)
61 | output['visual_feat'] = F.normalize(visual_embed, p=2, dim=1)
62 |
63 | return output
64 |
65 |
66 | def unet_conv(input_nc, output_nc, use_norm=False, norm_layer=nn.BatchNorm2d):
67 | downconv = nn.Conv2d(input_nc, output_nc, kernel_size=4, stride=2, padding=1)
68 | downrelu = nn.LeakyReLU(0.2, True)
69 | if use_norm:
70 | downnorm = norm_layer(output_nc)
71 | return nn.Sequential(*[downconv, downnorm, downrelu])
72 | else:
73 | return nn.Sequential(*[downconv, downrelu])
74 |
75 |
76 | def unet_upconv(input_nc, output_nc, outermost=False, use_sigmoid=False, use_tanh=False, use_norm=False, norm_layer=nn.BatchNorm2d):
77 | upconv = nn.ConvTranspose2d(input_nc, output_nc, kernel_size=4, stride=2, padding=1)
78 | uprelu = nn.ReLU(True)
79 | if use_norm and not outermost:
80 | upnorm = norm_layer(output_nc)
81 | return nn.Sequential(*[upconv, upnorm, uprelu])
82 | else:
83 | if outermost:
84 | if use_sigmoid:
85 | return nn.Sequential(*[upconv, nn.Sigmoid()])
86 | elif use_tanh:
87 | return nn.Sequential(*[upconv, nn.Tanh()])
88 | else:
89 | return nn.Sequential(*[upconv])
90 | else:
91 | return nn.Sequential(*[upconv, uprelu])
92 |
93 |
94 | class conv_block(nn.Module):
95 | def __init__(self, ch_in, ch_out, use_norm=False):
96 | super(conv_block, self).__init__()
97 | if use_norm:
98 | self.conv = nn.Sequential(
99 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
100 | nn.BatchNorm2d(ch_out),
101 | nn.LeakyReLU(0.2, True),
102 | nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
103 | nn.BatchNorm2d(ch_out),
104 | nn.LeakyReLU(0.2, True)
105 | )
106 | else:
107 | self.conv = nn.Sequential(
108 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
109 | nn.LeakyReLU(0.2, True),
110 | nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
111 | nn.LeakyReLU(0.2, True)
112 | )
113 |
114 | def forward(self, x):
115 | x = self.conv(x)
116 | return x
117 |
118 |
119 | class up_conv(nn.Module):
120 | def __init__(self, ch_in, ch_out, outermost=False, use_norm=False, scale_factor=(2., 1.)):
121 | super(up_conv, self).__init__()
122 | if not outermost:
123 | if use_norm:
124 | self.up = nn.Sequential(
125 | nn.Upsample(scale_factor=scale_factor),
126 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
127 | nn.BatchNorm2d(ch_out),
128 | nn.ReLU(inplace=True)
129 | )
130 | else:
131 | self.up = nn.Sequential(
132 | nn.Upsample(scale_factor=scale_factor),
133 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
134 | nn.ReLU(inplace=True)
135 | )
136 | else:
137 | self.up = nn.Sequential(
138 | nn.Upsample(scale_factor=scale_factor),
139 | nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
140 | nn.Sigmoid()
141 | )
142 |
143 | def forward(self, x):
144 | x = self.up(x)
145 | return x
146 |
147 |
148 | def create_conv(input_channels, output_channels, kernel, paddings, batch_norm=False, use_relu=True, stride=1):
149 | model = [nn.Conv2d(input_channels, output_channels, kernel, stride=stride, padding=paddings)]
150 | if batch_norm:
151 | model.append(nn.BatchNorm2d(output_channels))
152 | if use_relu:
153 | model.append(nn.ReLU())
154 | return nn.Sequential(*model)
155 |
156 |
157 | def weights_init(m):
158 | classname = m.__class__.__name__
159 | if classname.find('Conv') != -1:
160 | m.weight.data.normal_(0.0, 0.02)
161 | elif classname.find('BatchNorm2d') != -1:
162 | m.weight.data.normal_(1.0, 0.02)
163 | m.bias.data.fill_(0)
164 | elif classname.find('Linear') != -1:
165 | m.weight.data.normal_(0.0, 0.02)
166 |
167 |
168 | class VisualNet(nn.Module):
169 | def __init__(self, original_resnet, num_channel=3):
170 | super(VisualNet, self).__init__()
171 | original_resnet.conv1 = nn.Conv2d(num_channel, 64, kernel_size=7, stride=2, padding=3, bias=False)
172 | layers = list(original_resnet.children())[0:-2]
173 | self.feature_extraction = nn.Sequential(*layers) # features before conv1x1
174 |
175 | def forward(self, x):
176 | x = self.feature_extraction(x)
177 | return x
178 |
179 |
180 | class AudioNet(nn.Module):
181 | def __init__(self, ngf=64, input_nc=2, output_nc=2, use_visual=False, use_norm=True, audioVisual_feature_dim=512):
182 | super(AudioNet, self).__init__()
183 |
184 | self.use_visual = use_visual
185 | if use_visual:
186 | audioVisual_feature_dim += 512
187 |
188 | # initialize layers
189 | self.audionet_convlayer1 = unet_conv(input_nc, ngf, use_norm=False)
190 | self.audionet_convlayer2 = unet_conv(ngf, ngf * 2, use_norm=use_norm)
191 | self.audionet_convlayer3 = conv_block(ngf * 2, ngf * 4, use_norm=use_norm)
192 | self.audionet_convlayer4 = conv_block(ngf * 4, ngf * 8, use_norm=use_norm)
193 | self.audionet_convlayer5 = conv_block(ngf * 8, ngf * 8, use_norm=use_norm)
194 | self.audionet_convlayer6 = conv_block(ngf * 8, ngf * 8, use_norm=use_norm)
195 | self.audionet_convlayer7 = conv_block(ngf * 8, ngf * 8, use_norm=use_norm)
196 | self.audionet_convlayer8 = conv_block(ngf * 8, ngf * 8, use_norm=use_norm)
197 | self.frequency_time_pool = nn.MaxPool2d([2, 2])
198 | self.frequency_pool = nn.MaxPool2d([2, 1])
199 | self.audionet_upconvlayer1 = up_conv(audioVisual_feature_dim, ngf * 8, use_norm=use_norm)
200 | self.audionet_upconvlayer2 = up_conv(ngf * 16, ngf * 8, use_norm=use_norm)
201 | self.audionet_upconvlayer3 = up_conv(ngf * 16, ngf * 8, use_norm=use_norm, scale_factor=(2., 2.))
202 | self.audionet_upconvlayer4 = up_conv(ngf * 16, ngf * 8, use_norm=use_norm, scale_factor=(2., 2.))
203 | self.audionet_upconvlayer5 = up_conv(ngf * 16, ngf * 4, use_norm=use_norm, scale_factor=(2., 2.))
204 | self.audionet_upconvlayer6 = up_conv(ngf * 8, ngf * 2, use_norm=use_norm, scale_factor=(2., 2.))
205 | self.audionet_upconvlayer7 = unet_upconv(ngf * 4, ngf, use_norm=use_norm)
206 | self.audionet_upconvlayer8 = unet_upconv(ngf * 2, output_nc, True, use_norm=use_norm)
207 | self.Sigmoid = nn.Sigmoid()
208 | self.Tanh = nn.Tanh()
209 |
210 | # Source identifier
211 | self.source_detector = nn.Sequential(
212 | nn.Conv2d(audioVisual_feature_dim, 64, kernel_size=3, stride=1, padding=1), # maintains spatial dimensions as 4x4
213 | nn.ReLU(),
214 | nn.MaxPool2d(kernel_size=2, stride=2), # reduces spatial dimensions to 2x2
215 | nn.Conv2d(64, 32, kernel_size=2, stride=1), # reduces spatial dimensions to 1x1
216 | nn.ReLU(),
217 | nn.Flatten(),
218 | nn.Linear(32, 1),
219 | )
220 |
221 | def forward(self, audio_mix_stft, visual_feat, disable_detection=False):
222 | audio_conv1feature = self.audionet_convlayer1(audio_mix_stft)
223 | audio_conv2feature = self.audionet_convlayer2(audio_conv1feature)
224 | audio_conv3feature = self.audionet_convlayer3(audio_conv2feature)
225 | audio_conv3feature = self.frequency_time_pool(audio_conv3feature)
226 | audio_conv4feature = self.audionet_convlayer4(audio_conv3feature)
227 | audio_conv4feature = self.frequency_time_pool(audio_conv4feature)
228 | audio_conv5feature = self.audionet_convlayer5(audio_conv4feature)
229 | audio_conv5feature = self.frequency_time_pool(audio_conv5feature)
230 | audio_conv6feature = self.audionet_convlayer6(audio_conv5feature)
231 | audio_conv6feature = self.frequency_time_pool(audio_conv6feature)
232 | audio_conv7feature = self.audionet_convlayer7(audio_conv6feature)
233 | audio_conv7feature = self.frequency_pool(audio_conv7feature)
234 | audio_conv8feature = self.audionet_convlayer8(audio_conv7feature)
235 | audio_conv8feature = self.frequency_pool(audio_conv8feature)
236 |
237 | audioVisual_feature = audio_conv8feature
238 | if self.use_visual:
239 | visual_feat = visual_feat.view(visual_feat.shape[0], -1, 1, 1) # flatten visual feature
240 | visual_feat = visual_feat.repeat(1, 1, audio_conv8feature.shape[-2],
241 | audio_conv8feature.shape[-1]) # tile visual feature
242 |
243 | audioVisual_feature = torch.cat((visual_feat, audioVisual_feature), dim=1)
244 |
245 | if not disable_detection:
246 | source_detection = self.source_detector(audioVisual_feature)
247 | else:
248 | source_detection = torch.tensor([[0.0]])
249 |
250 | audio_upconv1feature = self.audionet_upconvlayer1(audioVisual_feature)
251 | audio_upconv2feature = self.audionet_upconvlayer2(torch.cat((audio_upconv1feature, audio_conv7feature), dim=1))
252 | audio_upconv3feature = self.audionet_upconvlayer3(torch.cat((audio_upconv2feature, audio_conv6feature), dim=1))
253 | audio_upconv4feature = self.audionet_upconvlayer4(torch.cat((audio_upconv3feature, audio_conv5feature), dim=1))
254 | audio_upconv5feature = self.audionet_upconvlayer5(torch.cat((audio_upconv4feature, audio_conv4feature), dim=1))
255 | audio_upconv6feature = self.audionet_upconvlayer6(torch.cat((audio_upconv5feature, audio_conv3feature), dim=1))
256 | audio_upconv7feature = self.audionet_upconvlayer7(torch.cat((audio_upconv6feature, audio_conv2feature), dim=1))
257 | prediction = self.audionet_upconvlayer8(torch.cat((audio_upconv7feature, audio_conv1feature), dim=1))
258 |
259 | pred_stft = prediction
260 |
261 | return pred_stft, audio_conv8feature, source_detection
262 |
--------------------------------------------------------------------------------
/nvas3d/utils/training_data_generation/generate_training_data.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import os
7 | import glob
8 | import json
9 | import random
10 | import shutil
11 | from tqdm import tqdm
12 |
13 | import torch
14 | import torchaudio
15 |
16 | from soundspaces_nvas3d.utils.aihabitat_utils import load_room_grid
17 | from soundspaces_nvas3d.utils.audio_utils import wiener_deconv_list
18 | from nvas3d.utils.audio_utils import clip_two
19 | from nvas3d.utils.utils import normalize, parse_librispeech_metadata, MP3D_SCENE_SPLITS
20 | from nvas3d.utils.generate_dataset_utils import sample_speech, sample_nonspeech, sample_acoustic_guitar, sample_all, clip_source, load_ir_source_receiver, save_audio_list, compute_reverb
21 |
22 | SOURCE1_DATA = 'all' # speech, nonspeech, guitar
23 | SOURCE2_DATA = 'all'
24 |
25 | random.seed(42)
26 |
27 | DATASET_NAME = f'nvas3d_square_{SOURCE1_DATA}_{SOURCE2_DATA}'
28 | os.makedirs(f'data/{DATASET_NAME}', exist_ok=True)
29 |
30 | grid_distance = 1.0
31 | grid_distance_str = str(grid_distance).replace(".", "_")
32 | target_shape_t = 256
33 | ir_length = 72000
34 | ir_clip_idx = ir_length - 1
35 | hop_length = 480
36 | len_clip = hop_length * (target_shape_t - 1) + ir_length - 1
37 | sample_rate = 48000
38 | snr = 100
39 | audio_format = 'flac'
40 |
41 | for split in ['train', 'val', 'test', 'demo']:
42 | # LibriSpeech
43 | if split == 'train':
44 | librispeech_dir = f'data/MIDI/clip/Speech/LibriSpeech48k/train'
45 | elif split == 'val':
46 | librispeech_dir = f'data/MIDI/clip/Speech/LibriSpeech48k/validation'
47 | elif split == 'test':
48 | librispeech_dir = f'data/MIDI/clip/Speech/LibriSpeech48k/test'
49 | else:
50 | librispeech_dir = f'data/MIDI/clip/Speech/LibriSpeech48k/validation'
51 | files_librispeech = glob.glob(librispeech_dir + '/**/*.flac', recursive=True)
52 | librispeech_metadata = parse_librispeech_metadata(f'data/MIDI/clip/Speech/LibriSpeech48k/SPEAKERS.TXT')
53 |
54 | # MIDI
55 | if split == 'train':
56 | all_instruments_dir = [path for path in glob.glob(os.path.join('data/MIDI/clip', '*/*', 'train')) if os.path.isdir(path)]
57 | elif split == 'val':
58 | all_instruments_dir = [path for path in glob.glob(os.path.join('data/MIDI/clip', '*/*', 'validation')) if os.path.isdir(path)]
59 | elif split == 'test':
60 | all_instruments_dir = [path for path in glob.glob(os.path.join('data/MIDI/clip', '*/*', 'test')) if os.path.isdir(path)]
61 | else:
62 | all_instruments_dir = [path for path in glob.glob(os.path.join('data/MIDI/clip', '*/*', 'validation')) if os.path.isdir(path)]
63 |
64 | # RIR
65 | ir_dir = f'data/nvas3d_square/ir/{split}/grid_{grid_distance_str}'
66 |
67 | # Image
68 | dirname_sourceimage = f'data/nvas3d_square/image/{split}/grid_{grid_distance_str}'
69 |
70 | # Iterate over rooms
71 | for i_room, room in enumerate(tqdm(MP3D_SCENE_SPLITS[split])):
72 | grid_points = load_room_grid(room, grid_distance)['grid_points']
73 | total_pairs = []
74 | filename = f'data/nvas3d_square/metadata/grid_{grid_distance_str}/{room}_square.json'
75 | with open(filename, 'r') as file:
76 | square_data = json.load(file)
77 | pairs = square_data['selected_pairs']
78 | # Add each pair with room id to the total list
79 | for pair in pairs:
80 | total_pairs.append((room,) + tuple(pair))
81 |
82 | random.shuffle(total_pairs)
83 |
84 | for i_pair, pair in enumerate(tqdm(total_pairs)):
85 | dirname = f'data/{DATASET_NAME}/{split}/{room}/{i_pair}'
86 | # os.makedirs(dirname, exist_ok=True)
87 |
88 | room, source_idx_list, receiver_idx_list, novel_receiver_idx = pair
89 |
90 | # Compute source
91 | if SOURCE1_DATA == 'speech':
92 | source1_audio, source1_class = sample_speech(files_librispeech, librispeech_metadata)
93 | elif SOURCE1_DATA == 'nonspeech':
94 | source1_audio, source1_class = sample_nonspeech(all_instruments_dir)
95 | elif SOURCE1_DATA == 'guitar':
96 | source1_audio, source1_class = sample_acoustic_guitar(all_instruments_dir)
97 | else:
98 | source1_audio, source1_class = sample_all(all_instruments_dir, librispeech_metadata)
99 |
100 | if SOURCE2_DATA == 'speech':
101 | source2_audio, source2_class = sample_speech(files_librispeech, librispeech_metadata)
102 | elif SOURCE2_DATA == 'nonspeech':
103 | source2_audio, source2_class = sample_nonspeech(all_instruments_dir)
104 | elif SOURCE2_DATA == 'guitar':
105 | source2_audio, source2_class = sample_acoustic_guitar(all_instruments_dir)
106 | else:
107 | source2_audio, source2_class = sample_all(all_instruments_dir, librispeech_metadata)
108 |
109 | source1_audio, source2_audio = clip_two(source1_audio, source2_audio)
110 | if not split == 'test':
111 | source1_audio, source2_audio = clip_source(source1_audio, source2_audio, len_clip)
112 |
113 | source1_audio = normalize(source1_audio)
114 | source2_audio = normalize(source2_audio)
115 |
116 | if torch.isnan(source1_audio).any() or torch.isnan(source2_audio).any():
117 | continue
118 |
119 | if split == 'test':
120 | if source1_audio.shape[0] < sample_rate * 10: # skip for short (<10s)
121 | continue
122 |
123 | os.makedirs(f'{dirname}/source', exist_ok=True)
124 | torchaudio.save(f'{dirname}/source/source1.{audio_format}', source1_audio[ir_clip_idx:].unsqueeze(0), sample_rate)
125 | torchaudio.save(f'{dirname}/source/source2.{audio_format}', source2_audio[ir_clip_idx:].unsqueeze(0), sample_rate)
126 |
127 | # Save IR
128 | os.makedirs(f'{dirname}/ir_receiver', exist_ok=True)
129 | ir1_list = load_ir_source_receiver(ir_dir, room, source_idx_list[0], receiver_idx_list, ir_length)
130 | ir2_list = load_ir_source_receiver(ir_dir, room, source_idx_list[1], receiver_idx_list, ir_length)
131 | ir3_list = load_ir_source_receiver(ir_dir, room, source_idx_list[2], receiver_idx_list, ir_length)
132 | save_audio_list(f'{dirname}/ir_receiver/ir1', ir1_list, sample_rate, audio_format)
133 | save_audio_list(f'{dirname}/ir_receiver/ir2', ir2_list, sample_rate, audio_format)
134 | save_audio_list(f'{dirname}/ir_receiver/ir3', ir3_list, sample_rate, audio_format)
135 |
136 | os.makedirs(f'{dirname}/ir_novel', exist_ok=True)
137 | ir1_novel = load_ir_source_receiver(ir_dir, room, source_idx_list[0], [novel_receiver_idx], ir_length)[0]
138 | ir2_novel = load_ir_source_receiver(ir_dir, room, source_idx_list[1], [novel_receiver_idx], ir_length)[0]
139 | torchaudio.save(f'{dirname}/ir_novel/ir1_novel.{audio_format}', ir1_novel.unsqueeze(0), sample_rate)
140 | torchaudio.save(f'{dirname}/ir_novel/ir2_novel.{audio_format}', ir2_novel.unsqueeze(0), sample_rate)
141 |
142 | # Save reverb
143 | os.makedirs(f'{dirname}/reverb', exist_ok=True)
144 | reverb1_list = compute_reverb(source1_audio, ir1_list)
145 | reverb2_list = compute_reverb(source2_audio, ir2_list)
146 | save_audio_list(f'{dirname}/reverb/reverb1', reverb1_list, sample_rate, audio_format)
147 | save_audio_list(f'{dirname}/reverb/reverb2', reverb2_list, sample_rate, audio_format)
148 |
149 | # Compute receiver
150 | os.makedirs(f'{dirname}/receiver', exist_ok=True)
151 | receiver_list = [reverb1 + reverb2 for reverb1, reverb2 in zip(reverb1_list, reverb2_list)]
152 | save_audio_list(f'{dirname}/receiver/receiver', receiver_list, sample_rate, audio_format)
153 |
154 | # Save Weiner
155 | os.makedirs(f'{dirname}/wiener', exist_ok=True)
156 | wiener1_list = wiener_deconv_list(receiver_list, ir1_list, snr)
157 | wiener2_list = wiener_deconv_list(receiver_list, ir2_list, snr)
158 | wiener3_list = wiener_deconv_list(receiver_list, ir3_list, snr)
159 | save_audio_list(f'{dirname}/wiener/wiener1', wiener1_list, sample_rate, audio_format)
160 | save_audio_list(f'{dirname}/wiener/wiener2', wiener2_list, sample_rate, audio_format)
161 | save_audio_list(f'{dirname}/wiener/wiener3', wiener3_list, sample_rate, audio_format)
162 |
163 | # Copy image
164 | os.makedirs(f'{dirname}/image', exist_ok=True)
165 |
166 | if source1_class.lower() == 'male' or source1_class.lower() == 'female':
167 | source1_class_render = source1_class.lower()
168 | else:
169 | source1_class_render = 'guitar'
170 |
171 | if source2_class.lower() == 'male' or source2_class.lower() == 'female':
172 | source2_class_render = source2_class.lower()
173 | else:
174 | source2_class_render = 'guitar'
175 | rgb1 = f'{dirname_sourceimage}/{room}/rgb_{room}_{source1_class_render}_{source_idx_list[0]}.npy'
176 | rgb2 = f'{dirname_sourceimage}/{room}/rgb_{room}_{source2_class_render}_{source_idx_list[1]}.npy'
177 | depth1 = f'{dirname_sourceimage}/{room}/depth_{room}_{source1_class_render}_{source_idx_list[0]}.npy'
178 | depth2 = f'{dirname_sourceimage}/{room}/depth_{room}_{source2_class_render}_{source_idx_list[1]}.npy'
179 |
180 | rgb1_bg = f'{dirname_sourceimage}/{room}/rgb_{room}_{source_idx_list[0]}.npy'
181 | rgb2_bg = f'{dirname_sourceimage}/{room}/rgb_{room}_{source_idx_list[1]}.npy'
182 | rgb3_bg = f'{dirname_sourceimage}/{room}/rgb_{room}_{source_idx_list[2]}.npy'
183 | depth1_bg = f'{dirname_sourceimage}/{room}/depth_{room}_{source_idx_list[0]}.npy'
184 | depth2_bg = f'{dirname_sourceimage}/{room}/depth_{room}_{source_idx_list[1]}.npy'
185 | depth3_bg = f'{dirname_sourceimage}/{room}/depth_{room}_{source_idx_list[2]}.npy'
186 |
187 | shutil.copy(rgb1, f'{dirname}/image/rgb1.npy')
188 | shutil.copy(rgb2, f'{dirname}/image/rgb2.npy')
189 | shutil.copy(depth1, f'{dirname}/image/depth1.npy')
190 | shutil.copy(depth2, f'{dirname}/image/depth2.npy')
191 |
192 | shutil.copy(rgb1_bg, f'{dirname}/image/rgb1_bg.npy')
193 | shutil.copy(rgb2_bg, f'{dirname}/image/rgb2_bg.npy')
194 | shutil.copy(rgb3_bg, f'{dirname}/image/rgb3_bg.npy')
195 | shutil.copy(depth1_bg, f'{dirname}/image/depth1_bg.npy')
196 | shutil.copy(depth2_bg, f'{dirname}/image/depth2_bg.npy')
197 | shutil.copy(depth3_bg, f'{dirname}/image/depth3_bg.npy')
198 |
199 | # # png (optional)
200 | # rgb1 = f'{dirname_sourceimage}/{room}/rgb_{room}_{source1_class_render}_{source_idx_list[0]}.png'
201 | # rgb2 = f'{dirname_sourceimage}/{room}/rgb_{room}_{source2_class_render}_{source_idx_list[1]}.png'
202 | # depth1 = f'{dirname_sourceimage}/{room}/depth_{room}_{source1_class_render}_{source_idx_list[0]}.png'
203 | # depth2 = f'{dirname_sourceimage}/{room}/depth_{room}_{source2_class_render}_{source_idx_list[1]}.png'
204 |
205 | # rgb1_bg = f'{dirname_sourceimage}/{room}/rgb_{room}_{source_idx_list[0]}.png'
206 | # rgb2_bg = f'{dirname_sourceimage}/{room}/rgb_{room}_{source_idx_list[1]}.png'
207 | # rgb3_bg = f'{dirname_sourceimage}/{room}/rgb_{room}_{source_idx_list[2]}.png'
208 | # depth1_bg = f'{dirname_sourceimage}/{room}/depth_{room}_{source_idx_list[0]}.png'
209 | # depth2_bg = f'{dirname_sourceimage}/{room}/depth_{room}_{source_idx_list[1]}.png'
210 | # depth3_bg = f'{dirname_sourceimage}/{room}/depth_{room}_{source_idx_list[2]}.png'
211 |
212 | # shutil.copy(rgb1, f'{dirname}/image/rgb1.png')
213 | # shutil.copy(rgb2, f'{dirname}/image/rgb2.png')
214 | # shutil.copy(depth1, f'{dirname}/image/depth1.png')
215 | # shutil.copy(depth2, f'{dirname}/image/depth2.png')
216 |
217 | # shutil.copy(rgb1_bg, f'{dirname}/image/rgb1_bg.png')
218 | # shutil.copy(rgb2_bg, f'{dirname}/image/rgb2_bg.png')
219 | # shutil.copy(rgb3_bg, f'{dirname}/image/rgb3_bg.png')
220 | # shutil.copy(depth1_bg, f'{dirname}/image/depth1_bg.png')
221 | # shutil.copy(depth2_bg, f'{dirname}/image/depth2_bg.png')
222 | # shutil.copy(depth3_bg, f'{dirname}/image/depth3_bg.png')
223 |
224 | # Save metadata
225 | metadata = {
226 | 'source1_idx': source_idx_list[0],
227 | 'source2_idx': source_idx_list[1],
228 | 'source3_idx': source_idx_list[2],
229 | 'receiver_idx_list': receiver_idx_list,
230 | 'novel_receiver_idx': novel_receiver_idx,
231 | 'source1_class': source1_class,
232 | 'source2_class': source2_class,
233 | 'grid_points': grid_points,
234 | 'room': room,
235 | }
236 | torch.save(metadata, f'{dirname}/metadata.pt')
237 |
238 | pass
239 |
--------------------------------------------------------------------------------
/demo/generate_demo_data.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import os
7 | import json
8 | import random
9 | import argparse
10 | import subprocess
11 | import typing as T
12 |
13 | import torch
14 | import torchaudio
15 |
16 | from soundspaces_nvas3d.utils.ss_utils import render_ir_parallel_room_idx, create_scene
17 | from soundspaces_nvas3d.utils.aihabitat_utils import load_room_grid
18 | from soundspaces_nvas3d.utils.audio_utils import wiener_deconv_list
19 | from nvas3d.utils.audio_utils import clip_two
20 | from nvas3d.utils.utils import normalize
21 | from nvas3d.utils.generate_dataset_utils import load_ir_source_receiver, save_audio_list, compute_reverb
22 |
23 |
24 | def generate_rir(
25 | args: argparse.Namespace,
26 | room: str,
27 | source_idx_list: T.List[int],
28 | receiver_idx_list: T.List[int]
29 | ):
30 | """
31 | Generates and saves Room Impulse Response (RIR) data for pairs of source_idx_list and receiver_idx_list.
32 |
33 | Args:
34 | - args: Parsed command line arguments for dirname and grid distance.
35 | - room: Name of the room.
36 | - source_idx_list: List of source indices.
37 | - receiver_idx_list: List of receiver indices.
38 | """
39 |
40 | ir_dir = f'data/{args.dataset_dir}/temp/ir/grid_{str(args.grid_distance).replace(".", "_")}'
41 | filename_ir = f'{ir_dir}/{room}/ir'
42 | os.makedirs(filename_ir, exist_ok=True)
43 | render_ir_parallel_room_idx(room, source_idx_list, receiver_idx_list, filename_ir, args.grid_distance)
44 |
45 |
46 | def visualize_grid(
47 | args: argparse.Namespace,
48 | room: str,
49 | image_size: T.Tuple[int, int]
50 | ):
51 | """
52 | Visualizes grid points for a given room.
53 |
54 | Args:
55 | - args: Parsed command line arguments for dirname and grid distance.
56 | - room: Name of the room.
57 | - image_size: Dimensions of the output image.
58 | """
59 |
60 | scene = create_scene(room, image_size=image_size)
61 | os.makedirs(f'data/{args.dataset_dir}/demo/{room}', exist_ok=True)
62 | scene.generate_xy_grid_points(args.grid_distance, filename_png=f'data/{args.dataset_dir}/demo/{room}/index_{str(args.grid_distance).replace(".", "_")}.png')
63 | scene.generate_xy_grid_points(0.5, filename_png=f'data/{args.dataset_dir}/demo/{room}/index_0_5.png')
64 |
65 |
66 | def process_audio_sources(
67 | dirname: str,
68 | source1_path: str,
69 | source2_path: str,
70 | rir_clip_idx: int,
71 | sample_rate: int,
72 | audio_format: str
73 | ) -> T.Tuple[torch.Tensor, torch.Tensor]:
74 | """
75 | Preprocess and saves audio for dry sound.
76 |
77 | Args:
78 | - dirname: Directory name for saving the processed audio.
79 | - source1_path: Path to the first audio source.
80 | - source2_path: Path to the second audio source.
81 | - rir_clip_idx: Index for clipping the RIR.
82 | - sample_rate: Sampling rate of the audio.
83 | - audio_format: Format to save the audio (e.g., 'flac', 'wav').
84 |
85 | Returns:
86 | - Processed audio tensors for source1 and source2.
87 | """
88 |
89 | source1_audio, _ = torchaudio.load(source1_path)
90 | source2_audio, _ = torchaudio.load(source2_path)
91 | source1_audio = source1_audio.reshape(-1)
92 | source2_audio = source2_audio.reshape(-1)
93 | source1_audio, source2_audio = clip_two(source1_audio, source2_audio)
94 | source1_audio = normalize(source1_audio)
95 | source2_audio = normalize(source2_audio)
96 | os.makedirs(f'{dirname}/source', exist_ok=True)
97 | torchaudio.save(f'{dirname}/source/source1.{audio_format}', source1_audio[rir_clip_idx:].unsqueeze(0), sample_rate)
98 | torchaudio.save(f'{dirname}/source/source2.{audio_format}', source2_audio[rir_clip_idx:].unsqueeze(0), sample_rate)
99 |
100 | return source1_audio, source2_audio
101 |
102 |
103 | def save_rir_data(
104 | dirname: str,
105 | rir_dir: str,
106 | room: str,
107 | source_idx_list: T.List[int],
108 | receiver_idx_list: T.List[int],
109 | rir_length: int,
110 | sample_rate: int,
111 | audio_format: str
112 | ) -> T.Tuple[T.List[torch.Tensor], T.List[torch.Tensor]]:
113 | """
114 | Saves RIR data for the pair of sources and receivers.
115 |
116 | Args:
117 | - dirname: Directory name for saving the RIR data.
118 | - rir_dir: Directory where RIR data is located.
119 | - room: Name of the room.
120 | - source_idx_list: List of source indices.
121 | - receiver_idx_list: List of receiver indices.
122 | - rir_length: Length of the RIR.
123 | - sample_rate: Sampling rate of the audio.
124 | - audio_format: Format to save the audio (e.g., 'flac', 'wav').
125 |
126 | Returns:
127 | - Lists of IRs from source1 and source2.
128 | """
129 |
130 | os.makedirs(f'{dirname}/ir_receiver', exist_ok=True)
131 | rir1_list = load_ir_source_receiver(rir_dir, room, source_idx_list[0], receiver_idx_list, rir_length)
132 | rir2_list = load_ir_source_receiver(rir_dir, room, source_idx_list[1], receiver_idx_list, rir_length)
133 | save_audio_list(f'{dirname}/ir_receiver/ir1', rir1_list, sample_rate, audio_format)
134 | save_audio_list(f'{dirname}/ir_receiver/ir2', rir2_list, sample_rate, audio_format)
135 | return rir1_list, rir2_list
136 |
137 |
138 | def reverb_audio(
139 | filename: str,
140 | source_audio: torch.Tensor,
141 | rir_list: T.List[torch.Tensor],
142 | sample_rate: int,
143 | audio_format: str
144 | ) -> T.List[torch.Tensor]:
145 | """
146 | Applies reverberation to audio using provided IR.
147 |
148 | Args:
149 | - filename: Directory name for saving the reverberant audio.
150 | - source_audio: Source audio tensor.
151 | - rir_list: List of RIR tensors.
152 | - sample_rate: Sampling rate of the audio.
153 | - audio_format: Format to save the audio (e.g., 'flac', 'wav').
154 |
155 | Returns:
156 | - List of reverberated audio.
157 | """
158 |
159 | reverb_list = compute_reverb(source_audio, rir_list)
160 | save_audio_list(filename, reverb_list, sample_rate, audio_format)
161 | return reverb_list
162 |
163 |
164 | def mix_audio(
165 | dirname: str,
166 | reverb1_list: T.List[torch.Tensor],
167 | reverb2_list: T.List[torch.Tensor],
168 | sample_rate: int,
169 | audio_format: str
170 | ) -> T.List[torch.Tensor]:
171 | """
172 | Mixes two reverberant audio.
173 |
174 | Args:
175 | - dirname: Directory name for saving the mixed audio.
176 | - reverb1_list: List of the first reverberant audio.
177 | - reverb2_list: List of the second reverberant audio.
178 | - sample_rate: Sampling rate of the audio.
179 | - audio_format: Format to save the audio (e.g., 'flac').
180 |
181 | Returns:
182 | - List containing mixed audio.
183 | """
184 |
185 | os.makedirs(f'{dirname}/receiver', exist_ok=True)
186 | receiver_list = [reverb1 + reverb2 for reverb1, reverb2 in zip(reverb1_list, reverb2_list)]
187 | save_audio_list(f'{dirname}/receiver/receiver', receiver_list, sample_rate, audio_format)
188 | return receiver_list
189 |
190 |
191 | def save_topdown_view(
192 | function_name: str,
193 | args_list: T.List[T.Union[str, int, float]]
194 | ):
195 | """
196 | Saves a topdown view of a given scene.
197 |
198 | Args:
199 | - function_name: Name of the function to be used for visualization.
200 | - args_list: List of arguments for the visualization function.
201 | (room, source_idx_list, receiver_idx_list, grid_distance)
202 | """
203 |
204 | subprocess.run(["python", "soundspaces_nvas3d/utils/render_scene_script.py", function_name, *map(str, args_list)])
205 |
206 |
207 | def main(args):
208 | """
209 | Generate and save demo data for a specific room.
210 |
211 | Directory Structure and Contents:
212 | ├── {data_demo} = data/nvas3d_demo/demo/{room}/0
213 | │ ├── receiver/ : Receiver audio.
214 | │ ├── wiener/ : Deconvolved audio (auxiliary data to accelerate tests) (wiener_{query_idx}_{receiver_id}).
215 | │ ├── source/ : Ground truth dry audio.
216 | │ ├── reverb1/ : Ground truth reverberant audio for source 1.
217 | │ ├── reverb2/ : Ground truth reverberant audio for source 2.
218 | │ ├── ir_receiver/ : Ground truth RIRs from source to receiver.
219 | │ ├── config.png : Visualization of room configuration.
220 | │ └── metadata.pt : Metadata containing source indices, classes, grid points, and room information.
221 |
222 | Additional Visualizations:
223 | ├── data/nvas3d_demo/{room} : Room index visualizations.
224 | """
225 |
226 | # Seed for reproducibility
227 | random.seed(42)
228 |
229 | # Directory setup for data
230 | os.makedirs(f'data/{args.dataset_dir}', exist_ok=True)
231 |
232 | # Extract and load room and grid related data
233 | room = args.room
234 | grid_distance = args.grid_distance
235 | grid_points = load_room_grid(room, grid_distance)['grid_points']
236 | rir_length = args.rir_length
237 | sample_rate = args.sample_rate
238 | snr = args.snr
239 | audio_format = args.audio_format
240 | scene_config = args.scene_config
241 |
242 | source1_path = args.source1_path
243 | source2_path = args.source2_path
244 |
245 | # Set source and receiver locations
246 | with open(f'demo/config_demo/{scene_config}.json', 'r') as file:
247 | json_scene = json.load(file)
248 | source_idx_list = json_scene['source_idx_list']
249 | receiver_idx_list = json_scene['receiver_idx_list']
250 |
251 | # Generate and save RIR data
252 | generate_rir(args, room, list(range(len(grid_points))), receiver_idx_list)
253 |
254 | # Prepare directory for demo data
255 | dirname = f'data/{args.dataset_dir}/demo/{room}/0'
256 | os.makedirs(dirname, exist_ok=True)
257 |
258 | # Visualize the grid points within the room
259 | visualize_grid(args, room, (400, 300))
260 |
261 | # Preprocess audio
262 | rir_clip_idx = rir_length - 1
263 | source1_class, source2_class = 'female', 'drum'
264 | source1_audio, source2_audio = process_audio_sources(dirname, source1_path, source2_path, rir_clip_idx, sample_rate, audio_format)
265 |
266 | # Reverb audio for source 1 and source 2
267 | rir_dir = f'data/{args.dataset_dir}/temp/ir/grid_{str(grid_distance).replace(".", "_")}'
268 | rir1_list, rir2_list = save_rir_data(dirname, rir_dir, room, source_idx_list, receiver_idx_list, rir_length, sample_rate, audio_format)
269 | os.makedirs(f'{dirname}/reverb', exist_ok=True)
270 | reverb1_list = reverb_audio(f'{dirname}/reverb/reverb1', source1_audio, rir1_list, sample_rate, audio_format)
271 | reverb2_list = reverb_audio(f'{dirname}/reverb/reverb2', source2_audio, rir2_list, sample_rate, audio_format)
272 |
273 | # Mix both reverberant audios
274 | receiver_list = mix_audio(dirname, reverb1_list, reverb2_list, sample_rate, audio_format)
275 |
276 | # Visualize topdown view of the scene
277 | save_topdown_view("scene", [f'{dirname}/config.png', room, source_idx_list[:2], receiver_idx_list, args.grid_distance])
278 |
279 | # Save metadata
280 | metadata = {
281 | 'source1_idx': source_idx_list[0],
282 | 'source2_idx': source_idx_list[1],
283 | 'receiver_idx_list': receiver_idx_list,
284 | 'source1_class': source1_class,
285 | 'source2_class': source2_class,
286 | 'grid_points': grid_points,
287 | 'grid_distance': grid_distance,
288 | 'room': room,
289 | }
290 | torch.save(metadata, f'{dirname}/metadata.pt')
291 |
292 | # Save deconvolved audio using Wiener deconvolution
293 | for query_idx in range(len(grid_points)):
294 | if (query_idx in receiver_idx_list):
295 | continue
296 |
297 | # load RIR
298 | rir_query_list = load_ir_source_receiver(rir_dir, room, query_idx, receiver_idx_list, rir_length)
299 |
300 | # save Weiner
301 | os.makedirs(f'{dirname}/wiener', exist_ok=True)
302 | wiener_list = wiener_deconv_list(receiver_list, rir_query_list, snr)
303 | save_audio_list(f'{dirname}/wiener/wiener{query_idx}', wiener_list, sample_rate, audio_format)
304 |
305 |
306 | if __name__ == '__main__':
307 | parser = argparse.ArgumentParser()
308 |
309 | parser.add_argument('--room', default='17DRP5sb8fy', type=str, help='mp3d room')
310 | parser.add_argument('--dataset_dir', default='nvas3d_demo', type=str, help='dirname')
311 | parser.add_argument('--grid_distance', default=1.0, type=float, help='Distance between grid points')
312 | parser.add_argument('--rir_length', default=72000, type=int, help='IR length')
313 | parser.add_argument('--sample_rate', default=48000, type=int, help='Sample rate')
314 | parser.add_argument('--snr', default=100, type=int, help='SNR for Wiener deconvolution')
315 | parser.add_argument('--audio_format', default='flac', type=str, help='Audio format to save')
316 | parser.add_argument('--scene_config', default='scene1_17DRP5sb8fy', type=str, help='Scene configuration json')
317 | parser.add_argument('--source1_path', default='data/source/female.flac', type=str, help='Filename for source 1')
318 | parser.add_argument('--source2_path', default='data/source/drum.flac', type=str, help='Filename for source 2')
319 |
320 | args = parser.parse_args()
321 | main(args)
322 |
--------------------------------------------------------------------------------
/soundspaces_nvas3d/utils/ss_utils.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import os
7 | import sys
8 | import random
9 | import itertools
10 | import typing as T
11 | import multiprocessing
12 | import matplotlib.pyplot as plt
13 | from contextlib import contextmanager
14 | from tqdm import tqdm
15 |
16 | import torch
17 | import torchaudio
18 |
19 | from soundspaces_nvas3d.soundspaces_nvas3d import Receiver, Source, Scene
20 | from soundspaces_nvas3d.utils.aihabitat_utils import save_grid_config, load_room_grid
21 |
22 |
23 | @contextmanager
24 | def suppress_stdout_and_stderr():
25 | """
26 | Suppress the logs from SoundSpaces.
27 | """
28 |
29 | original_stdout_fd = os.dup(sys.stdout.fileno())
30 | original_stderr_fd = os.dup(sys.stderr.fileno())
31 | devnull = os.open(os.devnull, os.O_WRONLY)
32 |
33 | try:
34 | os.dup2(devnull, sys.stdout.fileno())
35 | os.dup2(devnull, sys.stderr.fileno())
36 | yield
37 | finally:
38 | os.dup2(original_stdout_fd, sys.stdout.fileno())
39 | os.dup2(original_stderr_fd, sys.stderr.fileno())
40 | os.close(devnull)
41 |
42 |
43 | def create_scene(room: str,
44 | receiver_position: T.Tuple[float, float, float] = [0.0, 0.0, 0.0],
45 | sample_rate: float = 48000,
46 | image_size: T.Tuple[int, int] = (512, 256),
47 | include_visual_sensor: bool = True,
48 | hfov: float = 90.0
49 | ) -> Scene:
50 | """
51 | Create a soundspaces scene to render IR.
52 | """
53 |
54 | # Note: Make sure mp3d room is downloaded
55 | with suppress_stdout_and_stderr():
56 | # Create a receiver
57 | receiver = Receiver(
58 | position=receiver_position,
59 | rotation=0,
60 | sample_rate=sample_rate
61 | )
62 |
63 | scene = Scene(
64 | room,
65 | [None], # placeholder for source class
66 | receiver=receiver,
67 | include_visual_sensor=include_visual_sensor,
68 | add_source_mesh=False,
69 | device=torch.device('cpu'),
70 | add_source=False,
71 | image_size=image_size,
72 | hfov=hfov
73 | )
74 |
75 | return scene
76 |
77 |
78 | def render_ir(room: str,
79 | source_position: T.Tuple[float, float, float],
80 | receiver_position: T.Tuple[float, float, float],
81 | filename: str = None,
82 | receiver_rotation: float = None,
83 | sample_rate: float = 48000,
84 | use_default_material: bool = False,
85 | channel_type: str = 'Ambisonics',
86 | channel_order: int = 1
87 | ) -> torch.Tensor:
88 | """
89 | Render impulse response for a source and receiver pair in the mp3d room.
90 | """
91 |
92 | if receiver_rotation is None:
93 | receiver_rotation = 90
94 |
95 | # Create a receiver
96 | receiver = Receiver(
97 | position=receiver_position,
98 | rotation=receiver_rotation,
99 | sample_rate=sample_rate
100 | )
101 |
102 | # Create a source
103 | source = Source(
104 | position=source_position,
105 | rotation=0,
106 | dry_sound='',
107 | mesh='',
108 | device=torch.device('cpu')
109 | )
110 |
111 | scene = Scene(
112 | room,
113 | [None], # placeholder for source class
114 | receiver=receiver,
115 | source_list=[source],
116 | include_visual_sensor=False,
117 | add_source_mesh=False,
118 | device=torch.device('cpu'),
119 | use_default_material=use_default_material,
120 | channel_type=channel_type,
121 | channel_order=channel_order
122 | )
123 |
124 | # Render IR
125 | scene.add_audio_sensor()
126 | # with suppress_stdout_and_stderr():
127 | ir = scene.render_ir(0)
128 |
129 | # Save file if dirname is given
130 | if filename is not None:
131 | torchaudio.save(filename, ir, sample_rate=sample_rate)
132 | else:
133 | return ir
134 |
135 |
136 | def render_rir_parallel(room_list: T.List[str],
137 | source_position_list: T.List[T.Tuple[float, float, float]],
138 | receiver_position_list: T.List[T.Tuple[float, float, float]],
139 | filename_list: T.List[str] = None,
140 | receiver_rotation_list: T.List[float] = None,
141 | batch_size: int = 64,
142 | sample_rate: float = 48000,
143 | use_default_material: bool = False,
144 | channel_type: str = 'Ambisonics',
145 | channel_order: int = 1
146 | ) -> T.List[torch.Tensor]:
147 | """
148 | Run render_ir parallely for all elements of zip(source_position_list, receiver_position_list).
149 | """
150 |
151 | assert len(room_list) == len(source_position_list)
152 | assert len(source_position_list) == len(receiver_position_list)
153 |
154 | if filename_list is None:
155 | is_return = True
156 | else:
157 | is_return = False
158 |
159 | if receiver_rotation_list is None:
160 | receiver_rotation_list = [0] * len(receiver_position_list)
161 |
162 | # Note: Make sure all rooms are downloaded
163 |
164 | # Calculate the number of batches
165 | num_points = len(source_position_list)
166 | num_batches = (num_points + batch_size - 1) // batch_size
167 |
168 | # Use tqdm to display the progress bar
169 | progress_bar = tqdm(total=num_points)
170 |
171 | def update_progress(*_):
172 | progress_bar.update()
173 |
174 | ir_list = []
175 | # Process the tasks in batches
176 | for batch_idx in range(num_batches):
177 | # Calculate the start and end indices of the current batch
178 | start_idx = batch_idx * batch_size
179 | end_idx = min(start_idx + batch_size, num_points)
180 | if is_return:
181 | batch = [(room_list[i], source_position_list[i], receiver_position_list[i], None, receiver_rotation_list[i]) for i in range(start_idx, end_idx)]
182 | else:
183 | batch = [(room_list[i], source_position_list[i], receiver_position_list[i], filename_list[i], receiver_rotation_list[i]) for i in range(start_idx, end_idx)]
184 |
185 | # Create a multiprocessing Pool for the current batch
186 | with multiprocessing.Pool() as pool:
187 | tasks = []
188 | for room, source_position, receiver_position, filename, receiver_rotation in batch:
189 | # Apply async mapping of process_ir function
190 | task = pool.apply_async(render_ir, args=(room, source_position, receiver_position, filename, receiver_rotation, sample_rate, use_default_material, channel_type, channel_order), callback=update_progress)
191 | tasks.append(task)
192 |
193 | # Wait for all tasks in the batch to complete and collect results
194 | for task in tasks:
195 | if is_return:
196 | ir = task.get() # Block until the result is ready
197 | ir_list.append(ir) # Append the result to the list
198 | else:
199 | task.get()
200 | if is_return:
201 | return ir_list
202 |
203 |
204 | def render_ir_parallel_room_idx(room: str,
205 | source_idx_list: T.List[int],
206 | receiver_idx_list: T.List[int],
207 | filename: str = None,
208 | grid_distance=1.0,
209 | batch_size: int = 64,
210 | sample_rate: float = 48000,
211 | use_default_material: bool = False,
212 | channel_type='Ambisonics' # Binaural
213 | ) -> T.List[torch.Tensor]:
214 | """
215 | Run render_ir parallely for all elements of all_pair(source_idx_list, receiver_idx_list)
216 | """
217 |
218 | grid_points = load_room_grid(room, grid_distance=grid_distance)['grid_points']
219 |
220 | source_idx_pair_list, receiver_idx_pair_list = all_pairs(source_idx_list, receiver_idx_list) # only for filename
221 | receiver_points = grid_points[receiver_idx_list]
222 | source_points = grid_points[source_idx_list]
223 |
224 | source_points_pair, receiver_points_pair = all_pairs(source_points, receiver_points)
225 |
226 | room_list = [room] * len(source_points_pair)
227 | if filename is not None:
228 | filename_list = [f'{filename}_{room}_{source_idx}_{receiver_idx}.wav'
229 | for source_idx, receiver_idx in zip(source_idx_pair_list, receiver_idx_pair_list)]
230 | else:
231 | filename_list = None
232 |
233 | # Render IR for grid points
234 | ir_list = render_rir_parallel(room_list,
235 | source_points_pair,
236 | receiver_points_pair,
237 | filename_list,
238 | batch_size=batch_size,
239 | sample_rate=sample_rate,
240 | use_default_material=use_default_material,
241 | channel_type=channel_type)
242 |
243 | return ir_list, source_idx_pair_list, receiver_idx_pair_list
244 |
245 |
246 | def render_receiver_image(dirname: str,
247 | room: str,
248 | source_idx_list: T.List[int],
249 | source_class_list: T.List[str],
250 | receiver_idx_list: T.List[int],
251 | filename: str = None,
252 | grid_distance=1.0,
253 | hfov=120,
254 | image_size=(1024, 1024)
255 | ):
256 |
257 | # load grid points
258 | grid_points = load_room_grid(room, grid_distance=grid_distance)['grid_points']
259 |
260 | receiver_points = grid_points[receiver_idx_list]
261 | source_points = grid_points[source_idx_list]
262 | # source_points_pair, receiver_points_pair = all_pairs(source_points, receiver_points)
263 |
264 | # initialize scene
265 | scene = create_scene(room, image_size=image_size, hfov=hfov)
266 | scene.add_source_mesh = True
267 |
268 | # initialize receiver
269 | sample_rate = 48000
270 | position = [0.0, 0, 0]
271 | rotation = 0.0
272 | receiver = Receiver(position, rotation, sample_rate)
273 |
274 | # set source
275 | source_list = []
276 | for source_idx, source_class in zip(source_idx_list, source_class_list):
277 | position = grid_points[source_idx]
278 | if source_class == 'male' or source_class == 'female':
279 | source = Source(
280 | position=position,
281 | rotation=random.uniform(0, 360),
282 | dry_sound='',
283 | mesh=source_class,
284 | device=torch.device('cpu')
285 | )
286 | source_list.append(source)
287 | else:
288 | source = Source(
289 | position=position,
290 | rotation=random.uniform(0, 360),
291 | dry_sound='',
292 | mesh='guitar', # All instruments use guitar mesh
293 | device=torch.device('cpu')
294 | )
295 | source_list.append(source)
296 |
297 | # add mesh
298 | scene.source_list = [None] * len(source_idx_list)
299 | for id, source in enumerate(source_list):
300 | scene.update_source(source, id)
301 |
302 | # see source1 direction
303 | source_x = source_points[0][0]
304 | source_z = source_points[0][2]
305 |
306 | # render images
307 | rgb_list = []
308 | depth_list = []
309 | for receiver_idx in receiver_idx_list:
310 | receiver.position = grid_points[receiver_idx]
311 | # all receiver sees source 1 direction
312 | rotation_source1 = calculate_degree(source_x - receiver.position[0], source_z - receiver.position[2])
313 | receiver.rotation = rotation_source1
314 | scene.update_receiver(receiver)
315 | rgb, depth = scene.render_image()
316 | rgb_list.append(rgb)
317 | depth_list.append(depth)
318 |
319 | # for index debug
320 | # source_idx_pair_list, receiver_idx_pair_list = all_pairs(source_idx_list, receiver_idx_list)
321 |
322 | for i in range(len(receiver_idx_list)):
323 | plt.imsave(f'{dirname}/{i+1}.png', rgb_list[i])
324 | plt.imsave(f'{dirname}/{i+1}_depth.png', depth_list[i], cmap='gray')
325 |
326 | # return rgb_list, depth_list
327 |
328 |
329 | def render_scene_config(filename: str,
330 | room: str,
331 | source_idx_list: T.List[int],
332 | receiver_idx_list: T.List[int],
333 | grid_distance=1.0,
334 | no_query=False):
335 |
336 | # load grid points
337 | grid_points = load_room_grid(room, grid_distance=grid_distance)['grid_points']
338 |
339 | # config
340 | scene = create_scene(room)
341 | save_grid_config(filename, scene.sim.pathfinder, grid_points, receiver_idx_list=receiver_idx_list, source_idx_list=source_idx_list, no_query=no_query)
342 |
343 |
344 | # Additional utility functions
345 | def all_pairs(list1, list2):
346 | list_pair = list(itertools.product(list1, list2))
347 |
348 | list1_pair, list2_pair = zip(*list_pair)
349 | list1_pair = list(list1_pair)
350 | list2_pair = list(list2_pair)
351 |
352 | return list1_pair, list2_pair
353 |
354 |
355 | def calculate_degree(x, y):
356 | radian = torch.atan2(y, x)
357 | degree = torch.rad2deg(radian)
358 | # Adjusting for the described mapping
359 | degree = (-degree - 90) % 360
360 | return degree
361 |
--------------------------------------------------------------------------------
/nvas3d/utils/training_data_generation/generate_test_data.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import os
7 | import glob
8 | import json
9 | import random
10 | import subprocess
11 | import concurrent.futures
12 | from itertools import product
13 |
14 | from tqdm import tqdm
15 |
16 | import torch
17 | import torchaudio
18 |
19 | from soundspaces_nvas3d.utils.aihabitat_utils import load_room_grid
20 | from soundspaces_nvas3d.utils.audio_utils import wiener_deconv_list
21 | from nvas3d.utils.audio_utils import clip_two
22 | from nvas3d.utils.utils import normalize, parse_librispeech_metadata, MP3D_SCENE_SPLITS
23 | from nvas3d.utils.generate_dataset_utils import sample_speech, sample_nonspeech, sample_acoustic_guitar, sample_all, sample_instrument, clip_source, load_ir_source_receiver, save_audio_list, compute_reverb
24 |
25 | os.makedirs('data/temp', exist_ok=True)
26 |
27 |
28 | SOURCE1_DATA = 'Guitar'
29 | SOURCE2_DATA = 'Guitar'
30 |
31 | num_id_per_room = 1
32 |
33 | random.seed(42)
34 |
35 | DATASET_NAME = f'nvas3d_square_{SOURCE1_DATA}_{SOURCE2_DATA}_queryall_{num_id_per_room}_v3'
36 | os.makedirs(f'data/{DATASET_NAME}', exist_ok=True)
37 |
38 | grid_distance = 1.0
39 | grid_distance_str = str(grid_distance).replace(".", "_")
40 | target_shape_t = 256
41 | ir_length = 72000
42 | ir_clip_idx = ir_length - 1
43 | hop_length = 480
44 | len_clip = hop_length * (target_shape_t - 1) + ir_length - 1
45 | sample_rate = 48000
46 | snr = 100
47 | audio_format = 'flac'
48 |
49 | for split in ['val']:
50 | # LibriSpeech
51 | if split == 'train':
52 | librispeech_dir = f'data/MIDI/clip/Speech/LibriSpeech48k/train'
53 | elif split == 'val':
54 | librispeech_dir = f'data/MIDI/clip/Speech/LibriSpeech48k/validation'
55 | elif split == 'test':
56 | librispeech_dir = f'data/MIDI/clip/Speech/LibriSpeech48k/test'
57 | else:
58 | librispeech_dir = f'data/MIDI/clip/Speech/LibriSpeech48k/validation'
59 | files_librispeech = glob.glob(librispeech_dir + '/**/*.flac', recursive=True)
60 | librispeech_metadata = parse_librispeech_metadata(f'data/MIDI/clip/Speech/LibriSpeech48k/SPEAKERS.TXT')
61 |
62 | # MIDI
63 | if split == 'train':
64 | all_instruments_dir = [path for path in glob.glob(os.path.join('data/MIDI/clip', '*/*', 'train')) if os.path.isdir(path)]
65 | elif split == 'val':
66 | all_instruments_dir = [path for path in glob.glob(os.path.join('data/MIDI/clip', '*/*', 'validation')) if os.path.isdir(path)]
67 | elif split == 'test':
68 | all_instruments_dir = [path for path in glob.glob(os.path.join('data/MIDI/clip', '*/*', 'test')) if os.path.isdir(path)]
69 | else:
70 | all_instruments_dir = [path for path in glob.glob(os.path.join('data/MIDI/clip', '*/*', 'validation')) if os.path.isdir(path)]
71 |
72 | # RIR
73 | if split == 'val_trainscene':
74 | split_scene = 'train'
75 | else:
76 | split_scene = split
77 | ir_dir = f'data/nvas3d_square/ir/{split_scene}/grid_{grid_distance_str}'
78 |
79 | # Image
80 | dirname_sourceimage = f'data/nvas3d_square/image/{split_scene}/grid_{grid_distance_str}'
81 |
82 | # Iterate over rooms
83 | for i_room, room in enumerate(tqdm(MP3D_SCENE_SPLITS[split_scene])):
84 | grid_points = load_room_grid(room, grid_distance)['grid_points']
85 | num_points = grid_points.shape[0]
86 | total_pairs = []
87 | filename = f'data/nvas3d_square/metadata/grid_{grid_distance_str}/{room}_square.json' # from generate_metadata_square.json
88 | with open(filename, 'r') as file:
89 | square_data = json.load(file)
90 | pairs_all = square_data['selected_pairs']
91 | # Add each pair with room id to the total list
92 |
93 | random.shuffle(pairs_all)
94 |
95 | # pairs = pairs[:num_id_per_room]
96 |
97 | pairs = []
98 | for pair in pairs_all:
99 | source_idx_list, receiver_idx_list, novel_receiver_idx = pair
100 | if (novel_receiver_idx not in source_idx_list) and (novel_receiver_idx not in receiver_idx_list):
101 | pairs.append(pair)
102 | # else:
103 | # print(f'invalid idx: {source_idx_list}, {receiver_idx_list}, {novel_receiver_idx}')
104 | if len(pairs) >= num_id_per_room:
105 | break
106 |
107 | # All IRs
108 | # Initialize a list to store all combinations
109 | all_combinations = []
110 |
111 | # Iterate over selected pairs
112 | for pair in pairs:
113 | # Unpack the pair
114 | _, receiver_idxs, _ = pair
115 | # Get all combinations of source and receiver indices
116 | comb = product(list(range(num_points)), receiver_idxs)
117 | # Add these combinations to the list
118 | all_combinations.extend(comb)
119 | all_combinations = list(set(all_combinations)) # remove redundancy
120 |
121 | # download wav files # Replace to render IR
122 | # temp_list = set()
123 | # with concurrent.futures.ThreadPoolExecutor() as executor:
124 | # for source_idx in executor.map(download_wav, all_combinations):
125 | # temp_list.add(source_idx)
126 | # temp_list = list(temp_list)
127 |
128 | # Render image
129 | dirname_target_image = f'data/{DATASET_NAME}/{split}/{room}/image'
130 | os.makedirs(dirname_target_image, exist_ok=True)
131 | query_idx_list = list(range(num_points))
132 | subprocess.run(['python', 'soundspaces_nvas3d/image_rendering/generate_target_image.py', '--room', room, '--dirname', dirname_target_image, '--source_idx_list', ' '.join(map(str, query_idx_list))])
133 |
134 | # For each pair, make data
135 | for i_pair, pair in enumerate(tqdm(pairs)):
136 | dirname = f'data/{DATASET_NAME}/{split}/{room}/{i_pair}'
137 | source_idx_list, receiver_idx_list, novel_receiver_idx = pair
138 |
139 | os.makedirs(dirname, exist_ok=True)
140 |
141 | # Compute source
142 | os.makedirs(f'{dirname}/source', exist_ok=True)
143 | if SOURCE1_DATA == 'speech':
144 | source1_audio, source1_class = sample_speech(files_librispeech, librispeech_metadata)
145 | elif SOURCE1_DATA == 'nonspeech':
146 | source1_audio, source1_class = sample_nonspeech(all_instruments_dir)
147 | elif SOURCE1_DATA == 'guitar':
148 | source1_audio, source1_class = sample_acoustic_guitar(all_instruments_dir)
149 | elif SOURCE1_DATA == 'all':
150 | source1_audio, source1_class = sample_all(all_instruments_dir, librispeech_metadata)
151 | else:
152 | source1_audio, source1_class = sample_instrument(all_instruments_dir, librispeech_metadata, SOURCE1_DATA)
153 |
154 | if SOURCE2_DATA == 'speech':
155 | source2_audio, source2_class = sample_speech(files_librispeech, librispeech_metadata)
156 | elif SOURCE2_DATA == 'nonspeech':
157 | source2_audio, source2_class = sample_nonspeech(all_instruments_dir)
158 | elif SOURCE2_DATA == 'guitar':
159 | source2_audio, source2_class = sample_acoustic_guitar(all_instruments_dir)
160 | elif SOURCE2_DATA == 'all':
161 | source2_audio, source2_audio = sample_all(all_instruments_dir, librispeech_metadata)
162 | else:
163 | source2_audio, source2_class = sample_instrument(all_instruments_dir, librispeech_metadata, SOURCE2_DATA)
164 |
165 | source1_audio, source2_audio = clip_two(source1_audio, source2_audio)
166 | if not (split == 'test' or split == 'val'): # check
167 | source1_audio, source2_audio = clip_source(source1_audio, source2_audio, len_clip)
168 | source1_audio = normalize(source1_audio)
169 | source2_audio = normalize(source2_audio)
170 |
171 | if torch.isnan(source1_audio).any() or torch.isnan(source2_audio).any():
172 | continue
173 |
174 | if split == 'test' or split == 'val':
175 | if source1_audio.shape[0] < sample_rate * 10: # skip for short (<10s)
176 | continue
177 |
178 | os.makedirs(f'{dirname}/source', exist_ok=True)
179 | torchaudio.save(f'{dirname}/source/source1.{audio_format}', source1_audio[ir_clip_idx:].unsqueeze(0), sample_rate)
180 | torchaudio.save(f'{dirname}/source/source2.{audio_format}', source2_audio[ir_clip_idx:].unsqueeze(0), sample_rate)
181 |
182 | # Save IR
183 | os.makedirs(f'{dirname}/ir_receiver', exist_ok=True)
184 | ir1_list = load_ir_source_receiver(ir_dir, room, source_idx_list[0], receiver_idx_list, ir_length)
185 | ir2_list = load_ir_source_receiver(ir_dir, room, source_idx_list[1], receiver_idx_list, ir_length)
186 | ir3_list = load_ir_source_receiver(ir_dir, room, source_idx_list[2], receiver_idx_list, ir_length)
187 | save_audio_list(f'{dirname}/ir_receiver/ir1', ir1_list, sample_rate, audio_format)
188 | save_audio_list(f'{dirname}/ir_receiver/ir2', ir2_list, sample_rate, audio_format)
189 | save_audio_list(f'{dirname}/ir_receiver/ir3', ir3_list, sample_rate, audio_format)
190 |
191 | os.makedirs(f'{dirname}/ir_novel', exist_ok=True)
192 | ir1_novel = load_ir_source_receiver(ir_dir, room, source_idx_list[0], [novel_receiver_idx], ir_length)[0]
193 | ir2_novel = load_ir_source_receiver(ir_dir, room, source_idx_list[1], [novel_receiver_idx], ir_length)[0]
194 | torchaudio.save(f'{dirname}/ir_novel/ir1_novel.{audio_format}', ir1_novel.unsqueeze(0), sample_rate)
195 | torchaudio.save(f'{dirname}/ir_novel/ir2_novel.{audio_format}', ir2_novel.unsqueeze(0), sample_rate)
196 |
197 | # Save reverb
198 | os.makedirs(f'{dirname}/reverb', exist_ok=True)
199 | reverb1_list = compute_reverb(source1_audio, ir1_list)
200 | reverb2_list = compute_reverb(source2_audio, ir2_list)
201 | save_audio_list(f'{dirname}/reverb/reverb1', reverb1_list, sample_rate, audio_format)
202 | save_audio_list(f'{dirname}/reverb/reverb2', reverb2_list, sample_rate, audio_format)
203 |
204 | # Compute receiver
205 | os.makedirs(f'{dirname}/receiver', exist_ok=True)
206 | receiver_list = [reverb1 + reverb2 for reverb1, reverb2 in zip(reverb1_list, reverb2_list)]
207 | save_audio_list(f'{dirname}/receiver/receiver', receiver_list, sample_rate, audio_format)
208 |
209 | # Compute novel
210 | os.makedirs(f'{dirname}/receiver_novel', exist_ok=True)
211 | reverb1_novel = compute_reverb(source1_audio, [ir1_novel])[0]
212 | reverb2_novel = compute_reverb(source2_audio, [ir2_novel])[0]
213 | receiver_novel = reverb1_novel + reverb2_novel
214 | torchaudio.save(f'{dirname}/receiver_novel/receiver_novel.{audio_format}', receiver_novel.unsqueeze(0), sample_rate)
215 |
216 | # # Render binaural IR and audio
217 | # if not os.path.exists(f'{dirname}/ir_novel_binaural'):
218 | # os.makedirs(f'{dirname}/ir_novel_binaural', exist_ok=True)
219 | # ir_novel_list, source_idx_pair_list, receiver_idx_pair_list = render_ir_parallel_room_idx(room, source_idx_list, [novel_receiver_idx], filename=None, use_default_material=False, channel_type='Binaural')
220 | # for idx_novel, ir_novel in enumerate(ir_novel_list[:2]): # first two sources
221 | # if ir_novel[0].shape[0] > ir_length:
222 | # ir_binaural = ir_novel[:][:ir_length]
223 | # else:
224 | # ir_binaural = F.pad(ir_novel[:], (0, ir_length - ir_novel.shape[1]))
225 | # torchaudio.save(f'{dirname}/ir_novel_binaural/ir{idx_novel+1}_novel_binaural.{audio_format}', ir_binaural, sample_rate)
226 |
227 | # # set image rendering
228 | # all_receiver_idx_list = receiver_idx_list.copy()
229 | # all_receiver_idx_list.append(novel_receiver_idx)
230 |
231 | # def run_function(function_name, *args): # use subprocess because of memory leak in soundspaces
232 | # subprocess.run(["python", "script/render_scene_script.py", function_name, *map(str, args)])
233 |
234 | # # Render novel-view RGBD image
235 | # os.makedirs(f'{dirname}/image_receiver', exist_ok=True)
236 | # source_class_list = [source1_class, source2_class]
237 | # run_function("receiver", f'{dirname}/image_receiver', room, source_idx_list[:2], source_class_list, all_receiver_idx_list)
238 |
239 | # # Save topdown view
240 | # run_function("scene", f'{dirname}/config.png', room, source_idx_list[:2], all_receiver_idx_list)
241 |
242 | # Save metadata
243 | metadata = {
244 | 'source1_idx': source_idx_list[0],
245 | 'source2_idx': source_idx_list[1],
246 | 'receiver_idx_list': receiver_idx_list,
247 | 'novel_receiver_idx': novel_receiver_idx,
248 | 'source1_class': source1_class,
249 | 'source2_class': source2_class,
250 | 'grid_points': grid_points,
251 | 'room': room,
252 | }
253 |
254 | torch.save(metadata, f'{dirname}/metadata.pt')
255 |
256 | # Query
257 | for query_idx in range(num_points):
258 | if (query_idx in receiver_idx_list):
259 | continue
260 |
261 | # Load IR
262 | ir_query_list = load_ir_source_receiver(f'data/temp/ir', room, query_idx, receiver_idx_list, ir_length)
263 |
264 | # Save Weiner
265 | os.makedirs(f'{dirname}/wiener', exist_ok=True)
266 | wiener_list = wiener_deconv_list(receiver_list, ir_query_list, snr)
267 |
268 | save_audio_list(f'{dirname}/wiener/wiener{query_idx}', wiener_list, sample_rate, audio_format)
269 |
270 | # # debug delay
271 | # delay_ir = torch.nonzero(ir_query_list[0], as_tuple=True)[0][0]
272 | # receiver_point = grid_points[receiver_idx_list[0]]
273 | # query_point = grid_points[query_idx]
274 |
275 | # c = 343
276 | # dist = (query_point - receiver_point).norm() - 0.1 # soundspaces offset
277 | # delay_dist = (int(dist / c * sample_rate))
278 | # print(delay_ir, delay_dist)
279 | # pass
280 |
--------------------------------------------------------------------------------
/nvas3d/train/trainer.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import os
7 | import glob
8 | import time
9 | import logging
10 | import typing as T
11 | import matplotlib.pyplot as plt
12 |
13 | from tqdm import tqdm
14 | from collections import defaultdict
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torch.distributed as dist
19 | from torch.utils.tensorboard import SummaryWriter
20 |
21 | from nvas3d.utils.audio_utils import spec2wav, compute_metrics
22 | from nvas3d.utils.utils import source_class_map, get_key_from_value
23 | from nvas3d.utils.plot_utils import plot_debug
24 |
25 |
26 | class Trainer:
27 | def __init__(
28 | self,
29 | model: torch.nn.Module,
30 | data_loader,
31 | device: torch.device,
32 | save_dir: str,
33 | use_deconv: bool,
34 | config: T.Dict[str, T.Any],
35 | save_audio: bool = False
36 | ) -> None:
37 | """
38 | Initializes the Trainer class.
39 |
40 | Args:
41 | - model: Model to be trained.
42 | - data_loader: Data loader supplying training data.
43 | - device: Device (CPU or GPU) to be used for training.
44 | - save_dir: Directory to save the model.
45 | - use_deconv: Flag to use deconvolution.
46 | - config: Configuration with training parameters.
47 | """
48 |
49 | # Args
50 | self.model = model.to(device, dtype=torch.float)
51 | self.data_loader = data_loader
52 | self.device = device
53 | self.save_dir = save_dir
54 | self.use_deconv = use_deconv
55 | self.save_audio = save_audio
56 |
57 | # Optimzer
58 | self.optimizer = torch.optim.Adam(
59 | filter(lambda p: p.requires_grad, self.model.parameters()),
60 | lr=config['lr'],
61 | weight_decay=config['weight_decay']
62 | )
63 |
64 | # Loss
65 | self.regressor_criterion = nn.MSELoss().to(device=self.device)
66 | self.bce_criterion = nn.BCEWithLogitsLoss().to(device=self.device)
67 |
68 | # Training configuration
69 | self.epochs = config['epochs']
70 | self.resume = config['resume']
71 | self.detect_loss_weight = config['detect_loss_weight']
72 | self.save_checkpoint_interval = config['save_checkpoint_interval']
73 |
74 | # Tensorboard
75 | tb_dir = f'{self.save_dir}/tb'
76 | os.makedirs(tb_dir, exist_ok=True)
77 | self.writer = SummaryWriter(log_dir=tb_dir)
78 |
79 | # Resume
80 | if self.resume:
81 | print(f'Resume from {config["checkpoint_dir"]}...')
82 | latest_checkpoint = self.find_latest_checkpoint(config['checkpoint_dir'])
83 | if latest_checkpoint is not None:
84 | start_epoch = self.load_checkpoint(latest_checkpoint) + 1
85 | else:
86 | start_epoch = 0
87 | else:
88 | start_epoch = 0
89 | self.start_epoch = start_epoch
90 |
91 | def train(self):
92 | scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, pow(0.1, 1 / self.epochs))
93 |
94 | since = time.time()
95 |
96 | for epoch in range(self.start_epoch, self.epochs + 1):
97 | logging.info('-' * 10)
98 | logging.info('Epoch {}/{}'.format(epoch, self.epochs))
99 |
100 | # Set DistributedSampler for each epoch if model is DDP
101 | if isinstance(self.model, torch.nn.parallel.DistributedDataParallel) and self.model.training:
102 | self.data_loader.train_sampler.set_epoch(epoch)
103 |
104 | # Initialize variables to store metrics
105 | running_loss = defaultdict(float)
106 | running_metrics = defaultdict(float)
107 | num_data_point = defaultdict(int)
108 | count_metrics = defaultdict(int)
109 |
110 | for train_mode in ['Train', 'Val']:
111 | # Skip validation if not saving checkpoint this epoch
112 | if train_mode == 'Val' and epoch % self.save_checkpoint_interval != 0:
113 | continue
114 |
115 | data_loader = (self.data_loader.get_train_data() if train_mode == 'Train'
116 | else self.data_loader.get_val_data())
117 |
118 | # Set mode
119 | if train_mode == 'Train':
120 | self.model.train()
121 | else:
122 | self.model.eval()
123 |
124 | with torch.set_grad_enabled(train_mode == 'Train'):
125 | pbar = tqdm(data_loader, desc=f"Epoch {epoch}")
126 | for data, room, id in pbar:
127 | # Preprocess
128 | for key, value in data.items():
129 | data[key] = value.to(device=self.device, dtype=torch.float)
130 |
131 | # Set mask
132 | positive_mask = (data['is_source'] > 0.5).reshape(-1)
133 | nonzero_mask = torch.any(data['source1_audio'], dim=1)
134 | nan_mask = torch.any(torch.isnan(data['source1_audio']), dim=1)
135 | valid_mask = nonzero_mask & ~nan_mask
136 | valid_positive_mask = positive_mask & valid_mask
137 | if nan_mask.any():
138 | nan_room = [x for x, m in zip(room, nan_mask) if m]
139 | nan_id = [x for x, m in zip(id, nan_mask) if m]
140 | print(f'{nan_room}, {nan_id}: invalid audio')
141 | data['input_stft'][nan_mask] = 0 # will ignore in the loss
142 |
143 | # Parse data
144 | source1_stft = data['source1_stft'] # [n_batch, 2, F, N]
145 | input_stft = data['input_stft'] # [n_batch, n_receivers*2, F, N]
146 |
147 | source1_stft_nondc = source1_stft[:, :, 1:, :]
148 |
149 | # Zero the parameter gradients
150 | self.optimizer.zero_grad()
151 |
152 | # Forward
153 | output = self.model(data)
154 | pred_stft = output['pred_stft'] # [n_batch, n_receivers*2, F, N]
155 |
156 | # Compute detection loss on valid mask
157 | detect_loss = self.bce_criterion(output['source_detection'][valid_mask], data['is_source'][valid_mask])
158 |
159 | # Compute magnitude loss on valid positive mask, if any
160 | mag_loss = (self.regressor_criterion(pred_stft[valid_positive_mask], source1_stft_nondc[valid_positive_mask])
161 | if valid_positive_mask.any() else torch.tensor(0.0, device=self.device))
162 |
163 | # Combine losses
164 | loss = mag_loss + detect_loss * self.detect_loss_weight
165 | assert not torch.isnan(loss), 'Loss is NaN'
166 |
167 | # Backward
168 | if train_mode == 'Train':
169 | loss.backward()
170 | self.optimizer.step()
171 |
172 | # Save metrics
173 | if epoch % self.save_checkpoint_interval == 0 and valid_positive_mask.any():
174 | save_idx = valid_positive_mask.nonzero(as_tuple=True)[0][0]
175 | room = room[save_idx]
176 | id_value = int(id[save_idx])
177 |
178 | # Prepare audio data
179 | full_pred_stft_save = input_stft[save_idx].clone().permute(1, 2, 0)[..., :2]
180 | pred_stft_save = pred_stft[save_idx].permute(1, 2, 0)
181 | full_pred_stft_save[1:, :, :] = pred_stft_save # keep dc from input
182 | pred_audio = spec2wav(full_pred_stft_save)
183 | source1_audio_save = data['source1_audio'][save_idx]
184 |
185 | # Save debug data
186 | if self.save_audio:
187 | self.save_audio_data(save_idx, pred_audio, source1_audio_save, data, train_mode, room, id_value, epoch)
188 |
189 | # Compute metrics for dry-sound estimation
190 | dry_psnr, dry_sdr = compute_metrics(pred_audio, source1_audio_save)
191 |
192 | # Compute metrics for detection
193 | predicted = (output['source_detection'] > 0.5).int()
194 | accuracy_true = (predicted[data['is_source'] == 1] == 1).float().mean().item()
195 | accuracy_false = (predicted[data['is_source'] == 0] == 0).float().mean().item()
196 | detection_accuracy = (predicted == data['is_source']).float().mean().item()
197 |
198 | metrics_dict = {
199 | 'dry_psnr': dry_psnr,
200 | 'dry_sdr': dry_sdr,
201 | 'detection_accuracy': detection_accuracy,
202 | 'accuracy_true': accuracy_true,
203 | 'accuracy_false': accuracy_false
204 | }
205 |
206 | # Metric reduction for DDP
207 | if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
208 | for metric, value in metrics_dict.items():
209 | value_tensor = torch.tensor(value).to(self.device)
210 | dist.all_reduce(value_tensor, op=dist.ReduceOp.SUM)
211 | running_metrics[f'{train_mode}/{metric}'] += value_tensor.item()
212 | count_metrics[f'{train_mode}/{metric}'] += dist.get_world_size()
213 | else:
214 | for metric, value in metrics_dict.items():
215 | running_metrics[f'{train_mode}/{metric}'] += value
216 | count_metrics[f'{train_mode}/{metric}'] += 1
217 |
218 | # Loss reduction for DDP
219 | loss_dict = {
220 | 'total_loss': loss,
221 | 'mag_loss': mag_loss,
222 | 'detect_loss': detect_loss
223 | }
224 | B = input_stft.shape[0]
225 |
226 | if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
227 | for loss_type, loss_value in loss_dict.items():
228 | loss_tensor = torch.tensor(loss_value).to(self.device)
229 | dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
230 | normalized_loss = loss_tensor.item() / dist.get_world_size()
231 | running_loss[f'{train_mode}/{loss_type}'] += normalized_loss
232 | else:
233 | for loss_type, loss_value in loss_dict.items():
234 | current_loss = loss_dict[loss_type].item() * B
235 | running_loss[f'{train_mode}/{loss_type}'] += current_loss
236 | num_data_point[train_mode] += B
237 |
238 | # Log, save checkpoints, write tensorboard
239 | should_log = not isinstance(self.model, torch.nn.parallel.DistributedDataParallel) or (dist.get_rank() == 0)
240 | if should_log:
241 | # Log and save checkpoints
242 | self.log_and_save(running_loss, running_metrics, epoch)
243 |
244 | # Write tensorboard
245 | for key, value in running_loss.items():
246 | self.writer.add_scalar(key, value, epoch)
247 | for key, value in running_metrics.items():
248 | self.writer.add_scalar(key, value, epoch)
249 |
250 | scheduler.step()
251 |
252 | time_elapsed = time.time() - since
253 | logging.info(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
254 |
255 | print('Training finished.')
256 |
257 | def save_checkpoint(self, epoch):
258 | """
259 | Save checkpoint.
260 | """
261 |
262 | if isinstance(self.model, torch.nn.parallel.DistributedDataParallel):
263 | checkpoint = {
264 | 'epoch': epoch,
265 | 'model_state_dict': self.model.module.state_dict(),
266 | 'optimizer_state_dict': self.optimizer.state_dict()
267 | }
268 |
269 | else:
270 | checkpoint = {
271 | 'epoch': epoch,
272 | 'model_state_dict': self.model.state_dict(),
273 | 'optimizer_state_dict': self.optimizer.state_dict()
274 | }
275 |
276 | filename = f'checkpoints/checkpoint_{epoch}.pt'
277 | os.makedirs(os.path.join(self.save_dir, 'checkpoints'), exist_ok=True)
278 | torch.save(checkpoint, os.path.join(self.save_dir, filename))
279 |
280 | def load_checkpoint(self, checkpoint_path):
281 | """
282 | Load checkpoint.
283 | """
284 |
285 | checkpoint = torch.load(checkpoint_path)
286 | self.model.load_state_dict(checkpoint['model_state_dict'])
287 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
288 | return checkpoint['epoch']
289 |
290 | def find_latest_checkpoint(self, checkpoint_dir):
291 | """
292 | Find latest checkpoint.
293 | """
294 |
295 | checkpoint_files = glob.glob(os.path.join(checkpoint_dir, '**', 'checkpoint_*.pt'), recursive=True)
296 | if checkpoint_files:
297 | latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
298 | return latest_checkpoint
299 | else:
300 | return None
301 |
302 | def log_and_save(self, running_loss, running_metrics, epoch):
303 | """
304 | Log values to Tensorboard and save model checkpoint.
305 | """
306 |
307 | # Logging to Tensorboard
308 | for key, value in running_loss.items():
309 | self.writer.add_scalar(key, value, epoch)
310 | for key, value in running_metrics.items():
311 | self.writer.add_scalar(key, value, epoch)
312 |
313 | # Save checkpoint
314 | if epoch % self.save_checkpoint_interval == 0:
315 | self.save_checkpoint(epoch)
316 |
317 | def save_audio_data(self, save_idx, pred_audio, source1_audio_save, data, train_mode, room, id_value, epoch, sample_rate=48000, save_normalized=True):
318 | """
319 | Function to save audio data.
320 | """
321 |
322 | # Set directory
323 | source1_class_save = get_key_from_value(source_class_map, data['source1_class'][save_idx])
324 | source2_class_save = get_key_from_value(source_class_map, data['source2_class'][save_idx])
325 | debug_dir = f'{self.save_dir}/{train_mode}_audio/{source1_class_save}-{source2_class_save}'
326 |
327 | # Prepare audio
328 | input_audio_mean = data['input_audio'][save_idx].mean(dim=0)
329 | audio_name_list = ['pred', 'input_mean', 'pred_gt', 'input_0', 'receiver_0']
330 | audio_data_list = [pred_audio, input_audio_mean, source1_audio_save, data['input_audio'][save_idx][0], data['receiver_audio'][save_idx][0]]
331 |
332 | # Save
333 | os.makedirs(f'{debug_dir}/{room}/{id_value}', exist_ok=True)
334 | for audio_name, audio_data in zip(audio_name_list, audio_data_list):
335 | filename_val = f'{debug_dir}/{room}/{id_value}/{epoch}_{audio_name}'
336 | plot_debug(filename_val, audio_data.reshape(1, -1).detach().cpu(),
337 | sample_rate=sample_rate, save_normalized=save_normalized, reference=source1_audio_save)
338 |
339 | # Save RGB image if available
340 | if 'rgb' in data.keys():
341 | filename_val = f'{debug_dir}/{room}/{id_value}/{epoch}_rgb.png'
342 | plt.imsave(filename_val, data['rgb'][save_idx].permute(1, 2, 0).detach().cpu().numpy())
343 |
--------------------------------------------------------------------------------
/nvas3d/baseline/baseline_dsp.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2023 Apple Inc. All Rights Reserved.
4 | #
5 |
6 | import os
7 | import json
8 | import random
9 | import argparse
10 |
11 | import torch
12 | import torchaudio
13 | from tqdm import tqdm
14 | from scipy.signal import fftconvolve
15 | from torchmetrics.classification import BinaryAUROC
16 |
17 | from nvas3d.utils.audio_utils import cs_list, compute_metrics_si
18 |
19 |
20 | def normalize(audio):
21 | """
22 | Normalize the signal making the maximum of absolute value to be 1
23 | """
24 | return audio / audio.abs().max()
25 |
26 |
27 | def compute_delayed(audio_list, delay_list):
28 | """
29 | Remove delay of audios
30 | """
31 |
32 | delayed_audio_list = []
33 | for audio, delay in zip(audio_list, delay_list):
34 | delayed_audio = torch.zeros_like(audio)
35 | delayed_audio[:, :-delay] = audio[:, delay:]
36 |
37 | delayed_audio_list.append(delayed_audio)
38 | return delayed_audio_list
39 |
40 |
41 | class BaselineDSP:
42 | def __init__(self, split, num_receivers, audio_format, dataset_name):
43 | self.num_receivers = num_receivers
44 | self.split = split
45 | self.audio_format = audio_format
46 | self.dataset_name = dataset_name
47 | self.dataset_dir = f'data/{self.dataset_name}/{split}'
48 | self.room_id_pairs = self._get_room_id_pairs()
49 | # random.shuffle(self.room_id_pairs)
50 |
51 | def _get_room_id_pairs(self):
52 | pairs = []
53 | for room in os.listdir(self.dataset_dir):
54 | room_path = os.path.join(self.dataset_dir, room)
55 | if os.path.isdir(room_path): # Ensure room_path is a directory
56 | for id in os.listdir(room_path):
57 | id_path = os.path.join(room_path, id)
58 | if os.path.isdir(id_path) and id != 'image': # Ensure id_path is a directory and its name isn't 'image'
59 | pairs.append((room, id))
60 | return pairs
61 |
62 | def load_audio_list(self, audio_list_name):
63 | audio_list = []
64 | for m in range(self.num_receivers):
65 | audio, _ = torchaudio.load(f'{audio_list_name}_{m+1}.{self.audio_format}')
66 | audio_list.append(audio)
67 | return audio_list
68 |
69 | def separation(self, num_eval=None):
70 | """
71 | Baseline DSP for separation using delay-and-sum and deconv-and-sum
72 | """
73 |
74 | # metrics_names = ['psnr', 'sdr', 'sipsnr', 'sisdr']
75 | metrics_names = ['psnr', 'sdr']
76 | methods = ['dry1_deconv', 'dry2_deconv', 'reverb1_deconv', 'reverb2_deconv',
77 | 'dry1_delay', 'dry2_delay', 'reverb1_delay', 'reverb2_delay']
78 | metrics = {method: {metric: [] for metric in metrics_names} for method in methods}
79 |
80 | loop_length = num_eval if num_eval is not None else len(self.room_id_pairs)
81 | for room_id_pair in tqdm(self.room_id_pairs[:loop_length]):
82 | room, id = room_id_pair
83 | data_dir = f'{self.dataset_dir}/{room}/{id}'
84 | # Set query: positive or negative
85 |
86 | metadata = torch.load(f'{data_dir}/metadata.pt')
87 | grid_points = metadata['grid_points']
88 | num_points = grid_points.shape[0]
89 | source1_idx = metadata['source1_idx']
90 | source2_idx = metadata['source2_idx']
91 | receiver_idx_list = metadata['receiver_idx_list']
92 |
93 | source1_audio, _ = torchaudio.load(f'{data_dir}/source/source1.{self.audio_format}')
94 | source1_audio = source1_audio[0]
95 | source2_audio, _ = torchaudio.load(f'{data_dir}/source/source2.{self.audio_format}')
96 | source2_audio = source2_audio[0]
97 | reverb1_audio = self.load_audio_list(f'{data_dir}/reverb/reverb1')
98 | reverb1_audio = reverb1_audio[0][0]
99 | reverb2_audio = self.load_audio_list(f'{data_dir}/reverb/reverb2')
100 | reverb2_audio = reverb2_audio[0][0]
101 | receiver_audio = self.load_audio_list(f'{data_dir}/receiver/receiver')
102 |
103 | ir1 = self.load_audio_list(f'{data_dir}/ir_receiver/ir1')
104 | ir2 = self.load_audio_list(f'{data_dir}/ir_receiver/ir2')
105 |
106 | for query_idx in range(num_points):
107 | if (query_idx not in receiver_idx_list):
108 | deconv_audio = self.load_audio_list(f'{data_dir}/wiener/wiener{query_idx}')
109 |
110 | deconv_and_sum = torch.stack(deconv_audio, dim=0).reshape(self.num_receivers, -1).mean(dim=0)
111 |
112 | # compute delay from distance
113 | c = 343
114 | sample_rate = 48000
115 | query_point = grid_points[query_idx]
116 | delay_list = []
117 | for receiver_idx in receiver_idx_list:
118 | receiver_point = grid_points[receiver_idx]
119 | dist = (query_point - receiver_point).norm() - 0.1 # soundspaces offset
120 | delay_list.append(int(dist / c * sample_rate))
121 |
122 | delayed_audio = compute_delayed(receiver_audio, delay_list)
123 | delay_and_sum_0 = torch.stack(delayed_audio, dim=0).reshape(self.num_receivers, -1).mean(dim=0)
124 |
125 | # # apply delay again to align with reverb1 for fair evaluation
126 | # delay = delay_list[0]
127 | # delay_and_sum = torch.zeros_like(delay_and_sum_0)
128 | # delay_and_sum[delay:] = delay_and_sum_0[:-delay]
129 |
130 | if query_idx == source1_idx:
131 | dry1 = deconv_and_sum
132 | dry1_delay = delay_and_sum_0
133 |
134 | receiver_point = grid_points[receiver_idx_list[0]]
135 | dist = (query_point - receiver_point).norm() - 0.1 # soundspaces offset
136 | delay = int(dist / c * sample_rate)
137 | reverb1_delay = torch.zeros_like(delay_and_sum_0)
138 | reverb1_delay[delay:] = delay_and_sum_0[:-delay]
139 |
140 | # reverb1
141 | elif query_idx == source2_idx:
142 | dry2 = deconv_and_sum
143 | dry2_delay = delay_and_sum_0
144 |
145 | receiver_point = grid_points[receiver_idx_list[0]]
146 | dist = (query_point - receiver_point).norm() - 0.1 # soundspaces offset
147 | delay = int(dist / c * sample_rate)
148 | reverb2_delay = torch.zeros_like(delay_and_sum_0)
149 | reverb2_delay[delay:] = delay_and_sum_0[:-delay]
150 |
151 | m = 0
152 | ir1_m = ir1[m][0]
153 | pred_reverb1 = torch.from_numpy(fftconvolve(dry1, ir1_m)[:len(reverb1_audio)])
154 |
155 | ir2_m = ir2[m][0]
156 | pred_reverb2 = torch.from_numpy(fftconvolve(dry2, ir2_m)[:len(reverb2_audio)])
157 |
158 | def compute_and_save_metrics(processed_audio, reference_audio, metric_lists, filename, save_audio=True):
159 | psnr_score, sdr_score, _, _ = compute_metrics_si(processed_audio, reference_audio)
160 |
161 | metric_lists['psnr'].append(psnr_score)
162 | metric_lists['sdr'].append(sdr_score)
163 |
164 | # Saving the processed audio
165 | if save_audio:
166 | os.makedirs(os.path.dirname(filename), exist_ok=True) # Ensuring the directory exists
167 | torchaudio.save(filename, normalize(processed_audio.unsqueeze(0)), sample_rate=48000)
168 |
169 | current_dir = f'test_results/{self.dataset_name}/baseline/{self.split}/{room}/{id}'
170 | os.makedirs(current_dir, exist_ok=True)
171 |
172 | # dry (deconv-and-sum)
173 | filename = f'{current_dir}/dry1.wav'
174 | compute_and_save_metrics(dry1, source1_audio, metrics['dry1_deconv'], filename)
175 | filename = f'{current_dir}/dry2.wav'
176 | compute_and_save_metrics(dry2, source2_audio, metrics['dry2_deconv'], filename)
177 |
178 | # dry (delay-and-sum)
179 | filename = f'{current_dir}/dry1_delay.wav'
180 | compute_and_save_metrics(dry1_delay, source1_audio, metrics['dry1_delay'], filename)
181 | filename = f'{current_dir}/dry2_delay.wav'
182 | compute_and_save_metrics(dry2_delay, source2_audio, metrics['dry2_delay'], filename)
183 |
184 | # reverb (deconv-and-sum)
185 | filename = f'{current_dir}/reverb1.wav'
186 | compute_and_save_metrics(pred_reverb1, reverb1_audio, metrics['reverb1_deconv'], filename)
187 | filename = f'{current_dir}/reverb2.wav'
188 | compute_and_save_metrics(pred_reverb2, reverb2_audio, metrics['reverb2_deconv'], filename)
189 |
190 | # reverb (delay-and-sum)
191 | filename = f'{current_dir}/reverb1_delay.wav'
192 | compute_and_save_metrics(reverb1_delay, reverb1_audio, metrics['reverb1_delay'], filename)
193 | filename = f'{current_dir}/reverb2_delay.wav'
194 | compute_and_save_metrics(reverb2_delay, reverb2_audio, metrics['reverb2_delay'], filename)
195 |
196 | # Print metrics
197 | for method, method_metrics in metrics.items():
198 | print(f'\n[{method} audio]')
199 | for metric_name, metric_values in method_metrics.items():
200 | tensor_values = torch.tensor(metric_values)
201 | is_valid = torch.isfinite(tensor_values) & ~torch.isnan(tensor_values)
202 | valid_values = tensor_values[is_valid] # Exclude both 'inf' and 'nan' values
203 | print(f'{method} {metric_name}: {torch.tensor(valid_values).mean()}')
204 |
205 | # Build the dictionary with averages
206 | averaged_metrics = {}
207 | for method, method_metrics in metrics.items():
208 | averaged_metrics[method] = {}
209 | for metric_name, metric_values in method_metrics.items():
210 | tensor_values = torch.tensor(metric_values)
211 | is_valid = torch.isfinite(tensor_values) & ~torch.isnan(tensor_values)
212 | valid_values = tensor_values[is_valid] # Exclude both 'inf' and 'nan' values
213 | print(f'{method} {metric_name}: {torch.tensor(valid_values).mean()}')
214 | averaged_value = float(valid_values.mean()) if valid_values.nelement() > 0 else float('nan') # Avoid potential empty tensor
215 | averaged_metrics[method][metric_name] = averaged_value
216 |
217 | # Save to a JSON file
218 | filename_json = f'test_results/{self.dataset_name}/baseline/{self.split}/metrics.json'
219 | with open(filename_json, 'w') as json_file:
220 | json.dump(averaged_metrics, json_file, indent=4)
221 |
222 | def detection(self, num_eval=None):
223 | """
224 | Baseline DSP for detection using cosine similarity of delay signals or deconvolved signals
225 | """
226 |
227 | cs_deconv_list = []
228 | cs_delay_list = []
229 | cs_receiver_list = []
230 | positive_list = []
231 |
232 | loop_length = num_eval if num_eval is not None else len(self.room_id_pairs)
233 | for room_id_pair in tqdm(self.room_id_pairs[:loop_length]):
234 | room, id = room_id_pair
235 | data_dir = f'{self.dataset_dir}/{room}/{id}'
236 |
237 | metadata = torch.load(f'{data_dir}/metadata.pt')
238 | grid_points = metadata['grid_points']
239 | num_points = grid_points.shape[0]
240 | source1_idx = metadata['source1_idx']
241 | source2_idx = metadata['source2_idx']
242 | source_idx_list = [source1_idx, source2_idx]
243 | receiver_idx_list = metadata['receiver_idx_list']
244 | receiver_audio = self.load_audio_list(f'{data_dir}/receiver/receiver')
245 |
246 | for query_idx in range(num_points):
247 | if (query_idx not in receiver_idx_list):
248 | deconv_audio = self.load_audio_list(f'{data_dir}/wiener/wiener{query_idx}')
249 |
250 | # compute delay from distance
251 | c = 343
252 | sample_rate = 48000
253 | query_point = grid_points[query_idx]
254 | delay_list = []
255 | for receiver_idx in receiver_idx_list:
256 | receiver_point = grid_points[receiver_idx]
257 | dist = (query_point - receiver_point).norm() - 0.1 # soundspaces offset
258 | delay_list.append(int(dist / c * sample_rate))
259 |
260 | delayed_audio = compute_delayed(receiver_audio, delay_list)
261 |
262 | cs = cs_list(deconv_audio)
263 | cs_delay = cs_list(delayed_audio)
264 | cs_receiver = cs_list(receiver_audio)
265 | if query_idx in source_idx_list:
266 | positive_list.append(True)
267 | else:
268 | positive_list.append(False)
269 |
270 | cs_deconv_list.append(cs)
271 | cs_delay_list.append(cs_delay)
272 | cs_receiver_list.append(cs_receiver)
273 |
274 | cs_deconv_list = torch.tensor(cs_deconv_list)
275 | cs_delay_list = torch.tensor(cs_delay_list)
276 | cs_receiver_list = torch.tensor(cs_receiver_list)
277 | positive_list = torch.tensor(positive_list)
278 | print(f'deconv_p: {cs_deconv_list[positive_list].mean()}')
279 | print(f'deconv_n: {cs_deconv_list[~positive_list].mean()}')
280 | print(f'delay_p: {cs_delay_list[positive_list].mean()}')
281 | print(f'delay_n: {cs_delay_list[~positive_list].mean()}')
282 | print(f'receiver_p: {cs_receiver_list[positive_list].mean()}')
283 | print(f'receiver_n: {cs_receiver_list[~positive_list].mean()}')
284 |
285 | # AUC
286 | metric = BinaryAUROC(thresholds=None)
287 | auc_deconv = metric(cs_deconv_list, positive_list)
288 | auc_delay = metric(cs_delay_list, positive_list)
289 | auc_receiver = metric(cs_receiver_list, positive_list)
290 | print(f'AUC (receiver, delay, deconv): {auc_receiver}, {auc_delay}, {auc_deconv}')
291 |
292 | # # search best threshold
293 | # cs_deconv_p_tensor = torch.tensor(cs_deconv_p_list)
294 | # cs_deconv_n_tensor = torch.tensor(cs_deconv_n_list)
295 | # th_list = torch.arange(-0.1, 0.3, 0.01)
296 | # accuracy_list = []
297 | # for th in th_list:
298 | # accuracy_p = (cs_deconv_p_tensor > th).sum()
299 | # accuracy_n = (cs_deconv_n_tensor <= th).sum()
300 | # accuracy_ = accuracy_p + accuracy_n
301 | # accuracy = float(accuracy_) / (len(cs_deconv_p_tensor) + len(cs_deconv_n_tensor))
302 | # accuracy_list.append(accuracy)
303 |
304 | # max_accuracy = torch.tensor(accuracy_list).max()
305 | # max_idx = accuracy_list.index(max_accuracy)
306 | # max_th = th_list[max_idx]
307 |
308 | # print(f'max_accuracy: {max_accuracy}')
309 | # print(f'max_th: {max_th}')
310 | # # print(accuracy_list)
311 |
312 | # accuracy_p = (cs_deconv_p_tensor > max_th).float().mean()
313 | # accuracy_n = (cs_deconv_n_tensor <= max_th).float().mean()
314 | # print(accuracy_p)
315 | # print(accuracy_n)
316 |
317 | # max_th
318 |
319 |
320 | if __name__ == "__main__":
321 | random.seed(42)
322 |
323 | # Config parsing
324 | parser = argparse.ArgumentParser()
325 | parser.add_argument('--task', type=str, default='separation') # separation, detection
326 | parser.add_argument('--split', type=str, default='demo')
327 | parser.add_argument('--num_receivers', type=int, default=4)
328 | parser.add_argument('--audio_format', type=str, default='flac')
329 | parser.add_argument('--dataset_name', type=str, default='nvas3d_demo')
330 | parser.add_argument('--num_eval', type=int, default=None) # None to test all data
331 | args = parser.parse_args()
332 |
333 | # Run baseline
334 | baseline = BaselineDSP(args.split, args.num_receivers, args.audio_format, args.dataset_name)
335 | if args.task == 'separation':
336 | baseline.separation(args.num_eval)
337 | elif args.task == 'detection':
338 | baseline.detection(args.num_eval)
339 |
--------------------------------------------------------------------------------