├── .gitignore
├── README.md
├── components
├── semantic_extractor
│ ├── WavLM.py
│ ├── modules.py
│ └── ssl_model.py
└── simcodec
│ ├── __init__.py
│ ├── model.py
│ └── modules.py
├── configs
├── gense.yaml
└── gense_wavlm.yaml
├── fig
└── gense.png
├── infer.py
├── models
├── gense.py
└── gense_wavlm.py
├── noisy.wav
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GenSE: Generative Speech Enhancement via Language Models using Hierarchical Modeling
The official implementation of GenSE (ICLR 2025)
2 |
3 | We propose a comprehensive framework tailored for language model-based speech enhancement, called GenSE. Speech enhancement is regarded as a conditional language modeling task rather than a continuous signal regression problem defined in existing works. This is achieved by tokenizing speech signals into semantic tokens using a pre-trained self-supervised model and into acoustic tokens using a custom-designed single-quantizer neural codec model.
4 |
5 |
6 |
7 |
8 |
9 | GenSE employs a hierarchical modeling framework with a two-stage process: a N2S transformation front-end, which converts noisy speech into clean semantic tokens, and an S2S generation back-end, which synthesizes clean speech using both semantic tokens and noisy acoustic tokens.
10 |
11 | ## TODO 📝
12 | - [x] Release Inference pipeline
13 | - [x] Release pre-trained model
14 | - [ ] Support in colab
15 | - [ ] More to be added
16 |
17 | ## Getting Started 📥
18 |
19 | ### 1. Pre-requisites
20 | 0. Pytorch >=1.13 and torchaudio >= 0.13
21 | 1. Install requirements
22 | ```
23 | conda create -n gense python=3.8
24 | pip install -r requirements.txt
25 | ```
26 |
27 | ### 2. Get Self-supervised Model:
28 | Download [XLSR model](https://huggingface.co/facebook/wav2vec2-xls-r-300m) and move it to ckpts dir.
29 | or
30 | Download [WavLM Large](https://huggingface.co/microsoft/wavlm-large) run a variant of XLSR version.
31 |
32 | ### 3. Pre-trained Model:
33 | Download pre-trained model from [huggingface](https://huggingface.co/yaoxunji/gen-se/tree/main), all checkpoints should be stored in ckpts dir.
34 |
35 | ### 4. Speech Enhancement:
36 | ```
37 | python infer.py run \
38 | --noisy_path noisy.wav
39 | --out_path ./enhanced.wav
40 | --config_path configs/gense.yaml
41 | ```
42 | ### 5. SimCodec Copy-syn:
43 | ```
44 | from components.simcodec.model import SimCodec
45 | codec = SimCodec('config.json')
46 | codec.load_ckpt('g_00100000')
47 | codec = codec.eval()
48 | codec = codec.to('cuda')
49 |
50 | code = codec(wav)
51 | print(code.shape) #[B, L1, 1]
52 | syn = codec.decode(code)
53 | print(syn.shape) #[B, 1, L2]
54 | torchaudio.save('copy.wav', syn.detach().cpu().squeeze(0), 16000)
55 | ```
56 |
57 |
58 |
59 |
68 |
--------------------------------------------------------------------------------
/components/semantic_extractor/WavLM.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4 | # Copyright (c) 2021 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # Based on fairseq code bases
7 | # https://github.com/pytorch/fairseq
8 | # --------------------------------------------------------
9 |
10 | import math
11 | import logging
12 | from typing import List, Optional, Tuple
13 |
14 | import sys,os
15 | sys.path.append(os.path.dirname(sys.path[0]))
16 | import numpy as np
17 |
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | from torch.nn import LayerNorm
22 | from .modules import (
23 | Fp32GroupNorm,
24 | Fp32LayerNorm,
25 | GradMultiply,
26 | MultiheadAttention,
27 | SamePad,
28 | init_bert_params,
29 | get_activation_fn,
30 | TransposeLast,
31 | GLU_Linear,
32 | )
33 |
34 | logger = logging.getLogger(__name__)
35 |
36 |
37 | def compute_mask_indices(
38 | shape: Tuple[int, int],
39 | padding_mask: Optional[torch.Tensor],
40 | mask_prob: float,
41 | mask_length: int,
42 | mask_type: str = "static",
43 | mask_other: float = 0.0,
44 | min_masks: int = 0,
45 | no_overlap: bool = False,
46 | min_space: int = 0,
47 | ) -> np.ndarray:
48 | """
49 | Computes random mask spans for a given shape
50 |
51 | Args:
52 | shape: the the shape for which to compute masks.
53 | should be of size 2 where first element is batch size and 2nd is timesteps
54 | padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
55 | mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
56 | number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
57 | however due to overlaps, the actual number will be smaller (unless no_overlap is True)
58 | mask_type: how to compute mask lengths
59 | static = fixed size
60 | uniform = sample from uniform distribution [mask_other, mask_length*2]
61 | normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
62 | poisson = sample from possion distribution with lambda = mask length
63 | min_masks: minimum number of masked spans
64 | no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
65 | min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
66 | """
67 |
68 | bsz, all_sz = shape
69 | mask = np.full((bsz, all_sz), False)
70 |
71 | all_num_mask = int(
72 | # add a random number for probabilistic rounding
73 | mask_prob * all_sz / float(mask_length)
74 | + np.random.rand()
75 | )
76 |
77 | all_num_mask = max(min_masks, all_num_mask)
78 |
79 | mask_idcs = []
80 | for i in range(bsz):
81 | if padding_mask is not None:
82 | sz = all_sz - padding_mask[i].long().sum().item()
83 | num_mask = int(
84 | # add a random number for probabilistic rounding
85 | mask_prob * sz / float(mask_length)
86 | + np.random.rand()
87 | )
88 | num_mask = max(min_masks, num_mask)
89 | else:
90 | sz = all_sz
91 | num_mask = all_num_mask
92 |
93 | if mask_type == "static":
94 | lengths = np.full(num_mask, mask_length)
95 | elif mask_type == "uniform":
96 | lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
97 | elif mask_type == "normal":
98 | lengths = np.random.normal(mask_length, mask_other, size=num_mask)
99 | lengths = [max(1, int(round(x))) for x in lengths]
100 | elif mask_type == "poisson":
101 | lengths = np.random.poisson(mask_length, size=num_mask)
102 | lengths = [int(round(x)) for x in lengths]
103 | else:
104 | raise Exception("unknown mask selection " + mask_type)
105 |
106 | if sum(lengths) == 0:
107 | lengths[0] = min(mask_length, sz - 1)
108 |
109 | if no_overlap:
110 | mask_idc = []
111 |
112 | def arrange(s, e, length, keep_length):
113 | span_start = np.random.randint(s, e - length)
114 | mask_idc.extend(span_start + i for i in range(length))
115 |
116 | new_parts = []
117 | if span_start - s - min_space >= keep_length:
118 | new_parts.append((s, span_start - min_space + 1))
119 | if e - span_start - keep_length - min_space > keep_length:
120 | new_parts.append((span_start + length + min_space, e))
121 | return new_parts
122 |
123 | parts = [(0, sz)]
124 | min_length = min(lengths)
125 | for length in sorted(lengths, reverse=True):
126 | lens = np.fromiter(
127 | (e - s if e - s >= length + min_space else 0 for s, e in parts),
128 | np.int,
129 | )
130 | l_sum = np.sum(lens)
131 | if l_sum == 0:
132 | break
133 | probs = lens / np.sum(lens)
134 | c = np.random.choice(len(parts), p=probs)
135 | s, e = parts.pop(c)
136 | parts.extend(arrange(s, e, length, min_length))
137 | mask_idc = np.asarray(mask_idc)
138 | else:
139 | min_len = min(lengths)
140 | if sz - min_len <= num_mask:
141 | min_len = sz - num_mask - 1
142 |
143 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
144 |
145 | mask_idc = np.asarray(
146 | [
147 | mask_idc[j] + offset
148 | for j in range(len(mask_idc))
149 | for offset in range(lengths[j])
150 | ]
151 | )
152 |
153 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
154 |
155 | min_len = min([len(m) for m in mask_idcs])
156 | for i, mask_idc in enumerate(mask_idcs):
157 | if len(mask_idc) > min_len:
158 | mask_idc = np.random.choice(mask_idc, min_len, replace=False)
159 | mask[i, mask_idc] = True
160 |
161 | return mask
162 |
163 |
164 | class WavLMConfig:
165 | def __init__(self, cfg=None):
166 | self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True)
167 | self.encoder_layers: int = 12 # num encoder layers in the transformer
168 |
169 | self.encoder_embed_dim: int = 768 # encoder embedding dimension
170 | self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN
171 | self.encoder_attention_heads: int = 12 # num encoder attention heads
172 | self.activation_fn: str = "gelu" # activation function to use
173 |
174 | self.layer_norm_first: bool = False # apply layernorm first in the transformer
175 | self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...]
176 | self.conv_bias: bool = False # include bias in conv encoder
177 | self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this
178 |
179 | self.normalize: bool = False # normalize input to have 0 mean and unit variance during training
180 |
181 | # dropouts
182 | self.dropout: float = 0.1 # dropout probability for the transformer
183 | self.attention_dropout: float = 0.1 # dropout probability for attention weights
184 | self.activation_dropout: float = 0.0 # dropout probability after activation in FFN
185 | self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer
186 | self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr)
187 | self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr)
188 |
189 | # masking
190 | self.mask_length: int = 10 # mask length
191 | self.mask_prob: float = 0.65 # probability of replacing a token with mask
192 | self.mask_selection: str = "static" # how to choose mask length
193 | self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh
194 | self.no_mask_overlap: bool = False # whether to allow masks to overlap
195 | self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled)
196 |
197 | # channel masking
198 | self.mask_channel_length: int = 10 # length of the mask for features (channels)
199 | self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0
200 | self.mask_channel_selection: str = "static" # how to choose mask length for channel masking
201 | self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices
202 | self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap
203 | self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled)
204 |
205 | # positional embeddings
206 | self.conv_pos: int = 128 # number of filters for convolutional positional embeddings
207 | self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding
208 |
209 | # relative position embedding
210 | self.relative_position_embedding: bool = False # apply relative position embedding
211 | self.num_buckets: int = 320 # number of buckets for relative position embedding
212 | self.max_distance: int = 1280 # maximum distance for relative position embedding
213 | self.gru_rel_pos: bool = False # apply gated relative position embedding
214 |
215 | if cfg is not None:
216 | self.update(cfg)
217 |
218 | def update(self, cfg: dict):
219 | self.__dict__.update(cfg)
220 |
221 |
222 | class WavLM(nn.Module):
223 | def __init__(
224 | self,
225 | cfg: WavLMConfig,
226 | ) -> None:
227 | super().__init__()
228 | logger.info(f"WavLM Config: {cfg.__dict__}")
229 |
230 | self.cfg = cfg
231 | feature_enc_layers = eval(cfg.conv_feature_layers)
232 | self.embed = feature_enc_layers[-1][0]
233 |
234 | self.feature_extractor = ConvFeatureExtractionModel(
235 | conv_layers=feature_enc_layers,
236 | dropout=0.0,
237 | mode=cfg.extractor_mode,
238 | conv_bias=cfg.conv_bias,
239 | )
240 |
241 | self.post_extract_proj = (
242 | nn.Linear(self.embed, cfg.encoder_embed_dim)
243 | if self.embed != cfg.encoder_embed_dim
244 | else None
245 | )
246 |
247 | self.mask_prob = cfg.mask_prob
248 | self.mask_selection = cfg.mask_selection
249 | self.mask_other = cfg.mask_other
250 | self.mask_length = cfg.mask_length
251 | self.no_mask_overlap = cfg.no_mask_overlap
252 | self.mask_min_space = cfg.mask_min_space
253 |
254 | self.mask_channel_prob = cfg.mask_channel_prob
255 | self.mask_channel_selection = cfg.mask_channel_selection
256 | self.mask_channel_other = cfg.mask_channel_other
257 | self.mask_channel_length = cfg.mask_channel_length
258 | self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
259 | self.mask_channel_min_space = cfg.mask_channel_min_space
260 |
261 | self.dropout_input = nn.Dropout(cfg.dropout_input)
262 | self.dropout_features = nn.Dropout(cfg.dropout_features)
263 |
264 | self.feature_grad_mult = cfg.feature_grad_mult
265 |
266 | self.mask_emb = nn.Parameter(
267 | torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
268 | )
269 |
270 | self.encoder = TransformerEncoder(cfg)
271 | self.layer_norm = LayerNorm(self.embed)
272 |
273 | def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
274 | """
275 | Computes the output length of the convolutional layers
276 | """
277 |
278 | def _conv_out_length(input_length, kernel_size, stride):
279 | return torch.floor((input_length - kernel_size) / stride + 1)
280 |
281 | conv_cfg_list = eval(self.cfg.conv_feature_layers)
282 |
283 | out_lengths_list = []
284 | for i in range(len(conv_cfg_list)):
285 | input_lengths = _conv_out_length(
286 | input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2]
287 | )
288 | out_lengths_list.append(input_lengths)
289 |
290 | return input_lengths.to(torch.long), out_lengths_list
291 |
292 | def apply_mask(self, x, padding_mask):
293 | B, T, C = x.shape
294 | if self.mask_prob > 0:
295 | mask_indices = compute_mask_indices(
296 | (B, T),
297 | padding_mask,
298 | self.mask_prob,
299 | self.mask_length,
300 | self.mask_selection,
301 | self.mask_other,
302 | min_masks=2,
303 | no_overlap=self.no_mask_overlap,
304 | min_space=self.mask_min_space,
305 | )
306 | mask_indices = torch.from_numpy(mask_indices).to(x.device)
307 | x[mask_indices] = self.mask_emb
308 | else:
309 | mask_indices = None
310 |
311 | if self.mask_channel_prob > 0:
312 | mask_channel_indices = compute_mask_indices(
313 | (B, C),
314 | None,
315 | self.mask_channel_prob,
316 | self.mask_channel_length,
317 | self.mask_channel_selection,
318 | self.mask_channel_other,
319 | no_overlap=self.no_mask_channel_overlap,
320 | min_space=self.mask_channel_min_space,
321 | )
322 | mask_channel_indices = (
323 | torch.from_numpy(mask_channel_indices)
324 | .to(x.device)
325 | .unsqueeze(1)
326 | .expand(-1, T, -1)
327 | )
328 | x[mask_channel_indices] = 0
329 |
330 | return x, mask_indices
331 |
332 | def forward_padding_mask(
333 | self, features: torch.Tensor, padding_mask: torch.Tensor,
334 | ) -> torch.Tensor:
335 | extra = padding_mask.size(1) % features.size(1)
336 | if extra > 0:
337 | padding_mask = padding_mask[:, :-extra]
338 | padding_mask = padding_mask.view(
339 | padding_mask.size(0), features.size(1), -1
340 | )
341 | padding_mask = padding_mask.all(-1)
342 | return padding_mask
343 |
344 | def sequence_mask(self, sequence_length, max_len=None):
345 | """Create a sequence mask for filtering padding in a sequence tensor.
346 | Args:
347 | sequence_length (torch.tensor): Sequence lengths.
348 | max_len (int, Optional): Maximum sequence length. Defaults to None.
349 | Shapes:
350 | - mask: :math:`[B, T_max]`
351 | """
352 | if max_len is None:
353 | max_len = sequence_length.data.max()
354 | seq_range = torch.arange(max_len,
355 | dtype=sequence_length.dtype,
356 | device=sequence_length.device)
357 | # B x T_max
358 | mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
359 | return mask
360 |
361 | def extract_features(
362 | self,
363 | source: torch.Tensor,
364 | padding_mask: Optional[torch.Tensor] = None,
365 | mask: bool = False,
366 | ret_conv: bool = False,
367 | output_layer: Optional[int] = None,
368 | ret_layer_results: bool = False,
369 | input_length: Optional[torch.Tensor] = None
370 | ):
371 | out_lengths_list = None
372 | if input_length is not None:
373 | out_conv_lengths, out_lengths_list = self._get_feat_extract_output_lengths(input_length)
374 | else:
375 | out_conv_lengths, out_lengths_list = self._get_feat_extract_output_lengths(torch.tensor([source.shape[-1] for _ in range(source.shape[0])]).to(source.device))
376 |
377 | if self.feature_grad_mult > 0:
378 | features = self.feature_extractor(source, input_lengths=input_length, out_lengths_list=out_lengths_list)
379 | if self.feature_grad_mult != 1.0:
380 | features = GradMultiply.apply(features, self.feature_grad_mult)
381 | else:
382 | with torch.no_grad():
383 | features = self.feature_extractor(source)
384 |
385 | features = features.transpose(1, 2)
386 | features = self.layer_norm(features)
387 |
388 | # if padding_mask is not None:
389 | # padding_mask = self.forward_padding_mask(features, padding_mask)
390 |
391 | if self.post_extract_proj is not None:
392 | features *= self.sequence_mask(out_conv_lengths).unsqueeze(-1)
393 | features = self.post_extract_proj(features)
394 | features *= self.sequence_mask(out_conv_lengths).unsqueeze(-1)
395 |
396 |
397 | features = self.dropout_input(features)
398 | # return features
399 |
400 | if mask:
401 | x, mask_indices = self.apply_mask(
402 | features, padding_mask
403 | )
404 | else:
405 | x = features
406 |
407 | # feature: (B, T, D), float
408 | # target: (B, T), long
409 | # x: (B, T, D), float
410 | # padding_mask: (B, T), bool
411 | # mask_indices: (B, T), bool
412 | if source.shape[0] == 1:
413 | padding_mask = None
414 | else:
415 | padding_mask = ~self.sequence_mask(out_conv_lengths)
416 |
417 | x, layer_results = self.encoder(
418 | x,
419 | padding_mask=padding_mask,
420 | layer=None if output_layer is None else output_layer - 1
421 | )
422 |
423 | res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
424 |
425 | feature = res["features"] if ret_conv else res["x"]
426 | if ret_layer_results:
427 | feature = (feature, res["layer_results"])
428 | return feature, res["padding_mask"]
429 |
430 |
431 | def long_term_modeling(
432 | self,
433 | source: torch.Tensor,
434 | padding_mask: Optional[torch.Tensor] = None,
435 | mask: bool = False,
436 | ret_conv: bool = False,
437 | output_layer: Optional[int] = None,
438 | ret_layer_results: bool = False,
439 | ):
440 |
441 | features = source.transpose(1, 2)
442 | features = self.layer_norm(features)
443 |
444 | if padding_mask is not None:
445 | padding_mask = self.forward_padding_mask(features, padding_mask)
446 |
447 | if self.post_extract_proj is not None:
448 | features = self.post_extract_proj(features)
449 |
450 | features = self.dropout_input(features)
451 |
452 | if mask:
453 | x, mask_indices = self.apply_mask(
454 | features, padding_mask
455 | )
456 | else:
457 | x = features
458 |
459 | # feature: (B, T, D), float
460 | # target: (B, T), long
461 | # x: (B, T, D), float
462 | # padding_mask: (B, T), bool
463 | # mask_indices: (B, T), bool
464 | x, layer_results = self.encoder(
465 | x,
466 | padding_mask=padding_mask,
467 | layer=None if output_layer is None else output_layer - 1
468 | )
469 |
470 | res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results}
471 |
472 | feature = res["features"] if ret_conv else res["x"]
473 | if ret_layer_results:
474 | feature = (feature, res["layer_results"])
475 | return feature, res["padding_mask"]
476 |
477 |
478 |
479 | class ConvFeatureExtractionModel(nn.Module):
480 | def __init__(
481 | self,
482 | conv_layers: List[Tuple[int, int, int]],
483 | dropout: float = 0.0,
484 | mode: str = "default",
485 | conv_bias: bool = False,
486 | conv_type: str = "default"
487 | ):
488 | super().__init__()
489 |
490 | assert mode in {"default", "layer_norm"}
491 |
492 | def block(
493 | n_in,
494 | n_out,
495 | k,
496 | stride,
497 | is_layer_norm=False,
498 | is_group_norm=False,
499 | conv_bias=False,
500 | ):
501 | def make_conv():
502 | conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
503 | nn.init.kaiming_normal_(conv.weight)
504 | return conv
505 |
506 | assert (
507 | is_layer_norm and is_group_norm
508 | ) == False, "layer norm and group norm are exclusive"
509 |
510 | if is_layer_norm:
511 | return nn.Sequential(
512 | make_conv(),
513 | nn.Dropout(p=dropout),
514 | nn.Sequential(
515 | TransposeLast(),
516 | Fp32LayerNorm(dim, elementwise_affine=True),
517 | TransposeLast(),
518 | ),
519 | nn.GELU(),
520 | )
521 | # elif is_group_norm:
522 | # return nn.Sequential(
523 | # make_conv(),
524 | # nn.Dropout(p=dropout),
525 | # Fp32GroupNorm(dim, dim, affine=True),
526 | # nn.GELU(),
527 | # )
528 | # else:
529 | # return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
530 |
531 | self.conv_type = conv_type
532 | if self.conv_type == "default":
533 | in_d = 1
534 | self.conv_layers = nn.ModuleList()
535 | for i, cl in enumerate(conv_layers):
536 | assert len(cl) == 3, "invalid conv definition: " + str(cl)
537 | (dim, k, stride) = cl
538 |
539 | self.conv_layers.append(
540 | block(
541 | in_d,
542 | dim,
543 | k,
544 | stride,
545 | is_layer_norm=mode == "layer_norm",
546 | is_group_norm=mode == "default" and i == 0,
547 | conv_bias=conv_bias,
548 | )
549 | )
550 | in_d = dim
551 | elif self.conv_type == "conv2d":
552 | in_d = 1
553 | self.conv_layers = nn.ModuleList()
554 | for i, cl in enumerate(conv_layers):
555 | assert len(cl) == 3
556 | (dim, k, stride) = cl
557 |
558 | self.conv_layers.append(
559 | torch.nn.Conv2d(in_d, dim, k, stride)
560 | )
561 | self.conv_layers.append(torch.nn.ReLU())
562 | in_d = dim
563 | elif self.conv_type == "custom":
564 | in_d = 1
565 | idim = 80
566 | self.conv_layers = nn.ModuleList()
567 | for i, cl in enumerate(conv_layers):
568 | assert len(cl) == 3
569 | (dim, k, stride) = cl
570 | self.conv_layers.append(
571 | torch.nn.Conv2d(in_d, dim, k, stride, padding=1)
572 | )
573 | self.conv_layers.append(
574 | torch.nn.LayerNorm([dim, idim])
575 | )
576 | self.conv_layers.append(torch.nn.ReLU())
577 | in_d = dim
578 | if (i + 1) % 2 == 0:
579 | self.conv_layers.append(
580 | torch.nn.MaxPool2d(2, stride=2, ceil_mode=True)
581 | )
582 | idim = int(math.ceil(idim / 2))
583 | else:
584 | pass
585 |
586 | def sequence_mask(self, sequence_length, max_len=None):
587 | """Create a sequence mask for filtering padding in a sequence tensor.
588 | Args:
589 | sequence_length (torch.tensor): Sequence lengths.
590 | max_len (int, Optional): Maximum sequence length. Defaults to None.
591 | Shapes:
592 | - mask: :math:`[B, T_max]`
593 | """
594 | if max_len is None:
595 | max_len = sequence_length.data.max()
596 | seq_range = torch.arange(max_len,
597 | dtype=sequence_length.dtype,
598 | device=sequence_length.device)
599 | # B x T_max
600 | mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1)
601 | return mask
602 |
603 | def forward(self, x, mask=None, input_lengths=None, out_lengths_list=None):
604 |
605 | # BxT -> BxCxT
606 | x = x.unsqueeze(1)
607 | # if self.conv_type == "custom":
608 | # for conv in self.conv_layers:
609 | # if isinstance(conv, nn.LayerNorm):
610 | # x = x.transpose(1, 2)
611 | # x = conv(x).transpose(1, 2)
612 | # else:
613 | # x = conv(x)
614 | # x = x.transpose(2, 3).contiguous()
615 | # x = x.view(x.size(0), -1, x.size(-1))
616 | # else:
617 |
618 | for idx, conv in enumerate(self.conv_layers):
619 | x = conv(x)
620 | # if idx == 0:
621 | # x = conv(x * self.sequence_mask(input_lengths).unsqueeze(1))
622 | # else:
623 | # if len(out_lengths_list[idx-1]) == 1:
624 | # x = conv(x * self.sequence_mask(out_lengths_list[idx-1]))
625 | # else:
626 | # x = conv(x * self.sequence_mask(out_lengths_list[idx-1]).unsqueeze(1))
627 | # if len(out_lengths_list[idx-1]) == 1:
628 | # x *= self.sequence_mask(out_lengths_list[idx].unsqueeze(0))
629 | # else:
630 | # x *= self.sequence_mask(out_lengths_list[idx].unsqueeze(1))
631 | # if self.conv_type == "conv2d":
632 | # b, c, t, f = x.size()
633 | # x = x.transpose(2, 3).contiguous().view(b, c * f, t)
634 | return x
635 |
636 |
637 | class TransformerEncoder(nn.Module):
638 | def __init__(self, args):
639 | super().__init__()
640 |
641 | self.dropout = args.dropout
642 | self.embedding_dim = args.encoder_embed_dim
643 |
644 | self.pos_conv = nn.Conv1d(
645 | self.embedding_dim,
646 | self.embedding_dim,
647 | kernel_size=args.conv_pos,
648 | padding=args.conv_pos // 2,
649 | groups=args.conv_pos_groups,
650 | )
651 | dropout = 0
652 | std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
653 | nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
654 | nn.init.constant_(self.pos_conv.bias, 0)
655 |
656 | self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
657 | self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
658 |
659 | if hasattr(args, "relative_position_embedding"):
660 | self.relative_position_embedding = args.relative_position_embedding
661 | self.num_buckets = args.num_buckets
662 | self.max_distance = args.max_distance
663 | else:
664 | self.relative_position_embedding = False
665 | self.num_buckets = 0
666 | self.max_distance = 0
667 |
668 | self.layers = nn.ModuleList(
669 | [
670 | TransformerSentenceEncoderLayer(
671 | embedding_dim=self.embedding_dim,
672 | ffn_embedding_dim=args.encoder_ffn_embed_dim,
673 | num_attention_heads=args.encoder_attention_heads,
674 | dropout=self.dropout,
675 | attention_dropout=args.attention_dropout,
676 | activation_dropout=args.activation_dropout,
677 | activation_fn=args.activation_fn,
678 | layer_norm_first=args.layer_norm_first,
679 | has_relative_attention_bias=(self.relative_position_embedding and i == 0),
680 | num_buckets=self.num_buckets,
681 | max_distance=self.max_distance,
682 | gru_rel_pos=args.gru_rel_pos,
683 | )
684 | for i in range(args.encoder_layers)
685 | ]
686 | )
687 |
688 | self.layer_norm_first = args.layer_norm_first
689 | self.layer_norm = LayerNorm(self.embedding_dim)
690 | self.layerdrop = args.encoder_layerdrop
691 |
692 | self.apply(init_bert_params)
693 |
694 | def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
695 | x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
696 |
697 | if self.layer_norm_first and layer is None:
698 | x = self.layer_norm(x)
699 |
700 | return x, layer_results
701 |
702 | def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
703 |
704 | if padding_mask is not None:
705 | x[padding_mask] = 0
706 |
707 | y = x.transpose(1, 2).clone()
708 | x_conv = self.pos_conv(y)
709 | x_conv = x_conv.transpose(1, 2)
710 | x += x_conv
711 |
712 | if not self.layer_norm_first:
713 | x = self.layer_norm(x)
714 |
715 | x = F.dropout(x, p=self.dropout, training=self.training)
716 |
717 | # B x T x C -> T x B x C
718 | x = x.transpose(0, 1)
719 |
720 | layer_results = []
721 | z = None
722 | if tgt_layer is not None:
723 | layer_results.append((x, z))
724 | r = None
725 | pos_bias = None
726 | for i, layer in enumerate(self.layers):
727 | dropout_probability = np.random.random()
728 | if not self.training or (dropout_probability > self.layerdrop):
729 | x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False,
730 | self_attn_mask=streaming_mask, pos_bias=pos_bias)
731 | if tgt_layer is not None:
732 | layer_results.append((x, z))
733 | if i == tgt_layer:
734 | r = x
735 | break
736 |
737 | if r is not None:
738 | x = r
739 |
740 | # T x B x C -> B x T x C
741 | x = x.transpose(0, 1)
742 |
743 | return x, layer_results
744 |
745 |
746 | class TransformerSentenceEncoderLayer(nn.Module):
747 | """
748 | Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
749 | models.
750 | """
751 |
752 | def __init__(
753 | self,
754 | embedding_dim: float = 768,
755 | ffn_embedding_dim: float = 3072,
756 | num_attention_heads: float = 8,
757 | dropout: float = 0.1,
758 | attention_dropout: float = 0.1,
759 | activation_dropout: float = 0.1,
760 | activation_fn: str = "relu",
761 | layer_norm_first: bool = False,
762 | has_relative_attention_bias: bool = False,
763 | num_buckets: int = 0,
764 | max_distance: int = 0,
765 | rescale_init: bool = False,
766 | gru_rel_pos: bool = False,
767 | ) -> None:
768 |
769 | super().__init__()
770 | # Initialize parameters
771 | self.embedding_dim = embedding_dim
772 | self.dropout = dropout
773 | self.activation_dropout = activation_dropout
774 |
775 | # Initialize blocks
776 | self.activation_name = activation_fn
777 | self.activation_fn = get_activation_fn(activation_fn)
778 | self.self_attn = MultiheadAttention(
779 | self.embedding_dim,
780 | num_attention_heads,
781 | dropout=attention_dropout,
782 | self_attention=True,
783 | has_relative_attention_bias=has_relative_attention_bias,
784 | num_buckets=num_buckets,
785 | max_distance=max_distance,
786 | rescale_init=rescale_init,
787 | gru_rel_pos=gru_rel_pos,
788 | )
789 |
790 | self.dropout1 = nn.Dropout(dropout)
791 | self.dropout2 = nn.Dropout(self.activation_dropout)
792 | self.dropout3 = nn.Dropout(dropout)
793 |
794 | self.layer_norm_first = layer_norm_first
795 |
796 | # layer norm associated with the self attention layer
797 | self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
798 |
799 | if self.activation_name == "glu":
800 | self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish")
801 | else:
802 | self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
803 | self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
804 |
805 | # layer norm associated with the position wise feed-forward NN
806 | self.final_layer_norm = LayerNorm(self.embedding_dim)
807 |
808 | def forward(
809 | self,
810 | x: torch.Tensor,
811 | self_attn_mask: torch.Tensor = None,
812 | self_attn_padding_mask: torch.Tensor = None,
813 | need_weights: bool = False,
814 | pos_bias=None
815 | ):
816 | """
817 | LayerNorm is applied either before or after the self-attention/ffn
818 | modules similar to the original Transformer imlementation.
819 | """
820 | residual = x
821 |
822 | if self.layer_norm_first:
823 | x = self.self_attn_layer_norm(x)
824 | x, attn, pos_bias = self.self_attn(
825 | query=x,
826 | key=x,
827 | value=x,
828 | key_padding_mask=self_attn_padding_mask,
829 | need_weights=False,
830 | attn_mask=self_attn_mask,
831 | position_bias=pos_bias
832 | )
833 | x = self.dropout1(x)
834 | x = residual + x
835 |
836 | residual = x
837 | x = self.final_layer_norm(x)
838 | if self.activation_name == "glu":
839 | x = self.fc1(x)
840 | else:
841 | x = self.activation_fn(self.fc1(x))
842 | x = self.dropout2(x)
843 | x = self.fc2(x)
844 | x = self.dropout3(x)
845 | x = residual + x
846 | else:
847 | x, attn, pos_bias = self.self_attn(
848 | query=x,
849 | key=x,
850 | value=x,
851 | key_padding_mask=self_attn_padding_mask,
852 | need_weights=need_weights,
853 | attn_mask=self_attn_mask,
854 | position_bias=pos_bias
855 | )
856 |
857 | x = self.dropout1(x)
858 | x = residual + x
859 |
860 | x = self.self_attn_layer_norm(x)
861 |
862 | residual = x
863 | if self.activation_name == "glu":
864 | x = self.fc1(x)
865 | else:
866 | x = self.activation_fn(self.fc1(x))
867 | x = self.dropout2(x)
868 | x = self.fc2(x)
869 | x = self.dropout3(x)
870 | x = residual + x
871 | x = self.final_layer_norm(x)
872 |
873 | return x, attn, pos_bias
874 |
--------------------------------------------------------------------------------
/components/semantic_extractor/modules.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf)
3 | # Github source: https://github.com/microsoft/unilm/tree/master/wavlm
4 | # Copyright (c) 2021 Microsoft
5 | # Licensed under The MIT License [see LICENSE for details]
6 | # Based on fairseq code bases
7 | # https://github.com/pytorch/fairseq
8 | # --------------------------------------------------------
9 |
10 | import math
11 | import warnings
12 | from typing import Dict, Optional, Tuple
13 | import torch
14 | from torch import Tensor, nn
15 | from torch.nn import Parameter
16 | import torch.nn.functional as F
17 |
18 | class TransposeLast(nn.Module):
19 | def __init__(self, deconstruct_idx=None):
20 | super().__init__()
21 | self.deconstruct_idx = deconstruct_idx
22 |
23 | def forward(self, x):
24 | if self.deconstruct_idx is not None:
25 | x = x[self.deconstruct_idx]
26 | return x.transpose(-2, -1)
27 |
28 |
29 | class Fp32LayerNorm(nn.LayerNorm):
30 | def __init__(self, *args, **kwargs):
31 | super().__init__(*args, **kwargs)
32 |
33 | def forward(self, input):
34 | output = F.layer_norm(
35 | input.float(),
36 | self.normalized_shape,
37 | self.weight.float() if self.weight is not None else None,
38 | self.bias.float() if self.bias is not None else None,
39 | self.eps,
40 | )
41 | return output.type_as(input)
42 |
43 |
44 | class Fp32GroupNorm(nn.GroupNorm):
45 | def __init__(self, *args, **kwargs):
46 | super().__init__(*args, **kwargs)
47 |
48 | def forward(self, input):
49 | output = F.group_norm(
50 | input.float(),
51 | self.num_groups,
52 | self.weight.float() if self.weight is not None else None,
53 | self.bias.float() if self.bias is not None else None,
54 | self.eps,
55 | )
56 | return output.type_as(input)
57 |
58 |
59 | class GradMultiply(torch.autograd.Function):
60 | @staticmethod
61 | def forward(ctx, x, scale):
62 | ctx.scale = scale
63 | res = x.new(x)
64 | return res
65 |
66 | @staticmethod
67 | def backward(ctx, grad):
68 | return grad * ctx.scale, None
69 |
70 |
71 | class SamePad(nn.Module):
72 | def __init__(self, kernel_size, causal=False):
73 | super().__init__()
74 | if causal:
75 | self.remove = kernel_size - 1
76 | else:
77 | self.remove = 1 if kernel_size % 2 == 0 else 0
78 |
79 | def forward(self, x):
80 | if self.remove > 0:
81 | x = x[:, :, : -self.remove]
82 | return x
83 |
84 |
85 | class Swish(nn.Module):
86 | """Swish function
87 | """
88 |
89 | def __init__(self):
90 | """Construct an MultiHeadedAttention object."""
91 | super(Swish, self).__init__()
92 | self.act = torch.nn.Sigmoid()
93 |
94 | def forward(self, x):
95 | return x * self.act(x)
96 |
97 |
98 | class GLU_Linear(nn.Module):
99 | def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True):
100 | super(GLU_Linear, self).__init__()
101 |
102 | self.glu_type = glu_type
103 | self.output_dim = output_dim
104 |
105 | if glu_type == "sigmoid":
106 | self.glu_act = torch.nn.Sigmoid()
107 | elif glu_type == "swish":
108 | self.glu_act = Swish()
109 | elif glu_type == "relu":
110 | self.glu_act = torch.nn.ReLU()
111 | elif glu_type == "gelu":
112 | self.glu_act = torch.nn.GELU()
113 |
114 | if bias_in_glu:
115 | self.linear = nn.Linear(input_dim, output_dim * 2, True)
116 | else:
117 | self.linear = nn.Linear(input_dim, output_dim * 2, False)
118 |
119 | def forward(self, x):
120 | # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case
121 | x = self.linear(x)
122 |
123 | if self.glu_type == "bilinear":
124 | x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2])
125 | else:
126 | x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2]))
127 |
128 | return x
129 |
130 | def gelu_accurate(x):
131 | if not hasattr(gelu_accurate, "_a"):
132 | gelu_accurate._a = math.sqrt(2 / math.pi)
133 | return (
134 | 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3))))
135 | )
136 |
137 |
138 | def gelu(x: torch.Tensor) -> torch.Tensor:
139 | return torch.nn.functional.gelu(x.float()).type_as(x)
140 |
141 |
142 | def get_activation_fn(activation: str):
143 | """Returns the activation function corresponding to `activation`"""
144 |
145 | if activation == "relu":
146 | return F.relu
147 | elif activation == "gelu":
148 | return gelu
149 | elif activation == "gelu_fast":
150 | warnings.warn(
151 | "--activation-fn=gelu_fast has been renamed to gelu_accurate"
152 | )
153 | return gelu_accurate
154 | elif activation == "gelu_accurate":
155 | return gelu_accurate
156 | elif activation == "tanh":
157 | return torch.tanh
158 | elif activation == "linear":
159 | return lambda x: x
160 | elif activation == "glu":
161 | return lambda x: x
162 | else:
163 | raise RuntimeError("--activation-fn {} not supported".format(activation))
164 |
165 |
166 | def init_bert_params(module):
167 | """
168 | Initialize the weights specific to the BERT Model.
169 | This overrides the default initializations depending on the specified arguments.
170 | 1. If normal_init_linear_weights is set then weights of linear
171 | layer will be initialized using the normal distribution and
172 | bais will be set to the specified value.
173 | 2. If normal_init_embed_weights is set then weights of embedding
174 | layer will be initialized using the normal distribution.
175 | 3. If normal_init_proj_weights is set then weights of
176 | in_project_weight for MultiHeadAttention initialized using
177 | the normal distribution (to be validated).
178 | """
179 |
180 | def normal_(data):
181 | # with FSDP, module params will be on CUDA, so we cast them back to CPU
182 | # so that the RNG is consistent with and without FSDP
183 | data.copy_(
184 | data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
185 | )
186 |
187 | if isinstance(module, nn.Linear):
188 | normal_(module.weight.data)
189 | if module.bias is not None:
190 | module.bias.data.zero_()
191 | if isinstance(module, nn.Embedding):
192 | normal_(module.weight.data)
193 | if module.padding_idx is not None:
194 | module.weight.data[module.padding_idx].zero_()
195 | if isinstance(module, MultiheadAttention):
196 | normal_(module.q_proj.weight.data)
197 | normal_(module.k_proj.weight.data)
198 | normal_(module.v_proj.weight.data)
199 |
200 |
201 | def quant_noise(module, p, block_size):
202 | """
203 | Wraps modules and applies quantization noise to the weights for
204 | subsequent quantization with Iterative Product Quantization as
205 | described in "Training with Quantization Noise for Extreme Model Compression"
206 |
207 | Args:
208 | - module: nn.Module
209 | - p: amount of Quantization Noise
210 | - block_size: size of the blocks for subsequent quantization with iPQ
211 |
212 | Remarks:
213 | - Module weights must have the right sizes wrt the block size
214 | - Only Linear, Embedding and Conv2d modules are supported for the moment
215 | - For more detail on how to quantize by blocks with convolutional weights,
216 | see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks"
217 | - We implement the simplest form of noise here as stated in the paper
218 | which consists in randomly dropping blocks
219 | """
220 |
221 | # if no quantization noise, don't register hook
222 | if p <= 0:
223 | return module
224 |
225 | # supported modules
226 | assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d))
227 |
228 | # test whether module.weight has the right sizes wrt block_size
229 | is_conv = module.weight.ndim == 4
230 |
231 | # 2D matrix
232 | if not is_conv:
233 | assert (
234 | module.weight.size(1) % block_size == 0
235 | ), "Input features must be a multiple of block sizes"
236 |
237 | # 4D matrix
238 | else:
239 | # 1x1 convolutions
240 | if module.kernel_size == (1, 1):
241 | assert (
242 | module.in_channels % block_size == 0
243 | ), "Input channels must be a multiple of block sizes"
244 | # regular convolutions
245 | else:
246 | k = module.kernel_size[0] * module.kernel_size[1]
247 | assert k % block_size == 0, "Kernel size must be a multiple of block size"
248 |
249 | def _forward_pre_hook(mod, input):
250 | # no noise for evaluation
251 | if mod.training:
252 | if not is_conv:
253 | # gather weight and sizes
254 | weight = mod.weight
255 | in_features = weight.size(1)
256 | out_features = weight.size(0)
257 |
258 | # split weight matrix into blocks and randomly drop selected blocks
259 | mask = torch.zeros(
260 | in_features // block_size * out_features, device=weight.device
261 | )
262 | mask.bernoulli_(p)
263 | mask = mask.repeat_interleave(block_size, -1).view(-1, in_features)
264 |
265 | else:
266 | # gather weight and sizes
267 | weight = mod.weight
268 | in_channels = mod.in_channels
269 | out_channels = mod.out_channels
270 |
271 | # split weight matrix into blocks and randomly drop selected blocks
272 | if mod.kernel_size == (1, 1):
273 | mask = torch.zeros(
274 | int(in_channels // block_size * out_channels),
275 | device=weight.device,
276 | )
277 | mask.bernoulli_(p)
278 | mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels)
279 | else:
280 | mask = torch.zeros(
281 | weight.size(0), weight.size(1), device=weight.device
282 | )
283 | mask.bernoulli_(p)
284 | mask = (
285 | mask.unsqueeze(2)
286 | .unsqueeze(3)
287 | .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1])
288 | )
289 |
290 | # scale weights and apply mask
291 | mask = mask.to(
292 | torch.bool
293 | ) # x.bool() is not currently supported in TorchScript
294 | s = 1 / (1 - p)
295 | mod.weight.data = s * weight.masked_fill(mask, 0)
296 |
297 | module.register_forward_pre_hook(_forward_pre_hook)
298 | return module
299 |
300 |
301 | class MultiheadAttention(nn.Module):
302 | """Multi-headed attention.
303 |
304 | See "Attention Is All You Need" for more details.
305 | """
306 |
307 | def __init__(
308 | self,
309 | embed_dim,
310 | num_heads,
311 | kdim=None,
312 | vdim=None,
313 | dropout=0.0,
314 | bias=True,
315 | add_bias_kv=False,
316 | add_zero_attn=False,
317 | self_attention=False,
318 | encoder_decoder_attention=False,
319 | q_noise=0.0,
320 | qn_block_size=8,
321 | has_relative_attention_bias=False,
322 | num_buckets=32,
323 | max_distance=128,
324 | gru_rel_pos=False,
325 | rescale_init=False,
326 | ):
327 | super().__init__()
328 | self.embed_dim = embed_dim
329 | self.kdim = kdim if kdim is not None else embed_dim
330 | self.vdim = vdim if vdim is not None else embed_dim
331 | self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
332 |
333 | self.num_heads = num_heads
334 | self.dropout_module = nn.Dropout(dropout)
335 |
336 | self.has_relative_attention_bias = has_relative_attention_bias
337 | self.num_buckets = num_buckets
338 | self.max_distance = max_distance
339 | if self.has_relative_attention_bias:
340 | self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
341 |
342 | self.head_dim = embed_dim // num_heads
343 | self.q_head_dim = self.head_dim
344 | self.k_head_dim = self.head_dim
345 | assert (
346 | self.head_dim * num_heads == self.embed_dim
347 | ), "embed_dim must be divisible by num_heads"
348 | self.scaling = self.head_dim ** -0.5
349 |
350 | self.self_attention = self_attention
351 | self.encoder_decoder_attention = encoder_decoder_attention
352 |
353 | assert not self.self_attention or self.qkv_same_dim, (
354 | "Self-attention requires query, key and " "value to be of the same size"
355 | )
356 |
357 | k_bias = True
358 | if rescale_init:
359 | k_bias = False
360 |
361 | k_embed_dim = embed_dim
362 | q_embed_dim = embed_dim
363 |
364 | self.k_proj = quant_noise(
365 | nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size
366 | )
367 | self.v_proj = quant_noise(
368 | nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
369 | )
370 | self.q_proj = quant_noise(
371 | nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size
372 | )
373 |
374 | self.out_proj = quant_noise(
375 | nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
376 | )
377 |
378 | if add_bias_kv:
379 | self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
380 | self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
381 | else:
382 | self.bias_k = self.bias_v = None
383 |
384 | self.add_zero_attn = add_zero_attn
385 |
386 | self.gru_rel_pos = gru_rel_pos
387 | if self.gru_rel_pos:
388 | self.grep_linear = nn.Linear(self.q_head_dim, 8)
389 | self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1))
390 |
391 | self.reset_parameters()
392 |
393 | def reset_parameters(self):
394 | if self.qkv_same_dim:
395 | # Empirically observed the convergence to be much better with
396 | # the scaled initialization
397 | nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
398 | nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
399 | nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
400 | else:
401 | nn.init.xavier_uniform_(self.k_proj.weight)
402 | nn.init.xavier_uniform_(self.v_proj.weight)
403 | nn.init.xavier_uniform_(self.q_proj.weight)
404 |
405 | nn.init.xavier_uniform_(self.out_proj.weight)
406 | if self.out_proj.bias is not None:
407 | nn.init.constant_(self.out_proj.bias, 0.0)
408 | if self.bias_k is not None:
409 | nn.init.xavier_normal_(self.bias_k)
410 | if self.bias_v is not None:
411 | nn.init.xavier_normal_(self.bias_v)
412 | if self.has_relative_attention_bias:
413 | nn.init.xavier_normal_(self.relative_attention_bias.weight)
414 |
415 | def _relative_positions_bucket(self, relative_positions, bidirectional=True):
416 | num_buckets = self.num_buckets
417 | max_distance = self.max_distance
418 | relative_buckets = 0
419 |
420 | if bidirectional:
421 | num_buckets = num_buckets // 2
422 | relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
423 | relative_positions = torch.abs(relative_positions)
424 | else:
425 | relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
426 |
427 | max_exact = num_buckets // 2
428 | is_small = relative_positions < max_exact
429 |
430 | relative_postion_if_large = max_exact + (
431 | torch.log(relative_positions.float() / max_exact)
432 | / math.log(max_distance / max_exact)
433 | * (num_buckets - max_exact)
434 | ).to(torch.long)
435 | relative_postion_if_large = torch.min(
436 | relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
437 | )
438 |
439 | relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large)
440 | return relative_buckets
441 |
442 | def compute_bias(self, query_length, key_length):
443 | context_position = torch.arange(query_length, dtype=torch.long)[:, None]
444 | memory_position = torch.arange(key_length, dtype=torch.long)[None, :]
445 | relative_position = memory_position - context_position
446 | relative_position_bucket = self._relative_positions_bucket(
447 | relative_position,
448 | bidirectional=True
449 | )
450 | relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device)
451 | values = self.relative_attention_bias(relative_position_bucket)
452 | values = values.permute([2, 0, 1])
453 | return values
454 |
455 | def forward(
456 | self,
457 | query,
458 | key: Optional[Tensor],
459 | value: Optional[Tensor],
460 | key_padding_mask: Optional[Tensor] = None,
461 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
462 | need_weights: bool = True,
463 | static_kv: bool = False,
464 | attn_mask: Optional[Tensor] = None,
465 | before_softmax: bool = False,
466 | need_head_weights: bool = False,
467 | position_bias: Optional[Tensor] = None
468 | ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
469 | """Input shape: Time x Batch x Channel
470 |
471 | Args:
472 | key_padding_mask (ByteTensor, optional): mask to exclude
473 | keys that are pads, of shape `(batch, src_len)`, where
474 | padding elements are indicated by 1s.
475 | need_weights (bool, optional): return the attention weights,
476 | averaged over heads (default: False).
477 | attn_mask (ByteTensor, optional): typically used to
478 | implement causal attention, where the mask prevents the
479 | attention from looking forward in time (default: None).
480 | before_softmax (bool, optional): return the raw attention
481 | weights and values before the attention softmax.
482 | need_head_weights (bool, optional): return the attention
483 | weights for each head. Implies *need_weights*. Default:
484 | return the average attention weights over all heads.
485 | """
486 | if need_head_weights:
487 | need_weights = True
488 |
489 | is_tpu = query.device.type == "xla"
490 |
491 | tgt_len, bsz, embed_dim = query.size()
492 | src_len = tgt_len
493 | assert embed_dim == self.embed_dim
494 | assert list(query.size()) == [tgt_len, bsz, embed_dim]
495 | if key is not None:
496 | src_len, key_bsz, _ = key.size()
497 | if not torch.jit.is_scripting():
498 | assert key_bsz == bsz
499 | assert value is not None
500 | assert src_len, bsz == value.shape[:2]
501 |
502 | if self.has_relative_attention_bias and position_bias is None:
503 | position_bias = self.compute_bias(tgt_len, src_len)
504 | position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len)
505 |
506 | if (
507 | not is_tpu # don't use PyTorch version on TPUs
508 | and incremental_state is None
509 | and not static_kv
510 | # A workaround for quantization to work. Otherwise JIT compilation
511 | # treats bias in linear module as method.
512 | and not torch.jit.is_scripting()
513 | and self.q_head_dim == self.head_dim
514 | ):
515 | assert key is not None and value is not None
516 | assert attn_mask is None
517 |
518 | attn_mask_rel_pos = None
519 | if position_bias is not None:
520 | attn_mask_rel_pos = position_bias
521 | if self.gru_rel_pos:
522 | query_layer = query.transpose(0, 1)
523 | new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1)
524 | query_layer = query_layer.view(*new_x_shape)
525 | query_layer = query_layer.permute(0, 2, 1, 3)
526 | _B, _H, _L, __ = query_layer.size()
527 |
528 | gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
529 | _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
530 | gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
531 | attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
532 |
533 | attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len))
534 | k_proj_bias = self.k_proj.bias
535 | if k_proj_bias is None:
536 | k_proj_bias = torch.zeros_like(self.q_proj.bias)
537 |
538 | x, attn = F.multi_head_attention_forward(
539 | query,
540 | key,
541 | value,
542 | self.embed_dim,
543 | self.num_heads,
544 | torch.empty([0]),
545 | torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
546 | self.bias_k,
547 | self.bias_v,
548 | self.add_zero_attn,
549 | self.dropout_module.p,
550 | self.out_proj.weight,
551 | self.out_proj.bias,
552 | self.training,
553 | # self.training or self.dropout_module.apply_during_inference,
554 | key_padding_mask,
555 | need_weights,
556 | attn_mask_rel_pos,
557 | use_separate_proj_weight=True,
558 | q_proj_weight=self.q_proj.weight,
559 | k_proj_weight=self.k_proj.weight,
560 | v_proj_weight=self.v_proj.weight,
561 | )
562 | return x, attn, position_bias
563 |
564 | if incremental_state is not None:
565 | saved_state = self._get_input_buffer(incremental_state)
566 | if saved_state is not None and "prev_key" in saved_state:
567 | # previous time steps are cached - no need to recompute
568 | # key and value if they are static
569 | if static_kv:
570 | assert self.encoder_decoder_attention and not self.self_attention
571 | key = value = None
572 | else:
573 | saved_state = None
574 |
575 | if self.self_attention:
576 | q = self.q_proj(query)
577 | k = self.k_proj(query)
578 | v = self.v_proj(query)
579 | elif self.encoder_decoder_attention:
580 | # encoder-decoder attention
581 | q = self.q_proj(query)
582 | if key is None:
583 | assert value is None
584 | k = v = None
585 | else:
586 | k = self.k_proj(key)
587 | v = self.v_proj(key)
588 |
589 | else:
590 | assert key is not None and value is not None
591 | q = self.q_proj(query)
592 | k = self.k_proj(key)
593 | v = self.v_proj(value)
594 | q *= self.scaling
595 |
596 | if self.bias_k is not None:
597 | assert self.bias_v is not None
598 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
599 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
600 | if attn_mask is not None:
601 | attn_mask = torch.cat(
602 | [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
603 | )
604 | if key_padding_mask is not None:
605 | key_padding_mask = torch.cat(
606 | [
607 | key_padding_mask,
608 | key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
609 | ],
610 | dim=1,
611 | )
612 |
613 | q = (
614 | q.contiguous()
615 | .view(tgt_len, bsz * self.num_heads, self.q_head_dim)
616 | .transpose(0, 1)
617 | )
618 | if k is not None:
619 | k = (
620 | k.contiguous()
621 | .view(-1, bsz * self.num_heads, self.k_head_dim)
622 | .transpose(0, 1)
623 | )
624 | if v is not None:
625 | v = (
626 | v.contiguous()
627 | .view(-1, bsz * self.num_heads, self.head_dim)
628 | .transpose(0, 1)
629 | )
630 |
631 | if saved_state is not None:
632 | # saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
633 | if "prev_key" in saved_state:
634 | _prev_key = saved_state["prev_key"]
635 | assert _prev_key is not None
636 | prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
637 | if static_kv:
638 | k = prev_key
639 | else:
640 | assert k is not None
641 | k = torch.cat([prev_key, k], dim=1)
642 | src_len = k.size(1)
643 | if "prev_value" in saved_state:
644 | _prev_value = saved_state["prev_value"]
645 | assert _prev_value is not None
646 | prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
647 | if static_kv:
648 | v = prev_value
649 | else:
650 | assert v is not None
651 | v = torch.cat([prev_value, v], dim=1)
652 | prev_key_padding_mask: Optional[Tensor] = None
653 | if "prev_key_padding_mask" in saved_state:
654 | prev_key_padding_mask = saved_state["prev_key_padding_mask"]
655 | assert k is not None and v is not None
656 | key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
657 | key_padding_mask=key_padding_mask,
658 | prev_key_padding_mask=prev_key_padding_mask,
659 | batch_size=bsz,
660 | src_len=k.size(1),
661 | static_kv=static_kv,
662 | )
663 |
664 | saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
665 | saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
666 | saved_state["prev_key_padding_mask"] = key_padding_mask
667 | # In this branch incremental_state is never None
668 | assert incremental_state is not None
669 | incremental_state = self._set_input_buffer(incremental_state, saved_state)
670 | assert k is not None
671 | assert k.size(1) == src_len
672 |
673 | # This is part of a workaround to get around fork/join parallelism
674 | # not supporting Optional types.
675 | if key_padding_mask is not None and key_padding_mask.dim() == 0:
676 | key_padding_mask = None
677 |
678 | if key_padding_mask is not None:
679 | assert key_padding_mask.size(0) == bsz
680 | assert key_padding_mask.size(1) == src_len
681 |
682 | if self.add_zero_attn:
683 | assert v is not None
684 | src_len += 1
685 | k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
686 | v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
687 | if attn_mask is not None:
688 | attn_mask = torch.cat(
689 | [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
690 | )
691 | if key_padding_mask is not None:
692 | key_padding_mask = torch.cat(
693 | [
694 | key_padding_mask,
695 | torch.zeros(key_padding_mask.size(0), 1).type_as(
696 | key_padding_mask
697 | ),
698 | ],
699 | dim=1,
700 | )
701 |
702 | attn_weights = torch.bmm(q, k.transpose(1, 2))
703 | attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
704 |
705 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
706 |
707 | if attn_mask is not None:
708 | attn_mask = attn_mask.unsqueeze(0)
709 | attn_weights += attn_mask
710 |
711 | if key_padding_mask is not None:
712 | # don't attend to padding symbols
713 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
714 | if not is_tpu:
715 | attn_weights = attn_weights.masked_fill(
716 | key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
717 | float("-inf"),
718 | )
719 | else:
720 | attn_weights = attn_weights.transpose(0, 2)
721 | attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
722 | attn_weights = attn_weights.transpose(0, 2)
723 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
724 |
725 | if before_softmax:
726 | return attn_weights, v, position_bias
727 |
728 | if position_bias is not None:
729 | if self.gru_rel_pos == 1:
730 | query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim)
731 | _B, _H, _L, __ = query_layer.size()
732 | gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view(
733 | _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1)
734 | gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0
735 | position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias
736 |
737 | position_bias = position_bias.view(attn_weights.size())
738 |
739 | attn_weights = attn_weights + position_bias
740 |
741 | attn_weights_float = F.softmax(
742 | attn_weights, dim=-1
743 | )
744 | attn_weights = attn_weights_float.type_as(attn_weights)
745 | attn_probs = self.dropout_module(attn_weights)
746 |
747 | assert v is not None
748 | attn = torch.bmm(attn_probs, v)
749 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
750 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
751 | attn = self.out_proj(attn)
752 | attn_weights: Optional[Tensor] = None
753 | if need_weights:
754 | attn_weights = attn_weights_float.view(
755 | bsz, self.num_heads, tgt_len, src_len
756 | ).transpose(1, 0)
757 | if not need_head_weights:
758 | # average attention weights over heads
759 | attn_weights = attn_weights.mean(dim=0)
760 |
761 | return attn, attn_weights, position_bias
762 |
763 | @staticmethod
764 | def _append_prev_key_padding_mask(
765 | key_padding_mask: Optional[Tensor],
766 | prev_key_padding_mask: Optional[Tensor],
767 | batch_size: int,
768 | src_len: int,
769 | static_kv: bool,
770 | ) -> Optional[Tensor]:
771 | # saved key padding masks have shape (bsz, seq_len)
772 | if prev_key_padding_mask is not None and static_kv:
773 | new_key_padding_mask = prev_key_padding_mask
774 | elif prev_key_padding_mask is not None and key_padding_mask is not None:
775 | new_key_padding_mask = torch.cat(
776 | [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
777 | )
778 | # During incremental decoding, as the padding token enters and
779 | # leaves the frame, there will be a time when prev or current
780 | # is None
781 | elif prev_key_padding_mask is not None:
782 | if src_len > prev_key_padding_mask.size(1):
783 | filler = torch.zeros(
784 | (batch_size, src_len - prev_key_padding_mask.size(1)),
785 | device=prev_key_padding_mask.device,
786 | )
787 | new_key_padding_mask = torch.cat(
788 | [prev_key_padding_mask.float(), filler.float()], dim=1
789 | )
790 | else:
791 | new_key_padding_mask = prev_key_padding_mask.float()
792 | elif key_padding_mask is not None:
793 | if src_len > key_padding_mask.size(1):
794 | filler = torch.zeros(
795 | (batch_size, src_len - key_padding_mask.size(1)),
796 | device=key_padding_mask.device,
797 | )
798 | new_key_padding_mask = torch.cat(
799 | [filler.float(), key_padding_mask.float()], dim=1
800 | )
801 | else:
802 | new_key_padding_mask = key_padding_mask.float()
803 | else:
804 | new_key_padding_mask = prev_key_padding_mask
805 | return new_key_padding_mask
806 |
807 | def _get_input_buffer(
808 | self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
809 | ) -> Dict[str, Optional[Tensor]]:
810 | result = self.get_incremental_state(incremental_state, "attn_state")
811 | if result is not None:
812 | return result
813 | else:
814 | empty_result: Dict[str, Optional[Tensor]] = {}
815 | return empty_result
816 |
817 | def _set_input_buffer(
818 | self,
819 | incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
820 | buffer: Dict[str, Optional[Tensor]],
821 | ):
822 | return self.set_incremental_state(incremental_state, "attn_state", buffer)
823 |
824 | def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
825 | return attn_weights
--------------------------------------------------------------------------------
/components/semantic_extractor/ssl_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import joblib
4 | from components.semantic_extractor.WavLM import WavLM, WavLMConfig
5 |
6 | class ApplyKmeans(nn.Module):
7 | def __init__(self, km_path, device='cuda'):
8 | super(ApplyKmeans, self).__init__()
9 | print(f'Init k-means model from {km_path}')
10 | self.km_model = joblib.load(km_path)
11 | self.C_np = self.km_model.cluster_centers_.transpose()
12 | self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True)
13 | self.C = torch.from_numpy(self.C_np).to(device)
14 | self.Cnorm = torch.from_numpy(self.Cnorm_np).to(device)
15 | self.emb = nn.Embedding(num_embeddings=300, embedding_dim=1024)
16 | self.emb.weight.data = self.C.transpose(0, 1)
17 | self.emb.weight.require_grad = False
18 |
19 | def forward(self, x, b, t):
20 | if not hasattr(self, 'C'):
21 | self.C = torch.from_numpy(self.C_np).to(x.device)
22 | if not hasattr(self, 'Cnorm'):
23 | self.Cnorm = torch.from_numpy(self.Cnorm_np).to(x.device)
24 | dist = x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm
25 | tokens = dist.argmin(dim=-1).reshape(b, t)
26 | return tokens
27 |
28 | def get_ssl_model(ckpt_path, km_path, device='cuda', type='xlsr'):
29 | if type == 'xlsr':
30 | print(f'Init xlsr model from {ckpt_path}')
31 | import fairseq
32 | import argparse
33 | task_arg = argparse.Namespace(task='audio_pretraining')
34 | task = fairseq.tasks.setup_task(task_arg)
35 | model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path], task=task)
36 | model = model[0]
37 | model.eval()
38 | elif type == 'wavlm':
39 | print(f'Init wavlm model from {ckpt_path}')
40 | cpt = torch.load(ckpt_path, map_location="cpu")
41 | cfg = WavLMConfig(cpt["cfg"])
42 | model = WavLM(cfg)
43 | model.load_state_dict(cpt["model"])
44 | model = model.eval()
45 | model = model.requires_grad_(False)
46 | else:
47 | raise NotImplementedError
48 | km_model = ApplyKmeans(km_path, device)
49 | return model, km_model
50 |
51 |
--------------------------------------------------------------------------------
/components/simcodec/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yaoxunji/gen-se/bc14cd3b3ae3b131537bfa73501bd504985eecd5/components/simcodec/__init__.py
--------------------------------------------------------------------------------
/components/simcodec/model.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import torch.nn as nn
4 | from components.simcodec.modules import Encoder, Quantizer, Generator
5 |
6 | class AttrDict(dict):
7 | def __init__(self, *args, **kwargs):
8 | super(AttrDict, self).__init__(*args, **kwargs)
9 | self.__dict__ = self
10 |
11 | class SimCodec(nn.Module):
12 | def __init__(self, config_path):
13 | super(SimCodec, self).__init__()
14 | self.config_path = config_path
15 | with open(self.config_path) as f:
16 | data = f.read()
17 | json_config = json.loads(data)
18 | self.h = AttrDict(json_config)
19 | self.encoder = Encoder(self.h)
20 | self.quantizer = Quantizer(self.h)
21 | self.generator = Generator(self.h)
22 |
23 | def load_ckpt(self, ckpt_path):
24 | ckpt = torch.load(ckpt_path,map_location='cpu')
25 | self.encoder.load_state_dict(ckpt['encoder'])
26 | self.quantizer.load_state_dict(ckpt['quantizer'])
27 | self.generator.load_state_dict(ckpt['generator'])
28 |
29 | def forward(self, x):
30 | batch_size = x.size(0)
31 | if len(x.shape) == 3 and x.shape[-1] == 1:
32 | x = x.squeeze(-1)
33 | c = self.encoder(x)
34 | _, _, c = self.quantizer(c)
35 | c = [code.reshape(batch_size, -1) for code in c]
36 | return torch.stack(c, -1)
37 |
38 | def decode(self, x):
39 | return self.generator(self.quantizer.embed(x))
--------------------------------------------------------------------------------
/components/simcodec/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.utils import weight_norm, remove_weight_norm
5 | from torch.nn import Conv1d, ConvTranspose1d
6 |
7 | LRELU_SLOPE = 0.1
8 | alpha = 1.0
9 |
10 | def get_padding(kernel_size, dilation=1):
11 | return int((kernel_size*dilation - dilation)/2)
12 |
13 | def init_weights(m, mean=0.0, std=0.01):
14 | classname = m.__class__.__name__
15 | if classname.find("Conv") != -1:
16 | m.weight.data.normal_(mean, std)
17 |
18 | class ResBlock1(torch.nn.Module):
19 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
20 | super(ResBlock1, self).__init__()
21 | self.h = h
22 | self.convs1 = nn.ModuleList([
23 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
24 | padding=get_padding(kernel_size, dilation[0]))),
25 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
26 | padding=get_padding(kernel_size, dilation[1]))),
27 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
28 | padding=get_padding(kernel_size, dilation[2])))
29 | ])
30 | self.convs1.apply(init_weights)
31 |
32 | self.convs2 = nn.ModuleList([
33 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
34 | padding=get_padding(kernel_size, 1))),
35 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
36 | padding=get_padding(kernel_size, 1))),
37 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
38 | padding=get_padding(kernel_size, 1)))
39 | ])
40 | self.convs2.apply(init_weights)
41 | self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
42 | self.activations = nn.ModuleList([nn.LeakyReLU(LRELU_SLOPE) for _ in range(self.num_layers)])
43 |
44 |
45 | def forward(self, x):
46 | acts1, acts2 = self.activations[::2], self.activations[1::2]
47 | for c1, c2,a1,a2 in zip(self.convs1, self.convs2,acts1,acts2):
48 | xt = a1(x)
49 | xt = c1(xt)
50 | xt = a2(xt)
51 | xt = c2(xt)
52 | x = xt + x
53 | return x
54 |
55 | def remove_weight_norm(self):
56 | for l in self.convs1:
57 | remove_weight_norm(l)
58 | for l in self.convs2:
59 | remove_weight_norm(l)
60 |
61 |
62 | class Encoder(torch.nn.Module):
63 | def __init__(self, h):
64 | super(Encoder, self).__init__()
65 | self.n_filters = h.en_filters
66 | self.vq_dim = h.vq_dim
67 | self.num_kernels = len(h.resblock_kernel_sizes)
68 | self.num_upsamples = len(h.upsample_rates)
69 | self.upsample_initial_channel = self.n_filters * ( 2**self.num_upsamples )
70 | self.conv_pre = weight_norm(Conv1d(h.channel, self.n_filters, 7, 1, padding=3))
71 | self.normalize = nn.ModuleList()
72 | resblock = ResBlock1
73 |
74 | self.ups = nn.ModuleList()
75 | for i, (u, k) in enumerate(list(reversed(list(zip(h.upsample_rates, h.upsample_kernel_sizes))))):
76 | self.ups.append(weight_norm(
77 | Conv1d(self.n_filters*(2**i), self.n_filters*(2**(i+1)),
78 | k, u,
79 | padding=((k-u)//2)
80 | )))
81 | self.resblocks = nn.ModuleList()
82 | ch = 1
83 | for i in range(len(self.ups)):
84 | ch = self.n_filters*(2**(i+1))
85 | for j, (k, d) in enumerate(
86 | zip(
87 | list(reversed(h.resblock_kernel_sizes)),
88 | list(reversed(h.resblock_dilation_sizes))
89 | )
90 | ):
91 | self.resblocks.append(resblock(h, ch, k, d))
92 | self.normalize.append(torch.nn.LayerNorm([ch],eps=1e-6,elementwise_affine=True))
93 |
94 | self.activation_post = nn.LeakyReLU(LRELU_SLOPE)
95 | self.conv_post = Conv1d(ch, self.vq_dim, 3, 1, padding=1)
96 | self.ups.apply(init_weights)
97 | self.conv_post.apply(init_weights)
98 |
99 | def forward(self, x):
100 | x = self.conv_pre(x)
101 | for i in range(self.num_upsamples):
102 | x = self.ups[i](x)
103 | xs = None
104 | for j in range(self.num_kernels):
105 | if xs is None:
106 | xs = self.resblocks[i*self.num_kernels+j](x)
107 | xs = self.normalize[i*self.num_kernels+j](xs.transpose(1,2)).transpose(1,2)
108 | else:
109 | xs += self.resblocks[i*self.num_kernels+j](x)
110 | xs = self.normalize[i*self.num_kernels+j](xs.transpose(1,2)).transpose(1,2)
111 | x = xs / self.num_kernels
112 | x = self.activation_post(x)
113 | x = self.conv_post(x)
114 | return x
115 |
116 | def remove_weight_norm(self):
117 | print('Removing weight norm...')
118 | for l in self.ups:
119 | remove_weight_norm(l)
120 | for l in self.resblocks:
121 | l.remove_weight_norm()
122 | remove_weight_norm(self.conv_pre)
123 |
124 | class Quantizer_module(torch.nn.Module):
125 | def __init__(self, n_e, e_dim):
126 | super(Quantizer_module, self).__init__()
127 | self.embedding = nn.Embedding(n_e, e_dim)
128 | self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e)
129 | self.target = torch.arange(0,n_e)
130 |
131 | def forward(self, x, idx=0):
132 | loss=torch.Tensor([0.0])
133 | d = torch.sum(x ** 2, 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, 1) \
134 | - 2 * torch.matmul(x, self.embedding.weight.T)
135 | min_indicies = torch.argmin(d, 1)
136 | z_q = self.embedding(min_indicies)
137 | embed_vec = self.embedding.weight
138 | embed_dis = torch.mm(embed_vec , embed_vec.T)*3
139 | self.target = torch.arange(0,embed_vec.shape[0]).to(x.device)
140 | loss = F.cross_entropy(embed_dis,self.target)*(idx==0)
141 | return z_q, min_indicies,loss
142 |
143 | class Quantizer(torch.nn.Module):
144 | def __init__(self, h):
145 | super(Quantizer, self).__init__()
146 | assert h.vq_dim % h.n_code_groups == 0
147 | self.lm_offset = 0
148 | self.lm_states = None
149 | self.vq_dim = h.vq_dim
150 | self.residul_layer = h.n_q
151 | self.n_code_groups = h.n_code_groups
152 | self.quantizer_modules = nn.ModuleList()
153 | for i in range(self.residul_layer):
154 | self.quantizer_modules.append(nn.ModuleList([
155 | Quantizer_module(h.n_codes, self.vq_dim // h.n_code_groups) for _ in range(h.n_code_groups)
156 | ]))
157 | self.h = h
158 | self.codebook_loss_lambda = self.h.codebook_loss_lambda # e.g., 1
159 | self.commitment_loss_lambda = self.h.commitment_loss_lambda # e.g., 0.25
160 |
161 |
162 | def for_one_step(self, xin, idx):
163 | xin = xin.transpose(1, 2)
164 | x = xin.reshape(-1, self.vq_dim)
165 | x = torch.split(x, self.vq_dim // self.h.n_code_groups, dim=-1)
166 | min_indicies = []
167 | z_q = []
168 | all_losses = []
169 | for _x, m in zip(x, self.quantizer_modules[idx]):
170 | _z_q, _min_indicies,_loss = m(_x,idx)
171 | all_losses.append(_loss)
172 | z_q.append(_z_q)
173 | min_indicies.append(_min_indicies)
174 | z_q = torch.cat(z_q, -1).reshape(xin.shape)
175 | z_q = z_q.transpose(1, 2)
176 | all_losses = torch.stack(all_losses)
177 | loss = torch.mean(all_losses)
178 | return z_q, min_indicies, loss
179 |
180 |
181 | def forward(self, xin,bw=-1,mask_id=None):
182 | quantized_out = 0.0
183 | residual = xin
184 | all_losses = []
185 | all_indices = []
186 | if bw<=0:
187 | bw = self.residul_layer
188 | for i in range(bw):
189 | quantized, indices, e_loss = self.for_one_step(residual, i) #
190 | if mask_id is not None:
191 | mask = (
192 | torch.full([xin.shape[0],xin.shape[2],1], fill_value=i, device=xin.device) < mask_id.unsqueeze(2) + 1
193 | )
194 | mask = mask.repeat(1,1,xin.shape[1]).transpose(1,2)
195 | if mask_id is not None:
196 | loss = 0.1 * e_loss + self.codebook_loss_lambda * torch.mean((quantized - residual.detach()) ** 2 * mask) \
197 | + self.commitment_loss_lambda * torch.mean((quantized.detach() - residual) ** 2 * mask )
198 | else:
199 | loss = 0.1 * e_loss \
200 | + self.codebook_loss_lambda * torch.mean((quantized - residual.detach()) ** 2 ) \
201 | + self.commitment_loss_lambda * torch.mean((quantized.detach() - residual) ** 2 )
202 |
203 | quantized = residual + (quantized - residual).detach()
204 | residual = residual - quantized
205 | if mask_id is not None:
206 | quantized_out = quantized_out + quantized * mask
207 | else:
208 | quantized_out = quantized_out + quantized
209 | all_indices.extend(indices) #
210 | all_losses.append(loss)
211 | all_losses = torch.stack(all_losses)
212 | loss = torch.mean(all_losses)
213 | return quantized_out, loss, all_indices
214 |
215 | def embed(self, x , bw=-1):
216 | quantized_out = torch.tensor(0.0, device=x.device)
217 | x = torch.split(x, 1, 2)
218 | if bw <= 0 or bw > self.residul_layer:
219 | bw = self.residul_layer
220 | for i in range(bw):
221 | ret = []
222 | for j in range(self.n_code_groups):
223 | q = x[j+self.n_code_groups*i]
224 | embed = self.quantizer_modules[i][j]
225 | q = embed.embedding(q.squeeze(-1))
226 | ret.append(q)
227 | ret = torch.cat(ret, -1)
228 | quantized_out = quantized_out + ret
229 | return quantized_out.transpose(1, 2)
230 |
231 |
232 | class Generator(torch.nn.Module):
233 | def __init__(self, h):
234 | super(Generator, self).__init__()
235 | self.h = h
236 | self.n_filters = h.de_filters
237 | self.vq_dim = h.vq_dim
238 | self.num_kernels = len(h.resblock_kernel_sizes)
239 | self.num_upsamples = len(h.upsample_rates)
240 | self.upsample_initial_channel = self.n_filters * ( 2**self.num_upsamples )
241 | self.conv_pre = weight_norm(Conv1d(self.vq_dim, self.upsample_initial_channel, 7, 1, padding=3))
242 | resblock = ResBlock1
243 |
244 |
245 | self.norm = nn.Identity()
246 |
247 | self.ups = nn.ModuleList()
248 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
249 | self.ups.append(weight_norm(
250 | ConvTranspose1d(
251 | self.upsample_initial_channel//(2**i), self.upsample_initial_channel//(2**(i+1)),
252 | k, u,
253 | padding=(k - u )//2,
254 | )
255 | ))
256 | ch = 1
257 | self.resblocks = nn.ModuleList()
258 | for i in range(len(self.ups)):
259 | ch = self.upsample_initial_channel//(2**(i+1))
260 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
261 | self.resblocks.append(resblock(h, ch, k, d))
262 |
263 |
264 | self.activation_post = nn.LeakyReLU(LRELU_SLOPE)
265 | self.conv_post = weight_norm(Conv1d(ch, h.channel, 7, 1, padding=3))
266 | self.ups.apply(init_weights)
267 | self.conv_post.apply(init_weights)
268 |
269 | def forward(self, x):
270 | x = self.norm(x)
271 | x = self.conv_pre(x)
272 |
273 | for i in range(self.num_upsamples):
274 | x = self.ups[i](x)
275 | xs = None
276 | for j in range(self.num_kernels):
277 | if xs is None:
278 | xs = self.resblocks[i*self.num_kernels+j](x)
279 | else:
280 | xs += self.resblocks[i*self.num_kernels+j](x)
281 | x = xs / self.num_kernels
282 | x = self.activation_post(x)
283 | x = self.conv_post(x)
284 | x = torch.tanh(x)
285 |
286 | return x
287 |
288 | def remove_weight_norm(self):
289 | print('Removing weight norm...')
290 | for l in self.ups:
291 | remove_weight_norm(l)
292 | for l in self.resblocks:
293 | l.remove_weight_norm()
294 | remove_weight_norm(self.conv_pre)
295 | remove_weight_norm(self.conv_post)
--------------------------------------------------------------------------------
/configs/gense.yaml:
--------------------------------------------------------------------------------
1 | path:
2 | n2s_ckpt_path: ckpts/n2s_xlsr.ckpt
3 | s2s_ckpt_path: ckpts/s2s_xlsr.ckpt
4 | codec_config_path: ckpts/config.json
5 |
6 | model:
7 | hidden_size: 1024
8 | # intermediate_size: 2048
9 | num_hidden_layers: 12
10 | num_attention_heads: 8
11 | n2s_vocab_size: 1027 #1024 + 1+1+1
12 | s2s_vocab_size: 9219 #8192 + 1024 + 1+1+1
13 | semantic_num: 1024
14 |
15 | ssl_model:
16 | ckpt_path: ckpts/xlsr2_300m.pt
17 | km_path: ckpts/xlsr_km.mdl
18 | type: xlsr
--------------------------------------------------------------------------------
/configs/gense_wavlm.yaml:
--------------------------------------------------------------------------------
1 | path:
2 | n2s_ckpt_path: ckpts/n2s_wavlm.ckpt
3 | s2s_ckpt_path: ckpts/s2s_wavlm.ckpt
4 | codec_config_path: ckpts/config.json
5 |
6 | model:
7 | hidden_size: 1024
8 | # intermediate_size: 2048
9 | num_hidden_layers: 12
10 | num_attention_heads: 8
11 | n2s_vocab_size: 1027 #1024 + 1+1+1
12 | s2s_vocab_size: 9219 #8192 + 1024 + 1+1+1
13 | semantic_num: 1024
14 |
15 | ssl_model:
16 | ckpt_path: ckpts/WavLM-Large.pt
17 | km_path: ckpts/wavlm_km.mdl
18 | type: wavlm
--------------------------------------------------------------------------------
/fig/gense.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yaoxunji/gen-se/bc14cd3b3ae3b131537bfa73501bd504985eecd5/fig/gense.png
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import fire
2 | import torch
3 | import torchaudio
4 | import yaml
5 |
6 | from models.gense import N2S, S2S
7 |
8 | class AttrDict(dict):
9 | def __init__(self, *args, **kwargs):
10 | super(AttrDict, self).__init__(*args, **kwargs)
11 | self.__dict__ = self
12 |
13 | def get_firstchannel_read(path, target_sr=16000):
14 | wav, sr = torchaudio.load(path)
15 | if wav.shape[0] > 1:
16 | wav = wav[0].unsqueeze(0)
17 | if sr != target_sr:
18 | resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr)
19 | wav = resampler(wav)
20 | return wav.unsqueeze(0)
21 |
22 |
23 | def run(noisy_path, out_path, config_path):
24 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
25 | with open(config_path, "r") as f:
26 | config = yaml.safe_load(f)
27 | config = AttrDict(config)
28 |
29 | noisy_wav = get_firstchannel_read(noisy_path).to(device)
30 |
31 | n2s_model = N2S(config)
32 | n2s_model.load_state_dict(torch.load(config.path['n2s_ckpt_path'])["state_dict"])
33 | n2s_model = n2s_model.eval()
34 | n2s_model = n2s_model.to(device)
35 |
36 | s2s_model = S2S(config)
37 | s2s_model.load_state_dict(torch.load(config.path['s2s_ckpt_path'])["state_dict"])
38 | s2s_model = s2s_model.eval()
39 | s2s_model = s2s_model.to(device)
40 |
41 | noisy_s, clean_s = n2s_model.generate(noisy_wav)
42 | enhanced_wav = s2s_model.generate(noisy_wav, noisy_s, clean_s)
43 | torchaudio.save(out_path, enhanced_wav, sample_rate=16000)
44 |
45 |
46 | if __name__ == "__main__":
47 | fire.Fire(
48 | {
49 | "run": run,
50 | }
51 | )
--------------------------------------------------------------------------------
/models/gense.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 | from components.semantic_extractor.ssl_model import get_ssl_model
8 | from components.simcodec.model import SimCodec
9 | from transformers import GPT2Config, GPT2LMHeadModel
10 |
11 | class N2S(nn.Module):
12 | def __init__(self, hps):
13 | super().__init__()
14 | self.hps = hps
15 | self.xlsr, self.km = get_ssl_model(**hps.ssl_model)
16 | self.bos = 1
17 | self.eos = 2
18 | self.pad = 0
19 | self.shift_num = 3
20 |
21 | self.lm_conf = GPT2Config(
22 | vocab_size=self.hps.model['n2s_vocab_size'],
23 | n_embd=self.hps.model['hidden_size'],
24 | n_layer=self.hps.model['num_hidden_layers'],
25 | n_head=self.hps.model['num_attention_heads'],
26 | activation_function='gelu_new',
27 | n_positions=2048,
28 | n_ctx=2048,
29 | resid_pdrop=0.1,
30 | embd_pdrop=0.1,
31 | attn_pdrop=0.1,
32 | layer_norm_epsilon=1e-05,
33 | initializer_range=0.02,
34 | summary_type='mean',
35 | summary_use_proj=True,
36 | summary_activation=None,
37 | summary_proj_to_labels=True,
38 | summary_first_dropout=0.1,
39 | bos_token_id=self.bos,
40 | eos_token_id=self.eos,
41 | )
42 | self.lm = GPT2LMHeadModel(self.lm_conf)
43 |
44 | def extract_semantic(self, wavs, num_frames):
45 | padding_size = (0, 100)
46 | wavs = F.pad(wavs, padding_size, "constant", 0)
47 | num_frames += 100
48 | features = self.xlsr.extract_features(wavs, padding_mask=None)
49 | layer_results = features['layer_results'][5]
50 | x, _, _ = layer_results
51 | features = x.transpose(0,1)
52 | b, t, d = features.shape
53 | tokens = self.km(features.reshape(-1, d), b=b, t=t)
54 | return tokens
55 |
56 | def inference(self, token_gen, pos_gen):
57 | predict_len = (token_gen.shape[1] - 1)
58 | truck_length = token_gen.shape[1]
59 |
60 | for j in tqdm(range(predict_len)):
61 | lm_outputs = self.lm(
62 | input_ids=token_gen,
63 | attention_mask=None,
64 | position_ids=pos_gen
65 | )
66 | logits = lm_outputs['logits']
67 | logits[:, :, 0:self.shift_num] = -1e5
68 | probs = logits[:, -1, :].softmax(dim=-1)
69 |
70 | dist = torch.distributions.categorical.Categorical(probs=probs)
71 |
72 | samples = dist.sample().unsqueeze(1).to(token_gen.device)
73 | token_gen = torch.cat([token_gen, samples], dim=1)
74 | pos_pad = torch.ones(pos_gen.shape[0]) * j
75 | pos_gen = torch.cat([pos_gen, pos_pad.unsqueeze(1).to(token_gen.device).long()], dim=1)
76 |
77 | return token_gen[:,truck_length:][0]
78 |
79 |
80 | def generate(self, mix):
81 | mix = mix.squeeze(1)
82 | num_frame = torch.LongTensor([mix.shape[1]]).to(mix.device)
83 | token_s = self.extract_semantic(mix, num_frames=num_frame)
84 |
85 | token_s += 3
86 | bos = torch.ones(token_s.shape[0],1).long().to(mix.device)
87 | token_gen = torch.cat([token_s, bos], dim=1)
88 |
89 | pos_gen_id = torch.from_numpy(np.asarray(list(range(token_s.shape[1] + 1)))).to(mix.device)
90 | pos_gen = []
91 | for i in range(token_s.shape[0]):
92 | pos_gen.append(pos_gen_id.unsqueeze(0))
93 | pos_gen = torch.cat(pos_gen, dim=0)
94 |
95 | clean_s = self.inference(token_gen, pos_gen) - self.shift_num
96 | token_s -= self.shift_num
97 | return token_s, clean_s
98 |
99 |
100 | class S2S(nn.Module):
101 | def __init__(self, hps):
102 | super().__init__()
103 | self.hps = hps
104 | self.codec_tokenizer = SimCodec(hps.path['codec_config_path'])
105 | self.xlsr, self.km = get_ssl_model(**hps.ssl_model)
106 | self.bos = 1
107 | self.eos = 2
108 | self.pad = 0
109 | self.shift_num = 3 + self.hps.model['semantic_num']
110 | self.lm_conf = GPT2Config(
111 | vocab_size=self.hps.model['s2s_vocab_size'],
112 | n_embd=self.hps.model['hidden_size'],
113 | n_layer=self.hps.model['num_hidden_layers'],
114 | n_head=self.hps.model['num_attention_heads'],
115 | activation_function='gelu_new',
116 | n_positions=4096,
117 | n_ctx=4096,
118 | resid_pdrop=0.1,
119 | embd_pdrop=0.1,
120 | attn_pdrop=0.1,
121 | layer_norm_epsilon=1e-05,
122 | initializer_range=0.02,
123 | summary_type='mean',
124 | summary_use_proj=True,
125 | summary_activation=None,
126 | summary_proj_to_labels=True,
127 | summary_first_dropout=0.1,
128 | bos_token_id=self.bos,
129 | eos_token_id=self.eos,
130 | )
131 | self.lm = GPT2LMHeadModel(self.lm_conf)
132 |
133 | def inference(self, token_gen, pos_gen):
134 | predict_len = int((token_gen.shape[1] - 1) / 2)
135 | truck_length = token_gen.shape[1]
136 | for j in tqdm(range(predict_len)):
137 | lm_outputs = self.lm(
138 | input_ids=token_gen,
139 | attention_mask=None,
140 | position_ids=pos_gen
141 | )
142 | logits = lm_outputs['logits']
143 | logits[:, :, 0:self.shift_num] = -1e5
144 | probs = logits[:, -1, :].softmax(dim=-1)
145 | dist = torch.distributions.categorical.Categorical(probs=probs)
146 | samples = dist.sample().unsqueeze(1).to(token_gen.device)
147 | token_gen = torch.cat([token_gen, samples], dim=1)
148 | pos_pad = torch.ones(pos_gen.shape[0]) * (j + 1000)
149 | pos_gen = torch.cat([pos_gen, pos_pad.unsqueeze(1).to(token_gen.device).long()], dim=1)
150 |
151 | return token_gen[:,truck_length:][0]
152 |
153 | def generate(self, mix, mix_s, clean_s):
154 | mix_a = self.codec_tokenizer(mix).squeeze(-1)
155 | if len(clean_s.shape) == 1:
156 | clean_s = clean_s.unsqueeze(0)
157 |
158 | mix_s += 3
159 | clean_s += 3
160 | mix_a += self.shift_num
161 |
162 | bos = torch.ones(mix_s.shape[0],1).long().to(mix.device)
163 | token_gen = torch.cat([mix_s, clean_s, bos, mix_a], dim=1)
164 |
165 | pos_gen_id = torch.from_numpy(np.asarray(list(range(mix_s.shape[1] + clean_s.shape[1] + 1)) + list(range(mix_a.shape[1])))).to(mix.device)
166 | pos_gen = []
167 | for i in range(mix_s.shape[0]):
168 | pos_gen.append(pos_gen_id.unsqueeze(0))
169 | pos_gen = torch.cat(pos_gen, dim=0)
170 |
171 | pre_a = self.inference(token_gen, pos_gen) - self.shift_num
172 | gen_wav = self.codec_tokenizer.decode(pre_a.unsqueeze(0).unsqueeze(2)).squeeze(0).cpu()
173 |
174 | return gen_wav
--------------------------------------------------------------------------------
/models/gense_wavlm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 | from components.semantic_extractor.ssl_model import get_ssl_model
8 | from components.simcodec.model import SimCodec
9 | from transformers import GPT2Config, GPT2LMHeadModel
10 |
11 | class N2S(nn.Module):
12 | def __init__(self, hps):
13 | super().__init__()
14 | self.hps = hps
15 | self.wavlm, self.km = get_ssl_model(**hps.ssl_model)
16 | self.bos = 1
17 | self.eos = 2
18 | self.pad = 0
19 | self.shift_num = 3
20 |
21 | self.lm_conf = GPT2Config(
22 | vocab_size=self.hps.model['n2s_vocab_size'],
23 | n_embd=self.hps.model['hidden_size'],
24 | n_layer=self.hps.model['num_hidden_layers'],
25 | n_head=self.hps.model['num_attention_heads'],
26 | activation_function='gelu_new',
27 | n_positions=2048,
28 | n_ctx=2048,
29 | resid_pdrop=0.1,
30 | embd_pdrop=0.1,
31 | attn_pdrop=0.1,
32 | layer_norm_epsilon=1e-05,
33 | initializer_range=0.02,
34 | summary_type='mean',
35 | summary_use_proj=True,
36 | summary_activation=None,
37 | summary_proj_to_labels=True,
38 | summary_first_dropout=0.1,
39 | bos_token_id=self.bos,
40 | eos_token_id=self.eos,
41 | )
42 | self.lm = GPT2LMHeadModel(self.lm_conf)
43 |
44 | def extract_semantic(self, wavs, num_frames):
45 | padding_size = (0, 100)
46 | wavs = F.pad(wavs, padding_size, "constant", 0)
47 | num_frames += 100
48 | features = self.wavlm.extract_features(
49 | wavs,
50 | output_layer=6,
51 | ret_layer_results=False,
52 | input_length=num_frames
53 | )[0]
54 | b, t, d = features.shape
55 | tokens = self.km(features.reshape(-1, d), b=b, t=t)
56 | return tokens
57 |
58 | def inference(self, token_gen, pos_gen):
59 | predict_len = (token_gen.shape[1] - 1)
60 | truck_length = token_gen.shape[1]
61 |
62 | for j in tqdm(range(predict_len)):
63 | lm_outputs = self.lm(
64 | input_ids=token_gen,
65 | attention_mask=None,
66 | position_ids=pos_gen
67 | )
68 | logits = lm_outputs['logits']
69 | logits[:, :, 0:self.shift_num] = -1e5
70 | probs = logits[:, -1, :].softmax(dim=-1)
71 |
72 | dist = torch.distributions.categorical.Categorical(probs=probs)
73 |
74 | samples = dist.sample().unsqueeze(1).to(token_gen.device)
75 | token_gen = torch.cat([token_gen, samples], dim=1)
76 | pos_pad = torch.ones(pos_gen.shape[0]) * j
77 | pos_gen = torch.cat([pos_gen, pos_pad.unsqueeze(1).to(token_gen.device).long()], dim=1)
78 |
79 | return token_gen[:,truck_length:][0]
80 |
81 |
82 | def generate(self, mix):
83 | mix = mix.squeeze(1)
84 | num_frame = torch.LongTensor([mix.shape[1]]).to(mix.device)
85 | token_s = self.extract_semantic(mix, num_frames=num_frame)
86 |
87 | token_s += 3
88 | bos = torch.ones(token_s.shape[0],1).long().to(mix.device)
89 | token_gen = torch.cat([token_s, bos], dim=1)
90 |
91 | pos_gen_id = torch.from_numpy(np.asarray(list(range(token_s.shape[1] + 1)))).to(mix.device)
92 | pos_gen = []
93 | for i in range(token_s.shape[0]):
94 | pos_gen.append(pos_gen_id.unsqueeze(0))
95 | pos_gen = torch.cat(pos_gen, dim=0)
96 |
97 | clean_s = self.inference(token_gen, pos_gen) - self.shift_num
98 | token_s -= self.shift_num
99 | return token_s, clean_s
100 |
101 |
102 | class S2S(nn.Module):
103 | def __init__(self, hps):
104 | super().__init__()
105 | self.hps = hps
106 | self.codec_tokenizer = SimCodec(hps.path['codec_config_path'])
107 | self.wavlm, self.km = get_ssl_model(**hps.ssl_model)
108 | self.bos = 1
109 | self.eos = 2
110 | self.pad = 0
111 | self.shift_num = 3 + self.hps.model['semantic_num']
112 | self.lm_conf = GPT2Config(
113 | vocab_size=self.hps.model['s2s_vocab_size'],
114 | n_embd=self.hps.model['hidden_size'],
115 | n_layer=self.hps.model['num_hidden_layers'],
116 | n_head=self.hps.model['num_attention_heads'],
117 | activation_function='gelu_new',
118 | n_positions=4096,
119 | n_ctx=4096,
120 | resid_pdrop=0.1,
121 | embd_pdrop=0.1,
122 | attn_pdrop=0.1,
123 | layer_norm_epsilon=1e-05,
124 | initializer_range=0.02,
125 | summary_type='mean',
126 | summary_use_proj=True,
127 | summary_activation=None,
128 | summary_proj_to_labels=True,
129 | summary_first_dropout=0.1,
130 | bos_token_id=self.bos,
131 | eos_token_id=self.eos,
132 | )
133 | self.lm = GPT2LMHeadModel(self.lm_conf)
134 |
135 | def inference(self, token_gen, pos_gen):
136 | predict_len = int((token_gen.shape[1] - 1) / 2)
137 | truck_length = token_gen.shape[1]
138 | for j in tqdm(range(predict_len)):
139 | lm_outputs = self.lm(
140 | input_ids=token_gen,
141 | attention_mask=None,
142 | position_ids=pos_gen
143 | )
144 | logits = lm_outputs['logits']
145 | logits[:, :, 0:self.shift_num] = -1e5
146 | probs = logits[:, -1, :].softmax(dim=-1)
147 | dist = torch.distributions.categorical.Categorical(probs=probs)
148 | samples = dist.sample().unsqueeze(1).to(token_gen.device)
149 | token_gen = torch.cat([token_gen, samples], dim=1)
150 | pos_pad = torch.ones(pos_gen.shape[0]) * (j + 1000)
151 | pos_gen = torch.cat([pos_gen, pos_pad.unsqueeze(1).to(token_gen.device).long()], dim=1)
152 |
153 | return token_gen[:,truck_length:][0]
154 |
155 | def generate(self, mix, mix_s, clean_s):
156 | mix_a = self.codec_tokenizer(mix).squeeze(-1)
157 | if len(clean_s.shape) == 1:
158 | clean_s = clean_s.unsqueeze(0)
159 |
160 | mix_s += 3
161 | clean_s += 3
162 | mix_a += self.shift_num
163 |
164 | bos = torch.ones(mix_s.shape[0],1).long().to(mix.device)
165 | token_gen = torch.cat([mix_s, clean_s, bos, mix_a], dim=1)
166 |
167 | pos_gen_id = torch.from_numpy(np.asarray(list(range(mix_s.shape[1] + clean_s.shape[1] + 1)) + list(range(mix_a.shape[1])))).to(mix.device)
168 | pos_gen = []
169 | for i in range(mix_s.shape[0]):
170 | pos_gen.append(pos_gen_id.unsqueeze(0))
171 | pos_gen = torch.cat(pos_gen, dim=0)
172 |
173 | pre_a = self.inference(token_gen, pos_gen) - self.shift_num
174 | gen_wav = self.codec_tokenizer.decode(pre_a.unsqueeze(0).unsqueeze(2)).squeeze(0).cpu()
175 |
176 | return gen_wav
--------------------------------------------------------------------------------
/noisy.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yaoxunji/gen-se/bc14cd3b3ae3b131537bfa73501bd504985eecd5/noisy.wav
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.22.3
2 | torch==1.13.1
3 | torchaudio==0.13.1
4 | fire
5 | PyYAML==6.0.2
6 | joblib==1.4.0
7 | scikit-learn==1.3.2
8 | tqdm
9 | librosa==0.8.0
10 | transformers==4.40.1
11 | fairseq==0.12.2
--------------------------------------------------------------------------------