├── README.md ├── executor ├── __init__.py ├── executor.py └── feature.py ├── nnet ├── __init__.py └── early_exit_transformer.py ├── separate.py └── utils ├── __init__.py ├── audio_util.py ├── mvdr_util.py └── overlapped_speech_7ch.scp /README.md: -------------------------------------------------------------------------------- 1 | # Multi-channel Continuous Speech Separation with Early Exit Transformer 2 | 3 | ## Introduction 4 | 5 | We elaborate an early exit mechanism for Transformer based multi-channel speech separation, which aims to address the “overthinking” problem and accelerate inference stage simultaneously. 6 | 7 | For a detailed description and experimental results, please refer to our paper: [Don't shoot butterfly with rifles: Multi-channel Continuous Speech Separation with Early Exit Transformer](https://arxiv.org/abs/2010.12180) (Accepted by ICASSP 2021). 8 | 9 | ## Environment 10 | python 3.6.9, torch 1.7.1 11 | 12 | ## Get Started 13 | 1. Download the overlapped speech of [LibriCSS dataset](https://github.com/chenzhuo1011/libri_css). 14 | 15 | ```bash 16 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1PdloA-V8HGxkRu9MnT35_civpc3YXJsT' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1PdloA-V8HGxkRu9MnT35_civpc3YXJsT" -O overlapped_speech.zip && rm -rf /tmp/cookies.txt && unzip overlapped_speech.zip && rm overlapped_speech.zip 17 | ``` 18 | 19 | 2. Download the Conformer separation models. 20 | 21 | ```bash 22 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1bK_0jj4yQjCJUOX-Bd8x_1PJNQL8UvfZ' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1bK_0jj4yQjCJUOX-Bd8x_1PJNQL8UvfZ" -O checkpoints.zip && rm -rf /tmp/cookies.txt && unzip checkpoints.zip && rm checkpoints.zip 23 | ``` 24 | 25 | 3. Run the separation. 26 | 27 | ```bash 28 | export MODEL_NAME=EETransformer 29 | export EE_THRESHOLD=0 30 | python3 separate.py \ 31 | --checkpoint checkpoints/$MODEL_NAME \ 32 | --mix-scp utils/overlapped_speech_7ch.scp \ 33 | --dump-dir separated_speech/7ch/utterances_with_${MODEL_NAME}_eet${EE_THRESHOLD} \ 34 | --device-id 0 \ 35 | --num_spks 2 \ 36 | --mvdr True \ 37 | --early_exit_threshold $EE_THRESHOLD 38 | ``` 39 | 40 | The separated speech can be found in the directory 'separated_speech/7ch/utterances_with_${MODEL_NAME}_eet${EE_THRESHOLD}' 41 | 42 | ## Citation 43 | If you find our work useful, please cite [our paper](https://arxiv.org/abs/2010.12180): 44 | ```bibtex 45 | @inproceedings{CSS_with_EETransformer, 46 | title={Don’t shoot butterfly with rifles: Multi-channel continuous speech separation with early exit transformer}, 47 | author={Chen, Sanyuan and Wu, Yu and Chen, Zhuo and Yoshioka, Takuya and Liu, Shujie and Li, Jinyu and Yu, Xiangzhan}, 48 | booktitle={ICASSP 2021-2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 49 | pages={6139--6143}, 50 | year={2021}, 51 | organization={IEEE} 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /executor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sanyuan-Chen/CSS_with_EETransformer/377904067835e442c50f7db019397ac24e9cdf49/executor/__init__.py -------------------------------------------------------------------------------- /executor/executor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch as th 4 | import torch.nn as nn 5 | from pathlib import Path 6 | from .feature import FeatureExtractor 7 | 8 | 9 | class Executor(nn.Module): 10 | """ 11 | Executor is a class to handle feature extraction 12 | and forward process of the separation networks. 13 | """ 14 | def __init__(self, nnet, extractor_kwargs=None, get_mask=True): 15 | super(Executor, self).__init__() 16 | self.nnet = nnet 17 | self.inference_time = [] 18 | self.extractor = FeatureExtractor( 19 | **extractor_kwargs) if extractor_kwargs else None 20 | self.frame_len = extractor_kwargs['frame_len'] if extractor_kwargs else None 21 | self.frame_hop = extractor_kwargs['frame_hop'] if extractor_kwargs else None 22 | self.get_mask = get_mask 23 | 24 | def resume(self, checkpoint): 25 | """ 26 | Resume from checkpoint 27 | """ 28 | if not Path(checkpoint).exists(): 29 | raise FileNotFoundError( 30 | f"Could not find resume checkpoint: {checkpoint}") 31 | cpt = th.load(checkpoint, map_location="cpu") 32 | self.load_state_dict(cpt["model_state_dict"]) 33 | return cpt["epoch"] 34 | 35 | def _compute_feats(self, egs): 36 | """ 37 | Compute features: N x F x T 38 | """ 39 | if not self.extractor: 40 | raise RuntimeError("self.extractor is None, " 41 | "do not need to compute features") 42 | mag, pha, f = self.extractor(egs["mix"]) 43 | return mag, pha, f 44 | 45 | def forward(self, egs, early_exit_threshold=0, record=False): 46 | mag, pha, f = self._compute_feats(egs) 47 | 48 | if record: 49 | start_event = th.cuda.Event(enable_timing=True) 50 | end_event = th.cuda.Event(enable_timing=True) 51 | start_event.record() 52 | th.cuda.synchronize() 53 | out = self.nnet(f, early_exit_threshold=early_exit_threshold) 54 | if record: 55 | end_event.record() 56 | th.cuda.synchronize() 57 | self.inference_time += [start_event.elapsed_time(end_event)] 58 | 59 | if self.get_mask: 60 | return out 61 | else: 62 | return [self.extractor.istft(m * mag, pha) for m in out] 63 | -------------------------------------------------------------------------------- /executor/feature.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Implementation of front-end feature via PyTorch 5 | """ 6 | 7 | import math 8 | import torch as th 9 | 10 | from collections.abc import Sequence 11 | 12 | import torch.nn.functional as F 13 | import torch.nn as nn 14 | 15 | EPSILON = th.finfo(th.float32).eps 16 | MATH_PI = math.pi 17 | 18 | 19 | def init_kernel(frame_len, 20 | frame_hop, 21 | normalize=True, 22 | round_pow_of_two=True, 23 | window="sqrt_hann"): 24 | if window != "sqrt_hann" and window != "hann": 25 | raise RuntimeError("Now only support sqrt hanning window or hann window") 26 | # FFT points 27 | N = 2**math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len 28 | # window 29 | W = th.hann_window(frame_len) 30 | if window == "sqrt_hann": 31 | W = W**0.5 32 | # scale factor to make same magnitude after iSTFT 33 | if window == "sqrt_hann" and normalize: 34 | S = 0.5 * (N * N / frame_hop)**0.5 35 | else: 36 | S = 1 37 | # F x N/2+1 x 2 38 | K = th.rfft(th.eye(N) / S, 1)[:frame_len] 39 | # 2 x N/2+1 x F 40 | K = th.transpose(K, 0, 2) * W 41 | # N+2 x 1 x F 42 | K = th.reshape(K, (N + 2, 1, frame_len)) 43 | return K 44 | 45 | 46 | class STFTBase(nn.Module): 47 | """ 48 | Base layer for (i)STFT 49 | NOTE: 50 | 1) Recommend sqrt_hann window with 2**N frame length, because it 51 | could achieve perfect reconstruction after overlap-add 52 | 2) Now haven't consider padding problems yet 53 | """ 54 | def __init__(self, 55 | frame_len, 56 | frame_hop, 57 | window="sqrt_hann", 58 | normalize=True, 59 | round_pow_of_two=True): 60 | super(STFTBase, self).__init__() 61 | K = init_kernel(frame_len, 62 | frame_hop, 63 | round_pow_of_two=round_pow_of_two, 64 | window=window) 65 | self.K = nn.Parameter(K, requires_grad=False) 66 | self.stride = frame_hop 67 | self.window = window 68 | self.normalize = normalize 69 | self.num_bins = self.K.shape[0] // 2 70 | if window == "hann": 71 | self.conjugate = True 72 | else: 73 | self.conjugate = False 74 | 75 | def extra_repr(self): 76 | return (f"window={self.window}, stride={self.stride}, " + 77 | f"kernel_size={self.K.shape[0]}x{self.K.shape[2]}, " + 78 | f"normalize={self.normalize}") 79 | 80 | 81 | class STFT(STFTBase): 82 | """ 83 | Short-time Fourier Transform as a Layer 84 | """ 85 | def __init__(self, *args, **kwargs): 86 | super(STFT, self).__init__(*args, **kwargs) 87 | 88 | def forward(self, x, cplx=False): 89 | """ 90 | Accept (single or multiple channel) raw waveform and output magnitude and phase 91 | args 92 | x: input signal, N x C x S or N x S 93 | return 94 | m: magnitude, N x C x F x T or N x F x T 95 | p: phase, N x C x F x T or N x F x T 96 | """ 97 | if x.dim() not in [2, 3]: 98 | raise RuntimeError( 99 | "{} expect 2D/3D tensor, but got {:d}D signal".format( 100 | self.__name__, x.dim())) 101 | # if N x S, reshape N x 1 x S 102 | if x.dim() == 2: 103 | x = th.unsqueeze(x, 1) 104 | # N x 2F x T 105 | c = F.conv1d(x, self.K, stride=self.stride, padding=0) 106 | # N x F x T 107 | r, i = th.chunk(c, 2, dim=1) 108 | if self.conjugate: 109 | # to match with science pipeline, we need to do conjugate 110 | i = -i 111 | # else reshape NC x 1 x S 112 | else: 113 | N, C, S = x.shape 114 | x = x.view(N * C, 1, S) 115 | # NC x 2F x T 116 | c = F.conv1d(x, self.K, stride=self.stride, padding=0) 117 | # N x C x 2F x T 118 | c = c.view(N, C, -1, c.shape[-1]) 119 | # N x C x F x T 120 | r, i = th.chunk(c, 2, dim=2) 121 | if self.conjugate: 122 | # to match with science pipeline, we need to do conjugate 123 | i = -i 124 | if cplx: 125 | return r, i 126 | m = (r**2 + i**2)**0.5 127 | p = th.atan2(i, r) 128 | return m, p 129 | 130 | 131 | class iSTFT(STFTBase): 132 | """ 133 | Inverse Short-time Fourier Transform as a Layer 134 | """ 135 | def __init__(self, *args, **kwargs): 136 | super(iSTFT, self).__init__(*args, **kwargs) 137 | 138 | def forward(self, m, p, cplx=False, squeeze=False): 139 | """ 140 | Accept phase & magnitude and output raw waveform 141 | args 142 | m, p: N x F x T 143 | return 144 | s: N x S 145 | """ 146 | if p.dim() != m.dim() or p.dim() not in [2, 3]: 147 | raise RuntimeError("Expect 2D/3D tensor, but got {:d}D".format( 148 | p.dim())) 149 | # if F x T, reshape 1 x F x T 150 | if p.dim() == 2: 151 | p = th.unsqueeze(p, 0) 152 | m = th.unsqueeze(m, 0) 153 | if cplx: 154 | # N x 2F x T 155 | c = th.cat([m, p], dim=1) 156 | else: 157 | r = m * th.cos(p) 158 | i = m * th.sin(p) 159 | # N x 2F x T 160 | c = th.cat([r, i], dim=1) 161 | # N x 2F x T 162 | s = F.conv_transpose1d(c, self.K, stride=self.stride, padding=0) 163 | # N x S 164 | s = s.squeeze(1) 165 | if squeeze: 166 | s = th.squeeze(s) 167 | return s 168 | 169 | 170 | class IPDFeature(nn.Module): 171 | """ 172 | Compute inter-channel phase difference 173 | """ 174 | def __init__(self, 175 | ipd_index="1,0;2,0;3,0;4,0;5,0;6,0", 176 | cos=True, 177 | sin=False, 178 | ipd_mean_normalize_version=2, 179 | ipd_mean_normalize=True): 180 | super(IPDFeature, self).__init__() 181 | split_index = lambda sstr: [ 182 | tuple(map(int, p.split(","))) for p in sstr.split(";") 183 | ] 184 | # ipd index 185 | pair = split_index(ipd_index) 186 | self.index_l = [t[0] for t in pair] 187 | self.index_r = [t[1] for t in pair] 188 | self.ipd_index = ipd_index 189 | self.cos = cos 190 | self.sin = sin 191 | self.ipd_mean_normalize=ipd_mean_normalize 192 | self.ipd_mean_normalize_version=ipd_mean_normalize_version 193 | self.num_pairs = len(pair) * 2 if cos and sin else len(pair) 194 | 195 | def extra_repr(self): 196 | return f"ipd_index={self.ipd_index}, cos={self.cos}, sin={self.sin}" 197 | 198 | def forward(self, p): 199 | """ 200 | Accept multi-channel phase and output inter-channel phase difference 201 | args 202 | p: phase matrix, N x C x F x T 203 | return 204 | ipd: N x MF x T 205 | """ 206 | if p.dim() not in [3, 4]: 207 | raise RuntimeError( 208 | "{} expect 3/4D tensor, but got {:d} instead".format( 209 | self.__name__, p.dim())) 210 | # C x F x T => 1 x C x F x T 211 | if p.dim() == 3: 212 | p = p.unsqueeze(0) 213 | N, _, _, T = p.shape 214 | pha_dif = p[:, self.index_l] - p[:, self.index_r] 215 | if self.ipd_mean_normalize: 216 | yr = th.cos(pha_dif) 217 | yi = th.sin(pha_dif) 218 | yrm = yr.mean(-1, keepdim=True) 219 | yim = yi.mean(-1, keepdim=True) 220 | if self.ipd_mean_normalize_version == 1: 221 | pha_dif = th.atan2(yi - yim, yr - yrm) 222 | elif self.ipd_mean_normalize_version == 2: 223 | pha_dif_mean = th.atan2(yim, yrm) 224 | pha_dif -= pha_dif_mean 225 | elif self.ipd_mean_normalize_version == 3: 226 | pha_dif_mean = pha_dif.mean(-1, keepdim=True) 227 | pha_dif -= pha_dif_mean 228 | else: 229 | # we only support version 1, 2 and 3 230 | raise RuntimeError( 231 | "{} expect ipd_mean_normalization version 1 or version 2, but got {:d} instead".format( 232 | self.__name__, self.ipd_mean_normalize_version)) 233 | 234 | if self.cos: 235 | # N x M x F x T 236 | ipd = th.cos(pha_dif) 237 | if self.sin: 238 | # N x M x 2F x T, along frequency axis 239 | ipd = th.cat([ipd, th.sin(pha_dif)], 2) 240 | else: 241 | # th.fmod behaves differently from np.mod for the input that is less than -math.pi 242 | # i believe it is a bug 243 | # so we need to ensure it is larger than -math.pi by adding an extra 6 * math.pi 244 | #ipd = th.fmod(pha_dif + math.pi, 2 * math.pi) - math.pi 245 | ipd = pha_dif 246 | # N x MF x T 247 | ipd = ipd.view(N, -1, T) 248 | # N x MF x T 249 | return ipd 250 | 251 | 252 | class AngleFeature(nn.Module): 253 | """ 254 | Compute angle/directional feature 255 | 1) num_doas == 1: we known the DoA of the target speaker 256 | 2) num_doas != 1: we do not have that prior, so we sampled #num_doas DoAs 257 | and compute on each directions 258 | """ 259 | def __init__(self, 260 | geometric="princeton", 261 | sr=16000, 262 | velocity=340, 263 | num_bins=257, 264 | num_doas=1, 265 | af_index="1,0;2,0;3,0;4,0;5,0;6,0"): 266 | super(AngleFeature, self).__init__() 267 | if geometric not in ["princeton"]: 268 | raise RuntimeError( 269 | "Unsupported array geometric: {}".format(geometric)) 270 | self.geometric = geometric 271 | self.sr = sr 272 | self.num_bins = num_bins 273 | self.num_doas = num_doas 274 | self.velocity = velocity 275 | split_index = lambda sstr: [ 276 | tuple(map(int, p.split(","))) for p in sstr.split(";") 277 | ] 278 | # ipd index 279 | pair = split_index(af_index) 280 | self.index_l = [t[0] for t in pair] 281 | self.index_r = [t[1] for t in pair] 282 | self.af_index = af_index 283 | omega = th.tensor( 284 | [math.pi * sr * f / (num_bins - 1) for f in range(num_bins)]) 285 | # 1 x F 286 | self.omega = nn.Parameter(omega[None, :], requires_grad=False) 287 | 288 | def _oracle_phase_delay(self, doa): 289 | """ 290 | Compute oracle phase delay given DoA 291 | args 292 | doa: N 293 | return 294 | phi: N x C x F or N x D x C x F 295 | """ 296 | device = doa.device 297 | if self.num_doas != 1: 298 | # doa is a unused, fake parameter 299 | N = doa.shape[0] 300 | # N x D 301 | doa = th.linspace(0, MATH_PI * 2, self.num_doas + 1, 302 | device=device)[:-1].repeat(N, 1) 303 | # for princeton 304 | # M = 7, R = 0.0425, treat M_0 as (0, 0) 305 | # *3 *2 306 | # 307 | # *4 *0 *1 308 | # 309 | # *5 *6 310 | if self.geometric == "princeton": 311 | R = 0.0425 312 | zero = th.zeros_like(doa) 313 | # N x 7 or N x D x 7 314 | tau = R * th.stack([ 315 | zero, -th.cos(doa), -th.cos(MATH_PI / 3 - doa), 316 | -th.cos(2 * MATH_PI / 3 - doa), 317 | th.cos(doa), 318 | th.cos(MATH_PI / 3 - doa), 319 | th.cos(2 * MATH_PI / 3 - doa) 320 | ], 321 | dim=-1) / self.velocity 322 | # (Nx7x1) x (1xF) => Nx7xF or (NxDx7x1) x (1xF) => NxDx7xF 323 | phi = th.matmul(tau.unsqueeze(-1), -self.omega) 324 | return phi 325 | else: 326 | return None 327 | 328 | def extra_repr(self): 329 | return ( 330 | f"geometric={self.geometric}, af_index={self.af_index}, " + 331 | f"sr={self.sr}, num_bins={self.num_bins}, velocity={self.velocity}, " 332 | + f"known_doa={self.num_doas == 1}") 333 | 334 | def _compute_af(self, ipd, doa): 335 | """ 336 | Compute angle feature 337 | args 338 | ipd: N x C x F x T 339 | doa: DoA of the target speaker (if we known that), N 340 | or N x D (we do not known that, sampling D DoAs instead) 341 | return 342 | af: N x F x T or N x D x F x T 343 | """ 344 | # N x C x F or N x D x C x F 345 | d = self._oracle_phase_delay(doa) 346 | d = d.unsqueeze(-1) 347 | if self.num_doas == 1: 348 | dif = d[:, self.index_l] - d[:, self.index_r] 349 | # N x C x F x T 350 | af = th.cos(ipd - dif) 351 | # on channel dimention (mean or sum) 352 | af = th.mean(af, dim=1) 353 | else: 354 | # N x D x C x F x 1 355 | dif = d[:, :, self.index_l] - d[:, :, self.index_r] 356 | # N x D x C x F x T 357 | af = th.cos(ipd.unsqueeze(1) - dif) 358 | # N x D x F x T 359 | af = th.mean(af, dim=2) 360 | return af 361 | 362 | def forward(self, p, doa): 363 | """ 364 | Accept doa of the speaker & multi-channel phase, output angle feature 365 | args 366 | doa: DoA of target/each speaker, N or [N, ...] 367 | p: phase matrix, N x C x F x T 368 | return 369 | af: angle feature, N x F* x T or N x D x F x T (known_doa=False) 370 | """ 371 | if p.dim() not in [3, 4]: 372 | raise RuntimeError( 373 | "{} expect 3/4D tensor, but got {:d} instead".format( 374 | self.__name__, p.dim())) 375 | # C x F x T => 1 x C x F x T 376 | if p.dim() == 3: 377 | p = p.unsqueeze(0) 378 | ipd = p[:, self.index_l] - p[:, self.index_r] 379 | 380 | if isinstance(doa, Sequence): 381 | if self.num_doas != 1: 382 | raise RuntimeError("known_doa=False, no need to pass " 383 | "doa as a Sequence object") 384 | # [N x F x T or N x D x F x T, ...] 385 | af = [self._compute_af(ipd, spk_doa) for spk_doa in doa] 386 | # N x F x T => N x F* x T 387 | af = th.cat(af, 1) 388 | else: 389 | # N x F x T or N x D x F x T 390 | af = self._compute_af(ipd, doa) 391 | return af 392 | 393 | 394 | class FeatureExtractor(nn.Module): 395 | """ 396 | A PyTorch module to handle spectral & spatial features 397 | """ 398 | def __init__(self, 399 | frame_len=512, 400 | frame_hop=256, 401 | normalize=True, 402 | round_pow_of_two=True, 403 | num_spks=2, 404 | log_spectrogram=True, 405 | mvn_spectrogram=True, 406 | ipd_mean_normalize=True, 407 | ipd_mean_normalize_version=2, 408 | window="sqrt_hann", 409 | ext_af=0, 410 | ipd_cos=True, 411 | ipd_sin=False, 412 | ipd_index="1,4;2,5;3,6", 413 | ang_index="1,0;2,0;3,0;4,0;5,0;6,0" 414 | ): 415 | super(FeatureExtractor, self).__init__() 416 | # forward STFT 417 | self.forward_stft = STFT(frame_len, 418 | frame_hop, 419 | normalize=normalize, 420 | window=window, 421 | round_pow_of_two=round_pow_of_two) 422 | self.inverse_stft = iSTFT(frame_len, 423 | frame_hop, 424 | normalize=normalize, 425 | round_pow_of_two=round_pow_of_two) 426 | self.has_spatial = False 427 | num_bins = self.forward_stft.num_bins 428 | self.feature_dim = num_bins 429 | self.num_bins = num_bins 430 | self.num_spks = num_spks 431 | # add extra angle feature 432 | self.ext_af = ext_af 433 | 434 | # IPD or not 435 | self.ipd_extractor = None 436 | if ipd_index: 437 | self.ipd_extractor = IPDFeature(ipd_index, 438 | cos=ipd_cos, 439 | sin=ipd_sin, 440 | ipd_mean_normalize_version=ipd_mean_normalize_version, 441 | ipd_mean_normalize=ipd_mean_normalize) 442 | self.feature_dim += self.ipd_extractor.num_pairs * num_bins 443 | self.has_spatial = True 444 | # AF or not 445 | self.ang_extractor = None 446 | if ang_index: 447 | self.ang_extractor = AngleFeature( 448 | num_bins=num_bins, 449 | num_doas=1, # must known the DoA 450 | af_index=ang_index) 451 | self.feature_dim += num_bins * self.num_spks * (1 + self.ext_af) 452 | self.has_spatial = True 453 | # BN or not 454 | self.mvn_mag = mvn_spectrogram 455 | # apply log or not 456 | self.log_mag = log_spectrogram 457 | 458 | def _check_args(self, x, doa): 459 | if x.dim() == 2 and self.has_spatial: 460 | raise RuntimeError("Got 2D (single channel) input and can " 461 | "not extract spatial features") 462 | if self.ang_extractor is None and doa: 463 | raise RuntimeError("DoA is given and AF extractor " 464 | "is not initialized") 465 | if self.ang_extractor and doa is None: 466 | raise RuntimeError("AF extractor is initialized, but DoA is None") 467 | num_af = self.num_spks * (self.ext_af + 1) 468 | if isinstance(doa, Sequence) and len(doa) != num_af: 469 | raise RuntimeError("Number of DoA do not match the " + 470 | f"speaker number: {len(doa):d} vs {num_af:d}") 471 | 472 | def stft(self, x, cplx=False): 473 | return self.forward_stft(x, cplx=cplx) 474 | 475 | def istft(self, m, p, cplx=False): 476 | return self.inverse_stft(m, p, cplx=cplx) 477 | 478 | def compute_spectra(self, x): 479 | """ 480 | Compute spectra features 481 | args 482 | x: N x C x S (multi-channel) or N x S (single channel) 483 | return: 484 | mag & pha: N x F x T or N x C x F x T 485 | feature: N x * x T 486 | """ 487 | # mag & pha: N x C x F x T or N x F x T 488 | mag, pha = self.forward_stft(x) 489 | # ch0: N x F x T 490 | if mag.dim() == 4: 491 | f = th.clamp(mag[:, 0], min=EPSILON) 492 | else: 493 | f = th.clamp(mag, min=EPSILON) 494 | # log 495 | if self.log_mag: 496 | f = th.log(f) 497 | # mvn 498 | if self.mvn_mag: 499 | # f = self.mvn_mag(f) 500 | f = (f - f.mean(-1, keepdim=True)) / (f.std(-1, keepdim=True) + 501 | EPSILON) 502 | return mag, pha, f 503 | 504 | def compute_spatial(self, x, doa=None, pha=None): 505 | """ 506 | Compute spatial features 507 | args 508 | x: N x C x S (multi-channel) 509 | pha: N x C x F x T 510 | return 511 | feature: N x * x T 512 | """ 513 | if pha is None: 514 | self._check_args(x, doa) 515 | # mag & pha: N x C x F x T 516 | _, pha = self.forward_stft(x) 517 | else: 518 | if pha.dim() != 4: 519 | raise RuntimeError("Expect phase matrix a 4D tensor, " + 520 | f"got {pha.dim()} instead") 521 | feature = [] 522 | if self.has_spatial: 523 | if self.ipd_extractor: 524 | # N x C x F x T => N x MF x T 525 | ipd = self.ipd_extractor(pha) 526 | feature.append(ipd) 527 | if self.ang_extractor: 528 | # N x C x F x T => N x F* x T 529 | ang = self.ang_extractor(pha, doa) 530 | feature.append(ang) 531 | else: 532 | raise RuntimeError("No spatial features are configured") 533 | # N x * x T 534 | feature = th.cat(feature, 1) 535 | return feature 536 | 537 | def forward(self, x, doa=None, ref_channel=0): 538 | """ 539 | args 540 | x: N x C x S (multi-channel) or N x S (single channel) 541 | doa: N or [N, ...] (for each speaker) 542 | return: 543 | mag & pha: N x F x T (if ref_channel is not None), N x C x F x T 544 | feature: N x * x T 545 | """ 546 | self._check_args(x, doa) 547 | # mag & pha: N x C x F x T or N x F x T 548 | mag, pha, f = self.compute_spectra(x) 549 | feature = [f] 550 | if self.has_spatial: 551 | spatial = self.compute_spatial(x, pha=pha, doa=doa) 552 | feature.append(spatial) 553 | # N x * x T 554 | feature = th.cat(feature, 1) 555 | if mag.dim() == 4 and ref_channel is not None: 556 | return mag[:, ref_channel], pha[:, ref_channel], feature 557 | else: 558 | return mag, pha, feature -------------------------------------------------------------------------------- /nnet/__init__.py: -------------------------------------------------------------------------------- 1 | from . import early_exit_transformer 2 | 3 | supported_nnet = { 4 | "EETransformer": early_exit_transformer.EETransformerCSS, 5 | } 6 | -------------------------------------------------------------------------------- /nnet/early_exit_transformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import math 4 | import torch.nn as nn 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class RelativePositionalEncoding(torch.nn.Module): 10 | def __init__(self, d_model, maxlen=1000, embed_v=False): 11 | super(RelativePositionalEncoding, self).__init__() 12 | self.d_model = d_model 13 | self.maxlen = maxlen 14 | self.pe_k = torch.nn.Embedding(2*maxlen, d_model) 15 | if embed_v: 16 | self.pe_v = torch.nn.Embedding(2*maxlen, d_model) 17 | self.embed_v = embed_v 18 | 19 | def forward(self, pos_seq): 20 | pos_seq.clamp_(-self.maxlen, self.maxlen - 1) 21 | pos_seq = pos_seq + self.maxlen 22 | if self.embed_v: 23 | return self.pe_k(pos_seq), self.pe_v(pos_seq) 24 | else: 25 | return self.pe_k(pos_seq), None 26 | 27 | 28 | class PositionwiseFeedForward(torch.nn.Module): 29 | """Positionwise feed forward 30 | 31 | :param int idim: input dimension 32 | :param int hidden_units: number of hidden units 33 | :param float dropout_rate: dropout rate 34 | """ 35 | 36 | def __init__(self, idim, hidden_units, dropout_rate): 37 | super(PositionwiseFeedForward, self).__init__() 38 | self.w_1 = torch.nn.Linear(idim, hidden_units) 39 | self.w_2 = torch.nn.Linear(hidden_units, idim) 40 | self.dropout = torch.nn.Dropout(dropout_rate) 41 | 42 | def forward(self, x): 43 | return self.w_2(self.dropout(torch.relu(self.w_1(x)))) 44 | 45 | 46 | class MultiHeadedAttention(nn.Module): 47 | """Multi-Head Attention layer. 48 | 49 | :param int n_head: the number of head s 50 | :param int n_feat: the number of features 51 | :param float dropout_rate: dropout rate 52 | 53 | """ 54 | 55 | def __init__(self, n_head, n_feat, dropout_rate): 56 | """Construct an MultiHeadedAttention object.""" 57 | super(MultiHeadedAttention, self).__init__() 58 | assert n_feat % n_head == 0 59 | # We assume d_v always equals d_k 60 | self.d_k = n_feat // n_head 61 | self.h = n_head 62 | self.linear_q = nn.Linear(n_feat, n_feat) 63 | self.linear_k = nn.Linear(n_feat, n_feat) 64 | self.linear_v = nn.Linear(n_feat, n_feat) 65 | 66 | self.linear_out = nn.Linear(n_feat, n_feat) 67 | self.attn = None 68 | self.dropout = nn.Dropout(p=dropout_rate) 69 | 70 | 71 | def forward(self, query, key, value, pos_k, pos_v, mask): 72 | """Compute 'Scaled Dot Product Attention'. 73 | 74 | :param torch.Tensor query: (batch, time1, size) 75 | :param torch.Tensor key: (batch, time2, size) 76 | :param torch.Tensor value: (batch, time2, size) 77 | :param torch.Tensor mask: (batch, time1, time2) 78 | :return torch.Tensor: attentioned and transformed `value` (batch, time1, d_model) 79 | weighted by the query dot key attention (batch, head, time1, time2) 80 | """ 81 | n_batch = query.size(0) 82 | q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) #(b, t, d) 83 | k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) #(b, t, d) 84 | v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) 85 | q = q.transpose(1, 2) 86 | k = k.transpose(1, 2) # (batch, head, time2, d_k) 87 | v = v.transpose(1, 2) # (batch, head, time2, d_k) 88 | A = torch.matmul(q, k.transpose(-2, -1)) 89 | if pos_k is not None: 90 | reshape_q = q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0,1) 91 | B = torch.matmul(reshape_q, pos_k.transpose(-2, -1)) 92 | B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1)) 93 | scores = (A + B) / math.sqrt(self.d_k) 94 | else: 95 | scores = A / math.sqrt(self.d_k) 96 | if mask is not None: 97 | mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) 98 | min_value = float(np.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) 99 | scores = scores.masked_fill(mask, min_value) 100 | self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) 101 | else: 102 | self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) 103 | 104 | p_attn = self.dropout(self.attn) 105 | x = torch.matmul(p_attn, v) # (batch, head, time1, d_k) 106 | if pos_v is not None: 107 | reshape_attn = p_attn.contiguous().view(n_batch * self.h, pos_v.size(0), pos_v.size(1)).transpose(0,1) #(t1, bh, t2) 108 | 109 | attn_v = torch.matmul(reshape_attn, pos_v).transpose(0,1).contiguous().view(n_batch, self.h, pos_v.size(0), self.d_k) 110 | x = x + attn_v 111 | x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) 112 | return self.linear_out(x) # (batch, time1, d_model) 113 | 114 | 115 | class EncoderLayer(nn.Module): 116 | """Encoder layer module. 117 | 118 | :param int size: input dim 119 | :param espnet.nets.pytorch_backend.transformer.attention.MultiHeadedAttention self_attn: self attention module 120 | :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.PositionwiseFeedForward feed_forward: 121 | feed forward module 122 | :param float dropout_rate: dropout rate 123 | :param bool normalize_before: whether to use layer_norm before the first block 124 | :param bool concat_after: whether to concat attention layer's input and output 125 | if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) 126 | if False, no additional linear will be applied. i.e. x -> x + att(x) 127 | """ 128 | 129 | def __init__(self, size, self_attn, feed_forward, dropout_rate, 130 | normalize_before=True, concat_after=False, attention_heads=8): 131 | """Construct an EncoderLayer object.""" 132 | super(EncoderLayer, self).__init__() 133 | self.self_attn = self_attn 134 | self.feed_forward = feed_forward 135 | self.norm1 = torch.nn.LayerNorm(size, eps=1e-12) 136 | self.norm2 = torch.nn.LayerNorm(size, eps=1e-12) 137 | self.norm_k = torch.nn.LayerNorm(size//attention_heads, eps=1e-12) 138 | self.dropout = nn.Dropout(dropout_rate) 139 | self.size = size 140 | self.normalize_before = normalize_before 141 | self.concat_after = concat_after 142 | if self.concat_after: 143 | self.concat_linear = nn.Linear(size + size, size) 144 | 145 | def forward(self, x, pos_k, pos_v, mask): 146 | """Compute encoded features. 147 | 148 | :param torch.Tensor x: encoded source features (batch, max_time_in, size) 149 | :param torch.Tensor mask: mask for x (batch, max_time_in) 150 | :rtype: Tuple[torch.Tensor, torch.Tensor] 151 | """ 152 | residual = x 153 | if self.normalize_before: 154 | x = self.norm1(x) 155 | if pos_k is not None: 156 | pos_k = self.norm_k(pos_k) 157 | if pos_v is not None: 158 | pos_v = self.norm_v(pos_v) 159 | if self.concat_after: 160 | x_concat = torch.cat((x, self.self_attn(x, x, x, pos_k, pos_v, mask)), dim=-1) 161 | x = residual + self.concat_linear(x_concat) 162 | else: 163 | x = residual + self.dropout(self.self_attn(x, x, x, pos_k, pos_v, mask)) 164 | if not self.normalize_before: 165 | x = self.norm1(x) 166 | 167 | residual = x 168 | if self.normalize_before: 169 | x = self.norm2(x) 170 | x = residual + self.dropout(self.feed_forward(x)) 171 | if not self.normalize_before: 172 | x = self.norm2(x) 173 | 174 | return x, mask 175 | 176 | 177 | class EETransformerEncoder(torch.nn.Module): 178 | """Early Exit Transformer encoder module. 179 | 180 | :param int idim: input dim 181 | :param int attention_dim: dimention of attention 182 | :param int attention_heads: the number of heads of multi head attention 183 | :param int linear_units: the number of units of position-wise feed forward 184 | :param int num_blocks: the number of decoder blocks 185 | :param float dropout_rate: dropout rate 186 | :param float attention_dropout_rate: dropout rate in attention 187 | :param float positional_dropout_rate: dropout rate after adding positional encoding 188 | :param bool normalize_before: whether to use layer_norm before the first block 189 | :param bool concat_after: whether to concat attention layer's input and output 190 | if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x))) 191 | if False, no additional linear will be applied. i.e. x -> x + att(x) 192 | """ 193 | 194 | def __init__(self, 195 | idim=1799, 196 | attention_dim=256, 197 | attention_heads=4, 198 | linear_units=2048, 199 | num_blocks=16, 200 | dropout_rate=0.1, 201 | positional_dropout_rate=0.1, 202 | attention_dropout_rate=0.0, 203 | relative_pos_emb=True, 204 | normalize_before=True, 205 | concat_after=False, 206 | exit_classifiers=None): 207 | super(EETransformerEncoder, self).__init__() 208 | 209 | self.embed = torch.nn.Sequential( 210 | torch.nn.Linear(idim, attention_dim), 211 | torch.nn.LayerNorm(attention_dim), 212 | torch.nn.Dropout(dropout_rate), 213 | torch.nn.ReLU(), 214 | ) 215 | 216 | if relative_pos_emb: 217 | self.pos_emb = RelativePositionalEncoding(attention_dim // attention_heads, 1000) 218 | else: 219 | self.pos_emb = None 220 | 221 | self.encoders = torch.nn.Sequential(*[EncoderLayer( 222 | attention_dim, 223 | MultiHeadedAttention(attention_heads, attention_dim, attention_dropout_rate), 224 | PositionwiseFeedForward(attention_dim, linear_units, dropout_rate), 225 | dropout_rate, 226 | normalize_before, 227 | concat_after, 228 | attention_heads 229 | ) for _ in range(num_blocks)]) 230 | 231 | self.dropout_layer = torch.nn.Dropout(p=positional_dropout_rate) 232 | self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None 233 | 234 | self.exit_classifiers = exit_classifiers 235 | self.inference_exit_layers = [] 236 | 237 | def forward(self, xs, masks, early_exit_threshold=0): 238 | xs = self.embed(xs) 239 | 240 | if self.pos_emb is not None: 241 | x_len = xs.shape[1] 242 | pos_seq = torch.arange(0, x_len).long().to(xs.device) 243 | pos_seq = pos_seq[:, None] - pos_seq[None, :] 244 | pos_k, pos_v = self.pos_emb(pos_seq) 245 | else: 246 | pos_k, pos_v = None, None 247 | 248 | xs = self.dropout_layer(xs) 249 | 250 | if self.training: 251 | # during training, return the estimated masks of all the layers 252 | results = [] 253 | for i, layer in enumerate(self.encoders): 254 | xs, _ = layer(xs, pos_k, pos_v, masks) 255 | output = self.after_norm(xs) if self.after_norm else xs 256 | output = self.exit_classifiers[i](output) 257 | output = torch.sigmoid(output) 258 | results.append(output) 259 | else: 260 | # We dynamically stop the inference if the predictions from two consecutive layers are sufficiently similar 261 | last_predicts = None 262 | calculated_layer_num = 0 263 | for i, layer in enumerate(self.encoders): 264 | calculated_layer_num += 1 265 | xs, _ = layer(xs, pos_k, pos_v, masks) 266 | output = self.after_norm(xs) if self.after_norm else xs 267 | logits = self.exit_classifiers[i](output) 268 | predicts = torch.sigmoid(logits) 269 | predicts = predicts.detach() 270 | if (last_predicts is not None) and torch.dist(last_predicts, predicts, p=2) / last_predicts.shape[1] / last_predicts.shape[2] < early_exit_threshold: 271 | last_predicts = predicts 272 | break 273 | else: 274 | last_predicts = predicts 275 | results = [last_predicts] 276 | self.inference_exit_layers.append(calculated_layer_num) 277 | return results, masks 278 | 279 | 280 | default_encoder_conf = { 281 | "attention_dim": 256, 282 | "attention_heads": 4, 283 | "linear_units": 2048, 284 | "num_blocks": 16, 285 | "dropout_rate": 0.1, 286 | "positional_dropout_rate": 0.1, 287 | "attention_dropout_rate": 0.0, 288 | "relative_pos_emb": True, 289 | "normalize_before": True, 290 | "concat_after": False, 291 | } 292 | 293 | 294 | class EETransformerCSS(nn.Module): 295 | """ 296 | Early Exit Transformer speech separation model 297 | """ 298 | def __init__(self, 299 | stats_file=None, 300 | in_features=1799, 301 | num_bins=257, 302 | num_spks=2, 303 | num_nois=1, 304 | transformer_conf=default_encoder_conf): 305 | super(EETransformerCSS, self).__init__() 306 | 307 | # input normalization layer 308 | if stats_file is not None: 309 | stats = np.load(stats_file) 310 | self.input_bias = torch.from_numpy( 311 | np.tile(np.expand_dims(-stats['mean'].astype(np.float32), axis=0), (1, 1, 1))) 312 | self.input_scale = torch.from_numpy( 313 | np.tile(np.expand_dims(1 / np.sqrt(stats['variance'].astype(np.float32)), axis=0), (1, 1, 1))) 314 | self.input_bias = nn.Parameter(self.input_bias, requires_grad=False) 315 | self.input_scale = nn.Parameter(self.input_scale, requires_grad=False) 316 | else: 317 | self.input_bias = torch.zeros(1, 1, in_features) 318 | self.input_scale = torch.ones(1, 1, in_features) 319 | self.input_bias = nn.Parameter(self.input_bias, requires_grad=False) 320 | self.input_scale = nn.Parameter(self.input_scale, requires_grad=False) 321 | 322 | self.num_bins = num_bins 323 | self.num_spks = num_spks 324 | self.num_nois = num_nois 325 | self.linear = nn.ModuleList([nn.Linear(transformer_conf["attention_dim"], num_bins * (num_spks + num_nois)) 326 | for _ in range(transformer_conf['num_blocks'])]) 327 | 328 | # Transformers 329 | self.transformer = EETransformerEncoder(in_features, **transformer_conf, exit_classifiers=self.linear) 330 | 331 | def forward(self, f, early_exit_threshold=0): 332 | """ 333 | args 334 | f: N x * x T 335 | return 336 | m: [N x F x T, ...] 337 | """ 338 | # N x * x T => N x T x * 339 | f = f.transpose(1, 2) 340 | 341 | # global feature normalization 342 | f = f + self.input_bias 343 | f = f * self.input_scale 344 | 345 | m_list, _ = self.transformer(f, masks=None, early_exit_threshold=early_exit_threshold) 346 | res_m = [] 347 | for m in m_list: 348 | # N x T x F => N x F x T 349 | m = m.transpose(1, 2) 350 | m = torch.chunk(m, self.num_spks + self.num_nois, 1) 351 | res_m.append(m) 352 | 353 | if not self.training: 354 | assert len(res_m) == 1 355 | res_m = res_m[0] 356 | 357 | return res_m 358 | 359 | -------------------------------------------------------------------------------- /separate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import yaml 4 | import argparse 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | from pathlib import Path 10 | 11 | from nnet import supported_nnet 12 | from executor.executor import Executor 13 | from utils.audio_util import WaveReader, write_wav 14 | from utils.mvdr_util import make_mvdr 15 | 16 | 17 | class EgsReader(object): 18 | """ 19 | Egs reader 20 | """ 21 | def __init__(self, 22 | mix_scp, 23 | sr=16000): 24 | self.mix_reader = WaveReader(mix_scp, sr=sr) 25 | 26 | def __len__(self): 27 | return len(self.mix_reader) 28 | 29 | def __iter__(self): 30 | for key, mix in self.mix_reader: 31 | egs = dict() 32 | egs["mix"] = mix 33 | yield key, egs 34 | 35 | 36 | class Separator(object): 37 | """ 38 | A simple wrapper for speech separation 39 | """ 40 | def __init__(self, cpt_dir, get_mask=False, device_id=-1): 41 | # load executor 42 | cpt_dir = Path(cpt_dir) 43 | self.get_mask = get_mask 44 | self.executor = self._load_executor(cpt_dir) 45 | cpt_ptr = cpt_dir / "best.pt.tar" 46 | epoch = self.executor.resume(cpt_ptr.as_posix()) 47 | print(f"Load checkpoint at {cpt_dir}, on epoch {epoch}") 48 | print(f"Nnet summary: {self.executor}") 49 | if device_id < 0: 50 | self.device = th.device("cpu") 51 | else: 52 | self.device = th.device(f"cuda:{device_id:d}") 53 | self.executor.to(self.device) 54 | self.executor.eval() 55 | 56 | def separate(self, egs, early_exit_threshold=0): 57 | """ 58 | Do separation 59 | """ 60 | egs["mix"] = th.from_numpy(egs["mix"][None, :]).to(self.device, non_blocking=True) 61 | with th.no_grad(): 62 | spks = self.executor(egs, early_exit_threshold=early_exit_threshold, record=True) 63 | spks = [s.detach().squeeze().cpu().numpy() for s in spks] 64 | return spks 65 | 66 | def _load_executor(self, cpt_dir): 67 | """ 68 | Load executor from checkpoint 69 | """ 70 | with open(cpt_dir / "train.yaml", "r") as f: 71 | conf = yaml.load(f, Loader=yaml.FullLoader) 72 | nnet_type = conf["nnet_type"] 73 | if nnet_type not in supported_nnet: 74 | raise RuntimeError(f"Unknown network type: {nnet_type}") 75 | nnet = supported_nnet[nnet_type](**conf["nnet_conf"]) 76 | executor = Executor(nnet, extractor_kwargs=conf["extractor_conf"], get_mask=self.get_mask) 77 | return executor 78 | 79 | 80 | 81 | def run(args): 82 | # egs reader 83 | egs_reader = EgsReader(args.mix_scp, sr=args.sr) 84 | # separator 85 | seperator = Separator(args.checkpoint, device_id=args.device_id, get_mask=args.mvdr) 86 | 87 | dump_dir = Path(args.dump_dir) 88 | dump_dir.mkdir(exist_ok=True, parents=True) 89 | 90 | print(f"Start Separation " + ("w/ mvdr" if args.mvdr else "w/o mvdr")) 91 | for key, egs in egs_reader: 92 | print(f"Processing utterance {key}...") 93 | mixed = egs["mix"] 94 | spks = seperator.separate(egs, early_exit_threshold=args.early_exit_threshold) 95 | 96 | if args.mvdr: 97 | res1, res2 = make_mvdr(np.asfortranarray(mixed.T), spks) 98 | spks = [res1, res2] 99 | 100 | for i, s in enumerate(spks): 101 | if i < args.num_spks: 102 | write_wav(dump_dir / f"{key}_{i}.wav", s * 0.9 / np.max(np.abs(s))) 103 | 104 | print(f"Exit layers: {seperator.executor.nnet.transformer.inference_exit_layers}") 105 | print(f"Avg. exit layer: {sum(seperator.executor.nnet.transformer.inference_exit_layers) * 1.0 / len(seperator.executor.nnet.transformer.inference_exit_layers)}") 106 | print(f"Inference times: {seperator.executor.inference_time}") 107 | print(f"Avg. inference time: {sum(seperator.executor.inference_time) * 1.0 / len(seperator.executor.inference_time)}") 108 | print(f"Processed {len(egs_reader)} utterances done") 109 | 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser( 113 | description="Command to do speech separation", 114 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 115 | parser.add_argument("--checkpoint", type=str, help="Directory of checkpoint") 116 | parser.add_argument("--mix-scp", 117 | type=str, 118 | required=True, 119 | help="Rspecifier for mixed audio") 120 | parser.add_argument("--num_spks", 121 | type=int, 122 | default=2, 123 | help="Number of the speakers") 124 | parser.add_argument("--device-id", 125 | type=int, 126 | default=-1, 127 | help="GPU-id to offload model to, -1 means running on CPU") 128 | parser.add_argument("--sr", 129 | type=int, 130 | default=16000, 131 | help="Sample rate for mixture input") 132 | parser.add_argument("--dump-dir", 133 | type=str, 134 | default="sep", 135 | help="Directory to dump separated speakers") 136 | parser.add_argument("--mvdr", 137 | type=bool, 138 | default=False, 139 | help="apply mvdr") 140 | parser.add_argument("--early_exit_threshold", 141 | type=float, 142 | default=0, 143 | help="Threshold for the early exit mechanism") 144 | args = parser.parse_args() 145 | run(args) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sanyuan-Chen/CSS_with_EETransformer/377904067835e442c50f7db019397ac24e9cdf49/utils/__init__.py -------------------------------------------------------------------------------- /utils/audio_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import soundfile as sf 4 | import scipy.io.wavfile as wf 5 | 6 | MAX_INT16 = np.iinfo(np.int16).max 7 | EPSILON = np.finfo(np.float32).eps 8 | 9 | 10 | def _parse_script(scp_path, 11 | value_processor=lambda x: x, 12 | num_tokens=2, 13 | restrict=True): 14 | """ 15 | Parse kaldi's script(.scp) file 16 | If num_tokens >= 2, function will check token number 17 | """ 18 | scp_dict = dict() 19 | line = 0 20 | with open(scp_path, "r") as f: 21 | for raw_line in f: 22 | scp_tokens = raw_line.strip().split() 23 | line += 1 24 | if (num_tokens >= 2 and len(scp_tokens) != num_tokens) or ( 25 | restrict and len(scp_tokens) < 2): 26 | raise RuntimeError( 27 | "For {}, format error in line[{:d}]: {}".format( 28 | scp_path, line, raw_line)) 29 | if num_tokens == 2: 30 | key, value = scp_tokens 31 | else: 32 | key, value = scp_tokens[0], scp_tokens[1:] 33 | if key in scp_dict: 34 | raise ValueError("Duplicated key \'{0}\' exists in {1}".format( 35 | key, scp_path)) 36 | scp_dict[key] = value_processor(value) 37 | return scp_dict 38 | 39 | 40 | class BaseReader(object): 41 | """ 42 | BaseReader Class 43 | """ 44 | def __init__(self, scp_rspecifier, **kwargs): 45 | self.index_dict = _parse_script(scp_rspecifier, **kwargs) 46 | self.index_keys = list(self.index_dict.keys()) 47 | 48 | def _load(self, key): 49 | # return path 50 | return self.index_dict[key] 51 | 52 | # number of utterance 53 | def __len__(self): 54 | return len(self.index_dict) 55 | 56 | # avoid key error 57 | def __contains__(self, key): 58 | return key in self.index_dict 59 | 60 | # sequential index 61 | def __iter__(self): 62 | for key in self.index_keys: 63 | yield key, self._load(key) 64 | 65 | 66 | class WaveReader(BaseReader): 67 | """ 68 | Sequential/Random Reader for single channel wave 69 | Format of wav.scp follows Kaldi's definition: 70 | key1 /path/to/wav 71 | ... 72 | """ 73 | def __init__(self, wav_scp, sr=16000, normalize=True): 74 | super(WaveReader, self).__init__(wav_scp) 75 | self.sr = sr 76 | self.normalize = normalize 77 | 78 | def _load(self, key): 79 | # return C x N or N 80 | sr, samps = read_wav(self.index_dict[key], 81 | normalize=self.normalize, 82 | return_rate=True) 83 | # if given samp_rate, check it 84 | if self.sr is not None and sr != self.sr: 85 | raise RuntimeError("Sample rate mismatch: {:d} vs {:d}".format( 86 | sr, self.sr)) 87 | 88 | return samps 89 | 90 | 91 | def read_wav(fname, beg=None, end=None, normalize=True, return_rate=False): 92 | """ 93 | Read wave files using scipy.io.wavfile(support multi-channel) 94 | """ 95 | # samps_int16: N x C or N 96 | # N: number of samples 97 | # C: number of channels 98 | if beg is not None: 99 | samps_int16, samp_rate = sf.read(fname, 100 | start=beg, 101 | stop=end, 102 | dtype="int16") 103 | else: 104 | samp_rate, samps_int16 = wf.read(fname) 105 | # N x C => C x N 106 | samps = samps_int16.astype(np.float32) 107 | # tranpose because I used to put channel axis first 108 | if samps.ndim != 1: 109 | samps = np.transpose(samps) 110 | # normalize like MATLAB and librosa 111 | if normalize: 112 | samps = samps / MAX_INT16 113 | if return_rate: 114 | return samp_rate, samps 115 | return samps 116 | 117 | 118 | def write_wav(fname, samps, sr=16000, normalize=True): 119 | """ 120 | Write wav files in int16, support single/multi-channel 121 | """ 122 | if normalize: 123 | samps = samps * MAX_INT16 124 | # scipy.io.wavfile.write could write single/multi-channel files 125 | # for multi-channel, accept ndarray [Nsamples, Nchannels] 126 | if samps.ndim != 1 and samps.shape[0] < samps.shape[1]: 127 | samps = np.transpose(samps) 128 | samps = np.squeeze(samps) 129 | # same as MATLAB and kaldi 130 | samps_int16 = samps.astype(np.int16) 131 | fdir = os.path.dirname(fname) 132 | if fdir: 133 | os.makedirs(fdir, exist_ok=True) 134 | # NOTE: librosa 0.6.0 seems could not write non-float narray 135 | # so use scipy.io.wavfile instead 136 | wf.write(fname, sr, samps_int16) 137 | 138 | -------------------------------------------------------------------------------- /utils/mvdr_util.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | 4 | 5 | def make_wta(result_mask): 6 | noise_mask = result_mask[2] 7 | if len(result_mask) == 4: 8 | noise_mask += result_mask[3] 9 | mask = np.stack((result_mask[0], result_mask[1],noise_mask)) 10 | mask_max = np.amax(mask, axis=0, keepdims=True) 11 | mask = np.where(mask==mask_max, mask, 1e-10) 12 | return mask 13 | 14 | 15 | def make_mvdr(s,result): 16 | mask=make_wta(result) 17 | M=[] 18 | for i in range(7): 19 | st=librosa.core.stft(s[:,i],n_fft=512,hop_length=256) 20 | M.append(st) 21 | M=np.asarray(M) 22 | 23 | L=np.min([mask.shape[-1],M.shape[-1]]) 24 | M=M[:,:,:L] 25 | 26 | mask=mask[:,:,:L] 27 | 28 | tgt_scm,_=get_mask_scm(M,mask[0]) 29 | itf_scm,_=get_mask_scm(M,mask[1]) 30 | noi_scm,_=get_mask_scm(M,mask[2]) 31 | 32 | coef=calc_bfcoeffs(noi_scm+itf_scm,tgt_scm) 33 | res=get_bf(M,coef) 34 | res1=librosa.istft(res,hop_length=256) 35 | 36 | coef=calc_bfcoeffs(noi_scm+tgt_scm,itf_scm) 37 | res=get_bf(M,coef) 38 | res2=librosa.istft(res,hop_length=256) 39 | 40 | return res1, res2 41 | 42 | 43 | def get_mask_scm(mix,mask): 44 | Ri = np.einsum('FT,FTM,FTm->FMm', mask, mix.transpose(1,2,0), mix.transpose(1,2,0).conj()) 45 | t1=np.eye(7) 46 | t2=t1[np.newaxis,:,:] 47 | Ri+=1e-15*t2 48 | return Ri,np.sum(mask) 49 | 50 | 51 | def calc_bfcoeffs(noi_scm,tgt_scm): 52 | # Calculate BF coeffs. 53 | num = np.linalg.solve(noi_scm, tgt_scm) 54 | den = np.trace(num, axis1=-2, axis2=-1)[..., np.newaxis, np.newaxis] 55 | den[0]+=1e-15 56 | W = (num / den)[..., 0] 57 | return W 58 | 59 | 60 | def get_bf(mix,W): 61 | c,f,t=mix.shape 62 | return np.sum(W.reshape(f,c,1).conj()*mix.transpose(1,0,2),axis=1) 63 | --------------------------------------------------------------------------------