├── cos
├── __init__.py
├── helpers
│ ├── __init__.py
│ ├── constants.py
│ ├── eval_utils.py
│ ├── visualization.py
│ ├── ibm.py
│ ├── irm.py
│ ├── utils.py
│ └── mwf.py
├── inference
│ ├── __init__.py
│ ├── evaluate_synthetic.py
│ └── separation_by_localization.py
├── training
│ ├── __init__.py
│ ├── data_augmentation.py
│ ├── synthetic_dataset.py
│ ├── network.py
│ └── train.py
└── generate_dataset.py
├── checkpoints
└── info.txt
├── outputs
└── info.txt
├── requirements.txt
├── .gitignore
├── LICENSE
├── RunningTips.md
└── README.md
/cos/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cos/helpers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cos/inference/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cos/training/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/checkpoints/info.txt:
--------------------------------------------------------------------------------
1 | pt checkpoint files go here
--------------------------------------------------------------------------------
/outputs/info.txt:
--------------------------------------------------------------------------------
1 | Output data gets written here
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ### Requirements
2 | librosa
3 | torch >= 1.3.0
4 | soundfile
5 | scipy
6 | pyroomacoustics==0.3.1
7 | matplotlib
8 | mir_eval
9 | tqdm
10 | numpy
11 | pysndfx
12 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Dot underscore
7 | ._*
8 |
9 | # Outputs
10 |
11 | # Pytorch
12 | *.pt
13 |
14 | # Audio files
15 | *.wav
16 | *.mp4
17 | *.flac
18 | *.mp3
19 |
20 | # Data Files
21 | *.npy
22 |
--------------------------------------------------------------------------------
/cos/training/data_augmentation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from pysndfx import AudioEffectsChain
4 |
5 | class RandomAudioPerturbation(object):
6 | """Randomly perturb audio samples"""
7 |
8 | def __call__(self, data):
9 | """
10 | Data must be mics x T numpy array
11 | """
12 | highshelf_gain = np.random.normal(0, 2)
13 | lowshelf_gain = np.random.normal(0, 2)
14 | noise_amount = np.random.uniform(0, 0.001)
15 |
16 | fx = (
17 | AudioEffectsChain()
18 | .highshelf(gain=highshelf_gain)
19 | .lowshelf(gain=lowshelf_gain)
20 | )
21 |
22 | for i in range(data.shape[0]):
23 | data[i] = fx(data[i])
24 | data[i] += np.random.uniform(-noise_amount, noise_amount, size=data[i].shape)
25 | return data
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Teerapat Jenrungrot, Vivek Jayaram, Steve Seitz, and Ira Kemelmacher-Shlizerman
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/RunningTips.md:
--------------------------------------------------------------------------------
1 | # Tips for Running on Real Data
2 |
3 | ## Hardware
4 | We have only tested this network with circular microphone arrays. It might be possible to train a new model for ad-hoc or linear mic arrays. The pretrained model works with the 4 mic [Seed ReSpeaker MicArray v 2.0](https://wiki.seeedstudio.com/ReSpeaker_Mic_Array_v2.0/)
5 |
6 | ## Capturing Data
7 | ### Positioning
8 | For best results, the sources should not be too far from the microphone. Our model was trained with sources between 1-4 meters from the microphone, but best results are 1-2m away. Extreme far field (>4m) has not been explored. Ideally the sources should be at roughly the same elevation angle as the microphone. For example, If the microphone is on the ground and the sources are standing, the assumptions in the pre-shift formulas no longer hold.
9 | ### Obstructions
10 | The point source model breaks down if there is not a direct line of sight between the source and the microphone, for example if there is a laptop between the voice and mic.
11 | ## Hyperparameters
12 | If the sources are completely stationary, you can increase the `--duration` flag which processes large chunks at a time. This improves the performance and reduces boundary effects. At the top of `cos/inference/separation_by_localization.py` are additional parameters. Tweak `ENERGY_CUTOFF` to more aggressively keep or reject sources. For more aggressive non-max suppression, you can reduce `NMS_SIMILARITY_SDR` which only keeps additional sources if they have a SDR to the existing source that is lower than this parameter.
13 |
14 | ## Post Processing
15 | After separating voices from each other, any type of single channel post processing can be run on the output. We found that it was useful to run a low-pass filter on the output. The network requires 44.1kHz sampling runing for localization and time differences, but can sometimes produce artifacts in these high frequency ranges. Because human voice doesn't contain many frequences about ~6kHz, these frequencies can simply be cut.
16 |
17 | # Tips for Training on Real Data
18 |
19 | We strongly recommend training on data collected from your specific microphone. Due to the 2020 situation, we could not actually record a variety of real background sounds, a variety of environments, or a variety real speakers with our mic array. We expect the network performance to improve with more training data, even beyond the pretrained models we have provided. If you train on new data, do not train from scratch: instead fine-tune from our existing weights even if the number of channels is different. Training from scratch often results in the network outputting silence everywhere. In our experience, it is best to jointly train on real and synthetic data.
20 |
--------------------------------------------------------------------------------
/cos/helpers/constants.py:
--------------------------------------------------------------------------------
1 | """
2 | A collection of constants that probably should not be changed
3 | """
4 |
5 | import numpy as np
6 |
7 | # Universal Constants
8 | SPEED_OF_SOUND = 343.0 # m/s
9 | FAR_FIELD_RADIUS = 3.0 # meters, assume larger the mic array radius
10 |
11 | # Algorithmic Constants
12 | ALL_WINDOW_SIZES = [
13 | np.pi / 2, # 90 degrees
14 | np.pi / 4, # 45 degrees
15 | np.pi / 8, # 22.5 degrees
16 | np.pi / 16, # 11.25 degrees
17 | np.pi / 32, # 5.625 degrees
18 | ]
19 |
20 | def get_mic_diagram():
21 | """
22 | A simple hard-coded mic figure for matplotlib
23 | """
24 | import matplotlib
25 | matplotlib.use("Agg")
26 | mic_verts = np.array([[24. , 28. ],
27 | [27.31, 28. ],
28 | [29.98, 25.31],
29 | [29.98, 22. ],
30 | [30. , 10. ],
31 | [30. , 6.68],
32 | [27.32, 4. ],
33 | [24. , 4. ],
34 | [20.69, 4. ],
35 | [18. , 6.68],
36 | [18. , 10. ],
37 | [18. , 22. ],
38 | [18. , 25.31],
39 | [20.69, 28. ],
40 | [24. , 28. ],
41 | [24. , 28. ],
42 | [34.6 , 22. ],
43 | [34.6 , 28. ],
44 | [29.53, 32.2 ],
45 | [24. , 32.2 ],
46 | [18.48, 32.2 ],
47 | [13.4 , 28. ],
48 | [13.4 , 22. ],
49 | [10. , 22. ],
50 | [10. , 28.83],
51 | [15.44, 34.47],
52 | [22. , 35.44],
53 | [22. , 42. ],
54 | [26. , 42. ],
55 | [26. , 35.44],
56 | [32.56, 34.47],
57 | [38. , 28.83],
58 | [38. , 22. ],
59 | [34.6 , 22. ],
60 | [34.6 , 22. ]])
61 | mic_verts[:,1] = (48 - mic_verts[:,1]) - 24
62 | mic_verts[:,0] -= 24
63 |
64 | mic_verts[:,0] /= 240
65 | mic_verts[:,1] /= 240
66 |
67 | mic_verts *= 10
68 |
69 | mic_codes = np.array([ 1, 4, 4, 4, 2, 4, 4, 4, 4, 4, 4, 2, 4, 4, 4, 79, 1,
70 | 4, 4, 4, 4, 4, 4, 2, 4, 4, 4, 2, 2, 2, 4, 4, 4, 2,
71 | 79], dtype=np.uint8)
72 |
73 | mic = matplotlib.path.Path(mic_verts, mic_codes)
74 | return mic
75 |
--------------------------------------------------------------------------------
/cos/helpers/eval_utils.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import math
3 |
4 | import numpy as np
5 |
6 | from mir_eval.separation import bss_eval_sources
7 |
8 | from cos.helpers.utils import angular_distance
9 |
10 | def si_sdr(estimated_signal, reference_signals, scaling=True):
11 | """
12 | This is a scale invariant SDR. See https://arxiv.org/pdf/1811.02508.pdf
13 | or https://github.com/sigsep/bsseval/issues/3 for the motivation and
14 | explanation
15 |
16 | Input:
17 | estimated_signal and reference signals are (N,) numpy arrays
18 |
19 | Returns: SI-SDR as scalar
20 | """
21 | Rss = np.dot(reference_signals, reference_signals)
22 | this_s = reference_signals
23 |
24 | if scaling:
25 | # get the scaling factor for clean sources
26 | a = np.dot(this_s, estimated_signal) / Rss
27 | else:
28 | a = 1
29 |
30 | e_true = a * this_s
31 | e_res = estimated_signal - e_true
32 |
33 | Sss = (e_true**2).sum()
34 | Snn = (e_res**2).sum()
35 |
36 | SDR = 10 * math.log10(Sss/Snn)
37 |
38 | return SDR
39 |
40 |
41 | def compute_sdr(gt, output, single_channel=False):
42 | assert(gt.shape == output.shape)
43 | per_channel_sdr = []
44 |
45 | channels = [0] if single_channel else range(gt.shape[0])
46 | for channel_idx in channels:
47 | # sdr, _, _, _ = bss_eval_sources(gt[channel_idx], output[channel_idx])
48 | sdr = si_sdr(output[channel_idx], gt[channel_idx])
49 | per_channel_sdr.append(sdr)
50 |
51 | return np.array(per_channel_sdr).mean()
52 |
53 |
54 |
55 | def find_best_permutation_prec_recall(gt, output, acceptable_window=np.pi / 18):
56 | """
57 | Finds the best permutation for evaluation.
58 | Then uses that to find the precision and recall
59 |
60 | Inputs:
61 | gt, output: list of sources. lengths may differ
62 |
63 | Returns: Permutation that matches outputs to gt along with tp, fn and fp
64 | """
65 | n = max(len(gt), len(output))
66 |
67 | if len(gt) > len(output):
68 | output += [np.inf] * (n - len(output))
69 | elif len(output) > len(gt):
70 | gt += [np.inf] * (n - len(gt))
71 |
72 | best_perm = None
73 | best_inliers = -1
74 | for perm in itertools.permutations(range(n)):
75 | curr_inliers = 0
76 | for idx1, idx2 in enumerate(perm):
77 | if angular_distance(gt[idx1], output[idx2]) < acceptable_window:
78 | curr_inliers += 1
79 |
80 | if curr_inliers > best_inliers:
81 | best_inliers = curr_inliers
82 | best_perm = list(perm)
83 |
84 | return localization_precision_recall(best_perm, gt, output, acceptable_window)
85 |
86 |
87 | def localization_precision_recall(permutation, gt, output, acceptable_window=np.pi/18):
88 | tp, fn, fp = 0, 0, 0
89 | for idx1, idx2 in enumerate(permutation):
90 | if angular_distance(gt[idx1], output[idx2]) < acceptable_window:
91 | tp += 1
92 | elif gt[idx1] == np.inf:
93 | fp += 1
94 | elif output[idx2] == np.inf:
95 | fn += 1
96 | else:
97 | fn += 1
98 | fp += 1
99 |
100 | return permutation, (tp, fn, fp)
101 |
--------------------------------------------------------------------------------
/cos/helpers/visualization.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 |
5 | import cos.helpers.utils as utils
6 | from cos.helpers.constants import get_mic_diagram
7 |
8 |
9 |
10 | def draw_diagram(voice_positions, candidate_angles, angle_window_size, output_file):
11 | """
12 | Draws the setup of all the voices in space, and colored triangles for the beams
13 | """
14 | import matplotlib
15 | matplotlib.use("Agg")
16 | import matplotlib.pyplot as plt
17 | from matplotlib.patches import Circle, Wedge, Polygon
18 | from matplotlib.collections import PatchCollection
19 | matplotlib.style.use('ggplot')
20 |
21 | colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]
22 | colors = [colors[0], colors[3], colors[2], colors[1]]
23 |
24 | fig, ax = plt.subplots()
25 | ax.set(xlim=(-5, 5), ylim = (-5, 5))
26 | ax.set_aspect("equal")
27 | ax.annotate('X', xy=(-5,0), xytext=(5.2,0),
28 | arrowprops={'arrowstyle': '<|-', "color": "black", "linewidth":3}, va='center', fontsize=20)
29 | ax.annotate('Y', xy=(0,-5), xytext=(-0.25, 5.5),
30 | arrowprops={'arrowstyle': '<|-', "color": "black", "linewidth":3}, va='center', fontsize=20)
31 |
32 | plt.tick_params(axis='both',
33 | which='both', bottom='off',
34 | top='off', labelbottom='off', right='off', left='off', labelleft='off'
35 | )
36 |
37 | for pos in voice_positions:
38 | if pos[0] != 0.0:
39 | a_circle = plt.Circle((pos[0], pos[1]), 0.3, color='b', fill=False)
40 | ax.add_artist(a_circle)
41 |
42 | patches = []
43 | for idx, target_angle in enumerate(candidate_angles):
44 | vertices = angle_to_triangle(target_angle, angle_window_size) * 4.96
45 |
46 | ax.fill(vertices[:, 0], vertices[:, 1],
47 | edgecolor='black', linewidth=2, alpha=0.6)
48 |
49 | mic = get_mic_diagram()
50 | patch = matplotlib.patches.PathPatch(mic, fill=True, facecolor='black')
51 | ax.add_patch(patch)
52 | ax.tick_params(axis='both', which='both', labelcolor="white", colors="white")
53 | plt.savefig(output_file)
54 |
55 |
56 | def angle_to_triangle(target_angle, angle_window_size):
57 | """
58 | Takes a target angle and window size and returns a
59 | triangle corresponding to that pie slice
60 | """
61 | first_point = [0,0] # Always start at the origiin
62 | second_point = angle_to_point(utils.convert_angular_range(target_angle - angle_window_size/2))
63 | third_point = angle_to_point(utils.convert_angular_range(target_angle + angle_window_size/2))
64 |
65 | return(np.array([first_point, second_point, third_point]))
66 |
67 |
68 | def angle_to_point(angle):
69 | """Angle must be -pi to pi"""
70 | if -np.pi <= angle < -3*np.pi/4:
71 | return[-1, -np.tan(angle + np.pi)]
72 |
73 | elif -3*np.pi/4 <= angle < -np.pi/2:
74 | return[-np.tan(-np.pi/2 - angle), -1]
75 |
76 | elif -np.pi/2 <= angle < -np.pi/4:
77 | return[np.tan(angle + np.pi/2), -1]
78 |
79 | elif -np.pi/4 <= angle < 0:
80 | return[1, -np.tan(-angle)]
81 |
82 | elif 0 <= angle < np.pi / 4:
83 | return [1, np.tan(angle)]
84 |
85 | elif np.pi/4 <= angle < np.pi/2:
86 | return [np.tan(np.pi/2 - angle), 1]
87 |
88 | elif np.pi/2 <= angle <= 3*np.pi/4:
89 | return [-np.tan(angle - np.pi/2), 1]
90 |
91 | elif 3*np.pi/4 < angle <= np.pi:
92 | return [-1, np.tan(np.pi - angle)]
93 |
--------------------------------------------------------------------------------
/cos/helpers/ibm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import multiprocessing.dummy as mp
4 | import os
5 |
6 | from pathlib import Path
7 |
8 | import librosa
9 | import numpy as np
10 | import tqdm
11 |
12 | from scipy.signal import stft, istft
13 |
14 | from cos.helpers.eval_utils import compute_sdr
15 | from cos.inference.evaluate_synthetic import get_items
16 | from cos.helpers.utils import check_valid_dir
17 |
18 | def compute_ibm(gt, mix, alpha, theta=0.5):
19 | """
20 | Computes the Ideal Binary Mask SI-SDR
21 | gt: (n_voices, n_channels, t)
22 | mix: (n_channels, t)
23 | """
24 | n_voices = gt.shape[0]
25 | nfft = 2048
26 | eps = np.finfo(np.float).eps
27 |
28 | N = mix.shape[-1] # number of samples
29 | X = stft(mix, nperseg=nfft)[2]
30 | (I, F, T) = X.shape # (6, nfft//2 +1, n_frame)
31 |
32 | # perform separation
33 | estimates = []
34 | for gt_idx in range(n_voices):
35 | # compute STFT of target source
36 | Yj = stft(gt[gt_idx], nperseg=nfft)[2]
37 |
38 | # Create binary Mask
39 | mask = np.divide(np.abs(Yj)**alpha, (eps + np.abs(X) ** alpha))
40 | mask[np.where(mask >= theta)] = 1
41 | mask[np.where(mask < theta)] = 0
42 |
43 | Yj = np.multiply(X, mask)
44 | target_estimate = istft(Yj)[1][:,:N]
45 |
46 | estimates.append(target_estimate)
47 |
48 | estimates = np.array(estimates) # (nvoice, 6, 6*sr)
49 |
50 | # eval
51 | eval_mix = np.repeat(mix[np.newaxis, :, :], n_voices, axis=0) # (nvoice, 6, 6*sr)
52 | eval_gt = gt # (nvoice, 6, 6*sr)
53 | eval_est = estimates
54 |
55 | SDR_in = []
56 | SDR_out = []
57 | for i in range(n_voices):
58 | SDR_in.append(compute_sdr(eval_gt[i], eval_mix[i], single_channel=True)) # scalar
59 | SDR_out.append(compute_sdr(eval_gt[i], eval_est[i], single_channel=True)) # scalar
60 |
61 | output = np.array([SDR_in, SDR_out]) # (2, nvoice)
62 |
63 | return output
64 |
65 |
66 | def main(args):
67 | all_dirs = sorted(list(Path(args.input_dir).glob('[0-9]*')))
68 | all_dirs = [x for x in all_dirs if check_valid_dir(x, args.n_voices)]
69 |
70 | all_input_sdr = [0] * len(all_dirs)
71 | all_output_sdr = [0] * len(all_dirs)
72 |
73 | def evaluate_dir(idx):
74 | curr_dir = all_dirs[idx]
75 | # Loads the data
76 | mixed_data, gt = get_items(curr_dir, args)
77 | gt = np.array([x.data for x in gt])
78 | output = compute_ibm(gt, mixed_data, alpha=args.alpha)
79 | all_input_sdr[idx] = output[0]
80 | all_output_sdr[idx] = output[1]
81 | print("Running median SDRi: ",
82 | np.median(np.array(all_output_sdr[:idx+1]) - np.array(all_input_sdr[:idx+1])))
83 |
84 | pool = mp.Pool(args.n_workers)
85 | with tqdm.tqdm(total=len(all_dirs)) as pbar:
86 | for i, _ in enumerate(pool.imap_unordered(evaluate_dir, range(len(all_dirs)))):
87 | pbar.update()
88 |
89 | # tqdm.tqdm(pool.imap(evaluate_dir, range(len(all_dirs))), total=len(all_dirs))
90 | pool.close()
91 | pool.join()
92 |
93 | print("Median SI-SDRi: ",
94 | np.median(np.array(all_output_sdr).flatten() - np.array(all_input_sdr).flatten()))
95 |
96 | np.save("IBM_{}voices_{}kHz.npy".format(args.n_voices, args.sr),
97 | np.array([np.array(all_input_sdr).flatten(), np.array(all_output_sdr).flatten()]))
98 |
99 |
100 | if __name__ == '__main__':
101 | parser = argparse.ArgumentParser()
102 | parser.add_argument('input_dir', type=str, help="Path to the input dir")
103 | parser.add_argument('--sr', type=int, default=22050, help="Sampling rate")
104 | parser.add_argument('--n_channels',
105 | type=int,
106 | default=2,
107 | help="Number of channels")
108 | parser.add_argument('--n_workers',
109 | type=int,
110 | default=8,
111 | help="Number of parallel workers")
112 | parser.add_argument('--n_voices',
113 | type=int,
114 | default=2,
115 | help="Number of voices in the dataset")
116 | parser.add_argument('--alpha',
117 | type=int,
118 | default=1,
119 | help="See the original SigSep code for an explanation")
120 | args = parser.parse_args()
121 |
122 | main(args)
123 |
124 |
125 |
--------------------------------------------------------------------------------
/cos/helpers/irm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import multiprocessing.dummy as mp
4 | import os
5 |
6 | from pathlib import Path
7 |
8 | import librosa
9 | import numpy as np
10 | import tqdm
11 |
12 | from scipy.signal import stft, istft
13 |
14 | from cos.helpers.eval_utils import compute_sdr
15 | from cos.inference.evaluate_synthetic import get_items
16 | from cos.helpers.utils import check_valid_dir
17 |
18 |
19 | def compute_irm(gt, mix, alpha):
20 | """
21 | Computes the Ideal Ratio Mask SI-SDR
22 | gt: (n_voices, n_channels, t)
23 | mix: (n_channels, t)
24 | """
25 | n_voices = gt.shape[0]
26 | nfft = 2048
27 | hop = 1024
28 | eps = np.finfo(np.float).eps
29 |
30 | N = mix.shape[-1] # number of samples
31 | X = stft(mix, nperseg=nfft)[2]
32 | (I, F, T) = X.shape # (6, nfft//2 +1, n_frame)
33 |
34 | # Compute sources spectrograms
35 | P = []
36 | for gt_idx in range(n_voices):
37 | P.append(np.abs(stft(gt[gt_idx], nperseg=nfft)[2]) ** alpha)
38 |
39 | # compute model as the sum of spectrograms
40 | model = eps
41 | for gt_idx in range(n_voices):
42 | model += P[gt_idx]
43 |
44 | # perform separation
45 | estimates = []
46 | for gt_idx in range(n_voices):
47 | # Create a ratio Mask
48 | mask = np.divide(np.abs(P[gt_idx]), model)
49 |
50 | # apply mask
51 | Yj = np.multiply(X, mask)
52 |
53 | target_estimate = istft(Yj)[1][:,:N]
54 |
55 | estimates.append(target_estimate)
56 |
57 | estimates = np.array(estimates) # (nvoice, 6, 6*sr)
58 |
59 | # eval
60 | eval_mix = np.repeat(mix[np.newaxis, :, :], n_voices, axis=0) # (nvoice, 6, 6*sr)
61 | eval_gt = gt # (nvoice, 6, 6*sr)
62 | eval_est = estimates
63 |
64 | SDR_in = []
65 | SDR_out = []
66 | for i in range(n_voices):
67 | SDR_in.append(compute_sdr(eval_gt[i], eval_mix[i], single_channel=True)) # scalar
68 | SDR_out.append(compute_sdr(eval_gt[i], eval_est[i], single_channel=True)) # scalar
69 |
70 | output = np.array([SDR_in, SDR_out]) # (2, nvoice)
71 |
72 | return output
73 |
74 | def main(args):
75 | all_dirs = sorted(list(Path(args.input_dir).glob('[0-9]*')))
76 | all_dirs = [x for x in all_dirs if check_valid_dir(x, args.n_voices)]
77 |
78 | all_input_sdr = [0] * len(all_dirs)
79 | all_output_sdr = [0] * len(all_dirs)
80 |
81 | def evaluate_dir(idx):
82 | curr_dir = all_dirs[idx]
83 | # Loads the data
84 | mixed_data, gt = get_items(curr_dir, args)
85 | gt = np.array([x.data for x in gt])
86 | output = compute_irm(gt, mixed_data, alpha=args.alpha)
87 | all_input_sdr[idx] = output[0]
88 | all_output_sdr[idx] = output[1]
89 |
90 | pool = mp.Pool(args.n_workers)
91 | with tqdm.tqdm(total=len(all_dirs)) as pbar:
92 | for i, _ in enumerate(pool.imap_unordered(evaluate_dir, range(len(all_dirs)))):
93 | pbar.update()
94 |
95 | # tqdm.tqdm(pool.imap(evaluate_dir, range(len(all_dirs))), total=len(all_dirs))
96 | pool.close()
97 | pool.join()
98 |
99 | print("Median SI-SDRi: ",
100 | np.median(np.array(all_output_sdr).flatten() - np.array(all_input_sdr).flatten()))
101 |
102 | np.save("IRM_{}voices_{}kHz.npy".format(args.n_voices, args.sr),
103 | np.array([np.array(all_input_sdr).flatten(), np.array(all_output_sdr).flatten()]))
104 |
105 |
106 | if __name__ == '__main__':
107 | parser = argparse.ArgumentParser()
108 | parser.add_argument('input_dir', type=str, help="Path to the input dir")
109 | parser.add_argument('--sr', type=int, default=22050, help="Sampling rate")
110 | parser.add_argument('--n_channels',
111 | type=int,
112 | default=2,
113 | help="Number of channels")
114 | parser.add_argument('--n_workers',
115 | type=int,
116 | default=8,
117 | help="Number of parallel workers")
118 | parser.add_argument('--n_voices',
119 | type=int,
120 | default=2,
121 | help="Number of voices in the dataset")
122 | parser.add_argument('--alpha',
123 | type=int,
124 | default=1,
125 | help="See the original SigSep code for an explanation")
126 | args = parser.parse_args()
127 |
128 | main(args)
129 |
--------------------------------------------------------------------------------
/cos/helpers/utils.py:
--------------------------------------------------------------------------------
1 | """A collection of useful helper functions"""
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from pathlib import Path
7 |
8 | from cos.helpers.constants import SPEED_OF_SOUND
9 |
10 |
11 | def shift_mixture(input_data, target_position, mic_radius, sr, inverse=False):
12 | """
13 | Shifts the input according to the voice position. This
14 | lines up the voice samples in the time domain coming from a target_angle
15 | Args:
16 | input_data - M x T numpy array or torch tensor
17 | target_position - The location where the data should be aligned
18 | mic_radius - In meters. The number of mics is inferred from
19 | the input_Data
20 | sr - Sample Rate in samples/sec
21 | inverse - Whether to align or undo a previous alignment
22 |
23 | Returns: shifted data and a list of the shifts
24 | """
25 | # elevation_angle = 0.0 * np.pi / 180
26 | # target_height = 3.0 * np.tan(elevation_angle)
27 | # target_position = np.append(target_position, target_height)
28 |
29 | num_channels = input_data.shape[0]
30 |
31 | # Must match exactly the generated or captured data
32 | mic_array = [[
33 | mic_radius * np.cos(2 * np.pi / num_channels * i),
34 | mic_radius * np.sin(2 * np.pi / num_channels * i),
35 | ] for i in range(num_channels)]
36 |
37 | # Mic 0 is the canonical position
38 | distance_mic0 = np.linalg.norm(mic_array[0] - target_position)
39 | shifts = [0]
40 |
41 | # Check if numpy or torch
42 | if isinstance(input_data, np.ndarray):
43 | shift_fn = np.roll
44 | elif isinstance(input_data, torch.Tensor):
45 | shift_fn = torch.roll
46 | else:
47 | raise TypeError("Unknown input data type: {}".format(type(input_data)))
48 |
49 | # Shift each channel of the mixture to align with mic0
50 | for channel_idx in range(1, num_channels):
51 | distance = np.linalg.norm(mic_array[channel_idx] - target_position)
52 | distance_diff = distance - distance_mic0
53 | shift_time = distance_diff / SPEED_OF_SOUND
54 | shift_samples = int(round(sr * shift_time))
55 | if inverse:
56 | input_data[channel_idx] = shift_fn(input_data[channel_idx],
57 | shift_samples)
58 | else:
59 | input_data[channel_idx] = shift_fn(input_data[channel_idx],
60 | -shift_samples)
61 | shifts.append(shift_samples)
62 |
63 | return input_data, shifts
64 |
65 |
66 | def angular_distance(angle1, angle2):
67 | """
68 | Computes the distance in radians betwen angle1 and angle2.
69 | We assume they are between -pi and pi
70 | """
71 | d1 = abs(angle1 - angle2)
72 | d2 = abs(angle1 - angle2 + 2 * np.pi)
73 | d3 = abs(angle2 - angle1 + 2 * np.pi)
74 |
75 | return min(d1, d2, d3)
76 |
77 | def get_starting_angles(window_size):
78 | """Returns the list of target angles for a window size"""
79 | divisor = int(round(2 * np.pi / window_size))
80 | return np.array(list(range(-divisor + 1, divisor, 2))) * np.pi / divisor
81 |
82 |
83 | def to_categorical(index: int, num_classes: int):
84 | """Creates a 1-hot encoded np array"""
85 | data = np.zeros((num_classes))
86 | data[index] = 1
87 | return data
88 |
89 | def convert_angular_range(angle: float):
90 | """Converts an arbitrary angle to the range [-pi pi]"""
91 | corrected_angle = angle % (2 * np.pi)
92 | if corrected_angle > np.pi:
93 | corrected_angle -= (2 * np.pi)
94 |
95 | return corrected_angle
96 |
97 | def trim_silence(audio, window_size=22050, cutoff=0.001):
98 | """Trims all silence within an audio file"""
99 | idx = 0
100 | new_audio = []
101 | while idx * window_size < audio.shape[1]:
102 | segment = audio[:, idx*window_size:(idx+1)*window_size]
103 | if segment.std() > cutoff:
104 | new_audio.append(segment)
105 | idx += 1
106 |
107 | return np.concatenate(new_audio, axis=1)
108 |
109 |
110 | def check_valid_dir(dir, requires_n_voices=2):
111 | """Checks that there is at least n voices"""
112 | if len(list(Path(dir).glob('*_voice00.wav'))) < 1:
113 | return False
114 |
115 | if requires_n_voices == 2:
116 | if len(list(Path(dir).glob('*_voice01.wav'))) < 1:
117 | return False
118 |
119 | if requires_n_voices == 3:
120 | if len(list(Path(dir).glob('*_voice02.wav'))) < 1:
121 | return False
122 |
123 | if requires_n_voices == 4:
124 | if len(list(Path(dir).glob('*_voice03.wav'))) < 1:
125 | return False
126 |
127 | if len(list(Path(dir).glob('metadata.json'))) < 1:
128 | return False
129 |
130 | return True
131 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # The Cone of Silence: Speech Separation by Localization
2 | 
3 |
4 | ## Authors
5 | [Teerapat Jenrungrot](https://mjenrungrot.com/)**\***, [Vivek Jayaram](http://www.vivekjayaram.com/)**\***, [Steve Seitz](https://homes.cs.washington.edu/~seitz/), and [Ira Kemelmacher-Shlizerman](https://homes.cs.washington.edu/~kemelmi/)
6 | *\*Co-First Authors*
7 | University of Washington
8 |
9 | ## [Project Page](http://grail.cs.washington.edu/projects/cone-of-silence/)
10 | Video and audio demos are available at the project page
11 |
12 | ### [Paper](https://arxiv.org/pdf/2010.06007.pdf)
13 | 34th Conference on Neural Information Processing Systems, NeurIPS 2020. (Oral)
14 |
15 | ### Blog Post - Coming Soon
16 |
17 | ### Summary
18 | Our method performs source separation and localization for human speakers. Key features include handling an arbitary number of speakers and moving speakers with a single network. This code allows you to run and evaluate our method on synthetically rendered data. If you have a multi-microphone array, you can also obtain real results like the ones in our demo video.
19 |
20 | ## Getting Started
21 | Clone the repository:
22 | ```
23 | git clone https://github.com/vivjay30/Cone-of-Silence.git
24 | cd Cone-of-Silence
25 | export PYTHONPATH=$PYTHONPATH:`pwd`
26 | ```
27 |
28 | Make sure all the requirements in the requirements.txt are installed. We tested the code with torch 1.3.0, librosa 0.7.0 and cuda 10.0
29 |
30 | Download Pretrained Models: [Here](https://drive.google.com/drive/folders/1YeuHPvqmaPMGvcSOb9J-hnLDYSbK1S2c?usp=sharing). If you're working in a command-line environment, we recommend using [gdown](https://github.com/wkentaro/gdown) to download the checkpoint files.
31 |
32 | ```
33 | cd checkpoints
34 | gdown --id 1OcLxp0s_TN78iKaFrLAqjIoTKeOTUgKw # Download realdata_4mics_.03231m_44100kHz.pt
35 | gdown --id 18dpUnng_8ZUlDrQsg5VymypFnFlQBPIp # Download synthetic_6mics_.0725m_44100kHz.pt
36 | ```
37 |
38 | ## Quickstart: Running on Real Data
39 | You can easily produce results like those in our demo videos. Our pre-trained real models work with the 4 mic [Seed ReSpeaker MicArray v 2.0](https://wiki.seeedstudio.com/ReSpeaker_Mic_Array_v2.0/). We even provide a sample 4 channel file for you to run [Here](https://drive.google.com/drive/folders/1YeuHPvqmaPMGvcSOb9J-hnLDYSbK1S2c?usp=sharing). When you capture the data, it must be a m channel recording. Run the full command like below. For moving sources, reduce the duration flag to 1.5 and add `--moving` to stop the search at a coarse window.
40 | ```
41 | python cos/inference/separation_by_localization.py \
42 | /path/to/model.pt \
43 | /path/to/input_file.wav \
44 | outputs/some_dirname/ \
45 | --n_channels 4 \
46 | --sr 44100 \
47 | --mic_radius .03231 \
48 | --use_cuda
49 | ```
50 |
51 | ## Rendering Synthetic Spatial Data
52 | For training and evaluation, we use synthetically rendered spatial data. We place the voices in a virtual room and render the arrival times, level differences, and reverb using pyroomacoustics. We used the VCTK dataset but any voice dataset would work. An example command is below
53 | ```
54 | python cos/generate_dataset.py \
55 | /path/to/VCTK/data \
56 | ./outputs/somename \
57 | --input_background_path any_bg_audio.wav \
58 | --n_voices 2 \
59 | --n_outputs 1000 \
60 | --mic_radius {radius} \
61 | --n_mics {M}
62 | ```
63 |
64 | ## Training on Synthetic Data
65 | Below is an example command to train on the rendered data. You need to replace the training and testing dirs with the path to the generated datasets from above. We highly recommend initializing with a pre-trained model (even if the number of mics is different) and not training from scratch.
66 | ```
67 | python cos/training/train.py \
68 | ./generated/train_dir \
69 | ./generated/test_dir \
70 | --name experiment_name \
71 | --checkpoints_dir ./checkpoints \
72 | --pretrain_path ./path/to/pretrained.pt \
73 | --batch_size 8 \
74 | --mic_radius {radius} \
75 | --n_mics {M} \
76 | --use_cuda
77 | ```
78 | __Note__: The training code expects you to have `sox` installed. The easiest way to install is to install it using conda as follows: `conda install -c conda-forge -y sox`.
79 |
80 | ## Training on Real Data
81 | For those looking to improve on the pretrained models, we recommend gathering a lot more real data. We did not have the ability to gather very accurately positioned real data in a proper sound chamber. By training with a lot more real data, the results will almost certainly improve. All you have to do is create synthetic composites of speakers in the same format as the synthetic data, and run the same training script.
82 |
83 | ## Evaluation
84 | For the synthetic data and evaluation, we use a setup of 6 mics in a circle of radius 7.25 cm. The following is instructions to obtain results on mixtures of N voices and no backgrounds. First generate a synthetic datset with the microphone setup specified previous with ```--n_voices 8``` from the test set of VCTK. Then run the following script:
85 |
86 | ```
87 | python cos/inference/evaluate_synthetic.py \
88 | /path/to/rendered_data/ \
89 | /path/to/model.pt \
90 | --n_channels 6 \
91 | --mic_radius .0725 \
92 | --sr 44100 \
93 | --use_cuda \
94 | --n_workers 1 \
95 | --n_voices {N}
96 | ```
97 |
98 | Add ```--prec_recall``` separately to get the precision and recall.
99 |
100 | | Number of Speakers N | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
101 | |----------------------|-------|-------|-------|-------|-------|-------|-------|
102 | | Median SI-SDRi (dB) | 13.9 | 13.2 | 12.2 | 10.8 | 9.1 | 7.2 | 6.3 |
103 | | Median Angular Error | 2.0 | 2.3 | 2.7 | 3.5 | 4.4 | 5.2 | 6.3 |
104 | | Precision | 0.947 | 0.936 | 0.897 | 0.912 | 0.932 | 0.936 | 0.966 |
105 | | Recall | 0.979 | 0.972 | 0.915 | 0.898 | 0.859 | 0825 | 0.785 |
106 |
107 |
108 |
--------------------------------------------------------------------------------
/cos/helpers/mwf.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import multiprocessing.dummy as mp
5 | import os
6 |
7 | from pathlib import Path
8 |
9 | import librosa
10 | import numpy as np
11 | import tqdm
12 |
13 | from scipy.signal import stft, istft
14 |
15 | from cos.helpers.eval_utils import compute_sdr
16 | from cos.inference.evaluate_synthetic import get_items
17 | from cos.helpers.utils import check_valid_dir
18 |
19 |
20 | def invert(M,eps):
21 | """"inverting matrices M (matrices are the two last dimensions).
22 | This is assuming that these are 2x2 matrices, using the explicit
23 | inversion formula available in that case."""
24 | invDet = 1.0/(eps + M[...,0,0]*M[...,1,1] - M[...,0,1]*M[...,1,0])
25 | invM = np.zeros(M.shape,dtype='complex')
26 | invM[...,0,0] = invDet*M[...,1,1]
27 | invM[...,1,0] = -invDet*M[...,1,0]
28 | invM[...,0,1] = -invDet*M[...,0,1]
29 | invM[...,1,1] = invDet*M[...,0,0]
30 | return invM
31 |
32 |
33 | def compute_mwf(gt, mix):
34 | """
35 | Computes the Ideal Ratio Mask SI-SDR
36 | gt: (n_voices, n_channels, t)
37 | mix: (n_channels, t)
38 | """
39 | n_voices = gt.shape[0]
40 | nfft = 2048
41 | hop = 1024
42 | eps = np.finfo(np.float).eps
43 |
44 | N = mix.shape[-1] # number of samples
45 | X = stft(mix, nperseg=nfft)[2]
46 | (I, F, T) = X.shape # (6, nfft//2 +1, n_frame)
47 |
48 | # Allocate variables P: PSD, R: Spatial Covarianc Matrices
49 | P = []
50 | R = []
51 | for gt_idx in range(n_voices):
52 | # compute STFT of target source
53 | Yj = stft(gt[gt_idx], nperseg=nfft)[2]
54 |
55 | # Learn Power Spectral Density and spatial covariance matrix
56 | #-----------------------------------------------------------
57 |
58 | # 1/ compute observed covariance for source
59 | Rjj = np.zeros((F,T,I,I), dtype='complex')
60 | for (i1,i2) in itertools.product(range(I),range(I)):
61 | Rjj[...,i1,i2] = Yj[i1,...]*np.conj(Yj[i2,...])
62 |
63 | # 2/ compute first naive estimate of the source spectrogram as the
64 | # average of spectrogram over channels
65 | P.append(np.mean(np.abs(Yj)**2,axis=0))
66 |
67 | # 3/ take the spatial covariance matrix as the average of
68 | # the observed Rjj weighted Rjj by 1/Pj. This is because the
69 | # covariance is modeled as Pj Rj
70 | R.append(np.mean(Rjj / (eps+P[-1][...,None,None]), axis = 1))
71 |
72 | # add some regularization to this estimate: normalize and add small
73 | # identify matrix, so we are sure it behaves well numerically.
74 | R[-1] = R[-1] * I/ np.trace(R[-1]) + eps * np.tile(np.eye(I,dtype='complex64')[None,...],(F,1,1))
75 |
76 | # 4/ Now refine the power spectral density estimate. This is to better
77 | # estimate the PSD in case the source has some correlations between
78 | # channels.
79 |
80 | # invert Rj
81 | Rj_inv = invert(R[-1],eps)
82 |
83 | # now compute the PSD
84 | P[-1]=0
85 | for (i1,i2) in itertools.product(range(I),range(I)):
86 | P[-1] += 1./I*np.real(Rj_inv[:,i1,i2][:,None]*Rjj[...,i2,i1])
87 |
88 | # All parameters are estimated. compute the mix covariance matrix as
89 | # the sum of the sources covariances.
90 | Cxx = 0
91 | for gt_idx in range(n_voices):
92 | Cxx += P[gt_idx][...,None,None]*R[gt_idx][:,None,...]
93 | # we need its inverse for computing the Wiener filter
94 | invCxx = invert(Cxx,eps)
95 |
96 | # perform separation
97 | estimates = []
98 | for gt_idx in range(n_voices):
99 | # computes multichannel Wiener gain as Pj Rj invCxx
100 | G = np.zeros(invCxx.shape,dtype='complex64')
101 | SR = P[gt_idx][...,None,None]*R[gt_idx][:,None,...]
102 | for (i1,i2,i3) in itertools.product(range(I),range(I),range(I)):
103 | G[...,i1,i2] += SR[...,i1,i3]*invCxx[...,i3,i2]
104 | SR = 0 #free memory
105 |
106 | # separates by (matrix-)multiplying this gain with the mix.
107 | Yj=0
108 | for i in range(I):
109 | Yj+=G[...,i]*X[i,...,None]
110 | Yj = np.rollaxis(Yj,-1) #gets channels back in first position
111 |
112 | # inverte to time domain
113 | target_estimate = istft(Yj)[1][:,:N]
114 |
115 | estimates.append(target_estimate)
116 |
117 | estimates = np.array(estimates) # (nvoice, 6, 6*sr)
118 | # eval
119 | eval_mix = np.repeat(mix[np.newaxis, :, :], n_voices, axis=0) # (nvoice, 6, 6*sr)
120 | eval_gt = gt # (nvoice, 6, 6*sr)
121 | eval_est = estimates
122 |
123 | SDR_in = []
124 | SDR_out = []
125 | for i in range(n_voices):
126 | SDR_in.append(compute_sdr(eval_gt[i], eval_mix[i], single_channel=False)) # scalar
127 | SDR_out.append(compute_sdr(eval_gt[i], eval_est[i], single_channel=False)) # scalar
128 |
129 | output = np.array([SDR_in, SDR_out]) # (2, nvoice)
130 | return output
131 |
132 |
133 | def main(args):
134 | all_dirs = sorted(list(Path(args.input_dir).glob('[0-9]*')))
135 | all_dirs = [x for x in all_dirs if check_valid_dir(x, args.n_voices)]
136 |
137 | all_input_sdr = [0] * len(all_dirs)
138 | all_output_sdr = [0] * len(all_dirs)
139 |
140 | def evaluate_dir(idx):
141 | curr_dir = all_dirs[idx]
142 | # Loads the data
143 | mixed_data, gt = get_items(curr_dir, args)
144 | gt = np.array([x.data for x in gt])
145 | output = compute_mwf(gt, mixed_data)
146 | all_input_sdr[idx] = output[0]
147 | all_output_sdr[idx] = output[1]
148 |
149 | pool = mp.Pool(args.n_workers)
150 | with tqdm.tqdm(total=len(all_dirs)) as pbar:
151 | for i, _ in enumerate(pool.imap_unordered(evaluate_dir, range(len(all_dirs)))):
152 | pbar.update()
153 |
154 | # tqdm.tqdm(pool.imap(evaluate_dir, range(len(all_dirs))), total=len(all_dirs))
155 | pool.close()
156 | pool.join()
157 |
158 | print("Median SI-SDRi: ",
159 | np.median(np.array(all_output_sdr).flatten() - np.array(all_input_sdr).flatten()))
160 |
161 | np.save("MWF_{}voices_{}kHz.npy".format(args.n_voices, args.sr),
162 | np.array([np.array(all_input_sdr).flatten(), np.array(all_output_sdr).flatten()]))
163 |
164 |
165 | if __name__ == '__main__':
166 | parser = argparse.ArgumentParser()
167 | parser.add_argument('input_dir', type=str, help="Path to the input dir")
168 | parser.add_argument('--sr', type=int, default=22050, help="Sampling rate")
169 | parser.add_argument('--n_channels',
170 | type=int,
171 | default=2,
172 | help="Number of channels")
173 | parser.add_argument('--n_workers',
174 | type=int,
175 | default=8,
176 | help="Number of parallel workers")
177 | parser.add_argument('--n_voices',
178 | type=int,
179 | default=2,
180 | help="Number of voices in the dataset")
181 | parser.add_argument('--alpha',
182 | type=int,
183 | default=1,
184 | help="See the original SigSep code for an explanation")
185 | args = parser.parse_args()
186 |
187 | main(args)
188 |
--------------------------------------------------------------------------------
/cos/training/synthetic_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Torch dataset object for synthetically rendered
3 | spatial data
4 | """
5 | import json
6 | import random
7 |
8 | from typing import Tuple
9 | from pathlib import Path
10 |
11 | import torch
12 | import numpy as np
13 | import librosa
14 |
15 | import cos.helpers.utils as utils
16 | from cos.helpers.constants import ALL_WINDOW_SIZES, FAR_FIELD_RADIUS
17 | from cos.training.data_augmentation import RandomAudioPerturbation
18 |
19 |
20 | class SyntheticDataset(torch.utils.data.Dataset):
21 | """
22 | Synthetic Dataset of mixed waveforms and their corresponding ground truth waveforms
23 | recorded at different microphone.
24 |
25 | Data format is a pair of Tensors containing mixed waveforms and
26 | ground truth waveforms respectively. The tensor's dimension is formatted
27 | as (n_microphone, duration).
28 | """
29 | def __init__(self, input_dir, n_mics=6, sr=44100, perturb_prob=0.0,
30 | window_idx=-1, negatives=0.2, mic_radius=0.0725):
31 | super().__init__()
32 | self.dirs = sorted(list(Path(input_dir).glob('[0-9]*')))
33 |
34 | # Physical params
35 | self.n_mics = n_mics
36 | self.sr = sr
37 | self.mic_radius = mic_radius
38 |
39 | # Data augmentation
40 | self.perturb_prob = perturb_prob
41 |
42 | # Training params
43 | self.negatives = negatives # Fraction of negatives in training
44 | self.window_idx = window_idx # Set to -1 to pick randomly
45 |
46 | def __len__(self) -> int:
47 | return len(self.dirs)
48 |
49 | def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
50 | """
51 | Returns:
52 | mixed_data - M x T
53 | target_voice_data - M x T
54 | window_idx_one_hot - 1-D
55 | """
56 | num_windows = len(ALL_WINDOW_SIZES)
57 | if self.window_idx == -1:
58 | curr_window_idx = np.random.randint(0, 5)
59 | else:
60 | curr_window_idx = self.window_idx
61 |
62 | curr_window_size = ALL_WINDOW_SIZES[curr_window_idx]
63 | candidate_angles = utils.get_starting_angles(curr_window_size)
64 |
65 | curr_dir = self.dirs[idx]
66 |
67 | # Get metadata
68 | with open(Path(curr_dir) / 'metadata.json') as json_file:
69 | metadata = json.load(json_file)
70 |
71 | # Random split of negatives and positives
72 | if np.random.uniform() < self.negatives:
73 | target_angle = self.get_negative_region(metadata, candidate_angles)
74 | else:
75 | target_angle = get_positive_region(metadata, candidate_angles)
76 |
77 | all_sources, target_voice_data = self.get_mixture_and_gt(
78 | metadata, curr_dir, target_angle, curr_window_size)
79 |
80 | # Mixture
81 | all_sources = torch.stack(all_sources, dim=0)
82 | mixed_data = torch.sum(all_sources, dim=0)
83 |
84 | # GTs
85 | target_voice_data = torch.stack(target_voice_data, dim=0)
86 | target_voice_data = torch.sum(target_voice_data, dim=0)
87 |
88 | window_idx_one_hot = torch.tensor(
89 | utils.to_categorical(curr_window_idx, num_windows)).float()
90 |
91 | return (mixed_data, target_voice_data, window_idx_one_hot)
92 |
93 | def get_negative_region(self, metadata, candidate_angles):
94 | """Chooses a target angle which is adjacent to a voice region"""
95 | # Choose a random voice
96 | voice_keys = [x for x in metadata if "voice" in x]
97 | random_key = random.choice(voice_keys)
98 | voice_pos = np.array(metadata[random_key]["position"])
99 | voice_angle = np.arctan2(voice_pos[1], voice_pos[0])
100 | angle_idx = (np.abs(candidate_angles - voice_angle)).argmin()
101 |
102 | # Non uniform distribution to prefer regions close to a voice
103 | p = np.zeros_like(candidate_angles)
104 | for i in range(p.shape[0]):
105 | if i == angle_idx:
106 | # Can't choose the positive region
107 | p[i] = 0
108 | else:
109 | # Regions close to the voice are weighted more
110 | dist = min(abs(i - angle_idx),
111 | (len(candidate_angles) - angle_idx + i))
112 | p[i] = 1 / (dist)
113 |
114 | p /= p.sum()
115 |
116 | # Make sure we choose a region with different per-channel shifts from the voice
117 | matching_shift = True
118 | _, true_shift = utils.shift_mixture(np.zeros(
119 | (self.n_mics, 10)), voice_pos, self.mic_radius, self.sr)
120 | while matching_shift:
121 | target_angle = np.random.choice(candidate_angles, p=p)
122 | random_pos = np.array([
123 | FAR_FIELD_RADIUS * np.cos(target_angle),
124 | FAR_FIELD_RADIUS * np.sin(target_angle)
125 | ])
126 | _, curr_shift = utils.shift_mixture(np.zeros(
127 | (self.n_mics, 10)), random_pos, self.mic_radius, self.sr)
128 | if true_shift != curr_shift:
129 | matching_shift = False
130 |
131 | return target_angle
132 |
133 | def get_mixture_and_gt(self, metadata, curr_dir, target_angle,
134 | curr_window_size):
135 | """
136 | Given a target angle and window size, this function figures out
137 | the voices inside the region and returns them as GT waveforms
138 | """
139 | target_pos = np.array([
140 | FAR_FIELD_RADIUS * np.cos(target_angle),
141 | FAR_FIELD_RADIUS * np.sin(target_angle)
142 | ])
143 | random_perturb = RandomAudioPerturbation()
144 |
145 | # Iterate over different sources
146 | all_sources = []
147 | target_voice_data = []
148 | for key in metadata.keys():
149 | gt_audio_files = sorted(
150 | list(Path(curr_dir).rglob("*" + key + ".wav")))
151 | assert len(gt_audio_files) > 0, "No files found in {}".format(
152 | curr_dir)
153 | gt_waveforms = []
154 |
155 | # Iterate over different mics
156 | for _, gt_audio_file in enumerate(gt_audio_files):
157 | gt_waveform, _ = librosa.core.load(gt_audio_file, self.sr,
158 | mono=True)
159 | gt_waveforms.append(torch.from_numpy(gt_waveform))
160 |
161 | shifted_gt, _ = utils.shift_mixture(np.stack(gt_waveforms),
162 | target_pos,
163 | self.mic_radius, self.sr)
164 |
165 | # Data augmentation
166 | if np.random.uniform() < self.perturb_prob:
167 | perturbed_source = torch.tensor(
168 | random_perturb(shifted_gt)).float()
169 | else:
170 | perturbed_source = torch.tensor(shifted_gt).float()
171 |
172 | all_sources.append(perturbed_source)
173 |
174 | # Check which foregrounds are in the angle of interest
175 | if "bg" in key:
176 | continue
177 |
178 | locs_voice = metadata[key]['position']
179 | voice_angle = np.arctan2(locs_voice[1], locs_voice[0])
180 |
181 | # Voice is inside our target area. Need to save for ground truth
182 | if abs(voice_angle - target_angle) < (curr_window_size / 2):
183 | target_voice_data.append(
184 | perturbed_source.view(perturbed_source.shape[0],
185 | perturbed_source.shape[1]))
186 |
187 | # Train with front back confusion for 2 mics
188 | elif self.n_mics == 2 and abs(-voice_angle - target_angle) < (
189 | curr_window_size / 2):
190 | target_voice_data.append(
191 | perturbed_source.view(perturbed_source.shape[0],
192 | perturbed_source.shape[1]))
193 |
194 | # Voice is not within our region. Add silence
195 | else:
196 | target_voice_data.append(
197 | torch.zeros((perturbed_source.shape[0],
198 | perturbed_source.shape[1])))
199 |
200 | return all_sources, target_voice_data
201 |
202 |
203 | def get_positive_region(metadata, candidate_angles):
204 | """Chooses a target angle containing a voice region"""
205 | # Choose a random voice
206 | voice_keys = [x for x in metadata if "voice" in x]
207 | random_key = random.choice(voice_keys)
208 | voice_pos = metadata[random_key]["position"]
209 | voice_pos = np.array(voice_pos)
210 | voice_angle = np.arctan2(voice_pos[1], voice_pos[0])
211 |
212 | # Get the sector closest to that voice
213 | angle_idx = (np.abs(candidate_angles - voice_angle)).argmin()
214 | target_angle = candidate_angles[angle_idx]
215 |
216 | return target_angle
217 |
--------------------------------------------------------------------------------
/cos/training/network.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def rescale_conv(conv, reference):
8 | """
9 | Rescale a convolutional module with `reference`.
10 | """
11 | std = conv.weight.std().detach()
12 | scale = (std / reference)**0.5
13 | conv.weight.data /= scale
14 | if conv.bias is not None:
15 | conv.bias.data /= scale
16 |
17 |
18 | def rescale_module(module, reference):
19 | """
20 | Rescale a module with `reference`.
21 | """
22 | for sub in module.modules():
23 | if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)):
24 | rescale_conv(sub, reference)
25 |
26 |
27 | def center_trim(tensor, reference):
28 | """
29 | Trim a tensor to match with the dimension of `reference`.
30 | """
31 | if hasattr(reference, "size"):
32 | reference = reference.size(-1)
33 | diff = tensor.size(-1) - reference
34 | if diff < 0:
35 | raise ValueError("tensor must be larger than reference")
36 | if diff:
37 | tensor = tensor[..., diff // 2:-(diff - diff // 2)]
38 | return tensor
39 |
40 |
41 | def left_trim(tensor, reference):
42 | """
43 | Trim a tensor to match with the dimension of `reference`. Trims only the end.
44 | """
45 | if hasattr(reference, "size"):
46 | reference = reference.size(-1)
47 | diff = tensor.size(-1) - reference
48 | if diff < 0:
49 | raise ValueError("tensor must be larger than reference")
50 | if diff:
51 | tensor = tensor[..., 0:-diff]
52 | return tensor
53 |
54 | def normalize_input(data):
55 | """
56 | Normalizes the input to have mean 0 std 1 for each input
57 | Inputs:
58 | data - torch.tensor of size batch x n_mics x n_samples
59 | """
60 | data = (data * 2**15).round() / 2**15
61 | ref = data.mean(1) # Average across the n microphones
62 | means = ref.mean(1).unsqueeze(1).unsqueeze(2)
63 | stds = ref.std(1).unsqueeze(1).unsqueeze(2)
64 | data = (data - means) / stds
65 |
66 | return data, means, stds
67 |
68 | def unnormalize_input(data, means, stds):
69 | """
70 | Unnormalizes the step done in the previous function
71 | """
72 | data = (data * stds.unsqueeze(3) + means.unsqueeze(3))
73 | return data
74 |
75 |
76 | class CoSNetwork(nn.Module):
77 | """
78 | Cone of Silence network based on the Demucs network for audio source separation.
79 | """
80 | def __init__(
81 | self,
82 | n_audio_channels: int = 4, # pylint: disable=redefined-outer-name
83 | window_conditioning_size: int = 5,
84 | kernel_size: int = 8,
85 | stride: int = 4,
86 | context: int = 3,
87 | depth: int = 6,
88 | channels: int = 64,
89 | growth: float = 2.0,
90 | lstm_layers: int = 2,
91 | rescale: float = 0.1): # pylint: disable=redefined-outer-name
92 | super().__init__()
93 | self.n_audio_channels = n_audio_channels
94 | self.window_conditioning_size = window_conditioning_size
95 | self.kernel_size = kernel_size
96 | self.stride = stride
97 | self.context = context
98 | self.depth = depth
99 | self.channels = channels
100 | self.growth = growth
101 | self.lstm_layers = lstm_layers
102 | self.rescale = rescale
103 |
104 | self.encoder = nn.ModuleList() # Source encoder
105 | self.decoder = nn.ModuleList() # Audio output decoder
106 |
107 | activation = nn.GLU(dim=1)
108 |
109 | in_channels = n_audio_channels # Number of input channels
110 |
111 | # Wave U-Net structure
112 | for index in range(depth):
113 | encode = nn.ModuleDict()
114 | encode["conv1"] = nn.Conv1d(in_channels, channels, kernel_size,
115 | stride)
116 | encode["relu"] = nn.ReLU()
117 |
118 | encode["conv2"] = nn.Conv1d(channels, 2 * channels, 1)
119 | encode["activation"] = activation
120 |
121 | encode["gc_embed1"] = nn.Conv1d(self.window_conditioning_size, channels, 1)
122 | encode["gc_embed2"] = nn.Conv1d(self.window_conditioning_size, 2 * channels, 1)
123 |
124 | self.encoder.append(encode)
125 |
126 | decode = nn.ModuleDict()
127 | if index > 0:
128 | out_channels = in_channels
129 | else:
130 | out_channels = 2 * n_audio_channels
131 |
132 | decode["conv1"] = nn.Conv1d(channels, 2 * channels, context)
133 | decode["activation"] = activation
134 | decode["conv2"] = nn.ConvTranspose1d(channels, out_channels,
135 | kernel_size, stride)
136 |
137 | decode["gc_embed1"] = nn.Conv1d(self.window_conditioning_size, 2 * channels, 1)
138 | decode["gc_embed2"] = nn.Conv1d(self.window_conditioning_size, out_channels, 1)
139 |
140 | if index > 0:
141 | decode["relu"] = nn.ReLU()
142 | self.decoder.insert(0,
143 | decode) # Put it at the front, reverse order
144 |
145 | in_channels = channels
146 | channels = int(growth * channels)
147 |
148 | # Bi-directional LSTM for the bottleneck layer
149 | channels = in_channels
150 | self.lstm = nn.LSTM(bidirectional=True,
151 | num_layers=lstm_layers,
152 | hidden_size=channels,
153 | input_size=channels)
154 | self.lstm_linear = nn.Linear(2 * channels, channels)
155 |
156 | rescale_module(self, reference=rescale)
157 |
158 | def forward(self, mix: torch.Tensor, angle_conditioning: torch.Tensor): # pylint: disable=arguments-differ
159 | """
160 | Forward pass. Note that in our current work the use of `locs` is disregarded.
161 |
162 | Args:
163 | mix (torch.Tensor) - An input recording of size `(batch_size, n_mics, time)`.
164 |
165 | Output:
166 | x - A source separation output at every microphone
167 | """
168 | x = mix
169 | saved = [x]
170 |
171 | # Encoder
172 | for encode in self.encoder:
173 | x = encode["conv1"](x) # Conv 1d
174 | embedding = encode["gc_embed1"](angle_conditioning.unsqueeze(2))
175 |
176 | x = encode["relu"](x + embedding)
177 | x = encode["conv2"](x)
178 |
179 | embedding2 = encode["gc_embed2"](angle_conditioning.unsqueeze(2))
180 | x = encode["activation"](x + embedding2)
181 | saved.append(x)
182 |
183 | # Bi-directional LSTM at the bottleneck layer
184 | x = x.permute(2, 0, 1) # prep input for LSTM
185 | self.lstm.flatten_parameters() # to improve memory usage.
186 | x = self.lstm(x)[0]
187 | x = self.lstm_linear(x)
188 | x = x.permute(1, 2, 0)
189 |
190 | # Source decoder
191 | for decode in self.decoder:
192 | skip = center_trim(saved.pop(-1), x)
193 | x = x + skip
194 |
195 | x = decode["conv1"](x)
196 | embedding = decode["gc_embed1"](angle_conditioning.unsqueeze(2))
197 | x = decode["activation"](x + embedding)
198 | x = decode["conv2"](x)
199 | embedding2 = decode["gc_embed2"](angle_conditioning.unsqueeze(2))
200 | if "relu" in decode:
201 | x = decode["relu"](x + embedding2)
202 |
203 | # Reformat the output
204 | x = x.view(x.size(0), 2, self.n_audio_channels, x.size(-1))
205 |
206 | return x
207 |
208 | def loss(self, voice_signals, gt_voice_signals):
209 | """Simple L1 loss between voice and gt"""
210 | return F.l1_loss(voice_signals, gt_voice_signals)
211 |
212 | def valid_length(self, length: int) -> int: # pylint: disable=redefined-outer-name
213 | """
214 | Find the length of the input to the network such that the output's length is
215 | equal to the given `length`.
216 | """
217 | for _ in range(self.depth):
218 | length = math.ceil((length - self.kernel_size) / self.stride) + 1
219 | length = max(1, length)
220 | length += self.context - 1
221 |
222 | for _ in range(self.depth):
223 | length = (length - 1) * self.stride + self.kernel_size
224 |
225 | return int(length)
226 |
227 |
228 | def load_pretrain(model, state_dict): # pylint: disable=redefined-outer-name
229 | """Loads the pretrained keys in state_dict into model"""
230 | for key in state_dict.keys():
231 | try:
232 | _ = model.load_state_dict({key: state_dict[key]}, strict=False)
233 | print("Loaded {} (shape = {}) from the pretrained model".format(
234 | key, state_dict[key].shape))
235 | except Exception as e:
236 | print("Failed to load {}".format(key))
237 | print(e)
238 |
--------------------------------------------------------------------------------
/cos/generate_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import argparse
5 | import json
6 | from typing import List
7 | from pathlib import Path
8 | import tqdm
9 | import random
10 |
11 | import multiprocessing.dummy as mp
12 |
13 | import numpy as np
14 | import librosa
15 | import pyroomacoustics as pra
16 | import soundfile as sf
17 |
18 | # Mean and STD of the signal peak
19 | FG_VOL_MIN = 0.15
20 | FG_VOL_MAX = 0.4
21 |
22 | BG_VOL_MIN = 0.2
23 | BG_VOL_MAX = 0.5
24 |
25 |
26 | def generate_mic_array(room, mic_radius: float, n_mics: int):
27 | """
28 | Generate a list of Microphone objects
29 |
30 | Radius = 50th percentile of men Bitragion breadth
31 | (https://en.wikipedia.org/wiki/Human_head)
32 | """
33 | R = pra.circular_2D_array(center=[0., 0.], M=n_mics, phi0=0, radius=mic_radius)
34 | room.add_microphone_array(pra.MicrophoneArray(R, room.fs))
35 |
36 |
37 | def handle_error(e):
38 | print(e)
39 |
40 |
41 | def get_voices(args):
42 | # Make sure we dont get an empty sequence
43 | success = False
44 | while not success:
45 | voice_files = random.sample(args.all_voices, args.n_voices)
46 | # Save the identity also. This is VCTK specific
47 | success = True
48 | voices_data = []
49 | for voice_file in voice_files:
50 | voice_identity = str(voice_file).split("/")[-1].split("_")[0]
51 | voice, _ = librosa.core.load(voice_file, sr=args.sr, mono=True)
52 | voice, _ = librosa.effects.trim(voice)
53 | if voice.std() == 0:
54 | success = False
55 | voices_data.append((voice, voice_identity))
56 |
57 | return voices_data
58 |
59 |
60 | def generate_sample(args: argparse.Namespace, bg: np.ndarray, idx: int) -> int:
61 | """
62 | Generate a single sample. Return 0 on success.
63 |
64 | Steps:
65 | - [1] Load voice
66 | - [2] Sample background with the same length as voice.
67 | - [3] Pick background location
68 | - [4] Create a scene
69 | - [5] Render sound
70 | - [6] Save metadata
71 | """
72 | # [1] load voice
73 | output_prefix_dir = os.path.join(args.output_path, '{:05d}'.format(idx))
74 | Path(output_prefix_dir).mkdir(parents=True, exist_ok=True)
75 |
76 | voices_data = get_voices(args)
77 |
78 | # [2]
79 | total_samples = int(args.duration * args.sr)
80 | if bg is not None:
81 | bg_length = len(bg)
82 | bg_start_idx = np.random.randint(bg_length - total_samples)
83 | sample_bg = bg[bg_start_idx:bg_start_idx + total_samples]
84 |
85 | # Generate room parameters, each scene has a random room and absorption
86 | left_wall = np.random.uniform(low=-20, high=-15)
87 | right_wall = np.random.uniform(low=15, high=20)
88 | top_wall = np.random.uniform(low=15, high=20)
89 | bottom_wall = np.random.uniform(low=-20, high=-15)
90 | absorption = np.random.uniform(low=0.1, high=0.99)
91 | corners = np.array([[left_wall, bottom_wall], [left_wall, top_wall],
92 | [ right_wall, top_wall], [right_wall, bottom_wall]]).T
93 |
94 | # FG
95 | all_fg_signals = []
96 | voice_positions = []
97 | for voice_idx in range(args.n_voices):
98 | # Need to re-generate room to save GT. Could probably be optimized
99 | room = pra.Room.from_corners(corners,
100 | fs=args.sr,
101 | max_order=10,
102 | absorption=absorption)
103 | mic_array = generate_mic_array(room, args.mic_radius, args.n_mics)
104 |
105 | voice_radius = np.random.uniform(low=1.0, high=5.0)
106 | voice_theta = np.random.uniform(low=0, high=2 * np.pi)
107 | voice_loc = [
108 | voice_radius * np.cos(voice_theta),
109 | voice_radius * np.sin(voice_theta)
110 | ]
111 |
112 | voice_positions.append(voice_loc)
113 | room.add_source(voice_loc, signal=voices_data[voice_idx][0])
114 |
115 | room.image_source_model(use_libroom=True)
116 | room.simulate()
117 | fg_signals = room.mic_array.signals[:, :total_samples]
118 | fg_target = np.random.uniform(FG_VOL_MIN, FG_VOL_MAX)
119 | fg_signals = fg_signals * fg_target / abs(fg_signals).max()
120 | all_fg_signals.append(fg_signals)
121 |
122 | # BG
123 | if bg is not None:
124 | bg_radius = np.random.uniform(low=10.0, high=20.0)
125 | bg_theta = np.random.uniform(low=0, high=2 * np.pi)
126 | bg_loc = [bg_radius * np.cos(bg_theta), bg_radius * np.sin(bg_theta)]
127 |
128 | # Bg should be further away to be diffuse
129 | left_wall = np.random.uniform(low=-40, high=-20)
130 | right_wall = np.random.uniform(low=20, high=40)
131 | top_wall = np.random.uniform(low=20, high=40)
132 | bottom_wall = np.random.uniform(low=-40, high=-20)
133 | corners = np.array([[left_wall, bottom_wall], [left_wall, top_wall],
134 | [ right_wall, top_wall], [right_wall, bottom_wall]]).T
135 | absorption = np.random.uniform(low=0.5, high=0.99)
136 | room = pra.Room.from_corners(corners,
137 | fs=args.sr,
138 | max_order=10,
139 | absorption=absorption)
140 | mic_array = generate_mic_array(room, args.mic_radius, args.n_mics)
141 | room.add_source(bg_loc, signal=sample_bg)
142 |
143 | room.image_source_model(use_libroom=True)
144 | room.simulate()
145 | bg_signals = room.mic_array.signals[:, :total_samples]
146 | bg_target = np.random.uniform(BG_VOL_MIN, BG_VOL_MAX)
147 | bg_signals = bg_signals * bg_target / abs(bg_signals).max()
148 |
149 | # Save
150 | for mic_idx in range(args.n_mics):
151 | output_prefix = str(
152 | Path(output_prefix_dir) / "mic{:02d}_".format(mic_idx))
153 |
154 | # Save FG
155 | all_fg_buffer = np.zeros((total_samples))
156 | for voice_idx in range(args.n_voices):
157 | curr_fg_buffer = np.pad(all_fg_signals[voice_idx][mic_idx],
158 | (0, total_samples))[:total_samples]
159 | sf.write(output_prefix + "voice{:02d}.wav".format(voice_idx),
160 | curr_fg_buffer, args.sr)
161 | all_fg_buffer += curr_fg_buffer
162 |
163 | if bg is not None:
164 | bg_buffer = np.pad(bg_signals[mic_idx],
165 | (0, total_samples))[:total_samples]
166 | sf.write(output_prefix + "bg.wav", bg_buffer, args.sr)
167 |
168 | sf.write(output_prefix + "mixed.wav", all_fg_buffer + bg_buffer,
169 | args.sr)
170 | else:
171 | sf.write(output_prefix + "mixed.wav", all_fg_buffer,
172 | args.sr)
173 |
174 | # [6]
175 | metadata = {}
176 | for voice_idx in range(args.n_voices):
177 | metadata['voice{:02d}'.format(voice_idx)] = {
178 | 'position': voice_positions[voice_idx],
179 | 'speaker_id': voices_data[voice_idx][1]
180 | }
181 |
182 | if bg is not None:
183 | metadata['bg'] = {'position': bg_loc}
184 |
185 | metadata_file = str(Path(output_prefix_dir) / "metadata.json")
186 | with open(metadata_file, "w") as f:
187 | json.dump(metadata, f, indent=4)
188 |
189 |
190 | def main(args: argparse.Namespace):
191 | np.random.seed(args.seed)
192 |
193 | # Preload background to save time
194 | if args.input_background_path:
195 | background, _ = librosa.core.load(args.input_background_path,
196 | sr=args.sr,
197 | mono=True)
198 | else:
199 | background = None
200 |
201 | all_voices = Path(args.input_voice_dir).rglob('*.wav')
202 | args.all_voices = list(all_voices)
203 | if len(args.all_voices) == 0:
204 | raise ValueError("No voice files found")
205 |
206 | pbar = tqdm.tqdm(total=args.n_outputs)
207 | pool = mp.Pool(args.n_workers)
208 | callback_fn = lambda _: pbar.update()
209 | for i in range(args.n_outputs):
210 | pool.apply_async(generate_sample,
211 | args=(args, background, i),
212 | callback=callback_fn,
213 | error_callback=handle_error)
214 | pool.close()
215 | pool.join()
216 | pbar.close()
217 |
218 |
219 | if __name__ == '__main__':
220 | parser = argparse.ArgumentParser()
221 | parser.add_argument('input_voice_dir',
222 | type=str,
223 | help="Directory with voice wav files")
224 | parser.add_argument('output_path', type=str, help="Output directory to write the synthetic dataset")
225 | parser.add_argument('--input_background_path',
226 | type=str)
227 | parser.add_argument('--n_mics', type=int, default=4)
228 | parser.add_argument('--mic_radius',
229 | default=.03231,
230 | type=float,
231 | help="Radius of the mic array in meters")
232 | parser.add_argument('--n_voices', type=int, default=4)
233 | parser.add_argument('--n_outputs', type=int, default=10000)
234 | parser.add_argument('--n_workers', type=int, default=8)
235 | parser.add_argument('--seed', type=int, default=42)
236 | parser.add_argument('--sr', type=int, default=44100)
237 | parser.add_argument('--duration', type=float, default=3.0)
238 | main(parser.parse_args())
239 |
--------------------------------------------------------------------------------
/cos/training/train.py:
--------------------------------------------------------------------------------
1 | """
2 | The main training script for training on synthetic data
3 | """
4 |
5 | import argparse
6 | import multiprocessing
7 | import os
8 |
9 | from typing import Dict, List, Tuple, Optional # pylint: disable=unused-import
10 | from pathlib import Path
11 |
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | import torch.optim as optim
16 | import torch.nn.functional as F
17 | import tqdm # pylint: disable=unused-import
18 |
19 | from cos.training.synthetic_dataset import SyntheticDataset
20 | from cos.training.network import CoSNetwork, \
21 | center_trim, load_pretrain, \
22 | normalize_input, unnormalize_input
23 |
24 |
25 | def train_epoch(model: nn.Module, device: torch.device,
26 | optimizer: optim.Optimizer,
27 | train_loader: torch.utils.data.dataloader.DataLoader,
28 | epoch: int, log_interval: int = 20) -> float:
29 | """
30 | Train a single epoch.
31 | """
32 | # Set the model to training.
33 | model.train()
34 |
35 | # Training loop
36 | losses = []
37 | interval_losses = []
38 |
39 | for batch_idx, (data, label_voice_signals,
40 | window_idx) in enumerate(train_loader):
41 | data = data.to(device)
42 | label_voice_signals = label_voice_signals.to(device)
43 | window_idx = window_idx.to(device)
44 |
45 | # Normalize input, each batch item separately
46 | data, means, stds = normalize_input(data)
47 |
48 | # Reset grad
49 | optimizer.zero_grad()
50 |
51 | # Run through the model
52 | valid_length = model.valid_length(data.shape[-1])
53 | delta = valid_length - data.shape[-1]
54 | padded = F.pad(data, (delta // 2, delta - delta // 2))
55 |
56 | output_signal = model(padded, window_idx)
57 | output_signal = center_trim(output_signal, data)
58 |
59 | # Un-normalize
60 | output_signal = unnormalize_input(output_signal, means, stds)
61 | output_voices = output_signal[:, 0]
62 |
63 | loss = model.loss(output_voices, label_voice_signals)
64 |
65 | interval_losses.append(loss.item())
66 |
67 | # Backpropagation
68 | loss.backward()
69 |
70 | # Gradient clipping
71 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
72 |
73 | # Update the weights
74 | optimizer.step()
75 |
76 | # Print the loss
77 | if batch_idx % log_interval == 0:
78 | print("Train Epoch: {} [{}/{} ({:.0f}%)] \t Loss: {:.6f}".format(
79 | epoch, batch_idx * len(data), len(train_loader.dataset),
80 | 100. * batch_idx / len(train_loader),
81 | np.mean(interval_losses)))
82 |
83 | losses.extend(interval_losses)
84 | interval_losses = []
85 |
86 | return np.mean(losses)
87 |
88 |
89 | def test_epoch(model: nn.Module, device: torch.device,
90 | test_loader: torch.utils.data.dataloader.DataLoader,
91 | log_interval: int = 20) -> float:
92 | """
93 | Evaluate the network.
94 | """
95 | model.eval()
96 | test_loss = 0
97 |
98 | with torch.no_grad():
99 | for batch_idx, (data, label_voice_signals,
100 | window_idx) in enumerate(test_loader):
101 | data = data.to(device)
102 | label_voice_signals = label_voice_signals.to(device)
103 | window_idx = window_idx.to(device)
104 |
105 | # Normalize input, each batch item separately
106 | data, means, stds = normalize_input(data)
107 |
108 | valid_length = model.valid_length(data.shape[-1])
109 | delta = valid_length - data.shape[-1]
110 | padded = F.pad(data, (delta // 2, delta - delta // 2))
111 |
112 | # Run through the model
113 | output_signal = model(padded, window_idx)
114 | output_signal = center_trim(output_signal, data)
115 |
116 | # Un-normalize
117 | output_signal = unnormalize_input(output_signal, means, stds)
118 | output_voices = output_signal[:, 0]
119 |
120 | loss = model.loss(output_voices, label_voice_signals)
121 | test_loss += loss.item()
122 |
123 | if batch_idx % log_interval == 0:
124 | print("Loss: {}".format(loss))
125 |
126 | test_loss /= len(test_loader)
127 | print("\nTest set: Average Loss: {:.4f}\n".format(test_loss))
128 |
129 | return test_loss
130 |
131 |
132 | def train(args: argparse.Namespace):
133 | """
134 | Train the network.
135 | """
136 | # Load dataset
137 | data_train = SyntheticDataset(args.train_dir, n_mics=args.n_mics,
138 | sr=args.sr, perturb_prob=1.0,
139 | mic_radius=args.mic_radius)
140 | data_test = SyntheticDataset(args.test_dir, n_mics=args.n_mics,
141 | sr=args.sr, mic_radius=args.mic_radius)
142 |
143 | # Set up the device and workers.
144 | use_cuda = args.use_cuda and torch.cuda.is_available()
145 | device = torch.device('cuda' if use_cuda else 'cpu')
146 | print("Using device {}".format('cuda' if use_cuda else 'cpu'))
147 |
148 | # Set multiprocessing params
149 | num_workers = min(multiprocessing.cpu_count(), args.n_workers)
150 | kwargs = {
151 | 'num_workers': num_workers,
152 | 'pin_memory': True
153 | } if use_cuda else {}
154 |
155 | # Set up data loaders
156 | train_loader = torch.utils.data.DataLoader(data_train,
157 | batch_size=args.batch_size,
158 | shuffle=True, **kwargs)
159 | test_loader = torch.utils.data.DataLoader(data_test,
160 | batch_size=args.batch_size,
161 | **kwargs)
162 |
163 | # Set up model
164 | model = CoSNetwork(n_audio_channels=args.n_mics)
165 | model.to(device)
166 |
167 | # Set up checkpoints
168 | if not os.path.exists(os.path.join(args.checkpoints_dir, args.name)):
169 | os.makedirs(os.path.join(args.checkpoints_dir, args.name))
170 |
171 | # Set up the optimizer
172 | optimizer = optim.Adam(model.parameters(), lr=args.lr,
173 | weight_decay=args.decay)
174 |
175 | # Load pretrain
176 | if args.pretrain_path:
177 | state_dict = torch.load(args.pretrain_path)
178 | load_pretrain(model, state_dict)
179 |
180 | # Load the model if `args.start_epoch` is greater than 0. This will load the model from
181 | # epoch = `args.start_epoch - 1`
182 | if args.start_epoch is not None:
183 | assert args.start_epoch > 0, "start_epoch must be greater than 0."
184 | start_epoch = args.start_epoch
185 | checkpoint_path = Path(
186 | args.checkpoints_dir) / "{}.pt".format(start_epoch - 1)
187 | state_dict = torch.load(checkpoint_path)
188 | model.load_state_dict(state_dict)
189 | else:
190 | start_epoch = 0
191 |
192 | # Loss values
193 | train_losses = []
194 | test_losses = []
195 |
196 | # Training loop
197 | try:
198 | for epoch in range(start_epoch, args.epochs + 1):
199 | train_loss = train_epoch(model, device, optimizer, train_loader,
200 | epoch, args.print_interval)
201 | torch.save(
202 | model.state_dict(),
203 | os.path.join(args.checkpoints_dir, args.name,
204 | "{}.pt".format(epoch)))
205 | print("Done with training, going to testing")
206 | test_loss = test_epoch(model, device, test_loader,
207 | args.print_interval)
208 | train_losses.append((epoch, train_loss))
209 | test_losses.append((epoch, test_loss))
210 |
211 | return train_losses, test_losses
212 |
213 | except KeyboardInterrupt:
214 | print("Interrupted")
215 | except Exception as _: # pylint: disable=broad-except
216 | import traceback # pylint: disable=import-outside-toplevel
217 | traceback.print_exc()
218 |
219 |
220 | if __name__ == '__main__':
221 | parser = argparse.ArgumentParser()
222 | # Data Params
223 | parser.add_argument('train_dir', type=str,
224 | help="Path to the training dataset")
225 | parser.add_argument('test_dir', type=str,
226 | help="Path to the testing dataset")
227 | parser.add_argument('--name', type=str, default="multimic_experiment",
228 | help="Name of the experiment")
229 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints',
230 | help="Path to the checkpoints")
231 | parser.add_argument('--batch_size', type=int, default=8,
232 | help="Batch size")
233 |
234 | # Physical Params
235 | parser.add_argument('--n_mics', type=int, default=4,
236 | help="Number of mics (also number of channels)")
237 | parser.add_argument('--mic_radius', default=.03231, type=float,
238 | help="Radius in meters of the mic array")
239 |
240 | # Training Params
241 | parser.add_argument('--epochs', type=int, default=100,
242 | help="Number of epochs")
243 | parser.add_argument('--lr', type=float, default=3e-4, help="learning rate")
244 | parser.add_argument('--sr', type=int, default=44100, help="Sampling rate")
245 | parser.add_argument('--decay', type=float, default=0, help="Weight decay")
246 | parser.add_argument('--n_workers', type=int, default=16,
247 | help="Number of parallel workers")
248 | parser.add_argument('--print_interval', type=int, default=20,
249 | help="Logging interval")
250 |
251 | parser.add_argument('--start_epoch', type=int, default=None,
252 | help="Start epoch")
253 | parser.add_argument('--pretrain_path', type=str,
254 | help="Path to pretrained weights")
255 | parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
256 | help="Whether to use cuda")
257 |
258 | train(parser.parse_args())
259 |
--------------------------------------------------------------------------------
/cos/inference/evaluate_synthetic.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | from pathlib import Path
6 |
7 | import torch
8 | import numpy as np
9 | import librosa
10 | import soundfile as sf
11 | import tqdm
12 |
13 | from cos.helpers.eval_utils import find_best_permutation_prec_recall, compute_sdr
14 | from cos.helpers.utils import angular_distance, check_valid_dir
15 | from cos.training.network import CoSNetwork
16 | from cos.inference.separation_by_localization import run_separation, CandidateVoice
17 |
18 | import multiprocessing.dummy as mp
19 | from multiprocessing import Lock
20 |
21 |
22 | def get_items(curr_dir, args):
23 | """
24 | This is a modified version of the SpatialAudioDataset DataLoader
25 | """
26 | with open(Path(curr_dir) / 'metadata.json') as json_file:
27 | json_data = json.load(json_file)
28 |
29 | num_voices = args.n_voices
30 | mic_files = sorted(list(Path(curr_dir).rglob('*mixed.wav')))
31 |
32 | # All voice signals
33 | keys = ["voice{:02}".format(i) for i in range(num_voices)]
34 |
35 | # Comment out this line to do voice only, no bg
36 | if "bg" in json_data:
37 | keys.append("bg")
38 | """
39 | Loading the sources
40 | """
41 | # Iterate over different sources
42 | all_sources = []
43 | target_voice_data = []
44 | voice_positions = []
45 | for key in keys:
46 | gt_audio_files = sorted(list(Path(curr_dir).rglob("*" + key + ".wav")))
47 | assert (len(gt_audio_files) > 0)
48 | gt_waveforms = []
49 |
50 | # Iterate over different mics
51 | for _, gt_audio_file in enumerate(gt_audio_files):
52 | gt_waveform, _ = librosa.core.load(gt_audio_file, args.sr,
53 | mono=True)
54 | gt_waveforms.append(gt_waveform)
55 |
56 | single_source = np.stack(gt_waveforms)
57 | all_sources.append(single_source)
58 | locs_voice = np.arctan2(json_data[key]["position"][1],
59 | json_data[key]["position"][0])
60 | voice_positions.append(locs_voice)
61 |
62 | all_sources = np.stack(all_sources) # n voices x n mics x n samples
63 | mixed_data = np.sum(all_sources, axis=0) # n mics x n samples
64 |
65 | gt = [
66 | CandidateVoice(voice_positions[i], None, all_sources[i])
67 | for i in range(num_voices)
68 | ]
69 |
70 | return mixed_data, gt
71 |
72 |
73 | def main(args):
74 | args.moving = False
75 | device = torch.device('cuda') if args.use_cuda else torch.device('cpu')
76 |
77 | args.device = device
78 | model = CoSNetwork(n_audio_channels=args.n_channels)
79 | model.load_state_dict(torch.load(args.model_checkpoint), strict=True)
80 | model.train = False
81 | model.to(device)
82 |
83 | all_dirs = sorted(list(Path(args.test_dir).glob('[0-9]*')))
84 | all_dirs = [x for x in all_dirs if check_valid_dir(x, args.n_voices)]
85 |
86 | if args.prec_recall and args.oracle_position:
87 | raise(ValueError("Either specify prec recall or oracle position"))
88 |
89 | if args.prec_recall:
90 | # True positives, false negatives, false positives
91 | all_tp, all_fn, all_fp = [], [], []
92 |
93 | else:
94 | # Placeholders to support multiprocessing
95 | all_angle_errors = [0] * len(all_dirs)
96 | all_input_sdr = [0] * len(all_dirs)
97 | all_output_sdr = [0] * len(all_dirs)
98 |
99 | gpu_lock = Lock()
100 |
101 | def evaluate_dir(idx):
102 | if args.debug:
103 | curr_writing_dir = "{:05d}".format(idx)
104 | if not os.path.exists(curr_writing_dir):
105 | os.makedirs(curr_writing_dir)
106 | args.writing_dir = curr_writing_dir
107 |
108 | curr_dir = all_dirs[idx]
109 |
110 | # Loads the data
111 | mixed_data, gt = get_items(curr_dir, args)
112 |
113 | # Prevents CUDA out of memory
114 | gpu_lock.acquire()
115 | if args.prec_recall:
116 | # Case where we don't know the number of sources
117 | candidate_voices = run_separation(mixed_data, model, args)
118 |
119 | # Case where we know the number of sources
120 | else:
121 | # Normal run
122 | if not args.oracle_position:
123 | candidate_voices = run_separation(mixed_data, model, args, 0.005)
124 | # In order to compute SDR or angle error, the number of outputs must match gt
125 | # We set a very low threshold to ensure we get the correct number of outputs
126 | if args.oracle_position or len(candidate_voices) < len(gt):
127 | print("Had to go again\n")
128 | candidate_voices = run_separation(mixed_data, model, args, 0.000001)
129 |
130 | # Use the GT positions to find the best sources
131 | if args.oracle_position:
132 | trimmed_voices = []
133 | for gt_idx in range(args.n_voices):
134 | best_idx = np.argmin(np.array([angular_distance(x.angle,
135 | gt[gt_idx].angle) for x in candidate_voices]))
136 | trimmed_voices.append(candidate_voices[best_idx])
137 | candidate_voices = trimmed_voices
138 |
139 | # Take the top N voices
140 | else:
141 | candidate_voices = candidate_voices[:args.n_voices]
142 | if len(candidate_voices) != len(gt):
143 | print(f"Not enough outputs for dir {curr_dir}. Lower threshold to evaluate.")
144 | return
145 |
146 | if args.debug:
147 | sf.write(os.path.join(args.writing_dir, "mixed.wav"),
148 | mixed_data[0],
149 | args.sr)
150 | for voice in candidate_voices:
151 | fname = "out_angle{:.2f}.wav".format(
152 | voice.angle * 180 / np.pi)
153 | sf.write(os.path.join(args.writing_dir, fname), voice.data[0],
154 | args.sr)
155 |
156 | gpu_lock.release()
157 | curr_angle_errors = []
158 | curr_input_sdr = []
159 | curr_output_sdr = []
160 |
161 | best_permutation, (tp, fn, fp) = find_best_permutation_prec_recall(
162 | [x.angle for x in gt], [x.angle for x in candidate_voices])
163 |
164 | if args.prec_recall:
165 | all_tp.append(tp)
166 | all_fn.append(fn)
167 | all_fp.append(fp)
168 |
169 | # Evaluate SDR and Angular Error
170 | else:
171 | for gt_idx, output_idx in enumerate(best_permutation):
172 | angle_error = angular_distance(candidate_voices[output_idx].angle,
173 | gt[gt_idx].angle)
174 | # print(angle_error)
175 | curr_angle_errors.append(angle_error)
176 |
177 | # To speed up we only evaluate channel 0. For rigorous results
178 | # set that to false
179 | input_sdr = compute_sdr(gt[gt_idx].data, mixed_data,
180 | single_channel=True)
181 | output_sdr = compute_sdr(gt[gt_idx].data,
182 | candidate_voices[output_idx].data, single_channel=True)
183 |
184 | curr_input_sdr.append(input_sdr)
185 | curr_output_sdr.append(output_sdr)
186 |
187 | # print(curr_input_sdr)
188 | # print(curr_output_sdr)
189 |
190 | all_angle_errors[idx] = curr_angle_errors
191 | all_input_sdr[idx] = curr_input_sdr
192 | all_output_sdr[idx] = curr_output_sdr
193 |
194 | # print("Running median angle error: {}".format(np.median(np.array(all_angle_errors[:idx+1])) * 180 / np.pi))
195 | # print("Running median SDRi: ",
196 | # np.median(np.array(all_output_sdr[:idx+1]) - np.array(all_input_sdr[:idx+1])))
197 |
198 | pool = mp.Pool(args.n_workers)
199 | with tqdm.tqdm(total=len(all_dirs)) as pbar:
200 | for i, _ in enumerate(pool.imap_unordered(evaluate_dir, range(len(all_dirs)))):
201 | pbar.update()
202 | pool.close()
203 | pool.join()
204 |
205 |
206 | # Print and save the outputs
207 | if args.prec_recall:
208 | print("Overall Precision: {} Recall: {}".format(
209 | sum(all_tp) / (sum(all_tp) + sum(all_fp)),
210 | sum(all_tp) / (sum(all_tp) + sum(all_fn))))
211 |
212 | else:
213 | print("Median Angular Error: ", np.median(np.array(all_angle_errors)) * 180 / np.pi)
214 | print("Median SDRi: ",
215 | np.median(np.array(all_output_sdr) - np.array(all_input_sdr)))
216 | # Uncomment to save the data for visualization
217 | # np.save("angleerror_{}voices_{}kHz.npy".format(args.n_voices, args.sr),
218 | # np.array(all_angle_errors).flatten())
219 | # np.save("SDR_{}voices_{}kHz.npy".format(args.n_voices, args.sr),
220 | # np.array([np.array(all_input_sdr).flatten(), np.array(all_output_sdr).flatten()]))
221 |
222 |
223 |
224 | if __name__ == '__main__':
225 | parser = argparse.ArgumentParser()
226 | parser.add_argument('test_dir', type=str,
227 | help="Path to the testing directory")
228 | parser.add_argument('model_checkpoint', type=str,
229 | help="Path to the model file")
230 | parser.add_argument('--sr', type=int, default=22050, help="Sampling rate")
231 | parser.add_argument('--n_channels', type=int, default=2,
232 | help="Number of channels")
233 | parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
234 | help="Whether to use cuda")
235 | parser.add_argument('--debug', action='store_true', help="Save outputs")
236 | parser.add_argument('--mic_radius', default=.0725, type=float,
237 | help="To do")
238 | parser.add_argument('--n_workers', default=8, type=int,
239 | help="Multiprocessing")
240 | parser.add_argument(
241 | '--n_voices', default=2, type=int, help=
242 | "Number of voices in the GT scenarios. \
243 | Useful so you can re-use the same dataset with different number of fg sources"
244 | )
245 | parser.add_argument(
246 | '--prec_recall', action='store_true', help=
247 | "To compute precision and recall, we don't let the network know the number of sources"
248 | )
249 | parser.add_argument(
250 | '--oracle_position', action='store_true', help=
251 | "Compute the separation results if you know the GT positions"
252 | )
253 |
254 | print(parser.parse_args())
255 | main(parser.parse_args())
256 |
--------------------------------------------------------------------------------
/cos/inference/separation_by_localization.py:
--------------------------------------------------------------------------------
1 | """
2 | The main separation by localization inference algorithm
3 | """
4 |
5 | import argparse
6 | import os
7 |
8 | from collections import namedtuple
9 |
10 | import librosa
11 | import torch
12 | import torch.nn.functional as F
13 | import numpy as np
14 | import soundfile as sf
15 |
16 | import cos.helpers.utils as utils
17 |
18 | from cos.helpers.constants import ALL_WINDOW_SIZES, \
19 | FAR_FIELD_RADIUS
20 | from cos.helpers.visualization import draw_diagram
21 | from cos.training.network import CoSNetwork, center_trim, \
22 | normalize_input, unnormalize_input
23 | from cos.helpers.eval_utils import si_sdr
24 |
25 | # Constants which may be tweaked based on your setup
26 | ENERGY_CUTOFF = 0.002
27 | NMS_RADIUS = np.pi / 4
28 | NMS_SIMILARITY_SDR = -7.0 # SDR cutoff for different candidates
29 |
30 | CandidateVoice = namedtuple("CandidateVoice", ["angle", "energy", "data"])
31 |
32 |
33 | def nms(candidate_voices, nms_cutoff):
34 | """
35 | Runs non-max suppression on the candidate voices
36 | """
37 | final_proposals = []
38 | initial_proposals = candidate_voices
39 |
40 | while len(initial_proposals) > 0:
41 | new_initial_proposals = []
42 | sorted_candidates = sorted(initial_proposals,
43 | key=lambda x: x[1],
44 | reverse=True)
45 |
46 | # Choose the loudest voice
47 | best_candidate_voice = sorted_candidates[0]
48 | final_proposals.append(best_candidate_voice)
49 | sorted_candidates.pop(0)
50 |
51 | # See if any of the rest should be removed
52 | for candidate_voice in sorted_candidates:
53 | different_locations = utils.angular_distance(
54 | candidate_voice.angle, best_candidate_voice.angle) > NMS_RADIUS
55 |
56 | # different_content = abs(
57 | # candidate_voice.data -
58 | # best_candidate_voice.data).mean() > nms_cutoff
59 |
60 | different_content = si_sdr(
61 | candidate_voice.data[0],
62 | best_candidate_voice.data[0]) < nms_cutoff
63 |
64 | if different_locations or different_content:
65 | new_initial_proposals.append(candidate_voice)
66 |
67 | initial_proposals = new_initial_proposals
68 |
69 | return final_proposals
70 |
71 |
72 | def forward_pass(model, target_angle, mixed_data, conditioning_label, args):
73 | """
74 | Runs the network on the mixed_data
75 | with the candidate region given by voice
76 | """
77 | target_pos = np.array([
78 | FAR_FIELD_RADIUS * np.cos(target_angle),
79 | FAR_FIELD_RADIUS * np.sin(target_angle)
80 | ])
81 |
82 | data, _ = utils.shift_mixture(
83 | torch.tensor(mixed_data).to(args.device), target_pos, args.mic_radius,
84 | args.sr)
85 | data = data.float().unsqueeze(0) # Batch size is 1
86 |
87 | # Normalize input
88 | data, means, stds = normalize_input(data)
89 |
90 | # Run through the model
91 | valid_length = model.valid_length(data.shape[-1])
92 | delta = valid_length - data.shape[-1]
93 | padded = F.pad(data, (delta // 2, delta - delta // 2))
94 |
95 | output_signal = model(padded, conditioning_label)
96 | output_signal = center_trim(output_signal, data)
97 |
98 | output_signal = unnormalize_input(output_signal, means, stds)
99 | output_voices = output_signal[:, 0] # batch x n_mics x n_samples
100 |
101 | output_np = output_voices.detach().cpu().numpy()[0]
102 | energy = librosa.feature.rms(output_np).mean()
103 |
104 | return output_np, energy
105 |
106 |
107 | def run_separation(mixed_data, model, args,
108 | energy_cutoff=ENERGY_CUTOFF,
109 | nms_cutoff=NMS_SIMILARITY_SDR): # yapf: disable
110 | """
111 | The main separation by localization algorithm
112 | """
113 | # Get the initial candidates
114 | num_windows = len(ALL_WINDOW_SIZES) if not args.moving else 3
115 | starting_angles = utils.get_starting_angles(ALL_WINDOW_SIZES[0])
116 | candidate_voices = [CandidateVoice(x, None, None) for x in starting_angles]
117 |
118 | # All steps of the binary search
119 | for window_idx in range(num_windows):
120 | if args.debug:
121 | print("---------")
122 | conditioning_label = torch.tensor(utils.to_categorical(
123 | window_idx, 5)).float().to(args.device).unsqueeze(0)
124 |
125 | curr_window_size = ALL_WINDOW_SIZES[window_idx]
126 | new_candidate_voices = []
127 |
128 | # Iterate over all the potential locations
129 | for voice in candidate_voices:
130 | output, energy = forward_pass(model, voice.angle, mixed_data,
131 | conditioning_label, args)
132 |
133 | if args.debug:
134 | print("Angle {:.2f} energy {}".format(voice.angle, energy))
135 | fname = "out{}_angle{:.2f}.wav".format(
136 | window_idx, voice.angle * 180 / np.pi)
137 | # sf.write(os.path.join(args.writing_dir, fname), output[0],
138 | # args.sr)
139 |
140 | # If there was something there
141 | if energy > energy_cutoff:
142 |
143 | # We're done searching so undo the shifts
144 | if window_idx == num_windows - 1:
145 | target_pos = np.array([
146 | FAR_FIELD_RADIUS * np.cos(voice.angle),
147 | FAR_FIELD_RADIUS * np.sin(voice.angle)
148 | ])
149 | unshifted_output, _ = utils.shift_mixture(output,
150 | target_pos,
151 | args.mic_radius,
152 | args.sr,
153 | inverse=True)
154 |
155 | new_candidate_voices.append(
156 | CandidateVoice(voice.angle, energy, unshifted_output))
157 |
158 | # Split region and recurse.
159 | # You can either split strictly (fourths)
160 | # or with some redundancy (thirds)
161 | else:
162 | # new_candidate_voices.append(
163 | # CandidateVoice(
164 | # voice.angle + curr_window_size / 3,
165 | # energy, output))
166 | # new_candidate_voices.append(
167 | # CandidateVoice(
168 | # voice.angle - curr_window_size / 3,
169 | # energy, output))
170 | # new_candidate_voices.append(
171 | # CandidateVoice(
172 | # voice.angle,
173 | # energy, output))
174 | new_candidate_voices.append(
175 | CandidateVoice(
176 | voice.angle + curr_window_size / 4,
177 | energy, output))
178 | new_candidate_voices.append(
179 | CandidateVoice(
180 | voice.angle - curr_window_size / 4,
181 | energy, output))
182 |
183 | candidate_voices = new_candidate_voices
184 |
185 | # Run NMS on the final output and return
186 | return nms(candidate_voices, nms_cutoff)
187 |
188 |
189 | def main(args):
190 | device = torch.device('cuda') if args.use_cuda else torch.device('cpu')
191 |
192 | args.device = device
193 | model = CoSNetwork(n_audio_channels=args.n_channels)
194 | model.load_state_dict(torch.load(args.model_checkpoint), strict=True, map_location=args.device)
195 | model.train = False
196 | model.to(device)
197 |
198 | if not os.path.exists(args.output_dir):
199 | os.makedirs(args.output_dir)
200 |
201 | mixed_data = librosa.core.load(args.input_file, mono=False, sr=args.sr)[0]
202 | assert mixed_data.shape[0] == args.n_channels
203 |
204 | temporal_chunk_size = int(args.sr * args.duration)
205 | num_chunks = (mixed_data.shape[1] // temporal_chunk_size) + 1
206 |
207 | for chunk_idx in range(num_chunks):
208 | curr_writing_dir = os.path.join(args.output_dir,
209 | "{:03d}".format(chunk_idx))
210 | if not os.path.exists(curr_writing_dir):
211 | os.makedirs(curr_writing_dir)
212 |
213 | args.writing_dir = curr_writing_dir
214 | curr_mixed_data = mixed_data[:, (chunk_idx *
215 | temporal_chunk_size):(chunk_idx + 1) *
216 | temporal_chunk_size]
217 |
218 | output_voices = run_separation(curr_mixed_data, model, args)
219 | for voice in output_voices:
220 | fname = "output_angle{:.2f}.wav".format(
221 | voice.angle * 180 / np.pi)
222 | sf.write(os.path.join(args.writing_dir, fname), voice.data[0],
223 | args.sr)
224 |
225 | candidate_angles = [voice.angle for voice in output_voices]
226 | diagram_window_angle = ALL_WINDOW_SIZES[2] if args.moving else ALL_WINDOW_SIZES[-1]
227 | draw_diagram([], candidate_angles,
228 | diagram_window_angle,
229 | os.path.join(args.writing_dir, "positions.png".format(chunk_idx)))
230 |
231 | if __name__ == '__main__':
232 | parser = argparse.ArgumentParser()
233 | parser.add_argument('model_checkpoint',
234 | type=str,
235 | help="Path to the model file")
236 | parser.add_argument('input_file', type=str, help="Path to the input file")
237 | parser.add_argument('output_dir',
238 | type=str,
239 | help="Path to write the outputs")
240 | parser.add_argument('--sr', type=int, default=22050, help="Sampling rate")
241 | parser.add_argument('--n_channels',
242 | type=int,
243 | default=2,
244 | help="Number of channels")
245 | parser.add_argument('--use_cuda',
246 | dest='use_cuda',
247 | action='store_true',
248 | help="Whether to use cuda")
249 | parser.add_argument('--debug',
250 | action='store_true',
251 | help="Save intermediate outputs")
252 | parser.add_argument('--mic_radius',
253 | default=.03231,
254 | type=float,
255 | help="Radius of the mic array")
256 | parser.add_argument('--duration',
257 | default=3.0,
258 | type=float,
259 | help="Seconds of input to the network")
260 | parser.add_argument('--moving',
261 | action='store_true',
262 | help="If the sources are moving then stop at a coarse window")
263 | main(parser.parse_args())
264 |
--------------------------------------------------------------------------------