├── README.md ├── executor ├── __init__.py ├── executor.py └── feature.py ├── nnet ├── __init__.py └── conformer.py ├── separate.py └── utils ├── __init__.py ├── audio_util.py ├── mvdr_util.py ├── overlapped_speech_1ch.scp └── overlapped_speech_7ch.scp /README.md: -------------------------------------------------------------------------------- 1 | # Continuous Speech Separation with Conformer 2 | 3 | ## Introduction 4 | 5 | We examine the use of the Conformer architecture for continuous speech separation. 6 | Conformer allows the separation model to efficiently capture both local and global context information, which is helpful for speech separation. 7 | Experimental results using the LibriCSS dataset show that the Conformer separation model achieves state of the art results for both single-channel and multi-channel settings. 8 | 9 | For a detailed description and experimental results, please refer to our paper: [Continuous Speech Separation with Conformer](https://arxiv.org/abs/2008.05773) (Accepted by ICASSP 2021). 10 | 11 | ## Environment 12 | python 3.6.9, torch 1.7.1 13 | 14 | ## Get Started 15 | 1. Download the overlapped speech of [LibriCSS dataset](https://github.com/chenzhuo1011/libri_css). 16 | 17 | - from Google drive: 18 | ```bash 19 | wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1PdloA-V8HGxkRu9MnT35_civpc3YXJsT' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1PdloA-V8HGxkRu9MnT35_civpc3YXJsT" -O overlapped_speech.zip && rm -rf /tmp/cookies.txt && unzip overlapped_speech.zip && rm overlapped_speech.zip 20 | ``` 21 | - from [Microsoft Azure Storage](https://valle.blob.core.windows.net/share/CSS_with_Conformer/overlapped_speech.zip?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D): 22 | ```bash 23 | wget "https://valle.blob.core.windows.net/share/CSS_with_Conformer/overlapped_speech.zip?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D" -O overlapped_speech.zip && rm -rf /tmp/cookies.txt && unzip overlapped_speech.zip && rm overlapped_speech.zip 24 | ``` 25 | 26 | 2. Download the Conformer separation models. 27 | 28 | - from [Microsoft Azure Storage](https://valle.blob.core.windows.net/share/CSS_with_Conformer/checkpoints.zip?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D): 29 | ```bash 30 | wget "https://valle.blob.core.windows.net/share/CSS_with_Conformer/checkpoints.zip?sv=2020-08-04&st=2023-03-01T07%3A51%3A05Z&se=2033-03-02T07%3A51%3A00Z&sr=c&sp=rl&sig=QJXmSJG9DbMKf48UDIU1MfzIro8HQOf3sqlNXiflY1I%3D" -O checkpoints.zip && rm -rf /tmp/cookies.txt && unzip checkpoints.zip && rm checkpoints.zip 31 | ``` 32 | 33 | 3. Run the separation. 34 | 35 | 3.1 Single-channel separation 36 | 37 | ```bash 38 | export MODEL_NAME=1ch_conformer_base 39 | python3 separate.py \ 40 | --checkpoint checkpoints/$MODEL_NAME \ 41 | --mix-scp utils/overlapped_speech_1ch.scp \ 42 | --dump-dir separated_speech/monaural/utterances_with_$MODEL_NAME \ 43 | --device-id 0 \ 44 | --num_spks 2 45 | ``` 46 | 47 | The separated speech can be found in the directory 'separated_speech/monaural/utterances_with_$MODEL_NAME' 48 | 49 | 3.2 Seven-channel separation 50 | 51 | ```bash 52 | export MODEL_NAME=conformer_base 53 | python3 separate.py \ 54 | --checkpoint checkpoints/$MODEL_NAME \ 55 | --mix-scp utils/overlapped_speech_7ch.scp \ 56 | --dump-dir separated_speech/7ch/utterances_with_$MODEL_NAME \ 57 | --device-id 0 \ 58 | --num_spks 2 \ 59 | --mvdr True 60 | ``` 61 | 62 | The separated speech can be found in the directory 'separated_speech/7ch/utterances_with_$MODEL_NAME' 63 | 64 | ## Citation 65 | If you find our work useful, please cite [our paper](https://arxiv.org/abs/2008.05773): 66 | ```bibtex 67 | @inproceedings{CSS_with_Conformer, 68 | title={Continuous speech separation with conformer}, 69 | author={Chen, Sanyuan and Wu, Yu and Chen, Zhuo and Wu, Jian and Li, Jinyu and Yoshioka, Takuya and Wang, Chengyi and Liu, Shujie and Zhou, Ming}, 70 | booktitle={ICASSP 2021-2021 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)}, 71 | pages={5749--5753}, 72 | year={2021}, 73 | organization={IEEE} 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /executor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sanyuan-Chen/CSS_with_Conformer/acdda14d370ed4197649e8a12a39ab67989d59af/executor/__init__.py -------------------------------------------------------------------------------- /executor/executor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch as th 4 | import torch.nn as nn 5 | from pathlib import Path 6 | from .feature import FeatureExtractor 7 | 8 | 9 | class Executor(nn.Module): 10 | """ 11 | Executor is a class to handle feature extraction 12 | and forward process of the separation networks. 13 | """ 14 | def __init__(self, nnet, extractor_kwargs=None, get_mask=True): 15 | super(Executor, self).__init__() 16 | self.nnet = nnet 17 | self.extractor = FeatureExtractor( 18 | **extractor_kwargs) if extractor_kwargs else None 19 | self.frame_len = extractor_kwargs['frame_len'] if extractor_kwargs else None 20 | self.frame_hop = extractor_kwargs['frame_hop'] if extractor_kwargs else None 21 | self.get_mask = get_mask 22 | 23 | def resume(self, checkpoint): 24 | """ 25 | Resume from checkpoint 26 | """ 27 | if not Path(checkpoint).exists(): 28 | raise FileNotFoundError( 29 | f"Could not find resume checkpoint: {checkpoint}") 30 | cpt = th.load(checkpoint, map_location="cpu") 31 | self.load_state_dict(cpt["model_state_dict"]) 32 | return cpt["epoch"] 33 | 34 | def _compute_feats(self, egs): 35 | """ 36 | Compute features: N x F x T 37 | """ 38 | if not self.extractor: 39 | raise RuntimeError("self.extractor is None, " 40 | "do not need to compute features") 41 | mag, pha, f = self.extractor(egs["mix"]) 42 | return mag, pha, f 43 | 44 | def forward(self, egs): 45 | mag, pha, f = self._compute_feats(egs) 46 | out = self.nnet(f) 47 | if self.get_mask: 48 | return out 49 | else: 50 | return [self.extractor.istft(m * mag, pha) for m in out] 51 | -------------------------------------------------------------------------------- /executor/feature.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | """ 4 | Implementation of front-end feature via PyTorch 5 | """ 6 | 7 | import math 8 | import torch as th 9 | 10 | from collections.abc import Sequence 11 | 12 | import torch.nn.functional as F 13 | import torch.nn as nn 14 | 15 | EPSILON = th.finfo(th.float32).eps 16 | MATH_PI = math.pi 17 | 18 | 19 | def init_kernel(frame_len, 20 | frame_hop, 21 | normalize=True, 22 | round_pow_of_two=True, 23 | window="sqrt_hann"): 24 | if window != "sqrt_hann" and window != "hann": 25 | raise RuntimeError("Now only support sqrt hanning window or hann window") 26 | # FFT points 27 | N = 2**math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len 28 | # window 29 | W = th.hann_window(frame_len) 30 | if window == "sqrt_hann": 31 | W = W**0.5 32 | # scale factor to make same magnitude after iSTFT 33 | if window == "sqrt_hann" and normalize: 34 | S = 0.5 * (N * N / frame_hop)**0.5 35 | else: 36 | S = 1 37 | # F x N/2+1 x 2 38 | K = th.rfft(th.eye(N) / S, 1)[:frame_len] 39 | # 2 x N/2+1 x F 40 | K = th.transpose(K, 0, 2) * W 41 | # N+2 x 1 x F 42 | K = th.reshape(K, (N + 2, 1, frame_len)) 43 | return K 44 | 45 | 46 | class STFTBase(nn.Module): 47 | """ 48 | Base layer for (i)STFT 49 | NOTE: 50 | 1) Recommend sqrt_hann window with 2**N frame length, because it 51 | could achieve perfect reconstruction after overlap-add 52 | 2) Now haven't consider padding problems yet 53 | """ 54 | def __init__(self, 55 | frame_len, 56 | frame_hop, 57 | window="sqrt_hann", 58 | normalize=True, 59 | round_pow_of_two=True): 60 | super(STFTBase, self).__init__() 61 | K = init_kernel(frame_len, 62 | frame_hop, 63 | round_pow_of_two=round_pow_of_two, 64 | window=window) 65 | self.K = nn.Parameter(K, requires_grad=False) 66 | self.stride = frame_hop 67 | self.window = window 68 | self.normalize = normalize 69 | self.num_bins = self.K.shape[0] // 2 70 | if window == "hann": 71 | self.conjugate = True 72 | else: 73 | self.conjugate = False 74 | 75 | def extra_repr(self): 76 | return (f"window={self.window}, stride={self.stride}, " + 77 | f"kernel_size={self.K.shape[0]}x{self.K.shape[2]}, " + 78 | f"normalize={self.normalize}") 79 | 80 | 81 | class STFT(STFTBase): 82 | """ 83 | Short-time Fourier Transform as a Layer 84 | """ 85 | def __init__(self, *args, **kwargs): 86 | super(STFT, self).__init__(*args, **kwargs) 87 | 88 | def forward(self, x, cplx=False): 89 | """ 90 | Accept (single or multiple channel) raw waveform and output magnitude and phase 91 | args 92 | x: input signal, N x C x S or N x S 93 | return 94 | m: magnitude, N x C x F x T or N x F x T 95 | p: phase, N x C x F x T or N x F x T 96 | """ 97 | if x.dim() not in [2, 3]: 98 | raise RuntimeError( 99 | "{} expect 2D/3D tensor, but got {:d}D signal".format( 100 | self.__name__, x.dim())) 101 | # if N x S, reshape N x 1 x S 102 | if x.dim() == 2: 103 | x = th.unsqueeze(x, 1) 104 | # N x 2F x T 105 | c = F.conv1d(x, self.K, stride=self.stride, padding=0) 106 | # N x F x T 107 | r, i = th.chunk(c, 2, dim=1) 108 | if self.conjugate: 109 | # to match with science pipeline, we need to do conjugate 110 | i = -i 111 | # else reshape NC x 1 x S 112 | else: 113 | N, C, S = x.shape 114 | x = x.view(N * C, 1, S) 115 | # NC x 2F x T 116 | c = F.conv1d(x, self.K, stride=self.stride, padding=0) 117 | # N x C x 2F x T 118 | c = c.view(N, C, -1, c.shape[-1]) 119 | # N x C x F x T 120 | r, i = th.chunk(c, 2, dim=2) 121 | if self.conjugate: 122 | # to match with science pipeline, we need to do conjugate 123 | i = -i 124 | if cplx: 125 | return r, i 126 | m = (r**2 + i**2)**0.5 127 | p = th.atan2(i, r) 128 | return m, p 129 | 130 | 131 | class iSTFT(STFTBase): 132 | """ 133 | Inverse Short-time Fourier Transform as a Layer 134 | """ 135 | def __init__(self, *args, **kwargs): 136 | super(iSTFT, self).__init__(*args, **kwargs) 137 | 138 | def forward(self, m, p, cplx=False, squeeze=False): 139 | """ 140 | Accept phase & magnitude and output raw waveform 141 | args 142 | m, p: N x F x T 143 | return 144 | s: N x S 145 | """ 146 | if p.dim() != m.dim() or p.dim() not in [2, 3]: 147 | raise RuntimeError("Expect 2D/3D tensor, but got {:d}D".format( 148 | p.dim())) 149 | # if F x T, reshape 1 x F x T 150 | if p.dim() == 2: 151 | p = th.unsqueeze(p, 0) 152 | m = th.unsqueeze(m, 0) 153 | if cplx: 154 | # N x 2F x T 155 | c = th.cat([m, p], dim=1) 156 | else: 157 | r = m * th.cos(p) 158 | i = m * th.sin(p) 159 | # N x 2F x T 160 | c = th.cat([r, i], dim=1) 161 | # N x 2F x T 162 | s = F.conv_transpose1d(c, self.K, stride=self.stride, padding=0) 163 | # N x S 164 | s = s.squeeze(1) 165 | if squeeze: 166 | s = th.squeeze(s) 167 | return s 168 | 169 | 170 | class IPDFeature(nn.Module): 171 | """ 172 | Compute inter-channel phase difference 173 | """ 174 | def __init__(self, 175 | ipd_index="1,0;2,0;3,0;4,0;5,0;6,0", 176 | cos=True, 177 | sin=False, 178 | ipd_mean_normalize_version=2, 179 | ipd_mean_normalize=True): 180 | super(IPDFeature, self).__init__() 181 | split_index = lambda sstr: [ 182 | tuple(map(int, p.split(","))) for p in sstr.split(";") 183 | ] 184 | # ipd index 185 | pair = split_index(ipd_index) 186 | self.index_l = [t[0] for t in pair] 187 | self.index_r = [t[1] for t in pair] 188 | self.ipd_index = ipd_index 189 | self.cos = cos 190 | self.sin = sin 191 | self.ipd_mean_normalize=ipd_mean_normalize 192 | self.ipd_mean_normalize_version=ipd_mean_normalize_version 193 | self.num_pairs = len(pair) * 2 if cos and sin else len(pair) 194 | 195 | def extra_repr(self): 196 | return f"ipd_index={self.ipd_index}, cos={self.cos}, sin={self.sin}" 197 | 198 | def forward(self, p): 199 | """ 200 | Accept multi-channel phase and output inter-channel phase difference 201 | args 202 | p: phase matrix, N x C x F x T 203 | return 204 | ipd: N x MF x T 205 | """ 206 | if p.dim() not in [3, 4]: 207 | raise RuntimeError( 208 | "{} expect 3/4D tensor, but got {:d} instead".format( 209 | self.__name__, p.dim())) 210 | # C x F x T => 1 x C x F x T 211 | if p.dim() == 3: 212 | p = p.unsqueeze(0) 213 | N, _, _, T = p.shape 214 | pha_dif = p[:, self.index_l] - p[:, self.index_r] 215 | if self.ipd_mean_normalize: 216 | yr = th.cos(pha_dif) 217 | yi = th.sin(pha_dif) 218 | yrm = yr.mean(-1, keepdim=True) 219 | yim = yi.mean(-1, keepdim=True) 220 | if self.ipd_mean_normalize_version == 1: 221 | pha_dif = th.atan2(yi - yim, yr - yrm) 222 | elif self.ipd_mean_normalize_version == 2: 223 | pha_dif_mean = th.atan2(yim, yrm) 224 | pha_dif -= pha_dif_mean 225 | elif self.ipd_mean_normalize_version == 3: 226 | pha_dif_mean = pha_dif.mean(-1, keepdim=True) 227 | pha_dif -= pha_dif_mean 228 | else: 229 | # we only support version 1, 2 and 3 230 | raise RuntimeError( 231 | "{} expect ipd_mean_normalization version 1 or version 2, but got {:d} instead".format( 232 | self.__name__, self.ipd_mean_normalize_version)) 233 | 234 | if self.cos: 235 | # N x M x F x T 236 | ipd = th.cos(pha_dif) 237 | if self.sin: 238 | # N x M x 2F x T, along frequency axis 239 | ipd = th.cat([ipd, th.sin(pha_dif)], 2) 240 | else: 241 | # th.fmod behaves differently from np.mod for the input that is less than -math.pi 242 | # i believe it is a bug 243 | # so we need to ensure it is larger than -math.pi by adding an extra 6 * math.pi 244 | #ipd = th.fmod(pha_dif + math.pi, 2 * math.pi) - math.pi 245 | ipd = pha_dif 246 | # N x MF x T 247 | ipd = ipd.view(N, -1, T) 248 | # N x MF x T 249 | return ipd 250 | 251 | 252 | class AngleFeature(nn.Module): 253 | """ 254 | Compute angle/directional feature 255 | 1) num_doas == 1: we known the DoA of the target speaker 256 | 2) num_doas != 1: we do not have that prior, so we sampled #num_doas DoAs 257 | and compute on each directions 258 | """ 259 | def __init__(self, 260 | geometric="princeton", 261 | sr=16000, 262 | velocity=340, 263 | num_bins=257, 264 | num_doas=1, 265 | af_index="1,0;2,0;3,0;4,0;5,0;6,0"): 266 | super(AngleFeature, self).__init__() 267 | if geometric not in ["princeton"]: 268 | raise RuntimeError( 269 | "Unsupported array geometric: {}".format(geometric)) 270 | self.geometric = geometric 271 | self.sr = sr 272 | self.num_bins = num_bins 273 | self.num_doas = num_doas 274 | self.velocity = velocity 275 | split_index = lambda sstr: [ 276 | tuple(map(int, p.split(","))) for p in sstr.split(";") 277 | ] 278 | # ipd index 279 | pair = split_index(af_index) 280 | self.index_l = [t[0] for t in pair] 281 | self.index_r = [t[1] for t in pair] 282 | self.af_index = af_index 283 | omega = th.tensor( 284 | [math.pi * sr * f / (num_bins - 1) for f in range(num_bins)]) 285 | # 1 x F 286 | self.omega = nn.Parameter(omega[None, :], requires_grad=False) 287 | 288 | def _oracle_phase_delay(self, doa): 289 | """ 290 | Compute oracle phase delay given DoA 291 | args 292 | doa: N 293 | return 294 | phi: N x C x F or N x D x C x F 295 | """ 296 | device = doa.device 297 | if self.num_doas != 1: 298 | # doa is a unused, fake parameter 299 | N = doa.shape[0] 300 | # N x D 301 | doa = th.linspace(0, MATH_PI * 2, self.num_doas + 1, 302 | device=device)[:-1].repeat(N, 1) 303 | # for princeton 304 | # M = 7, R = 0.0425, treat M_0 as (0, 0) 305 | # *3 *2 306 | # 307 | # *4 *0 *1 308 | # 309 | # *5 *6 310 | if self.geometric == "princeton": 311 | R = 0.0425 312 | zero = th.zeros_like(doa) 313 | # N x 7 or N x D x 7 314 | tau = R * th.stack([ 315 | zero, -th.cos(doa), -th.cos(MATH_PI / 3 - doa), 316 | -th.cos(2 * MATH_PI / 3 - doa), 317 | th.cos(doa), 318 | th.cos(MATH_PI / 3 - doa), 319 | th.cos(2 * MATH_PI / 3 - doa) 320 | ], 321 | dim=-1) / self.velocity 322 | # (Nx7x1) x (1xF) => Nx7xF or (NxDx7x1) x (1xF) => NxDx7xF 323 | phi = th.matmul(tau.unsqueeze(-1), -self.omega) 324 | return phi 325 | else: 326 | return None 327 | 328 | def extra_repr(self): 329 | return ( 330 | f"geometric={self.geometric}, af_index={self.af_index}, " + 331 | f"sr={self.sr}, num_bins={self.num_bins}, velocity={self.velocity}, " 332 | + f"known_doa={self.num_doas == 1}") 333 | 334 | def _compute_af(self, ipd, doa): 335 | """ 336 | Compute angle feature 337 | args 338 | ipd: N x C x F x T 339 | doa: DoA of the target speaker (if we known that), N 340 | or N x D (we do not known that, sampling D DoAs instead) 341 | return 342 | af: N x F x T or N x D x F x T 343 | """ 344 | # N x C x F or N x D x C x F 345 | d = self._oracle_phase_delay(doa) 346 | d = d.unsqueeze(-1) 347 | if self.num_doas == 1: 348 | dif = d[:, self.index_l] - d[:, self.index_r] 349 | # N x C x F x T 350 | af = th.cos(ipd - dif) 351 | # on channel dimention (mean or sum) 352 | af = th.mean(af, dim=1) 353 | else: 354 | # N x D x C x F x 1 355 | dif = d[:, :, self.index_l] - d[:, :, self.index_r] 356 | # N x D x C x F x T 357 | af = th.cos(ipd.unsqueeze(1) - dif) 358 | # N x D x F x T 359 | af = th.mean(af, dim=2) 360 | return af 361 | 362 | def forward(self, p, doa): 363 | """ 364 | Accept doa of the speaker & multi-channel phase, output angle feature 365 | args 366 | doa: DoA of target/each speaker, N or [N, ...] 367 | p: phase matrix, N x C x F x T 368 | return 369 | af: angle feature, N x F* x T or N x D x F x T (known_doa=False) 370 | """ 371 | if p.dim() not in [3, 4]: 372 | raise RuntimeError( 373 | "{} expect 3/4D tensor, but got {:d} instead".format( 374 | self.__name__, p.dim())) 375 | # C x F x T => 1 x C x F x T 376 | if p.dim() == 3: 377 | p = p.unsqueeze(0) 378 | ipd = p[:, self.index_l] - p[:, self.index_r] 379 | 380 | if isinstance(doa, Sequence): 381 | if self.num_doas != 1: 382 | raise RuntimeError("known_doa=False, no need to pass " 383 | "doa as a Sequence object") 384 | # [N x F x T or N x D x F x T, ...] 385 | af = [self._compute_af(ipd, spk_doa) for spk_doa in doa] 386 | # N x F x T => N x F* x T 387 | af = th.cat(af, 1) 388 | else: 389 | # N x F x T or N x D x F x T 390 | af = self._compute_af(ipd, doa) 391 | return af 392 | 393 | 394 | class FeatureExtractor(nn.Module): 395 | """ 396 | A PyTorch module to handle spectral & spatial features 397 | """ 398 | def __init__(self, 399 | frame_len=512, 400 | frame_hop=256, 401 | normalize=True, 402 | round_pow_of_two=True, 403 | num_spks=2, 404 | log_spectrogram=True, 405 | mvn_spectrogram=True, 406 | ipd_mean_normalize=True, 407 | ipd_mean_normalize_version=2, 408 | window="sqrt_hann", 409 | ext_af=0, 410 | ipd_cos=True, 411 | ipd_sin=False, 412 | ipd_index="1,4;2,5;3,6", 413 | ang_index="1,0;2,0;3,0;4,0;5,0;6,0" 414 | ): 415 | super(FeatureExtractor, self).__init__() 416 | # forward STFT 417 | self.forward_stft = STFT(frame_len, 418 | frame_hop, 419 | normalize=normalize, 420 | window=window, 421 | round_pow_of_two=round_pow_of_two) 422 | self.inverse_stft = iSTFT(frame_len, 423 | frame_hop, 424 | normalize=normalize, 425 | round_pow_of_two=round_pow_of_two) 426 | self.has_spatial = False 427 | num_bins = self.forward_stft.num_bins 428 | self.feature_dim = num_bins 429 | self.num_bins = num_bins 430 | self.num_spks = num_spks 431 | # add extra angle feature 432 | self.ext_af = ext_af 433 | 434 | # IPD or not 435 | self.ipd_extractor = None 436 | if ipd_index: 437 | self.ipd_extractor = IPDFeature(ipd_index, 438 | cos=ipd_cos, 439 | sin=ipd_sin, 440 | ipd_mean_normalize_version=ipd_mean_normalize_version, 441 | ipd_mean_normalize=ipd_mean_normalize) 442 | self.feature_dim += self.ipd_extractor.num_pairs * num_bins 443 | self.has_spatial = True 444 | # AF or not 445 | self.ang_extractor = None 446 | if ang_index: 447 | self.ang_extractor = AngleFeature( 448 | num_bins=num_bins, 449 | num_doas=1, # must known the DoA 450 | af_index=ang_index) 451 | self.feature_dim += num_bins * self.num_spks * (1 + self.ext_af) 452 | self.has_spatial = True 453 | # BN or not 454 | self.mvn_mag = mvn_spectrogram 455 | # apply log or not 456 | self.log_mag = log_spectrogram 457 | 458 | def _check_args(self, x, doa): 459 | if x.dim() == 2 and self.has_spatial: 460 | raise RuntimeError("Got 2D (single channel) input and can " 461 | "not extract spatial features") 462 | if self.ang_extractor is None and doa: 463 | raise RuntimeError("DoA is given and AF extractor " 464 | "is not initialized") 465 | if self.ang_extractor and doa is None: 466 | raise RuntimeError("AF extractor is initialized, but DoA is None") 467 | num_af = self.num_spks * (self.ext_af + 1) 468 | if isinstance(doa, Sequence) and len(doa) != num_af: 469 | raise RuntimeError("Number of DoA do not match the " + 470 | f"speaker number: {len(doa):d} vs {num_af:d}") 471 | 472 | def stft(self, x, cplx=False): 473 | return self.forward_stft(x, cplx=cplx) 474 | 475 | def istft(self, m, p, cplx=False): 476 | return self.inverse_stft(m, p, cplx=cplx) 477 | 478 | def compute_spectra(self, x): 479 | """ 480 | Compute spectra features 481 | args 482 | x: N x C x S (multi-channel) or N x S (single channel) 483 | return: 484 | mag & pha: N x F x T or N x C x F x T 485 | feature: N x * x T 486 | """ 487 | # mag & pha: N x C x F x T or N x F x T 488 | mag, pha = self.forward_stft(x) 489 | # ch0: N x F x T 490 | if mag.dim() == 4: 491 | f = th.clamp(mag[:, 0], min=EPSILON) 492 | else: 493 | f = th.clamp(mag, min=EPSILON) 494 | # log 495 | if self.log_mag: 496 | f = th.log(f) 497 | # mvn 498 | if self.mvn_mag: 499 | # f = self.mvn_mag(f) 500 | f = (f - f.mean(-1, keepdim=True)) / (f.std(-1, keepdim=True) + 501 | EPSILON) 502 | return mag, pha, f 503 | 504 | def compute_spatial(self, x, doa=None, pha=None): 505 | """ 506 | Compute spatial features 507 | args 508 | x: N x C x S (multi-channel) 509 | pha: N x C x F x T 510 | return 511 | feature: N x * x T 512 | """ 513 | if pha is None: 514 | self._check_args(x, doa) 515 | # mag & pha: N x C x F x T 516 | _, pha = self.forward_stft(x) 517 | else: 518 | if pha.dim() != 4: 519 | raise RuntimeError("Expect phase matrix a 4D tensor, " + 520 | f"got {pha.dim()} instead") 521 | feature = [] 522 | if self.has_spatial: 523 | if self.ipd_extractor: 524 | # N x C x F x T => N x MF x T 525 | ipd = self.ipd_extractor(pha) 526 | feature.append(ipd) 527 | if self.ang_extractor: 528 | # N x C x F x T => N x F* x T 529 | ang = self.ang_extractor(pha, doa) 530 | feature.append(ang) 531 | else: 532 | raise RuntimeError("No spatial features are configured") 533 | # N x * x T 534 | feature = th.cat(feature, 1) 535 | return feature 536 | 537 | def forward(self, x, doa=None, ref_channel=0): 538 | """ 539 | args 540 | x: N x C x S (multi-channel) or N x S (single channel) 541 | doa: N or [N, ...] (for each speaker) 542 | return: 543 | mag & pha: N x F x T (if ref_channel is not None), N x C x F x T 544 | feature: N x * x T 545 | """ 546 | self._check_args(x, doa) 547 | # mag & pha: N x C x F x T or N x F x T 548 | mag, pha, f = self.compute_spectra(x) 549 | feature = [f] 550 | if self.has_spatial: 551 | spatial = self.compute_spatial(x, pha=pha, doa=doa) 552 | feature.append(spatial) 553 | # N x * x T 554 | feature = th.cat(feature, 1) 555 | if mag.dim() == 4 and ref_channel is not None: 556 | return mag[:, ref_channel], pha[:, ref_channel], feature 557 | else: 558 | return mag, pha, feature -------------------------------------------------------------------------------- /nnet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from . import conformer 3 | 4 | supported_nnet = { 5 | "conformer": conformer.ConformerCSS, 6 | } 7 | -------------------------------------------------------------------------------- /nnet/conformer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | 4 | """Implementation of Conformer speech separation model""" 5 | 6 | import math 7 | import numpy 8 | import torch 9 | from torch import nn 10 | 11 | 12 | class RelativePositionalEncoding(torch.nn.Module): 13 | def __init__(self, d_model, maxlen=1000, embed_v=False): 14 | super(RelativePositionalEncoding, self).__init__() 15 | 16 | self.d_model = d_model 17 | self.maxlen = maxlen 18 | self.pe_k = torch.nn.Embedding(2*maxlen, d_model) 19 | if embed_v: 20 | self.pe_v = torch.nn.Embedding(2*maxlen, d_model) 21 | self.embed_v = embed_v 22 | 23 | def forward(self, pos_seq): 24 | pos_seq.clamp_(-self.maxlen, self.maxlen - 1) 25 | pos_seq = pos_seq + self.maxlen 26 | if self.embed_v: 27 | return self.pe_k(pos_seq), self.pe_v(pos_seq) 28 | else: 29 | return self.pe_k(pos_seq), None 30 | 31 | 32 | class MultiHeadedAttention(nn.Module): 33 | """Multi-Head Attention layer. 34 | 35 | :param int n_head: the number of head s 36 | :param int n_feat: the number of features 37 | :param float dropout_rate: dropout rate 38 | 39 | """ 40 | 41 | def __init__(self, n_head, n_feat, dropout_rate): 42 | """Construct an MultiHeadedAttention object.""" 43 | super(MultiHeadedAttention, self).__init__() 44 | assert n_feat % n_head == 0 45 | # We assume d_v always equals d_k 46 | self.d_k = n_feat // n_head 47 | self.h = n_head 48 | self.layer_norm = nn.LayerNorm(n_feat) 49 | self.linear_q = nn.Linear(n_feat, n_feat) 50 | self.linear_k = nn.Linear(n_feat, n_feat) 51 | self.linear_v = nn.Linear(n_feat, n_feat) 52 | 53 | self.linear_out = nn.Linear(n_feat, n_feat) 54 | self.attn = None 55 | self.dropout = nn.Dropout(p=dropout_rate) 56 | 57 | def forward(self, x, pos_k, mask): 58 | """Compute 'Scaled Dot Product Attention'. 59 | 60 | :param torch.Tensor mask: (batch, time1, time2) 61 | :param torch.nn.Dropout dropout: 62 | :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model) 63 | weighted by the query dot key attention (batch, head, time1, time2) 64 | """ 65 | n_batch = x.size(0) 66 | x = self.layer_norm(x) 67 | q = self.linear_q(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d) 68 | k = self.linear_k(x).view(n_batch, -1, self.h, self.d_k) #(b, t, d) 69 | v = self.linear_v(x).view(n_batch, -1, self.h, self.d_k) 70 | q = q.transpose(1, 2) 71 | k = k.transpose(1, 2) # (batch, head, time2, d_k) 72 | v = v.transpose(1, 2) # (batch, head, time2, d_k) 73 | A = torch.matmul(q, k.transpose(-2, -1)) 74 | reshape_q = q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0,1) 75 | if pos_k is not None: 76 | B = torch.matmul(reshape_q, pos_k.transpose(-2, -1)) 77 | B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1)) 78 | scores = (A + B) / math.sqrt(self.d_k) 79 | else: 80 | scores = A / math.sqrt(self.d_k) 81 | if mask is not None: 82 | mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) 83 | min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) 84 | scores = scores.masked_fill(mask, min_value) 85 | self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) 86 | else: 87 | self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) 88 | 89 | p_attn = self.dropout(self.attn) 90 | x = torch.matmul(p_attn, v) # (batch, head, time1, d_k) 91 | x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) # (batch, time1, d_model) 92 | return self.dropout(self.linear_out(x)) # (batch, time1, d_model) 93 | 94 | 95 | class ConvModule(nn.Module): 96 | def __init__(self, input_dim, kernel_size, dropout_rate, causal=False): 97 | super(ConvModule, self).__init__() 98 | self.layer_norm = nn.LayerNorm(input_dim) 99 | 100 | self.pw_conv_1 = nn.Conv2d(1, 2, 1, 1, 0) 101 | self.glu_act = torch.nn.Sigmoid() 102 | self.causal = causal 103 | if causal: 104 | self.dw_conv_1d = nn.Conv1d(input_dim, input_dim, kernel_size, 1, padding=(kernel_size-1), groups=input_dim) 105 | else: 106 | self.dw_conv_1d = nn.Conv1d(input_dim, input_dim, kernel_size, 1, padding=(kernel_size-1)//2, groups=input_dim) 107 | self.BN = nn.BatchNorm1d(input_dim) 108 | self.act = nn.ReLU() 109 | self.pw_conv_2 = nn.Conv2d(1, 1, 1, 1, 0) 110 | self.dropout = nn.Dropout(dropout_rate) 111 | self.kernel_size = kernel_size 112 | 113 | def forward(self, x): 114 | x = x.unsqueeze(1) 115 | x = self.layer_norm(x) 116 | x = self.pw_conv_1(x) 117 | x = x[:, 0] * self.glu_act(x[:, 1]) 118 | x = x.permute([0, 2, 1]) 119 | x = self.dw_conv_1d(x) 120 | if self.causal: 121 | x = x[:, :, :-(self.kernel_size-1)] 122 | x = self.BN(x) 123 | x = self.act(x) 124 | x = x.unsqueeze(1).permute([0, 1, 3, 2]) 125 | x = self.pw_conv_2(x) 126 | x = self.dropout(x).squeeze(1) 127 | return x 128 | 129 | 130 | class FeedForward(nn.Module): 131 | def __init__(self, d_model, d_inner, dropout_rate): 132 | super(FeedForward, self).__init__() 133 | 134 | self.d_model = d_model 135 | self.d_inner = d_inner 136 | 137 | self.layer_norm = nn.LayerNorm(d_model) 138 | self.net = nn.Sequential( 139 | nn.Linear(d_model, d_inner), 140 | nn.ReLU(inplace=True), 141 | nn.Dropout(dropout_rate), 142 | nn.Linear(d_inner, d_model), 143 | nn.Dropout(dropout_rate) 144 | ) 145 | 146 | def forward(self, x): 147 | x = self.layer_norm(x) 148 | out = self.net(x) 149 | 150 | return out 151 | 152 | 153 | class EncoderLayer(nn.Module): 154 | """Encoder layer module. 155 | 156 | :param int d_model: attention vector size 157 | :param int n_head: number of heads 158 | :param int d_ffn: feedforward size 159 | :param int kernel_size: cnn kernal size, it must be an odd 160 | :param int dropout_rate: dropout_rate 161 | """ 162 | 163 | def __init__(self, d_model, n_head, d_ffn, kernel_size, dropout_rate, causal=False): 164 | """Construct an EncoderLayer object.""" 165 | super(EncoderLayer, self).__init__() 166 | self.feed_forward_in = FeedForward(d_model, d_ffn, dropout_rate) 167 | self.self_attn = MultiHeadedAttention(n_head, d_model, dropout_rate) 168 | self.conv = ConvModule(d_model, kernel_size, dropout_rate, causal=causal) 169 | self.feed_forward_out = FeedForward(d_model, d_ffn, dropout_rate) 170 | self.layer_norm = nn.LayerNorm(d_model) 171 | 172 | def forward(self, x, pos_k, mask): 173 | """Compute encoded features. 174 | 175 | :param torch.Tensor x: encoded source features (batch, max_time_in, size) 176 | :param torch.Tensor mask: mask for x (batch, max_time_in) 177 | :rtype: Tuple[torch.Tensor, torch.Tensor] 178 | """ 179 | x = x + 0.5 * self.feed_forward_in(x) 180 | x = x + self.self_attn(x, pos_k, mask) 181 | x = x + self.conv(x) 182 | x = x + 0.5 * self.feed_forward_out(x) 183 | 184 | out = self.layer_norm(x) 185 | 186 | return out 187 | 188 | 189 | class ConformerEncoder(nn.Module): 190 | """Conformer Encoder https://arxiv.org/abs/2005.08100 191 | """ 192 | def __init__(self, 193 | idim=257, 194 | attention_dim=256, 195 | attention_heads=4, 196 | linear_units=1024, 197 | num_blocks=16, 198 | kernel_size=33, 199 | dropout_rate=0.1, 200 | causal=False, 201 | relative_pos_emb=True 202 | ): 203 | super(ConformerEncoder, self).__init__() 204 | 205 | self.embed = torch.nn.Sequential( 206 | torch.nn.Linear(idim, attention_dim), 207 | torch.nn.LayerNorm(attention_dim), 208 | torch.nn.Dropout(dropout_rate), 209 | torch.nn.ReLU(), 210 | ) 211 | 212 | if relative_pos_emb: 213 | self.pos_emb = RelativePositionalEncoding(attention_dim // attention_heads, 1000, False) 214 | else: 215 | self.pos_emb = None 216 | 217 | self.encoders = torch.nn.Sequential(*[EncoderLayer( 218 | attention_dim, 219 | attention_heads, 220 | linear_units, 221 | kernel_size, 222 | dropout_rate, 223 | causal=causal 224 | ) for _ in range(num_blocks)]) 225 | 226 | def forward(self, xs, masks): 227 | xs = self.embed(xs) 228 | 229 | if self.pos_emb is not None: 230 | x_len = xs.shape[1] 231 | pos_seq = torch.arange(0, x_len).long().to(xs.device) 232 | pos_seq = pos_seq[:, None] - pos_seq[None, :] 233 | pos_k, _ = self.pos_emb(pos_seq) 234 | else: 235 | pos_k = None 236 | for layer in self.encoders: 237 | xs = layer(xs, pos_k, masks) 238 | 239 | return xs, masks 240 | 241 | 242 | default_encoder_conf = { 243 | "attention_dim": 256, 244 | "attention_heads": 4, 245 | "linear_units": 1024, 246 | "num_blocks": 16, 247 | "kernel_size": 33, 248 | "dropout_rate": 0.1, 249 | "relative_pos_emb": True 250 | } 251 | 252 | 253 | class ConformerCSS(nn.Module): 254 | """ 255 | Conformer speech separation model 256 | """ 257 | def __init__(self, 258 | stats_file=None, 259 | in_features=257, 260 | num_bins=257, 261 | num_spks=2, 262 | num_nois=1, 263 | conformer_conf=default_encoder_conf): 264 | super(ConformerCSS, self).__init__() 265 | 266 | # input normalization layer 267 | if stats_file is not None: 268 | stats = numpy.load(stats_file) 269 | self.input_bias = torch.from_numpy(numpy.tile(numpy.expand_dims(-stats['mean'].astype(numpy.float32), axis=0), (1, 1, 1))) 270 | self.input_scale = torch.from_numpy(numpy.tile(numpy.expand_dims(1 / numpy.sqrt(stats['variance'].astype(numpy.float32)), axis=0), (1, 1, 1))) 271 | self.input_bias = nn.Parameter(self.input_bias, requires_grad=False) 272 | self.input_scale = nn.Parameter(self.input_scale, requires_grad=False) 273 | else: 274 | self.input_bias = torch.zeros(1,1,in_features) 275 | self.input_scale = torch.ones(1,1,in_features) 276 | self.input_bias = nn.Parameter(self.input_bias, requires_grad=False) 277 | self.input_scale = nn.Parameter(self.input_scale, requires_grad=False) 278 | 279 | # Conformer Encoders 280 | self.conformer = ConformerEncoder(in_features, **conformer_conf) 281 | 282 | self.num_bins = num_bins 283 | self.num_spks = num_spks 284 | self.num_nois = num_nois 285 | self.linear = nn.Linear(conformer_conf["attention_dim"], num_bins * (num_spks + num_nois)) 286 | 287 | def forward(self, f): 288 | """ 289 | args 290 | f: N x * x T 291 | return 292 | m: [N x F x T, ...] 293 | """ 294 | # N x * x T => N x T x * 295 | f = f.transpose(1, 2) 296 | 297 | # global feature normalization 298 | f = f + self.input_bias 299 | f = f * self.input_scale 300 | 301 | f, _ = self.conformer(f, masks=None) 302 | m = self.linear(f) 303 | 304 | m = torch.sigmoid(m) 305 | 306 | # N x T x F => N x F x T 307 | m = m.transpose(1, 2) 308 | if self.num_spks > 1: 309 | m = torch.chunk(m, self.num_spks + self.num_nois, 1) 310 | return m 311 | -------------------------------------------------------------------------------- /separate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import yaml 4 | import argparse 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | from pathlib import Path 10 | 11 | from nnet import supported_nnet 12 | from executor.executor import Executor 13 | from utils.audio_util import WaveReader, write_wav 14 | from utils.mvdr_util import make_mvdr 15 | 16 | 17 | class EgsReader(object): 18 | """ 19 | Egs reader 20 | """ 21 | def __init__(self, 22 | mix_scp, 23 | sr=16000): 24 | self.mix_reader = WaveReader(mix_scp, sr=sr) 25 | 26 | def __len__(self): 27 | return len(self.mix_reader) 28 | 29 | def __iter__(self): 30 | for key, mix in self.mix_reader: 31 | egs = dict() 32 | egs["mix"] = mix 33 | yield key, egs 34 | 35 | 36 | class Separator(object): 37 | """ 38 | A simple wrapper for speech separation 39 | """ 40 | def __init__(self, cpt_dir, get_mask=False, device_id=-1): 41 | # load executor 42 | cpt_dir = Path(cpt_dir) 43 | self.get_mask = get_mask 44 | self.executor = self._load_executor(cpt_dir) 45 | cpt_ptr = cpt_dir / "best.pt.tar" 46 | epoch = self.executor.resume(cpt_ptr.as_posix()) 47 | print(f"Load checkpoint at {cpt_dir}, on epoch {epoch}") 48 | print(f"Nnet summary: {self.executor}") 49 | if device_id < 0: 50 | self.device = th.device("cpu") 51 | else: 52 | self.device = th.device(f"cuda:{device_id:d}") 53 | self.executor.to(self.device) 54 | self.executor.eval() 55 | 56 | def separate(self, egs): 57 | """ 58 | Do separation 59 | """ 60 | egs["mix"] = th.from_numpy(egs["mix"][None, :]).to(self.device, non_blocking=True) 61 | with th.no_grad(): 62 | spks = self.executor(egs) 63 | spks = [s.detach().squeeze().cpu().numpy() for s in spks] 64 | return spks 65 | 66 | def _load_executor(self, cpt_dir): 67 | """ 68 | Load executor from checkpoint 69 | """ 70 | with open(cpt_dir / "train.yaml", "r") as f: 71 | conf = yaml.load(f, Loader=yaml.FullLoader) 72 | nnet_type = conf["nnet_type"] 73 | if nnet_type not in supported_nnet: 74 | raise RuntimeError(f"Unknown network type: {nnet_type}") 75 | nnet = supported_nnet[nnet_type](**conf["nnet_conf"]) 76 | executor = Executor(nnet, extractor_kwargs=conf["extractor_conf"], get_mask=self.get_mask) 77 | return executor 78 | 79 | 80 | 81 | def run(args): 82 | # egs reader 83 | egs_reader = EgsReader(args.mix_scp, sr=args.sr) 84 | # separator 85 | seperator = Separator(args.checkpoint, device_id=args.device_id, get_mask=args.mvdr) 86 | 87 | dump_dir = Path(args.dump_dir) 88 | dump_dir.mkdir(exist_ok=True, parents=True) 89 | 90 | print(f"Start Separation " + ("w/ mvdr" if args.mvdr else "w/o mvdr")) 91 | for key, egs in egs_reader: 92 | print(f"Processing utterance {key}...") 93 | mixed = egs["mix"] 94 | spks = seperator.separate(egs) 95 | 96 | if args.mvdr: 97 | res1, res2 = make_mvdr(np.asfortranarray(mixed.T), spks) 98 | spks = [res1, res2] 99 | 100 | for i, s in enumerate(spks): 101 | if i < args.num_spks: 102 | write_wav(dump_dir / f"{key}_{i}.wav", 103 | s * 0.9 / np.max(np.abs(s))) 104 | 105 | print(f"Processed {len(egs_reader)} utterances done") 106 | 107 | 108 | if __name__ == "__main__": 109 | parser = argparse.ArgumentParser( 110 | description="Command to do speech separation", 111 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 112 | parser.add_argument("--checkpoint", type=str, help="Directory of checkpoint") 113 | parser.add_argument("--mix-scp", 114 | type=str, 115 | required=True, 116 | help="Rspecifier for mixed audio") 117 | parser.add_argument("--num_spks", 118 | type=int, 119 | default=2, 120 | help="Number of the speakers") 121 | parser.add_argument("--device-id", 122 | type=int, 123 | default=-1, 124 | help="GPU-id to offload model to, -1 means " 125 | "running on CPU") 126 | parser.add_argument("--sr", 127 | type=int, 128 | default=16000, 129 | help="Sample rate for mixture input") 130 | parser.add_argument("--dump-dir", 131 | type=str, 132 | default="sep", 133 | help="Directory to dump separated speakers") 134 | parser.add_argument("--mvdr", 135 | type=bool, 136 | default=False, 137 | help="apply mvdr") 138 | args = parser.parse_args() 139 | run(args) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sanyuan-Chen/CSS_with_Conformer/acdda14d370ed4197649e8a12a39ab67989d59af/utils/__init__.py -------------------------------------------------------------------------------- /utils/audio_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import soundfile as sf 4 | import scipy.io.wavfile as wf 5 | 6 | MAX_INT16 = np.iinfo(np.int16).max 7 | EPSILON = np.finfo(np.float32).eps 8 | 9 | 10 | def _parse_script(scp_path, 11 | value_processor=lambda x: x, 12 | num_tokens=2, 13 | restrict=True): 14 | """ 15 | Parse kaldi's script(.scp) file 16 | If num_tokens >= 2, function will check token number 17 | """ 18 | scp_dict = dict() 19 | line = 0 20 | with open(scp_path, "r") as f: 21 | for raw_line in f: 22 | scp_tokens = raw_line.strip().split() 23 | line += 1 24 | if (num_tokens >= 2 and len(scp_tokens) != num_tokens) or ( 25 | restrict and len(scp_tokens) < 2): 26 | raise RuntimeError( 27 | "For {}, format error in line[{:d}]: {}".format( 28 | scp_path, line, raw_line)) 29 | if num_tokens == 2: 30 | key, value = scp_tokens 31 | else: 32 | key, value = scp_tokens[0], scp_tokens[1:] 33 | if key in scp_dict: 34 | raise ValueError("Duplicated key \'{0}\' exists in {1}".format( 35 | key, scp_path)) 36 | scp_dict[key] = value_processor(value) 37 | return scp_dict 38 | 39 | 40 | class BaseReader(object): 41 | """ 42 | BaseReader Class 43 | """ 44 | def __init__(self, scp_rspecifier, **kwargs): 45 | self.index_dict = _parse_script(scp_rspecifier, **kwargs) 46 | self.index_keys = list(self.index_dict.keys()) 47 | 48 | def _load(self, key): 49 | # return path 50 | return self.index_dict[key] 51 | 52 | # number of utterance 53 | def __len__(self): 54 | return len(self.index_dict) 55 | 56 | # avoid key error 57 | def __contains__(self, key): 58 | return key in self.index_dict 59 | 60 | # sequential index 61 | def __iter__(self): 62 | for key in self.index_keys: 63 | yield key, self._load(key) 64 | 65 | 66 | class WaveReader(BaseReader): 67 | """ 68 | Sequential/Random Reader for single channel wave 69 | Format of wav.scp follows Kaldi's definition: 70 | key1 /path/to/wav 71 | ... 72 | """ 73 | def __init__(self, wav_scp, sr=16000, normalize=True): 74 | super(WaveReader, self).__init__(wav_scp) 75 | self.sr = sr 76 | self.normalize = normalize 77 | 78 | def _load(self, key): 79 | # return C x N or N 80 | sr, samps = read_wav(self.index_dict[key], 81 | normalize=self.normalize, 82 | return_rate=True) 83 | # if given samp_rate, check it 84 | if self.sr is not None and sr != self.sr: 85 | raise RuntimeError("Sample rate mismatch: {:d} vs {:d}".format( 86 | sr, self.sr)) 87 | 88 | return samps 89 | 90 | 91 | def read_wav(fname, beg=None, end=None, normalize=True, return_rate=False): 92 | """ 93 | Read wave files using scipy.io.wavfile(support multi-channel) 94 | """ 95 | # samps_int16: N x C or N 96 | # N: number of samples 97 | # C: number of channels 98 | if beg is not None: 99 | samps_int16, samp_rate = sf.read(fname, 100 | start=beg, 101 | stop=end, 102 | dtype="int16") 103 | else: 104 | samp_rate, samps_int16 = wf.read(fname) 105 | # N x C => C x N 106 | samps = samps_int16.astype(np.float32) 107 | # tranpose because I used to put channel axis first 108 | if samps.ndim != 1: 109 | samps = np.transpose(samps) 110 | # normalize like MATLAB and librosa 111 | if normalize: 112 | samps = samps / MAX_INT16 113 | if return_rate: 114 | return samp_rate, samps 115 | return samps 116 | 117 | 118 | def write_wav(fname, samps, sr=16000, normalize=True): 119 | """ 120 | Write wav files in int16, support single/multi-channel 121 | """ 122 | if normalize: 123 | samps = samps * MAX_INT16 124 | # scipy.io.wavfile.write could write single/multi-channel files 125 | # for multi-channel, accept ndarray [Nsamples, Nchannels] 126 | if samps.ndim != 1 and samps.shape[0] < samps.shape[1]: 127 | samps = np.transpose(samps) 128 | samps = np.squeeze(samps) 129 | # same as MATLAB and kaldi 130 | samps_int16 = samps.astype(np.int16) 131 | fdir = os.path.dirname(fname) 132 | if fdir: 133 | os.makedirs(fdir, exist_ok=True) 134 | # NOTE: librosa 0.6.0 seems could not write non-float narray 135 | # so use scipy.io.wavfile instead 136 | wf.write(fname, sr, samps_int16) 137 | 138 | -------------------------------------------------------------------------------- /utils/mvdr_util.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import numpy as np 3 | 4 | 5 | def make_wta(result_mask): 6 | noise_mask = result_mask[2] 7 | if len(result_mask) == 4: 8 | noise_mask += result_mask[3] 9 | mask = np.stack((result_mask[0], result_mask[1],noise_mask)) 10 | mask_max = np.amax(mask, axis=0, keepdims=True) 11 | mask = np.where(mask==mask_max, mask, 1e-10) 12 | return mask 13 | 14 | 15 | def make_mvdr(s,result): 16 | mask=make_wta(result) 17 | M=[] 18 | for i in range(7): 19 | st=librosa.core.stft(s[:,i],n_fft=512,hop_length=256) 20 | M.append(st) 21 | M=np.asarray(M) 22 | 23 | L=np.min([mask.shape[-1],M.shape[-1]]) 24 | M=M[:,:,:L] 25 | 26 | mask=mask[:,:,:L] 27 | 28 | tgt_scm,_=get_mask_scm(M,mask[0]) 29 | itf_scm,_=get_mask_scm(M,mask[1]) 30 | noi_scm,_=get_mask_scm(M,mask[2]) 31 | 32 | coef=calc_bfcoeffs(noi_scm+itf_scm,tgt_scm) 33 | res=get_bf(M,coef) 34 | res1=librosa.istft(res,hop_length=256) 35 | 36 | coef=calc_bfcoeffs(noi_scm+tgt_scm,itf_scm) 37 | res=get_bf(M,coef) 38 | res2=librosa.istft(res,hop_length=256) 39 | 40 | return res1, res2 41 | 42 | 43 | def get_mask_scm(mix,mask): 44 | Ri = np.einsum('FT,FTM,FTm->FMm', mask, mix.transpose(1,2,0), mix.transpose(1,2,0).conj()) 45 | t1=np.eye(7) 46 | t2=t1[np.newaxis,:,:] 47 | Ri+=1e-15*t2 48 | return Ri,np.sum(mask) 49 | 50 | 51 | def calc_bfcoeffs(noi_scm,tgt_scm): 52 | # Calculate BF coeffs. 53 | num = np.linalg.solve(noi_scm, tgt_scm) 54 | den = np.trace(num, axis1=-2, axis2=-1)[..., np.newaxis, np.newaxis] 55 | den[0]+=1e-15 56 | W = (num / den)[..., 0] 57 | return W 58 | 59 | 60 | def get_bf(mix,W): 61 | c,f,t=mix.shape 62 | return np.sum(W.reshape(f,c,1).conj()*mix.transpose(1,0,2),axis=1) 63 | --------------------------------------------------------------------------------