├── .gitignore ├── LICENSE ├── README.md ├── data ├── Classes.yaml ├── README.md ├── compile │ ├── consolidate_datasets.py │ ├── disco_noise_label_collector.py │ ├── esc50_label_collector.py │ ├── fsd50k_label_collector.py │ ├── musdb18_label_collector.py │ └── ontology.py ├── multi_ch_simulator.py ├── ontology.json └── utils.py ├── experiments └── dc_waveformer │ └── config.json ├── requirements.txt └── src ├── __init__.py ├── helpers ├── __init__.py ├── eval_utils.py └── utils.py └── training ├── __init__.py ├── datasets ├── curated_binaural.py ├── curated_binaural_augrir.py └── semaudio_binaural_base.py ├── dcc_tf.py ├── dcc_tf_binaural.py ├── eval.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | 43 | # Editor 44 | *.swp 45 | 46 | # Checkpoints 47 | *.ckpt 48 | *.pth 49 | *.pt 50 | 51 | # Logs 52 | *.log 53 | 54 | # Data 55 | data/*/ 56 | 57 | # Model files 58 | *.onnx 59 | *.tflite 60 | *.mlmodel 61 | *_pb/ 62 | *required_operators.config 63 | *required_operators.with_runtime_opt.config 64 | 65 | .DS_Store 66 | *.slurm 67 | 68 | */tmp/ 69 | tmp.ipynb 70 | experiments/*/results* 71 | experiments/*/*.csv 72 | .ipynb_checkpoints/ 73 | vscode.* 74 | tensorboard/ 75 | 76 | # Xcode 77 | *.xcworkspace/ 78 | Pods/ 79 | Podfile.lock 80 | 81 | # ONNX 82 | *.onnx 83 | *.ort 84 | 85 | # STEAM AUDIO API 86 | motion/libs/** 87 | motion/obj 88 | motion/moving_sources 89 | motion/*.wav 90 | motion/audio 91 | 92 | __pycache__/ 93 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Bandhav Veluri 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Semantic Hearing 2 | 3 | [![Gradio demo](https://img.shields.io/badge/DL.ACM-abs-green)](https://dl.acm.org/doi/10.1145/3586183.3606779) [![Gradio demo](https://img.shields.io/badge/DL.ACM-pdf-green)](https://dl.acm.org/doi/pdf/10.1145/3586183.3606779) 4 | 5 | This repository provides code for the binaural target sound extraction model proposed in the paper, _Semantic Hearing: Programming Acoustic Scenes with Binaural Hearables_, presented at UIST'23. This model helps us create systems that let you control what you want to hear in the environment, in real-time, using noise-cancelling earbuds & headphones. 6 | 7 | https://github.com/vb000/SemanticHearing/assets/16723254/f1b33d8c-179a-4d50-92aa-6a99dde696d0 8 | 9 | ## Conda environment setup 10 | 11 | conda create --name semhear python=3.8 12 | conda activate semhear 13 | pip install -r requirements.txt 14 | 15 | ## Training 16 | 17 | # Data 18 | wget -P data https://semantichearing.cs.washington.edu/BinauralCuratedDataset.tar 19 | 20 | # Train 21 | python -m src.training.train experiments/dc_waveformer --use_cuda 22 | 23 | ## Evaluation 24 | 25 | # Checkpoint 26 | wget -P experiments/dc_waveformer https://semantichearing.cs.washington.edu/39.pt 27 | 28 | # Eval 29 | python -m src.training.eval experiments/dc_waveformer --use_cuda 30 | 31 | ### BibTeX 32 | 33 | ``` 34 | @inproceedings{10.1145/3586183.3606779, 35 | author = {Veluri, Bandhav and Itani, Malek and Chan, Justin and Yoshioka, Takuya and Gollakota, Shyamnath}, 36 | title = {Semantic Hearing: Programming Acoustic Scenes with Binaural Hearables}, 37 | year = {2023}, 38 | isbn = {9798400701320}, 39 | publisher = {Association for Computing Machinery}, 40 | address = {New York, NY, USA}, 41 | url = {https://doi.org/10.1145/3586183.3606779}, 42 | doi = {10.1145/3586183.3606779}, 43 | abstract = {Imagine being able to listen to the birds chirping in a park without hearing the chatter from other hikers, or being able to block out traffic noise on a busy street while still being able to hear emergency sirens and car honks. We introduce semantic hearing, a novel capability for hearable devices that enables them to, in real-time, focus on, or ignore, specific sounds from real-world environments, while also preserving the spatial cues. To achieve this, we make two technical contributions: 1) we present the first neural network that can achieve binaural target sound extraction in the presence of interfering sounds and background noise, and 2) we design a training methodology that allows our system to generalize to real-world use. Results show that our system can operate with 20 sound classes and that our transformer-based network has a runtime of 6.56 ms on a connected smartphone. In-the-wild evaluation with participants in previously unseen indoor and outdoor scenarios shows that our proof-of-concept system can extract the target sounds and generalize to preserve the spatial cues in its binaural output. Project page with code: https://semantichearing.cs.washington.edu}, 44 | booktitle = {Proceedings of the 36th Annual ACM Symposium on User Interface Software and Technology}, 45 | articleno = {89}, 46 | numpages = {15}, 47 | keywords = {Spatial computing, binaural target sound extraction, attention, earable computing, causal neural networks, noise cancellation}, 48 | location = {San Francisco, CA, USA}, 49 | series = {UIST '23} 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /data/Classes.yaml: -------------------------------------------------------------------------------- 1 | # [NATURE] 2 | ocean: 3 | - Waves, surf 4 | 5 | thunderstorm: 6 | - Thunderstorm 7 | 8 | # [ANIMAL] 9 | dog: 10 | - Bark 11 | 12 | door_knock: 13 | - Knock 14 | 15 | cat: 16 | - Meow 17 | 18 | birds_chirping: 19 | - Chirp, tweet 20 | 21 | cricket: 22 | - Cricket 23 | 24 | cock_a_doodle_doo: 25 | - Crowing, cock-a-doodle-doo 26 | 27 | # [HUMAN] 28 | baby_cry: 29 | - Baby cry, infant cry 30 | 31 | speech: 32 | - Speech 33 | 34 | singing: 35 | - Singing 36 | 37 | # [MUSIC] 38 | music: 39 | - Melody 40 | 41 | # [Sounds of Things] 42 | gunshot: 43 | - Gunshot, gunfire 44 | 45 | glass_breaking: 46 | - Shatter 47 | 48 | computer_typing: 49 | - Computer keyboard 50 | 51 | toilet_flush: 52 | - Toilet flush 53 | 54 | hammer: 55 | - Hammer 56 | 57 | siren: 58 | - Siren 59 | 60 | alarm_clock: 61 | - Alarm clock 62 | 63 | car_horn: 64 | - Vehicle horn, car horn, honking 65 | 66 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Sourcing audio from multiple datasets 2 | As described in the paper, we source our audio files from four datasets: FSD50K, ESC-50, MUSDB18 and noise files from the DISCO dataset. 3 | 4 | Since our training and evaluation procedures make use of the Python Scaper library, we need to organize the data files in a way that the Scaper format. This happens in a multi-step procedure. 5 | 6 | ## Generating splits and mapping classes to AudioSet 7 | 8 | First, for each dataset we source from, we need to create three CSV files, one for each of the data splits (train, test and val). The CSV files contain some useful information about each audio file, ex: the sound classes and AudioSet ID of each file. These CSV files are also used to map classes that are outside of the AudioSet ontology into the closest label in the ontology. 9 | 10 | These CSV files are generated by a dedicated python script for each dataset. They can be found in ```data/compile/{DATASET_NAME}_label_collector.py```. For example, the script to generate the CSV files for FSD50K can be found in ```data/compile/fsd50k_label_collector.py```. 11 | 12 | ## Consolidating (uniting) splits from multiple dataset into a scaper format 13 | 14 | Once we have a description for each file in each dataset, given by the generated CSV files in the previous step, we now need to organize these audio files in the appropriate format to create JAMS files using Scaper. We to create two distinct directories, one for background sounds (not in our target sound classes), and one for foreground sounds (user can choose these classes). The foreground sounds are enumerated in ```data/Classes.yaml```. Here, we define the name of the class (as defined in our project), and what AudioSet classes belong to this class. You can move audio classes to and from the foreground set by changing this definition file. 15 | 16 | We consider the AudioSet ontology as a directed graph, where nodes are the audio classes and the edges are directed from an audio class to all its specialized sound classes. For example, if we consider the node "Aircraft", then there are edges from this node to each of "Aircraft engine", "Helicopter" and "Fixed-wing aircraft, airplane". 17 | 18 | When picking sound examples for a particular sound class, we choose all audio files that contain only nodes reachable from this class. These audio files will be placed in the directory of foreground sounds for this audio class. 19 | 20 | Roughly speaking, if an audio file is not part of our foreground set, then we consider it as a background sound. Specifically, audio files that are not included in the ```data/Classes.yaml``` and do not contain a sound class that is reachable from any foreground class is placed in the directory of background sounds. 21 | 22 | Audio files are placed into the Scaper directories by creating symbolic links, so as not to increase dataset size. In addition to organizing the audio files in this way, the dataset consolidation step also creates a CSV containing the start and end times of the sounds in each file, which is useful when creating audio mixtures. 23 | 24 | This step is done in ```data/compile/consolidate_datasets.py```. Note that this file is set up to call all the label collection scripts before writing the symlinks. Once you finish this step, you can create JAMS files using Scaper. 25 | -------------------------------------------------------------------------------- /data/compile/consolidate_datasets.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import subprocess 3 | import argparse 4 | from tqdm import tqdm 5 | import yaml, json 6 | import typing 7 | import pandas as pd 8 | import random 9 | from ontology import Ontology 10 | 11 | my_ontology = Ontology('data/ontology.json') 12 | 13 | # Merges several datasets into a single Scaper Format 14 | # Replaces Soundscape Gen 15 | # Given datasets distributed as: 16 | # - base_dir/ 17 | # -- Dataset1/ 18 | # --- train.csv 19 | # --- test.csv 20 | # --- val.csv 21 | # --- audio/*.wav 22 | # -- Dataset2/ 23 | # --- ... 24 | # -- Dataset3/ 25 | # --- ... 26 | 27 | # Generates: 28 | # - base_dir/ 29 | # -- ScaperFormat/ 30 | # --- train/ 31 | # ---- class1/*.wav 32 | # ---- class2/*.wav 33 | # --- test/ 34 | # ---- ... 35 | # --- val/ 36 | # ---- ... 37 | 38 | # [1] Go over each dataset 39 | # [2] Go over dataset type 40 | # [3] Go over label in dataset 41 | # [4] Check which classname this item label belongs to 42 | # [5] Create symlink between wavefile under the relevant classname directory 43 | 44 | # `classname` refers to the label name in *our* subset, `label` is reserved for dataset name 45 | 46 | all_samples = [] 47 | 48 | def meta_csv_to_dict(meta): 49 | """ 50 | Convert a ['fname', 'labels', 'mids'] headed dataframe to 51 | a dict with labels as keys, and list of file names as values. 52 | """ 53 | samples_dict = {} 54 | samples = pd.read_csv(meta) 55 | ids = list(samples['id'].unique()) 56 | for _id in ids: 57 | samples_dict[_id] = list(samples.loc[samples['id'] == _id]['fname']) 58 | 59 | return samples_dict 60 | 61 | def is_valid_background(label_id, foreground_ids): 62 | """ 63 | A label is valid background label iff it is not in an exclude subtree, 64 | and neither an ancestor nor a child of any foreground label. 65 | Since AudioSet hierarchies is more of a DAG than a tree, the concept of 66 | ancestor is a bit unclear. Instead, we just check if either node is reachable 67 | from the other. 68 | (we can precompute this at the beginning of each node but it's not much of a bottleneck) 69 | """ 70 | 71 | excluded_subtrees = [my_ontology.MUSIC, 72 | my_ontology.get_id_from_name('Human voice')] 73 | for subtree in excluded_subtrees: 74 | if my_ontology.is_reachable(subtree, label_id): 75 | return False 76 | 77 | for fg_id in foreground_ids: 78 | if my_ontology.is_reachable(label_id, fg_id) or\ 79 | my_ontology.is_reachable(fg_id, label_id): 80 | return False 81 | 82 | return True 83 | 84 | 85 | from scipy.io.wavfile import read 86 | import librosa 87 | import numpy as np 88 | 89 | 90 | def pcm2float(sig, dtype='float32'): 91 | """Convert PCM signal to floating point with a range from -1 to 1. 92 | Use dtype='float32' for single precision. 93 | Parameters 94 | ---------- 95 | sig : array_like 96 | Input array, must have integral type. 97 | dtype : data type, optional 98 | Desired (floating point) data type. 99 | Returns 100 | ------- 101 | numpy.ndarray 102 | Normalized floating point data. 103 | See Also 104 | -------- 105 | float2pcm, dtype 106 | """ 107 | sig = np.asarray(sig) 108 | if sig.dtype.kind not in 'iu': 109 | raise TypeError("'sig' must be an array of integers") 110 | dtype = np.dtype(dtype) 111 | if dtype.kind != 'f': 112 | raise TypeError("'dtype' must be a floating point type") 113 | 114 | i = np.iinfo(sig.dtype) 115 | abs_max = 2 ** (i.bits - 1) 116 | offset = i.min + abs_max 117 | return (sig.astype(dtype) - offset) / abs_max 118 | 119 | from scipy.ndimage import uniform_filter1d 120 | def trim_silence(s): 121 | # data, sr = librosa.load(s) 122 | sr, data = read(s) 123 | if len(data.shape) > 1: 124 | data = np.sum(data, axis=1) 125 | 126 | if data.dtype != np.float32: 127 | data = pcm2float(data) 128 | start, end = librosa.effects.trim(data, top_db=40)[1] 129 | 130 | data = data[start:end] 131 | 132 | window_size = int(round(1 * 44100)) 133 | avg_power = uniform_filter1d(data**2, size=window_size, mode='constant') 134 | threshold = 0.1 * avg_power.max() 135 | 136 | mask = avg_power < threshold 137 | if mask.any(): 138 | first_silence = np.argmax(mask) 139 | else: 140 | first_silence = end 141 | 142 | return start, first_silence, end 143 | 144 | def write_scaper_source(dataset_name: str, 145 | dataset_type: str, 146 | base_dir: str, 147 | fg_dest_dir: str, 148 | bg_dest_dir: str, 149 | id2classname: typing.Dict, 150 | dry_run: bool) -> None: 151 | dataset_path = os.path.join(base_dir, dataset_name) 152 | fg_out_dir = os.path.join(fg_dest_dir, dataset_type) 153 | bg_out_dir = os.path.join(bg_dest_dir, dataset_type) 154 | 155 | file_list_csv = os.path.join(dataset_path, f"{dataset_type}.csv") 156 | dataset = pd.read_csv(file_list_csv) 157 | 158 | print(f"Consolidating dataset {dataset_name}/{dataset_type}...") 159 | 160 | for index, sample_data in tqdm(dataset.iterrows(), total=dataset.shape[0]): 161 | # Check if we want to include this class 162 | if sample_data["id"] in id2classname: 163 | out_dir = fg_out_dir 164 | classname = id2classname[sample_data["id"]] 165 | elif is_valid_background(sample_data["id"], list(id2classname.keys())): 166 | out_dir = bg_out_dir 167 | classname = my_ontology.get_label(sample_data["id"]) 168 | else: 169 | continue 170 | 171 | out_path = os.path.join(out_dir, classname) 172 | os.makedirs(out_path, exist_ok=True) 173 | 174 | s = os.path.join('..', '..', '..', dataset_name, sample_data['fname']) 175 | 176 | fname = os.path.join(dataset_name.lower() + '_' + os.path.basename(sample_data['fname'])) 177 | d = os.path.join(out_path, fname) 178 | 179 | if dry_run: 180 | print("Would symlink %s to %s" % (s, d)) 181 | continue 182 | else: 183 | start_sample, first_silence, end_sample = trim_silence(os.path.join(dataset_path, sample_data['fname'])) 184 | assert start_sample < end_sample 185 | all_samples.append({'fname':d, 'start_sample':int(start_sample), 'end_sample':int(end_sample), 'first_silence':first_silence}) 186 | os.symlink(s, d) 187 | 188 | def read_yaml(yaml_path): 189 | with open(yaml_path, "r") as stream: 190 | yaml_data = yaml.safe_load(stream) 191 | 192 | return yaml_data 193 | 194 | def read_json(json_path): 195 | with open(json_path, "r") as stream: 196 | json_data = json.load(stream) 197 | 198 | return json_data 199 | 200 | def preprocess_ontology(ontology): 201 | res = {} 202 | for sound in ontology: 203 | res[sound['id']] = sound 204 | res[sound['name']] = sound 205 | return res 206 | 207 | def get_subtree(_id, ontology): 208 | subtree = [_id] 209 | for child_id in ontology[_id]['child_ids']: 210 | subtree.extend(get_subtree(child_id, ontology)) 211 | 212 | return subtree 213 | 214 | if __name__ == '__main__': 215 | random.seed(0) 216 | 217 | parser = argparse.ArgumentParser() 218 | parser.add_argument( 219 | '--datasets_dir', type=str, default='data/BinauralCuratedDataset', 220 | help="Path to directory containing all datasets.") 221 | parser.add_argument( 222 | '--class_definitions', type=str, default='data/Classes.yaml', 223 | help="Path to class susbet selection.") 224 | parser.add_argument( 225 | '--ontology', type=str, default='data/ontology.json', 226 | help="Path to ontology definition.") 227 | parser.add_argument( 228 | '--fg_output_dir', type=str, default='data/BinauralCuratedDataset/scaper_fmt', 229 | help="Path to directory to write scaper formatted data.") 230 | parser.add_argument( 231 | '--bg_output_dir', type=str, default='data/BinauralCuratedDataset/bg_scaper_fmt', 232 | help="Path to directory to write scaper formatted data.") 233 | parser.add_argument( 234 | '--dry_run', action='store_true', help="Dry run. Do not write any files.") 235 | args = parser.parse_args() 236 | 237 | datasets = ['FSD50K', 'ESC-50', 'musdb18', 'disco_noise'] 238 | dataset_types = ['train', 'val', 'test'] 239 | 240 | for dset in datasets: 241 | print(f'Collecting dataset {dset}...') 242 | collector_name = 'data/compile/' + dset.replace('-', '').lower() + '_label_collector.py' 243 | subprocess.run(['python', collector_name]) 244 | 245 | # Construct dict that maps an AudioSet class ID to a classname used in our dataset 246 | # Also helpful when classnames across datasets are different but same ID is used 247 | id2classname = {} 248 | class_data = read_yaml(args.class_definitions) 249 | ontology = read_json(args.ontology) 250 | ontology = preprocess_ontology(ontology) 251 | 252 | for class_name, class_list in class_data.items(): 253 | for element in class_list: 254 | class_id = ontology[element]['id'] 255 | 256 | # Map entire subtree to classname 257 | class_ids = get_subtree(class_id, ontology) 258 | 259 | for cid in class_ids: 260 | id2classname[cid] = class_name 261 | 262 | print(id2classname) 263 | 264 | for dataset_name in datasets: 265 | for dataset_type in dataset_types: 266 | write_scaper_source(dataset_name=dataset_name, 267 | dataset_type=dataset_type, 268 | base_dir=args.datasets_dir, 269 | fg_dest_dir=args.fg_output_dir, 270 | bg_dest_dir=args.bg_output_dir, 271 | id2classname=id2classname, 272 | dry_run=args.dry_run) 273 | 274 | df = pd.DataFrame.from_records(all_samples) 275 | df.to_csv(os.path.join(args.datasets_dir, 'start_times.csv')) 276 | -------------------------------------------------------------------------------- /data/compile/disco_noise_label_collector.py: -------------------------------------------------------------------------------- 1 | import os, sys, glob 2 | import argparse 3 | import random 4 | import pandas as pd 5 | import numpy as np 6 | from data.compile.ontology import Ontology 7 | from sklearn.model_selection import train_test_split 8 | 9 | 10 | dictionary = { 11 | "baby":"Baby cry, infant cry", 12 | "blender":"Blender", 13 | "dishwasher":None, 14 | "electric_shaver_toothbrush":"Toothbrush", 15 | "fan":"Mechanical fan", 16 | "frying":"Frying (food)", 17 | "printer":"Printer", 18 | "vacuum_cleaner":"Vacuum cleaner", 19 | "washing_machine":None, 20 | "water":"Water", 21 | } 22 | 23 | 24 | class DiscoNoiseLabelCollector(): 25 | def __init__(self, dataset_dir, ontology_path) -> None: 26 | self.ontology = Ontology(ontology_path) 27 | self.dataset_dir = dataset_dir 28 | 29 | self.files = {} 30 | 31 | for label in os.listdir(os.path.join(dataset_dir, 'train')): 32 | label_dir = os.path.join(dataset_dir, 'train', label) 33 | for x in glob.glob(os.path.join(label_dir, '*')): 34 | if label not in self.files: 35 | self.files[label] = [] 36 | 37 | self.files[label].append(x) 38 | 39 | for label in os.listdir(os.path.join(dataset_dir, 'test')): 40 | label_dir = os.path.join(dataset_dir, 'test', label) 41 | for x in glob.glob(os.path.join(label_dir, '*')): 42 | if label not in self.files: 43 | self.files[label] = [] 44 | 45 | self.files[label].append(x) 46 | 47 | def write_samples(self): 48 | train = [] 49 | test = [] 50 | val = [] 51 | 52 | for label in self.files: 53 | audio_set_label = dictionary[label] 54 | 55 | # Skip labels with no AudioSet equivalent 56 | if audio_set_label is None: 57 | continue 58 | 59 | _id = self.ontology.get_id_from_name(audio_set_label) 60 | 61 | train_files, test_files = train_test_split(self.files[label], test_size=0.33) 62 | 63 | random.shuffle(train_files) 64 | val_split = int(round(0.1 * len(train_files))) 65 | 66 | val_files = train_files[:val_split] 67 | train_files = train_files[val_split:] 68 | 69 | train.extend([dict(id=_id, label=audio_set_label, 70 | fname=os.path.relpath(fname, self.dataset_dir) ) for fname in train_files]) 71 | test.extend([dict(id=_id, label=audio_set_label, 72 | fname=os.path.relpath(fname, self.dataset_dir) ) for fname in test_files]) 73 | val.extend([dict(id=_id, label=audio_set_label, 74 | fname=os.path.relpath(fname, self.dataset_dir)) for fname in val_files]) 75 | 76 | train = pd.DataFrame.from_records(train) 77 | val = pd.DataFrame.from_records(val) 78 | test = pd.DataFrame.from_records(test) 79 | 80 | train.to_csv(os.path.join(self.dataset_dir, 'train.csv'), index=False) 81 | val.to_csv(os.path.join(self.dataset_dir, 'val.csv'), index=False) 82 | test.to_csv(os.path.join(self.dataset_dir, 'test.csv'), index=False) 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('--dataset_dir', type=str, default='data/BinauralCuratedDataset/disco_noises') 87 | 88 | args = parser.parse_args() 89 | 90 | random.seed(0) 91 | np.random.seed(0) 92 | 93 | label_collector = DiscoNoiseLabelCollector(args.dataset_dir, 'data/ontology.json') 94 | label_collector.write_samples() 95 | -------------------------------------------------------------------------------- /data/compile/esc50_label_collector.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | 4 | import pandas as pd 5 | import numpy as np 6 | from data.compile.ontology import Ontology 7 | 8 | 9 | dictionary = { 10 | "dog":"Bark", 11 | "rooster":"Crowing, cock-a-doodle-doo", 12 | "pig":"Pig", 13 | "cow":"Cattle, bovinae", 14 | "frog":"Frog", 15 | "cat":"Meow", 16 | "hen":"Chicken, rooster", 17 | "insects":"Insect", 18 | "sheep":"Sheep", 19 | "crow":"Crow", 20 | 21 | "rain":"Rain", 22 | "sea_waves":"Waves, surf", 23 | "crackling_fire":"Crackle", 24 | "crickets":"Cricket", 25 | "chirping_birds":"Chirp, tweet", 26 | "water_drops":"Drip", 27 | "wind":"Wind", 28 | "pouring_water":"Pour", 29 | "toilet_flush":"Toilet flush", 30 | "thunderstorm":"Thunderstorm", 31 | 32 | "crying_baby":"Baby cry, infant cry", 33 | "sneezing":"Sneeze", 34 | "clapping":"Clapping", 35 | "breathing":"Breathing", 36 | "coughing":"Cough", 37 | "footsteps":"Walk, footsteps", 38 | "laughing":"Laughter", 39 | "brushing_teeth":"Toothbrush", 40 | "snoring":"Snoring", 41 | "drinking_sipping":None, 42 | 43 | "door_wood_knock":"Knock", 44 | "mouse_click":None, 45 | "keyboard_typing":"Computer keyboard", 46 | "door_wood_creaks":"Creak", 47 | "can_opening":None, 48 | "washing_machine":None, 49 | "vacuum_cleaner":"Vacuum cleaner", 50 | "clock_alarm":"Alarm clock", 51 | "clock_tick":"Tick-tock", 52 | "glass_breaking":"Shatter", 53 | 54 | "helicopter":"Helicopter", 55 | "chainsaw":"Chainsaw", 56 | "siren":"Siren", 57 | "car_horn":"Vehicle horn, car horn, honking", 58 | "engine":"Engine", 59 | "train":"Train", 60 | "church_bells":"Church bell", 61 | "airplane":"Fixed-wing aircraft, airplane", 62 | "fireworks":"Fireworks", 63 | "hand_saw":"Sawing" 64 | } 65 | 66 | def write_csv(dataset, csv_name): 67 | dataset = dataset[dataset.apply(lambda x: dictionary[x['category']] is not None)] 68 | 69 | 70 | class ESC50LabelCollector(): 71 | def __init__(self, dataset_dir, ontology_path) -> None: 72 | self.ontology = Ontology(ontology_path) 73 | 74 | # Load metadata 75 | meta = pd.read_csv(os.path.join(dataset_dir, 'meta/esc50.csv')) 76 | 77 | # # Create a audio path column 78 | # meta['audio_path'] = meta['filename'].apply( 79 | # lambda x: os.path.join('..', '..', '..', 'ESC-50-master', 'audio', x)) 80 | 81 | # Use first 3 folds for training, 4th for validation, 5th for testing 82 | self.train_meta = meta[meta['fold'] <= 3] 83 | self.val_meta = meta[meta['fold'] == 4] 84 | self.test_meta = meta[meta['fold'] == 5] 85 | 86 | self.dataset_dir = dataset_dir 87 | 88 | def filter_samples(self, dataset: pd.DataFrame): 89 | dataset['label'] = dataset['category'].apply(lambda x: dictionary[x]) 90 | dataset = dataset.dropna().copy() 91 | 92 | dataset['fname'] = dataset['filename'].apply(lambda x: os.path.join('audio', x)) 93 | dataset['id'] = dataset['label'].apply(lambda x: self.ontology.get_id_from_name(x)) 94 | 95 | return dataset 96 | 97 | def write_samples(self): 98 | columns = ['fname', 'label', 'id'] 99 | 100 | train = self.filter_samples(self.train_meta) 101 | train = train[columns] 102 | 103 | val = self.filter_samples(self.val_meta) 104 | val = val[columns] 105 | 106 | test = self.filter_samples(self.test_meta) 107 | test = test[columns] 108 | 109 | train.to_csv(os.path.join(self.dataset_dir, 'train.csv'), index=False) 110 | val.to_csv(os.path.join(self.dataset_dir, 'val.csv'), index=False) 111 | test.to_csv(os.path.join(self.dataset_dir, 'test.csv'), index=False) 112 | 113 | if __name__ == '__main__': 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument('--dataset_dir', type=str, default='data/BinauralCuratedDataset/ESC-50') 116 | 117 | args = parser.parse_args() 118 | 119 | label_collector = ESC50LabelCollector(args.dataset_dir, 'data/ontology.json') 120 | label_collector.write_samples() 121 | -------------------------------------------------------------------------------- /data/compile/fsd50k_label_collector.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import urllib.request 5 | 6 | from sklearn.model_selection import train_test_split 7 | import pandas as pd 8 | import matplotlib.pyplot as plt 9 | from data.compile.ontology import Ontology 10 | import numpy as np 11 | import random 12 | 13 | 14 | # This assumes all samples are leaves in the vocabulary. In general, this is not true. 15 | 16 | class FSD50KLabelCollector(): 17 | """ 18 | FSD50K Curator dataset. 19 | 20 | Args: 21 | root_dir (str): Root directory of the FSD50K dataset. 22 | onotology (json file): AudioSet ontology file. 23 | """ 24 | def __init__(self, root_dir, ontology_path): 25 | self.root_dir = root_dir 26 | 27 | self.ontology = Ontology(ontology_path) 28 | 29 | # Label ratings 30 | with open(os.path.join(root_dir, 'FSD50K.metadata', 31 | 'pp_pnp_ratings_FSD50K.json')) as ratings_file: 32 | self.pp_pnp_ratings = json.load(ratings_file) 33 | 34 | def is_pp_sample(self, fname): 35 | # assert fname in self.pp_pnp_ratings.keys(), "fname not in ratings" 36 | # if id not in self.pp_pnp_ratings[fname].keys(): 37 | # return False 38 | # assert id in self.pp_pnp_ratings[fname].keys(), \ 39 | # "id not in ratings: id=%s fname=%s" % (id, fname) 40 | 41 | label_ratings = self.pp_pnp_ratings[fname] 42 | 43 | for node_id in label_ratings.keys(): 44 | label_rating = label_ratings[node_id] 45 | counts = {1.0: 0, 0.5: 0, 0: 0, -1: 0} 46 | for r in label_rating: 47 | # if r not in counts.keys(): 48 | # counts[r] = 0 49 | counts[r] += 1 50 | 51 | if counts[0.0] > 0 or counts[-1] > 0 or counts[1.0] < 2: 52 | return False 53 | 54 | return True 55 | 56 | def get_sample_split(self): 57 | dev_samples = pd.read_csv(os.path.join(self.root_dir, 'FSD50K.metadata', 'collection', 'collection_dev.csv')) 58 | eval_samples = pd.read_csv(os.path.join(self.root_dir, 'FSD50K.metadata', 'collection', 'collection_eval.csv')) 59 | 60 | train = dev_samples 61 | test = eval_samples 62 | 63 | return train, test 64 | 65 | def _curate_samples(self, samples, exclude=[]): 66 | # Format data 67 | samples = samples.dropna().copy() 68 | samples['fname'] = samples['fname'].apply(lambda x: str(x)) 69 | samples['mids'] = samples['mids'].apply(lambda x: x.split(',')) 70 | samples['labels'] = samples['labels'].apply(lambda x: x.split(',')) 71 | 72 | # Filter out samples without multiple true-positive ratings 73 | samples['pp_sample'] = samples.apply( 74 | lambda x: self.is_pp_sample(x['fname']), axis=1) 75 | samples = samples[samples['pp_sample'] == True] 76 | 77 | # Remove samples with source ambiguous sounds 78 | #samples = samples[samples['mids'].apply( 79 | # lambda x: not any([self.ontology.is_source_ambiguous(n) for n in x] ))] 80 | 81 | # Remove samples with multiple labels 82 | samples = samples[samples['mids'].apply( 83 | lambda x: len(x) == 1)] 84 | 85 | # Get sample id from mids 86 | samples['id'] = samples['mids'].apply( 87 | lambda x: x[0]) 88 | 89 | # Convert ID to AudioSet label 90 | samples['label'] = samples['id'].apply( 91 | lambda x: self.ontology.get_label(x)) 92 | 93 | return samples 94 | 95 | def _write_samples(self, dset, output, exclude=[]): 96 | if dset == 'train' or dset == 'val': 97 | src_dir = 'FSD50K.dev_audio' 98 | elif dset == 'test': 99 | src_dir = 'FSD50K.eval_audio' 100 | samples = self._curate_samples(dset=dset, exclude=exclude) 101 | samples = samples[['label', 'fname', 'id']] 102 | samples.columns = ['label', 'fname', 'id'] 103 | samples['fname'] = samples['fname'].apply( 104 | lambda x: os.path.join(src_dir, '%s.wav' % x)) 105 | samples.to_csv(output) 106 | 107 | def _plot_stats(self, dset='train', exclude=[], figsize=(20, 5)): 108 | samples = self._curate_samples(dset=dset, exclude=exclude) 109 | samples['root_label'] = samples['leaf_id'].apply( 110 | lambda x: self.ontology[self.ontology.get_ancestor_ids(x)[0]]['name']) 111 | print("Sample count: %d" % len(samples)) 112 | print("Leaf node count: %d" % len(samples['leaf_label'].unique())) 113 | plt.figure(figsize=figsize) 114 | samples['leaf_label'].value_counts().plot(kind='bar') 115 | plt.grid(True) 116 | plt.show() 117 | plt.figure(figsize=figsize) 118 | samples['root_label'].value_counts().plot(kind='bar') 119 | plt.grid(True) 120 | plt.show() 121 | 122 | def curate_samples(self): 123 | train, test = self.get_sample_split() 124 | 125 | train_samples_curated = self._curate_samples(train) 126 | 127 | train_samples = pd.DataFrame() 128 | val_samples = pd.DataFrame() 129 | 130 | train_labels = sorted(list(set(train_samples_curated['label']))) 131 | for label in train_labels: 132 | samples = train_samples_curated[train_samples_curated['label'] == label] 133 | 134 | if len(samples) == 1: 135 | continue 136 | 137 | train, val = train_test_split(samples, test_size=0.1) 138 | 139 | 140 | val_samples = pd.concat([val_samples, val]) 141 | train_samples = pd.concat([train_samples, train]) 142 | 143 | test_samples = self._curate_samples(test) 144 | 145 | train_src_dir = 'FSD50K.dev_audio' 146 | train_samples = train_samples[['label', 'fname', 'id']] 147 | train_samples.columns = ['label', 'fname', 'id'] 148 | train_samples['fname'] = train_samples['fname'].apply( 149 | lambda x: os.path.join(train_src_dir, '%s.wav' % x)) 150 | 151 | val_src_dir = 'FSD50K.dev_audio' 152 | val_samples = val_samples[['label', 'fname', 'id']] 153 | val_samples.columns = ['label', 'fname', 'id'] 154 | val_samples['fname'] = val_samples['fname'].apply( 155 | lambda x: os.path.join(val_src_dir, '%s.wav' % x)) 156 | 157 | test_src_dir = 'FSD50K.eval_audio' 158 | test_samples = test_samples[['label', 'fname', 'id']] 159 | test_samples.columns = ['label', 'fname', 'id'] 160 | test_samples['fname'] = test_samples['fname'].apply( 161 | lambda x: os.path.join(test_src_dir, '%s.wav' % x)) 162 | 163 | train_samples = train_samples[ 164 | train_samples['label'].map(train_samples['label'].value_counts()) >= 0] 165 | 166 | # List common labels across train, val and test 167 | common_labels = list( 168 | set(train_samples['label'].unique()) & \ 169 | set(val_samples['label'].unique()) & \ 170 | set(test_samples['label'].unique())) 171 | 172 | # Filter out samples with labels that are not common across 173 | # train, val and test. 174 | train_samples = train_samples[train_samples['label'].isin(common_labels)] 175 | val_samples = val_samples[val_samples['label'].isin(common_labels)] 176 | test_samples = test_samples[test_samples['label'].isin(common_labels)] 177 | 178 | return train_samples, val_samples, test_samples 179 | 180 | def write_samples(self, output_dir): 181 | train_samples, val_samples, test_samples = self.curate_samples() 182 | train_samples.to_csv(os.path.join(output_dir, 'train.csv'), index=False) 183 | val_samples.to_csv(os.path.join(output_dir, 'val.csv'), index=False) 184 | test_samples.to_csv(os.path.join(output_dir, 'test.csv'), index=False) 185 | 186 | def plot_stats(self, figsize=(20, 5)): 187 | train_samples, val_samples, test_samples = self.curate_samples() 188 | print(list(train_samples['label'].unique())) 189 | print("Train sample count: %d" % len(train_samples)) 190 | print("Val sample count: %d" % len(val_samples)) 191 | print("Test sample count: %d" % len(test_samples)) 192 | print("Train leaf node count: %d" % len(train_samples['label'].unique())) 193 | plt.figure(figsize=figsize) 194 | train_samples['label'].value_counts().plot(kind='bar') 195 | plt.grid(True) 196 | plt.show() 197 | print("Val leaf node count: %d" % len(val_samples['label'].unique())) 198 | plt.figure(figsize=figsize) 199 | val_samples['label'].value_counts().plot(kind='bar') 200 | plt.grid(True) 201 | plt.show() 202 | print("Test leaf node count: %d" % len(test_samples['label'].unique())) 203 | plt.figure(figsize=figsize) 204 | test_samples['label'].value_counts().plot(kind='bar') 205 | plt.grid(True) 206 | plt.show() 207 | 208 | if __name__ == '__main__': 209 | parser = argparse.ArgumentParser() 210 | parser.add_argument( 211 | '--dataset_dir', type=str, default='data/BinauralCuratedDataset/FSD50K', 212 | help="Root directory for the FSD50K dataset") 213 | args = parser.parse_args() 214 | 215 | random.seed(0) 216 | np.random.seed(0) 217 | 218 | # assert not os.path.exists(args.output_dir), \ 219 | # "Ouput dir %s already exists" % args.output_dir 220 | # os.makedirs(args.output_dir, exist_ok=True) 221 | 222 | fsd50k_curator = FSD50KLabelCollector(args.dataset_dir, 'data/ontology.json') 223 | fsd50k_curator.write_samples(args.dataset_dir) 224 | -------------------------------------------------------------------------------- /data/compile/musdb18_label_collector.py: -------------------------------------------------------------------------------- 1 | import os, glob 2 | import argparse 3 | import sys 4 | 5 | import pandas as pd 6 | import numpy as np 7 | import random 8 | import ffmpegio 9 | import tqdm 10 | 11 | import torchaudio, librosa 12 | from scipy.io.wavfile import write as wavwrite 13 | 14 | from data.compile.ontology import Ontology 15 | 16 | 17 | def read_audio_file_torch(file_path): 18 | waveform, sample_rate = torchaudio.load(file_path) 19 | return waveform 20 | 21 | def read_audio_file(file_path, sr): 22 | """ 23 | Reads audio file to system memory. 24 | """ 25 | return librosa.core.load(file_path, mono=False, sr=sr)[0] 26 | 27 | 28 | def write_audio_file(file_path, data, sr): 29 | """ 30 | Writes audio file to system memory. 31 | @param file_path: Path of the file to write to 32 | @param data: Audio signal to write (n_channels x n_samples) 33 | @param sr: Sampling rate 34 | """ 35 | wavwrite(file_path, sr, data) 36 | 37 | def convert_videos(video_paths, audio_dir, segment_duration_s): 38 | os.makedirs(audio_dir, exist_ok=True) 39 | 40 | instrumental_dir = os.path.join(audio_dir, 'instrumental') 41 | os.makedirs(instrumental_dir, exist_ok=True) 42 | 43 | vocals_dir = os.path.join(audio_dir, 'vocals') 44 | os.makedirs(vocals_dir, exist_ok=True) 45 | 46 | for i in tqdm.tqdm(range(len(video_paths))): 47 | path = video_paths[i] 48 | 49 | print(path) 50 | 51 | song_name = os.path.basename(path) 52 | audio_streams = ffmpegio.probe.audio_streams_basic(path) 53 | duration_samples = audio_streams[0]['duration'].numerator 54 | sr = audio_streams[0]['sample_rate'] 55 | segment_duration_samples = int(round(sr * segment_duration_s)) 56 | 57 | # Remaining audio must be at least 1/2 chunk size 58 | num_chunks = 1 + (duration_samples - segment_duration_samples // 2 - 1) // segment_duration_samples 59 | 60 | for chunk_id in tqdm.tqdm(range(num_chunks)): 61 | start_time = chunk_id * segment_duration_s 62 | _, mixture = ffmpegio.audio.read(path, ss=start_time, t=segment_duration_s, ac=1) 63 | _, vocals = ffmpegio.audio.read(path, ss=start_time, t=segment_duration_s, map=[['0','4']], ac=1) 64 | 65 | instrumental = mixture - vocals 66 | 67 | # Save audio files only if they are not completely silent (i.e. no vocals this chunk) 68 | if (np.abs(vocals) > 5e-3).any(): 69 | vocals_path = os.path.join(vocals_dir, f'{song_name}_v_{chunk_id}.wav') 70 | write_audio_file(vocals_path, vocals, sr) 71 | 72 | if (np.abs(instrumental) > 5e-3).any(): 73 | instrumental_path = os.path.join(instrumental_dir, f'{song_name}_i_{chunk_id}.wav') 74 | write_audio_file(instrumental_path, instrumental, sr) 75 | 76 | class MUSDB18LabelCollector(): 77 | def __init__(self, ontology_path) -> None: 78 | self.ontology = Ontology(ontology_path) 79 | 80 | 81 | def write_csv(self, dataset_dir, dataset_type): 82 | samples = [] 83 | 84 | preproc_dir = os.path.join(dataset_dir, 'audio', dataset_type) 85 | 86 | instrumental_dir = os.path.join(preproc_dir, 'instrumental') 87 | vocals_dir = os.path.join(preproc_dir, 'vocals') 88 | 89 | for sample_path in glob.glob(os.path.join(vocals_dir, '*.wav')): 90 | rel_path = os.path.relpath(sample_path, dataset_dir) 91 | label = 'Singing' 92 | sample = dict(label=label, 93 | fname=rel_path, 94 | id=self.ontology.get_id_from_name(label)) 95 | samples.append(sample) 96 | 97 | for sample_path in glob.glob(os.path.join(instrumental_dir, '*.wav')): 98 | rel_path = os.path.relpath(sample_path, dataset_dir) 99 | label = 'Melody' 100 | sample = dict(label=label, 101 | fname=rel_path, 102 | id=self.ontology.get_id_from_name(label)) 103 | samples.append(sample) 104 | 105 | df = pd.DataFrame.from_records(samples) 106 | output_csv = os.path.join(dataset_dir, dataset_type + '.csv') 107 | df.to_csv(output_csv) 108 | 109 | if __name__ == '__main__': 110 | parser = argparse.ArgumentParser() 111 | parser.add_argument('--dataset_dir', type=str, default='data/BinauralCuratedDataset/musdb18') 112 | parser.add_argument('--segment_duration_s', type=str, default=15) 113 | args = parser.parse_args() 114 | 115 | random.seed(0) 116 | 117 | assert os.path.exists(args.dataset_dir), f"Path {args.dataset_dir} to dataset is invalid (not found)" 118 | 119 | audio_dir = os.path.join(args.dataset_dir, 'audio') 120 | 121 | if not os.path.exists(audio_dir): 122 | print("[INFO] DATASET HAS NOT BEEN PREPROCESSED - PREPROCESSING... (THIS MAY TAKE SOME TIME)") 123 | os.makedirs(audio_dir, exist_ok=True) 124 | 125 | test_video_list = sorted(list(glob.glob(os.path.join(args.dataset_dir, 'test', '*')))) 126 | 127 | # Split train into train & val sets 128 | train_video_list = sorted(list(glob.glob(os.path.join(args.dataset_dir, 'train', '*')))) 129 | 130 | random.shuffle(train_video_list) 131 | val_split = int(round(0.1 * len(train_video_list))) 132 | val_video_list = train_video_list[:val_split] 133 | train_video_list = train_video_list[val_split:] 134 | 135 | convert_videos(train_video_list, os.path.join(audio_dir, 'train'), args.segment_duration_s) 136 | convert_videos(test_video_list, os.path.join(audio_dir, 'test'), args.segment_duration_s) 137 | convert_videos(val_video_list, os.path.join(audio_dir, 'val'), args.segment_duration_s) 138 | 139 | collector = MUSDB18LabelCollector('data/ontology.json') 140 | collector.write_csv(args.dataset_dir, 'train') 141 | collector.write_csv(args.dataset_dir, 'test') 142 | collector.write_csv(args.dataset_dir, 'val') 143 | -------------------------------------------------------------------------------- /data/compile/ontology.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | class Ontology(object): 5 | ANIMAL = '/m/0jbk' 6 | SOUNDS_OF_THINGS = '/t/dd00041' 7 | HUMAN_SOUNDS = '/m/0dgw9r' 8 | NATURAL_SOUNDS = "/m/059j3w" 9 | SOURCE_AMBIGUOUS_SOUNDS = "/t/dd00098" 10 | CHANNEL_ENVIRONMENT_BACKGROUND = "/t/dd00123" 11 | MUSIC = "/m/04rlf" 12 | ROOT = "__root__" 13 | 14 | def __init__(self, path) -> None: 15 | with open(path, 'rb') as f: 16 | ontology_list = json.load(f) 17 | 18 | root_node = {} 19 | root_node['child_ids'] = [self.SOURCE_AMBIGUOUS_SOUNDS, 20 | self.ANIMAL, 21 | self.SOUNDS_OF_THINGS, 22 | self.HUMAN_SOUNDS, 23 | self.NATURAL_SOUNDS, 24 | self.CHANNEL_ENVIRONMENT_BACKGROUND, 25 | self.MUSIC] 26 | root_node['id'] = self.ROOT 27 | root_node['name'] = self.ROOT 28 | ontology_list.append(root_node) 29 | 30 | self.ontology = {item['id']: item for item in ontology_list} 31 | 32 | self._dfs() 33 | self.mark_source_ambiguous_sounds() 34 | 35 | def _dfs(self, node_id=None): 36 | if node_id is None: 37 | node_id = self.ROOT 38 | self.ontology[node_id]['depth'] = 0 39 | self.ontology[node_id]['parent_id'] = None 40 | else: 41 | parent_node = self.ontology[node_id]['parent_id'] 42 | self.ontology[node_id]['depth'] = self.ontology[parent_node]['depth'] + 1 43 | 44 | self.ontology[node_id]['source_ambiguous'] = 0 45 | 46 | for child_id in self.ontology[node_id]['child_ids']: 47 | self.ontology[child_id]['parent_id'] = node_id 48 | self._dfs(node_id=child_id) 49 | 50 | def mark_source_ambiguous_sounds(self, node_id=None): 51 | if node_id is None: 52 | node_id = self.SOURCE_AMBIGUOUS_SOUNDS 53 | 54 | self.ontology[node_id]['source_ambiguous'] = 1 55 | 56 | for child_id in self.ontology[node_id]['child_ids']: 57 | self.mark_source_ambiguous_sounds(child_id) 58 | 59 | def is_source_ambiguous(self, node_id): 60 | # print("NODE", node_id) 61 | return self.ontology[node_id]['source_ambiguous'] 62 | 63 | def get_label(self, node_id): 64 | return self.ontology[node_id]['name'] 65 | 66 | def get_id_from_name(self, name): 67 | for _id in self.ontology: 68 | if self.ontology[_id]['name'] == name: 69 | return _id 70 | 71 | assert 0, f"Could not find AudioSet class with name \'{name}\'" 72 | 73 | def unsmear(self, args): 74 | x = sorted(args, key = lambda x: -self.ontology[x]['depth']) 75 | 76 | unsmeared = [] 77 | removed = [] 78 | for i in range(len(args)): 79 | if i in removed: 80 | continue 81 | node_id = args[i] 82 | unsmeared.append(node_id) 83 | while self.ontology[node_id]['parent_id'] is not None: 84 | node_id = self.ontology[node_id]['parent_id'] 85 | try: 86 | idx = args.index(node_id) 87 | removed.append(idx) 88 | except: 89 | pass 90 | 91 | return unsmeared 92 | 93 | 94 | def is_leaf_node(self, id): 95 | assert id in self.ontology.keys(), "id not in ontology" 96 | 97 | if self.ontology[id]['child_ids'] == []: 98 | return True 99 | 100 | return False 101 | 102 | def get_ancestor_ids(self, id): 103 | assert id in self.ontology.keys(), "id not in ontology" 104 | 105 | ancestor_ids = [id] 106 | parent_id = self.ontology[id]['parent_id'] 107 | while parent_id is not None: 108 | ancestor_ids.append(parent_id) 109 | parent_id = self.ontology[parent_id]['parent_id'] 110 | return list(reversed(ancestor_ids)) 111 | 112 | def is_reachable(self, parent, child): 113 | assert parent in self.ontology.keys(), "parent not in ontology" 114 | assert child in self.ontology.keys(), "child not in ontology" 115 | 116 | if parent == child: 117 | return True 118 | for child_id in self.ontology[parent]['child_ids']: 119 | if self.is_reachable(child_id, child): 120 | return True 121 | return False 122 | 123 | def get_leaf_nodes(self, ids): 124 | leaf_nodes = [] 125 | for id in ids: 126 | if self.is_leaf_node(id): 127 | leaf_nodes.append(id) 128 | 129 | return leaf_nodes 130 | 131 | def get_unique_leaf_node(self, ids): 132 | leaf_nodes = self.get_leaf_nodes(ids) 133 | 134 | if len(leaf_nodes) != 1: 135 | return None 136 | 137 | return leaf_nodes[0] 138 | 139 | def is_unique_branch(self, ids, debug=False): 140 | ids = sorted(ids, key=lambda x: -self.ontology[x]['depth']) 141 | 142 | bottom = ids[0] 143 | 144 | ancestor_ids = self.get_ancestor_ids(bottom) 145 | 146 | for _id in ids: 147 | if _id not in ancestor_ids: 148 | return False 149 | 150 | return True -------------------------------------------------------------------------------- /data/multi_ch_simulator.py: -------------------------------------------------------------------------------- 1 | import pyroomacoustics as pra 2 | import numpy as np 3 | import random 4 | import json 5 | import os, glob 6 | import sofa 7 | import torch 8 | import torchaudio.transforms as AT 9 | from data.utils import read_audio_file 10 | 11 | from scipy.signal import convolve 12 | from scipy.ndimage import convolve1d 13 | 14 | 15 | import time 16 | 17 | class BaseSimulator(object): 18 | def __init__(self): 19 | pass 20 | 21 | def preprocess(self, audio): 22 | return audio 23 | 24 | def postprocess(self, audio): 25 | return audio 26 | 27 | def randomize_sources(self, num_sources): 28 | pass 29 | 30 | def get_metadata(self): 31 | metadata = {} 32 | 33 | metadata['duration'] = self.D 34 | metadata['sofa'] = self.sofa 35 | 36 | metadata['mic_positions'] = self.mic_positions 37 | 38 | metadata['sources'] = [] 39 | for i, source_id in enumerate(self.source_order): 40 | source = {'position':self.source_positions[i], 41 | 'order':source_id, 42 | 'hrtf_index':self.hrtf_indices[i], 43 | 'label':self.source_labels[i]} 44 | metadata['sources'].append(source) 45 | 46 | metadata['num_background'] = self.num_background_sources 47 | 48 | return metadata 49 | 50 | def save(self, path): 51 | metadata = self.get_metadata() 52 | 53 | with open(path, 'w') as f: 54 | json.dump(metadata, f, indent=4) 55 | 56 | def simulate(self, audio: np.ndarray) -> np.ndarray: 57 | """ 58 | Simulates RIR 59 | audio: (C x T) 60 | """ 61 | num_sources = audio.shape[0] 62 | 63 | #t1 = time.time() 64 | 65 | rirs = self.get_rirs() 66 | 67 | #t2 = time.time() 68 | 69 | #t_rir = t2 - t1 70 | 71 | #t1 = time.time() 72 | x = self.preprocess(audio) 73 | 74 | output = [] 75 | for i in range(num_sources): 76 | rir = rirs[i] 77 | waveform = x[i] 78 | 79 | left = convolve(waveform, rir[0]) 80 | left = self.postprocess(left) 81 | 82 | right = convolve(waveform, rir[1]) 83 | right = self.postprocess(right) 84 | 85 | binaural = np.stack([left, right]) 86 | output.append(binaural) 87 | 88 | output = np.array(output, dtype=np.float32) 89 | #t2 = time.time() 90 | 91 | #t_convolve = t2 - t1 92 | 93 | #print('RIR time:', t_rir) 94 | #print('Convolution time:', t_convolve) 95 | 96 | return output 97 | 98 | def initialize_room_with_random_params(self, 99 | num_sources: int, 100 | duration: float, 101 | ann_list: list, 102 | nbackground_sources: int = 1): 103 | self.D = duration 104 | 105 | self.source_labels = [] 106 | for i in range(num_sources): 107 | self.source_labels.append(ann_list[i]) 108 | 109 | # Randomize source choose order 110 | # First k sources correspond to background sources 111 | # Next n - k sources are foreground sources 112 | n = num_sources 113 | k = nbackground_sources 114 | self.source_order = [i for i in range(n - k)] 115 | np.random.shuffle(self.source_order) 116 | self.source_order = [i for i in range(n - k, n)] + self.source_order 117 | 118 | self.num_background_sources = k 119 | 120 | return self 121 | 122 | def seed(self, seed_value): 123 | np.random.seed(seed_value) 124 | random.seed(seed_value) 125 | 126 | class CATTRIR_Simulator(BaseSimulator): 127 | def __init__(self, dset_text_file, **kwargs) -> None: 128 | super().__init__() 129 | 130 | dset_dir = os.path.dirname(dset_text_file) 131 | with open(dset_text_file, 'r') as f: 132 | self.rt60_list = f.read().split('\n') 133 | self.rt60_dirs = [os.path.join(dset_dir, x) for x in self.rt60_list] 134 | 135 | def randomize_sources(self, num_sources): 136 | source_positions = [] 137 | hrtf_indices = [] 138 | rirs = sorted(os.listdir(self.room_dir)) 139 | random_source_rir_wavs = random.sample(rirs, num_sources) 140 | 141 | angles = [] 142 | for f in random_source_rir_wavs: 143 | angle = int(f[f.rfind('_')+1:-4]) 144 | angles.append(angle) 145 | 146 | for i in range(num_sources): 147 | pos = [np.cos(np.deg2rad(angle)), np.sin(np.deg2rad(angle))] 148 | source_positions.append(pos) 149 | hrtf_indices.append(angle) 150 | 151 | return source_positions, hrtf_indices 152 | 153 | def get_rirs(self): 154 | num_sources = len(self.source_positions) 155 | rt60 = os.path.basename(self.room_dir) 156 | 157 | rirs = [] 158 | for i in range(num_sources): 159 | path = os.path.join(self.room_dir, f'CATT_{rt60}_{self.hrtf_indices[i]}.wav') 160 | rir = read_audio_file(path, 44100) 161 | rirs.append(rir.astype(np.float32)) 162 | 163 | return rirs 164 | 165 | def initialize_room_with_random_params(self, 166 | num_sources: int, 167 | duration: float, 168 | ann_list: list, 169 | nbackground_sources: int = 1): 170 | 171 | self.room_dir = self.rt60_dirs[np.random.randint(len(self.rt60_dirs))] 172 | self.sofa = self.room_dir # TODO: Implement this better 173 | 174 | self.mic_positions = [[0, 0.9, 0], [0, -0.9, 0]] 175 | self.source_positions, self.hrtf_indices = self.randomize_sources(num_sources) 176 | 177 | return super().initialize_room_with_random_params(num_sources, 178 | duration, 179 | ann_list, 180 | nbackground_sources) 181 | 182 | class SOFASimulator(BaseSimulator): 183 | def __init__(self, sofa_text_file, **kwargs) -> None: 184 | super().__init__() 185 | self.hrtf_cache = {} 186 | self.sofa_dict = {} 187 | sofa_dir = os.path.dirname(sofa_text_file) 188 | with open(sofa_text_file, 'r') as f: 189 | self.subject_sofa_list = f.read().split('\n') 190 | self.sofa_files = [os.path.join(sofa_dir, x) for x in self.subject_sofa_list] 191 | 192 | for f in self.sofa_files: 193 | self.sofa_dict[f] = sofa.Database.open(f) 194 | 195 | self.kwargs = kwargs 196 | 197 | def initialize_room_with_random_params(self, 198 | num_sources: int, 199 | duration: float, 200 | ann_list: list, 201 | nbackground_sources: int = 1): 202 | 203 | self.sofa = self.sofa_files[np.random.randint(len(self.sofa_files))] 204 | self.HRTF = self.sofa_dict[self.sofa]#sofa.Database.open(self.sofa) 205 | mic_positions = self.HRTF.Receiver.Position.get_values(system="cartesian")[..., 0] 206 | self.mic_positions = mic_positions.tolist() 207 | self.source_positions, self.hrtf_indices = self.randomize_sources(num_sources) 208 | 209 | return super().initialize_room_with_random_params(num_sources, 210 | duration, 211 | ann_list, 212 | nbackground_sources) 213 | def get_rirs(self): 214 | num_sources = len(self.source_positions) 215 | rirs = [] 216 | for i in range(num_sources): 217 | key = self.sofa + str(sorted(list(self.hrtf_indices[i].items()))) 218 | #print('KEY', key) 219 | if key in self.hrtf_cache: 220 | rir = self.hrtf_cache[key] 221 | else: 222 | rir = self.HRTF.Data.IR.get_values(indices=self.hrtf_indices[i]).astype(np.float32) 223 | self.hrtf_cache[key] = rir.copy() 224 | rirs.append(rir) 225 | return rirs 226 | 227 | class CIPIC_Simulator(SOFASimulator): 228 | def randomize_sources(self, num_sources): 229 | source_positions = [] 230 | hrtf_indices = [] 231 | random_source_positions = random.sample(range(self.HRTF.Dimensions.M), num_sources) 232 | for i in range(num_sources): 233 | sofa_indices = {"M":random_source_positions[i]} 234 | pos = self.HRTF.Source.Position.get_values(system="cartesian", indices=sofa_indices).tolist() 235 | source_positions.append(pos) 236 | hrtf_indices.append(sofa_indices) 237 | 238 | return source_positions, hrtf_indices 239 | 240 | 241 | class CIPIC_HRTF_Simulator(CIPIC_Simulator): pass 242 | 243 | class BRIR48kHz_Simulator(CIPIC_HRTF_Simulator): 244 | def __init__(self, sofa_text_file, **kwargs): 245 | super().__init__(sofa_text_file, **kwargs) 246 | self.presampler = AT.Resample(self.kwargs['sr'], 48000) 247 | self.postsampler = AT.Resample(48000, self.kwargs['sr']) 248 | 249 | def preprocess(self, audio: np.ndarray) -> np.ndarray: 250 | audio = self.presampler(torch.from_numpy(audio)) 251 | return audio.numpy() 252 | 253 | def postprocess(self, audio: np.ndarray) -> np.ndarray: 254 | audio = self.postsampler(torch.from_numpy(audio)) 255 | return audio.numpy() 256 | 257 | 258 | # Salford-BBC Spatially-sampled Binaural Room Impulse Responses 259 | # https://usir.salford.ac.uk/id/eprint/30868/ 260 | class SBSBRIR_Simulator(BRIR48kHz_Simulator): 261 | def randomize_sources(self, num_sources): 262 | source_positions = [] 263 | hrtf_indices = [] 264 | 265 | random_source_positions = random.sample(range(self.HRTF.Dimensions.E), num_sources) 266 | random_measurement_rotation = np.random.randint(self.HRTF.Dimensions.M) 267 | for i in range(num_sources): 268 | #sofa_indices = {"M":0, "E":0} 269 | sofa_indices = {"M":random_measurement_rotation, "E":random_source_positions[i]} 270 | pos = self.HRTF.Emitter.Position.get_values(system="cartesian", indices=sofa_indices).tolist() 271 | source_positions.append(pos) 272 | hrtf_indices.append(sofa_indices) 273 | 274 | return source_positions, hrtf_indices 275 | 276 | def preprocess(self, audio: np.ndarray) -> np.ndarray: 277 | audio = super().preprocess(audio) 278 | return audio * 15 # Gain because RIRs are very low for some reason 279 | 280 | # Real Room BRIRs 281 | # https://github.com/IoSR-Surrey/RealRoomBRIRs 282 | class RRBRIR_Simulator(BRIR48kHz_Simulator): pass 283 | 284 | class Multi_Ch_Simulator(BaseSimulator): 285 | # simulators = [CIPIC_Simulator] 286 | # simulators = [ CATTRIR_Simulator] 287 | # simulators = [ SBSBRIR_Simulator] 288 | # simulators = [RRBRIR_Simulator] 289 | # simulators = [SBSBRIR_Simulator, RRBRIR_Simulator, CATTRIR_Simulator] # UNCOMMENT FOR REVERBED HRTF ONLY 290 | def __init__(self, hrtf_dir, dset_type: str, sr: int, reverb: bool = True) -> None: 291 | self.hrtf_dir = hrtf_dir 292 | self.dset = dset_type 293 | self.sr = sr 294 | 295 | if reverb: 296 | simulators = [CIPIC_Simulator, SBSBRIR_Simulator, RRBRIR_Simulator, CATTRIR_Simulator] 297 | else: 298 | simulators = [CIPIC_Simulator] 299 | 300 | #simulators = [SBSBRIR_Simulator] 301 | self.simulators = [sim(os.path.join(self.hrtf_dir, sim.__name__[:-len("_Simulator")], self.dset + '_hrtf.txt'), sr=self.sr) for sim in simulators] 302 | 303 | def get_random_simulator(self) -> BaseSimulator: 304 | sim = random.choice(self.simulators) 305 | #print("Using simulator", type(sim)) 306 | return sim#(os.path.join(self.hrtf_dir, sim.__name__[:-len("_Simulator")], self.dset + '_hrtf.txt'),sr=self.sr) 307 | 308 | class PRASimulator(object): 309 | def __init__(self, 310 | n_mics = 2, 311 | min_absorption=0.6, 312 | max_absorption=1, 313 | fs=44100, 314 | max_order=15, 315 | mean_mic_distance=13.9, 316 | mic_distance_var=0.7, 317 | mic_array_keepout=0.5, 318 | min_room_length=6, 319 | max_room_length=8, 320 | min_room_width=6, 321 | max_room_width=8) -> None: 322 | """ 323 | Mic distance is by default the average of the 324 | median Bitragion Breadth for men and women 325 | """ 326 | # Constant across samples 327 | self.M = n_mics 328 | self.fs = fs 329 | self.K = mic_array_keepout 330 | self.max_order = max_order 331 | self.min_absorption = min_absorption 332 | self.max_absorption = max_absorption 333 | self.R = mean_mic_distance 334 | self.V = mic_distance_var 335 | 336 | self.min_room_length = min_room_length 337 | self.max_room_length = max_room_length 338 | self.min_room_width = min_room_width 339 | self.max_room_width = max_room_width 340 | 341 | def initialize_room_with_random_params(self, num_sources: int, duration: float): 342 | self.D = duration 343 | self.mic_distance = np.random.normal(self.R, scale=self.V ** 0.5) * 1e-2 344 | self.mic_positions = [[-self.mic_distance/2, 0], [self.mic_distance/2, 0]] 345 | self.absorption = np.random.uniform(self.min_absorption, self.max_absorption) 346 | 347 | self.L = np.random.uniform(self.min_room_length, self.max_room_length) 348 | self.W = np.random.uniform(self.min_room_width, self.max_room_width) 349 | 350 | self.left_wall = -self.L / 2 351 | self.right_wall = self.L / 2 352 | 353 | self.bottom_wall = -self.W / 2 354 | self.top_wall = self.W / 2 355 | 356 | self.source_positions = [] 357 | for i in range(num_sources): 358 | source_pos = self._get_random_source_pos(self.left_wall, 359 | self.right_wall, 360 | self.bottom_wall, 361 | self.top_wall, 362 | self.K) 363 | self.source_positions.append(source_pos) 364 | 365 | # Randomize source choose order 366 | self.source_order = [i for i in range(num_sources - 1)] 367 | np.random.shuffle(self.source_order) 368 | self.source_order = [num_sources-1] + self.source_order # Background is always last 369 | 370 | return self 371 | 372 | def intialize_from_metadata(self, metadata_path): 373 | with open(metadata_path, 'r') as f: 374 | metadata = json.load(f) 375 | 376 | self.D = metadata['duration'] 377 | self.M = metadata['n_mics'] 378 | self.fs = metadata['sampling_rate'] 379 | self.max_order = metadata['max_order'] 380 | self.absorption = metadata['absorption'] 381 | self.mic_distance = metadata['mic_distance'] 382 | self.mic_positions = metadata['mic_positions'] 383 | 384 | room_desc = metadata['room'] 385 | self.L = room_desc['length'] 386 | self.W = room_desc['width'] 387 | 388 | self.left_wall = -self.L / 2 389 | self.right_wall = self.L / 2 390 | 391 | self.bottom_wall = -self.W / 2 392 | self.top_wall = self.W / 2 393 | 394 | self.source_order = [] 395 | self.source_positions = [] 396 | 397 | source_list = metadata['sources'] 398 | for source in source_list: 399 | source_id = source['order'] 400 | source_position = source['position'] 401 | 402 | self.source_order.append(source_id) 403 | self.source_positions.append(source_position) 404 | 405 | return self 406 | 407 | def get_metadata(self): 408 | metadata = {} 409 | 410 | metadata['duration'] = self.D 411 | metadata['sampling_rate'] = self.fs 412 | metadata['max_order'] = self.max_order 413 | 414 | metadata['n_mics'] = self.M 415 | metadata['absorption'] = self.absorption 416 | metadata['mic_distance'] = self.mic_distance 417 | metadata['mic_positions'] = self.mic_positions 418 | 419 | room_desc = {} 420 | room_desc['length'] = self.L 421 | room_desc['width'] = self.W 422 | metadata['room'] = room_desc 423 | 424 | metadata['sources'] = [] 425 | for i, source_id in enumerate(self.source_order): 426 | source = {'position':self.source_positions[i], 'order':source_id} 427 | metadata['sources'].append(source) 428 | 429 | return metadata 430 | 431 | def save(self, path): 432 | metadata = self.get_metadata() 433 | 434 | with open(path, 'w') as f: 435 | json.dump(metadata, f, indent=4) 436 | 437 | def simulate(self, 438 | source_audio): 439 | """ 440 | Input: list of source_audio (T,) 441 | returns y (M, T) 442 | """ 443 | 444 | corners = np.array([[self.left_wall, self.bottom_wall], 445 | [self.right_wall, self.bottom_wall], 446 | [self.right_wall, self.top_wall], 447 | [self.left_wall, self.top_wall]]).T 448 | 449 | room = pra.room.Room.from_corners(corners, 450 | absorption=self.absorption, 451 | fs=self.fs, 452 | max_order=self.max_order) 453 | 454 | mic_array = np.array(self.mic_positions).T#pra.circular_2D_array(center=[0., 0.], M=self.M, phi0=180, radius=self.mic_distance * 0.5 * 1e-2) 455 | room.add_microphone_array(mic_array) 456 | 457 | for i, source_pos in enumerate(self.source_positions): 458 | room.add_source(source_pos, signal=source_audio[i]) 459 | 460 | y = room.simulate(return_premix=True) 461 | 462 | total_samples = int(round(self.D * self.fs)) 463 | return y[..., :total_samples] 464 | 465 | def _get_random_source_pos(self, L, R, B, T, K): 466 | pos = [0, 0] 467 | 468 | while np.linalg.norm(pos) < K: 469 | x = np.random.uniform(L, R) 470 | y = np.random.uniform(B, T) 471 | 472 | pos = [x, y] 473 | 474 | return pos 475 | 476 | 477 | def test(): 478 | n_sources = 5 479 | duration = 1 480 | save_path = 'mymetadata.json' 481 | 482 | simulator = PRASimulator().initialize_room_with_random_params(n_sources, duration) 483 | simulator.save(save_path) 484 | 485 | simulator2 = PRASimulator().intialize_from_metadata(save_path) 486 | 487 | x = [np.random.random(44100) for i in range(n_sources)] 488 | # x = [np.sin(2 * np.pi * 440 * np.arange(0, 1, 1/44100)) for i in range(n_sources)] 489 | y = simulator2.simulate(x) * 1e3 490 | 491 | import soundfile as sf 492 | sf.write('audio.wav', y[0].T, 44100) 493 | 494 | 495 | if __name__ == "__main__": 496 | test() 497 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | from scipy.io.wavfile import write as wavwrite 3 | 4 | 5 | def read_audio_file(file_path, sr): 6 | """ 7 | Reads audio file to system memory. 8 | """ 9 | return librosa.core.load(file_path, mono=False, sr=sr)[0] 10 | 11 | 12 | def write_audio_file(file_path, data, sr): 13 | """ 14 | Writes audio file to system memory. 15 | @param file_path: Path of the file to write to 16 | @param data: Audio signal to write (n_channels x n_samples) 17 | @param sr: Sampling rate 18 | """ 19 | wavwrite(file_path, sr, data.T) -------------------------------------------------------------------------------- /experiments/dc_waveformer/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model": "src.training.dcc_tf_binaural", 3 | "base_metric": "scale_invariant_signal_noise_ratio", 4 | "fix_lr_epochs": 40, 5 | "epochs": 80, 6 | "batch_size": 16, 7 | "eval_batch_size": 64, 8 | "n_workers": 16, 9 | "model_params": { 10 | "L": 32, 11 | "label_len": 20, 12 | "model_dim": 256, 13 | "num_enc_layers": 10, 14 | "num_dec_layers": 1, 15 | "dec_buf_len": 13, 16 | "dec_chunk_size": 13, 17 | "use_pos_enc": true, 18 | "conditioning": "mult", 19 | "out_buf_len": 4 20 | }, 21 | "train_dataset": "src.training.datasets.curated_binaural_augrir.CuratedBinauralAugRIRDataset", 22 | "train_data_args": { 23 | "fg_dir": "data/BinauralCuratedDataset/scaper_fmt/train", 24 | "bg_dir": "data/BinauralCuratedDataset/TAU-acoustic-sounds/TAU-urban-acoustic-scenes-2019-development", 25 | "bg_scaper_dir": "data/BinauralCuratedDataset/bg_scaper_fmt/train", 26 | "jams_dir": "data/BinauralCuratedDataset/jams_hard/train", 27 | "hrtf_dir": "data/BinauralCuratedDataset/hrtf", 28 | "dset": "train", 29 | "sr": 44100, 30 | "resample_rate": null, 31 | "reverb": true 32 | }, 33 | "val_dataset": "src.training.datasets.curated_binaural_augrir.CuratedBinauralAugRIRDataset", 34 | "val_data_args": { 35 | "fg_dir": "data/BinauralCuratedDataset/scaper_fmt/val", 36 | "bg_dir": "data/BinauralCuratedDataset/TAU-acoustic-sounds/TAU-urban-acoustic-scenes-2019-development", 37 | "bg_scaper_dir": "data/BinauralCuratedDataset/bg_scaper_fmt/val", 38 | "jams_dir": "data/BinauralCuratedDataset/jams_hard/val", 39 | "hrtf_dir": "data/BinauralCuratedDataset/hrtf", 40 | "dset": "val", 41 | "sr": 44100, 42 | "resample_rate": null, 43 | "reverb": true 44 | }, 45 | "test_dataset": "src.training.datasets.curated_binaural_augrir.CuratedBinauralAugRIRDataset", 46 | "test_data_args": { 47 | "fg_dir": "data/BinauralCuratedDataset/scaper_fmt/test", 48 | "bg_dir": "data/BinauralCuratedDataset/TAU-acoustic-sounds/TAU-urban-acoustic-scenes-2019-evaluation", 49 | "bg_scaper_dir": "data/BinauralCuratedDataset/bg_scaper_fmt/test", 50 | "jams_dir": "data/BinauralCuratedDataset/jams/test", 51 | "hrtf_dir": "data/BinauralCuratedDataset/hrtf", 52 | "dset": "test", 53 | "sr": 44100, 54 | "resample_rate": null, 55 | "reverb": true 56 | }, 57 | "optim": { 58 | "lr": 0.0005, 59 | "weight_decay": 0.0 60 | }, 61 | "lr_sched": { 62 | "mode": "max", 63 | "factor": 0.5, 64 | "patience": 5, 65 | "min_lr": 5e-06, 66 | "threshold": 0.1, 67 | "threshold_mode": "abs" 68 | }, 69 | "commit_hash": "dce742247886d0c98116fea3602c78bc215e5591" 70 | } 71 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ### Requirements 2 | librosa 3 | torch==1.13.1 4 | torchaudio==0.13.1 5 | soundfile 6 | scipy 7 | matplotlib 8 | tqdm 9 | numpy 10 | pandas 11 | speechbrain 12 | tensorflow 13 | tensorflow-probability 14 | torchmetrics==0.10.0 15 | seaborn 16 | ipykernel 17 | scaper 18 | thop==0.1.1.post2209072238 19 | openl3 20 | youtube_dl 21 | transformers 22 | bs4 23 | pyroomacoustics 24 | python-sofa==0.2.0 25 | onnx 26 | onnxruntime 27 | torch_tb_profiler 28 | ffmpegio 29 | noisereduce 30 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vb000/SemanticHearing/07e9786c7a741f0a7c722dcde66a2679ca068c50/src/__init__.py -------------------------------------------------------------------------------- /src/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vb000/SemanticHearing/07e9786c7a741f0a7c722dcde66a2679ca068c50/src/helpers/__init__.py -------------------------------------------------------------------------------- /src/helpers/eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from scipy import signal 4 | from scipy.fft import rfft, irfft 5 | from scipy.signal import stft 6 | from pyroomacoustics.doa import srp 7 | from pyroomacoustics.experimental.localization import tdoa 8 | import pyroomacoustics as pra 9 | import src.helpers.utils as utils 10 | import torch 11 | 12 | try: 13 | import mklfft as fft 14 | except ImportError: 15 | import numpy.fft as fft 16 | 17 | 18 | def tdoa2(x1, x2, interp=1, fs=1, phat=True, t_max=None): 19 | """ 20 | This function computes the time difference of arrival (TDOA) 21 | of the signal at the two microphones. This in turns is used to infer 22 | the direction of arrival (DOA) of the signal. 23 | Specifically if s(k) is the signal at the reference microphone and 24 | s_2(k) at the second microphone, then for signal arriving with DOA 25 | theta we have 26 | s_2(k) = s(k - tau) 27 | with 28 | tau = fs*d*sin(theta)/c 29 | where d is the distance between the two microphones and c the speed of sound. 30 | We recover tau using the Generalized Cross Correlation - Phase Transform (GCC-PHAT) 31 | method. The reference is 32 | Knapp, C., & Carter, G. C. (1976). The generalized correlation method for estimation of time delay. 33 | Parameters 34 | ---------- 35 | x1 : nd-array 36 | The signal of the reference microphone 37 | x2 : nd-array 38 | The signal of the second microphone 39 | interp : int, optional (default 1) 40 | The interpolation value for the cross-correlation, it can 41 | improve the time resolution (and hence DOA resolution) 42 | fs : int, optional (default 44100 Hz) 43 | The sampling frequency of the input signal 44 | Return 45 | ------ 46 | theta : float 47 | the angle of arrival (in radian (I think)) 48 | pwr : float 49 | the magnitude of the maximum cross correlation coefficient 50 | delay : float 51 | the delay between the two microphones (in seconds) 52 | """ 53 | # zero padded length for the FFT 54 | n = x1.shape[-1] + x2.shape[-1] - 1 55 | if n % 2 != 0: 56 | n += 1 57 | 58 | # Generalized Cross Correlation Phase Transform 59 | # Used to find the delay between the two microphones 60 | # up to line 71 61 | X1 = fft.rfft(np.array(x1, dtype=np.float32), n=n, axis=-1) 62 | X2 = fft.rfft(np.array(x2, dtype=np.float32), n=n, axis=-1) 63 | 64 | if phat: 65 | X1 /= np.abs(X1) 66 | X2 /= np.abs(X2) 67 | 68 | cc = fft.irfft(X1 * np.conj(X2), n=interp * n, axis=-1) 69 | 70 | # maximum possible delay given distance between microphones 71 | 72 | if t_max is None: 73 | t_max = n // 2 + 1 74 | 75 | # reorder the cross-correlation coefficients 76 | cc = np.concatenate((cc[..., -t_max:], cc[..., :t_max]), axis=-1) 77 | 78 | # import matplotlib.pyplot as plt 79 | 80 | # t = np.arange(-t_max/fs, (t_max)/fs, 1/fs) * 1e6 81 | # plt.plot(t, cc[15]) 82 | # plt.show() 83 | 84 | # pick max cross correlation index as delay 85 | tau = np.argmax(np.abs(cc), axis=-1) 86 | tau -= t_max # because zero time is at the center of the array 87 | 88 | return tau / (fs * interp) 89 | 90 | 91 | from sklearn.utils.extmath import weighted_mode 92 | def framewise_gccphat(x, frame_dur, sr, window='tukey'): 93 | TMAX = int(round(1e-3 * sr)) 94 | frame_width = int(round(frame_dur * sr)) 95 | 96 | # Total number of frames 97 | T = 1 + (x.shape[-1] - 1)// frame_width 98 | 99 | # Drop samples to get a multiple of frame size 100 | if x.shape[-1] % T != 0: 101 | x = x[..., -x.shape[-1]%T:] 102 | 103 | assert x.shape[-1] % T == 0 104 | frames = np.array(np.split(x, T, axis=-1)) 105 | 106 | window = signal.get_window(window, frame_width) 107 | frames = frames * window 108 | 109 | # Consider only frames that have energy above some threshold (ignore silence) 110 | ENERGY_THRESHOLD = 5e-4 111 | frame_energy = np.max(np.mean(frames**2, axis=-1)**0.5, axis=-1) 112 | mask = frame_energy > ENERGY_THRESHOLD 113 | frames = frames[mask] 114 | 115 | fw_gccphat = tdoa2(frames[..., 0, :], frames[..., 1, :], fs=sr, t_max=TMAX) 116 | 117 | # print(mask) 118 | # print(fw_gccphat) 119 | # print(frame_energy[mask]) 120 | itd = weighted_mode(fw_gccphat, frame_energy[mask], axis=-1)[0] 121 | return itd[0] 122 | 123 | def fw_itd_diff(s_est, s_gt, sr, frame_duration=0.25): 124 | """ 125 | Computes frame-wise delta ITD 126 | """ 127 | # print("GT") 128 | itd_gt = framewise_gccphat(s_gt, frame_duration, sr) * 1e6 129 | # print("GT FW_ITD", itd_gt) 130 | # print("EST") 131 | itd_est = framewise_gccphat(s_est, frame_duration, sr) * 1e6 132 | # print("EST FW_ITD", itd_est) 133 | return np.abs(itd_est - itd_gt) 134 | 135 | def cal_interaural_error(predictions, targets, sr, debug=False): 136 | """Compute ITD and ILD errors 137 | input: (1, time, channel, speaker) 138 | """ 139 | 140 | TMAX = int(round(1e-3 * sr)) 141 | EPS = 1e-8 142 | s_target = targets[0] # [T,E,C] 143 | s_prediction = predictions[0] # [T,E,C] 144 | 145 | # ITD is computed with generalized cross-correlation phase transform (GCC-PHAT) 146 | ITD_target = [ 147 | tdoa2( 148 | s_target[:, 0, i].cpu().numpy(), 149 | s_target[:, 1, i].cpu().numpy(), 150 | fs=sr, 151 | t_max=TMAX 152 | ) 153 | * 10 ** 6 154 | for i in range(s_target.shape[-1]) 155 | ] 156 | if debug: 157 | print("TARGET ITD", ITD_target) 158 | 159 | ITD_prediction = [ 160 | tdoa2( 161 | s_prediction[:, 0, i].cpu().numpy(), 162 | s_prediction[:, 1, i].cpu().numpy(), 163 | fs=sr, 164 | t_max=TMAX, 165 | ) 166 | * 10 ** 6 167 | for i in range(s_prediction.shape[-1]) 168 | ] 169 | 170 | if debug: 171 | print("PREDICTED ITD", ITD_prediction) 172 | 173 | ITD_error1 = np.mean( 174 | np.abs(np.array(ITD_target) - np.array(ITD_prediction)) 175 | ) 176 | ITD_error2 = np.mean( 177 | np.abs(np.array(ITD_target) - np.array(ITD_prediction)[::-1]) 178 | ) 179 | ITD_error = min(ITD_error1, ITD_error2) 180 | 181 | # ILD = 10 * log_10(||s_left||^2 / ||s_right||^2) 182 | ILD_target_beforelog = torch.sum(s_target[:, 0] ** 2, dim=0) / ( 183 | torch.sum(s_target[:, 1] ** 2, dim=0) + EPS 184 | ) 185 | ILD_target = 10 * torch.log10(ILD_target_beforelog + EPS) # [C] 186 | ILD_prediction_beforelog = torch.sum(s_prediction[:, 0] ** 2, dim=0) / ( 187 | torch.sum(s_prediction[:, 1] ** 2, dim=0) + EPS 188 | ) 189 | ILD_prediction = 10 * torch.log10(ILD_prediction_beforelog + EPS) # [C] 190 | 191 | ILD_error1 = torch.mean(torch.abs(ILD_target - ILD_prediction)) 192 | ILD_error2 = torch.mean(torch.abs(ILD_target - ILD_prediction.flip(0))) 193 | ILD_error = min(ILD_error1.item(), ILD_error2.item()) 194 | 195 | return ITD_error, ILD_error 196 | 197 | def compute_itd(s_left, s_right, sr, t_max = None): 198 | corr = signal.correlate(s_left, s_right) 199 | lags = signal.correlation_lags(len(s_left), len(s_right)) 200 | corr /= np.max(corr) 201 | 202 | mid = len(corr)//2 + 1 203 | 204 | # print(corr[-t_max:]) 205 | cc = np.concatenate((corr[-mid:], corr[:mid])) 206 | 207 | if t_max is not None: 208 | # if False: 209 | # print(cc[-t_max:].shape) 210 | cc = np.concatenate([cc[-t_max+1:], cc[:t_max+1]]) 211 | else: 212 | t_max = mid 213 | 214 | # print("OKKK", cc.shape) 215 | # t = np.arange(-t_max/sr, (t_max)/sr, 1/sr) * 1e6 216 | # plt.plot(t, np.abs(cc)) 217 | # plt.show() 218 | tau = np.argmax(np.abs(cc)) 219 | tau -= t_max 220 | # tau = lags[x] 221 | # print(tau/ sr * 1e6) 222 | 223 | return tau / sr * 1e6 224 | 225 | 226 | def compute_doa(mic_pos, s, sr, nfft=2048, num_sources=1): 227 | # freq_range = [100, 20000] 228 | 229 | X = pra.transform.stft.analysis(s.T, nfft, nfft // 2, ) 230 | X = X.transpose([2, 1, 0]) 231 | 232 | algo_names = ['SRP', 'MUSIC', 'FRIDA', 'TOPS', 'WAVES', 'CSSM', 'NormMUSIC'] 233 | 234 | srp = pra.doa.algorithms['NormMUSIC'](mic_pos.T, sr, nfft, c=343, num_sources=num_sources) 235 | srp.locate_sources(X) 236 | 237 | values = srp.grid.values 238 | phi = np.linspace(-np.pi, np.pi, 360) 239 | 240 | values = np.roll(values, shift=180) 241 | 242 | # plt.plot(phi * 180 / np.pi, values) 243 | # plt.xlim([-90, 90]) 244 | # plt.show() 245 | 246 | peak_idx = 90 + np.argmax(values[90:270]) 247 | return phi[peak_idx] 248 | 249 | def doa_diff(mic_pos, est, gt, sr): 250 | doa_est = compute_doa(mic_pos, est, sr) 251 | doa_gt = compute_doa(mic_pos, gt, sr) 252 | return np.abs(doa_gt - doa_est) 253 | 254 | def gcc_phat(s_left, s_right, sr): 255 | X = rfft(s_left) 256 | Y = rfft(s_right) 257 | 258 | Z = X * np.conj(Y) 259 | 260 | y = irfft(np.exp(1j * np.angle(Z))) 261 | center = (len(y) + 1)//2 262 | y = np.concatenate([y[center:], y[:center]]) 263 | lags = (np.linspace(0, len(y), len(y)) - ((len(y) + 1) / 2)) / sr 264 | x = np.argmax(y) 265 | tau = lags[x] 266 | 267 | return lags, y 268 | 269 | def compute_ild(s_left, s_right): 270 | sum_sq_left = np.sum(s_left ** 2, axis=-1) 271 | sum_sq_right = np.sum(s_right ** 2, axis=-1) 272 | # print(sum_sq_left) 273 | # print(sum_sq_right) 274 | return 10 * np.log10(sum_sq_left / sum_sq_right) 275 | 276 | def itd_diff(s_est, s_gt, sr): 277 | """ 278 | Computes the ITD error between model estimate and ground truth 279 | input: (*, 2, T), (*, 2, T) 280 | """ 281 | TMAX = int(round(1e-3 * sr)) 282 | itd_est = compute_itd(s_est[..., 0, :], s_est[..., 1, :], sr, TMAX) 283 | itd_gt = compute_itd(s_gt[..., 0, :], s_gt[..., 1, :], sr, TMAX) 284 | return np.abs(itd_est - itd_gt) 285 | 286 | def gcc_phat_diff(s_est, s_gt, sr): 287 | TMAX = int(round(1e-3 * sr)) 288 | itd_est = tdoa2(s_est[..., 0, :], s_est[..., 1, :], fs=sr, t_max=TMAX) 289 | itd_gt = tdoa2(s_gt[..., 0, :], s_gt[..., 1, :], fs=sr, t_max=TMAX) 290 | return np.abs(itd_est - itd_gt) * 10 ** 6 291 | 292 | def ild_diff(s_est, s_gt): 293 | """ 294 | Computes the ILD error between model estimate and ground truth 295 | input: (*, 2, T), (*, 2, T) 296 | """ 297 | ild_est = compute_ild(s_est[..., 0, :], s_est[..., 1, :]) 298 | ild_gt = compute_ild(s_gt[..., 0, :], s_gt[..., 1, :]) 299 | return np.abs(ild_est - ild_gt) 300 | 301 | def si_sdr(estimated_signal, reference_signals, scaling=True): 302 | """ 303 | This is a scale invariant SDR. See https://arxiv.org/pdf/1811.02508.pdf 304 | or https://github.com/sigsep/bsseval/issues/3 for the motivation and 305 | explanation 306 | Input: 307 | estimated_signal and reference signals are (N,) numpy arrays 308 | Returns: SI-SDR as scalar 309 | """ 310 | 311 | Rss = np.dot(reference_signals, reference_signals) 312 | this_s = reference_signals 313 | 314 | if scaling: 315 | # get the scaling factor for clean sources 316 | a = np.dot(this_s, estimated_signal) / Rss 317 | else: 318 | a = 1 319 | 320 | e_true = a * this_s 321 | e_res = estimated_signal - e_true 322 | 323 | Sss = (e_true**2).sum() 324 | Snn = (e_res**2).sum() 325 | 326 | SDR = 10 * np.log10(Sss/Snn) 327 | 328 | return SDR 329 | 330 | 331 | if __name__ == "__main__": 332 | fs = 44100 333 | 334 | corners = np.array([[-2, 2], 335 | [2, 2], 336 | [2, -2], 337 | [-2, -2]]).T 338 | 339 | room = pra.room.Room.from_corners(corners, 340 | absorption=1, 341 | fs=fs, 342 | max_order=1) 343 | 344 | # x = utils.read_audio_file('outputs/bin_gt.wav', fs) 345 | x_gt = utils.read_audio_file('save_examples_few/00622/gt.wav', fs) 346 | x_est = utils.read_audio_file('save_examples_few/00622/binaural.wav', fs) 347 | 348 | # framewise_gccphat(x, 0.25, fs) 349 | print(fw_itd_diff(x_est, x_gt, fs)) 350 | # x = utils.read_audio_file('save_examples_val/00000/gt.wav', fs) 351 | # y = utils.read_audio_file('tests/sample_audio2.wav', fs) 352 | # mic_positions = np.array([[0, 0.09], [0, -0.09]]) 353 | # room.add_microphone_array(mic_positions.T) 354 | 355 | # a1 = 50 * np.pi / 180 356 | # a2 = 60 * np.pi / 180 357 | 358 | # s1 = np.array([np.cos(a1), np.sin(a1)]) 359 | # room.add_source(s1, signal=x) 360 | 361 | # # s2 = np.array([np.cos(a2), np.sin(a2)]) 362 | # # room.add_source(s2, signal=y) 363 | 364 | # room.simulate() 365 | 366 | # s = room.mic_array.signals # (M, T) 367 | # s = s.transpose() # (T, M) 368 | # s = np.reshape(s, (1, *s.shape, 1)) 369 | 370 | # s_est = s.copy() + np.random.normal(0, 1e-2, s.shape) 371 | # s_est[0, :, 0, 0] = np.roll(s_est[0, : , 0, 0], shift=222) 372 | 373 | # s = torch.from_numpy(s) 374 | # s_est = torch.from_numpy(s_est) 375 | 376 | # # itd_error, ild_error = cal_interaural_error(s_est, s, fs) 377 | # # print('ITD', itd_error) 378 | # # print('ILD', ild_error) 379 | 380 | # itd_error = itd_diff(s_est, s, fs) 381 | # print('ITD', itd_error) 382 | 383 | # doa = compute_doa(mic_positions, s, fs, num_sources=2) 384 | # print(doa * 180 / np.pi) 385 | 386 | # x = np.array([x[0], x[0]]) 387 | # x[0] = np.roll(x[0], shift=2) * 0.5 388 | # # np.random.seed(0) 389 | # x = x + np.random.normal(loc=0, scale=1e-2, size=x.shape) 390 | 391 | # x = x[:, 140000:140000 + 190000] 392 | # x = x 393 | 394 | # fig, ax = plt.subplots() 395 | # ax.plot(x[0]) 396 | # ax.plot(x[1]) 397 | 398 | # tdoa2(x[0, :], x[1, :], fs=fs, t_max=44) 399 | # utils.write_audio_file('gcc.wav', x, fs) 400 | 401 | # tau = compute_itd(x, y, 44100) 402 | # # print(tau) 403 | # lags, z = gcc_phat(x, y, 44100) 404 | # plt.plot(t, x) 405 | # plt.plot(t, y) 406 | # plt.plot(lags, z) 407 | # plt.grid() 408 | # plt.show() 409 | -------------------------------------------------------------------------------- /src/helpers/utils.py: -------------------------------------------------------------------------------- 1 | """A collection of useful helper functions""" 2 | 3 | import os 4 | import logging 5 | import json 6 | import importlib 7 | 8 | import torch 9 | from torch.profiler import profile, record_function, ProfilerActivity 10 | import pandas as pd 11 | from torchmetrics.functional import( 12 | scale_invariant_signal_noise_ratio as si_snr, 13 | signal_noise_ratio as snr, 14 | signal_distortion_ratio as sdr, 15 | scale_invariant_signal_distortion_ratio as si_sdr) 16 | import matplotlib.pyplot as plt 17 | 18 | class Params(): 19 | """Class that loads hyperparameters from a json file. 20 | Example: 21 | ``` 22 | params = Params(json_path) 23 | print(params.learning_rate) 24 | params.learning_rate = 0.5 # change the value of learning_rate in params 25 | ``` 26 | """ 27 | 28 | def __init__(self, json_path): 29 | with open(json_path) as f: 30 | params = json.load(f) 31 | self.__dict__.update(params) 32 | 33 | def save(self, json_path): 34 | with open(json_path, 'w') as f: 35 | json.dump(self.__dict__, f, indent=4) 36 | 37 | def update(self, json_path): 38 | """Loads parameters from json file""" 39 | with open(json_path) as f: 40 | params = json.load(f) 41 | self.__dict__.update(params) 42 | 43 | @property 44 | def dict(self): 45 | """Gives dict-like access to Params instance by `params.dict['learning_rate']""" 46 | return self.__dict__ 47 | 48 | def save_graph(train_metrics, test_metrics, save_dir): 49 | metrics = [snr, si_snr] 50 | results = {'train_loss': train_metrics['loss'], 51 | 'test_loss' : test_metrics['loss']} 52 | 53 | for m_fn in metrics: 54 | results["train_"+m_fn.__name__] = train_metrics[m_fn.__name__] 55 | results["test_"+m_fn.__name__] = test_metrics[m_fn.__name__] 56 | 57 | results_pd = pd.DataFrame(results) 58 | 59 | results_pd.to_csv(os.path.join(save_dir, 'results.csv')) 60 | 61 | fig, temp_ax = plt.subplots(2, 3, figsize=(15,10)) 62 | axs=[] 63 | for i in temp_ax: 64 | for j in i: 65 | axs.append(j) 66 | 67 | x = range(len(train_metrics['loss'])) 68 | axs[0].plot(x, train_metrics['loss'], label='train') 69 | axs[0].plot(x, test_metrics['loss'], label='test') 70 | axs[0].set(ylabel='Loss') 71 | axs[0].set(xlabel='Epoch') 72 | axs[0].set_title('loss',fontweight='bold') 73 | axs[0].legend() 74 | 75 | for i in range(len(metrics)): 76 | axs[i+1].plot(x, train_metrics[metrics[i].__name__], label='train') 77 | axs[i+1].plot(x, test_metrics[metrics[i].__name__], label='test') 78 | axs[i+1].set(xlabel='Epoch') 79 | axs[i+1].set_title(metrics[i].__name__,fontweight='bold') 80 | axs[i+1].legend() 81 | 82 | plt.tight_layout() 83 | plt.savefig(os.path.join(save_dir, 'results.png')) 84 | plt.close(fig) 85 | 86 | def import_attr(import_path): 87 | module, attr = import_path.rsplit('.', 1) 88 | return getattr(importlib.import_module(module), attr) 89 | 90 | def set_logger(log_path): 91 | """Set the logger to log info in terminal and file `log_path`. 92 | In general, it is useful to have a logger so that every output to the terminal is saved 93 | in a permanent file. Here we save it to `model_dir/train.log`. 94 | Example: 95 | ``` 96 | logging.info("Starting training...") 97 | ``` 98 | Args: 99 | log_path: (string) where to log 100 | """ 101 | logger = logging.getLogger() 102 | logger.setLevel(logging.INFO) 103 | logger.handlers.clear() 104 | 105 | # Logging to a file 106 | file_handler = logging.FileHandler(log_path) 107 | file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) 108 | logger.addHandler(file_handler) 109 | 110 | # Logging to console 111 | stream_handler = logging.StreamHandler() 112 | stream_handler.setFormatter(logging.Formatter('%(message)s')) 113 | logger.addHandler(stream_handler) 114 | 115 | def load_checkpoint(checkpoint, model, optim=None, lr_sched=None, data_parallel=False): 116 | """Loads model parameters (state_dict) from file_path. 117 | 118 | Args: 119 | checkpoint: (string) filename which needs to be loaded 120 | model: (torch.nn.Module) model for which the parameters are loaded 121 | data_parallel: (bool) if the model is a data parallel model 122 | """ 123 | if not os.path.exists(checkpoint): 124 | raise("File doesn't exist {}".format(checkpoint)) 125 | 126 | state_dict = torch.load(checkpoint) 127 | 128 | if data_parallel: 129 | state_dict['model_state_dict'] = { 130 | 'module.' + k: state_dict['model_state_dict'][k] 131 | for k in state_dict['model_state_dict'].keys()} 132 | model.load_state_dict(state_dict['model_state_dict']) 133 | 134 | if optim is not None: 135 | optim.load_state_dict(state_dict['optim_state_dict']) 136 | 137 | if lr_sched is not None: 138 | lr_sched.load_state_dict(state_dict['lr_sched_state_dict']) 139 | 140 | return state_dict['epoch'], state_dict['train_metrics'], \ 141 | state_dict['val_metrics'] 142 | 143 | def save_checkpoint(checkpoint, epoch, model, optim=None, lr_sched=None, 144 | train_metrics=None, val_metrics=None, data_parallel=False): 145 | """Saves model parameters (state_dict) to file_path. 146 | 147 | Args: 148 | checkpoint: (string) filename which needs to be loaded 149 | model: (torch.nn.Module) model for which the parameters are loaded 150 | data_parallel: (bool) if the model is a data parallel model 151 | """ 152 | if os.path.exists(checkpoint): 153 | raise("File already exists {}".format(checkpoint)) 154 | 155 | model_state_dict = model.state_dict() 156 | if data_parallel: 157 | model_state_dict = { 158 | k.partition('module.')[2]: 159 | model_state_dict[k] for k in model_state_dict.keys()} 160 | 161 | optim_state_dict = None if not optim else optim.state_dict() 162 | lr_sched_state_dict = None if not lr_sched else lr_sched.state_dict() 163 | 164 | state_dict = { 165 | 'epoch': epoch, 166 | 'model_state_dict': model_state_dict, 167 | 'optim_state_dict': optim_state_dict, 168 | 'lr_sched_state_dict': lr_sched_state_dict, 169 | 'train_metrics': train_metrics, 170 | 'val_metrics': val_metrics 171 | } 172 | 173 | torch.save(state_dict, checkpoint) 174 | 175 | def model_size(model): 176 | """ 177 | Returns size of the `model` in millions of parameters. 178 | """ 179 | num_train_params = sum( 180 | p.numel() for p in model.parameters() if p.requires_grad) 181 | return num_train_params / 1e6 182 | 183 | def run_time(model, inputs, profiling=False): 184 | """ 185 | Returns runtime of a model in ms. 186 | """ 187 | # Warmup 188 | for _ in range(100): 189 | output = model(*inputs) 190 | 191 | with profile(activities=[ProfilerActivity.CPU], 192 | record_shapes=True) as prof: 193 | with record_function("model_inference"): 194 | output = model(*inputs) 195 | 196 | # Print profiling results 197 | if profiling: 198 | print(prof.key_averages().table(sort_by="self_cpu_time_total", 199 | row_limit=20)) 200 | 201 | # Return runtime in ms 202 | return prof.profiler.self_cpu_time_total / 1000 203 | 204 | def format_lr_info(optimizer): 205 | lr_info = "" 206 | for i, pg in enumerate(optimizer.param_groups): 207 | lr_info += " {group %d: params=%.5fM lr=%.1E}" % ( 208 | i, sum([p.numel() for p in pg['params']]) / (1024 ** 2), pg['lr']) 209 | return lr_info 210 | 211 | -------------------------------------------------------------------------------- /src/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vb000/SemanticHearing/07e9786c7a741f0a7c722dcde66a2679ca068c50/src/training/__init__.py -------------------------------------------------------------------------------- /src/training/datasets/curated_binaural.py: -------------------------------------------------------------------------------- 1 | from src.training.datasets.semaudio_binaural_base import SemAudioBinauralBaseDataset 2 | 3 | class CuratedBinauralDataset(SemAudioBinauralBaseDataset): 4 | """ 5 | Torch dataset object for synthetically rendered spatial data. 6 | """ 7 | labels = [ 8 | "alarm_clock", "baby_cry", "birds_chirping", "cat", "car_horn", 9 | "cock_a_doodle_doo", "cricket", "computer_typing", 10 | "dog", "glass_breaking", "gunshot", "hammer", "music", 11 | "ocean", "door_knock", "singing", "siren", "speech", 12 | "thunderstorm", "toilet_flush"] 13 | 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | self.labels = [ 17 | "alarm_clock", "baby_cry", "birds_chirping", "cat", "car_horn", 18 | "cock_a_doodle_doo", "cricket", "computer_typing", 19 | "dog", "glass_breaking", "gunshot", "hammer", "music", 20 | "ocean", "door_knock", "singing", "siren", "speech", 21 | "thunderstorm", "toilet_flush"] 22 | -------------------------------------------------------------------------------- /src/training/datasets/curated_binaural_augrir.py: -------------------------------------------------------------------------------- 1 | from src.training.datasets.curated_binaural import CuratedBinauralDataset 2 | import os 3 | import sofa 4 | import json 5 | import scaper 6 | import random 7 | import torch 8 | import numpy as np 9 | from random import randrange 10 | 11 | from data.multi_ch_simulator import Multi_Ch_Simulator 12 | 13 | import hashlib 14 | 15 | 16 | class CuratedBinauralAugRIRDataset(CuratedBinauralDataset): 17 | """ 18 | Torch dataset object for synthetically rendered spatial data. 19 | """ 20 | def __init__(self, *args, **kwargs): 21 | self.scaper_bg_dir = kwargs['bg_scaper_dir'] 22 | kwargs.pop('bg_scaper_dir', None) 23 | 24 | if 'reverb' in kwargs: 25 | self.reverb = kwargs['reverb'] 26 | kwargs.pop('reverb', None) 27 | else: 28 | self.reverb = True 29 | 30 | super().__init__(*args, **kwargs) 31 | 32 | # Simulate 33 | self.simulator = Multi_Ch_Simulator(self.hrtf_dir, self.dset, self.sr, self.reverb) 34 | 35 | def load_sample(self, sample_dir, hrtf_dir, fg_dir, bg_dir, num_targets, 36 | sample_targets=False, resampler=None): 37 | """ 38 | Loads a single sample: 39 | sample_dir: Path to sample directory (jams file + metadata JSON) 40 | hrtf_dir: Path to hrtf dataset (sofa files) 41 | fg_dir: Path to foreground dataset () 42 | bg_dir: Path to background dataset (TAU) 43 | num_targets: Number of gt targets to choose. 44 | sample_targets: Whether or not targets should be randomly chosen from the list 45 | channel: Channel index to select. If None, returns all channels 46 | """ 47 | 48 | sample_path = sample_dir 49 | 50 | # Load HRIR 51 | metadata_path = os.path.join(sample_path, 'metadata.json') 52 | with open(metadata_path, 'rb') as f: 53 | metadata = json.load(f) 54 | 55 | # Load background audio 56 | bg_jamsfile = os.path.join(sample_path, 'background.jams') 57 | _, _, _, bg_event_audio_list = scaper.generate_from_jams( 58 | bg_jamsfile, fg_path=self.scaper_bg_dir, bg_path=bg_dir) 59 | 60 | # Load foreground audio 61 | fg_jamsfile = os.path.join(sample_path, 'mixture.jams') 62 | mixture, _, _, fg_event_audio_list = scaper.generate_from_jams( 63 | fg_jamsfile, fg_path=fg_dir, bg_path='.') 64 | 65 | # Read number of background sources 66 | num_background = metadata['num_background'] 67 | 68 | source_labels = [] 69 | source_list = metadata['sources'] 70 | for i in range(len(source_list)): 71 | label = source_list[i]['label'] 72 | 73 | # Sanity check 74 | if i < num_background: 75 | assert label not in self.labels, "Background sources are not in the right order" 76 | else: 77 | assert label in self.labels, "Foreground sources are not in the right order" 78 | 79 | source_labels.append(label) 80 | 81 | # Concatenate event audio lists 82 | event_audio_list = np.array(bg_event_audio_list + fg_event_audio_list, dtype=np.float32)[..., 0] 83 | 84 | # Generate random simulator 85 | simulator = self.simulator.get_random_simulator() 86 | 87 | if self.dset == 'test': 88 | seed = int.from_bytes(hashlib.sha256(str(sample_dir).encode()).digest()[:4], 'little') # 32-bit int 89 | simulator.seed(seed) 90 | 91 | total_samples = mixture.shape[0] 92 | gt_audio = simulator.initialize_room_with_random_params(len(source_list), 0, source_labels, num_background)\ 93 | .simulate(event_audio_list)[..., :total_samples] 94 | metadata = simulator.get_metadata() 95 | 96 | # Load source information 97 | sources = [] 98 | source_list = metadata['sources'] 99 | for i in range(len(source_list)): 100 | order = source_list[i]['order'] 101 | pos = source_list[i]['position'] 102 | label = source_list[i]['label'] 103 | 104 | sources.append((order, i, pos, gt_audio[i], label)) 105 | 106 | # Sort sources by order 107 | sources = sorted(sources, key=lambda x: x[0]) 108 | 109 | gt_events = [x[4] for x in sources] 110 | 111 | # Remove background from gt_events 112 | gt_events = gt_events[:-num_background] 113 | 114 | if sample_targets: 115 | labels = random.sample(gt_events, randrange(1,num_targets+1)) 116 | else: 117 | labels = gt_events[:num_targets] 118 | 119 | label_vector = self.get_label_vector(labels) 120 | 121 | # Generate mixture and gt audio 122 | mixture = np.sum(gt_audio, axis=0) 123 | gt = np.zeros_like(mixture) 124 | 125 | # Go over each source and convolve with HRTF 126 | metadata['chosen_sources'] = [] 127 | for source in sources: 128 | _, i, _, audio, label = source 129 | 130 | if label in labels: 131 | gt += audio 132 | metadata['chosen_sources'].append(i) 133 | 134 | mixture = torch.from_numpy(mixture) 135 | gt = torch.from_numpy(gt) 136 | 137 | maxval = float(torch.max(torch.abs(mixture))) 138 | # If maxval > 1, normalize so that input is between [-1, 1] 139 | if maxval > 1: 140 | mixture = mixture / maxval 141 | gt = gt / maxval 142 | 143 | maxval = 1 144 | 145 | # # Augment scale 146 | # if self.dset != 'test': 147 | # random_amplitude = np.random.uniform(0.2, 1) 148 | # random_scale = random_amplitude / maxval 149 | # mixture *= random_scale 150 | # gt *= random_scale 151 | # metadata['random_amplitude'] = random_amplitude 152 | 153 | if resampler is not None: 154 | mixture = resampler(mixture.to(torch.float)) 155 | gt = resampler(gt.to(torch.float)) 156 | 157 | return mixture, gt, label_vector, metadata 158 | -------------------------------------------------------------------------------- /src/training/datasets/semaudio_binaural_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch dataset object for synthetically rendered spatial data. 3 | """ 4 | 5 | import os 6 | import json 7 | import random 8 | from pathlib import Path 9 | import logging 10 | import warnings 11 | 12 | import numpy as np 13 | import pandas as pd 14 | import matplotlib.pyplot as plt 15 | import scaper 16 | import torch 17 | import torchaudio 18 | import torchaudio.transforms as AT 19 | import sofa 20 | from random import randrange 21 | 22 | # Ignore scaper normalization warnings 23 | warnings.filterwarnings( 24 | "ignore", message="Soundscape audio is clipping!") 25 | warnings.filterwarnings( 26 | "ignore", message="Peak normalization applied") 27 | 28 | class SemAudioBinauralBaseDataset(torch.utils.data.Dataset): # type: ignore 29 | """ 30 | Base class for FSD Sound Scapes dataset 31 | """ 32 | def __init__(self, fg_dir, bg_dir, jams_dir, hrtf_dir, dset, sr=None, 33 | resample_rate=None, max_num_targets=1): 34 | assert dset in ['train', 'val', 'test'], \ 35 | "`dset` must be one of ['train', 'val', 'test']" 36 | 37 | self.labels = None 38 | self.fg_dir = fg_dir 39 | self.bg_dir = bg_dir 40 | self.hrtf_dir = hrtf_dir 41 | self.jams_dir = jams_dir 42 | self.dset = dset 43 | logging.info("Loading dataset: jams=%s fg_dir=%s bg_dir=%s" % 44 | (self.jams_dir, self.fg_dir, self.bg_dir)) 45 | 46 | self.samples = sorted(list(Path(self.jams_dir).glob('[0-9]*'))) 47 | self.hrtf_dir = hrtf_dir 48 | 49 | self.max_num_targets = max_num_targets 50 | 51 | jamsfile = os.path.join(self.samples[0], 'mixture.jams') 52 | _, jams, _, _ = scaper.generate_from_jams( 53 | jamsfile, fg_path=self.fg_dir, bg_path=self.bg_dir) 54 | _sr = jams['annotations'][0]['sandbox']['scaper']['sr'] 55 | assert _sr == sr, "Sampling rate provided does not match the data" 56 | 57 | if resample_rate is not None: 58 | self.resampler = AT.Resample(sr, resample_rate) 59 | self.sr = resample_rate 60 | else: 61 | self.resampler = lambda a: a 62 | self.sr = sr 63 | 64 | def get_label_vector(self, labels): 65 | """ 66 | Generates a multi-hot vector corresponding to `labels`. 67 | """ 68 | vector = torch.zeros(len(self.labels)) 69 | 70 | for label in labels: 71 | idx = self.labels.index(label) 72 | assert vector[idx] == 0, "Repeated labels" 73 | vector[idx] = 1 74 | 75 | return vector 76 | 77 | def load_sample(self, sample_dir, hrtf_dir, fg_dir, bg_dir, num_targets, 78 | sample_targets=False, resampler=None): 79 | """ 80 | Loads a single sample: 81 | sample_dir: Path to sample directory (jams file + metadata JSON) 82 | hrtf_dir: Path to hrtf dataset (sofa files) 83 | fg_dir: Path to foreground dataset () 84 | bg_dir: Path to background dataset (TAU) 85 | num_targets: Number of gt targets to choose. 86 | sample_targets: Whether or not targets should be randomly chosen from the list 87 | channel: Channel index to select. If None, returns all channels 88 | """ 89 | 90 | sample_path = sample_dir 91 | 92 | # Load HRIR 93 | metadata_path = os.path.join(sample_path, 'metadata.json') 94 | with open(metadata_path, 'rb') as f: 95 | metadata = json.load(f) 96 | 97 | HRTF_path = os.path.join(hrtf_dir, os.path.basename(metadata['sofa'])) 98 | HRTF = sofa.Database.open(HRTF_path) 99 | 100 | # Load Audio 101 | jamsfile = os.path.join(sample_path, 'mixture.jams') 102 | mixture, jams, ann_list, event_audio_list = scaper.generate_from_jams( 103 | jamsfile, fg_path=fg_dir, bg_path=bg_dir) 104 | 105 | # Load source information 106 | sources = [] 107 | source_list = metadata['sources'] 108 | for i in range(len(source_list)): 109 | order = source_list[i]['order'] 110 | pos = source_list[i]['position'] 111 | rir_id = source_list[i]['hrtf_index'] 112 | label = source_list[i]['label'] 113 | if i == 0: 114 | assert label == 'background', "Background not first source" 115 | 116 | rir = HRTF.Data.IR.get_values(indices={"M":rir_id}) 117 | sources.append((order, i, pos, rir, label)) 118 | 119 | # Sort sources by order 120 | sources = sorted(sources, key=lambda x: x[0]) 121 | 122 | gt_events = [x[4] for x in sources] 123 | gt_events = gt_events[:-1] # Remove background from gt_events 124 | 125 | if sample_targets: 126 | labels = random.sample(gt_events, randrange(1,num_targets+1)) 127 | else: 128 | labels = gt_events[:num_targets] 129 | 130 | label_vector = self.get_label_vector(labels) 131 | 132 | mixture = np.zeros((2, mixture.shape[0])) 133 | 134 | gt = np.zeros_like(mixture) 135 | 136 | # Go over each source and convolve with HRTF 137 | metadata['chosen_sources'] = [] 138 | for source in sources: 139 | _, i, _, rir, label = source 140 | 141 | # Get audio event as single-channel 142 | a = event_audio_list[i] 143 | a = a[..., 0] 144 | 145 | # Convolve single-channel with HRTF to get binaural 146 | tmp = np.zeros_like(mixture) 147 | tmp[0] = np.convolve(a, rir[0], mode='same') 148 | tmp[1] = np.convolve(a, rir[1], mode='same') 149 | 150 | if label in labels: 151 | gt += tmp 152 | metadata['chosen_sources'].append(i) 153 | 154 | mixture += tmp 155 | 156 | mixture = torch.from_numpy(mixture) 157 | gt = torch.from_numpy(gt) 158 | 159 | maxval = (torch.max(torch.abs(mixture)) + 1e-6) 160 | mixture = mixture / maxval 161 | gt = gt / maxval 162 | 163 | if resampler is not None: 164 | mixture = resampler(mixture.to(torch.float)) 165 | gt = resampler(gt.to(torch.float)) 166 | 167 | # Add microphone positions to metadata 168 | mic_positions = HRTF.Receiver.Position.get_values(system="cartesian")[..., 0] 169 | metadata['mic_positions'] = mic_positions.tolist() 170 | 171 | return mixture, gt, label_vector, metadata 172 | 173 | def __len__(self): 174 | return len(self.samples) 175 | 176 | def __getitem__(self, idx): 177 | sample_dir = self.samples[idx] 178 | 179 | sample_targets=False 180 | if self.dset == 'train': 181 | num_targets = self.max_num_targets 182 | sample_targets=True 183 | elif self.dset == 'val': 184 | num_targets = idx%self.max_num_targets + 1 185 | elif self.dset == 'test': 186 | num_targets = self.max_num_targets 187 | 188 | mixture, gt, label_vector, metadata = \ 189 | self.load_sample( 190 | sample_dir=sample_dir, hrtf_dir=self.hrtf_dir, fg_dir=self.fg_dir, 191 | bg_dir=self.bg_dir, num_targets=num_targets, sample_targets=sample_targets, 192 | resampler=self.resampler) 193 | 194 | mixture = self.resampler(mixture.to(torch.float)) 195 | gt = self.resampler(gt.to(torch.float)) 196 | 197 | # Ground-truth shifts using cross-correlation between gt channels 198 | _gt = gt.numpy() 199 | _gt = _gt / np.max(np.abs(_gt), axis=1, keepdims=True) 200 | 201 | shift = np.argmax(np.correlate(_gt[0][32:-32], _gt[1])) - 32 202 | shift = torch.tensor(shift) 203 | 204 | inputs = { 205 | 'mixture': mixture, 206 | 'label_vector': label_vector, 207 | 'shift': shift, 208 | 'metadata': metadata, 209 | } 210 | 211 | return inputs, gt 212 | 213 | def to(self, inputs, gt, device): 214 | inputs['mixture'] = inputs['mixture'].to(device) 215 | inputs['label_vector'] = inputs['label_vector'].to(device) 216 | inputs['shift'] = inputs['shift'].to(device) 217 | gt = gt.to(device) 218 | return inputs, gt 219 | 220 | def output_to(self, output, device): 221 | for k, v in output.items(): 222 | output[k] = v.to(device) 223 | return output 224 | 225 | def output_detach(self, output): 226 | for k, v in output.items(): 227 | output[k] = v.detach() 228 | return output 229 | 230 | def collate_fn(self, batch): 231 | inputs, gt = zip(*batch) 232 | inputs = { 233 | 'mixture': torch.stack([i['mixture'] for i in inputs]), 234 | 'label_vector': torch.stack([i['label_vector'] for i in inputs]), 235 | 'shift': torch.stack([i['shift'] for i in inputs]), 236 | 'metadata': [i['metadata'] for i in inputs] 237 | } 238 | gt = torch.stack(gt) 239 | return inputs, gt 240 | 241 | def tensorboard_add_sample(self, writer, tag, sample, step): 242 | """ 243 | Adds a sample of FSDSynthDataset to tensorboard. 244 | """ 245 | resample_rate = 8000 if self.sr > 8000 else self.sr 246 | 247 | inputs, output, gt = sample 248 | m = inputs['mixture'] 249 | o = output['x'] 250 | labels = [] 251 | for lv in inputs['label_vector']: 252 | label = '' 253 | for i, l in enumerate(lv): 254 | if l == 1: 255 | label += self.labels[i] + ';' 256 | labels.append(label) 257 | 258 | # Resample to 8kHz and normalize 259 | m, gt, o = ( 260 | torchaudio.functional.resample(_, self.sr, resample_rate).cpu() 261 | for _ in (m, gt, o)) 262 | m, gt, o = (_ / _.abs().max() for _ in (m, gt, o)) 263 | 264 | def _add_audio(a, audio_tag, axis, plt_title): 265 | for i, ch in enumerate(a[:1]): 266 | axis.plot(ch, label='mic %d' % i) 267 | writer.add_audio( 268 | '%s/mic %d' % (audio_tag, i), ch.unsqueeze(0), step, resample_rate) 269 | axis.set_title(plt_title) 270 | axis.legend() 271 | 272 | n_samples = min(8, m.shape[0]) 273 | for b in range(n_samples): 274 | # Add waveforms 275 | rows = 3 # input, output, gt 276 | fig = plt.figure(figsize=(10, 2 * rows)) 277 | axes = fig.subplots(rows, 1, sharex=True) 278 | l = labels[b] 279 | _add_audio(m[b], '%s/sample_%d/0_input' % (tag, b), axes[0], "Mixed") 280 | _add_audio(o[b], '%s/sample_%d/1_out_%s' % (tag, b, l), axes[1], f"Out ({l})") 281 | _add_audio(gt[b], '%s/sample_%d/2_gt_%s' % (tag, b, l), axes[2], f"GT ({l})") 282 | writer.add_figure('%s/sample_%d/waveform' % (tag, b), fig, step) 283 | 284 | def tensorboard_add_metrics(self, writer, tag, metrics, step): 285 | """ 286 | Add metrics to tensorboard. 287 | """ 288 | vals = np.asarray(metrics['scale_invariant_signal_noise_ratio']) 289 | 290 | writer.add_histogram('%s/%s' % (tag, 'SI-SNRi'), vals, step) 291 | 292 | return 293 | -------------------------------------------------------------------------------- /src/training/dcc_tf.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import OrderedDict 3 | from typing import Optional 4 | import logging 5 | 6 | from torch import Tensor 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | from torchmetrics.functional import( 12 | scale_invariant_signal_noise_ratio as si_snr, 13 | signal_noise_ratio as snr, 14 | signal_distortion_ratio as sdr, 15 | scale_invariant_signal_distortion_ratio as si_sdr) 16 | 17 | from speechbrain.lobes.models.transformer.Transformer import PositionalEncoding 18 | 19 | def mod_pad(x, chunk_size, pad): 20 | # Mod pad the input to perform integer number of 21 | # inferences 22 | mod = 0 23 | if (x.shape[-1] % chunk_size) != 0: 24 | mod = chunk_size - (x.shape[-1] % chunk_size) 25 | 26 | x = F.pad(x, (0, mod)) 27 | x = F.pad(x, pad) 28 | 29 | return x, mod 30 | 31 | class LayerNormPermuted(nn.LayerNorm): 32 | def __init__(self, *args, **kwargs): 33 | super(LayerNormPermuted, self).__init__(*args, **kwargs) 34 | 35 | def forward(self, x): 36 | """ 37 | Args: 38 | x: [B, C, T] 39 | """ 40 | x = x.permute(0, 2, 1) # [B, T, C] 41 | x = super().forward(x) 42 | x = x.permute(0, 2, 1) # [B, C, T] 43 | return x 44 | 45 | class DepthwiseSeparableConv(nn.Module): 46 | """ 47 | Depthwise separable convolutions 48 | """ 49 | def __init__(self, in_channels, out_channels, kernel_size, stride, 50 | padding, dilation): 51 | super(DepthwiseSeparableConv, self).__init__() 52 | 53 | self.layers = nn.Sequential( 54 | nn.Conv1d(in_channels, in_channels, kernel_size, stride, 55 | padding, groups=in_channels, dilation=dilation), 56 | LayerNormPermuted(in_channels), 57 | nn.ReLU(), 58 | nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, 59 | padding=0), 60 | LayerNormPermuted(out_channels), 61 | nn.ReLU(), 62 | ) 63 | 64 | def forward(self, x): 65 | return self.layers(x) 66 | 67 | class DilatedCausalConvEncoder(nn.Module): 68 | """ 69 | A dilated causal convolution based encoder for encoding 70 | time domain audio input into latent space. 71 | """ 72 | def __init__(self, channels, num_layers, kernel_size=3): 73 | super(DilatedCausalConvEncoder, self).__init__() 74 | self.channels = channels 75 | self.num_layers = num_layers 76 | self.kernel_size = kernel_size 77 | 78 | # Compute buffer lengths for each layer 79 | # buf_length[i] = (kernel_size - 1) * dilation[i] 80 | self.buf_lengths = [(kernel_size - 1) * 2**i 81 | for i in range(num_layers)] 82 | 83 | # Compute buffer start indices for each layer 84 | self.buf_indices = [0] 85 | for i in range(num_layers - 1): 86 | self.buf_indices.append( 87 | self.buf_indices[-1] + self.buf_lengths[i]) 88 | 89 | # Dilated causal conv layers aggregate previous context to obtain 90 | # contexful encoded input. 91 | _dcc_layers = OrderedDict() 92 | for i in range(num_layers): 93 | dcc_layer = DepthwiseSeparableConv( 94 | channels, channels, kernel_size=3, stride=1, 95 | padding=0, dilation=2**i) 96 | _dcc_layers.update({'dcc_%d' % i: dcc_layer}) 97 | self.dcc_layers = nn.Sequential(_dcc_layers) 98 | 99 | def init_ctx_buf(self, batch_size, device): 100 | """ 101 | Returns an initialized context buffer for a given batch size. 102 | """ 103 | return torch.zeros( 104 | (batch_size, self.channels, 105 | (self.kernel_size - 1) * (2**self.num_layers - 1)), 106 | device=device) 107 | 108 | def forward(self, x, ctx_buf): 109 | """ 110 | Encodes input audio `x` into latent space, and aggregates 111 | contextual information in `ctx_buf`. Also generates new context 112 | buffer with updated context. 113 | Args: 114 | x: [B, in_channels, T] 115 | Input multi-channel audio. 116 | ctx_buf: {[B, channels, self.buf_length[0]], ...} 117 | A list of tensors holding context for each dilation 118 | causal conv layer. (len(ctx_buf) == self.num_layers) 119 | Returns: 120 | ctx_buf: {[B, channels, self.buf_length[0]], ...} 121 | Updated context buffer with output as the 122 | last element. 123 | """ 124 | T = x.shape[-1] # Sequence length 125 | 126 | for i in range(self.num_layers): 127 | buf_start_idx = self.buf_indices[i] 128 | buf_end_idx = self.buf_indices[i] + self.buf_lengths[i] 129 | 130 | # DCC input: concatenation of current output and context 131 | dcc_in = torch.cat( 132 | (ctx_buf[..., buf_start_idx:buf_end_idx], x), dim=-1) 133 | 134 | # Push current output to the context buffer 135 | ctx_buf[..., buf_start_idx:buf_end_idx] = \ 136 | dcc_in[..., -self.buf_lengths[i]:] 137 | 138 | # Residual connection 139 | x = x + self.dcc_layers[i](dcc_in) 140 | 141 | return x, ctx_buf 142 | 143 | class CausalTransformerDecoderLayer(torch.nn.TransformerDecoderLayer): 144 | """ 145 | Adapted from: 146 | "https://github.com/alexmt-scale/causal-transformer-decoder/blob/" 147 | "0caf6ad71c46488f76d89845b0123d2550ef792f/" 148 | "causal_transformer_decoder/model.py#L77" 149 | """ 150 | def forward( 151 | self, 152 | tgt: Tensor, 153 | memory: Optional[Tensor] = None, 154 | chunk_size: int = 1 155 | ) -> Tensor: 156 | tgt_last_tok = tgt[:, -chunk_size:, :] 157 | 158 | # self attention part 159 | tmp_tgt, sa_map = self.self_attn( 160 | tgt_last_tok, 161 | tgt, 162 | tgt, 163 | attn_mask=None, # not needed because we only care about the last token 164 | key_padding_mask=None, 165 | ) 166 | tgt_last_tok = tgt_last_tok + self.dropout1(tmp_tgt) 167 | tgt_last_tok = self.norm1(tgt_last_tok) 168 | 169 | # encoder-decoder attention 170 | ca_map = None 171 | if memory is not None: 172 | tmp_tgt, ca_map = self.multihead_attn( 173 | tgt_last_tok, 174 | memory, 175 | memory, 176 | attn_mask=None, # Attend to the entire chunk 177 | key_padding_mask=None, 178 | ) 179 | tgt_last_tok = tgt_last_tok + self.dropout2(tmp_tgt) 180 | tgt_last_tok = self.norm2(tgt_last_tok) 181 | 182 | # final feed-forward network 183 | tmp_tgt = self.linear2( 184 | self.dropout(self.activation(self.linear1(tgt_last_tok))) 185 | ) 186 | tgt_last_tok = tgt_last_tok + self.dropout3(tmp_tgt) 187 | tgt_last_tok = self.norm3(tgt_last_tok) 188 | return tgt_last_tok, sa_map, ca_map 189 | 190 | class CausalTransformerDecoder(nn.Module): 191 | """ 192 | A casual transformer decoder which decodes input vectors using 193 | precisely `ctx_len` past vectors in the sequence, and using no future 194 | vectors at all. 195 | """ 196 | def __init__(self, model_dim, ctx_len, chunk_size, num_layers, 197 | nhead, use_pos_enc, ff_dim, conditioning='conv'): 198 | super(CausalTransformerDecoder, self).__init__() 199 | self.num_layers = num_layers 200 | self.model_dim = model_dim 201 | self.ctx_len = ctx_len 202 | self.chunk_size = chunk_size 203 | self.nhead = nhead 204 | self.use_pos_enc = use_pos_enc 205 | self.unfold = nn.Unfold(kernel_size=(ctx_len + chunk_size, 1), stride=chunk_size) 206 | self.pos_enc_tgt = PositionalEncoding(model_dim, max_len=1000) 207 | self.pos_enc_mem = PositionalEncoding(model_dim, max_len=100) 208 | self.tf_dec_layers = nn.ModuleList([CausalTransformerDecoderLayer( 209 | d_model=model_dim, nhead=nhead, dim_feedforward=ff_dim, 210 | batch_first=True) for _ in range(num_layers)]) 211 | self.conditioning = conditioning 212 | 213 | if conditioning == 'film': 214 | self.film = nn.Sequential( 215 | nn.Linear(model_dim, 2 * model_dim), 216 | nn.ReLU()) 217 | 218 | def init_ctx_buf(self, batch_size, device): 219 | return torch.zeros( 220 | (batch_size, self.num_layers + 1, self.ctx_len, self.model_dim), 221 | device=device) 222 | 223 | def _causal_unfold(self, x): 224 | """ 225 | Unfolds the sequence into a batch of sequences 226 | prepended with `ctx_len` previous values. 227 | 228 | Args: 229 | x: [B, ctx_len + L, C] 230 | ctx_len: int 231 | Returns: 232 | [B * L, ctx_len + 1, C] 233 | """ 234 | B, T, C = x.shape 235 | x = x.permute(0, 2, 1) # [B, C, ctx_len + L] 236 | x = self.unfold(x.unsqueeze(-1)) # [B, C * (ctx_len + chunk_size), -1] 237 | x = x.permute(0, 2, 1) 238 | x = x.reshape(B, -1, C, self.ctx_len + self.chunk_size) 239 | x = x.reshape(-1, C, self.ctx_len + self.chunk_size) 240 | x = x.permute(0, 2, 1) 241 | return x 242 | 243 | def forward(self, input, embedding, ctx_buf, K=4000): 244 | """ 245 | Args: 246 | input: [B, model_dim, T] 247 | embedding: [B, NE, model_dim, embed_len] 248 | ctx_buf: [B, num_layers, ctx_len, model_dim] 249 | K: int 250 | Number of batches to process at once to avoid OOM. 251 | Returns: 252 | output: [B, model_dim, T] 253 | ctx_buf: [B, num_layers, ctx_len, model_dim] 254 | """ 255 | 256 | # Mod pad the input so the sequence length is a multiple 257 | # of chunk_size. 258 | input, mod = mod_pad(input, self.chunk_size, (0, 0)) 259 | 260 | # Init 261 | B, C, T = input.shape 262 | output = input.permute(0, 2, 1).contiguous() 263 | mem = None 264 | 265 | if self.conditioning == 'conv': 266 | # Convolutional/mutltiplicative conditioning 267 | input = input.view(1, B * C, T) 268 | input = F.pad( 269 | input, (embedding.shape[-1] - 1, 0)) # [1, B * C, T + embed_len - 1] 270 | emb_filter = torch.mean(embedding, dim=1).reshape(B * C, 1, -1) 271 | output = F.conv1d(input, emb_filter, groups=B * C) 272 | output = output.view(B, C, T) 273 | output = output.permute(0, 2, 1) 274 | elif self.conditioning == 'attn': 275 | # Use cross attn for conditioning 276 | mem = embedding.permute(0, 1, 3, 2) # [B, NE, embed_len, C] 277 | if self.use_pos_enc: 278 | mem = mem.view(-1, mem.shape[-2], mem.shape[-1]) 279 | mem = mem + self.pos_enc_mem(mem) 280 | mem = mem.view(B, -1, mem.shape[-2], mem.shape[-1]) 281 | mem = mem.reshape(B, -1, mem.shape[-1]) # [B, NE * embed_len, C] 282 | mem = mem.unsqueeze(1).repeat( 283 | 1, (T // self.chunk_size), 1, 1 284 | ) # [B, T // chunk_size, NE * embed_len, C] 285 | mem = mem.reshape( 286 | -1, mem.shape[-2], mem.shape[-1] 287 | ) # [B * (T // chunk_size), NE * embed_len, C] 288 | elif self.conditioning == 'film': 289 | # Use FILM for conditioning 290 | emb_filter = torch.mean(embedding, dim=(1, 3)) # [B, C] 291 | emb_filter = self.film(emb_filter) # [B, 2 * C] 292 | gamma, beta = emb_filter.chunk(2, dim=-1) 293 | output = output * gamma.unsqueeze(1) + beta.unsqueeze(1) 294 | else: 295 | emb_filter = torch.mean(embedding, dim=(1, 3)) # [B, C] 296 | output = output * emb_filter.unsqueeze(1) # [B, T, C] 297 | 298 | for i, layer in enumerate(self.tf_dec_layers): 299 | # Prepend the context to the input and update the context 300 | # [B, ctx_len + T, C] 301 | tgt = torch.cat([ctx_buf[:, i, :, :], output], dim=1) 302 | ctx_buf[:, i, :, :] = tgt[:, -self.ctx_len:, :] 303 | 304 | # Unfold the sequence into a batch of sequences prepended 305 | # with `ctx_len` previous values. 306 | # [B * (T // chunk_size), ctx_len + chunk_size, C] 307 | tgt = self._causal_unfold(tgt) 308 | 309 | # Positional encoding 310 | if i == 0 and self.use_pos_enc: 311 | tgt = tgt + self.pos_enc_tgt(tgt) 312 | 313 | _tgt = torch.zeros_like(tgt)[:, :self.chunk_size, :] 314 | for k in range(int(math.ceil(tgt.shape[0] / K))): 315 | s, e = k * K, (k + 1) * K 316 | _mem = None if mem is None else mem[s:e] 317 | _tgt[s:e], _, _ = layer(tgt[s:e], _mem, self.chunk_size) 318 | 319 | output = _tgt.reshape(B, T, C) 320 | 321 | # Remove the mod padding 322 | output = output.permute(0, 2, 1) 323 | if mod != 0: 324 | output = output[:, :, :-mod] 325 | 326 | return output, ctx_buf 327 | 328 | class MaskNet(nn.Module): 329 | def __init__(self, model_dim, num_enc_layers, dec_buf_len, 330 | dec_chunk_size, num_dec_layers, use_pos_enc, conditioning): 331 | super(MaskNet, self).__init__() 332 | 333 | # Encoder based on dilated causal convolutions. 334 | self.encoder = DilatedCausalConvEncoder(channels=model_dim, 335 | num_layers=num_enc_layers) 336 | 337 | # Transformer decoder that operates on chunks of size 338 | # buffer size. 339 | self.decoder = CausalTransformerDecoder( 340 | model_dim=model_dim, ctx_len=dec_buf_len, chunk_size=dec_chunk_size, 341 | num_layers=num_dec_layers, nhead=8, use_pos_enc=use_pos_enc, 342 | ff_dim=2 * model_dim, conditioning=conditioning) 343 | 344 | def forward(self, x, l, enc_buf, dec_buf): 345 | """ 346 | Generates a mask based on encoded input `e` and the one-hot 347 | label `label`. 348 | 349 | Args: 350 | x: [B, C, T] 351 | Input audio sequence 352 | l: [B, C] 353 | Label embedding 354 | ctx_buf: {[B, C, ], ...} 355 | List of context buffers maintained by DCC encoder 356 | """ 357 | # Enocder the label integrated input 358 | e, enc_buf = self.encoder(x, enc_buf) 359 | 360 | # Decoder conditioned on embedding 361 | m, dec_buf = self.decoder(input=e, embedding=l, ctx_buf=dec_buf) 362 | 363 | return m, enc_buf, dec_buf 364 | 365 | class Net(nn.Module): 366 | def __init__(self, label_len, L=8, 367 | model_dim=512, num_enc_layers=10, 368 | dec_buf_len=100, num_dec_layers=2, 369 | dec_chunk_size=72, out_buf_len=2, 370 | use_pos_enc=True, conditioning="mult", lookahead=True): 371 | super(Net, self).__init__() 372 | self.L = L 373 | self.out_buf_len = out_buf_len 374 | self.model_dim = model_dim 375 | self.lookahead = lookahead 376 | 377 | # Input conv to convert input audio to a latent representation 378 | kernel_size = 3 * L if lookahead else L 379 | self.in_conv = nn.Sequential( 380 | nn.Conv1d(in_channels=1, 381 | out_channels=model_dim, kernel_size=kernel_size, stride=L, 382 | padding=0, bias=False), 383 | nn.ReLU()) 384 | 385 | # Label embedding layer 386 | self.label_embedding = nn.Sequential( 387 | nn.Linear(label_len, 512), 388 | nn.LayerNorm(512), 389 | nn.ReLU(), 390 | nn.Linear(512, model_dim), 391 | nn.LayerNorm(model_dim), 392 | nn.ReLU()) 393 | 394 | # Mask generator 395 | self.mask_gen = MaskNet( 396 | model_dim=model_dim, num_enc_layers=num_enc_layers, 397 | dec_buf_len=dec_buf_len, 398 | dec_chunk_size=dec_chunk_size, num_dec_layers=num_dec_layers, 399 | use_pos_enc=use_pos_enc, conditioning=conditioning) 400 | 401 | # Output conv layer 402 | self.out_conv = nn.Sequential( 403 | nn.ConvTranspose1d( 404 | in_channels=model_dim, out_channels=1, 405 | kernel_size=(out_buf_len + 1) * L, 406 | stride=L, 407 | padding=out_buf_len * L, bias=False), 408 | nn.Tanh()) 409 | 410 | def init_buffers(self, batch_size, device): 411 | enc_buf = self.mask_gen.encoder.init_ctx_buf(batch_size, device) 412 | dec_buf = self.mask_gen.decoder.init_ctx_buf(batch_size, device) 413 | out_buf = torch.zeros(batch_size, self.model_dim, self.out_buf_len, 414 | device=device) 415 | return enc_buf, dec_buf, out_buf 416 | 417 | def predict(self, x, label, enc_buf, dec_buf, out_buf, pad=True): 418 | mod = 0 419 | if pad: 420 | pad_size = (self.L, self.L) if self.lookahead else (0, 0) 421 | x, mod = mod_pad(x, chunk_size=self.L, pad=pad_size) 422 | 423 | # Generate latent space representation of the input 424 | x = self.in_conv(x) 425 | 426 | # Generate label embedding 427 | l = self.label_embedding(label) # [B, label_len] --> [B, channels] 428 | l = l.unsqueeze(1).unsqueeze(-1) # [B, 1, channels, 1] 429 | 430 | # Generate mask corresponding to the label 431 | m, enc_buf, dec_buf = self.mask_gen(x, l, enc_buf, dec_buf) 432 | 433 | # Apply mask and decode 434 | x = x * m 435 | x = torch.cat((out_buf, x), dim=-1) 436 | out_buf = x[..., -self.out_buf_len:] 437 | x = self.out_conv(x) 438 | 439 | # Remove mod padding, if present. 440 | if mod != 0: 441 | x = x[:, :, :-mod] 442 | 443 | return x, enc_buf, dec_buf, out_buf 444 | 445 | def forward(self, inputs, init_enc_buf=None, init_dec_buf=None, 446 | init_out_buf=None, pad=True): 447 | """ 448 | Extracts the audio corresponding to the `label` in the given 449 | `mixture`. Generates `chunk_size` samples per iteration. 450 | Args: 451 | mixed: [B, n_mics, T] 452 | input audio mixture 453 | label: [B, num_labels] 454 | one hot label 455 | Returns: 456 | out: [B, n_spk, T] 457 | extracted audio with sounds corresponding to the `label` 458 | """ 459 | x, label = inputs['mixture'], inputs['label_vector'] 460 | 461 | if init_enc_buf is None or init_dec_buf is None or init_out_buf is None: 462 | assert init_enc_buf is None and \ 463 | init_dec_buf is None and \ 464 | init_out_buf is None, \ 465 | "Both buffers have to initialized, or " \ 466 | "both of them have to be None." 467 | enc_buf, dec_buf, out_buf = self.init_buffers( 468 | x.shape[0], x.device) 469 | else: 470 | enc_buf, dec_buf, out_buf = \ 471 | init_enc_buf, init_dec_buf, init_out_buf 472 | 473 | x, enc_buf, dec_buf, out_buf = self.predict( 474 | x, label, enc_buf, dec_buf, out_buf) 475 | 476 | if init_enc_buf is None: 477 | return x 478 | else: 479 | return x, enc_buf, dec_buf, out_buf 480 | 481 | # Define optimizer, loss and metrics 482 | 483 | def optimizer(model, data_parallel=False, **kwargs): 484 | # Trainable parameters 485 | params = [p for p in model.parameters() if p.requires_grad] 486 | return optim.Adam(params, **kwargs) 487 | 488 | def loss(pred, tgt): 489 | return -si_snr(pred, tgt).mean() 490 | 491 | def metrics(mixed, output, gt): 492 | """ Function to compute metrics """ 493 | metrics = {} 494 | 495 | def metric_i(metric, src, pred, tgt): 496 | _vals = [] 497 | for s, t, p in zip(src, tgt, pred): 498 | _vals.append((metric(p, t) - metric(s, t)).cpu().item()) 499 | return _vals 500 | 501 | for m_fn in [snr, si_snr]: 502 | metrics[m_fn.__name__] = metric_i(m_fn, 503 | mixed[:, :gt.shape[1], :], 504 | output, 505 | gt) 506 | 507 | return metrics 508 | 509 | if __name__ == '__main__': 510 | net = CausalTransformerDecoder( 511 | model_dim=8, ctx_len=4, chunk_size=4, num_layers=2, nhead=4, conditioning='attn', 512 | use_pos_enc=True, ff_dim=16 513 | ) 514 | x = torch.randn(2, 8, 16) 515 | e = torch.randn(2, 2, 8, 2) 516 | buf = torch.rand(2, 2, 4, 8) 517 | out = net(x, e, buf) 518 | print(out[0].shape, out[1].shape) 519 | -------------------------------------------------------------------------------- /src/training/dcc_tf_binaural.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from collections import OrderedDict 4 | from typing import Optional 5 | import logging 6 | from copy import deepcopy 7 | 8 | import numpy as np 9 | from torch import Tensor 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.optim as optim 14 | import torchaudio 15 | from torchmetrics.functional import( 16 | scale_invariant_signal_noise_ratio as si_snr, 17 | signal_noise_ratio as snr, 18 | signal_distortion_ratio as sdr, 19 | scale_invariant_signal_distortion_ratio as si_sdr) 20 | 21 | from src.training.dcc_tf import mod_pad, MaskNet 22 | from src.helpers.eval_utils import itd_diff, ild_diff 23 | 24 | class Net(nn.Module): 25 | def __init__(self, label_len, L=8, 26 | model_dim=512, num_enc_layers=10, 27 | dec_buf_len=100, num_dec_layers=2, 28 | dec_chunk_size=72, out_buf_len=2, 29 | use_pos_enc=True, conditioning="mult", lookahead=True, 30 | pretrained_path=None): 31 | super(Net, self).__init__() 32 | self.L = L 33 | self.out_buf_len = out_buf_len 34 | self.model_dim = model_dim 35 | self.lookahead = lookahead 36 | 37 | # Input conv to convert input audio to a latent representation 38 | kernel_size = 3 * L if lookahead else L 39 | self.in_conv = nn.Sequential( 40 | nn.Conv1d(in_channels=2, 41 | out_channels=model_dim, kernel_size=kernel_size, stride=L, 42 | padding=0, bias=False), 43 | nn.ReLU()) 44 | 45 | # Label embedding layer 46 | self.label_embedding = nn.Sequential( 47 | nn.Linear(label_len, 512), 48 | nn.LayerNorm(512), 49 | nn.ReLU(), 50 | nn.Linear(512, model_dim), 51 | nn.LayerNorm(model_dim), 52 | nn.ReLU()) 53 | 54 | # Mask generator 55 | self.mask_gen = MaskNet( 56 | model_dim=model_dim, num_enc_layers=num_enc_layers, 57 | dec_buf_len=dec_buf_len, 58 | dec_chunk_size=dec_chunk_size, num_dec_layers=num_dec_layers, 59 | use_pos_enc=use_pos_enc, conditioning=conditioning) 60 | 61 | # Output conv layer 62 | self.out_conv = nn.Sequential( 63 | nn.ConvTranspose1d( 64 | in_channels=model_dim, out_channels=2, 65 | kernel_size=(out_buf_len + 1) * L, 66 | stride=L, 67 | padding=out_buf_len * L, bias=False), 68 | nn.Tanh()) 69 | 70 | if pretrained_path is not None: 71 | state_dict = torch.load(pretrained_path)['model_state_dict'] 72 | 73 | # Load all the layers except label_embedding and freeze them 74 | for name, param in self.named_parameters(): 75 | if 'label_embedding' not in name: 76 | param.data = state_dict[name] 77 | param.requires_grad = False 78 | 79 | def init_buffers(self, batch_size, device): 80 | enc_buf = self.mask_gen.encoder.init_ctx_buf(batch_size, device) 81 | dec_buf = self.mask_gen.decoder.init_ctx_buf(batch_size, device) 82 | out_buf = torch.zeros(batch_size, self.model_dim, self.out_buf_len, 83 | device=device) 84 | return enc_buf, dec_buf, out_buf 85 | 86 | def predict(self, x, label, enc_buf, dec_buf, out_buf): 87 | # Generate latent space representation of the input 88 | x = self.in_conv(x) 89 | 90 | # Generate label embedding 91 | l = self.label_embedding(label) # [B, label_len] --> [B, channels] 92 | l = l.unsqueeze(1).unsqueeze(-1) # [B, 1, channels, 1] 93 | 94 | # Generate mask corresponding to the label 95 | m, enc_buf, dec_buf = self.mask_gen(x, l, enc_buf, dec_buf) 96 | 97 | # Apply mask and decode 98 | x = x * m 99 | x = torch.cat((out_buf, x), dim=-1) 100 | out_buf = x[..., -self.out_buf_len:] 101 | x = self.out_conv(x) 102 | 103 | return x, enc_buf, dec_buf, out_buf 104 | 105 | def forward(self, inputs, init_enc_buf=None, init_dec_buf=None, 106 | init_out_buf=None, pad=True, writer=None, step=None, idx=None): 107 | """ 108 | Extracts the audio corresponding to the `label` in the given 109 | `mixture`. Generates `chunk_size` samples per iteration. 110 | Args: 111 | mixed: [B, n_mics, T] 112 | input audio mixture 113 | label: [B, num_labels] 114 | one hot label 115 | Returns: 116 | out: [B, n_spk, T] 117 | extracted audio with sounds corresponding to the `label` 118 | """ 119 | x, label = inputs['mixture'], inputs['label_vector'] 120 | 121 | if init_enc_buf is None or init_dec_buf is None or init_out_buf is None: 122 | assert init_enc_buf is None and \ 123 | init_dec_buf is None and \ 124 | init_out_buf is None, \ 125 | "Both buffers have to initialized, or " \ 126 | "both of them have to be None." 127 | enc_buf, dec_buf, out_buf = self.init_buffers( 128 | x.shape[0], x.device) 129 | else: 130 | enc_buf, dec_buf, out_buf = \ 131 | init_enc_buf, init_dec_buf, init_out_buf 132 | 133 | mod = 0 134 | if pad: 135 | pad_size = (self.L, self.L) if self.lookahead else (0, 0) 136 | x, mod = mod_pad(x, chunk_size=self.L, pad=pad_size) 137 | 138 | x, enc_buf, dec_buf, out_buf = self.predict( 139 | x, label, enc_buf, dec_buf, out_buf) 140 | 141 | # Remove mod padding, if present. 142 | if mod != 0: 143 | x = x[:, :, :-mod] 144 | 145 | out = {'x': x} 146 | 147 | if init_enc_buf is None: 148 | return out 149 | else: 150 | return out, enc_buf, dec_buf, out_buf 151 | 152 | # Define optimizer, loss and metrics 153 | 154 | def optimizer(model, data_parallel=False, **kwargs): 155 | params = [p for p in model.parameters() if p.requires_grad] 156 | return optim.Adam(params, **kwargs) 157 | 158 | def loss(_output, tgt): 159 | pred = _output['x'] 160 | return -0.9 * snr(pred, tgt).mean() - 0.1 * si_snr(pred, tgt).mean() 161 | 162 | def metrics(inputs, _output, gt): 163 | """ Function to compute metrics """ 164 | mixed = inputs['mixture'] 165 | output = _output['x'] 166 | metrics = {} 167 | 168 | def metric_i(metric, src, pred, tgt): 169 | _vals = [] 170 | for s, t, p in zip(src, tgt, pred): 171 | _vals.append(torch.mean((metric(p, t) - metric(s, t))).cpu().item()) 172 | return _vals 173 | 174 | for m_fn in [snr, si_snr]: 175 | metrics[m_fn.__name__] = metric_i(m_fn, 176 | mixed[:, :gt.shape[1], :], 177 | output, 178 | gt) 179 | 180 | return metrics 181 | 182 | def test_metrics(inputs, _output, gt): 183 | test_metrics = metrics(inputs, _output, gt) 184 | output = _output['x'] 185 | delta_itds, delta_ilds, snrs = [], [], [] 186 | for o, g in zip(output, gt): 187 | delta_itds.append(itd_diff(o.cpu(), g.cpu(), sr=44100)) 188 | delta_ilds.append(ild_diff(o.cpu().numpy(), g.cpu().numpy())) 189 | snrs.append(torch.mean(si_snr(o, g)).cpu().item()) 190 | test_metrics['delta_ITD'] = delta_itds 191 | test_metrics['delta_ILD'] = delta_ilds 192 | test_metrics['si_snr'] = snrs 193 | return test_metrics 194 | 195 | def format_results(idx, inputs, output, gt, metrics, output_dir=None): 196 | results = metrics 197 | results['metadata'] = inputs['metadata'] 198 | results = deepcopy(results) 199 | 200 | # Save audio 201 | if output_dir is not None: 202 | output = output['x'] 203 | for i in range(output.shape[0]): 204 | out_dir = os.path.join(output_dir, f'{idx + i:03d}') 205 | os.makedirs(out_dir) 206 | torchaudio.save( 207 | os.path.join(out_dir, 'mixture.wav'), inputs['mixture'][i], 44100) 208 | torchaudio.save( 209 | os.path.join(out_dir, 'gt.wav'), gt[i], 44100) 210 | torchaudio.save( 211 | os.path.join(out_dir, 'output.wav'), output[i], 44100) 212 | 213 | return results 214 | 215 | if __name__ == "__main__": 216 | torch.random.manual_seed(0) 217 | 218 | model = Net(41) 219 | model.eval() 220 | 221 | with torch.no_grad(): 222 | x = torch.randn(1, 2, 417) 223 | emb = torch.randn(1, 41) 224 | 225 | y = model({'mixture': x, 'label_vector': emb}) 226 | 227 | print(f'{y.shape=}') 228 | print(f"First channel data:\n{y[0, 0]}") 229 | -------------------------------------------------------------------------------- /src/training/eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test script to evaluate the model. 3 | """ 4 | 5 | import argparse 6 | import importlib 7 | import multiprocessing 8 | import os, glob 9 | import logging 10 | import json 11 | 12 | import numpy as np 13 | import torch 14 | import pandas as pd 15 | import torch.nn as nn 16 | from torch.utils.tensorboard import SummaryWriter 17 | from torch.profiler import profile, record_function, ProfilerActivity 18 | from tqdm import tqdm # pylint: disable=unused-import 19 | from torchmetrics.functional import( 20 | scale_invariant_signal_noise_ratio as si_snr, 21 | signal_noise_ratio as snr, 22 | signal_distortion_ratio as sdr, 23 | scale_invariant_signal_distortion_ratio as si_sdr) 24 | 25 | from src.helpers import utils 26 | 27 | def test_epoch(model: nn.Module, device: torch.device, 28 | test_loader: torch.utils.data.dataloader.DataLoader, 29 | n_items: int, loss_fn, metrics_fn, 30 | results_fn = None, results_path: str = None, output_dir: str = None, 31 | profiling: bool = False, epoch: int = 0, 32 | writer: SummaryWriter = None) -> float: 33 | """ 34 | Evaluate the network. 35 | """ 36 | model.eval() 37 | metrics = {} 38 | losses = [] 39 | runtimes = [] 40 | results = [] 41 | 42 | with torch.no_grad(): 43 | for batch_idx, (inp, tgt) in \ 44 | enumerate(tqdm(test_loader, desc='Test', ncols=100)): 45 | # Move data to device 46 | inp, tgt = test_loader.dataset.to(inp, tgt, device) 47 | 48 | # Run through the model 49 | if profiling: 50 | with profile(activities=[ProfilerActivity.CPU], 51 | record_shapes=True) as prof: 52 | with record_function("model_inference"): 53 | output = model(inp, writer=writer, step=epoch, idx=batch_idx) 54 | if profiling: 55 | logging.info( 56 | prof.key_averages().table(sort_by="self_cpu_time_total", 57 | row_limit=20)) 58 | else: 59 | output = model(inp, writer=writer, step=epoch, idx=batch_idx) 60 | 61 | # Compute loss 62 | loss = loss_fn(output, tgt) 63 | 64 | # Compute metrics 65 | metrics_batch = metrics_fn(inp, output, tgt) 66 | for k in metrics_batch.keys(): 67 | if not k in metrics: 68 | metrics[k] = metrics_batch[k] 69 | else: 70 | metrics[k] += metrics_batch[k] 71 | 72 | output = test_loader.dataset.output_to(output, 'cpu') 73 | inp, tgt = test_loader.dataset.to(inp, tgt, 'cpu') 74 | 75 | # Results to save 76 | if results_path is not None: 77 | results.append(results_fn( 78 | batch_idx * test_loader.batch_size, 79 | inp, output, tgt, metrics_batch, output_dir=output_dir)) 80 | 81 | losses += [loss.item()] 82 | if profiling: 83 | runtimes += [ # Runtime per sample in ms 84 | prof.profiler.self_cpu_time_total / (test_loader.batch_size * 1e3)] 85 | else: 86 | runtimes += [0.0] 87 | 88 | output = test_loader.dataset.output_to(output, 'cpu') 89 | inp, tgt = test_loader.dataset.to(inp, tgt, 'cpu') 90 | if writer is not None: 91 | if batch_idx == 0: 92 | test_loader.dataset.tensorboard_add_sample( 93 | writer, tag='Test', 94 | sample=(inp, output, tgt), 95 | step=epoch) 96 | test_loader.dataset.tensorboard_add_metrics( 97 | writer, tag='Test', metrics=metrics_batch, step=epoch) 98 | 99 | if n_items is not None and batch_idx == (n_items - 1): 100 | break 101 | 102 | if results_path is not None: 103 | torch.save(results, results_path) 104 | logging.info("Saved results to %s" % results_path) 105 | 106 | avg_metrics = {k: np.mean(metrics[k]) for k in metrics.keys()} 107 | avg_metrics['loss'] = np.mean(losses) 108 | avg_metrics['runtime'] = np.mean(runtimes) 109 | avg_metrics_str = "Test:" 110 | for m in avg_metrics.keys(): 111 | avg_metrics_str += ' %s=%.04f' % (m, avg_metrics[m]) 112 | logging.info(avg_metrics_str) 113 | 114 | return avg_metrics 115 | 116 | def evaluate(network, args: argparse.Namespace): 117 | """ 118 | Evaluate the model on a given dataset. 119 | """ 120 | 121 | # Load dataset 122 | data_test = utils.import_attr(args.test_dataset)(**args.test_data_args) 123 | logging.info("Loaded test dataset %d elements" % len(data_test)) 124 | 125 | # Set up the device and workers. 126 | use_cuda = args.use_cuda and torch.cuda.is_available() 127 | if use_cuda: 128 | gpu_ids = args.gpu_ids if args.gpu_ids is not None\ 129 | else range(torch.cuda.device_count()) 130 | device_ids = [_ for _ in gpu_ids] 131 | data_parallel = len(device_ids) > 1 132 | device = 'cuda:%d' % device_ids[0] 133 | torch.cuda.set_device(device_ids[0]) 134 | logging.info("Using CUDA devices: %s" % str(device_ids)) 135 | else: 136 | data_parallel = False 137 | device = torch.device('cpu') 138 | logging.info("Using device: CPU") 139 | 140 | # Set multiprocessing params 141 | num_workers = min(multiprocessing.cpu_count(), args.n_workers) 142 | kwargs = { 143 | 'num_workers': num_workers, 144 | 'pin_memory': True 145 | } if use_cuda else {} 146 | 147 | # Set up data loader 148 | test_loader = torch.utils.data.DataLoader( 149 | data_test, batch_size=args.eval_batch_size, collate_fn=data_test.collate_fn, 150 | **kwargs) 151 | 152 | # Set up model 153 | model = network.Net(**args.model_params) 154 | if use_cuda and data_parallel: 155 | model = nn.DataParallel(model, device_ids=device_ids) 156 | logging.info("Using data parallel model") 157 | model.to(device) 158 | 159 | # Load weights 160 | if args.pretrain_path == "best": 161 | ckpts = glob.glob(os.path.join(args.exp_dir, '*.pt')) 162 | ckpts.sort( 163 | key=lambda _: int(os.path.splitext(os.path.basename(_))[0])) 164 | val_metrics = torch.load(ckpts[-1])['val_metrics'][args.base_metric] 165 | best_epoch = max(range(len(val_metrics)), key=val_metrics.__getitem__) 166 | args.pretrain_path = os.path.join(args.exp_dir, '%d.pt' % best_epoch) 167 | logging.info( 168 | "Found 'best' validation %s=%.02f at %s" % 169 | (args.base_metric, val_metrics[best_epoch], args.pretrain_path)) 170 | if args.pretrain_path != "": 171 | utils.load_checkpoint( 172 | args.pretrain_path, model, data_parallel=data_parallel) 173 | logging.info("Loaded pretrain weights from %s" % args.pretrain_path) 174 | 175 | # Results csv file 176 | results_fn = network.format_results 177 | results_path = os.path.join(args.exp_dir, 'results.eval.pth') 178 | if args.output_dir is not None: 179 | os.makedirs(args.output_dir, exist_ok=True) 180 | 181 | # Evaluate 182 | try: 183 | return test_epoch( 184 | model, device, test_loader, args.n_items, network.loss, 185 | network.test_metrics, results_fn, results_path, args.output_dir, args.profiling) 186 | except KeyboardInterrupt: 187 | print("Interrupted") 188 | except Exception as _: # pylint: disable=broad-except 189 | import traceback # pylint: disable=import-outside-toplevel 190 | traceback.print_exc() 191 | 192 | def get_unique_hparams(exps): 193 | """ 194 | Return a list of unique hyperparameters across the set of experiments. 195 | """ 196 | # Read config files into a dataframe 197 | configs = [] 198 | for e in exps: 199 | with open(os.path.join(e, 'config.json')) as f: 200 | configs.append(pd.json_normalize(json.load(f))) 201 | configs = pd.concat(configs, ignore_index=True) 202 | 203 | # Remove constant colums from configs dataframe. None values are considered constant. 204 | configs = configs.loc[:, configs.nunique() > 1] 205 | 206 | return configs.to_dict('records') 207 | 208 | if __name__ == '__main__': 209 | parser = argparse.ArgumentParser() 210 | # Data Params 211 | parser.add_argument('experiments', nargs='+', type=str, 212 | default=None, 213 | help="List of experiments to evaluate. " 214 | "Provide only one experiment when providing " 215 | "pretrained path. If pretrianed path is not " 216 | "provided, epoch with best validation metric " 217 | "is used for evaluation.") 218 | parser.add_argument('--results', type=str, default="", 219 | help="Path to the CSV file to store results.") 220 | parser.add_argument('--output_dir', type=str, default=None, 221 | help="Path to the directory to store outputs.") 222 | 223 | # System params 224 | parser.add_argument('--n_items', type=int, default=None, 225 | help="Number of items to test.") 226 | parser.add_argument('--pretrain_path', type=str, default="best", 227 | help="Path to pretrained weights") 228 | parser.add_argument('--profiling', dest='profiling', action='store_true', 229 | help="Enable or disable profiling.") 230 | parser.add_argument('--use_cuda', dest='use_cuda', action='store_true', 231 | help="Whether to use cuda") 232 | parser.add_argument('--gpu_ids', nargs='+', type=int, default=None, 233 | help="List of GPU ids used for training. " 234 | "Eg., --gpu_ids 2 4. All GPUs are used by default.") 235 | args = parser.parse_args() 236 | 237 | results = [] 238 | unique_hparams = get_unique_hparams(args.experiments) 239 | if len(unique_hparams) == 0: 240 | unique_hparams = [{}] 241 | 242 | for exp_dir, hparams in zip(args.experiments, unique_hparams): 243 | eval_args = argparse.Namespace(**vars(args)) 244 | eval_args.exp_dir = exp_dir 245 | 246 | utils.set_logger(os.path.join(exp_dir, 'eval.log')) 247 | logging.info("Evaluating %s ..." % exp_dir) 248 | 249 | # Load model and training params 250 | params = utils.Params(os.path.join(exp_dir, 'config.json')) 251 | for k, v in params.__dict__.items(): 252 | vars(eval_args)[k] = v 253 | 254 | network = importlib.import_module(eval_args.model) 255 | logging.info("Imported the model from '%s'." % eval_args.model) 256 | 257 | curr_res = evaluate(network, eval_args) 258 | for k, v in hparams.items(): 259 | curr_res[k] = v 260 | results.append(curr_res) 261 | 262 | del eval_args 263 | 264 | if args.results != "": 265 | print("Writing results to %s" % args.results) 266 | pd.DataFrame(results).to_csv(args.results, index=False) 267 | -------------------------------------------------------------------------------- /src/training/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | The main training script for training on synthetic data 3 | """ 4 | 5 | import argparse 6 | import multiprocessing 7 | import os 8 | import logging 9 | from pathlib import Path 10 | import random 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | from torch.utils.tensorboard import SummaryWriter 18 | from tqdm import tqdm # pylint: disable=unused-import 19 | from torchmetrics.functional import( 20 | scale_invariant_signal_noise_ratio as si_snr, 21 | signal_noise_ratio as snr, 22 | signal_distortion_ratio as sdr, 23 | scale_invariant_signal_distortion_ratio as si_sdr) 24 | 25 | from src.helpers import utils 26 | from src.training.eval import test_epoch 27 | 28 | def train_epoch(model: nn.Module, device: torch.device, 29 | optimizer: optim.Optimizer, 30 | train_loader: torch.utils.data.dataloader.DataLoader, 31 | n_items: int, epoch: int = 0, 32 | writer: SummaryWriter = None) -> float: 33 | 34 | """ 35 | Train a single epoch. 36 | """ 37 | # Set the model to training. 38 | model.train() 39 | 40 | # Training loop 41 | losses = [] 42 | metrics = {} 43 | 44 | tensorboard_trace_handler = torch.profiler.tensorboard_trace_handler( 45 | writer.log_dir) 46 | with tqdm(total=len(train_loader), desc='Train', ncols=100) as t: 47 | # with torch.profiler.profile( 48 | # schedule=torch.profiler.schedule( 49 | # skip_first=10, 50 | # wait=2, 51 | # warmup=2, 52 | # active=6, 53 | # repeat=2), 54 | # on_trace_ready=tensorboard_trace_handler, 55 | # profile_memory=True, 56 | # with_stack=True 57 | # ) as profiler: 58 | for batch_idx, (inp, tgt) in enumerate(train_loader): 59 | # Move data to device 60 | inp, tgt = train_loader.dataset.to(inp, tgt, device) 61 | 62 | # Reset grad 63 | optimizer.zero_grad() 64 | 65 | # Run through the model 66 | output = model(inp) 67 | 68 | # Compute loss 69 | loss = network.loss(output, tgt) 70 | 71 | losses.append(loss.item()) 72 | 73 | # Backpropagation 74 | loss.backward() 75 | 76 | # Gradient clipping 77 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 78 | 79 | # Update the weights 80 | optimizer.step() 81 | 82 | # Compute metrics 83 | output = train_loader.dataset.output_detach(output) 84 | metrics_batch = network.metrics(inp, output, tgt) 85 | for k in metrics_batch.keys(): 86 | if not k in metrics: 87 | metrics[k] = metrics_batch[k] 88 | else: 89 | metrics[k] += metrics_batch[k] 90 | 91 | output = train_loader.dataset.output_to(output, 'cpu') 92 | inp, tgt = train_loader.dataset.to(inp, tgt, 'cpu') 93 | if writer is not None and batch_idx == 0: 94 | train_loader.dataset.tensorboard_add_sample( 95 | writer, tag='Train', 96 | sample=(inp, output, tgt), 97 | step=epoch) 98 | 99 | # Step the profiler 100 | # profiler.step() 101 | 102 | # Show current loss in the progress meter 103 | t.set_postfix(loss='%.05f'%loss.item()) 104 | t.update() 105 | 106 | if n_items is not None and batch_idx == n_items: 107 | break 108 | 109 | avg_metrics = {k: np.mean(metrics[k]) for k in metrics.keys()} 110 | avg_metrics['loss'] = np.mean(losses) 111 | avg_metrics_str = "Train:" 112 | for m in avg_metrics.keys(): 113 | avg_metrics_str += ' %s=%.04f' % (m, avg_metrics[m]) 114 | logging.info(avg_metrics_str) 115 | 116 | return avg_metrics 117 | 118 | def train(args: argparse.Namespace): 119 | """ 120 | Train the network. 121 | """ 122 | # Load dataset 123 | data_train = utils.import_attr(args.train_dataset)(**args.train_data_args) 124 | logging.info("Loaded train dataset containing %d elements" % 125 | (len(data_train))) 126 | data_val = utils.import_attr(args.val_dataset)(**args.val_data_args) 127 | logging.info("Loaded test dataset containing %d elements" % 128 | (len(data_val))) 129 | 130 | # Set up the device and workers. 131 | use_cuda = args.use_cuda and torch.cuda.is_available() 132 | if use_cuda: 133 | gpu_ids = args.gpu_ids if args.gpu_ids is not None\ 134 | else range(torch.cuda.device_count()) 135 | device_ids = [_ for _ in gpu_ids] 136 | data_parallel = len(device_ids) > 1 137 | device = 'cuda:%d' % device_ids[0] 138 | torch.cuda.set_device(device_ids[0]) 139 | logging.info("Using CUDA devices: %s" % str(device_ids)) 140 | else: 141 | data_parallel = False 142 | device = torch.device('cpu') 143 | logging.info("Using device: CPU") 144 | 145 | # Set multiprocessing params 146 | num_workers = min(multiprocessing.cpu_count(), args.n_workers) 147 | kwargs = { 148 | 'num_workers': num_workers, 149 | 'pin_memory': True 150 | } if use_cuda else {} 151 | 152 | # Set up data loaders 153 | #print(args.batch_size, args.eval_batch_size) 154 | train_loader = torch.utils.data.DataLoader( 155 | data_train, batch_size=args.batch_size, shuffle=True, 156 | collate_fn=data_train.collate_fn, **kwargs) 157 | val_loader = torch.utils.data.DataLoader( 158 | data_val, batch_size=args.eval_batch_size, collate_fn=data_val.collate_fn, 159 | **kwargs) 160 | 161 | # Set up model 162 | model = network.Net(**args.model_params) 163 | 164 | # Add graph to tensorboard with example train samples 165 | # _mixed, _label, _ = next(iter(val_loader)) 166 | # args.writer.add_graph(model, (_mixed, _label)) 167 | 168 | if use_cuda and data_parallel: 169 | model = nn.DataParallel(model, device_ids=device_ids) 170 | logging.info("Using data parallel model") 171 | model.to(device) 172 | 173 | # Set up the optimizer 174 | logging.info("Initializing optimizer with %s" % str(args.optim)) 175 | optimizer = network.optimizer(model, **args.optim, data_parallel=data_parallel) 176 | logging.info('Learning rates initialized to:' + utils.format_lr_info(optimizer)) 177 | 178 | lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( 179 | optimizer, **args.lr_sched) 180 | logging.info("Initialized LR scheduler with params: fix_lr_epochs=%d %s" 181 | % (args.fix_lr_epochs, str(args.lr_sched))) 182 | 183 | base_metric = args.base_metric 184 | train_metrics = {} 185 | val_metrics = {} 186 | 187 | # Load the model if `args.start_epoch` is greater than 0. This will load the 188 | # model from epoch = `args.start_epoch - 1` 189 | assert args.start_epoch >=0, "start_epoch must be greater than 0." 190 | if args.start_epoch > 0: 191 | checkpoint_path = os.path.join(args.exp_dir, 192 | '%d.pt' % (args.start_epoch - 1)) 193 | _, train_metrics, val_metrics = utils.load_checkpoint( 194 | checkpoint_path, model, optim=optimizer, lr_sched=lr_scheduler, 195 | data_parallel=data_parallel) 196 | logging.info("Loaded checkpoint from %s" % checkpoint_path) 197 | logging.info("Learning rates restored to:" + utils.format_lr_info(optimizer)) 198 | 199 | # Training loop 200 | try: 201 | torch.autograd.set_detect_anomaly(args.detect_anomaly) 202 | for epoch in range(args.start_epoch, args.epochs + 1): 203 | logging.info("Epoch %d:" % epoch) 204 | checkpoint_file = os.path.join(args.exp_dir, '%d.pt' % epoch) 205 | assert not os.path.exists(checkpoint_file), \ 206 | "Checkpoint file %s already exists" % checkpoint_file 207 | #print("---- begin trianivg") 208 | curr_train_metrics = train_epoch(model, device, optimizer, 209 | train_loader, args.n_train_items, 210 | epoch=epoch, writer=args.writer) 211 | #raise KeyboardInterrupt 212 | curr_test_metrics = test_epoch(model, device, val_loader, 213 | args.n_test_items, network.loss, 214 | network.metrics, epoch=epoch, 215 | writer=args.writer) 216 | # LR scheduler 217 | if epoch >= args.fix_lr_epochs: 218 | lr_scheduler.step(curr_test_metrics[base_metric]) 219 | logging.info( 220 | "LR after scheduling step: %s" % 221 | [_['lr'] for _ in optimizer.param_groups]) 222 | 223 | # Write metrics to tensorboard 224 | args.writer.add_scalars('Train', curr_train_metrics, epoch) 225 | args.writer.add_scalars('Val', curr_test_metrics, epoch) 226 | args.writer.flush() 227 | 228 | for k in curr_train_metrics.keys(): 229 | if not k in train_metrics: 230 | train_metrics[k] = [curr_train_metrics[k]] 231 | else: 232 | train_metrics[k].append(curr_train_metrics[k]) 233 | 234 | for k in curr_test_metrics.keys(): 235 | if not k in val_metrics: 236 | val_metrics[k] = [curr_test_metrics[k]] 237 | else: 238 | val_metrics[k].append(curr_test_metrics[k]) 239 | 240 | if max(val_metrics[base_metric]) == val_metrics[base_metric][-1]: 241 | logging.info("Found best validation %s!" % base_metric) 242 | 243 | utils.save_checkpoint( 244 | checkpoint_file, epoch, model, optimizer, lr_scheduler, 245 | train_metrics, val_metrics, data_parallel) 246 | logging.info("Saved checkpoint at %s" % checkpoint_file) 247 | 248 | utils.save_graph(train_metrics, val_metrics, args.exp_dir) 249 | 250 | return train_metrics, val_metrics 251 | 252 | 253 | except KeyboardInterrupt: 254 | print("Interrupted") 255 | except Exception as _: # pylint: disable=broad-except 256 | import traceback # pylint: disable=import-outside-toplevel 257 | traceback.print_exc() 258 | 259 | 260 | if __name__ == '__main__': 261 | parser = argparse.ArgumentParser() 262 | # Data Params 263 | parser.add_argument('exp_dir', type=str, 264 | default='./experiments/fsd_mask_label_mult', 265 | help="Path to save checkpoints and logs.") 266 | 267 | parser.add_argument('--n_train_items', type=int, default=None, 268 | help="Number of items to train on in each epoch") 269 | parser.add_argument('--n_test_items', type=int, default=None, 270 | help="Number of items to test.") 271 | parser.add_argument('--start_epoch', type=int, default=0, 272 | help="Start epoch") 273 | parser.add_argument('--pretrain_path', type=str, 274 | help="Path to pretrained weights") 275 | parser.add_argument('--use_cuda', dest='use_cuda', action='store_true', 276 | help="Whether to use cuda") 277 | parser.add_argument('--gpu_ids', nargs='+', type=int, default=None, 278 | help="List of GPU ids used for training. " 279 | "Eg., --gpu_ids 2 4. All GPUs are used by default.") 280 | parser.add_argument('--detect_anomaly', dest='detect_anomaly', 281 | action='store_true', 282 | help="Whether to use cuda") 283 | parser.add_argument('--wandb', dest='wandb', action='store_true', 284 | help="Whether to sync tensorboard to wandb") 285 | 286 | args = parser.parse_args() 287 | 288 | # Set the random seed for reproducible experiments 289 | torch.manual_seed(230) 290 | random.seed(230) 291 | np.random.seed(230) 292 | if args.use_cuda: 293 | torch.cuda.manual_seed(230) 294 | 295 | # Set up checkpoints 296 | if not os.path.exists(args.exp_dir): 297 | os.makedirs(args.exp_dir) 298 | 299 | utils.set_logger(os.path.join(args.exp_dir, 'train.log')) 300 | 301 | # Load model and training params 302 | params = utils.Params(os.path.join(args.exp_dir, 'config.json')) 303 | for k, v in params.__dict__.items(): 304 | if k in vars(args): 305 | logging.warning("Argument %s is overwritten by config file." % k) 306 | vars(args)[k] = v 307 | 308 | # Initialize tensorboard writer 309 | tensorboard_dir = os.path.join(args.exp_dir, 'tensorboard') 310 | args.writer = SummaryWriter(tensorboard_dir, purge_step=args.start_epoch) 311 | if args.wandb: 312 | import wandb 313 | wandb.init( 314 | project='Semaudio', sync_tensorboard=True, 315 | dir=tensorboard_dir, name=os.path.basename(args.exp_dir)) 316 | 317 | exec("import %s as network" % args.model) 318 | logging.info("Imported the model from '%s'." % args.model) 319 | 320 | train(args) 321 | 322 | args.writer.close() 323 | if args.wandb: 324 | wandb.finish() 325 | --------------------------------------------------------------------------------