├── .gitignore ├── README.md ├── components ├── semantic_extractor │ ├── WavLM.py │ ├── modules.py │ └── ssl_model.py └── simcodec │ ├── __init__.py │ ├── model.py │ └── modules.py ├── configs ├── gense.yaml └── gense_wavlm.yaml ├── fig └── gense.png ├── infer.py ├── models ├── gense.py └── gense_wavlm.py ├── noisy.wav └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GenSE: Generative Speech Enhancement via Language Models using Hierarchical Modeling
The official implementation of GenSE (ICLR 2025) 2 | 3 | We propose a comprehensive framework tailored for language model-based speech enhancement, called GenSE. Speech enhancement is regarded as a conditional language modeling task rather than a continuous signal regression problem defined in existing works. This is achieved by tokenizing speech signals into semantic tokens using a pre-trained self-supervised model and into acoustic tokens using a custom-designed single-quantizer neural codec model. 4 | 5 |

6 | 7 |

8 | 9 | GenSE employs a hierarchical modeling framework with a two-stage process: a N2S transformation front-end, which converts noisy speech into clean semantic tokens, and an S2S generation back-end, which synthesizes clean speech using both semantic tokens and noisy acoustic tokens. 10 | 11 | ## TODO 📝 12 | - [x] Release Inference pipeline 13 | - [x] Release pre-trained model 14 | - [ ] Support in colab 15 | - [ ] More to be added 16 | 17 | ## Getting Started 📥 18 | 19 | ### 1. Pre-requisites 20 | 0. Pytorch >=1.13 and torchaudio >= 0.13 21 | 1. Install requirements 22 | ``` 23 | conda create -n gense python=3.8 24 | pip install -r requirements.txt 25 | ``` 26 | 27 | ### 2. Get Self-supervised Model: 28 | Download [XLSR model](https://huggingface.co/facebook/wav2vec2-xls-r-300m) and move it to ckpts dir. 29 | or 30 | Download [WavLM Large](https://huggingface.co/microsoft/wavlm-large) run a variant of XLSR version. 31 | 32 | ### 3. Pre-trained Model: 33 | Download pre-trained model from [huggingface](https://huggingface.co/yaoxunji/gen-se/tree/main), all checkpoints should be stored in ckpts dir. 34 | 35 | ### 4. Speech Enhancement: 36 | ``` 37 | python infer.py run \ 38 | --noisy_path noisy.wav 39 | --out_path ./enhanced.wav 40 | --config_path configs/gense.yaml 41 | ``` 42 | ### 5. SimCodec Copy-syn: 43 | ``` 44 | from components.simcodec.model import SimCodec 45 | codec = SimCodec('config.json') 46 | codec.load_ckpt('g_00100000') 47 | codec = codec.eval() 48 | codec = codec.to('cuda') 49 | 50 | code = codec(wav) 51 | print(code.shape) #[B, L1, 1] 52 | syn = codec.decode(code) 53 | print(syn.shape) #[B, 1, L2] 54 | torchaudio.save('copy.wav', syn.detach().cpu().squeeze(0), 16000) 55 | ``` 56 | 57 | 58 | 59 | 68 | -------------------------------------------------------------------------------- /components/semantic_extractor/WavLM.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/wavlm 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on fairseq code bases 7 | # https://github.com/pytorch/fairseq 8 | # -------------------------------------------------------- 9 | 10 | import math 11 | import logging 12 | from typing import List, Optional, Tuple 13 | 14 | import sys,os 15 | sys.path.append(os.path.dirname(sys.path[0])) 16 | import numpy as np 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.nn import LayerNorm 22 | from .modules import ( 23 | Fp32GroupNorm, 24 | Fp32LayerNorm, 25 | GradMultiply, 26 | MultiheadAttention, 27 | SamePad, 28 | init_bert_params, 29 | get_activation_fn, 30 | TransposeLast, 31 | GLU_Linear, 32 | ) 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | 37 | def compute_mask_indices( 38 | shape: Tuple[int, int], 39 | padding_mask: Optional[torch.Tensor], 40 | mask_prob: float, 41 | mask_length: int, 42 | mask_type: str = "static", 43 | mask_other: float = 0.0, 44 | min_masks: int = 0, 45 | no_overlap: bool = False, 46 | min_space: int = 0, 47 | ) -> np.ndarray: 48 | """ 49 | Computes random mask spans for a given shape 50 | 51 | Args: 52 | shape: the the shape for which to compute masks. 53 | should be of size 2 where first element is batch size and 2nd is timesteps 54 | padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements 55 | mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by 56 | number of timesteps divided by length of mask span to mask approximately this percentage of all elements. 57 | however due to overlaps, the actual number will be smaller (unless no_overlap is True) 58 | mask_type: how to compute mask lengths 59 | static = fixed size 60 | uniform = sample from uniform distribution [mask_other, mask_length*2] 61 | normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element 62 | poisson = sample from possion distribution with lambda = mask length 63 | min_masks: minimum number of masked spans 64 | no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping 65 | min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans 66 | """ 67 | 68 | bsz, all_sz = shape 69 | mask = np.full((bsz, all_sz), False) 70 | 71 | all_num_mask = int( 72 | # add a random number for probabilistic rounding 73 | mask_prob * all_sz / float(mask_length) 74 | + np.random.rand() 75 | ) 76 | 77 | all_num_mask = max(min_masks, all_num_mask) 78 | 79 | mask_idcs = [] 80 | for i in range(bsz): 81 | if padding_mask is not None: 82 | sz = all_sz - padding_mask[i].long().sum().item() 83 | num_mask = int( 84 | # add a random number for probabilistic rounding 85 | mask_prob * sz / float(mask_length) 86 | + np.random.rand() 87 | ) 88 | num_mask = max(min_masks, num_mask) 89 | else: 90 | sz = all_sz 91 | num_mask = all_num_mask 92 | 93 | if mask_type == "static": 94 | lengths = np.full(num_mask, mask_length) 95 | elif mask_type == "uniform": 96 | lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) 97 | elif mask_type == "normal": 98 | lengths = np.random.normal(mask_length, mask_other, size=num_mask) 99 | lengths = [max(1, int(round(x))) for x in lengths] 100 | elif mask_type == "poisson": 101 | lengths = np.random.poisson(mask_length, size=num_mask) 102 | lengths = [int(round(x)) for x in lengths] 103 | else: 104 | raise Exception("unknown mask selection " + mask_type) 105 | 106 | if sum(lengths) == 0: 107 | lengths[0] = min(mask_length, sz - 1) 108 | 109 | if no_overlap: 110 | mask_idc = [] 111 | 112 | def arrange(s, e, length, keep_length): 113 | span_start = np.random.randint(s, e - length) 114 | mask_idc.extend(span_start + i for i in range(length)) 115 | 116 | new_parts = [] 117 | if span_start - s - min_space >= keep_length: 118 | new_parts.append((s, span_start - min_space + 1)) 119 | if e - span_start - keep_length - min_space > keep_length: 120 | new_parts.append((span_start + length + min_space, e)) 121 | return new_parts 122 | 123 | parts = [(0, sz)] 124 | min_length = min(lengths) 125 | for length in sorted(lengths, reverse=True): 126 | lens = np.fromiter( 127 | (e - s if e - s >= length + min_space else 0 for s, e in parts), 128 | np.int, 129 | ) 130 | l_sum = np.sum(lens) 131 | if l_sum == 0: 132 | break 133 | probs = lens / np.sum(lens) 134 | c = np.random.choice(len(parts), p=probs) 135 | s, e = parts.pop(c) 136 | parts.extend(arrange(s, e, length, min_length)) 137 | mask_idc = np.asarray(mask_idc) 138 | else: 139 | min_len = min(lengths) 140 | if sz - min_len <= num_mask: 141 | min_len = sz - num_mask - 1 142 | 143 | mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) 144 | 145 | mask_idc = np.asarray( 146 | [ 147 | mask_idc[j] + offset 148 | for j in range(len(mask_idc)) 149 | for offset in range(lengths[j]) 150 | ] 151 | ) 152 | 153 | mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) 154 | 155 | min_len = min([len(m) for m in mask_idcs]) 156 | for i, mask_idc in enumerate(mask_idcs): 157 | if len(mask_idc) > min_len: 158 | mask_idc = np.random.choice(mask_idc, min_len, replace=False) 159 | mask[i, mask_idc] = True 160 | 161 | return mask 162 | 163 | 164 | class WavLMConfig: 165 | def __init__(self, cfg=None): 166 | self.extractor_mode: str = "default" # mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with normalize=True) 167 | self.encoder_layers: int = 12 # num encoder layers in the transformer 168 | 169 | self.encoder_embed_dim: int = 768 # encoder embedding dimension 170 | self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN 171 | self.encoder_attention_heads: int = 12 # num encoder attention heads 172 | self.activation_fn: str = "gelu" # activation function to use 173 | 174 | self.layer_norm_first: bool = False # apply layernorm first in the transformer 175 | self.conv_feature_layers: str = "[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2" # string describing convolutional feature extraction layers in form of a python list that contains [(dim, kernel_size, stride), ...] 176 | self.conv_bias: bool = False # include bias in conv encoder 177 | self.feature_grad_mult: float = 1.0 # multiply feature extractor var grads by this 178 | 179 | self.normalize: bool = False # normalize input to have 0 mean and unit variance during training 180 | 181 | # dropouts 182 | self.dropout: float = 0.1 # dropout probability for the transformer 183 | self.attention_dropout: float = 0.1 # dropout probability for attention weights 184 | self.activation_dropout: float = 0.0 # dropout probability after activation in FFN 185 | self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer 186 | self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) 187 | self.dropout_features: float = 0.0 # dropout to apply to the features (after feat extr) 188 | 189 | # masking 190 | self.mask_length: int = 10 # mask length 191 | self.mask_prob: float = 0.65 # probability of replacing a token with mask 192 | self.mask_selection: str = "static" # how to choose mask length 193 | self.mask_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indicesh 194 | self.no_mask_overlap: bool = False # whether to allow masks to overlap 195 | self.mask_min_space: int = 1 # min space between spans (if no overlap is enabled) 196 | 197 | # channel masking 198 | self.mask_channel_length: int = 10 # length of the mask for features (channels) 199 | self.mask_channel_prob: float = 0.0 # probability of replacing a feature with 0 200 | self.mask_channel_selection: str = "static" # how to choose mask length for channel masking 201 | self.mask_channel_other: float = 0 # secondary mask argument (used for more complex distributions), see help in compute_mask_indices 202 | self.no_mask_channel_overlap: bool = False # whether to allow channel masks to overlap 203 | self.mask_channel_min_space: int = 1 # min space between spans (if no overlap is enabled) 204 | 205 | # positional embeddings 206 | self.conv_pos: int = 128 # number of filters for convolutional positional embeddings 207 | self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding 208 | 209 | # relative position embedding 210 | self.relative_position_embedding: bool = False # apply relative position embedding 211 | self.num_buckets: int = 320 # number of buckets for relative position embedding 212 | self.max_distance: int = 1280 # maximum distance for relative position embedding 213 | self.gru_rel_pos: bool = False # apply gated relative position embedding 214 | 215 | if cfg is not None: 216 | self.update(cfg) 217 | 218 | def update(self, cfg: dict): 219 | self.__dict__.update(cfg) 220 | 221 | 222 | class WavLM(nn.Module): 223 | def __init__( 224 | self, 225 | cfg: WavLMConfig, 226 | ) -> None: 227 | super().__init__() 228 | logger.info(f"WavLM Config: {cfg.__dict__}") 229 | 230 | self.cfg = cfg 231 | feature_enc_layers = eval(cfg.conv_feature_layers) 232 | self.embed = feature_enc_layers[-1][0] 233 | 234 | self.feature_extractor = ConvFeatureExtractionModel( 235 | conv_layers=feature_enc_layers, 236 | dropout=0.0, 237 | mode=cfg.extractor_mode, 238 | conv_bias=cfg.conv_bias, 239 | ) 240 | 241 | self.post_extract_proj = ( 242 | nn.Linear(self.embed, cfg.encoder_embed_dim) 243 | if self.embed != cfg.encoder_embed_dim 244 | else None 245 | ) 246 | 247 | self.mask_prob = cfg.mask_prob 248 | self.mask_selection = cfg.mask_selection 249 | self.mask_other = cfg.mask_other 250 | self.mask_length = cfg.mask_length 251 | self.no_mask_overlap = cfg.no_mask_overlap 252 | self.mask_min_space = cfg.mask_min_space 253 | 254 | self.mask_channel_prob = cfg.mask_channel_prob 255 | self.mask_channel_selection = cfg.mask_channel_selection 256 | self.mask_channel_other = cfg.mask_channel_other 257 | self.mask_channel_length = cfg.mask_channel_length 258 | self.no_mask_channel_overlap = cfg.no_mask_channel_overlap 259 | self.mask_channel_min_space = cfg.mask_channel_min_space 260 | 261 | self.dropout_input = nn.Dropout(cfg.dropout_input) 262 | self.dropout_features = nn.Dropout(cfg.dropout_features) 263 | 264 | self.feature_grad_mult = cfg.feature_grad_mult 265 | 266 | self.mask_emb = nn.Parameter( 267 | torch.FloatTensor(cfg.encoder_embed_dim).uniform_() 268 | ) 269 | 270 | self.encoder = TransformerEncoder(cfg) 271 | self.layer_norm = LayerNorm(self.embed) 272 | 273 | def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): 274 | """ 275 | Computes the output length of the convolutional layers 276 | """ 277 | 278 | def _conv_out_length(input_length, kernel_size, stride): 279 | return torch.floor((input_length - kernel_size) / stride + 1) 280 | 281 | conv_cfg_list = eval(self.cfg.conv_feature_layers) 282 | 283 | out_lengths_list = [] 284 | for i in range(len(conv_cfg_list)): 285 | input_lengths = _conv_out_length( 286 | input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2] 287 | ) 288 | out_lengths_list.append(input_lengths) 289 | 290 | return input_lengths.to(torch.long), out_lengths_list 291 | 292 | def apply_mask(self, x, padding_mask): 293 | B, T, C = x.shape 294 | if self.mask_prob > 0: 295 | mask_indices = compute_mask_indices( 296 | (B, T), 297 | padding_mask, 298 | self.mask_prob, 299 | self.mask_length, 300 | self.mask_selection, 301 | self.mask_other, 302 | min_masks=2, 303 | no_overlap=self.no_mask_overlap, 304 | min_space=self.mask_min_space, 305 | ) 306 | mask_indices = torch.from_numpy(mask_indices).to(x.device) 307 | x[mask_indices] = self.mask_emb 308 | else: 309 | mask_indices = None 310 | 311 | if self.mask_channel_prob > 0: 312 | mask_channel_indices = compute_mask_indices( 313 | (B, C), 314 | None, 315 | self.mask_channel_prob, 316 | self.mask_channel_length, 317 | self.mask_channel_selection, 318 | self.mask_channel_other, 319 | no_overlap=self.no_mask_channel_overlap, 320 | min_space=self.mask_channel_min_space, 321 | ) 322 | mask_channel_indices = ( 323 | torch.from_numpy(mask_channel_indices) 324 | .to(x.device) 325 | .unsqueeze(1) 326 | .expand(-1, T, -1) 327 | ) 328 | x[mask_channel_indices] = 0 329 | 330 | return x, mask_indices 331 | 332 | def forward_padding_mask( 333 | self, features: torch.Tensor, padding_mask: torch.Tensor, 334 | ) -> torch.Tensor: 335 | extra = padding_mask.size(1) % features.size(1) 336 | if extra > 0: 337 | padding_mask = padding_mask[:, :-extra] 338 | padding_mask = padding_mask.view( 339 | padding_mask.size(0), features.size(1), -1 340 | ) 341 | padding_mask = padding_mask.all(-1) 342 | return padding_mask 343 | 344 | def sequence_mask(self, sequence_length, max_len=None): 345 | """Create a sequence mask for filtering padding in a sequence tensor. 346 | Args: 347 | sequence_length (torch.tensor): Sequence lengths. 348 | max_len (int, Optional): Maximum sequence length. Defaults to None. 349 | Shapes: 350 | - mask: :math:`[B, T_max]` 351 | """ 352 | if max_len is None: 353 | max_len = sequence_length.data.max() 354 | seq_range = torch.arange(max_len, 355 | dtype=sequence_length.dtype, 356 | device=sequence_length.device) 357 | # B x T_max 358 | mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) 359 | return mask 360 | 361 | def extract_features( 362 | self, 363 | source: torch.Tensor, 364 | padding_mask: Optional[torch.Tensor] = None, 365 | mask: bool = False, 366 | ret_conv: bool = False, 367 | output_layer: Optional[int] = None, 368 | ret_layer_results: bool = False, 369 | input_length: Optional[torch.Tensor] = None 370 | ): 371 | out_lengths_list = None 372 | if input_length is not None: 373 | out_conv_lengths, out_lengths_list = self._get_feat_extract_output_lengths(input_length) 374 | else: 375 | out_conv_lengths, out_lengths_list = self._get_feat_extract_output_lengths(torch.tensor([source.shape[-1] for _ in range(source.shape[0])]).to(source.device)) 376 | 377 | if self.feature_grad_mult > 0: 378 | features = self.feature_extractor(source, input_lengths=input_length, out_lengths_list=out_lengths_list) 379 | if self.feature_grad_mult != 1.0: 380 | features = GradMultiply.apply(features, self.feature_grad_mult) 381 | else: 382 | with torch.no_grad(): 383 | features = self.feature_extractor(source) 384 | 385 | features = features.transpose(1, 2) 386 | features = self.layer_norm(features) 387 | 388 | # if padding_mask is not None: 389 | # padding_mask = self.forward_padding_mask(features, padding_mask) 390 | 391 | if self.post_extract_proj is not None: 392 | features *= self.sequence_mask(out_conv_lengths).unsqueeze(-1) 393 | features = self.post_extract_proj(features) 394 | features *= self.sequence_mask(out_conv_lengths).unsqueeze(-1) 395 | 396 | 397 | features = self.dropout_input(features) 398 | # return features 399 | 400 | if mask: 401 | x, mask_indices = self.apply_mask( 402 | features, padding_mask 403 | ) 404 | else: 405 | x = features 406 | 407 | # feature: (B, T, D), float 408 | # target: (B, T), long 409 | # x: (B, T, D), float 410 | # padding_mask: (B, T), bool 411 | # mask_indices: (B, T), bool 412 | if source.shape[0] == 1: 413 | padding_mask = None 414 | else: 415 | padding_mask = ~self.sequence_mask(out_conv_lengths) 416 | 417 | x, layer_results = self.encoder( 418 | x, 419 | padding_mask=padding_mask, 420 | layer=None if output_layer is None else output_layer - 1 421 | ) 422 | 423 | res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} 424 | 425 | feature = res["features"] if ret_conv else res["x"] 426 | if ret_layer_results: 427 | feature = (feature, res["layer_results"]) 428 | return feature, res["padding_mask"] 429 | 430 | 431 | def long_term_modeling( 432 | self, 433 | source: torch.Tensor, 434 | padding_mask: Optional[torch.Tensor] = None, 435 | mask: bool = False, 436 | ret_conv: bool = False, 437 | output_layer: Optional[int] = None, 438 | ret_layer_results: bool = False, 439 | ): 440 | 441 | features = source.transpose(1, 2) 442 | features = self.layer_norm(features) 443 | 444 | if padding_mask is not None: 445 | padding_mask = self.forward_padding_mask(features, padding_mask) 446 | 447 | if self.post_extract_proj is not None: 448 | features = self.post_extract_proj(features) 449 | 450 | features = self.dropout_input(features) 451 | 452 | if mask: 453 | x, mask_indices = self.apply_mask( 454 | features, padding_mask 455 | ) 456 | else: 457 | x = features 458 | 459 | # feature: (B, T, D), float 460 | # target: (B, T), long 461 | # x: (B, T, D), float 462 | # padding_mask: (B, T), bool 463 | # mask_indices: (B, T), bool 464 | x, layer_results = self.encoder( 465 | x, 466 | padding_mask=padding_mask, 467 | layer=None if output_layer is None else output_layer - 1 468 | ) 469 | 470 | res = {"x": x, "padding_mask": padding_mask, "features": features, "layer_results": layer_results} 471 | 472 | feature = res["features"] if ret_conv else res["x"] 473 | if ret_layer_results: 474 | feature = (feature, res["layer_results"]) 475 | return feature, res["padding_mask"] 476 | 477 | 478 | 479 | class ConvFeatureExtractionModel(nn.Module): 480 | def __init__( 481 | self, 482 | conv_layers: List[Tuple[int, int, int]], 483 | dropout: float = 0.0, 484 | mode: str = "default", 485 | conv_bias: bool = False, 486 | conv_type: str = "default" 487 | ): 488 | super().__init__() 489 | 490 | assert mode in {"default", "layer_norm"} 491 | 492 | def block( 493 | n_in, 494 | n_out, 495 | k, 496 | stride, 497 | is_layer_norm=False, 498 | is_group_norm=False, 499 | conv_bias=False, 500 | ): 501 | def make_conv(): 502 | conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) 503 | nn.init.kaiming_normal_(conv.weight) 504 | return conv 505 | 506 | assert ( 507 | is_layer_norm and is_group_norm 508 | ) == False, "layer norm and group norm are exclusive" 509 | 510 | if is_layer_norm: 511 | return nn.Sequential( 512 | make_conv(), 513 | nn.Dropout(p=dropout), 514 | nn.Sequential( 515 | TransposeLast(), 516 | Fp32LayerNorm(dim, elementwise_affine=True), 517 | TransposeLast(), 518 | ), 519 | nn.GELU(), 520 | ) 521 | # elif is_group_norm: 522 | # return nn.Sequential( 523 | # make_conv(), 524 | # nn.Dropout(p=dropout), 525 | # Fp32GroupNorm(dim, dim, affine=True), 526 | # nn.GELU(), 527 | # ) 528 | # else: 529 | # return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) 530 | 531 | self.conv_type = conv_type 532 | if self.conv_type == "default": 533 | in_d = 1 534 | self.conv_layers = nn.ModuleList() 535 | for i, cl in enumerate(conv_layers): 536 | assert len(cl) == 3, "invalid conv definition: " + str(cl) 537 | (dim, k, stride) = cl 538 | 539 | self.conv_layers.append( 540 | block( 541 | in_d, 542 | dim, 543 | k, 544 | stride, 545 | is_layer_norm=mode == "layer_norm", 546 | is_group_norm=mode == "default" and i == 0, 547 | conv_bias=conv_bias, 548 | ) 549 | ) 550 | in_d = dim 551 | elif self.conv_type == "conv2d": 552 | in_d = 1 553 | self.conv_layers = nn.ModuleList() 554 | for i, cl in enumerate(conv_layers): 555 | assert len(cl) == 3 556 | (dim, k, stride) = cl 557 | 558 | self.conv_layers.append( 559 | torch.nn.Conv2d(in_d, dim, k, stride) 560 | ) 561 | self.conv_layers.append(torch.nn.ReLU()) 562 | in_d = dim 563 | elif self.conv_type == "custom": 564 | in_d = 1 565 | idim = 80 566 | self.conv_layers = nn.ModuleList() 567 | for i, cl in enumerate(conv_layers): 568 | assert len(cl) == 3 569 | (dim, k, stride) = cl 570 | self.conv_layers.append( 571 | torch.nn.Conv2d(in_d, dim, k, stride, padding=1) 572 | ) 573 | self.conv_layers.append( 574 | torch.nn.LayerNorm([dim, idim]) 575 | ) 576 | self.conv_layers.append(torch.nn.ReLU()) 577 | in_d = dim 578 | if (i + 1) % 2 == 0: 579 | self.conv_layers.append( 580 | torch.nn.MaxPool2d(2, stride=2, ceil_mode=True) 581 | ) 582 | idim = int(math.ceil(idim / 2)) 583 | else: 584 | pass 585 | 586 | def sequence_mask(self, sequence_length, max_len=None): 587 | """Create a sequence mask for filtering padding in a sequence tensor. 588 | Args: 589 | sequence_length (torch.tensor): Sequence lengths. 590 | max_len (int, Optional): Maximum sequence length. Defaults to None. 591 | Shapes: 592 | - mask: :math:`[B, T_max]` 593 | """ 594 | if max_len is None: 595 | max_len = sequence_length.data.max() 596 | seq_range = torch.arange(max_len, 597 | dtype=sequence_length.dtype, 598 | device=sequence_length.device) 599 | # B x T_max 600 | mask = seq_range.unsqueeze(0) < sequence_length.unsqueeze(1) 601 | return mask 602 | 603 | def forward(self, x, mask=None, input_lengths=None, out_lengths_list=None): 604 | 605 | # BxT -> BxCxT 606 | x = x.unsqueeze(1) 607 | # if self.conv_type == "custom": 608 | # for conv in self.conv_layers: 609 | # if isinstance(conv, nn.LayerNorm): 610 | # x = x.transpose(1, 2) 611 | # x = conv(x).transpose(1, 2) 612 | # else: 613 | # x = conv(x) 614 | # x = x.transpose(2, 3).contiguous() 615 | # x = x.view(x.size(0), -1, x.size(-1)) 616 | # else: 617 | 618 | for idx, conv in enumerate(self.conv_layers): 619 | x = conv(x) 620 | # if idx == 0: 621 | # x = conv(x * self.sequence_mask(input_lengths).unsqueeze(1)) 622 | # else: 623 | # if len(out_lengths_list[idx-1]) == 1: 624 | # x = conv(x * self.sequence_mask(out_lengths_list[idx-1])) 625 | # else: 626 | # x = conv(x * self.sequence_mask(out_lengths_list[idx-1]).unsqueeze(1)) 627 | # if len(out_lengths_list[idx-1]) == 1: 628 | # x *= self.sequence_mask(out_lengths_list[idx].unsqueeze(0)) 629 | # else: 630 | # x *= self.sequence_mask(out_lengths_list[idx].unsqueeze(1)) 631 | # if self.conv_type == "conv2d": 632 | # b, c, t, f = x.size() 633 | # x = x.transpose(2, 3).contiguous().view(b, c * f, t) 634 | return x 635 | 636 | 637 | class TransformerEncoder(nn.Module): 638 | def __init__(self, args): 639 | super().__init__() 640 | 641 | self.dropout = args.dropout 642 | self.embedding_dim = args.encoder_embed_dim 643 | 644 | self.pos_conv = nn.Conv1d( 645 | self.embedding_dim, 646 | self.embedding_dim, 647 | kernel_size=args.conv_pos, 648 | padding=args.conv_pos // 2, 649 | groups=args.conv_pos_groups, 650 | ) 651 | dropout = 0 652 | std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim)) 653 | nn.init.normal_(self.pos_conv.weight, mean=0, std=std) 654 | nn.init.constant_(self.pos_conv.bias, 0) 655 | 656 | self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2) 657 | self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU()) 658 | 659 | if hasattr(args, "relative_position_embedding"): 660 | self.relative_position_embedding = args.relative_position_embedding 661 | self.num_buckets = args.num_buckets 662 | self.max_distance = args.max_distance 663 | else: 664 | self.relative_position_embedding = False 665 | self.num_buckets = 0 666 | self.max_distance = 0 667 | 668 | self.layers = nn.ModuleList( 669 | [ 670 | TransformerSentenceEncoderLayer( 671 | embedding_dim=self.embedding_dim, 672 | ffn_embedding_dim=args.encoder_ffn_embed_dim, 673 | num_attention_heads=args.encoder_attention_heads, 674 | dropout=self.dropout, 675 | attention_dropout=args.attention_dropout, 676 | activation_dropout=args.activation_dropout, 677 | activation_fn=args.activation_fn, 678 | layer_norm_first=args.layer_norm_first, 679 | has_relative_attention_bias=(self.relative_position_embedding and i == 0), 680 | num_buckets=self.num_buckets, 681 | max_distance=self.max_distance, 682 | gru_rel_pos=args.gru_rel_pos, 683 | ) 684 | for i in range(args.encoder_layers) 685 | ] 686 | ) 687 | 688 | self.layer_norm_first = args.layer_norm_first 689 | self.layer_norm = LayerNorm(self.embedding_dim) 690 | self.layerdrop = args.encoder_layerdrop 691 | 692 | self.apply(init_bert_params) 693 | 694 | def forward(self, x, padding_mask=None, streaming_mask=None, layer=None): 695 | x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer) 696 | 697 | if self.layer_norm_first and layer is None: 698 | x = self.layer_norm(x) 699 | 700 | return x, layer_results 701 | 702 | def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None): 703 | 704 | if padding_mask is not None: 705 | x[padding_mask] = 0 706 | 707 | y = x.transpose(1, 2).clone() 708 | x_conv = self.pos_conv(y) 709 | x_conv = x_conv.transpose(1, 2) 710 | x += x_conv 711 | 712 | if not self.layer_norm_first: 713 | x = self.layer_norm(x) 714 | 715 | x = F.dropout(x, p=self.dropout, training=self.training) 716 | 717 | # B x T x C -> T x B x C 718 | x = x.transpose(0, 1) 719 | 720 | layer_results = [] 721 | z = None 722 | if tgt_layer is not None: 723 | layer_results.append((x, z)) 724 | r = None 725 | pos_bias = None 726 | for i, layer in enumerate(self.layers): 727 | dropout_probability = np.random.random() 728 | if not self.training or (dropout_probability > self.layerdrop): 729 | x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, 730 | self_attn_mask=streaming_mask, pos_bias=pos_bias) 731 | if tgt_layer is not None: 732 | layer_results.append((x, z)) 733 | if i == tgt_layer: 734 | r = x 735 | break 736 | 737 | if r is not None: 738 | x = r 739 | 740 | # T x B x C -> B x T x C 741 | x = x.transpose(0, 1) 742 | 743 | return x, layer_results 744 | 745 | 746 | class TransformerSentenceEncoderLayer(nn.Module): 747 | """ 748 | Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained 749 | models. 750 | """ 751 | 752 | def __init__( 753 | self, 754 | embedding_dim: float = 768, 755 | ffn_embedding_dim: float = 3072, 756 | num_attention_heads: float = 8, 757 | dropout: float = 0.1, 758 | attention_dropout: float = 0.1, 759 | activation_dropout: float = 0.1, 760 | activation_fn: str = "relu", 761 | layer_norm_first: bool = False, 762 | has_relative_attention_bias: bool = False, 763 | num_buckets: int = 0, 764 | max_distance: int = 0, 765 | rescale_init: bool = False, 766 | gru_rel_pos: bool = False, 767 | ) -> None: 768 | 769 | super().__init__() 770 | # Initialize parameters 771 | self.embedding_dim = embedding_dim 772 | self.dropout = dropout 773 | self.activation_dropout = activation_dropout 774 | 775 | # Initialize blocks 776 | self.activation_name = activation_fn 777 | self.activation_fn = get_activation_fn(activation_fn) 778 | self.self_attn = MultiheadAttention( 779 | self.embedding_dim, 780 | num_attention_heads, 781 | dropout=attention_dropout, 782 | self_attention=True, 783 | has_relative_attention_bias=has_relative_attention_bias, 784 | num_buckets=num_buckets, 785 | max_distance=max_distance, 786 | rescale_init=rescale_init, 787 | gru_rel_pos=gru_rel_pos, 788 | ) 789 | 790 | self.dropout1 = nn.Dropout(dropout) 791 | self.dropout2 = nn.Dropout(self.activation_dropout) 792 | self.dropout3 = nn.Dropout(dropout) 793 | 794 | self.layer_norm_first = layer_norm_first 795 | 796 | # layer norm associated with the self attention layer 797 | self.self_attn_layer_norm = LayerNorm(self.embedding_dim) 798 | 799 | if self.activation_name == "glu": 800 | self.fc1 = GLU_Linear(self.embedding_dim, ffn_embedding_dim, "swish") 801 | else: 802 | self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) 803 | self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) 804 | 805 | # layer norm associated with the position wise feed-forward NN 806 | self.final_layer_norm = LayerNorm(self.embedding_dim) 807 | 808 | def forward( 809 | self, 810 | x: torch.Tensor, 811 | self_attn_mask: torch.Tensor = None, 812 | self_attn_padding_mask: torch.Tensor = None, 813 | need_weights: bool = False, 814 | pos_bias=None 815 | ): 816 | """ 817 | LayerNorm is applied either before or after the self-attention/ffn 818 | modules similar to the original Transformer imlementation. 819 | """ 820 | residual = x 821 | 822 | if self.layer_norm_first: 823 | x = self.self_attn_layer_norm(x) 824 | x, attn, pos_bias = self.self_attn( 825 | query=x, 826 | key=x, 827 | value=x, 828 | key_padding_mask=self_attn_padding_mask, 829 | need_weights=False, 830 | attn_mask=self_attn_mask, 831 | position_bias=pos_bias 832 | ) 833 | x = self.dropout1(x) 834 | x = residual + x 835 | 836 | residual = x 837 | x = self.final_layer_norm(x) 838 | if self.activation_name == "glu": 839 | x = self.fc1(x) 840 | else: 841 | x = self.activation_fn(self.fc1(x)) 842 | x = self.dropout2(x) 843 | x = self.fc2(x) 844 | x = self.dropout3(x) 845 | x = residual + x 846 | else: 847 | x, attn, pos_bias = self.self_attn( 848 | query=x, 849 | key=x, 850 | value=x, 851 | key_padding_mask=self_attn_padding_mask, 852 | need_weights=need_weights, 853 | attn_mask=self_attn_mask, 854 | position_bias=pos_bias 855 | ) 856 | 857 | x = self.dropout1(x) 858 | x = residual + x 859 | 860 | x = self.self_attn_layer_norm(x) 861 | 862 | residual = x 863 | if self.activation_name == "glu": 864 | x = self.fc1(x) 865 | else: 866 | x = self.activation_fn(self.fc1(x)) 867 | x = self.dropout2(x) 868 | x = self.fc2(x) 869 | x = self.dropout3(x) 870 | x = residual + x 871 | x = self.final_layer_norm(x) 872 | 873 | return x, attn, pos_bias 874 | -------------------------------------------------------------------------------- /components/semantic_extractor/modules.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # WavLM: Large-Scale Self-Supervised Pre-training for Full Stack Speech Processing (https://arxiv.org/abs/2110.13900.pdf) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/wavlm 4 | # Copyright (c) 2021 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on fairseq code bases 7 | # https://github.com/pytorch/fairseq 8 | # -------------------------------------------------------- 9 | 10 | import math 11 | import warnings 12 | from typing import Dict, Optional, Tuple 13 | import torch 14 | from torch import Tensor, nn 15 | from torch.nn import Parameter 16 | import torch.nn.functional as F 17 | 18 | class TransposeLast(nn.Module): 19 | def __init__(self, deconstruct_idx=None): 20 | super().__init__() 21 | self.deconstruct_idx = deconstruct_idx 22 | 23 | def forward(self, x): 24 | if self.deconstruct_idx is not None: 25 | x = x[self.deconstruct_idx] 26 | return x.transpose(-2, -1) 27 | 28 | 29 | class Fp32LayerNorm(nn.LayerNorm): 30 | def __init__(self, *args, **kwargs): 31 | super().__init__(*args, **kwargs) 32 | 33 | def forward(self, input): 34 | output = F.layer_norm( 35 | input.float(), 36 | self.normalized_shape, 37 | self.weight.float() if self.weight is not None else None, 38 | self.bias.float() if self.bias is not None else None, 39 | self.eps, 40 | ) 41 | return output.type_as(input) 42 | 43 | 44 | class Fp32GroupNorm(nn.GroupNorm): 45 | def __init__(self, *args, **kwargs): 46 | super().__init__(*args, **kwargs) 47 | 48 | def forward(self, input): 49 | output = F.group_norm( 50 | input.float(), 51 | self.num_groups, 52 | self.weight.float() if self.weight is not None else None, 53 | self.bias.float() if self.bias is not None else None, 54 | self.eps, 55 | ) 56 | return output.type_as(input) 57 | 58 | 59 | class GradMultiply(torch.autograd.Function): 60 | @staticmethod 61 | def forward(ctx, x, scale): 62 | ctx.scale = scale 63 | res = x.new(x) 64 | return res 65 | 66 | @staticmethod 67 | def backward(ctx, grad): 68 | return grad * ctx.scale, None 69 | 70 | 71 | class SamePad(nn.Module): 72 | def __init__(self, kernel_size, causal=False): 73 | super().__init__() 74 | if causal: 75 | self.remove = kernel_size - 1 76 | else: 77 | self.remove = 1 if kernel_size % 2 == 0 else 0 78 | 79 | def forward(self, x): 80 | if self.remove > 0: 81 | x = x[:, :, : -self.remove] 82 | return x 83 | 84 | 85 | class Swish(nn.Module): 86 | """Swish function 87 | """ 88 | 89 | def __init__(self): 90 | """Construct an MultiHeadedAttention object.""" 91 | super(Swish, self).__init__() 92 | self.act = torch.nn.Sigmoid() 93 | 94 | def forward(self, x): 95 | return x * self.act(x) 96 | 97 | 98 | class GLU_Linear(nn.Module): 99 | def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): 100 | super(GLU_Linear, self).__init__() 101 | 102 | self.glu_type = glu_type 103 | self.output_dim = output_dim 104 | 105 | if glu_type == "sigmoid": 106 | self.glu_act = torch.nn.Sigmoid() 107 | elif glu_type == "swish": 108 | self.glu_act = Swish() 109 | elif glu_type == "relu": 110 | self.glu_act = torch.nn.ReLU() 111 | elif glu_type == "gelu": 112 | self.glu_act = torch.nn.GELU() 113 | 114 | if bias_in_glu: 115 | self.linear = nn.Linear(input_dim, output_dim * 2, True) 116 | else: 117 | self.linear = nn.Linear(input_dim, output_dim * 2, False) 118 | 119 | def forward(self, x): 120 | # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case 121 | x = self.linear(x) 122 | 123 | if self.glu_type == "bilinear": 124 | x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) 125 | else: 126 | x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) 127 | 128 | return x 129 | 130 | def gelu_accurate(x): 131 | if not hasattr(gelu_accurate, "_a"): 132 | gelu_accurate._a = math.sqrt(2 / math.pi) 133 | return ( 134 | 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) 135 | ) 136 | 137 | 138 | def gelu(x: torch.Tensor) -> torch.Tensor: 139 | return torch.nn.functional.gelu(x.float()).type_as(x) 140 | 141 | 142 | def get_activation_fn(activation: str): 143 | """Returns the activation function corresponding to `activation`""" 144 | 145 | if activation == "relu": 146 | return F.relu 147 | elif activation == "gelu": 148 | return gelu 149 | elif activation == "gelu_fast": 150 | warnings.warn( 151 | "--activation-fn=gelu_fast has been renamed to gelu_accurate" 152 | ) 153 | return gelu_accurate 154 | elif activation == "gelu_accurate": 155 | return gelu_accurate 156 | elif activation == "tanh": 157 | return torch.tanh 158 | elif activation == "linear": 159 | return lambda x: x 160 | elif activation == "glu": 161 | return lambda x: x 162 | else: 163 | raise RuntimeError("--activation-fn {} not supported".format(activation)) 164 | 165 | 166 | def init_bert_params(module): 167 | """ 168 | Initialize the weights specific to the BERT Model. 169 | This overrides the default initializations depending on the specified arguments. 170 | 1. If normal_init_linear_weights is set then weights of linear 171 | layer will be initialized using the normal distribution and 172 | bais will be set to the specified value. 173 | 2. If normal_init_embed_weights is set then weights of embedding 174 | layer will be initialized using the normal distribution. 175 | 3. If normal_init_proj_weights is set then weights of 176 | in_project_weight for MultiHeadAttention initialized using 177 | the normal distribution (to be validated). 178 | """ 179 | 180 | def normal_(data): 181 | # with FSDP, module params will be on CUDA, so we cast them back to CPU 182 | # so that the RNG is consistent with and without FSDP 183 | data.copy_( 184 | data.cpu().normal_(mean=0.0, std=0.02).to(data.device) 185 | ) 186 | 187 | if isinstance(module, nn.Linear): 188 | normal_(module.weight.data) 189 | if module.bias is not None: 190 | module.bias.data.zero_() 191 | if isinstance(module, nn.Embedding): 192 | normal_(module.weight.data) 193 | if module.padding_idx is not None: 194 | module.weight.data[module.padding_idx].zero_() 195 | if isinstance(module, MultiheadAttention): 196 | normal_(module.q_proj.weight.data) 197 | normal_(module.k_proj.weight.data) 198 | normal_(module.v_proj.weight.data) 199 | 200 | 201 | def quant_noise(module, p, block_size): 202 | """ 203 | Wraps modules and applies quantization noise to the weights for 204 | subsequent quantization with Iterative Product Quantization as 205 | described in "Training with Quantization Noise for Extreme Model Compression" 206 | 207 | Args: 208 | - module: nn.Module 209 | - p: amount of Quantization Noise 210 | - block_size: size of the blocks for subsequent quantization with iPQ 211 | 212 | Remarks: 213 | - Module weights must have the right sizes wrt the block size 214 | - Only Linear, Embedding and Conv2d modules are supported for the moment 215 | - For more detail on how to quantize by blocks with convolutional weights, 216 | see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" 217 | - We implement the simplest form of noise here as stated in the paper 218 | which consists in randomly dropping blocks 219 | """ 220 | 221 | # if no quantization noise, don't register hook 222 | if p <= 0: 223 | return module 224 | 225 | # supported modules 226 | assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) 227 | 228 | # test whether module.weight has the right sizes wrt block_size 229 | is_conv = module.weight.ndim == 4 230 | 231 | # 2D matrix 232 | if not is_conv: 233 | assert ( 234 | module.weight.size(1) % block_size == 0 235 | ), "Input features must be a multiple of block sizes" 236 | 237 | # 4D matrix 238 | else: 239 | # 1x1 convolutions 240 | if module.kernel_size == (1, 1): 241 | assert ( 242 | module.in_channels % block_size == 0 243 | ), "Input channels must be a multiple of block sizes" 244 | # regular convolutions 245 | else: 246 | k = module.kernel_size[0] * module.kernel_size[1] 247 | assert k % block_size == 0, "Kernel size must be a multiple of block size" 248 | 249 | def _forward_pre_hook(mod, input): 250 | # no noise for evaluation 251 | if mod.training: 252 | if not is_conv: 253 | # gather weight and sizes 254 | weight = mod.weight 255 | in_features = weight.size(1) 256 | out_features = weight.size(0) 257 | 258 | # split weight matrix into blocks and randomly drop selected blocks 259 | mask = torch.zeros( 260 | in_features // block_size * out_features, device=weight.device 261 | ) 262 | mask.bernoulli_(p) 263 | mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) 264 | 265 | else: 266 | # gather weight and sizes 267 | weight = mod.weight 268 | in_channels = mod.in_channels 269 | out_channels = mod.out_channels 270 | 271 | # split weight matrix into blocks and randomly drop selected blocks 272 | if mod.kernel_size == (1, 1): 273 | mask = torch.zeros( 274 | int(in_channels // block_size * out_channels), 275 | device=weight.device, 276 | ) 277 | mask.bernoulli_(p) 278 | mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) 279 | else: 280 | mask = torch.zeros( 281 | weight.size(0), weight.size(1), device=weight.device 282 | ) 283 | mask.bernoulli_(p) 284 | mask = ( 285 | mask.unsqueeze(2) 286 | .unsqueeze(3) 287 | .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) 288 | ) 289 | 290 | # scale weights and apply mask 291 | mask = mask.to( 292 | torch.bool 293 | ) # x.bool() is not currently supported in TorchScript 294 | s = 1 / (1 - p) 295 | mod.weight.data = s * weight.masked_fill(mask, 0) 296 | 297 | module.register_forward_pre_hook(_forward_pre_hook) 298 | return module 299 | 300 | 301 | class MultiheadAttention(nn.Module): 302 | """Multi-headed attention. 303 | 304 | See "Attention Is All You Need" for more details. 305 | """ 306 | 307 | def __init__( 308 | self, 309 | embed_dim, 310 | num_heads, 311 | kdim=None, 312 | vdim=None, 313 | dropout=0.0, 314 | bias=True, 315 | add_bias_kv=False, 316 | add_zero_attn=False, 317 | self_attention=False, 318 | encoder_decoder_attention=False, 319 | q_noise=0.0, 320 | qn_block_size=8, 321 | has_relative_attention_bias=False, 322 | num_buckets=32, 323 | max_distance=128, 324 | gru_rel_pos=False, 325 | rescale_init=False, 326 | ): 327 | super().__init__() 328 | self.embed_dim = embed_dim 329 | self.kdim = kdim if kdim is not None else embed_dim 330 | self.vdim = vdim if vdim is not None else embed_dim 331 | self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim 332 | 333 | self.num_heads = num_heads 334 | self.dropout_module = nn.Dropout(dropout) 335 | 336 | self.has_relative_attention_bias = has_relative_attention_bias 337 | self.num_buckets = num_buckets 338 | self.max_distance = max_distance 339 | if self.has_relative_attention_bias: 340 | self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) 341 | 342 | self.head_dim = embed_dim // num_heads 343 | self.q_head_dim = self.head_dim 344 | self.k_head_dim = self.head_dim 345 | assert ( 346 | self.head_dim * num_heads == self.embed_dim 347 | ), "embed_dim must be divisible by num_heads" 348 | self.scaling = self.head_dim ** -0.5 349 | 350 | self.self_attention = self_attention 351 | self.encoder_decoder_attention = encoder_decoder_attention 352 | 353 | assert not self.self_attention or self.qkv_same_dim, ( 354 | "Self-attention requires query, key and " "value to be of the same size" 355 | ) 356 | 357 | k_bias = True 358 | if rescale_init: 359 | k_bias = False 360 | 361 | k_embed_dim = embed_dim 362 | q_embed_dim = embed_dim 363 | 364 | self.k_proj = quant_noise( 365 | nn.Linear(self.kdim, k_embed_dim, bias=k_bias), q_noise, qn_block_size 366 | ) 367 | self.v_proj = quant_noise( 368 | nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size 369 | ) 370 | self.q_proj = quant_noise( 371 | nn.Linear(embed_dim, q_embed_dim, bias=bias), q_noise, qn_block_size 372 | ) 373 | 374 | self.out_proj = quant_noise( 375 | nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size 376 | ) 377 | 378 | if add_bias_kv: 379 | self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) 380 | self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) 381 | else: 382 | self.bias_k = self.bias_v = None 383 | 384 | self.add_zero_attn = add_zero_attn 385 | 386 | self.gru_rel_pos = gru_rel_pos 387 | if self.gru_rel_pos: 388 | self.grep_linear = nn.Linear(self.q_head_dim, 8) 389 | self.grep_a = nn.Parameter(torch.ones(1, num_heads, 1, 1)) 390 | 391 | self.reset_parameters() 392 | 393 | def reset_parameters(self): 394 | if self.qkv_same_dim: 395 | # Empirically observed the convergence to be much better with 396 | # the scaled initialization 397 | nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) 398 | nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) 399 | nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) 400 | else: 401 | nn.init.xavier_uniform_(self.k_proj.weight) 402 | nn.init.xavier_uniform_(self.v_proj.weight) 403 | nn.init.xavier_uniform_(self.q_proj.weight) 404 | 405 | nn.init.xavier_uniform_(self.out_proj.weight) 406 | if self.out_proj.bias is not None: 407 | nn.init.constant_(self.out_proj.bias, 0.0) 408 | if self.bias_k is not None: 409 | nn.init.xavier_normal_(self.bias_k) 410 | if self.bias_v is not None: 411 | nn.init.xavier_normal_(self.bias_v) 412 | if self.has_relative_attention_bias: 413 | nn.init.xavier_normal_(self.relative_attention_bias.weight) 414 | 415 | def _relative_positions_bucket(self, relative_positions, bidirectional=True): 416 | num_buckets = self.num_buckets 417 | max_distance = self.max_distance 418 | relative_buckets = 0 419 | 420 | if bidirectional: 421 | num_buckets = num_buckets // 2 422 | relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets 423 | relative_positions = torch.abs(relative_positions) 424 | else: 425 | relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions)) 426 | 427 | max_exact = num_buckets // 2 428 | is_small = relative_positions < max_exact 429 | 430 | relative_postion_if_large = max_exact + ( 431 | torch.log(relative_positions.float() / max_exact) 432 | / math.log(max_distance / max_exact) 433 | * (num_buckets - max_exact) 434 | ).to(torch.long) 435 | relative_postion_if_large = torch.min( 436 | relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1) 437 | ) 438 | 439 | relative_buckets += torch.where(is_small, relative_positions, relative_postion_if_large) 440 | return relative_buckets 441 | 442 | def compute_bias(self, query_length, key_length): 443 | context_position = torch.arange(query_length, dtype=torch.long)[:, None] 444 | memory_position = torch.arange(key_length, dtype=torch.long)[None, :] 445 | relative_position = memory_position - context_position 446 | relative_position_bucket = self._relative_positions_bucket( 447 | relative_position, 448 | bidirectional=True 449 | ) 450 | relative_position_bucket = relative_position_bucket.to(self.relative_attention_bias.weight.device) 451 | values = self.relative_attention_bias(relative_position_bucket) 452 | values = values.permute([2, 0, 1]) 453 | return values 454 | 455 | def forward( 456 | self, 457 | query, 458 | key: Optional[Tensor], 459 | value: Optional[Tensor], 460 | key_padding_mask: Optional[Tensor] = None, 461 | incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, 462 | need_weights: bool = True, 463 | static_kv: bool = False, 464 | attn_mask: Optional[Tensor] = None, 465 | before_softmax: bool = False, 466 | need_head_weights: bool = False, 467 | position_bias: Optional[Tensor] = None 468 | ) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]: 469 | """Input shape: Time x Batch x Channel 470 | 471 | Args: 472 | key_padding_mask (ByteTensor, optional): mask to exclude 473 | keys that are pads, of shape `(batch, src_len)`, where 474 | padding elements are indicated by 1s. 475 | need_weights (bool, optional): return the attention weights, 476 | averaged over heads (default: False). 477 | attn_mask (ByteTensor, optional): typically used to 478 | implement causal attention, where the mask prevents the 479 | attention from looking forward in time (default: None). 480 | before_softmax (bool, optional): return the raw attention 481 | weights and values before the attention softmax. 482 | need_head_weights (bool, optional): return the attention 483 | weights for each head. Implies *need_weights*. Default: 484 | return the average attention weights over all heads. 485 | """ 486 | if need_head_weights: 487 | need_weights = True 488 | 489 | is_tpu = query.device.type == "xla" 490 | 491 | tgt_len, bsz, embed_dim = query.size() 492 | src_len = tgt_len 493 | assert embed_dim == self.embed_dim 494 | assert list(query.size()) == [tgt_len, bsz, embed_dim] 495 | if key is not None: 496 | src_len, key_bsz, _ = key.size() 497 | if not torch.jit.is_scripting(): 498 | assert key_bsz == bsz 499 | assert value is not None 500 | assert src_len, bsz == value.shape[:2] 501 | 502 | if self.has_relative_attention_bias and position_bias is None: 503 | position_bias = self.compute_bias(tgt_len, src_len) 504 | position_bias = position_bias.unsqueeze(0).repeat(bsz, 1, 1, 1).view(bsz * self.num_heads, tgt_len, src_len) 505 | 506 | if ( 507 | not is_tpu # don't use PyTorch version on TPUs 508 | and incremental_state is None 509 | and not static_kv 510 | # A workaround for quantization to work. Otherwise JIT compilation 511 | # treats bias in linear module as method. 512 | and not torch.jit.is_scripting() 513 | and self.q_head_dim == self.head_dim 514 | ): 515 | assert key is not None and value is not None 516 | assert attn_mask is None 517 | 518 | attn_mask_rel_pos = None 519 | if position_bias is not None: 520 | attn_mask_rel_pos = position_bias 521 | if self.gru_rel_pos: 522 | query_layer = query.transpose(0, 1) 523 | new_x_shape = query_layer.size()[:-1] + (self.num_heads, -1) 524 | query_layer = query_layer.view(*new_x_shape) 525 | query_layer = query_layer.permute(0, 2, 1, 3) 526 | _B, _H, _L, __ = query_layer.size() 527 | 528 | gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( 529 | _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) 530 | gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 531 | attn_mask_rel_pos = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias 532 | 533 | attn_mask_rel_pos = attn_mask_rel_pos.view((-1, tgt_len, tgt_len)) 534 | k_proj_bias = self.k_proj.bias 535 | if k_proj_bias is None: 536 | k_proj_bias = torch.zeros_like(self.q_proj.bias) 537 | 538 | x, attn = F.multi_head_attention_forward( 539 | query, 540 | key, 541 | value, 542 | self.embed_dim, 543 | self.num_heads, 544 | torch.empty([0]), 545 | torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), 546 | self.bias_k, 547 | self.bias_v, 548 | self.add_zero_attn, 549 | self.dropout_module.p, 550 | self.out_proj.weight, 551 | self.out_proj.bias, 552 | self.training, 553 | # self.training or self.dropout_module.apply_during_inference, 554 | key_padding_mask, 555 | need_weights, 556 | attn_mask_rel_pos, 557 | use_separate_proj_weight=True, 558 | q_proj_weight=self.q_proj.weight, 559 | k_proj_weight=self.k_proj.weight, 560 | v_proj_weight=self.v_proj.weight, 561 | ) 562 | return x, attn, position_bias 563 | 564 | if incremental_state is not None: 565 | saved_state = self._get_input_buffer(incremental_state) 566 | if saved_state is not None and "prev_key" in saved_state: 567 | # previous time steps are cached - no need to recompute 568 | # key and value if they are static 569 | if static_kv: 570 | assert self.encoder_decoder_attention and not self.self_attention 571 | key = value = None 572 | else: 573 | saved_state = None 574 | 575 | if self.self_attention: 576 | q = self.q_proj(query) 577 | k = self.k_proj(query) 578 | v = self.v_proj(query) 579 | elif self.encoder_decoder_attention: 580 | # encoder-decoder attention 581 | q = self.q_proj(query) 582 | if key is None: 583 | assert value is None 584 | k = v = None 585 | else: 586 | k = self.k_proj(key) 587 | v = self.v_proj(key) 588 | 589 | else: 590 | assert key is not None and value is not None 591 | q = self.q_proj(query) 592 | k = self.k_proj(key) 593 | v = self.v_proj(value) 594 | q *= self.scaling 595 | 596 | if self.bias_k is not None: 597 | assert self.bias_v is not None 598 | k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) 599 | v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) 600 | if attn_mask is not None: 601 | attn_mask = torch.cat( 602 | [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 603 | ) 604 | if key_padding_mask is not None: 605 | key_padding_mask = torch.cat( 606 | [ 607 | key_padding_mask, 608 | key_padding_mask.new_zeros(key_padding_mask.size(0), 1), 609 | ], 610 | dim=1, 611 | ) 612 | 613 | q = ( 614 | q.contiguous() 615 | .view(tgt_len, bsz * self.num_heads, self.q_head_dim) 616 | .transpose(0, 1) 617 | ) 618 | if k is not None: 619 | k = ( 620 | k.contiguous() 621 | .view(-1, bsz * self.num_heads, self.k_head_dim) 622 | .transpose(0, 1) 623 | ) 624 | if v is not None: 625 | v = ( 626 | v.contiguous() 627 | .view(-1, bsz * self.num_heads, self.head_dim) 628 | .transpose(0, 1) 629 | ) 630 | 631 | if saved_state is not None: 632 | # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) 633 | if "prev_key" in saved_state: 634 | _prev_key = saved_state["prev_key"] 635 | assert _prev_key is not None 636 | prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim) 637 | if static_kv: 638 | k = prev_key 639 | else: 640 | assert k is not None 641 | k = torch.cat([prev_key, k], dim=1) 642 | src_len = k.size(1) 643 | if "prev_value" in saved_state: 644 | _prev_value = saved_state["prev_value"] 645 | assert _prev_value is not None 646 | prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim) 647 | if static_kv: 648 | v = prev_value 649 | else: 650 | assert v is not None 651 | v = torch.cat([prev_value, v], dim=1) 652 | prev_key_padding_mask: Optional[Tensor] = None 653 | if "prev_key_padding_mask" in saved_state: 654 | prev_key_padding_mask = saved_state["prev_key_padding_mask"] 655 | assert k is not None and v is not None 656 | key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( 657 | key_padding_mask=key_padding_mask, 658 | prev_key_padding_mask=prev_key_padding_mask, 659 | batch_size=bsz, 660 | src_len=k.size(1), 661 | static_kv=static_kv, 662 | ) 663 | 664 | saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim) 665 | saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim) 666 | saved_state["prev_key_padding_mask"] = key_padding_mask 667 | # In this branch incremental_state is never None 668 | assert incremental_state is not None 669 | incremental_state = self._set_input_buffer(incremental_state, saved_state) 670 | assert k is not None 671 | assert k.size(1) == src_len 672 | 673 | # This is part of a workaround to get around fork/join parallelism 674 | # not supporting Optional types. 675 | if key_padding_mask is not None and key_padding_mask.dim() == 0: 676 | key_padding_mask = None 677 | 678 | if key_padding_mask is not None: 679 | assert key_padding_mask.size(0) == bsz 680 | assert key_padding_mask.size(1) == src_len 681 | 682 | if self.add_zero_attn: 683 | assert v is not None 684 | src_len += 1 685 | k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) 686 | v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) 687 | if attn_mask is not None: 688 | attn_mask = torch.cat( 689 | [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1 690 | ) 691 | if key_padding_mask is not None: 692 | key_padding_mask = torch.cat( 693 | [ 694 | key_padding_mask, 695 | torch.zeros(key_padding_mask.size(0), 1).type_as( 696 | key_padding_mask 697 | ), 698 | ], 699 | dim=1, 700 | ) 701 | 702 | attn_weights = torch.bmm(q, k.transpose(1, 2)) 703 | attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) 704 | 705 | assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] 706 | 707 | if attn_mask is not None: 708 | attn_mask = attn_mask.unsqueeze(0) 709 | attn_weights += attn_mask 710 | 711 | if key_padding_mask is not None: 712 | # don't attend to padding symbols 713 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 714 | if not is_tpu: 715 | attn_weights = attn_weights.masked_fill( 716 | key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), 717 | float("-inf"), 718 | ) 719 | else: 720 | attn_weights = attn_weights.transpose(0, 2) 721 | attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) 722 | attn_weights = attn_weights.transpose(0, 2) 723 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 724 | 725 | if before_softmax: 726 | return attn_weights, v, position_bias 727 | 728 | if position_bias is not None: 729 | if self.gru_rel_pos == 1: 730 | query_layer = q.view(bsz, self.num_heads, tgt_len, self.q_head_dim) 731 | _B, _H, _L, __ = query_layer.size() 732 | gate_a, gate_b = torch.sigmoid(self.grep_linear(query_layer).view( 733 | _B, _H, _L, 2, 4).sum(-1, keepdim=False)).chunk(2, dim=-1) 734 | gate_a_1 = gate_a * (gate_b * self.grep_a - 1.0) + 2.0 735 | position_bias = gate_a_1.view(bsz * self.num_heads, -1, 1) * position_bias 736 | 737 | position_bias = position_bias.view(attn_weights.size()) 738 | 739 | attn_weights = attn_weights + position_bias 740 | 741 | attn_weights_float = F.softmax( 742 | attn_weights, dim=-1 743 | ) 744 | attn_weights = attn_weights_float.type_as(attn_weights) 745 | attn_probs = self.dropout_module(attn_weights) 746 | 747 | assert v is not None 748 | attn = torch.bmm(attn_probs, v) 749 | assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] 750 | attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 751 | attn = self.out_proj(attn) 752 | attn_weights: Optional[Tensor] = None 753 | if need_weights: 754 | attn_weights = attn_weights_float.view( 755 | bsz, self.num_heads, tgt_len, src_len 756 | ).transpose(1, 0) 757 | if not need_head_weights: 758 | # average attention weights over heads 759 | attn_weights = attn_weights.mean(dim=0) 760 | 761 | return attn, attn_weights, position_bias 762 | 763 | @staticmethod 764 | def _append_prev_key_padding_mask( 765 | key_padding_mask: Optional[Tensor], 766 | prev_key_padding_mask: Optional[Tensor], 767 | batch_size: int, 768 | src_len: int, 769 | static_kv: bool, 770 | ) -> Optional[Tensor]: 771 | # saved key padding masks have shape (bsz, seq_len) 772 | if prev_key_padding_mask is not None and static_kv: 773 | new_key_padding_mask = prev_key_padding_mask 774 | elif prev_key_padding_mask is not None and key_padding_mask is not None: 775 | new_key_padding_mask = torch.cat( 776 | [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 777 | ) 778 | # During incremental decoding, as the padding token enters and 779 | # leaves the frame, there will be a time when prev or current 780 | # is None 781 | elif prev_key_padding_mask is not None: 782 | if src_len > prev_key_padding_mask.size(1): 783 | filler = torch.zeros( 784 | (batch_size, src_len - prev_key_padding_mask.size(1)), 785 | device=prev_key_padding_mask.device, 786 | ) 787 | new_key_padding_mask = torch.cat( 788 | [prev_key_padding_mask.float(), filler.float()], dim=1 789 | ) 790 | else: 791 | new_key_padding_mask = prev_key_padding_mask.float() 792 | elif key_padding_mask is not None: 793 | if src_len > key_padding_mask.size(1): 794 | filler = torch.zeros( 795 | (batch_size, src_len - key_padding_mask.size(1)), 796 | device=key_padding_mask.device, 797 | ) 798 | new_key_padding_mask = torch.cat( 799 | [filler.float(), key_padding_mask.float()], dim=1 800 | ) 801 | else: 802 | new_key_padding_mask = key_padding_mask.float() 803 | else: 804 | new_key_padding_mask = prev_key_padding_mask 805 | return new_key_padding_mask 806 | 807 | def _get_input_buffer( 808 | self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] 809 | ) -> Dict[str, Optional[Tensor]]: 810 | result = self.get_incremental_state(incremental_state, "attn_state") 811 | if result is not None: 812 | return result 813 | else: 814 | empty_result: Dict[str, Optional[Tensor]] = {} 815 | return empty_result 816 | 817 | def _set_input_buffer( 818 | self, 819 | incremental_state: Dict[str, Dict[str, Optional[Tensor]]], 820 | buffer: Dict[str, Optional[Tensor]], 821 | ): 822 | return self.set_incremental_state(incremental_state, "attn_state", buffer) 823 | 824 | def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): 825 | return attn_weights -------------------------------------------------------------------------------- /components/semantic_extractor/ssl_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import joblib 4 | from components.semantic_extractor.WavLM import WavLM, WavLMConfig 5 | 6 | class ApplyKmeans(nn.Module): 7 | def __init__(self, km_path, device='cuda'): 8 | super(ApplyKmeans, self).__init__() 9 | print(f'Init k-means model from {km_path}') 10 | self.km_model = joblib.load(km_path) 11 | self.C_np = self.km_model.cluster_centers_.transpose() 12 | self.Cnorm_np = (self.C_np ** 2).sum(0, keepdims=True) 13 | self.C = torch.from_numpy(self.C_np).to(device) 14 | self.Cnorm = torch.from_numpy(self.Cnorm_np).to(device) 15 | self.emb = nn.Embedding(num_embeddings=300, embedding_dim=1024) 16 | self.emb.weight.data = self.C.transpose(0, 1) 17 | self.emb.weight.require_grad = False 18 | 19 | def forward(self, x, b, t): 20 | if not hasattr(self, 'C'): 21 | self.C = torch.from_numpy(self.C_np).to(x.device) 22 | if not hasattr(self, 'Cnorm'): 23 | self.Cnorm = torch.from_numpy(self.Cnorm_np).to(x.device) 24 | dist = x.pow(2).sum(1, keepdim=True) - 2 * torch.matmul(x, self.C) + self.Cnorm 25 | tokens = dist.argmin(dim=-1).reshape(b, t) 26 | return tokens 27 | 28 | def get_ssl_model(ckpt_path, km_path, device='cuda', type='xlsr'): 29 | if type == 'xlsr': 30 | print(f'Init xlsr model from {ckpt_path}') 31 | import fairseq 32 | import argparse 33 | task_arg = argparse.Namespace(task='audio_pretraining') 34 | task = fairseq.tasks.setup_task(task_arg) 35 | model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([ckpt_path], task=task) 36 | model = model[0] 37 | model.eval() 38 | elif type == 'wavlm': 39 | print(f'Init wavlm model from {ckpt_path}') 40 | cpt = torch.load(ckpt_path, map_location="cpu") 41 | cfg = WavLMConfig(cpt["cfg"]) 42 | model = WavLM(cfg) 43 | model.load_state_dict(cpt["model"]) 44 | model = model.eval() 45 | model = model.requires_grad_(False) 46 | else: 47 | raise NotImplementedError 48 | km_model = ApplyKmeans(km_path, device) 49 | return model, km_model 50 | 51 | -------------------------------------------------------------------------------- /components/simcodec/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxunji/gen-se/bc14cd3b3ae3b131537bfa73501bd504985eecd5/components/simcodec/__init__.py -------------------------------------------------------------------------------- /components/simcodec/model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import torch.nn as nn 4 | from components.simcodec.modules import Encoder, Quantizer, Generator 5 | 6 | class AttrDict(dict): 7 | def __init__(self, *args, **kwargs): 8 | super(AttrDict, self).__init__(*args, **kwargs) 9 | self.__dict__ = self 10 | 11 | class SimCodec(nn.Module): 12 | def __init__(self, config_path): 13 | super(SimCodec, self).__init__() 14 | self.config_path = config_path 15 | with open(self.config_path) as f: 16 | data = f.read() 17 | json_config = json.loads(data) 18 | self.h = AttrDict(json_config) 19 | self.encoder = Encoder(self.h) 20 | self.quantizer = Quantizer(self.h) 21 | self.generator = Generator(self.h) 22 | 23 | def load_ckpt(self, ckpt_path): 24 | ckpt = torch.load(ckpt_path,map_location='cpu') 25 | self.encoder.load_state_dict(ckpt['encoder']) 26 | self.quantizer.load_state_dict(ckpt['quantizer']) 27 | self.generator.load_state_dict(ckpt['generator']) 28 | 29 | def forward(self, x): 30 | batch_size = x.size(0) 31 | if len(x.shape) == 3 and x.shape[-1] == 1: 32 | x = x.squeeze(-1) 33 | c = self.encoder(x) 34 | _, _, c = self.quantizer(c) 35 | c = [code.reshape(batch_size, -1) for code in c] 36 | return torch.stack(c, -1) 37 | 38 | def decode(self, x): 39 | return self.generator(self.quantizer.embed(x)) -------------------------------------------------------------------------------- /components/simcodec/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.utils import weight_norm, remove_weight_norm 5 | from torch.nn import Conv1d, ConvTranspose1d 6 | 7 | LRELU_SLOPE = 0.1 8 | alpha = 1.0 9 | 10 | def get_padding(kernel_size, dilation=1): 11 | return int((kernel_size*dilation - dilation)/2) 12 | 13 | def init_weights(m, mean=0.0, std=0.01): 14 | classname = m.__class__.__name__ 15 | if classname.find("Conv") != -1: 16 | m.weight.data.normal_(mean, std) 17 | 18 | class ResBlock1(torch.nn.Module): 19 | def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): 20 | super(ResBlock1, self).__init__() 21 | self.h = h 22 | self.convs1 = nn.ModuleList([ 23 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], 24 | padding=get_padding(kernel_size, dilation[0]))), 25 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], 26 | padding=get_padding(kernel_size, dilation[1]))), 27 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], 28 | padding=get_padding(kernel_size, dilation[2]))) 29 | ]) 30 | self.convs1.apply(init_weights) 31 | 32 | self.convs2 = nn.ModuleList([ 33 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 34 | padding=get_padding(kernel_size, 1))), 35 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 36 | padding=get_padding(kernel_size, 1))), 37 | weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, 38 | padding=get_padding(kernel_size, 1))) 39 | ]) 40 | self.convs2.apply(init_weights) 41 | self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers 42 | self.activations = nn.ModuleList([nn.LeakyReLU(LRELU_SLOPE) for _ in range(self.num_layers)]) 43 | 44 | 45 | def forward(self, x): 46 | acts1, acts2 = self.activations[::2], self.activations[1::2] 47 | for c1, c2,a1,a2 in zip(self.convs1, self.convs2,acts1,acts2): 48 | xt = a1(x) 49 | xt = c1(xt) 50 | xt = a2(xt) 51 | xt = c2(xt) 52 | x = xt + x 53 | return x 54 | 55 | def remove_weight_norm(self): 56 | for l in self.convs1: 57 | remove_weight_norm(l) 58 | for l in self.convs2: 59 | remove_weight_norm(l) 60 | 61 | 62 | class Encoder(torch.nn.Module): 63 | def __init__(self, h): 64 | super(Encoder, self).__init__() 65 | self.n_filters = h.en_filters 66 | self.vq_dim = h.vq_dim 67 | self.num_kernels = len(h.resblock_kernel_sizes) 68 | self.num_upsamples = len(h.upsample_rates) 69 | self.upsample_initial_channel = self.n_filters * ( 2**self.num_upsamples ) 70 | self.conv_pre = weight_norm(Conv1d(h.channel, self.n_filters, 7, 1, padding=3)) 71 | self.normalize = nn.ModuleList() 72 | resblock = ResBlock1 73 | 74 | self.ups = nn.ModuleList() 75 | for i, (u, k) in enumerate(list(reversed(list(zip(h.upsample_rates, h.upsample_kernel_sizes))))): 76 | self.ups.append(weight_norm( 77 | Conv1d(self.n_filters*(2**i), self.n_filters*(2**(i+1)), 78 | k, u, 79 | padding=((k-u)//2) 80 | ))) 81 | self.resblocks = nn.ModuleList() 82 | ch = 1 83 | for i in range(len(self.ups)): 84 | ch = self.n_filters*(2**(i+1)) 85 | for j, (k, d) in enumerate( 86 | zip( 87 | list(reversed(h.resblock_kernel_sizes)), 88 | list(reversed(h.resblock_dilation_sizes)) 89 | ) 90 | ): 91 | self.resblocks.append(resblock(h, ch, k, d)) 92 | self.normalize.append(torch.nn.LayerNorm([ch],eps=1e-6,elementwise_affine=True)) 93 | 94 | self.activation_post = nn.LeakyReLU(LRELU_SLOPE) 95 | self.conv_post = Conv1d(ch, self.vq_dim, 3, 1, padding=1) 96 | self.ups.apply(init_weights) 97 | self.conv_post.apply(init_weights) 98 | 99 | def forward(self, x): 100 | x = self.conv_pre(x) 101 | for i in range(self.num_upsamples): 102 | x = self.ups[i](x) 103 | xs = None 104 | for j in range(self.num_kernels): 105 | if xs is None: 106 | xs = self.resblocks[i*self.num_kernels+j](x) 107 | xs = self.normalize[i*self.num_kernels+j](xs.transpose(1,2)).transpose(1,2) 108 | else: 109 | xs += self.resblocks[i*self.num_kernels+j](x) 110 | xs = self.normalize[i*self.num_kernels+j](xs.transpose(1,2)).transpose(1,2) 111 | x = xs / self.num_kernels 112 | x = self.activation_post(x) 113 | x = self.conv_post(x) 114 | return x 115 | 116 | def remove_weight_norm(self): 117 | print('Removing weight norm...') 118 | for l in self.ups: 119 | remove_weight_norm(l) 120 | for l in self.resblocks: 121 | l.remove_weight_norm() 122 | remove_weight_norm(self.conv_pre) 123 | 124 | class Quantizer_module(torch.nn.Module): 125 | def __init__(self, n_e, e_dim): 126 | super(Quantizer_module, self).__init__() 127 | self.embedding = nn.Embedding(n_e, e_dim) 128 | self.embedding.weight.data.uniform_(-1.0 / n_e, 1.0 / n_e) 129 | self.target = torch.arange(0,n_e) 130 | 131 | def forward(self, x, idx=0): 132 | loss=torch.Tensor([0.0]) 133 | d = torch.sum(x ** 2, 1, keepdim=True) + torch.sum(self.embedding.weight ** 2, 1) \ 134 | - 2 * torch.matmul(x, self.embedding.weight.T) 135 | min_indicies = torch.argmin(d, 1) 136 | z_q = self.embedding(min_indicies) 137 | embed_vec = self.embedding.weight 138 | embed_dis = torch.mm(embed_vec , embed_vec.T)*3 139 | self.target = torch.arange(0,embed_vec.shape[0]).to(x.device) 140 | loss = F.cross_entropy(embed_dis,self.target)*(idx==0) 141 | return z_q, min_indicies,loss 142 | 143 | class Quantizer(torch.nn.Module): 144 | def __init__(self, h): 145 | super(Quantizer, self).__init__() 146 | assert h.vq_dim % h.n_code_groups == 0 147 | self.lm_offset = 0 148 | self.lm_states = None 149 | self.vq_dim = h.vq_dim 150 | self.residul_layer = h.n_q 151 | self.n_code_groups = h.n_code_groups 152 | self.quantizer_modules = nn.ModuleList() 153 | for i in range(self.residul_layer): 154 | self.quantizer_modules.append(nn.ModuleList([ 155 | Quantizer_module(h.n_codes, self.vq_dim // h.n_code_groups) for _ in range(h.n_code_groups) 156 | ])) 157 | self.h = h 158 | self.codebook_loss_lambda = self.h.codebook_loss_lambda # e.g., 1 159 | self.commitment_loss_lambda = self.h.commitment_loss_lambda # e.g., 0.25 160 | 161 | 162 | def for_one_step(self, xin, idx): 163 | xin = xin.transpose(1, 2) 164 | x = xin.reshape(-1, self.vq_dim) 165 | x = torch.split(x, self.vq_dim // self.h.n_code_groups, dim=-1) 166 | min_indicies = [] 167 | z_q = [] 168 | all_losses = [] 169 | for _x, m in zip(x, self.quantizer_modules[idx]): 170 | _z_q, _min_indicies,_loss = m(_x,idx) 171 | all_losses.append(_loss) 172 | z_q.append(_z_q) 173 | min_indicies.append(_min_indicies) 174 | z_q = torch.cat(z_q, -1).reshape(xin.shape) 175 | z_q = z_q.transpose(1, 2) 176 | all_losses = torch.stack(all_losses) 177 | loss = torch.mean(all_losses) 178 | return z_q, min_indicies, loss 179 | 180 | 181 | def forward(self, xin,bw=-1,mask_id=None): 182 | quantized_out = 0.0 183 | residual = xin 184 | all_losses = [] 185 | all_indices = [] 186 | if bw<=0: 187 | bw = self.residul_layer 188 | for i in range(bw): 189 | quantized, indices, e_loss = self.for_one_step(residual, i) # 190 | if mask_id is not None: 191 | mask = ( 192 | torch.full([xin.shape[0],xin.shape[2],1], fill_value=i, device=xin.device) < mask_id.unsqueeze(2) + 1 193 | ) 194 | mask = mask.repeat(1,1,xin.shape[1]).transpose(1,2) 195 | if mask_id is not None: 196 | loss = 0.1 * e_loss + self.codebook_loss_lambda * torch.mean((quantized - residual.detach()) ** 2 * mask) \ 197 | + self.commitment_loss_lambda * torch.mean((quantized.detach() - residual) ** 2 * mask ) 198 | else: 199 | loss = 0.1 * e_loss \ 200 | + self.codebook_loss_lambda * torch.mean((quantized - residual.detach()) ** 2 ) \ 201 | + self.commitment_loss_lambda * torch.mean((quantized.detach() - residual) ** 2 ) 202 | 203 | quantized = residual + (quantized - residual).detach() 204 | residual = residual - quantized 205 | if mask_id is not None: 206 | quantized_out = quantized_out + quantized * mask 207 | else: 208 | quantized_out = quantized_out + quantized 209 | all_indices.extend(indices) # 210 | all_losses.append(loss) 211 | all_losses = torch.stack(all_losses) 212 | loss = torch.mean(all_losses) 213 | return quantized_out, loss, all_indices 214 | 215 | def embed(self, x , bw=-1): 216 | quantized_out = torch.tensor(0.0, device=x.device) 217 | x = torch.split(x, 1, 2) 218 | if bw <= 0 or bw > self.residul_layer: 219 | bw = self.residul_layer 220 | for i in range(bw): 221 | ret = [] 222 | for j in range(self.n_code_groups): 223 | q = x[j+self.n_code_groups*i] 224 | embed = self.quantizer_modules[i][j] 225 | q = embed.embedding(q.squeeze(-1)) 226 | ret.append(q) 227 | ret = torch.cat(ret, -1) 228 | quantized_out = quantized_out + ret 229 | return quantized_out.transpose(1, 2) 230 | 231 | 232 | class Generator(torch.nn.Module): 233 | def __init__(self, h): 234 | super(Generator, self).__init__() 235 | self.h = h 236 | self.n_filters = h.de_filters 237 | self.vq_dim = h.vq_dim 238 | self.num_kernels = len(h.resblock_kernel_sizes) 239 | self.num_upsamples = len(h.upsample_rates) 240 | self.upsample_initial_channel = self.n_filters * ( 2**self.num_upsamples ) 241 | self.conv_pre = weight_norm(Conv1d(self.vq_dim, self.upsample_initial_channel, 7, 1, padding=3)) 242 | resblock = ResBlock1 243 | 244 | 245 | self.norm = nn.Identity() 246 | 247 | self.ups = nn.ModuleList() 248 | for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): 249 | self.ups.append(weight_norm( 250 | ConvTranspose1d( 251 | self.upsample_initial_channel//(2**i), self.upsample_initial_channel//(2**(i+1)), 252 | k, u, 253 | padding=(k - u )//2, 254 | ) 255 | )) 256 | ch = 1 257 | self.resblocks = nn.ModuleList() 258 | for i in range(len(self.ups)): 259 | ch = self.upsample_initial_channel//(2**(i+1)) 260 | for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): 261 | self.resblocks.append(resblock(h, ch, k, d)) 262 | 263 | 264 | self.activation_post = nn.LeakyReLU(LRELU_SLOPE) 265 | self.conv_post = weight_norm(Conv1d(ch, h.channel, 7, 1, padding=3)) 266 | self.ups.apply(init_weights) 267 | self.conv_post.apply(init_weights) 268 | 269 | def forward(self, x): 270 | x = self.norm(x) 271 | x = self.conv_pre(x) 272 | 273 | for i in range(self.num_upsamples): 274 | x = self.ups[i](x) 275 | xs = None 276 | for j in range(self.num_kernels): 277 | if xs is None: 278 | xs = self.resblocks[i*self.num_kernels+j](x) 279 | else: 280 | xs += self.resblocks[i*self.num_kernels+j](x) 281 | x = xs / self.num_kernels 282 | x = self.activation_post(x) 283 | x = self.conv_post(x) 284 | x = torch.tanh(x) 285 | 286 | return x 287 | 288 | def remove_weight_norm(self): 289 | print('Removing weight norm...') 290 | for l in self.ups: 291 | remove_weight_norm(l) 292 | for l in self.resblocks: 293 | l.remove_weight_norm() 294 | remove_weight_norm(self.conv_pre) 295 | remove_weight_norm(self.conv_post) -------------------------------------------------------------------------------- /configs/gense.yaml: -------------------------------------------------------------------------------- 1 | path: 2 | n2s_ckpt_path: ckpts/n2s_xlsr.ckpt 3 | s2s_ckpt_path: ckpts/s2s_xlsr.ckpt 4 | codec_config_path: ckpts/config.json 5 | 6 | model: 7 | hidden_size: 1024 8 | # intermediate_size: 2048 9 | num_hidden_layers: 12 10 | num_attention_heads: 8 11 | n2s_vocab_size: 1027 #1024 + 1+1+1 12 | s2s_vocab_size: 9219 #8192 + 1024 + 1+1+1 13 | semantic_num: 1024 14 | 15 | ssl_model: 16 | ckpt_path: ckpts/xlsr2_300m.pt 17 | km_path: ckpts/xlsr_km.mdl 18 | type: xlsr -------------------------------------------------------------------------------- /configs/gense_wavlm.yaml: -------------------------------------------------------------------------------- 1 | path: 2 | n2s_ckpt_path: ckpts/n2s_wavlm.ckpt 3 | s2s_ckpt_path: ckpts/s2s_wavlm.ckpt 4 | codec_config_path: ckpts/config.json 5 | 6 | model: 7 | hidden_size: 1024 8 | # intermediate_size: 2048 9 | num_hidden_layers: 12 10 | num_attention_heads: 8 11 | n2s_vocab_size: 1027 #1024 + 1+1+1 12 | s2s_vocab_size: 9219 #8192 + 1024 + 1+1+1 13 | semantic_num: 1024 14 | 15 | ssl_model: 16 | ckpt_path: ckpts/WavLM-Large.pt 17 | km_path: ckpts/wavlm_km.mdl 18 | type: wavlm -------------------------------------------------------------------------------- /fig/gense.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxunji/gen-se/bc14cd3b3ae3b131537bfa73501bd504985eecd5/fig/gense.png -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import torch 3 | import torchaudio 4 | import yaml 5 | 6 | from models.gense import N2S, S2S 7 | 8 | class AttrDict(dict): 9 | def __init__(self, *args, **kwargs): 10 | super(AttrDict, self).__init__(*args, **kwargs) 11 | self.__dict__ = self 12 | 13 | def get_firstchannel_read(path, target_sr=16000): 14 | wav, sr = torchaudio.load(path) 15 | if wav.shape[0] > 1: 16 | wav = wav[0].unsqueeze(0) 17 | if sr != target_sr: 18 | resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=target_sr) 19 | wav = resampler(wav) 20 | return wav.unsqueeze(0) 21 | 22 | 23 | def run(noisy_path, out_path, config_path): 24 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 25 | with open(config_path, "r") as f: 26 | config = yaml.safe_load(f) 27 | config = AttrDict(config) 28 | 29 | noisy_wav = get_firstchannel_read(noisy_path).to(device) 30 | 31 | n2s_model = N2S(config) 32 | n2s_model.load_state_dict(torch.load(config.path['n2s_ckpt_path'])["state_dict"]) 33 | n2s_model = n2s_model.eval() 34 | n2s_model = n2s_model.to(device) 35 | 36 | s2s_model = S2S(config) 37 | s2s_model.load_state_dict(torch.load(config.path['s2s_ckpt_path'])["state_dict"]) 38 | s2s_model = s2s_model.eval() 39 | s2s_model = s2s_model.to(device) 40 | 41 | noisy_s, clean_s = n2s_model.generate(noisy_wav) 42 | enhanced_wav = s2s_model.generate(noisy_wav, noisy_s, clean_s) 43 | torchaudio.save(out_path, enhanced_wav, sample_rate=16000) 44 | 45 | 46 | if __name__ == "__main__": 47 | fire.Fire( 48 | { 49 | "run": run, 50 | } 51 | ) -------------------------------------------------------------------------------- /models/gense.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | from components.semantic_extractor.ssl_model import get_ssl_model 8 | from components.simcodec.model import SimCodec 9 | from transformers import GPT2Config, GPT2LMHeadModel 10 | 11 | class N2S(nn.Module): 12 | def __init__(self, hps): 13 | super().__init__() 14 | self.hps = hps 15 | self.xlsr, self.km = get_ssl_model(**hps.ssl_model) 16 | self.bos = 1 17 | self.eos = 2 18 | self.pad = 0 19 | self.shift_num = 3 20 | 21 | self.lm_conf = GPT2Config( 22 | vocab_size=self.hps.model['n2s_vocab_size'], 23 | n_embd=self.hps.model['hidden_size'], 24 | n_layer=self.hps.model['num_hidden_layers'], 25 | n_head=self.hps.model['num_attention_heads'], 26 | activation_function='gelu_new', 27 | n_positions=2048, 28 | n_ctx=2048, 29 | resid_pdrop=0.1, 30 | embd_pdrop=0.1, 31 | attn_pdrop=0.1, 32 | layer_norm_epsilon=1e-05, 33 | initializer_range=0.02, 34 | summary_type='mean', 35 | summary_use_proj=True, 36 | summary_activation=None, 37 | summary_proj_to_labels=True, 38 | summary_first_dropout=0.1, 39 | bos_token_id=self.bos, 40 | eos_token_id=self.eos, 41 | ) 42 | self.lm = GPT2LMHeadModel(self.lm_conf) 43 | 44 | def extract_semantic(self, wavs, num_frames): 45 | padding_size = (0, 100) 46 | wavs = F.pad(wavs, padding_size, "constant", 0) 47 | num_frames += 100 48 | features = self.xlsr.extract_features(wavs, padding_mask=None) 49 | layer_results = features['layer_results'][5] 50 | x, _, _ = layer_results 51 | features = x.transpose(0,1) 52 | b, t, d = features.shape 53 | tokens = self.km(features.reshape(-1, d), b=b, t=t) 54 | return tokens 55 | 56 | def inference(self, token_gen, pos_gen): 57 | predict_len = (token_gen.shape[1] - 1) 58 | truck_length = token_gen.shape[1] 59 | 60 | for j in tqdm(range(predict_len)): 61 | lm_outputs = self.lm( 62 | input_ids=token_gen, 63 | attention_mask=None, 64 | position_ids=pos_gen 65 | ) 66 | logits = lm_outputs['logits'] 67 | logits[:, :, 0:self.shift_num] = -1e5 68 | probs = logits[:, -1, :].softmax(dim=-1) 69 | 70 | dist = torch.distributions.categorical.Categorical(probs=probs) 71 | 72 | samples = dist.sample().unsqueeze(1).to(token_gen.device) 73 | token_gen = torch.cat([token_gen, samples], dim=1) 74 | pos_pad = torch.ones(pos_gen.shape[0]) * j 75 | pos_gen = torch.cat([pos_gen, pos_pad.unsqueeze(1).to(token_gen.device).long()], dim=1) 76 | 77 | return token_gen[:,truck_length:][0] 78 | 79 | 80 | def generate(self, mix): 81 | mix = mix.squeeze(1) 82 | num_frame = torch.LongTensor([mix.shape[1]]).to(mix.device) 83 | token_s = self.extract_semantic(mix, num_frames=num_frame) 84 | 85 | token_s += 3 86 | bos = torch.ones(token_s.shape[0],1).long().to(mix.device) 87 | token_gen = torch.cat([token_s, bos], dim=1) 88 | 89 | pos_gen_id = torch.from_numpy(np.asarray(list(range(token_s.shape[1] + 1)))).to(mix.device) 90 | pos_gen = [] 91 | for i in range(token_s.shape[0]): 92 | pos_gen.append(pos_gen_id.unsqueeze(0)) 93 | pos_gen = torch.cat(pos_gen, dim=0) 94 | 95 | clean_s = self.inference(token_gen, pos_gen) - self.shift_num 96 | token_s -= self.shift_num 97 | return token_s, clean_s 98 | 99 | 100 | class S2S(nn.Module): 101 | def __init__(self, hps): 102 | super().__init__() 103 | self.hps = hps 104 | self.codec_tokenizer = SimCodec(hps.path['codec_config_path']) 105 | self.xlsr, self.km = get_ssl_model(**hps.ssl_model) 106 | self.bos = 1 107 | self.eos = 2 108 | self.pad = 0 109 | self.shift_num = 3 + self.hps.model['semantic_num'] 110 | self.lm_conf = GPT2Config( 111 | vocab_size=self.hps.model['s2s_vocab_size'], 112 | n_embd=self.hps.model['hidden_size'], 113 | n_layer=self.hps.model['num_hidden_layers'], 114 | n_head=self.hps.model['num_attention_heads'], 115 | activation_function='gelu_new', 116 | n_positions=4096, 117 | n_ctx=4096, 118 | resid_pdrop=0.1, 119 | embd_pdrop=0.1, 120 | attn_pdrop=0.1, 121 | layer_norm_epsilon=1e-05, 122 | initializer_range=0.02, 123 | summary_type='mean', 124 | summary_use_proj=True, 125 | summary_activation=None, 126 | summary_proj_to_labels=True, 127 | summary_first_dropout=0.1, 128 | bos_token_id=self.bos, 129 | eos_token_id=self.eos, 130 | ) 131 | self.lm = GPT2LMHeadModel(self.lm_conf) 132 | 133 | def inference(self, token_gen, pos_gen): 134 | predict_len = int((token_gen.shape[1] - 1) / 2) 135 | truck_length = token_gen.shape[1] 136 | for j in tqdm(range(predict_len)): 137 | lm_outputs = self.lm( 138 | input_ids=token_gen, 139 | attention_mask=None, 140 | position_ids=pos_gen 141 | ) 142 | logits = lm_outputs['logits'] 143 | logits[:, :, 0:self.shift_num] = -1e5 144 | probs = logits[:, -1, :].softmax(dim=-1) 145 | dist = torch.distributions.categorical.Categorical(probs=probs) 146 | samples = dist.sample().unsqueeze(1).to(token_gen.device) 147 | token_gen = torch.cat([token_gen, samples], dim=1) 148 | pos_pad = torch.ones(pos_gen.shape[0]) * (j + 1000) 149 | pos_gen = torch.cat([pos_gen, pos_pad.unsqueeze(1).to(token_gen.device).long()], dim=1) 150 | 151 | return token_gen[:,truck_length:][0] 152 | 153 | def generate(self, mix, mix_s, clean_s): 154 | mix_a = self.codec_tokenizer(mix).squeeze(-1) 155 | if len(clean_s.shape) == 1: 156 | clean_s = clean_s.unsqueeze(0) 157 | 158 | mix_s += 3 159 | clean_s += 3 160 | mix_a += self.shift_num 161 | 162 | bos = torch.ones(mix_s.shape[0],1).long().to(mix.device) 163 | token_gen = torch.cat([mix_s, clean_s, bos, mix_a], dim=1) 164 | 165 | pos_gen_id = torch.from_numpy(np.asarray(list(range(mix_s.shape[1] + clean_s.shape[1] + 1)) + list(range(mix_a.shape[1])))).to(mix.device) 166 | pos_gen = [] 167 | for i in range(mix_s.shape[0]): 168 | pos_gen.append(pos_gen_id.unsqueeze(0)) 169 | pos_gen = torch.cat(pos_gen, dim=0) 170 | 171 | pre_a = self.inference(token_gen, pos_gen) - self.shift_num 172 | gen_wav = self.codec_tokenizer.decode(pre_a.unsqueeze(0).unsqueeze(2)).squeeze(0).cpu() 173 | 174 | return gen_wav -------------------------------------------------------------------------------- /models/gense_wavlm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | from components.semantic_extractor.ssl_model import get_ssl_model 8 | from components.simcodec.model import SimCodec 9 | from transformers import GPT2Config, GPT2LMHeadModel 10 | 11 | class N2S(nn.Module): 12 | def __init__(self, hps): 13 | super().__init__() 14 | self.hps = hps 15 | self.wavlm, self.km = get_ssl_model(**hps.ssl_model) 16 | self.bos = 1 17 | self.eos = 2 18 | self.pad = 0 19 | self.shift_num = 3 20 | 21 | self.lm_conf = GPT2Config( 22 | vocab_size=self.hps.model['n2s_vocab_size'], 23 | n_embd=self.hps.model['hidden_size'], 24 | n_layer=self.hps.model['num_hidden_layers'], 25 | n_head=self.hps.model['num_attention_heads'], 26 | activation_function='gelu_new', 27 | n_positions=2048, 28 | n_ctx=2048, 29 | resid_pdrop=0.1, 30 | embd_pdrop=0.1, 31 | attn_pdrop=0.1, 32 | layer_norm_epsilon=1e-05, 33 | initializer_range=0.02, 34 | summary_type='mean', 35 | summary_use_proj=True, 36 | summary_activation=None, 37 | summary_proj_to_labels=True, 38 | summary_first_dropout=0.1, 39 | bos_token_id=self.bos, 40 | eos_token_id=self.eos, 41 | ) 42 | self.lm = GPT2LMHeadModel(self.lm_conf) 43 | 44 | def extract_semantic(self, wavs, num_frames): 45 | padding_size = (0, 100) 46 | wavs = F.pad(wavs, padding_size, "constant", 0) 47 | num_frames += 100 48 | features = self.wavlm.extract_features( 49 | wavs, 50 | output_layer=6, 51 | ret_layer_results=False, 52 | input_length=num_frames 53 | )[0] 54 | b, t, d = features.shape 55 | tokens = self.km(features.reshape(-1, d), b=b, t=t) 56 | return tokens 57 | 58 | def inference(self, token_gen, pos_gen): 59 | predict_len = (token_gen.shape[1] - 1) 60 | truck_length = token_gen.shape[1] 61 | 62 | for j in tqdm(range(predict_len)): 63 | lm_outputs = self.lm( 64 | input_ids=token_gen, 65 | attention_mask=None, 66 | position_ids=pos_gen 67 | ) 68 | logits = lm_outputs['logits'] 69 | logits[:, :, 0:self.shift_num] = -1e5 70 | probs = logits[:, -1, :].softmax(dim=-1) 71 | 72 | dist = torch.distributions.categorical.Categorical(probs=probs) 73 | 74 | samples = dist.sample().unsqueeze(1).to(token_gen.device) 75 | token_gen = torch.cat([token_gen, samples], dim=1) 76 | pos_pad = torch.ones(pos_gen.shape[0]) * j 77 | pos_gen = torch.cat([pos_gen, pos_pad.unsqueeze(1).to(token_gen.device).long()], dim=1) 78 | 79 | return token_gen[:,truck_length:][0] 80 | 81 | 82 | def generate(self, mix): 83 | mix = mix.squeeze(1) 84 | num_frame = torch.LongTensor([mix.shape[1]]).to(mix.device) 85 | token_s = self.extract_semantic(mix, num_frames=num_frame) 86 | 87 | token_s += 3 88 | bos = torch.ones(token_s.shape[0],1).long().to(mix.device) 89 | token_gen = torch.cat([token_s, bos], dim=1) 90 | 91 | pos_gen_id = torch.from_numpy(np.asarray(list(range(token_s.shape[1] + 1)))).to(mix.device) 92 | pos_gen = [] 93 | for i in range(token_s.shape[0]): 94 | pos_gen.append(pos_gen_id.unsqueeze(0)) 95 | pos_gen = torch.cat(pos_gen, dim=0) 96 | 97 | clean_s = self.inference(token_gen, pos_gen) - self.shift_num 98 | token_s -= self.shift_num 99 | return token_s, clean_s 100 | 101 | 102 | class S2S(nn.Module): 103 | def __init__(self, hps): 104 | super().__init__() 105 | self.hps = hps 106 | self.codec_tokenizer = SimCodec(hps.path['codec_config_path']) 107 | self.wavlm, self.km = get_ssl_model(**hps.ssl_model) 108 | self.bos = 1 109 | self.eos = 2 110 | self.pad = 0 111 | self.shift_num = 3 + self.hps.model['semantic_num'] 112 | self.lm_conf = GPT2Config( 113 | vocab_size=self.hps.model['s2s_vocab_size'], 114 | n_embd=self.hps.model['hidden_size'], 115 | n_layer=self.hps.model['num_hidden_layers'], 116 | n_head=self.hps.model['num_attention_heads'], 117 | activation_function='gelu_new', 118 | n_positions=4096, 119 | n_ctx=4096, 120 | resid_pdrop=0.1, 121 | embd_pdrop=0.1, 122 | attn_pdrop=0.1, 123 | layer_norm_epsilon=1e-05, 124 | initializer_range=0.02, 125 | summary_type='mean', 126 | summary_use_proj=True, 127 | summary_activation=None, 128 | summary_proj_to_labels=True, 129 | summary_first_dropout=0.1, 130 | bos_token_id=self.bos, 131 | eos_token_id=self.eos, 132 | ) 133 | self.lm = GPT2LMHeadModel(self.lm_conf) 134 | 135 | def inference(self, token_gen, pos_gen): 136 | predict_len = int((token_gen.shape[1] - 1) / 2) 137 | truck_length = token_gen.shape[1] 138 | for j in tqdm(range(predict_len)): 139 | lm_outputs = self.lm( 140 | input_ids=token_gen, 141 | attention_mask=None, 142 | position_ids=pos_gen 143 | ) 144 | logits = lm_outputs['logits'] 145 | logits[:, :, 0:self.shift_num] = -1e5 146 | probs = logits[:, -1, :].softmax(dim=-1) 147 | dist = torch.distributions.categorical.Categorical(probs=probs) 148 | samples = dist.sample().unsqueeze(1).to(token_gen.device) 149 | token_gen = torch.cat([token_gen, samples], dim=1) 150 | pos_pad = torch.ones(pos_gen.shape[0]) * (j + 1000) 151 | pos_gen = torch.cat([pos_gen, pos_pad.unsqueeze(1).to(token_gen.device).long()], dim=1) 152 | 153 | return token_gen[:,truck_length:][0] 154 | 155 | def generate(self, mix, mix_s, clean_s): 156 | mix_a = self.codec_tokenizer(mix).squeeze(-1) 157 | if len(clean_s.shape) == 1: 158 | clean_s = clean_s.unsqueeze(0) 159 | 160 | mix_s += 3 161 | clean_s += 3 162 | mix_a += self.shift_num 163 | 164 | bos = torch.ones(mix_s.shape[0],1).long().to(mix.device) 165 | token_gen = torch.cat([mix_s, clean_s, bos, mix_a], dim=1) 166 | 167 | pos_gen_id = torch.from_numpy(np.asarray(list(range(mix_s.shape[1] + clean_s.shape[1] + 1)) + list(range(mix_a.shape[1])))).to(mix.device) 168 | pos_gen = [] 169 | for i in range(mix_s.shape[0]): 170 | pos_gen.append(pos_gen_id.unsqueeze(0)) 171 | pos_gen = torch.cat(pos_gen, dim=0) 172 | 173 | pre_a = self.inference(token_gen, pos_gen) - self.shift_num 174 | gen_wav = self.codec_tokenizer.decode(pre_a.unsqueeze(0).unsqueeze(2)).squeeze(0).cpu() 175 | 176 | return gen_wav -------------------------------------------------------------------------------- /noisy.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yaoxunji/gen-se/bc14cd3b3ae3b131537bfa73501bd504985eecd5/noisy.wav -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.22.3 2 | torch==1.13.1 3 | torchaudio==0.13.1 4 | fire 5 | PyYAML==6.0.2 6 | joblib==1.4.0 7 | scikit-learn==1.3.2 8 | tqdm 9 | librosa==0.8.0 10 | transformers==4.40.1 11 | fairseq==0.12.2 --------------------------------------------------------------------------------