├── cos ├── __init__.py ├── helpers │ ├── __init__.py │ ├── constants.py │ ├── eval_utils.py │ ├── visualization.py │ ├── ibm.py │ ├── irm.py │ ├── utils.py │ └── mwf.py ├── inference │ ├── __init__.py │ ├── evaluate_synthetic.py │ └── separation_by_localization.py ├── training │ ├── __init__.py │ ├── data_augmentation.py │ ├── synthetic_dataset.py │ ├── network.py │ └── train.py └── generate_dataset.py ├── checkpoints └── info.txt ├── outputs └── info.txt ├── requirements.txt ├── .gitignore ├── LICENSE ├── RunningTips.md └── README.md /cos/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cos/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cos/inference/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cos/training/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /checkpoints/info.txt: -------------------------------------------------------------------------------- 1 | pt checkpoint files go here -------------------------------------------------------------------------------- /outputs/info.txt: -------------------------------------------------------------------------------- 1 | Output data gets written here -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ### Requirements 2 | librosa 3 | torch >= 1.3.0 4 | soundfile 5 | scipy 6 | pyroomacoustics==0.3.1 7 | matplotlib 8 | mir_eval 9 | tqdm 10 | numpy 11 | pysndfx 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Dot underscore 7 | ._* 8 | 9 | # Outputs 10 | 11 | # Pytorch 12 | *.pt 13 | 14 | # Audio files 15 | *.wav 16 | *.mp4 17 | *.flac 18 | *.mp3 19 | 20 | # Data Files 21 | *.npy 22 | -------------------------------------------------------------------------------- /cos/training/data_augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pysndfx import AudioEffectsChain 4 | 5 | class RandomAudioPerturbation(object): 6 | """Randomly perturb audio samples""" 7 | 8 | def __call__(self, data): 9 | """ 10 | Data must be mics x T numpy array 11 | """ 12 | highshelf_gain = np.random.normal(0, 2) 13 | lowshelf_gain = np.random.normal(0, 2) 14 | noise_amount = np.random.uniform(0, 0.001) 15 | 16 | fx = ( 17 | AudioEffectsChain() 18 | .highshelf(gain=highshelf_gain) 19 | .lowshelf(gain=lowshelf_gain) 20 | ) 21 | 22 | for i in range(data.shape[0]): 23 | data[i] = fx(data[i]) 24 | data[i] += np.random.uniform(-noise_amount, noise_amount, size=data[i].shape) 25 | return data 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Teerapat Jenrungrot, Vivek Jayaram, Steve Seitz, and Ira Kemelmacher-Shlizerman 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /RunningTips.md: -------------------------------------------------------------------------------- 1 | # Tips for Running on Real Data 2 | 3 | ## Hardware 4 | We have only tested this network with circular microphone arrays. It might be possible to train a new model for ad-hoc or linear mic arrays. The pretrained model works with the 4 mic [Seed ReSpeaker MicArray v 2.0](https://wiki.seeedstudio.com/ReSpeaker_Mic_Array_v2.0/) 5 | 6 | ## Capturing Data 7 | ### Positioning 8 | For best results, the sources should not be too far from the microphone. Our model was trained with sources between 1-4 meters from the microphone, but best results are 1-2m away. Extreme far field (>4m) has not been explored. Ideally the sources should be at roughly the same elevation angle as the microphone. For example, If the microphone is on the ground and the sources are standing, the assumptions in the pre-shift formulas no longer hold. 9 | ### Obstructions 10 | The point source model breaks down if there is not a direct line of sight between the source and the microphone, for example if there is a laptop between the voice and mic. 11 | ## Hyperparameters 12 | If the sources are completely stationary, you can increase the `--duration` flag which processes large chunks at a time. This improves the performance and reduces boundary effects. At the top of `cos/inference/separation_by_localization.py` are additional parameters. Tweak `ENERGY_CUTOFF` to more aggressively keep or reject sources. For more aggressive non-max suppression, you can reduce `NMS_SIMILARITY_SDR` which only keeps additional sources if they have a SDR to the existing source that is lower than this parameter. 13 | 14 | ## Post Processing 15 | After separating voices from each other, any type of single channel post processing can be run on the output. We found that it was useful to run a low-pass filter on the output. The network requires 44.1kHz sampling runing for localization and time differences, but can sometimes produce artifacts in these high frequency ranges. Because human voice doesn't contain many frequences about ~6kHz, these frequencies can simply be cut. 16 | 17 | # Tips for Training on Real Data 18 | 19 | We strongly recommend training on data collected from your specific microphone. Due to the 2020 situation, we could not actually record a variety of real background sounds, a variety of environments, or a variety real speakers with our mic array. We expect the network performance to improve with more training data, even beyond the pretrained models we have provided. If you train on new data, do not train from scratch: instead fine-tune from our existing weights even if the number of channels is different. Training from scratch often results in the network outputting silence everywhere. In our experience, it is best to jointly train on real and synthetic data. 20 | -------------------------------------------------------------------------------- /cos/helpers/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | A collection of constants that probably should not be changed 3 | """ 4 | 5 | import numpy as np 6 | 7 | # Universal Constants 8 | SPEED_OF_SOUND = 343.0 # m/s 9 | FAR_FIELD_RADIUS = 3.0 # meters, assume larger the mic array radius 10 | 11 | # Algorithmic Constants 12 | ALL_WINDOW_SIZES = [ 13 | np.pi / 2, # 90 degrees 14 | np.pi / 4, # 45 degrees 15 | np.pi / 8, # 22.5 degrees 16 | np.pi / 16, # 11.25 degrees 17 | np.pi / 32, # 5.625 degrees 18 | ] 19 | 20 | def get_mic_diagram(): 21 | """ 22 | A simple hard-coded mic figure for matplotlib 23 | """ 24 | import matplotlib 25 | matplotlib.use("Agg") 26 | mic_verts = np.array([[24. , 28. ], 27 | [27.31, 28. ], 28 | [29.98, 25.31], 29 | [29.98, 22. ], 30 | [30. , 10. ], 31 | [30. , 6.68], 32 | [27.32, 4. ], 33 | [24. , 4. ], 34 | [20.69, 4. ], 35 | [18. , 6.68], 36 | [18. , 10. ], 37 | [18. , 22. ], 38 | [18. , 25.31], 39 | [20.69, 28. ], 40 | [24. , 28. ], 41 | [24. , 28. ], 42 | [34.6 , 22. ], 43 | [34.6 , 28. ], 44 | [29.53, 32.2 ], 45 | [24. , 32.2 ], 46 | [18.48, 32.2 ], 47 | [13.4 , 28. ], 48 | [13.4 , 22. ], 49 | [10. , 22. ], 50 | [10. , 28.83], 51 | [15.44, 34.47], 52 | [22. , 35.44], 53 | [22. , 42. ], 54 | [26. , 42. ], 55 | [26. , 35.44], 56 | [32.56, 34.47], 57 | [38. , 28.83], 58 | [38. , 22. ], 59 | [34.6 , 22. ], 60 | [34.6 , 22. ]]) 61 | mic_verts[:,1] = (48 - mic_verts[:,1]) - 24 62 | mic_verts[:,0] -= 24 63 | 64 | mic_verts[:,0] /= 240 65 | mic_verts[:,1] /= 240 66 | 67 | mic_verts *= 10 68 | 69 | mic_codes = np.array([ 1, 4, 4, 4, 2, 4, 4, 4, 4, 4, 4, 2, 4, 4, 4, 79, 1, 70 | 4, 4, 4, 4, 4, 4, 2, 4, 4, 4, 2, 2, 2, 4, 4, 4, 2, 71 | 79], dtype=np.uint8) 72 | 73 | mic = matplotlib.path.Path(mic_verts, mic_codes) 74 | return mic 75 | -------------------------------------------------------------------------------- /cos/helpers/eval_utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import math 3 | 4 | import numpy as np 5 | 6 | from mir_eval.separation import bss_eval_sources 7 | 8 | from cos.helpers.utils import angular_distance 9 | 10 | def si_sdr(estimated_signal, reference_signals, scaling=True): 11 | """ 12 | This is a scale invariant SDR. See https://arxiv.org/pdf/1811.02508.pdf 13 | or https://github.com/sigsep/bsseval/issues/3 for the motivation and 14 | explanation 15 | 16 | Input: 17 | estimated_signal and reference signals are (N,) numpy arrays 18 | 19 | Returns: SI-SDR as scalar 20 | """ 21 | Rss = np.dot(reference_signals, reference_signals) 22 | this_s = reference_signals 23 | 24 | if scaling: 25 | # get the scaling factor for clean sources 26 | a = np.dot(this_s, estimated_signal) / Rss 27 | else: 28 | a = 1 29 | 30 | e_true = a * this_s 31 | e_res = estimated_signal - e_true 32 | 33 | Sss = (e_true**2).sum() 34 | Snn = (e_res**2).sum() 35 | 36 | SDR = 10 * math.log10(Sss/Snn) 37 | 38 | return SDR 39 | 40 | 41 | def compute_sdr(gt, output, single_channel=False): 42 | assert(gt.shape == output.shape) 43 | per_channel_sdr = [] 44 | 45 | channels = [0] if single_channel else range(gt.shape[0]) 46 | for channel_idx in channels: 47 | # sdr, _, _, _ = bss_eval_sources(gt[channel_idx], output[channel_idx]) 48 | sdr = si_sdr(output[channel_idx], gt[channel_idx]) 49 | per_channel_sdr.append(sdr) 50 | 51 | return np.array(per_channel_sdr).mean() 52 | 53 | 54 | 55 | def find_best_permutation_prec_recall(gt, output, acceptable_window=np.pi / 18): 56 | """ 57 | Finds the best permutation for evaluation. 58 | Then uses that to find the precision and recall 59 | 60 | Inputs: 61 | gt, output: list of sources. lengths may differ 62 | 63 | Returns: Permutation that matches outputs to gt along with tp, fn and fp 64 | """ 65 | n = max(len(gt), len(output)) 66 | 67 | if len(gt) > len(output): 68 | output += [np.inf] * (n - len(output)) 69 | elif len(output) > len(gt): 70 | gt += [np.inf] * (n - len(gt)) 71 | 72 | best_perm = None 73 | best_inliers = -1 74 | for perm in itertools.permutations(range(n)): 75 | curr_inliers = 0 76 | for idx1, idx2 in enumerate(perm): 77 | if angular_distance(gt[idx1], output[idx2]) < acceptable_window: 78 | curr_inliers += 1 79 | 80 | if curr_inliers > best_inliers: 81 | best_inliers = curr_inliers 82 | best_perm = list(perm) 83 | 84 | return localization_precision_recall(best_perm, gt, output, acceptable_window) 85 | 86 | 87 | def localization_precision_recall(permutation, gt, output, acceptable_window=np.pi/18): 88 | tp, fn, fp = 0, 0, 0 89 | for idx1, idx2 in enumerate(permutation): 90 | if angular_distance(gt[idx1], output[idx2]) < acceptable_window: 91 | tp += 1 92 | elif gt[idx1] == np.inf: 93 | fp += 1 94 | elif output[idx2] == np.inf: 95 | fn += 1 96 | else: 97 | fn += 1 98 | fp += 1 99 | 100 | return permutation, (tp, fn, fp) 101 | -------------------------------------------------------------------------------- /cos/helpers/visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | 5 | import cos.helpers.utils as utils 6 | from cos.helpers.constants import get_mic_diagram 7 | 8 | 9 | 10 | def draw_diagram(voice_positions, candidate_angles, angle_window_size, output_file): 11 | """ 12 | Draws the setup of all the voices in space, and colored triangles for the beams 13 | """ 14 | import matplotlib 15 | matplotlib.use("Agg") 16 | import matplotlib.pyplot as plt 17 | from matplotlib.patches import Circle, Wedge, Polygon 18 | from matplotlib.collections import PatchCollection 19 | matplotlib.style.use('ggplot') 20 | 21 | colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] 22 | colors = [colors[0], colors[3], colors[2], colors[1]] 23 | 24 | fig, ax = plt.subplots() 25 | ax.set(xlim=(-5, 5), ylim = (-5, 5)) 26 | ax.set_aspect("equal") 27 | ax.annotate('X', xy=(-5,0), xytext=(5.2,0), 28 | arrowprops={'arrowstyle': '<|-', "color": "black", "linewidth":3}, va='center', fontsize=20) 29 | ax.annotate('Y', xy=(0,-5), xytext=(-0.25, 5.5), 30 | arrowprops={'arrowstyle': '<|-', "color": "black", "linewidth":3}, va='center', fontsize=20) 31 | 32 | plt.tick_params(axis='both', 33 | which='both', bottom='off', 34 | top='off', labelbottom='off', right='off', left='off', labelleft='off' 35 | ) 36 | 37 | for pos in voice_positions: 38 | if pos[0] != 0.0: 39 | a_circle = plt.Circle((pos[0], pos[1]), 0.3, color='b', fill=False) 40 | ax.add_artist(a_circle) 41 | 42 | patches = [] 43 | for idx, target_angle in enumerate(candidate_angles): 44 | vertices = angle_to_triangle(target_angle, angle_window_size) * 4.96 45 | 46 | ax.fill(vertices[:, 0], vertices[:, 1], 47 | edgecolor='black', linewidth=2, alpha=0.6) 48 | 49 | mic = get_mic_diagram() 50 | patch = matplotlib.patches.PathPatch(mic, fill=True, facecolor='black') 51 | ax.add_patch(patch) 52 | ax.tick_params(axis='both', which='both', labelcolor="white", colors="white") 53 | plt.savefig(output_file) 54 | 55 | 56 | def angle_to_triangle(target_angle, angle_window_size): 57 | """ 58 | Takes a target angle and window size and returns a 59 | triangle corresponding to that pie slice 60 | """ 61 | first_point = [0,0] # Always start at the origiin 62 | second_point = angle_to_point(utils.convert_angular_range(target_angle - angle_window_size/2)) 63 | third_point = angle_to_point(utils.convert_angular_range(target_angle + angle_window_size/2)) 64 | 65 | return(np.array([first_point, second_point, third_point])) 66 | 67 | 68 | def angle_to_point(angle): 69 | """Angle must be -pi to pi""" 70 | if -np.pi <= angle < -3*np.pi/4: 71 | return[-1, -np.tan(angle + np.pi)] 72 | 73 | elif -3*np.pi/4 <= angle < -np.pi/2: 74 | return[-np.tan(-np.pi/2 - angle), -1] 75 | 76 | elif -np.pi/2 <= angle < -np.pi/4: 77 | return[np.tan(angle + np.pi/2), -1] 78 | 79 | elif -np.pi/4 <= angle < 0: 80 | return[1, -np.tan(-angle)] 81 | 82 | elif 0 <= angle < np.pi / 4: 83 | return [1, np.tan(angle)] 84 | 85 | elif np.pi/4 <= angle < np.pi/2: 86 | return [np.tan(np.pi/2 - angle), 1] 87 | 88 | elif np.pi/2 <= angle <= 3*np.pi/4: 89 | return [-np.tan(angle - np.pi/2), 1] 90 | 91 | elif 3*np.pi/4 < angle <= np.pi: 92 | return [-1, np.tan(np.pi - angle)] 93 | -------------------------------------------------------------------------------- /cos/helpers/ibm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import multiprocessing.dummy as mp 4 | import os 5 | 6 | from pathlib import Path 7 | 8 | import librosa 9 | import numpy as np 10 | import tqdm 11 | 12 | from scipy.signal import stft, istft 13 | 14 | from cos.helpers.eval_utils import compute_sdr 15 | from cos.inference.evaluate_synthetic import get_items 16 | from cos.helpers.utils import check_valid_dir 17 | 18 | def compute_ibm(gt, mix, alpha, theta=0.5): 19 | """ 20 | Computes the Ideal Binary Mask SI-SDR 21 | gt: (n_voices, n_channels, t) 22 | mix: (n_channels, t) 23 | """ 24 | n_voices = gt.shape[0] 25 | nfft = 2048 26 | eps = np.finfo(np.float).eps 27 | 28 | N = mix.shape[-1] # number of samples 29 | X = stft(mix, nperseg=nfft)[2] 30 | (I, F, T) = X.shape # (6, nfft//2 +1, n_frame) 31 | 32 | # perform separation 33 | estimates = [] 34 | for gt_idx in range(n_voices): 35 | # compute STFT of target source 36 | Yj = stft(gt[gt_idx], nperseg=nfft)[2] 37 | 38 | # Create binary Mask 39 | mask = np.divide(np.abs(Yj)**alpha, (eps + np.abs(X) ** alpha)) 40 | mask[np.where(mask >= theta)] = 1 41 | mask[np.where(mask < theta)] = 0 42 | 43 | Yj = np.multiply(X, mask) 44 | target_estimate = istft(Yj)[1][:,:N] 45 | 46 | estimates.append(target_estimate) 47 | 48 | estimates = np.array(estimates) # (nvoice, 6, 6*sr) 49 | 50 | # eval 51 | eval_mix = np.repeat(mix[np.newaxis, :, :], n_voices, axis=0) # (nvoice, 6, 6*sr) 52 | eval_gt = gt # (nvoice, 6, 6*sr) 53 | eval_est = estimates 54 | 55 | SDR_in = [] 56 | SDR_out = [] 57 | for i in range(n_voices): 58 | SDR_in.append(compute_sdr(eval_gt[i], eval_mix[i], single_channel=True)) # scalar 59 | SDR_out.append(compute_sdr(eval_gt[i], eval_est[i], single_channel=True)) # scalar 60 | 61 | output = np.array([SDR_in, SDR_out]) # (2, nvoice) 62 | 63 | return output 64 | 65 | 66 | def main(args): 67 | all_dirs = sorted(list(Path(args.input_dir).glob('[0-9]*'))) 68 | all_dirs = [x for x in all_dirs if check_valid_dir(x, args.n_voices)] 69 | 70 | all_input_sdr = [0] * len(all_dirs) 71 | all_output_sdr = [0] * len(all_dirs) 72 | 73 | def evaluate_dir(idx): 74 | curr_dir = all_dirs[idx] 75 | # Loads the data 76 | mixed_data, gt = get_items(curr_dir, args) 77 | gt = np.array([x.data for x in gt]) 78 | output = compute_ibm(gt, mixed_data, alpha=args.alpha) 79 | all_input_sdr[idx] = output[0] 80 | all_output_sdr[idx] = output[1] 81 | print("Running median SDRi: ", 82 | np.median(np.array(all_output_sdr[:idx+1]) - np.array(all_input_sdr[:idx+1]))) 83 | 84 | pool = mp.Pool(args.n_workers) 85 | with tqdm.tqdm(total=len(all_dirs)) as pbar: 86 | for i, _ in enumerate(pool.imap_unordered(evaluate_dir, range(len(all_dirs)))): 87 | pbar.update() 88 | 89 | # tqdm.tqdm(pool.imap(evaluate_dir, range(len(all_dirs))), total=len(all_dirs)) 90 | pool.close() 91 | pool.join() 92 | 93 | print("Median SI-SDRi: ", 94 | np.median(np.array(all_output_sdr).flatten() - np.array(all_input_sdr).flatten())) 95 | 96 | np.save("IBM_{}voices_{}kHz.npy".format(args.n_voices, args.sr), 97 | np.array([np.array(all_input_sdr).flatten(), np.array(all_output_sdr).flatten()])) 98 | 99 | 100 | if __name__ == '__main__': 101 | parser = argparse.ArgumentParser() 102 | parser.add_argument('input_dir', type=str, help="Path to the input dir") 103 | parser.add_argument('--sr', type=int, default=22050, help="Sampling rate") 104 | parser.add_argument('--n_channels', 105 | type=int, 106 | default=2, 107 | help="Number of channels") 108 | parser.add_argument('--n_workers', 109 | type=int, 110 | default=8, 111 | help="Number of parallel workers") 112 | parser.add_argument('--n_voices', 113 | type=int, 114 | default=2, 115 | help="Number of voices in the dataset") 116 | parser.add_argument('--alpha', 117 | type=int, 118 | default=1, 119 | help="See the original SigSep code for an explanation") 120 | args = parser.parse_args() 121 | 122 | main(args) 123 | 124 | 125 | -------------------------------------------------------------------------------- /cos/helpers/irm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import multiprocessing.dummy as mp 4 | import os 5 | 6 | from pathlib import Path 7 | 8 | import librosa 9 | import numpy as np 10 | import tqdm 11 | 12 | from scipy.signal import stft, istft 13 | 14 | from cos.helpers.eval_utils import compute_sdr 15 | from cos.inference.evaluate_synthetic import get_items 16 | from cos.helpers.utils import check_valid_dir 17 | 18 | 19 | def compute_irm(gt, mix, alpha): 20 | """ 21 | Computes the Ideal Ratio Mask SI-SDR 22 | gt: (n_voices, n_channels, t) 23 | mix: (n_channels, t) 24 | """ 25 | n_voices = gt.shape[0] 26 | nfft = 2048 27 | hop = 1024 28 | eps = np.finfo(np.float).eps 29 | 30 | N = mix.shape[-1] # number of samples 31 | X = stft(mix, nperseg=nfft)[2] 32 | (I, F, T) = X.shape # (6, nfft//2 +1, n_frame) 33 | 34 | # Compute sources spectrograms 35 | P = [] 36 | for gt_idx in range(n_voices): 37 | P.append(np.abs(stft(gt[gt_idx], nperseg=nfft)[2]) ** alpha) 38 | 39 | # compute model as the sum of spectrograms 40 | model = eps 41 | for gt_idx in range(n_voices): 42 | model += P[gt_idx] 43 | 44 | # perform separation 45 | estimates = [] 46 | for gt_idx in range(n_voices): 47 | # Create a ratio Mask 48 | mask = np.divide(np.abs(P[gt_idx]), model) 49 | 50 | # apply mask 51 | Yj = np.multiply(X, mask) 52 | 53 | target_estimate = istft(Yj)[1][:,:N] 54 | 55 | estimates.append(target_estimate) 56 | 57 | estimates = np.array(estimates) # (nvoice, 6, 6*sr) 58 | 59 | # eval 60 | eval_mix = np.repeat(mix[np.newaxis, :, :], n_voices, axis=0) # (nvoice, 6, 6*sr) 61 | eval_gt = gt # (nvoice, 6, 6*sr) 62 | eval_est = estimates 63 | 64 | SDR_in = [] 65 | SDR_out = [] 66 | for i in range(n_voices): 67 | SDR_in.append(compute_sdr(eval_gt[i], eval_mix[i], single_channel=True)) # scalar 68 | SDR_out.append(compute_sdr(eval_gt[i], eval_est[i], single_channel=True)) # scalar 69 | 70 | output = np.array([SDR_in, SDR_out]) # (2, nvoice) 71 | 72 | return output 73 | 74 | def main(args): 75 | all_dirs = sorted(list(Path(args.input_dir).glob('[0-9]*'))) 76 | all_dirs = [x for x in all_dirs if check_valid_dir(x, args.n_voices)] 77 | 78 | all_input_sdr = [0] * len(all_dirs) 79 | all_output_sdr = [0] * len(all_dirs) 80 | 81 | def evaluate_dir(idx): 82 | curr_dir = all_dirs[idx] 83 | # Loads the data 84 | mixed_data, gt = get_items(curr_dir, args) 85 | gt = np.array([x.data for x in gt]) 86 | output = compute_irm(gt, mixed_data, alpha=args.alpha) 87 | all_input_sdr[idx] = output[0] 88 | all_output_sdr[idx] = output[1] 89 | 90 | pool = mp.Pool(args.n_workers) 91 | with tqdm.tqdm(total=len(all_dirs)) as pbar: 92 | for i, _ in enumerate(pool.imap_unordered(evaluate_dir, range(len(all_dirs)))): 93 | pbar.update() 94 | 95 | # tqdm.tqdm(pool.imap(evaluate_dir, range(len(all_dirs))), total=len(all_dirs)) 96 | pool.close() 97 | pool.join() 98 | 99 | print("Median SI-SDRi: ", 100 | np.median(np.array(all_output_sdr).flatten() - np.array(all_input_sdr).flatten())) 101 | 102 | np.save("IRM_{}voices_{}kHz.npy".format(args.n_voices, args.sr), 103 | np.array([np.array(all_input_sdr).flatten(), np.array(all_output_sdr).flatten()])) 104 | 105 | 106 | if __name__ == '__main__': 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument('input_dir', type=str, help="Path to the input dir") 109 | parser.add_argument('--sr', type=int, default=22050, help="Sampling rate") 110 | parser.add_argument('--n_channels', 111 | type=int, 112 | default=2, 113 | help="Number of channels") 114 | parser.add_argument('--n_workers', 115 | type=int, 116 | default=8, 117 | help="Number of parallel workers") 118 | parser.add_argument('--n_voices', 119 | type=int, 120 | default=2, 121 | help="Number of voices in the dataset") 122 | parser.add_argument('--alpha', 123 | type=int, 124 | default=1, 125 | help="See the original SigSep code for an explanation") 126 | args = parser.parse_args() 127 | 128 | main(args) 129 | -------------------------------------------------------------------------------- /cos/helpers/utils.py: -------------------------------------------------------------------------------- 1 | """A collection of useful helper functions""" 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from pathlib import Path 7 | 8 | from cos.helpers.constants import SPEED_OF_SOUND 9 | 10 | 11 | def shift_mixture(input_data, target_position, mic_radius, sr, inverse=False): 12 | """ 13 | Shifts the input according to the voice position. This 14 | lines up the voice samples in the time domain coming from a target_angle 15 | Args: 16 | input_data - M x T numpy array or torch tensor 17 | target_position - The location where the data should be aligned 18 | mic_radius - In meters. The number of mics is inferred from 19 | the input_Data 20 | sr - Sample Rate in samples/sec 21 | inverse - Whether to align or undo a previous alignment 22 | 23 | Returns: shifted data and a list of the shifts 24 | """ 25 | # elevation_angle = 0.0 * np.pi / 180 26 | # target_height = 3.0 * np.tan(elevation_angle) 27 | # target_position = np.append(target_position, target_height) 28 | 29 | num_channels = input_data.shape[0] 30 | 31 | # Must match exactly the generated or captured data 32 | mic_array = [[ 33 | mic_radius * np.cos(2 * np.pi / num_channels * i), 34 | mic_radius * np.sin(2 * np.pi / num_channels * i), 35 | ] for i in range(num_channels)] 36 | 37 | # Mic 0 is the canonical position 38 | distance_mic0 = np.linalg.norm(mic_array[0] - target_position) 39 | shifts = [0] 40 | 41 | # Check if numpy or torch 42 | if isinstance(input_data, np.ndarray): 43 | shift_fn = np.roll 44 | elif isinstance(input_data, torch.Tensor): 45 | shift_fn = torch.roll 46 | else: 47 | raise TypeError("Unknown input data type: {}".format(type(input_data))) 48 | 49 | # Shift each channel of the mixture to align with mic0 50 | for channel_idx in range(1, num_channels): 51 | distance = np.linalg.norm(mic_array[channel_idx] - target_position) 52 | distance_diff = distance - distance_mic0 53 | shift_time = distance_diff / SPEED_OF_SOUND 54 | shift_samples = int(round(sr * shift_time)) 55 | if inverse: 56 | input_data[channel_idx] = shift_fn(input_data[channel_idx], 57 | shift_samples) 58 | else: 59 | input_data[channel_idx] = shift_fn(input_data[channel_idx], 60 | -shift_samples) 61 | shifts.append(shift_samples) 62 | 63 | return input_data, shifts 64 | 65 | 66 | def angular_distance(angle1, angle2): 67 | """ 68 | Computes the distance in radians betwen angle1 and angle2. 69 | We assume they are between -pi and pi 70 | """ 71 | d1 = abs(angle1 - angle2) 72 | d2 = abs(angle1 - angle2 + 2 * np.pi) 73 | d3 = abs(angle2 - angle1 + 2 * np.pi) 74 | 75 | return min(d1, d2, d3) 76 | 77 | def get_starting_angles(window_size): 78 | """Returns the list of target angles for a window size""" 79 | divisor = int(round(2 * np.pi / window_size)) 80 | return np.array(list(range(-divisor + 1, divisor, 2))) * np.pi / divisor 81 | 82 | 83 | def to_categorical(index: int, num_classes: int): 84 | """Creates a 1-hot encoded np array""" 85 | data = np.zeros((num_classes)) 86 | data[index] = 1 87 | return data 88 | 89 | def convert_angular_range(angle: float): 90 | """Converts an arbitrary angle to the range [-pi pi]""" 91 | corrected_angle = angle % (2 * np.pi) 92 | if corrected_angle > np.pi: 93 | corrected_angle -= (2 * np.pi) 94 | 95 | return corrected_angle 96 | 97 | def trim_silence(audio, window_size=22050, cutoff=0.001): 98 | """Trims all silence within an audio file""" 99 | idx = 0 100 | new_audio = [] 101 | while idx * window_size < audio.shape[1]: 102 | segment = audio[:, idx*window_size:(idx+1)*window_size] 103 | if segment.std() > cutoff: 104 | new_audio.append(segment) 105 | idx += 1 106 | 107 | return np.concatenate(new_audio, axis=1) 108 | 109 | 110 | def check_valid_dir(dir, requires_n_voices=2): 111 | """Checks that there is at least n voices""" 112 | if len(list(Path(dir).glob('*_voice00.wav'))) < 1: 113 | return False 114 | 115 | if requires_n_voices == 2: 116 | if len(list(Path(dir).glob('*_voice01.wav'))) < 1: 117 | return False 118 | 119 | if requires_n_voices == 3: 120 | if len(list(Path(dir).glob('*_voice02.wav'))) < 1: 121 | return False 122 | 123 | if requires_n_voices == 4: 124 | if len(list(Path(dir).glob('*_voice03.wav'))) < 1: 125 | return False 126 | 127 | if len(list(Path(dir).glob('metadata.json'))) < 1: 128 | return False 129 | 130 | return True 131 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # The Cone of Silence: Speech Separation by Localization 2 | ![alt text](https://public-static-files.s3-us-west-1.amazonaws.com/real_teaser_2.png) 3 | 4 | ## Authors 5 | [Teerapat Jenrungrot](https://mjenrungrot.com/)**\***, [Vivek Jayaram](http://www.vivekjayaram.com/)**\***, [Steve Seitz](https://homes.cs.washington.edu/~seitz/), and [Ira Kemelmacher-Shlizerman](https://homes.cs.washington.edu/~kemelmi/)
6 | *\*Co-First Authors*
7 | University of Washington 8 | 9 | ## [Project Page](http://grail.cs.washington.edu/projects/cone-of-silence/) 10 | Video and audio demos are available at the project page 11 | 12 | ### [Paper](https://arxiv.org/pdf/2010.06007.pdf) 13 | 34th Conference on Neural Information Processing Systems, NeurIPS 2020. (Oral) 14 | 15 | ### Blog Post - Coming Soon 16 | 17 | ### Summary 18 | Our method performs source separation and localization for human speakers. Key features include handling an arbitary number of speakers and moving speakers with a single network. This code allows you to run and evaluate our method on synthetically rendered data. If you have a multi-microphone array, you can also obtain real results like the ones in our demo video. 19 | 20 | ## Getting Started 21 | Clone the repository: 22 | ``` 23 | git clone https://github.com/vivjay30/Cone-of-Silence.git 24 | cd Cone-of-Silence 25 | export PYTHONPATH=$PYTHONPATH:`pwd` 26 | ``` 27 | 28 | Make sure all the requirements in the requirements.txt are installed. We tested the code with torch 1.3.0, librosa 0.7.0 and cuda 10.0 29 | 30 | Download Pretrained Models: [Here](https://drive.google.com/drive/folders/1YeuHPvqmaPMGvcSOb9J-hnLDYSbK1S2c?usp=sharing). If you're working in a command-line environment, we recommend using [gdown](https://github.com/wkentaro/gdown) to download the checkpoint files. 31 | 32 | ``` 33 | cd checkpoints 34 | gdown --id 1OcLxp0s_TN78iKaFrLAqjIoTKeOTUgKw # Download realdata_4mics_.03231m_44100kHz.pt 35 | gdown --id 18dpUnng_8ZUlDrQsg5VymypFnFlQBPIp # Download synthetic_6mics_.0725m_44100kHz.pt 36 | ``` 37 | 38 | ## Quickstart: Running on Real Data 39 | You can easily produce results like those in our demo videos. Our pre-trained real models work with the 4 mic [Seed ReSpeaker MicArray v 2.0](https://wiki.seeedstudio.com/ReSpeaker_Mic_Array_v2.0/). We even provide a sample 4 channel file for you to run [Here](https://drive.google.com/drive/folders/1YeuHPvqmaPMGvcSOb9J-hnLDYSbK1S2c?usp=sharing). When you capture the data, it must be a m channel recording. Run the full command like below. For moving sources, reduce the duration flag to 1.5 and add `--moving` to stop the search at a coarse window. 40 | ``` 41 | python cos/inference/separation_by_localization.py \ 42 | /path/to/model.pt \ 43 | /path/to/input_file.wav \ 44 | outputs/some_dirname/ \ 45 | --n_channels 4 \ 46 | --sr 44100 \ 47 | --mic_radius .03231 \ 48 | --use_cuda 49 | ``` 50 | 51 | ## Rendering Synthetic Spatial Data 52 | For training and evaluation, we use synthetically rendered spatial data. We place the voices in a virtual room and render the arrival times, level differences, and reverb using pyroomacoustics. We used the VCTK dataset but any voice dataset would work. An example command is below 53 | ``` 54 | python cos/generate_dataset.py \ 55 | /path/to/VCTK/data \ 56 | ./outputs/somename \ 57 | --input_background_path any_bg_audio.wav \ 58 | --n_voices 2 \ 59 | --n_outputs 1000 \ 60 | --mic_radius {radius} \ 61 | --n_mics {M} 62 | ``` 63 | 64 | ## Training on Synthetic Data 65 | Below is an example command to train on the rendered data. You need to replace the training and testing dirs with the path to the generated datasets from above. We highly recommend initializing with a pre-trained model (even if the number of mics is different) and not training from scratch. 66 | ``` 67 | python cos/training/train.py \ 68 | ./generated/train_dir \ 69 | ./generated/test_dir \ 70 | --name experiment_name \ 71 | --checkpoints_dir ./checkpoints \ 72 | --pretrain_path ./path/to/pretrained.pt \ 73 | --batch_size 8 \ 74 | --mic_radius {radius} \ 75 | --n_mics {M} \ 76 | --use_cuda 77 | ``` 78 | __Note__: The training code expects you to have `sox` installed. The easiest way to install is to install it using conda as follows: `conda install -c conda-forge -y sox`. 79 | 80 | ## Training on Real Data 81 | For those looking to improve on the pretrained models, we recommend gathering a lot more real data. We did not have the ability to gather very accurately positioned real data in a proper sound chamber. By training with a lot more real data, the results will almost certainly improve. All you have to do is create synthetic composites of speakers in the same format as the synthetic data, and run the same training script. 82 | 83 | ## Evaluation 84 | For the synthetic data and evaluation, we use a setup of 6 mics in a circle of radius 7.25 cm. The following is instructions to obtain results on mixtures of N voices and no backgrounds. First generate a synthetic datset with the microphone setup specified previous with ```--n_voices 8``` from the test set of VCTK. Then run the following script: 85 | 86 | ``` 87 | python cos/inference/evaluate_synthetic.py \ 88 | /path/to/rendered_data/ \ 89 | /path/to/model.pt \ 90 | --n_channels 6 \ 91 | --mic_radius .0725 \ 92 | --sr 44100 \ 93 | --use_cuda \ 94 | --n_workers 1 \ 95 | --n_voices {N} 96 | ``` 97 | 98 | Add ```--prec_recall``` separately to get the precision and recall. 99 | 100 | | Number of Speakers N | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 101 | |----------------------|-------|-------|-------|-------|-------|-------|-------| 102 | | Median SI-SDRi (dB) | 13.9 | 13.2 | 12.2 | 10.8 | 9.1 | 7.2 | 6.3 | 103 | | Median Angular Error | 2.0 | 2.3 | 2.7 | 3.5 | 4.4 | 5.2 | 6.3 | 104 | | Precision | 0.947 | 0.936 | 0.897 | 0.912 | 0.932 | 0.936 | 0.966 | 105 | | Recall | 0.979 | 0.972 | 0.915 | 0.898 | 0.859 | 0825 | 0.785 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /cos/helpers/mwf.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import multiprocessing.dummy as mp 5 | import os 6 | 7 | from pathlib import Path 8 | 9 | import librosa 10 | import numpy as np 11 | import tqdm 12 | 13 | from scipy.signal import stft, istft 14 | 15 | from cos.helpers.eval_utils import compute_sdr 16 | from cos.inference.evaluate_synthetic import get_items 17 | from cos.helpers.utils import check_valid_dir 18 | 19 | 20 | def invert(M,eps): 21 | """"inverting matrices M (matrices are the two last dimensions). 22 | This is assuming that these are 2x2 matrices, using the explicit 23 | inversion formula available in that case.""" 24 | invDet = 1.0/(eps + M[...,0,0]*M[...,1,1] - M[...,0,1]*M[...,1,0]) 25 | invM = np.zeros(M.shape,dtype='complex') 26 | invM[...,0,0] = invDet*M[...,1,1] 27 | invM[...,1,0] = -invDet*M[...,1,0] 28 | invM[...,0,1] = -invDet*M[...,0,1] 29 | invM[...,1,1] = invDet*M[...,0,0] 30 | return invM 31 | 32 | 33 | def compute_mwf(gt, mix): 34 | """ 35 | Computes the Ideal Ratio Mask SI-SDR 36 | gt: (n_voices, n_channels, t) 37 | mix: (n_channels, t) 38 | """ 39 | n_voices = gt.shape[0] 40 | nfft = 2048 41 | hop = 1024 42 | eps = np.finfo(np.float).eps 43 | 44 | N = mix.shape[-1] # number of samples 45 | X = stft(mix, nperseg=nfft)[2] 46 | (I, F, T) = X.shape # (6, nfft//2 +1, n_frame) 47 | 48 | # Allocate variables P: PSD, R: Spatial Covarianc Matrices 49 | P = [] 50 | R = [] 51 | for gt_idx in range(n_voices): 52 | # compute STFT of target source 53 | Yj = stft(gt[gt_idx], nperseg=nfft)[2] 54 | 55 | # Learn Power Spectral Density and spatial covariance matrix 56 | #----------------------------------------------------------- 57 | 58 | # 1/ compute observed covariance for source 59 | Rjj = np.zeros((F,T,I,I), dtype='complex') 60 | for (i1,i2) in itertools.product(range(I),range(I)): 61 | Rjj[...,i1,i2] = Yj[i1,...]*np.conj(Yj[i2,...]) 62 | 63 | # 2/ compute first naive estimate of the source spectrogram as the 64 | # average of spectrogram over channels 65 | P.append(np.mean(np.abs(Yj)**2,axis=0)) 66 | 67 | # 3/ take the spatial covariance matrix as the average of 68 | # the observed Rjj weighted Rjj by 1/Pj. This is because the 69 | # covariance is modeled as Pj Rj 70 | R.append(np.mean(Rjj / (eps+P[-1][...,None,None]), axis = 1)) 71 | 72 | # add some regularization to this estimate: normalize and add small 73 | # identify matrix, so we are sure it behaves well numerically. 74 | R[-1] = R[-1] * I/ np.trace(R[-1]) + eps * np.tile(np.eye(I,dtype='complex64')[None,...],(F,1,1)) 75 | 76 | # 4/ Now refine the power spectral density estimate. This is to better 77 | # estimate the PSD in case the source has some correlations between 78 | # channels. 79 | 80 | # invert Rj 81 | Rj_inv = invert(R[-1],eps) 82 | 83 | # now compute the PSD 84 | P[-1]=0 85 | for (i1,i2) in itertools.product(range(I),range(I)): 86 | P[-1] += 1./I*np.real(Rj_inv[:,i1,i2][:,None]*Rjj[...,i2,i1]) 87 | 88 | # All parameters are estimated. compute the mix covariance matrix as 89 | # the sum of the sources covariances. 90 | Cxx = 0 91 | for gt_idx in range(n_voices): 92 | Cxx += P[gt_idx][...,None,None]*R[gt_idx][:,None,...] 93 | # we need its inverse for computing the Wiener filter 94 | invCxx = invert(Cxx,eps) 95 | 96 | # perform separation 97 | estimates = [] 98 | for gt_idx in range(n_voices): 99 | # computes multichannel Wiener gain as Pj Rj invCxx 100 | G = np.zeros(invCxx.shape,dtype='complex64') 101 | SR = P[gt_idx][...,None,None]*R[gt_idx][:,None,...] 102 | for (i1,i2,i3) in itertools.product(range(I),range(I),range(I)): 103 | G[...,i1,i2] += SR[...,i1,i3]*invCxx[...,i3,i2] 104 | SR = 0 #free memory 105 | 106 | # separates by (matrix-)multiplying this gain with the mix. 107 | Yj=0 108 | for i in range(I): 109 | Yj+=G[...,i]*X[i,...,None] 110 | Yj = np.rollaxis(Yj,-1) #gets channels back in first position 111 | 112 | # inverte to time domain 113 | target_estimate = istft(Yj)[1][:,:N] 114 | 115 | estimates.append(target_estimate) 116 | 117 | estimates = np.array(estimates) # (nvoice, 6, 6*sr) 118 | # eval 119 | eval_mix = np.repeat(mix[np.newaxis, :, :], n_voices, axis=0) # (nvoice, 6, 6*sr) 120 | eval_gt = gt # (nvoice, 6, 6*sr) 121 | eval_est = estimates 122 | 123 | SDR_in = [] 124 | SDR_out = [] 125 | for i in range(n_voices): 126 | SDR_in.append(compute_sdr(eval_gt[i], eval_mix[i], single_channel=False)) # scalar 127 | SDR_out.append(compute_sdr(eval_gt[i], eval_est[i], single_channel=False)) # scalar 128 | 129 | output = np.array([SDR_in, SDR_out]) # (2, nvoice) 130 | return output 131 | 132 | 133 | def main(args): 134 | all_dirs = sorted(list(Path(args.input_dir).glob('[0-9]*'))) 135 | all_dirs = [x for x in all_dirs if check_valid_dir(x, args.n_voices)] 136 | 137 | all_input_sdr = [0] * len(all_dirs) 138 | all_output_sdr = [0] * len(all_dirs) 139 | 140 | def evaluate_dir(idx): 141 | curr_dir = all_dirs[idx] 142 | # Loads the data 143 | mixed_data, gt = get_items(curr_dir, args) 144 | gt = np.array([x.data for x in gt]) 145 | output = compute_mwf(gt, mixed_data) 146 | all_input_sdr[idx] = output[0] 147 | all_output_sdr[idx] = output[1] 148 | 149 | pool = mp.Pool(args.n_workers) 150 | with tqdm.tqdm(total=len(all_dirs)) as pbar: 151 | for i, _ in enumerate(pool.imap_unordered(evaluate_dir, range(len(all_dirs)))): 152 | pbar.update() 153 | 154 | # tqdm.tqdm(pool.imap(evaluate_dir, range(len(all_dirs))), total=len(all_dirs)) 155 | pool.close() 156 | pool.join() 157 | 158 | print("Median SI-SDRi: ", 159 | np.median(np.array(all_output_sdr).flatten() - np.array(all_input_sdr).flatten())) 160 | 161 | np.save("MWF_{}voices_{}kHz.npy".format(args.n_voices, args.sr), 162 | np.array([np.array(all_input_sdr).flatten(), np.array(all_output_sdr).flatten()])) 163 | 164 | 165 | if __name__ == '__main__': 166 | parser = argparse.ArgumentParser() 167 | parser.add_argument('input_dir', type=str, help="Path to the input dir") 168 | parser.add_argument('--sr', type=int, default=22050, help="Sampling rate") 169 | parser.add_argument('--n_channels', 170 | type=int, 171 | default=2, 172 | help="Number of channels") 173 | parser.add_argument('--n_workers', 174 | type=int, 175 | default=8, 176 | help="Number of parallel workers") 177 | parser.add_argument('--n_voices', 178 | type=int, 179 | default=2, 180 | help="Number of voices in the dataset") 181 | parser.add_argument('--alpha', 182 | type=int, 183 | default=1, 184 | help="See the original SigSep code for an explanation") 185 | args = parser.parse_args() 186 | 187 | main(args) 188 | -------------------------------------------------------------------------------- /cos/training/synthetic_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Torch dataset object for synthetically rendered 3 | spatial data 4 | """ 5 | import json 6 | import random 7 | 8 | from typing import Tuple 9 | from pathlib import Path 10 | 11 | import torch 12 | import numpy as np 13 | import librosa 14 | 15 | import cos.helpers.utils as utils 16 | from cos.helpers.constants import ALL_WINDOW_SIZES, FAR_FIELD_RADIUS 17 | from cos.training.data_augmentation import RandomAudioPerturbation 18 | 19 | 20 | class SyntheticDataset(torch.utils.data.Dataset): 21 | """ 22 | Synthetic Dataset of mixed waveforms and their corresponding ground truth waveforms 23 | recorded at different microphone. 24 | 25 | Data format is a pair of Tensors containing mixed waveforms and 26 | ground truth waveforms respectively. The tensor's dimension is formatted 27 | as (n_microphone, duration). 28 | """ 29 | def __init__(self, input_dir, n_mics=6, sr=44100, perturb_prob=0.0, 30 | window_idx=-1, negatives=0.2, mic_radius=0.0725): 31 | super().__init__() 32 | self.dirs = sorted(list(Path(input_dir).glob('[0-9]*'))) 33 | 34 | # Physical params 35 | self.n_mics = n_mics 36 | self.sr = sr 37 | self.mic_radius = mic_radius 38 | 39 | # Data augmentation 40 | self.perturb_prob = perturb_prob 41 | 42 | # Training params 43 | self.negatives = negatives # Fraction of negatives in training 44 | self.window_idx = window_idx # Set to -1 to pick randomly 45 | 46 | def __len__(self) -> int: 47 | return len(self.dirs) 48 | 49 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: 50 | """ 51 | Returns: 52 | mixed_data - M x T 53 | target_voice_data - M x T 54 | window_idx_one_hot - 1-D 55 | """ 56 | num_windows = len(ALL_WINDOW_SIZES) 57 | if self.window_idx == -1: 58 | curr_window_idx = np.random.randint(0, 5) 59 | else: 60 | curr_window_idx = self.window_idx 61 | 62 | curr_window_size = ALL_WINDOW_SIZES[curr_window_idx] 63 | candidate_angles = utils.get_starting_angles(curr_window_size) 64 | 65 | curr_dir = self.dirs[idx] 66 | 67 | # Get metadata 68 | with open(Path(curr_dir) / 'metadata.json') as json_file: 69 | metadata = json.load(json_file) 70 | 71 | # Random split of negatives and positives 72 | if np.random.uniform() < self.negatives: 73 | target_angle = self.get_negative_region(metadata, candidate_angles) 74 | else: 75 | target_angle = get_positive_region(metadata, candidate_angles) 76 | 77 | all_sources, target_voice_data = self.get_mixture_and_gt( 78 | metadata, curr_dir, target_angle, curr_window_size) 79 | 80 | # Mixture 81 | all_sources = torch.stack(all_sources, dim=0) 82 | mixed_data = torch.sum(all_sources, dim=0) 83 | 84 | # GTs 85 | target_voice_data = torch.stack(target_voice_data, dim=0) 86 | target_voice_data = torch.sum(target_voice_data, dim=0) 87 | 88 | window_idx_one_hot = torch.tensor( 89 | utils.to_categorical(curr_window_idx, num_windows)).float() 90 | 91 | return (mixed_data, target_voice_data, window_idx_one_hot) 92 | 93 | def get_negative_region(self, metadata, candidate_angles): 94 | """Chooses a target angle which is adjacent to a voice region""" 95 | # Choose a random voice 96 | voice_keys = [x for x in metadata if "voice" in x] 97 | random_key = random.choice(voice_keys) 98 | voice_pos = np.array(metadata[random_key]["position"]) 99 | voice_angle = np.arctan2(voice_pos[1], voice_pos[0]) 100 | angle_idx = (np.abs(candidate_angles - voice_angle)).argmin() 101 | 102 | # Non uniform distribution to prefer regions close to a voice 103 | p = np.zeros_like(candidate_angles) 104 | for i in range(p.shape[0]): 105 | if i == angle_idx: 106 | # Can't choose the positive region 107 | p[i] = 0 108 | else: 109 | # Regions close to the voice are weighted more 110 | dist = min(abs(i - angle_idx), 111 | (len(candidate_angles) - angle_idx + i)) 112 | p[i] = 1 / (dist) 113 | 114 | p /= p.sum() 115 | 116 | # Make sure we choose a region with different per-channel shifts from the voice 117 | matching_shift = True 118 | _, true_shift = utils.shift_mixture(np.zeros( 119 | (self.n_mics, 10)), voice_pos, self.mic_radius, self.sr) 120 | while matching_shift: 121 | target_angle = np.random.choice(candidate_angles, p=p) 122 | random_pos = np.array([ 123 | FAR_FIELD_RADIUS * np.cos(target_angle), 124 | FAR_FIELD_RADIUS * np.sin(target_angle) 125 | ]) 126 | _, curr_shift = utils.shift_mixture(np.zeros( 127 | (self.n_mics, 10)), random_pos, self.mic_radius, self.sr) 128 | if true_shift != curr_shift: 129 | matching_shift = False 130 | 131 | return target_angle 132 | 133 | def get_mixture_and_gt(self, metadata, curr_dir, target_angle, 134 | curr_window_size): 135 | """ 136 | Given a target angle and window size, this function figures out 137 | the voices inside the region and returns them as GT waveforms 138 | """ 139 | target_pos = np.array([ 140 | FAR_FIELD_RADIUS * np.cos(target_angle), 141 | FAR_FIELD_RADIUS * np.sin(target_angle) 142 | ]) 143 | random_perturb = RandomAudioPerturbation() 144 | 145 | # Iterate over different sources 146 | all_sources = [] 147 | target_voice_data = [] 148 | for key in metadata.keys(): 149 | gt_audio_files = sorted( 150 | list(Path(curr_dir).rglob("*" + key + ".wav"))) 151 | assert len(gt_audio_files) > 0, "No files found in {}".format( 152 | curr_dir) 153 | gt_waveforms = [] 154 | 155 | # Iterate over different mics 156 | for _, gt_audio_file in enumerate(gt_audio_files): 157 | gt_waveform, _ = librosa.core.load(gt_audio_file, self.sr, 158 | mono=True) 159 | gt_waveforms.append(torch.from_numpy(gt_waveform)) 160 | 161 | shifted_gt, _ = utils.shift_mixture(np.stack(gt_waveforms), 162 | target_pos, 163 | self.mic_radius, self.sr) 164 | 165 | # Data augmentation 166 | if np.random.uniform() < self.perturb_prob: 167 | perturbed_source = torch.tensor( 168 | random_perturb(shifted_gt)).float() 169 | else: 170 | perturbed_source = torch.tensor(shifted_gt).float() 171 | 172 | all_sources.append(perturbed_source) 173 | 174 | # Check which foregrounds are in the angle of interest 175 | if "bg" in key: 176 | continue 177 | 178 | locs_voice = metadata[key]['position'] 179 | voice_angle = np.arctan2(locs_voice[1], locs_voice[0]) 180 | 181 | # Voice is inside our target area. Need to save for ground truth 182 | if abs(voice_angle - target_angle) < (curr_window_size / 2): 183 | target_voice_data.append( 184 | perturbed_source.view(perturbed_source.shape[0], 185 | perturbed_source.shape[1])) 186 | 187 | # Train with front back confusion for 2 mics 188 | elif self.n_mics == 2 and abs(-voice_angle - target_angle) < ( 189 | curr_window_size / 2): 190 | target_voice_data.append( 191 | perturbed_source.view(perturbed_source.shape[0], 192 | perturbed_source.shape[1])) 193 | 194 | # Voice is not within our region. Add silence 195 | else: 196 | target_voice_data.append( 197 | torch.zeros((perturbed_source.shape[0], 198 | perturbed_source.shape[1]))) 199 | 200 | return all_sources, target_voice_data 201 | 202 | 203 | def get_positive_region(metadata, candidate_angles): 204 | """Chooses a target angle containing a voice region""" 205 | # Choose a random voice 206 | voice_keys = [x for x in metadata if "voice" in x] 207 | random_key = random.choice(voice_keys) 208 | voice_pos = metadata[random_key]["position"] 209 | voice_pos = np.array(voice_pos) 210 | voice_angle = np.arctan2(voice_pos[1], voice_pos[0]) 211 | 212 | # Get the sector closest to that voice 213 | angle_idx = (np.abs(candidate_angles - voice_angle)).argmin() 214 | target_angle = candidate_angles[angle_idx] 215 | 216 | return target_angle 217 | -------------------------------------------------------------------------------- /cos/training/network.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def rescale_conv(conv, reference): 8 | """ 9 | Rescale a convolutional module with `reference`. 10 | """ 11 | std = conv.weight.std().detach() 12 | scale = (std / reference)**0.5 13 | conv.weight.data /= scale 14 | if conv.bias is not None: 15 | conv.bias.data /= scale 16 | 17 | 18 | def rescale_module(module, reference): 19 | """ 20 | Rescale a module with `reference`. 21 | """ 22 | for sub in module.modules(): 23 | if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): 24 | rescale_conv(sub, reference) 25 | 26 | 27 | def center_trim(tensor, reference): 28 | """ 29 | Trim a tensor to match with the dimension of `reference`. 30 | """ 31 | if hasattr(reference, "size"): 32 | reference = reference.size(-1) 33 | diff = tensor.size(-1) - reference 34 | if diff < 0: 35 | raise ValueError("tensor must be larger than reference") 36 | if diff: 37 | tensor = tensor[..., diff // 2:-(diff - diff // 2)] 38 | return tensor 39 | 40 | 41 | def left_trim(tensor, reference): 42 | """ 43 | Trim a tensor to match with the dimension of `reference`. Trims only the end. 44 | """ 45 | if hasattr(reference, "size"): 46 | reference = reference.size(-1) 47 | diff = tensor.size(-1) - reference 48 | if diff < 0: 49 | raise ValueError("tensor must be larger than reference") 50 | if diff: 51 | tensor = tensor[..., 0:-diff] 52 | return tensor 53 | 54 | def normalize_input(data): 55 | """ 56 | Normalizes the input to have mean 0 std 1 for each input 57 | Inputs: 58 | data - torch.tensor of size batch x n_mics x n_samples 59 | """ 60 | data = (data * 2**15).round() / 2**15 61 | ref = data.mean(1) # Average across the n microphones 62 | means = ref.mean(1).unsqueeze(1).unsqueeze(2) 63 | stds = ref.std(1).unsqueeze(1).unsqueeze(2) 64 | data = (data - means) / stds 65 | 66 | return data, means, stds 67 | 68 | def unnormalize_input(data, means, stds): 69 | """ 70 | Unnormalizes the step done in the previous function 71 | """ 72 | data = (data * stds.unsqueeze(3) + means.unsqueeze(3)) 73 | return data 74 | 75 | 76 | class CoSNetwork(nn.Module): 77 | """ 78 | Cone of Silence network based on the Demucs network for audio source separation. 79 | """ 80 | def __init__( 81 | self, 82 | n_audio_channels: int = 4, # pylint: disable=redefined-outer-name 83 | window_conditioning_size: int = 5, 84 | kernel_size: int = 8, 85 | stride: int = 4, 86 | context: int = 3, 87 | depth: int = 6, 88 | channels: int = 64, 89 | growth: float = 2.0, 90 | lstm_layers: int = 2, 91 | rescale: float = 0.1): # pylint: disable=redefined-outer-name 92 | super().__init__() 93 | self.n_audio_channels = n_audio_channels 94 | self.window_conditioning_size = window_conditioning_size 95 | self.kernel_size = kernel_size 96 | self.stride = stride 97 | self.context = context 98 | self.depth = depth 99 | self.channels = channels 100 | self.growth = growth 101 | self.lstm_layers = lstm_layers 102 | self.rescale = rescale 103 | 104 | self.encoder = nn.ModuleList() # Source encoder 105 | self.decoder = nn.ModuleList() # Audio output decoder 106 | 107 | activation = nn.GLU(dim=1) 108 | 109 | in_channels = n_audio_channels # Number of input channels 110 | 111 | # Wave U-Net structure 112 | for index in range(depth): 113 | encode = nn.ModuleDict() 114 | encode["conv1"] = nn.Conv1d(in_channels, channels, kernel_size, 115 | stride) 116 | encode["relu"] = nn.ReLU() 117 | 118 | encode["conv2"] = nn.Conv1d(channels, 2 * channels, 1) 119 | encode["activation"] = activation 120 | 121 | encode["gc_embed1"] = nn.Conv1d(self.window_conditioning_size, channels, 1) 122 | encode["gc_embed2"] = nn.Conv1d(self.window_conditioning_size, 2 * channels, 1) 123 | 124 | self.encoder.append(encode) 125 | 126 | decode = nn.ModuleDict() 127 | if index > 0: 128 | out_channels = in_channels 129 | else: 130 | out_channels = 2 * n_audio_channels 131 | 132 | decode["conv1"] = nn.Conv1d(channels, 2 * channels, context) 133 | decode["activation"] = activation 134 | decode["conv2"] = nn.ConvTranspose1d(channels, out_channels, 135 | kernel_size, stride) 136 | 137 | decode["gc_embed1"] = nn.Conv1d(self.window_conditioning_size, 2 * channels, 1) 138 | decode["gc_embed2"] = nn.Conv1d(self.window_conditioning_size, out_channels, 1) 139 | 140 | if index > 0: 141 | decode["relu"] = nn.ReLU() 142 | self.decoder.insert(0, 143 | decode) # Put it at the front, reverse order 144 | 145 | in_channels = channels 146 | channels = int(growth * channels) 147 | 148 | # Bi-directional LSTM for the bottleneck layer 149 | channels = in_channels 150 | self.lstm = nn.LSTM(bidirectional=True, 151 | num_layers=lstm_layers, 152 | hidden_size=channels, 153 | input_size=channels) 154 | self.lstm_linear = nn.Linear(2 * channels, channels) 155 | 156 | rescale_module(self, reference=rescale) 157 | 158 | def forward(self, mix: torch.Tensor, angle_conditioning: torch.Tensor): # pylint: disable=arguments-differ 159 | """ 160 | Forward pass. Note that in our current work the use of `locs` is disregarded. 161 | 162 | Args: 163 | mix (torch.Tensor) - An input recording of size `(batch_size, n_mics, time)`. 164 | 165 | Output: 166 | x - A source separation output at every microphone 167 | """ 168 | x = mix 169 | saved = [x] 170 | 171 | # Encoder 172 | for encode in self.encoder: 173 | x = encode["conv1"](x) # Conv 1d 174 | embedding = encode["gc_embed1"](angle_conditioning.unsqueeze(2)) 175 | 176 | x = encode["relu"](x + embedding) 177 | x = encode["conv2"](x) 178 | 179 | embedding2 = encode["gc_embed2"](angle_conditioning.unsqueeze(2)) 180 | x = encode["activation"](x + embedding2) 181 | saved.append(x) 182 | 183 | # Bi-directional LSTM at the bottleneck layer 184 | x = x.permute(2, 0, 1) # prep input for LSTM 185 | self.lstm.flatten_parameters() # to improve memory usage. 186 | x = self.lstm(x)[0] 187 | x = self.lstm_linear(x) 188 | x = x.permute(1, 2, 0) 189 | 190 | # Source decoder 191 | for decode in self.decoder: 192 | skip = center_trim(saved.pop(-1), x) 193 | x = x + skip 194 | 195 | x = decode["conv1"](x) 196 | embedding = decode["gc_embed1"](angle_conditioning.unsqueeze(2)) 197 | x = decode["activation"](x + embedding) 198 | x = decode["conv2"](x) 199 | embedding2 = decode["gc_embed2"](angle_conditioning.unsqueeze(2)) 200 | if "relu" in decode: 201 | x = decode["relu"](x + embedding2) 202 | 203 | # Reformat the output 204 | x = x.view(x.size(0), 2, self.n_audio_channels, x.size(-1)) 205 | 206 | return x 207 | 208 | def loss(self, voice_signals, gt_voice_signals): 209 | """Simple L1 loss between voice and gt""" 210 | return F.l1_loss(voice_signals, gt_voice_signals) 211 | 212 | def valid_length(self, length: int) -> int: # pylint: disable=redefined-outer-name 213 | """ 214 | Find the length of the input to the network such that the output's length is 215 | equal to the given `length`. 216 | """ 217 | for _ in range(self.depth): 218 | length = math.ceil((length - self.kernel_size) / self.stride) + 1 219 | length = max(1, length) 220 | length += self.context - 1 221 | 222 | for _ in range(self.depth): 223 | length = (length - 1) * self.stride + self.kernel_size 224 | 225 | return int(length) 226 | 227 | 228 | def load_pretrain(model, state_dict): # pylint: disable=redefined-outer-name 229 | """Loads the pretrained keys in state_dict into model""" 230 | for key in state_dict.keys(): 231 | try: 232 | _ = model.load_state_dict({key: state_dict[key]}, strict=False) 233 | print("Loaded {} (shape = {}) from the pretrained model".format( 234 | key, state_dict[key].shape)) 235 | except Exception as e: 236 | print("Failed to load {}".format(key)) 237 | print(e) 238 | -------------------------------------------------------------------------------- /cos/generate_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import argparse 5 | import json 6 | from typing import List 7 | from pathlib import Path 8 | import tqdm 9 | import random 10 | 11 | import multiprocessing.dummy as mp 12 | 13 | import numpy as np 14 | import librosa 15 | import pyroomacoustics as pra 16 | import soundfile as sf 17 | 18 | # Mean and STD of the signal peak 19 | FG_VOL_MIN = 0.15 20 | FG_VOL_MAX = 0.4 21 | 22 | BG_VOL_MIN = 0.2 23 | BG_VOL_MAX = 0.5 24 | 25 | 26 | def generate_mic_array(room, mic_radius: float, n_mics: int): 27 | """ 28 | Generate a list of Microphone objects 29 | 30 | Radius = 50th percentile of men Bitragion breadth 31 | (https://en.wikipedia.org/wiki/Human_head) 32 | """ 33 | R = pra.circular_2D_array(center=[0., 0.], M=n_mics, phi0=0, radius=mic_radius) 34 | room.add_microphone_array(pra.MicrophoneArray(R, room.fs)) 35 | 36 | 37 | def handle_error(e): 38 | print(e) 39 | 40 | 41 | def get_voices(args): 42 | # Make sure we dont get an empty sequence 43 | success = False 44 | while not success: 45 | voice_files = random.sample(args.all_voices, args.n_voices) 46 | # Save the identity also. This is VCTK specific 47 | success = True 48 | voices_data = [] 49 | for voice_file in voice_files: 50 | voice_identity = str(voice_file).split("/")[-1].split("_")[0] 51 | voice, _ = librosa.core.load(voice_file, sr=args.sr, mono=True) 52 | voice, _ = librosa.effects.trim(voice) 53 | if voice.std() == 0: 54 | success = False 55 | voices_data.append((voice, voice_identity)) 56 | 57 | return voices_data 58 | 59 | 60 | def generate_sample(args: argparse.Namespace, bg: np.ndarray, idx: int) -> int: 61 | """ 62 | Generate a single sample. Return 0 on success. 63 | 64 | Steps: 65 | - [1] Load voice 66 | - [2] Sample background with the same length as voice. 67 | - [3] Pick background location 68 | - [4] Create a scene 69 | - [5] Render sound 70 | - [6] Save metadata 71 | """ 72 | # [1] load voice 73 | output_prefix_dir = os.path.join(args.output_path, '{:05d}'.format(idx)) 74 | Path(output_prefix_dir).mkdir(parents=True, exist_ok=True) 75 | 76 | voices_data = get_voices(args) 77 | 78 | # [2] 79 | total_samples = int(args.duration * args.sr) 80 | if bg is not None: 81 | bg_length = len(bg) 82 | bg_start_idx = np.random.randint(bg_length - total_samples) 83 | sample_bg = bg[bg_start_idx:bg_start_idx + total_samples] 84 | 85 | # Generate room parameters, each scene has a random room and absorption 86 | left_wall = np.random.uniform(low=-20, high=-15) 87 | right_wall = np.random.uniform(low=15, high=20) 88 | top_wall = np.random.uniform(low=15, high=20) 89 | bottom_wall = np.random.uniform(low=-20, high=-15) 90 | absorption = np.random.uniform(low=0.1, high=0.99) 91 | corners = np.array([[left_wall, bottom_wall], [left_wall, top_wall], 92 | [ right_wall, top_wall], [right_wall, bottom_wall]]).T 93 | 94 | # FG 95 | all_fg_signals = [] 96 | voice_positions = [] 97 | for voice_idx in range(args.n_voices): 98 | # Need to re-generate room to save GT. Could probably be optimized 99 | room = pra.Room.from_corners(corners, 100 | fs=args.sr, 101 | max_order=10, 102 | absorption=absorption) 103 | mic_array = generate_mic_array(room, args.mic_radius, args.n_mics) 104 | 105 | voice_radius = np.random.uniform(low=1.0, high=5.0) 106 | voice_theta = np.random.uniform(low=0, high=2 * np.pi) 107 | voice_loc = [ 108 | voice_radius * np.cos(voice_theta), 109 | voice_radius * np.sin(voice_theta) 110 | ] 111 | 112 | voice_positions.append(voice_loc) 113 | room.add_source(voice_loc, signal=voices_data[voice_idx][0]) 114 | 115 | room.image_source_model(use_libroom=True) 116 | room.simulate() 117 | fg_signals = room.mic_array.signals[:, :total_samples] 118 | fg_target = np.random.uniform(FG_VOL_MIN, FG_VOL_MAX) 119 | fg_signals = fg_signals * fg_target / abs(fg_signals).max() 120 | all_fg_signals.append(fg_signals) 121 | 122 | # BG 123 | if bg is not None: 124 | bg_radius = np.random.uniform(low=10.0, high=20.0) 125 | bg_theta = np.random.uniform(low=0, high=2 * np.pi) 126 | bg_loc = [bg_radius * np.cos(bg_theta), bg_radius * np.sin(bg_theta)] 127 | 128 | # Bg should be further away to be diffuse 129 | left_wall = np.random.uniform(low=-40, high=-20) 130 | right_wall = np.random.uniform(low=20, high=40) 131 | top_wall = np.random.uniform(low=20, high=40) 132 | bottom_wall = np.random.uniform(low=-40, high=-20) 133 | corners = np.array([[left_wall, bottom_wall], [left_wall, top_wall], 134 | [ right_wall, top_wall], [right_wall, bottom_wall]]).T 135 | absorption = np.random.uniform(low=0.5, high=0.99) 136 | room = pra.Room.from_corners(corners, 137 | fs=args.sr, 138 | max_order=10, 139 | absorption=absorption) 140 | mic_array = generate_mic_array(room, args.mic_radius, args.n_mics) 141 | room.add_source(bg_loc, signal=sample_bg) 142 | 143 | room.image_source_model(use_libroom=True) 144 | room.simulate() 145 | bg_signals = room.mic_array.signals[:, :total_samples] 146 | bg_target = np.random.uniform(BG_VOL_MIN, BG_VOL_MAX) 147 | bg_signals = bg_signals * bg_target / abs(bg_signals).max() 148 | 149 | # Save 150 | for mic_idx in range(args.n_mics): 151 | output_prefix = str( 152 | Path(output_prefix_dir) / "mic{:02d}_".format(mic_idx)) 153 | 154 | # Save FG 155 | all_fg_buffer = np.zeros((total_samples)) 156 | for voice_idx in range(args.n_voices): 157 | curr_fg_buffer = np.pad(all_fg_signals[voice_idx][mic_idx], 158 | (0, total_samples))[:total_samples] 159 | sf.write(output_prefix + "voice{:02d}.wav".format(voice_idx), 160 | curr_fg_buffer, args.sr) 161 | all_fg_buffer += curr_fg_buffer 162 | 163 | if bg is not None: 164 | bg_buffer = np.pad(bg_signals[mic_idx], 165 | (0, total_samples))[:total_samples] 166 | sf.write(output_prefix + "bg.wav", bg_buffer, args.sr) 167 | 168 | sf.write(output_prefix + "mixed.wav", all_fg_buffer + bg_buffer, 169 | args.sr) 170 | else: 171 | sf.write(output_prefix + "mixed.wav", all_fg_buffer, 172 | args.sr) 173 | 174 | # [6] 175 | metadata = {} 176 | for voice_idx in range(args.n_voices): 177 | metadata['voice{:02d}'.format(voice_idx)] = { 178 | 'position': voice_positions[voice_idx], 179 | 'speaker_id': voices_data[voice_idx][1] 180 | } 181 | 182 | if bg is not None: 183 | metadata['bg'] = {'position': bg_loc} 184 | 185 | metadata_file = str(Path(output_prefix_dir) / "metadata.json") 186 | with open(metadata_file, "w") as f: 187 | json.dump(metadata, f, indent=4) 188 | 189 | 190 | def main(args: argparse.Namespace): 191 | np.random.seed(args.seed) 192 | 193 | # Preload background to save time 194 | if args.input_background_path: 195 | background, _ = librosa.core.load(args.input_background_path, 196 | sr=args.sr, 197 | mono=True) 198 | else: 199 | background = None 200 | 201 | all_voices = Path(args.input_voice_dir).rglob('*.wav') 202 | args.all_voices = list(all_voices) 203 | if len(args.all_voices) == 0: 204 | raise ValueError("No voice files found") 205 | 206 | pbar = tqdm.tqdm(total=args.n_outputs) 207 | pool = mp.Pool(args.n_workers) 208 | callback_fn = lambda _: pbar.update() 209 | for i in range(args.n_outputs): 210 | pool.apply_async(generate_sample, 211 | args=(args, background, i), 212 | callback=callback_fn, 213 | error_callback=handle_error) 214 | pool.close() 215 | pool.join() 216 | pbar.close() 217 | 218 | 219 | if __name__ == '__main__': 220 | parser = argparse.ArgumentParser() 221 | parser.add_argument('input_voice_dir', 222 | type=str, 223 | help="Directory with voice wav files") 224 | parser.add_argument('output_path', type=str, help="Output directory to write the synthetic dataset") 225 | parser.add_argument('--input_background_path', 226 | type=str) 227 | parser.add_argument('--n_mics', type=int, default=4) 228 | parser.add_argument('--mic_radius', 229 | default=.03231, 230 | type=float, 231 | help="Radius of the mic array in meters") 232 | parser.add_argument('--n_voices', type=int, default=4) 233 | parser.add_argument('--n_outputs', type=int, default=10000) 234 | parser.add_argument('--n_workers', type=int, default=8) 235 | parser.add_argument('--seed', type=int, default=42) 236 | parser.add_argument('--sr', type=int, default=44100) 237 | parser.add_argument('--duration', type=float, default=3.0) 238 | main(parser.parse_args()) 239 | -------------------------------------------------------------------------------- /cos/training/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | The main training script for training on synthetic data 3 | """ 4 | 5 | import argparse 6 | import multiprocessing 7 | import os 8 | 9 | from typing import Dict, List, Tuple, Optional # pylint: disable=unused-import 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.optim as optim 16 | import torch.nn.functional as F 17 | import tqdm # pylint: disable=unused-import 18 | 19 | from cos.training.synthetic_dataset import SyntheticDataset 20 | from cos.training.network import CoSNetwork, \ 21 | center_trim, load_pretrain, \ 22 | normalize_input, unnormalize_input 23 | 24 | 25 | def train_epoch(model: nn.Module, device: torch.device, 26 | optimizer: optim.Optimizer, 27 | train_loader: torch.utils.data.dataloader.DataLoader, 28 | epoch: int, log_interval: int = 20) -> float: 29 | """ 30 | Train a single epoch. 31 | """ 32 | # Set the model to training. 33 | model.train() 34 | 35 | # Training loop 36 | losses = [] 37 | interval_losses = [] 38 | 39 | for batch_idx, (data, label_voice_signals, 40 | window_idx) in enumerate(train_loader): 41 | data = data.to(device) 42 | label_voice_signals = label_voice_signals.to(device) 43 | window_idx = window_idx.to(device) 44 | 45 | # Normalize input, each batch item separately 46 | data, means, stds = normalize_input(data) 47 | 48 | # Reset grad 49 | optimizer.zero_grad() 50 | 51 | # Run through the model 52 | valid_length = model.valid_length(data.shape[-1]) 53 | delta = valid_length - data.shape[-1] 54 | padded = F.pad(data, (delta // 2, delta - delta // 2)) 55 | 56 | output_signal = model(padded, window_idx) 57 | output_signal = center_trim(output_signal, data) 58 | 59 | # Un-normalize 60 | output_signal = unnormalize_input(output_signal, means, stds) 61 | output_voices = output_signal[:, 0] 62 | 63 | loss = model.loss(output_voices, label_voice_signals) 64 | 65 | interval_losses.append(loss.item()) 66 | 67 | # Backpropagation 68 | loss.backward() 69 | 70 | # Gradient clipping 71 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) 72 | 73 | # Update the weights 74 | optimizer.step() 75 | 76 | # Print the loss 77 | if batch_idx % log_interval == 0: 78 | print("Train Epoch: {} [{}/{} ({:.0f}%)] \t Loss: {:.6f}".format( 79 | epoch, batch_idx * len(data), len(train_loader.dataset), 80 | 100. * batch_idx / len(train_loader), 81 | np.mean(interval_losses))) 82 | 83 | losses.extend(interval_losses) 84 | interval_losses = [] 85 | 86 | return np.mean(losses) 87 | 88 | 89 | def test_epoch(model: nn.Module, device: torch.device, 90 | test_loader: torch.utils.data.dataloader.DataLoader, 91 | log_interval: int = 20) -> float: 92 | """ 93 | Evaluate the network. 94 | """ 95 | model.eval() 96 | test_loss = 0 97 | 98 | with torch.no_grad(): 99 | for batch_idx, (data, label_voice_signals, 100 | window_idx) in enumerate(test_loader): 101 | data = data.to(device) 102 | label_voice_signals = label_voice_signals.to(device) 103 | window_idx = window_idx.to(device) 104 | 105 | # Normalize input, each batch item separately 106 | data, means, stds = normalize_input(data) 107 | 108 | valid_length = model.valid_length(data.shape[-1]) 109 | delta = valid_length - data.shape[-1] 110 | padded = F.pad(data, (delta // 2, delta - delta // 2)) 111 | 112 | # Run through the model 113 | output_signal = model(padded, window_idx) 114 | output_signal = center_trim(output_signal, data) 115 | 116 | # Un-normalize 117 | output_signal = unnormalize_input(output_signal, means, stds) 118 | output_voices = output_signal[:, 0] 119 | 120 | loss = model.loss(output_voices, label_voice_signals) 121 | test_loss += loss.item() 122 | 123 | if batch_idx % log_interval == 0: 124 | print("Loss: {}".format(loss)) 125 | 126 | test_loss /= len(test_loader) 127 | print("\nTest set: Average Loss: {:.4f}\n".format(test_loss)) 128 | 129 | return test_loss 130 | 131 | 132 | def train(args: argparse.Namespace): 133 | """ 134 | Train the network. 135 | """ 136 | # Load dataset 137 | data_train = SyntheticDataset(args.train_dir, n_mics=args.n_mics, 138 | sr=args.sr, perturb_prob=1.0, 139 | mic_radius=args.mic_radius) 140 | data_test = SyntheticDataset(args.test_dir, n_mics=args.n_mics, 141 | sr=args.sr, mic_radius=args.mic_radius) 142 | 143 | # Set up the device and workers. 144 | use_cuda = args.use_cuda and torch.cuda.is_available() 145 | device = torch.device('cuda' if use_cuda else 'cpu') 146 | print("Using device {}".format('cuda' if use_cuda else 'cpu')) 147 | 148 | # Set multiprocessing params 149 | num_workers = min(multiprocessing.cpu_count(), args.n_workers) 150 | kwargs = { 151 | 'num_workers': num_workers, 152 | 'pin_memory': True 153 | } if use_cuda else {} 154 | 155 | # Set up data loaders 156 | train_loader = torch.utils.data.DataLoader(data_train, 157 | batch_size=args.batch_size, 158 | shuffle=True, **kwargs) 159 | test_loader = torch.utils.data.DataLoader(data_test, 160 | batch_size=args.batch_size, 161 | **kwargs) 162 | 163 | # Set up model 164 | model = CoSNetwork(n_audio_channels=args.n_mics) 165 | model.to(device) 166 | 167 | # Set up checkpoints 168 | if not os.path.exists(os.path.join(args.checkpoints_dir, args.name)): 169 | os.makedirs(os.path.join(args.checkpoints_dir, args.name)) 170 | 171 | # Set up the optimizer 172 | optimizer = optim.Adam(model.parameters(), lr=args.lr, 173 | weight_decay=args.decay) 174 | 175 | # Load pretrain 176 | if args.pretrain_path: 177 | state_dict = torch.load(args.pretrain_path) 178 | load_pretrain(model, state_dict) 179 | 180 | # Load the model if `args.start_epoch` is greater than 0. This will load the model from 181 | # epoch = `args.start_epoch - 1` 182 | if args.start_epoch is not None: 183 | assert args.start_epoch > 0, "start_epoch must be greater than 0." 184 | start_epoch = args.start_epoch 185 | checkpoint_path = Path( 186 | args.checkpoints_dir) / "{}.pt".format(start_epoch - 1) 187 | state_dict = torch.load(checkpoint_path) 188 | model.load_state_dict(state_dict) 189 | else: 190 | start_epoch = 0 191 | 192 | # Loss values 193 | train_losses = [] 194 | test_losses = [] 195 | 196 | # Training loop 197 | try: 198 | for epoch in range(start_epoch, args.epochs + 1): 199 | train_loss = train_epoch(model, device, optimizer, train_loader, 200 | epoch, args.print_interval) 201 | torch.save( 202 | model.state_dict(), 203 | os.path.join(args.checkpoints_dir, args.name, 204 | "{}.pt".format(epoch))) 205 | print("Done with training, going to testing") 206 | test_loss = test_epoch(model, device, test_loader, 207 | args.print_interval) 208 | train_losses.append((epoch, train_loss)) 209 | test_losses.append((epoch, test_loss)) 210 | 211 | return train_losses, test_losses 212 | 213 | except KeyboardInterrupt: 214 | print("Interrupted") 215 | except Exception as _: # pylint: disable=broad-except 216 | import traceback # pylint: disable=import-outside-toplevel 217 | traceback.print_exc() 218 | 219 | 220 | if __name__ == '__main__': 221 | parser = argparse.ArgumentParser() 222 | # Data Params 223 | parser.add_argument('train_dir', type=str, 224 | help="Path to the training dataset") 225 | parser.add_argument('test_dir', type=str, 226 | help="Path to the testing dataset") 227 | parser.add_argument('--name', type=str, default="multimic_experiment", 228 | help="Name of the experiment") 229 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', 230 | help="Path to the checkpoints") 231 | parser.add_argument('--batch_size', type=int, default=8, 232 | help="Batch size") 233 | 234 | # Physical Params 235 | parser.add_argument('--n_mics', type=int, default=4, 236 | help="Number of mics (also number of channels)") 237 | parser.add_argument('--mic_radius', default=.03231, type=float, 238 | help="Radius in meters of the mic array") 239 | 240 | # Training Params 241 | parser.add_argument('--epochs', type=int, default=100, 242 | help="Number of epochs") 243 | parser.add_argument('--lr', type=float, default=3e-4, help="learning rate") 244 | parser.add_argument('--sr', type=int, default=44100, help="Sampling rate") 245 | parser.add_argument('--decay', type=float, default=0, help="Weight decay") 246 | parser.add_argument('--n_workers', type=int, default=16, 247 | help="Number of parallel workers") 248 | parser.add_argument('--print_interval', type=int, default=20, 249 | help="Logging interval") 250 | 251 | parser.add_argument('--start_epoch', type=int, default=None, 252 | help="Start epoch") 253 | parser.add_argument('--pretrain_path', type=str, 254 | help="Path to pretrained weights") 255 | parser.add_argument('--use_cuda', dest='use_cuda', action='store_true', 256 | help="Whether to use cuda") 257 | 258 | train(parser.parse_args()) 259 | -------------------------------------------------------------------------------- /cos/inference/evaluate_synthetic.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | from pathlib import Path 6 | 7 | import torch 8 | import numpy as np 9 | import librosa 10 | import soundfile as sf 11 | import tqdm 12 | 13 | from cos.helpers.eval_utils import find_best_permutation_prec_recall, compute_sdr 14 | from cos.helpers.utils import angular_distance, check_valid_dir 15 | from cos.training.network import CoSNetwork 16 | from cos.inference.separation_by_localization import run_separation, CandidateVoice 17 | 18 | import multiprocessing.dummy as mp 19 | from multiprocessing import Lock 20 | 21 | 22 | def get_items(curr_dir, args): 23 | """ 24 | This is a modified version of the SpatialAudioDataset DataLoader 25 | """ 26 | with open(Path(curr_dir) / 'metadata.json') as json_file: 27 | json_data = json.load(json_file) 28 | 29 | num_voices = args.n_voices 30 | mic_files = sorted(list(Path(curr_dir).rglob('*mixed.wav'))) 31 | 32 | # All voice signals 33 | keys = ["voice{:02}".format(i) for i in range(num_voices)] 34 | 35 | # Comment out this line to do voice only, no bg 36 | if "bg" in json_data: 37 | keys.append("bg") 38 | """ 39 | Loading the sources 40 | """ 41 | # Iterate over different sources 42 | all_sources = [] 43 | target_voice_data = [] 44 | voice_positions = [] 45 | for key in keys: 46 | gt_audio_files = sorted(list(Path(curr_dir).rglob("*" + key + ".wav"))) 47 | assert (len(gt_audio_files) > 0) 48 | gt_waveforms = [] 49 | 50 | # Iterate over different mics 51 | for _, gt_audio_file in enumerate(gt_audio_files): 52 | gt_waveform, _ = librosa.core.load(gt_audio_file, args.sr, 53 | mono=True) 54 | gt_waveforms.append(gt_waveform) 55 | 56 | single_source = np.stack(gt_waveforms) 57 | all_sources.append(single_source) 58 | locs_voice = np.arctan2(json_data[key]["position"][1], 59 | json_data[key]["position"][0]) 60 | voice_positions.append(locs_voice) 61 | 62 | all_sources = np.stack(all_sources) # n voices x n mics x n samples 63 | mixed_data = np.sum(all_sources, axis=0) # n mics x n samples 64 | 65 | gt = [ 66 | CandidateVoice(voice_positions[i], None, all_sources[i]) 67 | for i in range(num_voices) 68 | ] 69 | 70 | return mixed_data, gt 71 | 72 | 73 | def main(args): 74 | args.moving = False 75 | device = torch.device('cuda') if args.use_cuda else torch.device('cpu') 76 | 77 | args.device = device 78 | model = CoSNetwork(n_audio_channels=args.n_channels) 79 | model.load_state_dict(torch.load(args.model_checkpoint), strict=True) 80 | model.train = False 81 | model.to(device) 82 | 83 | all_dirs = sorted(list(Path(args.test_dir).glob('[0-9]*'))) 84 | all_dirs = [x for x in all_dirs if check_valid_dir(x, args.n_voices)] 85 | 86 | if args.prec_recall and args.oracle_position: 87 | raise(ValueError("Either specify prec recall or oracle position")) 88 | 89 | if args.prec_recall: 90 | # True positives, false negatives, false positives 91 | all_tp, all_fn, all_fp = [], [], [] 92 | 93 | else: 94 | # Placeholders to support multiprocessing 95 | all_angle_errors = [0] * len(all_dirs) 96 | all_input_sdr = [0] * len(all_dirs) 97 | all_output_sdr = [0] * len(all_dirs) 98 | 99 | gpu_lock = Lock() 100 | 101 | def evaluate_dir(idx): 102 | if args.debug: 103 | curr_writing_dir = "{:05d}".format(idx) 104 | if not os.path.exists(curr_writing_dir): 105 | os.makedirs(curr_writing_dir) 106 | args.writing_dir = curr_writing_dir 107 | 108 | curr_dir = all_dirs[idx] 109 | 110 | # Loads the data 111 | mixed_data, gt = get_items(curr_dir, args) 112 | 113 | # Prevents CUDA out of memory 114 | gpu_lock.acquire() 115 | if args.prec_recall: 116 | # Case where we don't know the number of sources 117 | candidate_voices = run_separation(mixed_data, model, args) 118 | 119 | # Case where we know the number of sources 120 | else: 121 | # Normal run 122 | if not args.oracle_position: 123 | candidate_voices = run_separation(mixed_data, model, args, 0.005) 124 | # In order to compute SDR or angle error, the number of outputs must match gt 125 | # We set a very low threshold to ensure we get the correct number of outputs 126 | if args.oracle_position or len(candidate_voices) < len(gt): 127 | print("Had to go again\n") 128 | candidate_voices = run_separation(mixed_data, model, args, 0.000001) 129 | 130 | # Use the GT positions to find the best sources 131 | if args.oracle_position: 132 | trimmed_voices = [] 133 | for gt_idx in range(args.n_voices): 134 | best_idx = np.argmin(np.array([angular_distance(x.angle, 135 | gt[gt_idx].angle) for x in candidate_voices])) 136 | trimmed_voices.append(candidate_voices[best_idx]) 137 | candidate_voices = trimmed_voices 138 | 139 | # Take the top N voices 140 | else: 141 | candidate_voices = candidate_voices[:args.n_voices] 142 | if len(candidate_voices) != len(gt): 143 | print(f"Not enough outputs for dir {curr_dir}. Lower threshold to evaluate.") 144 | return 145 | 146 | if args.debug: 147 | sf.write(os.path.join(args.writing_dir, "mixed.wav"), 148 | mixed_data[0], 149 | args.sr) 150 | for voice in candidate_voices: 151 | fname = "out_angle{:.2f}.wav".format( 152 | voice.angle * 180 / np.pi) 153 | sf.write(os.path.join(args.writing_dir, fname), voice.data[0], 154 | args.sr) 155 | 156 | gpu_lock.release() 157 | curr_angle_errors = [] 158 | curr_input_sdr = [] 159 | curr_output_sdr = [] 160 | 161 | best_permutation, (tp, fn, fp) = find_best_permutation_prec_recall( 162 | [x.angle for x in gt], [x.angle for x in candidate_voices]) 163 | 164 | if args.prec_recall: 165 | all_tp.append(tp) 166 | all_fn.append(fn) 167 | all_fp.append(fp) 168 | 169 | # Evaluate SDR and Angular Error 170 | else: 171 | for gt_idx, output_idx in enumerate(best_permutation): 172 | angle_error = angular_distance(candidate_voices[output_idx].angle, 173 | gt[gt_idx].angle) 174 | # print(angle_error) 175 | curr_angle_errors.append(angle_error) 176 | 177 | # To speed up we only evaluate channel 0. For rigorous results 178 | # set that to false 179 | input_sdr = compute_sdr(gt[gt_idx].data, mixed_data, 180 | single_channel=True) 181 | output_sdr = compute_sdr(gt[gt_idx].data, 182 | candidate_voices[output_idx].data, single_channel=True) 183 | 184 | curr_input_sdr.append(input_sdr) 185 | curr_output_sdr.append(output_sdr) 186 | 187 | # print(curr_input_sdr) 188 | # print(curr_output_sdr) 189 | 190 | all_angle_errors[idx] = curr_angle_errors 191 | all_input_sdr[idx] = curr_input_sdr 192 | all_output_sdr[idx] = curr_output_sdr 193 | 194 | # print("Running median angle error: {}".format(np.median(np.array(all_angle_errors[:idx+1])) * 180 / np.pi)) 195 | # print("Running median SDRi: ", 196 | # np.median(np.array(all_output_sdr[:idx+1]) - np.array(all_input_sdr[:idx+1]))) 197 | 198 | pool = mp.Pool(args.n_workers) 199 | with tqdm.tqdm(total=len(all_dirs)) as pbar: 200 | for i, _ in enumerate(pool.imap_unordered(evaluate_dir, range(len(all_dirs)))): 201 | pbar.update() 202 | pool.close() 203 | pool.join() 204 | 205 | 206 | # Print and save the outputs 207 | if args.prec_recall: 208 | print("Overall Precision: {} Recall: {}".format( 209 | sum(all_tp) / (sum(all_tp) + sum(all_fp)), 210 | sum(all_tp) / (sum(all_tp) + sum(all_fn)))) 211 | 212 | else: 213 | print("Median Angular Error: ", np.median(np.array(all_angle_errors)) * 180 / np.pi) 214 | print("Median SDRi: ", 215 | np.median(np.array(all_output_sdr) - np.array(all_input_sdr))) 216 | # Uncomment to save the data for visualization 217 | # np.save("angleerror_{}voices_{}kHz.npy".format(args.n_voices, args.sr), 218 | # np.array(all_angle_errors).flatten()) 219 | # np.save("SDR_{}voices_{}kHz.npy".format(args.n_voices, args.sr), 220 | # np.array([np.array(all_input_sdr).flatten(), np.array(all_output_sdr).flatten()])) 221 | 222 | 223 | 224 | if __name__ == '__main__': 225 | parser = argparse.ArgumentParser() 226 | parser.add_argument('test_dir', type=str, 227 | help="Path to the testing directory") 228 | parser.add_argument('model_checkpoint', type=str, 229 | help="Path to the model file") 230 | parser.add_argument('--sr', type=int, default=22050, help="Sampling rate") 231 | parser.add_argument('--n_channels', type=int, default=2, 232 | help="Number of channels") 233 | parser.add_argument('--use_cuda', dest='use_cuda', action='store_true', 234 | help="Whether to use cuda") 235 | parser.add_argument('--debug', action='store_true', help="Save outputs") 236 | parser.add_argument('--mic_radius', default=.0725, type=float, 237 | help="To do") 238 | parser.add_argument('--n_workers', default=8, type=int, 239 | help="Multiprocessing") 240 | parser.add_argument( 241 | '--n_voices', default=2, type=int, help= 242 | "Number of voices in the GT scenarios. \ 243 | Useful so you can re-use the same dataset with different number of fg sources" 244 | ) 245 | parser.add_argument( 246 | '--prec_recall', action='store_true', help= 247 | "To compute precision and recall, we don't let the network know the number of sources" 248 | ) 249 | parser.add_argument( 250 | '--oracle_position', action='store_true', help= 251 | "Compute the separation results if you know the GT positions" 252 | ) 253 | 254 | print(parser.parse_args()) 255 | main(parser.parse_args()) 256 | -------------------------------------------------------------------------------- /cos/inference/separation_by_localization.py: -------------------------------------------------------------------------------- 1 | """ 2 | The main separation by localization inference algorithm 3 | """ 4 | 5 | import argparse 6 | import os 7 | 8 | from collections import namedtuple 9 | 10 | import librosa 11 | import torch 12 | import torch.nn.functional as F 13 | import numpy as np 14 | import soundfile as sf 15 | 16 | import cos.helpers.utils as utils 17 | 18 | from cos.helpers.constants import ALL_WINDOW_SIZES, \ 19 | FAR_FIELD_RADIUS 20 | from cos.helpers.visualization import draw_diagram 21 | from cos.training.network import CoSNetwork, center_trim, \ 22 | normalize_input, unnormalize_input 23 | from cos.helpers.eval_utils import si_sdr 24 | 25 | # Constants which may be tweaked based on your setup 26 | ENERGY_CUTOFF = 0.002 27 | NMS_RADIUS = np.pi / 4 28 | NMS_SIMILARITY_SDR = -7.0 # SDR cutoff for different candidates 29 | 30 | CandidateVoice = namedtuple("CandidateVoice", ["angle", "energy", "data"]) 31 | 32 | 33 | def nms(candidate_voices, nms_cutoff): 34 | """ 35 | Runs non-max suppression on the candidate voices 36 | """ 37 | final_proposals = [] 38 | initial_proposals = candidate_voices 39 | 40 | while len(initial_proposals) > 0: 41 | new_initial_proposals = [] 42 | sorted_candidates = sorted(initial_proposals, 43 | key=lambda x: x[1], 44 | reverse=True) 45 | 46 | # Choose the loudest voice 47 | best_candidate_voice = sorted_candidates[0] 48 | final_proposals.append(best_candidate_voice) 49 | sorted_candidates.pop(0) 50 | 51 | # See if any of the rest should be removed 52 | for candidate_voice in sorted_candidates: 53 | different_locations = utils.angular_distance( 54 | candidate_voice.angle, best_candidate_voice.angle) > NMS_RADIUS 55 | 56 | # different_content = abs( 57 | # candidate_voice.data - 58 | # best_candidate_voice.data).mean() > nms_cutoff 59 | 60 | different_content = si_sdr( 61 | candidate_voice.data[0], 62 | best_candidate_voice.data[0]) < nms_cutoff 63 | 64 | if different_locations or different_content: 65 | new_initial_proposals.append(candidate_voice) 66 | 67 | initial_proposals = new_initial_proposals 68 | 69 | return final_proposals 70 | 71 | 72 | def forward_pass(model, target_angle, mixed_data, conditioning_label, args): 73 | """ 74 | Runs the network on the mixed_data 75 | with the candidate region given by voice 76 | """ 77 | target_pos = np.array([ 78 | FAR_FIELD_RADIUS * np.cos(target_angle), 79 | FAR_FIELD_RADIUS * np.sin(target_angle) 80 | ]) 81 | 82 | data, _ = utils.shift_mixture( 83 | torch.tensor(mixed_data).to(args.device), target_pos, args.mic_radius, 84 | args.sr) 85 | data = data.float().unsqueeze(0) # Batch size is 1 86 | 87 | # Normalize input 88 | data, means, stds = normalize_input(data) 89 | 90 | # Run through the model 91 | valid_length = model.valid_length(data.shape[-1]) 92 | delta = valid_length - data.shape[-1] 93 | padded = F.pad(data, (delta // 2, delta - delta // 2)) 94 | 95 | output_signal = model(padded, conditioning_label) 96 | output_signal = center_trim(output_signal, data) 97 | 98 | output_signal = unnormalize_input(output_signal, means, stds) 99 | output_voices = output_signal[:, 0] # batch x n_mics x n_samples 100 | 101 | output_np = output_voices.detach().cpu().numpy()[0] 102 | energy = librosa.feature.rms(output_np).mean() 103 | 104 | return output_np, energy 105 | 106 | 107 | def run_separation(mixed_data, model, args, 108 | energy_cutoff=ENERGY_CUTOFF, 109 | nms_cutoff=NMS_SIMILARITY_SDR): # yapf: disable 110 | """ 111 | The main separation by localization algorithm 112 | """ 113 | # Get the initial candidates 114 | num_windows = len(ALL_WINDOW_SIZES) if not args.moving else 3 115 | starting_angles = utils.get_starting_angles(ALL_WINDOW_SIZES[0]) 116 | candidate_voices = [CandidateVoice(x, None, None) for x in starting_angles] 117 | 118 | # All steps of the binary search 119 | for window_idx in range(num_windows): 120 | if args.debug: 121 | print("---------") 122 | conditioning_label = torch.tensor(utils.to_categorical( 123 | window_idx, 5)).float().to(args.device).unsqueeze(0) 124 | 125 | curr_window_size = ALL_WINDOW_SIZES[window_idx] 126 | new_candidate_voices = [] 127 | 128 | # Iterate over all the potential locations 129 | for voice in candidate_voices: 130 | output, energy = forward_pass(model, voice.angle, mixed_data, 131 | conditioning_label, args) 132 | 133 | if args.debug: 134 | print("Angle {:.2f} energy {}".format(voice.angle, energy)) 135 | fname = "out{}_angle{:.2f}.wav".format( 136 | window_idx, voice.angle * 180 / np.pi) 137 | # sf.write(os.path.join(args.writing_dir, fname), output[0], 138 | # args.sr) 139 | 140 | # If there was something there 141 | if energy > energy_cutoff: 142 | 143 | # We're done searching so undo the shifts 144 | if window_idx == num_windows - 1: 145 | target_pos = np.array([ 146 | FAR_FIELD_RADIUS * np.cos(voice.angle), 147 | FAR_FIELD_RADIUS * np.sin(voice.angle) 148 | ]) 149 | unshifted_output, _ = utils.shift_mixture(output, 150 | target_pos, 151 | args.mic_radius, 152 | args.sr, 153 | inverse=True) 154 | 155 | new_candidate_voices.append( 156 | CandidateVoice(voice.angle, energy, unshifted_output)) 157 | 158 | # Split region and recurse. 159 | # You can either split strictly (fourths) 160 | # or with some redundancy (thirds) 161 | else: 162 | # new_candidate_voices.append( 163 | # CandidateVoice( 164 | # voice.angle + curr_window_size / 3, 165 | # energy, output)) 166 | # new_candidate_voices.append( 167 | # CandidateVoice( 168 | # voice.angle - curr_window_size / 3, 169 | # energy, output)) 170 | # new_candidate_voices.append( 171 | # CandidateVoice( 172 | # voice.angle, 173 | # energy, output)) 174 | new_candidate_voices.append( 175 | CandidateVoice( 176 | voice.angle + curr_window_size / 4, 177 | energy, output)) 178 | new_candidate_voices.append( 179 | CandidateVoice( 180 | voice.angle - curr_window_size / 4, 181 | energy, output)) 182 | 183 | candidate_voices = new_candidate_voices 184 | 185 | # Run NMS on the final output and return 186 | return nms(candidate_voices, nms_cutoff) 187 | 188 | 189 | def main(args): 190 | device = torch.device('cuda') if args.use_cuda else torch.device('cpu') 191 | 192 | args.device = device 193 | model = CoSNetwork(n_audio_channels=args.n_channels) 194 | model.load_state_dict(torch.load(args.model_checkpoint), strict=True, map_location=args.device) 195 | model.train = False 196 | model.to(device) 197 | 198 | if not os.path.exists(args.output_dir): 199 | os.makedirs(args.output_dir) 200 | 201 | mixed_data = librosa.core.load(args.input_file, mono=False, sr=args.sr)[0] 202 | assert mixed_data.shape[0] == args.n_channels 203 | 204 | temporal_chunk_size = int(args.sr * args.duration) 205 | num_chunks = (mixed_data.shape[1] // temporal_chunk_size) + 1 206 | 207 | for chunk_idx in range(num_chunks): 208 | curr_writing_dir = os.path.join(args.output_dir, 209 | "{:03d}".format(chunk_idx)) 210 | if not os.path.exists(curr_writing_dir): 211 | os.makedirs(curr_writing_dir) 212 | 213 | args.writing_dir = curr_writing_dir 214 | curr_mixed_data = mixed_data[:, (chunk_idx * 215 | temporal_chunk_size):(chunk_idx + 1) * 216 | temporal_chunk_size] 217 | 218 | output_voices = run_separation(curr_mixed_data, model, args) 219 | for voice in output_voices: 220 | fname = "output_angle{:.2f}.wav".format( 221 | voice.angle * 180 / np.pi) 222 | sf.write(os.path.join(args.writing_dir, fname), voice.data[0], 223 | args.sr) 224 | 225 | candidate_angles = [voice.angle for voice in output_voices] 226 | diagram_window_angle = ALL_WINDOW_SIZES[2] if args.moving else ALL_WINDOW_SIZES[-1] 227 | draw_diagram([], candidate_angles, 228 | diagram_window_angle, 229 | os.path.join(args.writing_dir, "positions.png".format(chunk_idx))) 230 | 231 | if __name__ == '__main__': 232 | parser = argparse.ArgumentParser() 233 | parser.add_argument('model_checkpoint', 234 | type=str, 235 | help="Path to the model file") 236 | parser.add_argument('input_file', type=str, help="Path to the input file") 237 | parser.add_argument('output_dir', 238 | type=str, 239 | help="Path to write the outputs") 240 | parser.add_argument('--sr', type=int, default=22050, help="Sampling rate") 241 | parser.add_argument('--n_channels', 242 | type=int, 243 | default=2, 244 | help="Number of channels") 245 | parser.add_argument('--use_cuda', 246 | dest='use_cuda', 247 | action='store_true', 248 | help="Whether to use cuda") 249 | parser.add_argument('--debug', 250 | action='store_true', 251 | help="Save intermediate outputs") 252 | parser.add_argument('--mic_radius', 253 | default=.03231, 254 | type=float, 255 | help="Radius of the mic array") 256 | parser.add_argument('--duration', 257 | default=3.0, 258 | type=float, 259 | help="Seconds of input to the network") 260 | parser.add_argument('--moving', 261 | action='store_true', 262 | help="If the sources are moving then stop at a coarse window") 263 | main(parser.parse_args()) 264 | --------------------------------------------------------------------------------