├── module ├── __init__.py ├── layers.py ├── losses.py ├── cnhubert.py ├── wavenet.py ├── mel_processing.py ├── commons.py ├── transforms.py ├── quantize.py ├── data_utils.py └── core_vq.py ├── pretrain └── .gitkeep ├── data_conf.py ├── asr ├── __init__.py ├── data_module.py ├── meldataset.py ├── models.py ├── trainer.py └── layers.py ├── img └── magvits.png ├── text ├── cmudict_cache.pickle ├── __init__.py ├── cleaner.py ├── bert.py ├── symbols.py ├── japanese.py ├── english.py ├── chinese.py └── opencpop-strict.txt ├── transformer ├── __init__.py ├── lr_schedulers.py ├── embedding.py ├── transformer.py └── scaling.py ├── requirements.txt ├── configs ├── asr.yml └── vits.json ├── gen_filelist.py ├── README.md ├── resample.py ├── extract_spk_embedding.py ├── extract_ssl.py ├── gen_phonemes.py ├── asr_train.py ├── inference.py ├── extract_duration.py ├── utils.py └── vits_train.py /module/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pretrain/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data_conf.py: -------------------------------------------------------------------------------- 1 | data_root = 'dataset_raw' 2 | -------------------------------------------------------------------------------- /asr/__init__.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/yl4579/AuxiliaryASR/ -------------------------------------------------------------------------------- /img/magvits.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innnky/MagVITS/HEAD/img/magvits.png -------------------------------------------------------------------------------- /text/cmudict_cache.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/innnky/MagVITS/HEAD/text/cmudict_cache.pickle -------------------------------------------------------------------------------- /transformer/__init__.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/modules 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cn2an 2 | jieba 3 | pypinyin 4 | g2p_en 5 | transformers 6 | torch 7 | torchaudio 8 | git+https://github.com/resemble-ai/monotonic_align.git 9 | jiwer 10 | tensorboard -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | from text.symbols import * 2 | 3 | 4 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 5 | 6 | def cleaned_text_to_sequence(cleaned_text): 7 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 8 | Args: 9 | text: string to convert to a sequence 10 | Returns: 11 | List of integers corresponding to the symbols in the text 12 | ''' 13 | phones = [_symbol_to_id[symbol] for symbol in cleaned_text] 14 | return phones 15 | 16 | -------------------------------------------------------------------------------- /configs/asr.yml: -------------------------------------------------------------------------------- 1 | log_dir: "logs/asr" 2 | save_freq: 1 3 | device: "cuda" 4 | epochs: 200 5 | batch_size: 16 6 | pretrained_model: 'pretrain/asr.ckpt' 7 | train_data: "Data/train.txt" 8 | val_data: "Data/val.txt" 9 | load_only_params: False 10 | preprocess_parasm: 11 | sr: 32000 12 | spect_params: 13 | n_fft: 2048 14 | win_length: 2048 15 | hop_length: 640 16 | mel_params: 17 | n_mels: 128 18 | 19 | model_params: 20 | input_dim: 128 21 | hidden_dim: 256 22 | n_token: 400 23 | token_embedding_dim: 256 24 | 25 | optimizer_params: 26 | lr: 0.0005 27 | -------------------------------------------------------------------------------- /gen_filelist.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from random import shuffle 3 | from data_conf import data_root 4 | from tqdm import tqdm 5 | import os 6 | filenames = glob(f"{data_root}/**/*.wav", recursive=True) # [:10] 7 | filenames += glob(f"{data_root}/**/*.mp3", recursive=True) # [:10] 8 | filenames = [f for f in tqdm(filenames)] 9 | shuffle(filenames) 10 | 11 | val_num = 4 12 | print(len(filenames)) 13 | train = filenames[:-val_num] 14 | val = filenames[-val_num:] 15 | train.sort() 16 | val.sort() 17 | 18 | with open('dump/train_files.list', 'w') as f: 19 | f.write('\n'.join(train)) 20 | with open('dump/val_files.list', 'w') as f: 21 | f.write('\n'.join(val)) 22 | 23 | 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MagVITS 2 | VITS with phoneme-level prosody modeling based on MaskGIT (WIP) 3 | 4 | feature: inference speed ~= bert-vits2 & prosody > bert-vits2 (maybe) 5 | 6 | 目前代码正在重构,可能还跑不通,目前不建议跑 7 | 中文预训练模型不久后会上传(数据:原神中文+aishell 共200h多一些) 8 | 9 | ### structure 10 | ![](img/magvits.png) 11 | 12 | ### Acknowledgements 13 | + Thanks to the support of the GPUs by [leng-yue](https://github.com/leng-yue) [fishaudio](https://github.com/fishaudio) 14 | + [VITS](https://github.com/jaywalnut310/vits) 15 | + [MaskGIT](https://github.com/valeoai/Maskgit-pytorch/blob/main/Trainer/vit.py) 16 | + [AuxiliaryASR and styletts2](https://github.com/yl4579/AuxiliaryASR/) 17 | + [MegaTTS](https://arxiv.org/abs/2306.03509) 18 | + [descript-audio-codec](https://github.com/descriptinc/descript-audio-codec) 19 | + [visinger](https://github.com/zhangyongmao/VISinger2) 20 | -------------------------------------------------------------------------------- /text/cleaner.py: -------------------------------------------------------------------------------- 1 | from text import chinese, japanese, cleaned_text_to_sequence, symbols, english 2 | 3 | language_module_map = { 4 | 'zh': chinese, 5 | "ja": japanese, 6 | 'en': english 7 | } 8 | def clean_text(text, language): 9 | language_module = language_module_map[language] 10 | norm_text = language_module.text_normalize(text) 11 | phones, word2ph = language_module.g2p(norm_text) 12 | # assert len(phones) == sum(word2ph) 13 | # assert len(norm_text) == len(word2ph) 14 | 15 | for ph in phones: 16 | assert ph in symbols 17 | return phones, word2ph, norm_text 18 | 19 | 20 | def text_to_sequence(text, language): 21 | phones = clean_text(text) 22 | return cleaned_text_to_sequence(phones) 23 | 24 | if __name__ == '__main__': 25 | print(clean_text("你好%啊啊啊额、还是到付红四方。", 'zh')) 26 | 27 | 28 | -------------------------------------------------------------------------------- /module/layers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from torch.nn.utils import weight_norm 7 | 8 | 9 | def WNConv1d(*args, **kwargs): 10 | return weight_norm(nn.Conv1d(*args, **kwargs)) 11 | 12 | 13 | def WNConvTranspose1d(*args, **kwargs): 14 | return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) 15 | 16 | 17 | # Scripting this brings model speed up 1.4x 18 | @torch.jit.script 19 | def snake(x, alpha): 20 | shape = x.shape 21 | x = x.reshape(shape[0], shape[1], -1) 22 | x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) 23 | x = x.reshape(shape) 24 | return x 25 | 26 | 27 | class Snake1d(nn.Module): 28 | def __init__(self, channels): 29 | super().__init__() 30 | self.alpha = nn.Parameter(torch.ones(1, channels, 1)) 31 | 32 | def forward(self, x): 33 | return snake(x, self.alpha) 34 | -------------------------------------------------------------------------------- /resample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from glob import glob 4 | 5 | import librosa 6 | import soundfile 7 | from tqdm import tqdm 8 | from multiprocessing import Pool 9 | 10 | from data_conf import data_root 11 | 12 | def process_wav(wavpath): 13 | wav, _ = librosa.load(wavpath, sr=tgt_sr) 14 | soundfile.write(wavpath, wav, tgt_sr) 15 | 16 | def get_wav_files(path): 17 | wav_files = [] 18 | for root, dirs, files in os.walk(path): 19 | for file in files: 20 | if file.endswith(".wav") or file.endswith(".mp3"): 21 | wav_files.append(os.path.join(root, file)) 22 | return wav_files 23 | 24 | tgt_path = data_root 25 | 26 | num_processes = 10 # You can adjust the number of processes as needed 27 | tgt_sr = 32000 28 | 29 | print("Note: this script will overwrite the original files!") 30 | print("all the wav files under {} will be resampled to {}Hz".format(tgt_path, tgt_sr)) 31 | input("press enter to continue... or press Ctrl+C to cancel") 32 | 33 | if __name__ == "__main__": 34 | with Pool(num_processes) as pool: 35 | file_list = get_wav_files(tgt_path) 36 | list(tqdm(pool.imap(process_wav, file_list), total=len(file_list))) 37 | 38 | -------------------------------------------------------------------------------- /text/bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, AutoModelForMaskedLM 3 | 4 | bert_models = None 5 | tokenizer = AutoTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext-large") 6 | 7 | def get_bert_feature(text, word2ph, device, language): 8 | global bert_models 9 | if language != "zh": 10 | return torch.zeros(1024, sum(word2ph)) 11 | 12 | if bert_models == None: 13 | bert_models = AutoModelForMaskedLM.from_pretrained( 14 | "hfl/chinese-roberta-wwm-ext-large" 15 | ).to(device) 16 | print('loaded bert model at rank', device) 17 | 18 | with torch.no_grad(): 19 | inputs = tokenizer(text, return_tensors="pt") 20 | for i in inputs: 21 | inputs[i] = inputs[i].to(device) 22 | res = bert_models(**inputs, output_hidden_states=True) 23 | res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1] 24 | 25 | assert len(word2ph) == len(text) 26 | phone_level_feature = [] 27 | for i in range(len(word2ph)): 28 | repeat_feature = res[i].repeat(word2ph[i], 1) 29 | phone_level_feature.append(repeat_feature) 30 | 31 | phone_level_feature = torch.cat(phone_level_feature, dim=0) 32 | 33 | return phone_level_feature.T 34 | 35 | -------------------------------------------------------------------------------- /asr/data_module.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | 3 | from asr.meldataset import build_dataloader 4 | 5 | 6 | 7 | def get_data_path_list(train_path=None, val_path=None): 8 | if train_path is None: 9 | train_path = "Data/train_list.txt" 10 | if val_path is None: 11 | val_path = "Data/val_list.txt" 12 | 13 | with open(train_path, 'r') as f: 14 | train_list = f.readlines() 15 | with open(val_path, 'r') as f: 16 | val_list = f.readlines() 17 | 18 | return train_list, val_list 19 | 20 | class ASRDataModule(pl.LightningDataModule): 21 | def __init__(self, data_dir='dump/', batch_size=64,num_workers=8): 22 | super().__init__() 23 | train_path = f'{data_dir}/train_files.list' 24 | val_path = f'{data_dir}/val_files.list' 25 | train_list, val_list = get_data_path_list(train_path, val_path) 26 | train_dataloader = build_dataloader(train_list, 27 | batch_size=batch_size, 28 | num_workers=num_workers, 29 | dataset_config={}) 30 | 31 | val_dataloader = build_dataloader(val_list, 32 | batch_size=batch_size, 33 | validation=True, 34 | num_workers=1, 35 | dataset_config={}) 36 | self.train_loader = train_dataloader 37 | self.val_loader = val_dataloader 38 | 39 | 40 | def train_dataloader(self): 41 | 42 | return self.train_loader 43 | 44 | def val_dataloader(self): 45 | 46 | return self.val_loader -------------------------------------------------------------------------------- /module/losses.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def feature_loss(fmap_r, fmap_g): 8 | loss = 0 9 | for dr, dg in zip(fmap_r, fmap_g): 10 | for rl, gl in zip(dr, dg): 11 | rl = rl.float().detach() 12 | gl = gl.float() 13 | loss += torch.mean(torch.abs(rl - gl)) 14 | 15 | return loss * 2 16 | 17 | 18 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 19 | loss = 0 20 | r_losses = [] 21 | g_losses = [] 22 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 23 | dr = dr.float() 24 | dg = dg.float() 25 | r_loss = torch.mean((1-dr)**2) 26 | g_loss = torch.mean(dg**2) 27 | loss += (r_loss + g_loss) 28 | r_losses.append(r_loss.item()) 29 | g_losses.append(g_loss.item()) 30 | 31 | return loss, r_losses, g_losses 32 | 33 | 34 | def generator_loss(disc_outputs): 35 | loss = 0 36 | gen_losses = [] 37 | for dg in disc_outputs: 38 | dg = dg.float() 39 | l = torch.mean((1-dg)**2) 40 | gen_losses.append(l) 41 | loss += l 42 | 43 | return loss, gen_losses 44 | 45 | 46 | def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): 47 | """ 48 | z_p, logs_q: [b, h, t_t] 49 | m_p, logs_p: [b, h, t_t] 50 | """ 51 | z_p = z_p.float() 52 | logs_q = logs_q.float() 53 | m_p = m_p.float() 54 | logs_p = logs_p.float() 55 | z_mask = z_mask.float() 56 | 57 | kl = logs_p - logs_q - 0.5 58 | kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) 59 | kl = torch.sum(kl * z_mask) 60 | l = kl / torch.sum(z_mask) 61 | return l 62 | 63 | def mle_loss(z, m, logs, logdet, mask): 64 | l = torch.sum(logs) + 0.5 * torch.sum(torch.exp(-2 * logs) * ((z - m)**2)) # neg normal likelihood w/o the constant term 65 | l = l - torch.sum(logdet) # log jacobian determinant 66 | l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes 67 | l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term 68 | return l -------------------------------------------------------------------------------- /configs/vits.json: -------------------------------------------------------------------------------- 1 | { 2 | "train": { 3 | "log_interval": 200, 4 | "eval_interval": 1000, 5 | "seed": 1234, 6 | "epochs": 10000, 7 | "learning_rate": 0.0001, 8 | "betas": [ 9 | 0.8, 10 | 0.99 11 | ], 12 | "eps": 1e-09, 13 | "batch_size": 12, 14 | "fp16_run": false, 15 | "lr_decay": 0.999875, 16 | "segment_size": 20480, 17 | "init_lr_ratio": 1, 18 | "warmup_epochs": 0, 19 | "c_mel": 45, 20 | "c_kl": 1.0 21 | }, 22 | "data": { 23 | "training_files": "dump/train_files.list", 24 | "validation_files": "dump/val_files.list", 25 | "max_wav_value": 32768.0, 26 | "sampling_rate": 32000, 27 | "filter_length": 2048, 28 | "hop_length": 640, 29 | "win_length": 2048, 30 | "n_mel_channels": 128, 31 | "mel_fmin": 0.0, 32 | "mel_fmax": null, 33 | "add_blank": true, 34 | "n_speakers": 5500, 35 | "cleaned_text": true 36 | }, 37 | "model": { 38 | "inter_channels": 192, 39 | "hidden_channels": 192, 40 | "filter_channels": 768, 41 | "n_heads": 2, 42 | "n_layers": 6, 43 | "kernel_size": 3, 44 | "p_dropout": 0.1, 45 | "resblock": "1", 46 | "resblock_kernel_sizes": [ 47 | 3, 48 | 7, 49 | 11 50 | ], 51 | "resblock_dilation_sizes": [ 52 | [ 53 | 1, 54 | 3, 55 | 5 56 | ], 57 | [ 58 | 1, 59 | 3, 60 | 5 61 | ], 62 | [ 63 | 1, 64 | 3, 65 | 5 66 | ] 67 | ], 68 | "upsample_rates": [ 69 | 10, 70 | 8, 71 | 2, 72 | 2, 73 | 2 74 | ], 75 | "upsample_initial_channel": 512, 76 | "upsample_kernel_sizes": [ 77 | 16, 78 | 16, 79 | 8, 80 | 2, 81 | 2 82 | ], 83 | "n_layers_q": 3, 84 | "use_spectral_norm": false, 85 | "gin_channels": 192, 86 | "semantic_frame_rate": "16hz", 87 | "freeze_quantizer": false, 88 | "use_reference_enc": true 89 | }, 90 | "s2_ckpt_dir": "logs/vits", 91 | "content_module": "whisper" 92 | } -------------------------------------------------------------------------------- /extract_spk_embedding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from pyannote.audio import Model 3 | from pyannote.audio import Inference 4 | 5 | import math 6 | import multiprocessing 7 | from random import shuffle 8 | import torch.multiprocessing as mp 9 | 10 | import torch 11 | from glob import glob 12 | 13 | from tqdm import tqdm 14 | 15 | import logging 16 | 17 | from data_conf import data_root 18 | 19 | logging.getLogger("numba").setLevel(logging.WARNING) 20 | 21 | 22 | def process_one(file_path, inference, device): 23 | spk_emb_path = file_path.replace(".wav", ".spk.npy").replace(".mp3", ".spk.npy") 24 | try: 25 | np.load(spk_emb_path) 26 | except: 27 | 28 | embedding = inference(file_path) 29 | np.save(spk_emb_path, embedding) 30 | np.save(spk_emb_path, embedding) 31 | 32 | 33 | def process_batch(filenames): 34 | print("Loading models ...") 35 | process_idx = mp.current_process()._identity 36 | rank = process_idx[0] if len(process_idx) > 0 else 0 37 | gpu_id = rank % torch.cuda.device_count() 38 | device = torch.device(f"cuda:{gpu_id}") 39 | print(device) 40 | 41 | model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") 42 | model = model.to(device) 43 | inference = Inference(model, window="whole") 44 | 45 | print("Loaded .") 46 | with torch.no_grad(): 47 | for filename in tqdm(filenames): 48 | process_one(filename, inference, device) 49 | 50 | 51 | in_dir = data_root 52 | 53 | if __name__ == "__main__": 54 | filenames = glob(f"{in_dir}/**/*.wav", recursive=True) # [:10] 55 | filenames += glob(f"{in_dir}/**/*.mp3", recursive=True) # [:10] 56 | shuffle(filenames) 57 | multiprocessing.set_start_method("spawn", force=True) 58 | 59 | num_processes = 1 60 | chunk_size = int(math.ceil(len(filenames) / num_processes)) 61 | chunks = [ 62 | filenames[i: i + chunk_size] for i in range(0, len(filenames), chunk_size) 63 | ] 64 | print([len(c) for c in chunks]) 65 | processes = [ 66 | multiprocessing.Process(target=process_batch, args=(chunk,)) for chunk in chunks 67 | ] 68 | for p in processes: 69 | p.start() 70 | 71 | for p in processes: 72 | p.join() 73 | -------------------------------------------------------------------------------- /extract_ssl.py: -------------------------------------------------------------------------------- 1 | import math 2 | import multiprocessing 3 | import argparse 4 | from random import shuffle 5 | import torch.multiprocessing as mp 6 | 7 | import torch 8 | from glob import glob 9 | from tqdm import tqdm 10 | 11 | import utils 12 | from data_conf import data_root 13 | from module.cnhubert import get_model, get_content 14 | import logging 15 | 16 | logging.getLogger("numba").setLevel(logging.WARNING) 17 | import librosa 18 | 19 | 20 | def process_one(file_path, model): 21 | 22 | # file_path16k = file_path.replace('genshin_data', 'genshin_data16k') 23 | 24 | ssl_path = file_path.replace(".wav", ".ssl.pt").replace(".mp3", ".ssl.pt") 25 | ssl_content = get_content(model, file_path) 26 | assert not torch.isnan(ssl_content).any(), f"NaN in {file_path}" 27 | torch.save(ssl_content.half().cpu(), ssl_path) 28 | 29 | def process_batch(filenames): 30 | print("Loading content model...") 31 | rank = mp.current_process()._identity 32 | rank = rank[0] if len(rank) > 0 else 0 33 | gpu_id = rank % torch.cuda.device_count() 34 | device = torch.device(f"cuda:{gpu_id}") 35 | ssl_model = get_model() 36 | ssl_model = ssl_model.to(device) 37 | ssl_model.eval() 38 | print("Loaded content model.") 39 | for filename in tqdm(filenames): 40 | try: 41 | process_one(filename, ssl_model) 42 | except Exception as e: 43 | print(f"Error processing {filename}: {e}") 44 | 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser() 48 | 49 | parser.add_argument( 50 | "--config", type=str, default="configs/vits.json", help="path to config" 51 | ) 52 | args = parser.parse_args() 53 | filenames = glob(f"{data_root}/**/*.wav", recursive=True) # [:10] 54 | filenames += glob(f"{data_root}/**/*.mp3", recursive=True) # [:10] 55 | hps = utils.get_hparams_from_file(args.config) 56 | shuffle(filenames) 57 | multiprocessing.set_start_method("spawn", force=True) 58 | 59 | num_processes = 1 60 | chunk_size = int(math.ceil(len(filenames) / num_processes)) 61 | chunks = [ 62 | filenames[i : i + chunk_size] for i in range(0, len(filenames), chunk_size) 63 | ] 64 | print([len(c) for c in chunks]) 65 | processes = [ 66 | multiprocessing.Process(target=process_batch, args=(chunk, )) for chunk in chunks 67 | ] 68 | for p in processes: 69 | p.start() -------------------------------------------------------------------------------- /gen_phonemes.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os.path 3 | from glob import glob 4 | 5 | import torch 6 | from tqdm import tqdm 7 | 8 | from data_conf import data_root 9 | from text.bert import get_bert_feature 10 | from text.cleaner import clean_text 11 | import numpy as np 12 | from multiprocessing import Pool 13 | 14 | out_dir = "dump" 15 | os.makedirs(out_dir, exist_ok=True) 16 | phoneme_path = os.path.join(out_dir, "phoneme.npy") 17 | phone_dict = {} 18 | 19 | 20 | def process_file(data): 21 | wav_path, language = data 22 | lab_path = wav_path.replace(".wav", ".lab").replace(".mp3", ".lab") 23 | if os.path.exists(lab_path): 24 | print(lab_path) 25 | text = open(lab_path, encoding='utf-8').readline().strip() 26 | 27 | try: 28 | phones, word2ph, norm_text = clean_text(text, language) 29 | 30 | rank = multiprocessing.current_process()._identity 31 | rank = rank[0] if len(rank) > 0 else 0 32 | gpu_id = rank % torch.cuda.device_count() 33 | device = torch.device(f"cuda:{gpu_id}") 34 | bert_feature = get_bert_feature(norm_text, word2ph, device, language) 35 | torch.save(bert_feature.cpu(), wav_path.replace(".wav", ".bert.pt").replace(".mp3", ".bert.pt")) 36 | 37 | phones = " ".join(phones) 38 | return (wav_path, phones) 39 | except Exception as e: 40 | print(f"Error in {wav_path}, {text}", e) 41 | return None 42 | else: 43 | return None 44 | 45 | 46 | if __name__ == '__main__': 47 | 48 | for language in ['zh']: 49 | filenames = glob(f"{data_root}/{language}/**/*.wav", recursive=True) 50 | filenames += glob(f"{data_root}/{language}/**/*.mp3", recursive=True) 51 | 52 | # Define the number of processes to use 53 | num_processes = 1 # You can adjust this as needed 54 | # multiprocessing.set_start_method("spawn", force=True) 55 | print(len(filenames)) 56 | with Pool(num_processes) as pool: 57 | results = list(tqdm(pool.imap(process_file, [(f, language) for f in filenames]), total=len(filenames))) 58 | 59 | for result in results: 60 | if result is not None: 61 | phone_dict[result[0]] = result[1] 62 | # 输出前10个 63 | for k, v in list(phone_dict.items())[:10]: 64 | print(k, v) 65 | np.save(phoneme_path, phone_dict) 66 | -------------------------------------------------------------------------------- /module/cnhubert.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import librosa 4 | import torch 5 | import torch.nn.functional as F 6 | import soundfile as sf 7 | import logging 8 | 9 | logging.getLogger("numba").setLevel(logging.WARNING) 10 | 11 | from transformers import ( 12 | Wav2Vec2FeatureExtractor, 13 | HubertModel, 14 | ) 15 | from torchaudio.transforms import Resample 16 | 17 | import utils 18 | import torch.nn as nn 19 | 20 | class CNHubert(nn.Module): 21 | def __init__(self, input_sample_rate=32000): 22 | super().__init__() 23 | self.model = HubertModel.from_pretrained("TencentGameMate/chinese-hubert-base") 24 | self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("TencentGameMate/chinese-hubert-base") 25 | if input_sample_rate != 16000: 26 | self.resample = Resample(orig_freq=input_sample_rate, new_freq=16000) 27 | else: 28 | self.resample = None 29 | self.input_sample_rate = input_sample_rate 30 | self.log_flag = False 31 | def forward(self, x): 32 | if self.resample is not None: 33 | if not self.log_flag: 34 | print(f"Resampling from {x.shape[-1]} to 16000") 35 | self.log_flag = True 36 | x = self.resample(x) 37 | input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device).to(x.dtype) 38 | feats = self.model(input_values)["last_hidden_state"] 39 | return feats 40 | 41 | 42 | 43 | def get_model(input_sample_rate=32000): 44 | model = CNHubert(input_sample_rate=input_sample_rate) 45 | model.eval() 46 | return model 47 | 48 | def get_content(hmodel, src_path): 49 | input_sample_rate = hmodel.input_sample_rate 50 | wav, _ = librosa.load(src_path, sr=input_sample_rate) 51 | device = hmodel.parameters().__next__().device 52 | dtype = hmodel.parameters().__next__().dtype 53 | wav_16k_tensor = torch.from_numpy(wav).to(device).to(dtype) 54 | with torch.no_grad(): 55 | feats = hmodel(wav_16k_tensor) 56 | return feats.transpose(1,2) 57 | 58 | 59 | if __name__ == '__main__': 60 | model = get_model() 61 | src_path = "/Users/Shared/原音频2.wav" 62 | wav_16k_tensor = utils.load_wav_to_torch_and_resample(src_path, 16000) 63 | model = model 64 | wav_16k_tensor = wav_16k_tensor 65 | feats = get_content(model,wav_16k_tensor) 66 | print(feats.shape) 67 | 68 | -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | punctuation = ['!', '?', '…', ",", ".", '-'] 4 | pu_symbols = punctuation + ["SP", 'SP2', 'SP3', "UNK", 'EOS'] 5 | pad = '_' 6 | 7 | c = ['AA', 'EE', 'OO', 'b', 'c', 'ch', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 'sh', 't', 'w', 'x', 'y', 'z', 'zh'] 8 | v = ['E1', 'En1', 'a1', 'ai1', 'an1', 'ang1', 'ao1', 'e1', 'ei1', 'en1', 'eng1', 'er1', 'i1', 'i01', 'ia1', 'ian1', 'iang1', 'iao1', 'ie1', 'in1', 'ing1', 'iong1', 'ir1', 'iu1', 'o1', 'ong1', 'ou1', 'u1', 'ua1', 'uai1', 'uan1', 'uang1', 'ui1', 'un1', 'uo1', 'v1', 'van1', 've1', 'vn1', 'E2', 'En2', 'a2', 'ai2', 'an2', 'ang2', 'ao2', 'e2', 'ei2', 'en2', 'eng2', 'er2', 'i2', 'i02', 'ia2', 'ian2', 'iang2', 'iao2', 'ie2', 'in2', 'ing2', 'iong2', 'ir2', 'iu2', 'o2', 'ong2', 'ou2', 'u2', 'ua2', 'uai2', 'uan2', 'uang2', 'ui2', 'un2', 'uo2', 'v2', 'van2', 've2', 'vn2', 'E3', 'En3', 'a3', 'ai3', 'an3', 'ang3', 'ao3', 'e3', 'ei3', 'en3', 'eng3', 'er3', 'i3', 'i03', 'ia3', 'ian3', 'iang3', 'iao3', 'ie3', 'in3', 'ing3', 'iong3', 'ir3', 'iu3', 'o3', 'ong3', 'ou3', 'u3', 'ua3', 'uai3', 'uan3', 'uang3', 'ui3', 'un3', 'uo3', 'v3', 'van3', 've3', 'vn3', 'E4', 'En4', 'a4', 'ai4', 'an4', 'ang4', 'ao4', 'e4', 'ei4', 'en4', 'eng4', 'er4', 'i4', 'i04', 'ia4', 'ian4', 'iang4', 'iao4', 'ie4', 'in4', 'ing4', 'iong4', 'ir4', 'iu4', 'o4', 'ong4', 'ou4', 'u4', 'ua4', 'uai4', 'uan4', 'uang4', 'ui4', 'un4', 'uo4', 'v4', 'van4', 've4', 'vn4', 'E5', 'En5', 'a5', 'ai5', 'an5', 'ang5', 'ao5', 'e5', 'ei5', 'en5', 'eng5', 'er5', 'i5', 'i05', 'ia5', 'ian5', 'iang5', 'iao5', 'ie5', 'in5', 'ing5', 'iong5', 'ir5', 'iu5', 'o5', 'ong5', 'ou5', 'u5', 'ua5', 'uai5', 'uan5', 'uang5', 'ui5', 'un5', 'uo5', 'v5', 'van5', 've5', 'vn5'] 9 | 10 | v_without_tone = ['E', 'En', 'a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'i0', 'ia', 'ian', 'iang', 'iao', 'ie', 'in', 'ing', 'iong', 'ir', 'iu', 'o', 'ong', 'ou', 'u', 'ua', 'uai', 'uan', 'uang', 'ui', 'un', 'uo', 'v', 'van', 've', 'vn'] 11 | 12 | # japanese 13 | ja_symbols = ['I', 'N', 'U', 'a', 'b', 'by', 'ch', 'cl', 'd', 'dy', 'e', 'f', 'g', 'gy', 'h', 'hy', 'i', 'j', 'k', 'ky', 14 | 'm', 'my', 'n', 'ny', 'o', 'p', 'py', 'r', 'ry', 's', 'sh', 't', 'ts', 'u', 'v', 'w', 'y', 'z'] 15 | 16 | arpa = {'AH0', 'S', 'AH1', 'EY2', 'AE2', 'EH0', 'OW2', 'UH0', 'NG', 'B', 'G', 'AY0', 'M', 'AA0', 'F', 'AO0', 'ER2', 'UH1', 'IY1', 'AH2', 'DH', 'IY0', 'EY1', 'IH0', 'K', 'N', 'W', 'IY2', 'T', 'AA1', 'ER1', 'EH2', 'OY0', 'UH2', 'UW1', 'Z', 'AW2', 'AW1', 'V', 'UW2', 'AA2', 'ER', 'AW0', 'UW0', 'R', 'OW1', 'EH1', 'ZH', 'AE0', 'IH2', 'IH', 'Y', 'JH', 'P', 'AY1', 'EY0', 'OY2', 'TH', 'HH', 'D', 'ER0', 'CH', 'AO1', 'AE1', 'AO2', 'OY1', 'AY2', 'IH1', 'OW0', 'L', 'SH'} 17 | 18 | symbols = c + v + ja_symbols + pu_symbols + list(arpa) 19 | symbols = [pad] + sorted(set(symbols)) 20 | if __name__ == '__main__': 21 | print(len(symbols)) 22 | print(symbols) -------------------------------------------------------------------------------- /transformer/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/lr_schedulers.py 2 | import math 3 | 4 | import torch 5 | from matplotlib import pyplot as plt 6 | from torch import nn 7 | from torch.optim import Adam 8 | 9 | 10 | class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler): 11 | """ 12 | Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers. 13 | """ 14 | 15 | def __init__(self, 16 | optimizer, 17 | init_lr, 18 | peak_lr, 19 | end_lr, 20 | warmup_steps=10000, 21 | total_steps=400000, 22 | current_step=0): 23 | self.init_lr = init_lr 24 | self.peak_lr = peak_lr 25 | self.end_lr = end_lr 26 | self.optimizer = optimizer 27 | self._warmup_rate = (peak_lr - init_lr) / warmup_steps 28 | self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps) 29 | self._current_step = current_step 30 | self.lr = init_lr 31 | self.warmup_steps = warmup_steps 32 | self.total_steps = total_steps 33 | self._last_lr = [self.lr] 34 | 35 | def set_lr(self, lr): 36 | self._last_lr = [g['lr'] for g in self.optimizer.param_groups] 37 | for g in self.optimizer.param_groups: 38 | g['lr'] = lr 39 | 40 | def step(self): 41 | if self._current_step < self.warmup_steps: 42 | lr = self.init_lr + self._warmup_rate * self._current_step 43 | 44 | elif self._current_step > self.total_steps: 45 | lr = self.end_lr 46 | 47 | else: 48 | decay_ratio = (self._current_step - self.warmup_steps) / ( 49 | self.total_steps - self.warmup_steps) 50 | if decay_ratio < 0.0 or decay_ratio > 1.0: 51 | raise RuntimeError( 52 | "Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings." 53 | ) 54 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) 55 | lr = self.end_lr + coeff * (self.peak_lr - self.end_lr) 56 | 57 | self.set_lr(lr) 58 | self.lr = lr 59 | self._current_step += 1 60 | return self.lr 61 | 62 | 63 | if __name__ == '__main__': 64 | m = nn.Linear(10, 10) 65 | opt = Adam(m.parameters(), lr=1e-4) 66 | s = WarmupCosineLRSchedule( 67 | opt, 68 | 1e-6, 69 | 2e-4, 70 | 1e-6, 71 | warmup_steps=2000, 72 | total_steps=20000, 73 | current_step=0) 74 | lrs = [] 75 | for i in range(25000): 76 | s.step() 77 | lrs.append(s.lr) 78 | print(s.lr) 79 | 80 | plt.plot(lrs) 81 | plt.plot(range(0, 25000), lrs) 82 | plt.show() 83 | -------------------------------------------------------------------------------- /transformer/embedding.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/embedding.py 2 | import math 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class TokenEmbedding(nn.Module): 9 | def __init__( 10 | self, 11 | embedding_dim: int, 12 | vocab_size: int, 13 | dropout: float=0.0, ): 14 | super().__init__() 15 | 16 | self.vocab_size = vocab_size 17 | self.embedding_dim = embedding_dim 18 | 19 | self.dropout = torch.nn.Dropout(p=dropout) 20 | self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_dim) 21 | 22 | @property 23 | def weight(self) -> torch.Tensor: 24 | return self.word_embeddings.weight 25 | 26 | def embedding(self, index: int) -> torch.Tensor: 27 | return self.word_embeddings.weight[index:index + 1] 28 | 29 | def forward(self, x: torch.Tensor): 30 | x = self.word_embeddings(x) 31 | x = self.dropout(x) 32 | return x 33 | 34 | 35 | class SinePositionalEmbedding(nn.Module): 36 | def __init__( 37 | self, 38 | embedding_dim: int, 39 | dropout: float=0.0, 40 | scale: bool=False, 41 | alpha: bool=False, ): 42 | super().__init__() 43 | self.embedding_dim = embedding_dim 44 | self.x_scale = math.sqrt(embedding_dim) if scale else 1.0 45 | self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) 46 | self.dropout = torch.nn.Dropout(p=dropout) 47 | 48 | self.reverse = False 49 | self.pe = None 50 | self.extend_pe(torch.tensor(0.0).expand(1, 4000)) 51 | 52 | def extend_pe(self, x): 53 | """Reset the positional encodings.""" 54 | if self.pe is not None: 55 | if self.pe.size(1) >= x.size(1): 56 | if self.pe.dtype != x.dtype or self.pe.device != x.device: 57 | self.pe = self.pe.to(dtype=x.dtype, device=x.device) 58 | return 59 | pe = torch.zeros(x.size(1), self.embedding_dim) 60 | if self.reverse: 61 | position = torch.arange( 62 | x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1) 63 | else: 64 | position = torch.arange( 65 | 0, x.size(1), dtype=torch.float32).unsqueeze(1) 66 | div_term = torch.exp( 67 | torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * 68 | -(math.log(10000.0) / self.embedding_dim)) 69 | pe[:, 0::2] = torch.sin(position * div_term) 70 | pe[:, 1::2] = torch.cos(position * div_term) 71 | pe = pe.unsqueeze(0) 72 | self.pe = pe.to(device=x.device, dtype=x.dtype).detach() 73 | 74 | def forward(self, x: torch.Tensor) -> torch.Tensor: 75 | self.extend_pe(x) 76 | output = x.unsqueeze(-1) if x.ndim == 2 else x 77 | output = output * self.x_scale + self.alpha * self.pe[:, :x.size(1)] 78 | return self.dropout(output) 79 | -------------------------------------------------------------------------------- /text/japanese.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/CjangCjengh/vits/blob/main/text/japanese.py 2 | import re 3 | import sys 4 | 5 | # import pyopenjtalk 6 | 7 | from text import symbols 8 | 9 | # Regular expression matching Japanese without punctuation marks: 10 | _japanese_characters = re.compile( 11 | r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 12 | 13 | # Regular expression matching non-Japanese characters or punctuation marks: 14 | _japanese_marks = re.compile( 15 | r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 16 | 17 | # List of (symbol, Japanese) pairs for marks: 18 | _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [ 19 | ('%', 'パーセント') 20 | ]] 21 | 22 | 23 | # List of (consonant, sokuon) pairs: 24 | _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 25 | (r'Q([↑↓]*[kg])', r'k#\1'), 26 | (r'Q([↑↓]*[tdjʧ])', r't#\1'), 27 | (r'Q([↑↓]*[sʃ])', r's\1'), 28 | (r'Q([↑↓]*[pb])', r'p#\1') 29 | ]] 30 | 31 | # List of (consonant, hatsuon) pairs: 32 | _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 33 | (r'N([↑↓]*[pbm])', r'm\1'), 34 | (r'N([↑↓]*[ʧʥj])', r'n^\1'), 35 | (r'N([↑↓]*[tdn])', r'n\1'), 36 | (r'N([↑↓]*[kg])', r'ŋ\1') 37 | ]] 38 | 39 | 40 | 41 | def post_replace_ph(ph): 42 | rep_map = { 43 | ':': ',', 44 | ';': ',', 45 | ',': ',', 46 | '。': '.', 47 | '!': '!', 48 | '?': '?', 49 | '\n': '.', 50 | "·": ",", 51 | '、': ",", 52 | '...': '…' 53 | } 54 | if ph in rep_map.keys(): 55 | ph = rep_map[ph] 56 | if ph in symbols: 57 | return ph 58 | if ph not in symbols: 59 | ph = 'UNK' 60 | return ph 61 | 62 | def symbols_to_japanese(text): 63 | for regex, replacement in _symbols_to_japanese: 64 | text = re.sub(regex, replacement, text) 65 | return text 66 | 67 | 68 | def preprocess_jap(text): 69 | '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' 70 | text = symbols_to_japanese(text) 71 | sentences = re.split(_japanese_marks, text) 72 | marks = re.findall(_japanese_marks, text) 73 | text = [] 74 | for i, sentence in enumerate(sentences): 75 | if re.match(_japanese_characters, sentence): 76 | p = pyopenjtalk.g2p(sentence) 77 | text += p.split(" ") 78 | 79 | if i < len(marks): 80 | text += [marks[i].replace(' ', '')] 81 | return text 82 | 83 | def text_normalize(text): 84 | # todo: jap text normalize 85 | return text 86 | 87 | def g2p(norm_text): 88 | phones = preprocess_jap(norm_text) 89 | phones = [post_replace_ph(i) for i in phones] 90 | word2ph = [1 for i in phones] 91 | # todo: implement tones and word2ph 92 | return phones, word2ph 93 | 94 | 95 | if __name__ == '__main__': 96 | for line in open("../../../Downloads/transcript_utf8.txt").readlines(): 97 | text = line.split(":")[1] 98 | phones = g2p(text) 99 | print(phones) 100 | -------------------------------------------------------------------------------- /text/english.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import re 4 | from g2p_en import G2p 5 | 6 | from string import punctuation 7 | 8 | from text import symbols 9 | 10 | current_file_path = os.path.dirname(__file__) 11 | CMU_DICT_PATH = os.path.join(current_file_path, 'cmudict.rep') 12 | CACHE_PATH = os.path.join(current_file_path, 'cmudict_cache.pickle') 13 | _g2p = G2p() 14 | 15 | arpa = {'AH0', 'S', 'AH1', 'EY2', 'AE2', 'EH0', 'OW2', 'UH0', 'NG', 'B', 'G', 'AY0', 'M', 'AA0', 'F', 'AO0', 'ER2', 'UH1', 'IY1', 'AH2', 'DH', 'IY0', 'EY1', 'IH0', 'K', 'N', 'W', 'IY2', 'T', 'AA1', 'ER1', 'EH2', 'OY0', 'UH2', 'UW1', 'Z', 'AW2', 'AW1', 'V', 'UW2', 'AA2', 'ER', 'AW0', 'UW0', 'R', 'OW1', 'EH1', 'ZH', 'AE0', 'IH2', 'IH', 'Y', 'JH', 'P', 'AY1', 'EY0', 'OY2', 'TH', 'HH', 'D', 'ER0', 'CH', 'AO1', 'AE1', 'AO2', 'OY1', 'AY2', 'IH1', 'OW0', 'L', 'SH'} 16 | 17 | 18 | def replace_phs(phs): 19 | rep_map = { 20 | ';': ',', 21 | ':': ',', 22 | '\'': '-', 23 | '"': '-' 24 | } 25 | phs_new = [] 26 | for ph in phs: 27 | if ph in symbols: 28 | phs_new.append(ph) 29 | elif ph in rep_map.keys(): 30 | phs_new.append(rep_map[ph]) 31 | else: 32 | print('ph not in symbols: ', ph) 33 | return phs_new 34 | 35 | def read_dict(): 36 | g2p_dict = {} 37 | start_line = 49 38 | with open(CMU_DICT_PATH) as f: 39 | line = f.readline() 40 | line_index = 1 41 | while line: 42 | if line_index >= start_line: 43 | line = line.strip() 44 | word_split = line.split(' ') 45 | word = word_split[0] 46 | 47 | syllable_split = word_split[1].split(' - ') 48 | g2p_dict[word] = [] 49 | for syllable in syllable_split: 50 | phone_split = syllable.split(' ') 51 | g2p_dict[word].append(phone_split) 52 | 53 | line_index = line_index + 1 54 | line = f.readline() 55 | 56 | return g2p_dict 57 | 58 | 59 | def cache_dict(g2p_dict, file_path): 60 | with open(file_path, 'wb') as pickle_file: 61 | pickle.dump(g2p_dict, pickle_file) 62 | 63 | 64 | def get_dict(): 65 | if os.path.exists(CACHE_PATH): 66 | with open(CACHE_PATH, 'rb') as pickle_file: 67 | g2p_dict = pickle.load(pickle_file) 68 | else: 69 | g2p_dict = read_dict() 70 | cache_dict(g2p_dict, CACHE_PATH) 71 | 72 | return g2p_dict 73 | 74 | eng_dict = get_dict() 75 | 76 | 77 | def text_normalize(text): 78 | # todo: eng text normalize 79 | return text.replace(";", ",") 80 | 81 | def g2p(text): 82 | 83 | phones = [] 84 | words = re.split(r"([,;.\-\?\!\s+])", text) 85 | for w in words: 86 | if w.upper() in eng_dict: 87 | phns = eng_dict[w.upper()] 88 | for ph in phns: 89 | phones += ph 90 | else: 91 | phone_list = list(filter(lambda p: p != " ", _g2p(w))) 92 | for ph in phone_list: 93 | if ph in arpa: 94 | phones.append(ph) 95 | else: 96 | phones.append(ph) 97 | 98 | phones = replace_phs(phones) 99 | word2ph = [1 for i in phones] 100 | return phones, word2ph 101 | 102 | if __name__ == "__main__": 103 | # print(get_dict()) 104 | print(g2p("hello")) 105 | print(g2p("In this; paper, we propose 1 DSPGAN, a GAN-based universal vocoder.")) 106 | # all_phones = set() 107 | # for k, syllables in eng_dict.items(): 108 | # for group in syllables: 109 | # for ph in group: 110 | # all_phones.add(ph) 111 | # print(all_phones) -------------------------------------------------------------------------------- /asr_train.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import Trainer 2 | from pytorch_lightning.callbacks import ModelCheckpoint 3 | from pytorch_lightning.strategies import DDPStrategy 4 | from pytorch_lightning.loggers import WandbLogger 5 | from torch import nn 6 | 7 | from asr.data_module import ASRDataModule 8 | from asr.meldataset import build_dataloader 9 | from utils import * 10 | from asr.models import build_model 11 | from asr.trainer import ASRTrainer 12 | 13 | import os 14 | import os.path as osp 15 | import yaml 16 | import shutil 17 | import click 18 | 19 | def get_data_path_list(train_path=None, val_path=None): 20 | if train_path is None: 21 | train_path = "Data/train_list.txt" 22 | if val_path is None: 23 | val_path = "Data/val_list.txt" 24 | 25 | with open(train_path, 'r') as f: 26 | train_list = f.readlines() 27 | with open(val_path, 'r') as f: 28 | val_list = f.readlines() 29 | 30 | return train_list, val_list 31 | def build_criterion(critic_params={}): 32 | criterion = { 33 | "ce": nn.CrossEntropyLoss(ignore_index=-1), 34 | "ctc": torch.nn.CTCLoss(**critic_params.get('ctc', {})), 35 | } 36 | return criterion 37 | 38 | @click.command() 39 | @click.option('-c', '--config_path', default='configs/asr.yml', type=str) 40 | def main(config_path): 41 | config = yaml.safe_load(open(config_path)) 42 | log_dir = config['log_dir'] 43 | os.makedirs(log_dir,exist_ok=True) 44 | shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path))) 45 | 46 | 47 | batch_size = config.get('batch_size', 10) 48 | epochs = config.get('epochs', 1000) 49 | save_freq = config.get('save_freq', 20) 50 | train_path = config.get('train_data', None) 51 | val_path = config.get('val_data', None) 52 | 53 | # train_list, val_list = get_data_path_list(train_path, val_path) 54 | # train_dataloader = build_dataloader(train_list, 55 | # batch_size=batch_size, 56 | # num_workers=8, 57 | # dataset_config=config.get('dataset_params', {})) 58 | # 59 | # val_dataloader = build_dataloader(val_list, 60 | # batch_size=batch_size, 61 | # validation=True, 62 | # num_workers=2, 63 | # dataset_config=config.get('dataset_params', {})) 64 | data_module = ASRDataModule(data_dir='dump', batch_size=batch_size, num_workers=8) 65 | 66 | model = build_model(model_params=config['model_params'] or {}) 67 | 68 | checkpoint_callback = ModelCheckpoint( 69 | dirpath=log_dir, 70 | filename=('{epoch}-{step}'), 71 | every_n_train_steps=None, 72 | every_n_epochs=1, 73 | verbose=True, 74 | save_last=True 75 | ) 76 | logger = WandbLogger(project="asr-align") 77 | 78 | blank_index = 0 79 | 80 | criterion = build_criterion(critic_params={ 81 | 'ctc': {'blank': blank_index}, 82 | }) 83 | training_wrapper = ASRTrainer(model=model, criterion=criterion,mono_start_epoch=10,lr=1e-4) 84 | if config.get('pretrained_model',None): 85 | training_wrapper.load_checkpoint(config['pretrained_model']) 86 | 87 | trainer: Trainer = Trainer( 88 | max_epochs=epochs, 89 | accelerator='gpu', 90 | devices=-1, 91 | benchmark=False, 92 | fast_dev_run=False, 93 | strategy=DDPStrategy(), 94 | logger=logger, 95 | callbacks=[checkpoint_callback]) 96 | 97 | trainer.fit(training_wrapper, data_module) 98 | 99 | 100 | if __name__ == "__main__": 101 | main() -------------------------------------------------------------------------------- /module/wavenet.py: -------------------------------------------------------------------------------- 1 | # https://github.com/CNChTu/Diffusion-SVC/blob/v1_Stable/diffusion/wavenet.py 2 | import math 3 | from math import sqrt 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.nn import Mish 9 | 10 | 11 | class Conv1d(torch.nn.Conv1d): 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | nn.init.kaiming_normal_(self.weight) 15 | 16 | 17 | class SinusoidalPosEmb(nn.Module): 18 | def __init__(self, dim): 19 | super().__init__() 20 | self.dim = dim 21 | 22 | def forward(self, x): 23 | device = x.device 24 | half_dim = self.dim // 2 25 | emb = math.log(10000) / (half_dim - 1) 26 | emb = torch.exp(torch.arange(half_dim, device=device) * -emb) 27 | emb = x[:, None] * emb[None, :] 28 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 29 | return emb 30 | 31 | 32 | class ResidualBlock(nn.Module): 33 | def __init__(self, encoder_hidden, residual_channels, dilation): 34 | super().__init__() 35 | self.residual_channels = residual_channels 36 | self.dilated_conv = nn.Conv1d( 37 | residual_channels, 38 | 2 * residual_channels, 39 | kernel_size=3, 40 | padding=dilation, 41 | dilation=dilation 42 | ) 43 | self.diffusion_projection = nn.Linear(residual_channels, residual_channels) 44 | self.conditioner_projection = nn.Conv1d(encoder_hidden, 2 * residual_channels, 1) 45 | self.output_projection = nn.Conv1d(residual_channels, 2 * residual_channels, 1) 46 | 47 | def forward(self, x, conditioner, diffusion_step): 48 | diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) 49 | conditioner = self.conditioner_projection(conditioner) 50 | y = x + diffusion_step 51 | 52 | y = self.dilated_conv(y) + conditioner 53 | 54 | # Using torch.split instead of torch.chunk to avoid using onnx::Slice 55 | gate, filter = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) 56 | y = torch.sigmoid(gate) * torch.tanh(filter) 57 | 58 | y = self.output_projection(y) 59 | 60 | # Using torch.split instead of torch.chunk to avoid using onnx::Slice 61 | residual, skip = torch.split(y, [self.residual_channels, self.residual_channels], dim=1) 62 | return (x + residual) / math.sqrt(2.0), skip 63 | 64 | 65 | class WaveNet(nn.Module): 66 | def __init__(self, in_dims=128, n_layers=20, n_chans=384, n_hidden=256): 67 | super().__init__() 68 | self.input_projection = Conv1d(in_dims, n_chans, 1) 69 | self.diffusion_embedding = SinusoidalPosEmb(n_chans) 70 | self.mlp = nn.Sequential( 71 | nn.Linear(n_chans, n_chans * 4), 72 | Mish(), 73 | nn.Linear(n_chans * 4, n_chans) 74 | ) 75 | self.residual_layers = nn.ModuleList([ 76 | ResidualBlock( 77 | encoder_hidden=n_hidden, 78 | residual_channels=n_chans, 79 | dilation=1 80 | ) 81 | for i in range(n_layers) 82 | ]) 83 | self.skip_projection = Conv1d(n_chans, n_chans, 1) 84 | self.output_projection = Conv1d(n_chans, in_dims, 1) 85 | nn.init.zeros_(self.output_projection.weight) 86 | 87 | def forward(self, x, diffusion_step, cond): 88 | """ 89 | :param x: [B, M, T] 90 | :param diffusion_step: [B, 1] 91 | :param cond: [B, M, T] 92 | :return: 93 | """ 94 | x = self.input_projection(x) # [B, residual_channel, T] 95 | 96 | x = F.relu(x) 97 | diffusion_step = self.diffusion_embedding(diffusion_step) 98 | diffusion_step = self.mlp(diffusion_step) 99 | skip = [] 100 | for layer in self.residual_layers: 101 | x, skip_connection = layer(x, cond, diffusion_step) 102 | skip.append(skip_connection) 103 | 104 | x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers)) 105 | x = self.skip_projection(x) 106 | x = F.relu(x) 107 | x = self.output_projection(x) # [B, mel_bins, T] 108 | return x -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from text import cleaned_text_to_sequence 2 | from text.cleaner import text_to_sequence, clean_text 3 | from gen_phonemes import get_bert_feature 4 | import torch.nn.functional as F 5 | from module import commons 6 | import torch 7 | import utils 8 | from module.models import SynthesizerTrn 9 | from module.mel_processing import spectrogram_torch, spec_to_mel_torch 10 | import soundfile 11 | import torchaudio 12 | from pyannote.audio import Model 13 | from pyannote.audio import Inference 14 | import numpy 15 | 16 | def text2phoneid(text, lang='zh'): 17 | phones, word2ph, norm_text = clean_text(text, lang) 18 | print(phones) 19 | 20 | bert = get_bert_feature(norm_text, word2ph, 'cpu', lang) 21 | phonemes = cleaned_text_to_sequence(phones) 22 | phonemes = commons.intersperse(phonemes, 0) 23 | bert = F.interpolate(bert.unsqueeze(0), scale_factor=2, mode='nearest') 24 | bert = F.pad(bert, (0, 1), value=0).squeeze(0) 25 | return phones, phonemes, bert 26 | 27 | 28 | 29 | def load_model(device="cuda", config_path="configs/s2.json", model_path=None): 30 | device = torch.device(device) 31 | print('loading models...') 32 | hps = utils.get_hparams_from_file(config_path) 33 | net_g = SynthesizerTrn( 34 | hps.data.filter_length // 2 + 1, 35 | hps.train.segment_size // hps.data.hop_length, 36 | n_speakers=hps.data.n_speakers, 37 | **hps.model).to(device) 38 | if model_path is None: 39 | model_path = utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth") 40 | utils.load_checkpoint(model_path, net_g, 41 | None, False) 42 | net_g.eval() 43 | spk_emb_model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") 44 | spk_emb_model = spk_emb_model.to(device) 45 | inference = Inference(spk_emb_model, window="whole") 46 | 47 | return hps, net_g, inference 48 | 49 | 50 | 51 | def get_spepc(hps, filename): 52 | audio, sampling_rate = utils.load_wav_to_torch(filename) 53 | audio = audio.unsqueeze(0) 54 | if sampling_rate != hps.data.sampling_rate: 55 | audio = torchaudio.functional.resample(audio, sampling_rate, hps.data.sampling_rate) 56 | audio_norm = audio 57 | spec = spectrogram_torch(audio_norm, hps.data.filter_length, 58 | hps.data.sampling_rate, hps.data.hop_length, hps.data.win_length, 59 | center=False) 60 | return spec 61 | 62 | 63 | 64 | @torch.no_grad() 65 | @torch.inference_mode() 66 | def decode_to_file(codes, ref_path, save_path): 67 | device = codes.device 68 | hps, net_g, ssl = load_model(device=device) 69 | ref = get_spepc(hps, ref_path).to(device) 70 | 71 | audio = net_g.decode_codes(codes, ref).detach().cpu().numpy()[0, 0] 72 | soundfile.write(save_path, audio, hps.data.sampling_rate) 73 | 74 | import os 75 | if __name__ == '__main__': 76 | 77 | device = 'cpu' 78 | outdir = 'out' 79 | os.makedirs(outdir, exist_ok=True) 80 | txt_list = [ 81 | "先帝创业未半而中道崩殂,今天下三分,益州疲弊,此诚危急存亡之秋也。然侍卫之臣不懈于内,忠志之士忘身于外者,盖追先帝之殊遇,欲报之于陛下也。", 82 | '话说天下大势,分久必合,合久必分。周末七国分争,并入于秦。及秦灭之后,楚汉分争,又并入于汉。汉朝自高祖斩白蛇而起义,一统天下,后来光武中兴,传至献帝,遂分为三国。', 83 | ] 84 | 85 | prompt_list = [ "dataset_raw/zh/Azusa/Azusa_113.wav", 86 | 'dataset_raw/zh/Azusa/Azusa_288.wav',] 87 | 88 | 89 | 90 | hps, model,spk_emb_model = load_model(device=device) 91 | 92 | for name, text in enumerate(txt_list): 93 | for i, prompt_wav_path in enumerate(prompt_list): 94 | out_path = f'{outdir}/{name}_{i}.wav' 95 | phlist, phones, bert = text2phoneid(text) 96 | print(len(phones)) 97 | ref = get_spepc(hps, prompt_wav_path).to(device) 98 | spk_emb = spk_emb_model(prompt_wav_path) 99 | spk_emb = torch.FloatTensor(spk_emb).to(device).unsqueeze(0) 100 | 101 | all_phoneme_ids = torch.LongTensor(phones).to(device).unsqueeze(0) 102 | bert = bert.to(device).unsqueeze(0) 103 | x_lengths = torch.LongTensor([all_phoneme_ids.shape[-1]]).to(device) 104 | 105 | with torch.no_grad(): 106 | wavs = model.infer(all_phoneme_ids, x_lengths, ref, bert,spk_emb,noise_scale=.4) 107 | soundfile.write(out_path, wavs[0,0].cpu().numpy(), hps.data.sampling_rate) -------------------------------------------------------------------------------- /module/mel_processing.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | import torch.utils.data 8 | import numpy as np 9 | import librosa 10 | import librosa.util as librosa_util 11 | from librosa.util import normalize, pad_center, tiny 12 | from scipy.signal import get_window 13 | from scipy.io.wavfile import read 14 | from librosa.filters import mel as librosa_mel_fn 15 | 16 | MAX_WAV_VALUE = 32768.0 17 | 18 | 19 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 20 | """ 21 | PARAMS 22 | ------ 23 | C: compression factor 24 | """ 25 | return torch.log(torch.clamp(x, min=clip_val) * C) 26 | 27 | 28 | def dynamic_range_decompression_torch(x, C=1): 29 | """ 30 | PARAMS 31 | ------ 32 | C: compression factor used to compress 33 | """ 34 | return torch.exp(x) / C 35 | 36 | 37 | def spectral_normalize_torch(magnitudes): 38 | output = dynamic_range_compression_torch(magnitudes) 39 | return output 40 | 41 | 42 | def spectral_de_normalize_torch(magnitudes): 43 | output = dynamic_range_decompression_torch(magnitudes) 44 | return output 45 | 46 | 47 | mel_basis = {} 48 | hann_window = {} 49 | 50 | 51 | def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False): 52 | if torch.min(y) < -1.: 53 | print('min value is ', torch.min(y)) 54 | if torch.max(y) > 1.: 55 | print('max value is ', torch.max(y)) 56 | 57 | global hann_window 58 | dtype_device = str(y.dtype) + '_' + str(y.device) 59 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 60 | if wnsize_dtype_device not in hann_window: 61 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 62 | 63 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 64 | y = y.squeeze(1) 65 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 66 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 67 | 68 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 69 | return spec 70 | 71 | 72 | def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax): 73 | global mel_basis 74 | dtype_device = str(spec.dtype) + '_' + str(spec.device) 75 | fmax_dtype_device = str(fmax) + '_' + dtype_device 76 | if fmax_dtype_device not in mel_basis: 77 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 78 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) 79 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 80 | spec = spectral_normalize_torch(spec) 81 | return spec 82 | 83 | 84 | def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 85 | if torch.min(y) < -1.: 86 | print('min value is ', torch.min(y)) 87 | if torch.max(y) > 1.: 88 | print('max value is ', torch.max(y)) 89 | 90 | global mel_basis, hann_window 91 | dtype_device = str(y.dtype) + '_' + str(y.device) 92 | fmax_dtype_device = str(fmax) + '_' + dtype_device 93 | wnsize_dtype_device = str(win_size) + '_' + dtype_device 94 | if fmax_dtype_device not in mel_basis: 95 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 96 | mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) 97 | if wnsize_dtype_device not in hann_window: 98 | hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) 99 | 100 | y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') 101 | y = y.squeeze(1) 102 | 103 | spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], 104 | center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) 105 | 106 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 107 | 108 | spec = torch.matmul(mel_basis[fmax_dtype_device], spec) 109 | spec = spectral_normalize_torch(spec) 110 | 111 | return spec 112 | -------------------------------------------------------------------------------- /text/chinese.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | import cn2an 5 | from pypinyin import lazy_pinyin, Style 6 | 7 | from text.symbols import punctuation 8 | from text.tone_sandhi import ToneSandhi 9 | 10 | current_file_path = os.path.dirname(__file__) 11 | pinyin_to_symbol_map = {line.split("\t")[0]: line.strip().split("\t")[1] for line in 12 | open(os.path.join(current_file_path, 'opencpop-strict.txt')).readlines()} 13 | 14 | import jieba.posseg as psg 15 | 16 | 17 | rep_map = { 18 | ':': ',', 19 | ';': ',', 20 | ',': ',', 21 | '。': '.', 22 | '!': '!', 23 | '?': '?', 24 | '\n': '.', 25 | "·": ",", 26 | '、': ",", 27 | '...': '…', 28 | '$': '.', 29 | '—': "-" 30 | } 31 | 32 | tone_modifier = ToneSandhi() 33 | 34 | def replace_punctuation(text): 35 | text = text.replace("嗯", "恩").replace("呣","母") 36 | pattern = re.compile('|'.join(re.escape(p) for p in rep_map.keys())) 37 | 38 | replaced_text = pattern.sub(lambda x: rep_map[x.group()], text) 39 | 40 | replaced_text = re.sub(r'[^\u4e00-\u9fa5'+"".join(punctuation)+r']+', '', replaced_text) 41 | 42 | return replaced_text 43 | 44 | def g2p(text): 45 | pattern = r'(?<=[{0}])\s*'.format(''.join(punctuation)) 46 | sentences = [i for i in re.split(pattern, text) if i.strip()!=''] 47 | phones, word2ph = _g2p(sentences) 48 | return phones, word2ph 49 | 50 | 51 | def _get_initials_finals(word): 52 | initials = [] 53 | finals = [] 54 | orig_initials = lazy_pinyin( 55 | word, neutral_tone_with_five=True, style=Style.INITIALS) 56 | orig_finals = lazy_pinyin( 57 | word, neutral_tone_with_five=True, style=Style.FINALS_TONE3) 58 | for c, v in zip(orig_initials, orig_finals): 59 | initials.append(c) 60 | finals.append(v) 61 | return initials, finals 62 | 63 | 64 | def _g2p(segments): 65 | phones_list = [] 66 | word2ph = [] 67 | for seg in segments: 68 | pinyins = [] 69 | # Replace all English words in the sentence 70 | seg = re.sub('[a-zA-Z]+', '', seg) 71 | seg_cut = psg.lcut(seg) 72 | initials = [] 73 | finals = [] 74 | seg_cut = tone_modifier.pre_merge_for_modify(seg_cut) 75 | for word, pos in seg_cut: 76 | if pos == 'eng': 77 | continue 78 | sub_initials, sub_finals = _get_initials_finals(word) 79 | sub_finals = tone_modifier.modified_tone(word, pos, 80 | sub_finals) 81 | initials.append(sub_initials) 82 | finals.append(sub_finals) 83 | 84 | # assert len(sub_initials) == len(sub_finals) == len(word) 85 | initials = sum(initials, []) 86 | finals = sum(finals, []) 87 | # 88 | for c, v in zip(initials, finals): 89 | raw_pinyin = c+v 90 | # NOTE: post process for pypinyin outputs 91 | # we discriminate i, ii and iii 92 | if c == v: 93 | assert c in punctuation 94 | phone = [c] 95 | word2ph.append(1) 96 | else: 97 | v_without_tone = v[:-1] 98 | tone = v[-1] 99 | 100 | pinyin = c+v_without_tone 101 | assert tone in '12345' 102 | 103 | if c: 104 | # 多音节 105 | v_rep_map = { 106 | "uei": 'ui', 107 | 'iou': 'iu', 108 | 'uen': 'un', 109 | } 110 | if v_without_tone in v_rep_map.keys(): 111 | pinyin = c+v_rep_map[v_without_tone] 112 | else: 113 | # 单音节 114 | pinyin_rep_map = { 115 | 'ing': 'ying', 116 | 'i': 'yi', 117 | 'in': 'yin', 118 | 'u': 'wu', 119 | } 120 | if pinyin in pinyin_rep_map.keys(): 121 | pinyin = pinyin_rep_map[pinyin] 122 | else: 123 | single_rep_map = { 124 | 'v': 'yu', 125 | 'e': 'e', 126 | 'i': 'y', 127 | 'u': 'w', 128 | } 129 | if pinyin[0] in single_rep_map.keys(): 130 | pinyin = single_rep_map[pinyin[0]]+pinyin[1:] 131 | 132 | assert pinyin in pinyin_to_symbol_map.keys(), (pinyin, seg, raw_pinyin) 133 | new_c, new_v = pinyin_to_symbol_map[pinyin].split(' ') 134 | new_v = new_v + tone 135 | phone = [new_c, new_v] 136 | word2ph.append(len(phone)) 137 | 138 | phones_list += phone 139 | return phones_list, word2ph 140 | 141 | 142 | 143 | def text_normalize(text): 144 | numbers = re.findall(r'\d+(?:\.?\d+)?', text) 145 | for number in numbers: 146 | text = text.replace(number, cn2an.an2cn(number), 1) 147 | text = replace_punctuation(text) 148 | 149 | return text 150 | 151 | 152 | if __name__ == '__main__': 153 | text = "啊——但是《原神》是由,米哈\游自主,研发的一款全.新开放世界.冒险游戏" 154 | text = "呣呣呣~就是…大人的鼹鼠党吧?" 155 | text = "你好" 156 | text = text_normalize(text) 157 | print(g2p(text)) 158 | 159 | 160 | # # 示例用法 161 | # text = "这是一个示例文本:,你好!这是一个测试..." 162 | # print(g2p_paddle(text)) # 输出: 这是一个示例文本你好这是一个测试 163 | -------------------------------------------------------------------------------- /text/opencpop-strict.txt: -------------------------------------------------------------------------------- 1 | a AA a 2 | ai AA ai 3 | an AA an 4 | ang AA ang 5 | ao AA ao 6 | ba b a 7 | bai b ai 8 | ban b an 9 | bang b ang 10 | bao b ao 11 | bei b ei 12 | ben b en 13 | beng b eng 14 | bi b i 15 | bian b ian 16 | biao b iao 17 | bie b ie 18 | bin b in 19 | bing b ing 20 | bo b o 21 | bu b u 22 | ca c a 23 | cai c ai 24 | can c an 25 | cang c ang 26 | cao c ao 27 | ce c e 28 | cei c ei 29 | cen c en 30 | ceng c eng 31 | cha ch a 32 | chai ch ai 33 | chan ch an 34 | chang ch ang 35 | chao ch ao 36 | che ch e 37 | chen ch en 38 | cheng ch eng 39 | chi ch ir 40 | chong ch ong 41 | chou ch ou 42 | chu ch u 43 | chua ch ua 44 | chuai ch uai 45 | chuan ch uan 46 | chuang ch uang 47 | chui ch ui 48 | chun ch un 49 | chuo ch uo 50 | ci c i0 51 | cong c ong 52 | cou c ou 53 | cu c u 54 | cuan c uan 55 | cui c ui 56 | cun c un 57 | cuo c uo 58 | da d a 59 | dai d ai 60 | dan d an 61 | dang d ang 62 | dao d ao 63 | de d e 64 | dei d ei 65 | den d en 66 | deng d eng 67 | di d i 68 | dia d ia 69 | dian d ian 70 | diao d iao 71 | die d ie 72 | ding d ing 73 | diu d iu 74 | dong d ong 75 | dou d ou 76 | du d u 77 | duan d uan 78 | dui d ui 79 | dun d un 80 | duo d uo 81 | e EE e 82 | ei EE ei 83 | en EE en 84 | eng EE eng 85 | er EE er 86 | fa f a 87 | fan f an 88 | fang f ang 89 | fei f ei 90 | fen f en 91 | feng f eng 92 | fo f o 93 | fou f ou 94 | fu f u 95 | ga g a 96 | gai g ai 97 | gan g an 98 | gang g ang 99 | gao g ao 100 | ge g e 101 | gei g ei 102 | gen g en 103 | geng g eng 104 | gong g ong 105 | gou g ou 106 | gu g u 107 | gua g ua 108 | guai g uai 109 | guan g uan 110 | guang g uang 111 | gui g ui 112 | gun g un 113 | guo g uo 114 | ha h a 115 | hai h ai 116 | han h an 117 | hang h ang 118 | hao h ao 119 | he h e 120 | hei h ei 121 | hen h en 122 | heng h eng 123 | hong h ong 124 | hou h ou 125 | hu h u 126 | hua h ua 127 | huai h uai 128 | huan h uan 129 | huang h uang 130 | hui h ui 131 | hun h un 132 | huo h uo 133 | ji j i 134 | jia j ia 135 | jian j ian 136 | jiang j iang 137 | jiao j iao 138 | jie j ie 139 | jin j in 140 | jing j ing 141 | jiong j iong 142 | jiu j iu 143 | ju j v 144 | jv j v 145 | juan j van 146 | jvan j van 147 | jue j ve 148 | jve j ve 149 | jun j vn 150 | jvn j vn 151 | ka k a 152 | kai k ai 153 | kan k an 154 | kang k ang 155 | kao k ao 156 | ke k e 157 | kei k ei 158 | ken k en 159 | keng k eng 160 | kong k ong 161 | kou k ou 162 | ku k u 163 | kua k ua 164 | kuai k uai 165 | kuan k uan 166 | kuang k uang 167 | kui k ui 168 | kun k un 169 | kuo k uo 170 | la l a 171 | lai l ai 172 | lan l an 173 | lang l ang 174 | lao l ao 175 | le l e 176 | lei l ei 177 | leng l eng 178 | li l i 179 | lia l ia 180 | lian l ian 181 | liang l iang 182 | liao l iao 183 | lie l ie 184 | lin l in 185 | ling l ing 186 | liu l iu 187 | lo l o 188 | long l ong 189 | lou l ou 190 | lu l u 191 | luan l uan 192 | lun l un 193 | luo l uo 194 | lv l v 195 | lve l ve 196 | ma m a 197 | mai m ai 198 | man m an 199 | mang m ang 200 | mao m ao 201 | me m e 202 | mei m ei 203 | men m en 204 | meng m eng 205 | mi m i 206 | mian m ian 207 | miao m iao 208 | mie m ie 209 | min m in 210 | ming m ing 211 | miu m iu 212 | mo m o 213 | mou m ou 214 | mu m u 215 | na n a 216 | nai n ai 217 | nan n an 218 | nang n ang 219 | nao n ao 220 | ne n e 221 | nei n ei 222 | nen n en 223 | neng n eng 224 | ni n i 225 | nian n ian 226 | niang n iang 227 | niao n iao 228 | nie n ie 229 | nin n in 230 | ning n ing 231 | niu n iu 232 | nong n ong 233 | nou n ou 234 | nu n u 235 | nuan n uan 236 | nun n un 237 | nuo n uo 238 | nv n v 239 | nve n ve 240 | o OO o 241 | ou OO ou 242 | pa p a 243 | pai p ai 244 | pan p an 245 | pang p ang 246 | pao p ao 247 | pei p ei 248 | pen p en 249 | peng p eng 250 | pi p i 251 | pian p ian 252 | piao p iao 253 | pie p ie 254 | pin p in 255 | ping p ing 256 | po p o 257 | pou p ou 258 | pu p u 259 | qi q i 260 | qia q ia 261 | qian q ian 262 | qiang q iang 263 | qiao q iao 264 | qie q ie 265 | qin q in 266 | qing q ing 267 | qiong q iong 268 | qiu q iu 269 | qu q v 270 | qv q v 271 | quan q van 272 | qvan q van 273 | que q ve 274 | qve q ve 275 | qun q vn 276 | qvn q vn 277 | ran r an 278 | rang r ang 279 | rao r ao 280 | re r e 281 | ren r en 282 | reng r eng 283 | ri r ir 284 | rong r ong 285 | rou r ou 286 | ru r u 287 | rua r ua 288 | ruan r uan 289 | rui r ui 290 | run r un 291 | ruo r uo 292 | sa s a 293 | sai s ai 294 | san s an 295 | sang s ang 296 | sao s ao 297 | se s e 298 | sen s en 299 | seng s eng 300 | sha sh a 301 | shai sh ai 302 | shan sh an 303 | shang sh ang 304 | shao sh ao 305 | she sh e 306 | shei sh ei 307 | shen sh en 308 | sheng sh eng 309 | shi sh ir 310 | shou sh ou 311 | shu sh u 312 | shua sh ua 313 | shuai sh uai 314 | shuan sh uan 315 | shuang sh uang 316 | shui sh ui 317 | shun sh un 318 | shuo sh uo 319 | si s i0 320 | song s ong 321 | sou s ou 322 | su s u 323 | suan s uan 324 | sui s ui 325 | sun s un 326 | suo s uo 327 | ta t a 328 | tai t ai 329 | tan t an 330 | tang t ang 331 | tao t ao 332 | te t e 333 | tei t ei 334 | teng t eng 335 | ti t i 336 | tian t ian 337 | tiao t iao 338 | tie t ie 339 | ting t ing 340 | tong t ong 341 | tou t ou 342 | tu t u 343 | tuan t uan 344 | tui t ui 345 | tun t un 346 | tuo t uo 347 | wa w a 348 | wai w ai 349 | wan w an 350 | wang w ang 351 | wei w ei 352 | wen w en 353 | weng w eng 354 | wo w o 355 | wu w u 356 | xi x i 357 | xia x ia 358 | xian x ian 359 | xiang x iang 360 | xiao x iao 361 | xie x ie 362 | xin x in 363 | xing x ing 364 | xiong x iong 365 | xiu x iu 366 | xu x v 367 | xv x v 368 | xuan x van 369 | xvan x van 370 | xue x ve 371 | xve x ve 372 | xun x vn 373 | xvn x vn 374 | ya y a 375 | yan y En 376 | yang y ang 377 | yao y ao 378 | ye y E 379 | yi y i 380 | yin y in 381 | ying y ing 382 | yo y o 383 | yong y ong 384 | you y ou 385 | yu y v 386 | yv y v 387 | yuan y van 388 | yvan y van 389 | yue y ve 390 | yve y ve 391 | yun y vn 392 | yvn y vn 393 | za z a 394 | zai z ai 395 | zan z an 396 | zang z ang 397 | zao z ao 398 | ze z e 399 | zei z ei 400 | zen z en 401 | zeng z eng 402 | zha zh a 403 | zhai zh ai 404 | zhan zh an 405 | zhang zh ang 406 | zhao zh ao 407 | zhe zh e 408 | zhei zh ei 409 | zhen zh en 410 | zheng zh eng 411 | zhi zh ir 412 | zhong zh ong 413 | zhou zh ou 414 | zhu zh u 415 | zhua zh ua 416 | zhuai zh uai 417 | zhuan zh uan 418 | zhuang zh uang 419 | zhui zh ui 420 | zhun zh un 421 | zhuo zh uo 422 | zi z i0 423 | zong z ong 424 | zou z ou 425 | zu z u 426 | zuan z uan 427 | zui z ui 428 | zun z un 429 | zuo z uo 430 | -------------------------------------------------------------------------------- /module/commons.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | def init_weights(m, mean=0.0, std=0.01): 9 | classname = m.__class__.__name__ 10 | if classname.find("Conv") != -1: 11 | m.weight.data.normal_(mean, std) 12 | 13 | 14 | def get_padding(kernel_size, dilation=1): 15 | return int((kernel_size*dilation - dilation)/2) 16 | 17 | 18 | def convert_pad_shape(pad_shape): 19 | l = pad_shape[::-1] 20 | pad_shape = [item for sublist in l for item in sublist] 21 | return pad_shape 22 | 23 | 24 | def intersperse(lst, item): 25 | result = [item] * (len(lst) * 2 + 1) 26 | result[1::2] = lst 27 | return result 28 | 29 | 30 | def kl_divergence(m_p, logs_p, m_q, logs_q): 31 | """KL(P||Q)""" 32 | kl = (logs_q - logs_p) - 0.5 33 | kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q) 34 | return kl 35 | 36 | 37 | def rand_gumbel(shape): 38 | """Sample from the Gumbel distribution, protect from overflows.""" 39 | uniform_samples = torch.rand(shape) * 0.99998 + 0.00001 40 | return -torch.log(-torch.log(uniform_samples)) 41 | 42 | 43 | def rand_gumbel_like(x): 44 | g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device) 45 | return g 46 | 47 | 48 | def slice_segments(x, ids_str, segment_size=4): 49 | ret = torch.zeros_like(x[:, :, :segment_size]) 50 | for i in range(x.size(0)): 51 | idx_str = ids_str[i] 52 | idx_end = idx_str + segment_size 53 | ret[i] = x[i, :, idx_str:idx_end] 54 | return ret 55 | 56 | 57 | def rand_slice_segments(x, x_lengths=None, segment_size=4): 58 | b, d, t = x.size() 59 | if x_lengths is None: 60 | x_lengths = t 61 | ids_str_max = x_lengths - segment_size + 1 62 | ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) 63 | ret = slice_segments(x, ids_str, segment_size) 64 | return ret, ids_str 65 | 66 | 67 | def get_timing_signal_1d( 68 | length, channels, min_timescale=1.0, max_timescale=1.0e4): 69 | position = torch.arange(length, dtype=torch.float) 70 | num_timescales = channels // 2 71 | log_timescale_increment = ( 72 | math.log(float(max_timescale) / float(min_timescale)) / 73 | (num_timescales - 1)) 74 | inv_timescales = min_timescale * torch.exp( 75 | torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment) 76 | scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1) 77 | signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0) 78 | signal = F.pad(signal, [0, 0, 0, channels % 2]) 79 | signal = signal.view(1, channels, length) 80 | return signal 81 | 82 | 83 | def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4): 84 | b, channels, length = x.size() 85 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 86 | return x + signal.to(dtype=x.dtype, device=x.device) 87 | 88 | 89 | def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1): 90 | b, channels, length = x.size() 91 | signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale) 92 | return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis) 93 | 94 | 95 | def subsequent_mask(length): 96 | mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) 97 | return mask 98 | 99 | 100 | @torch.jit.script 101 | def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): 102 | n_channels_int = n_channels[0] 103 | in_act = input_a + input_b 104 | t_act = torch.tanh(in_act[:, :n_channels_int, :]) 105 | s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) 106 | acts = t_act * s_act 107 | return acts 108 | 109 | 110 | def convert_pad_shape(pad_shape): 111 | l = pad_shape[::-1] 112 | pad_shape = [item for sublist in l for item in sublist] 113 | return pad_shape 114 | 115 | 116 | def shift_1d(x): 117 | x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1] 118 | return x 119 | 120 | 121 | def sequence_mask(length, max_length=None): 122 | if max_length is None: 123 | max_length = length.max() 124 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 125 | return x.unsqueeze(0) < length.unsqueeze(1) 126 | 127 | 128 | def generate_path(duration, mask): 129 | """ 130 | duration: [b, 1, t_x] 131 | mask: [b, 1, t_y, t_x] 132 | """ 133 | device = duration.device 134 | 135 | b, _, t_y, t_x = mask.shape 136 | cum_duration = torch.cumsum(duration, -1) 137 | 138 | cum_duration_flat = cum_duration.view(b * t_x) 139 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 140 | path = path.view(b, t_x, t_y) 141 | path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 142 | path = path.unsqueeze(1).transpose(2,3) * mask 143 | return path 144 | 145 | 146 | def clip_grad_value_(parameters, clip_value, norm_type=2): 147 | if isinstance(parameters, torch.Tensor): 148 | parameters = [parameters] 149 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 150 | norm_type = float(norm_type) 151 | if clip_value is not None: 152 | clip_value = float(clip_value) 153 | 154 | total_norm = 0 155 | for p in parameters: 156 | param_norm = p.grad.data.norm(norm_type) 157 | total_norm += param_norm.item() ** norm_type 158 | if clip_value is not None: 159 | p.grad.data.clamp_(min=-clip_value, max=clip_value) 160 | total_norm = total_norm ** (1. / norm_type) 161 | return total_norm 162 | 163 | 164 | def squeeze(x, x_mask=None, n_sqz=2): 165 | b, c, t = x.size() 166 | 167 | t = (t // n_sqz) * n_sqz 168 | x = x[:, :, :t] 169 | x_sqz = x.view(b, c, t // n_sqz, n_sqz) 170 | x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz) 171 | 172 | if x_mask is not None: 173 | x_mask = x_mask[:, :, n_sqz - 1::n_sqz] 174 | else: 175 | x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype) 176 | return x_sqz * x_mask, x_mask 177 | 178 | 179 | def unsqueeze(x, x_mask=None, n_sqz=2): 180 | b, c, t = x.size() 181 | 182 | x_unsqz = x.view(b, n_sqz, c // n_sqz, t) 183 | x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz) 184 | 185 | if x_mask is not None: 186 | x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz) 187 | else: 188 | x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype) 189 | return x_unsqz * x_mask, x_mask 190 | -------------------------------------------------------------------------------- /extract_duration.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import librosa 3 | from text import cleaned_text_to_sequence 4 | from asr.models import build_model 5 | import torch 6 | import torchaudio 7 | from tqdm import tqdm 8 | from monotonic_align import mask_from_lens 9 | from monotonic_align.core import maximum_path_c 10 | import torch.nn.functional as F 11 | import numpy as np 12 | import os 13 | 14 | from asr.trainer import calc_wer 15 | 16 | 17 | def maximum_path(neg_cent, mask): 18 | """ Cython optimized version. 19 | neg_cent: [b, t_t, t_s] 20 | mask: [b, t_t, t_s] 21 | """ 22 | device = neg_cent.device 23 | dtype = neg_cent.dtype 24 | neg_cent = np.ascontiguousarray(neg_cent.data.cpu().numpy().astype(np.float32)) 25 | path = np.ascontiguousarray(np.zeros(neg_cent.shape, dtype=np.int32)) 26 | 27 | t_t_max = np.ascontiguousarray(mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)) 28 | t_s_max = np.ascontiguousarray(mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)) 29 | maximum_path_c(path, neg_cent, t_t_max, t_s_max) 30 | return torch.from_numpy(path).to(device=device, dtype=dtype) 31 | 32 | 33 | def intersperse(lst, item): 34 | result = [item] * (len(lst) * 2 + 1) 35 | result[1::2] = lst 36 | return result 37 | 38 | 39 | def calc_mono_loss(s2s_attn, input_lengths, mel_input_length, text_mask, mel_mask, n_down): 40 | s2s_attn = s2s_attn.transpose(-1, -2) 41 | s2s_attn = s2s_attn[..., 1:] 42 | s2s_attn = s2s_attn.transpose(-1, -2) 43 | 44 | with torch.no_grad(): 45 | attn_mask = (~mel_mask).unsqueeze(-1).expand(mel_mask.shape[0], mel_mask.shape[1], 46 | text_mask.shape[-1]).float().transpose(-1, -2) 47 | attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], 48 | text_mask.shape[1], 49 | mel_mask.shape[-1]).float() 50 | attn_mask = (attn_mask < 1) 51 | 52 | s2s_attn.masked_fill_(attn_mask, 0.0) 53 | 54 | with torch.no_grad(): 55 | mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length) 56 | s2s_attn_mono = maximum_path(s2s_attn, mask_ST) 57 | loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10 58 | 59 | return loss_mono, s2s_attn_mono 60 | 61 | 62 | def get_attention_mono(model, text_input, text_input_length, mel_input, mel_input_length): 63 | mel_input_length = mel_input_length // (2 ** model.n_down) 64 | future_mask = model.get_future_mask( 65 | mel_input.size(2) // (2 ** model.n_down), unmask_future_steps=0).to(text_input.device) 66 | mel_mask = model.length_to_mask(mel_input_length) 67 | text_mask = model.length_to_mask(text_input_length) 68 | ppgs, s2s_pred, s2s_attn = model( 69 | mel_input, src_key_padding_mask=mel_mask, text_input=text_input) 70 | loss_mono, s2s_attn_mono = calc_mono_loss(s2s_attn, text_input_length, mel_input_length, text_mask, mel_mask, 71 | model.n_down) 72 | _, amax_ppgs = torch.max(ppgs, dim=2) 73 | wers = [calc_wer(target[:text_length], 74 | pred[:mel_length], 75 | ignore_indexes=list(range(5))) \ 76 | for target, pred, text_length, mel_length in zip( 77 | text_input.cpu(), amax_ppgs.cpu(), text_input_length.cpu(), mel_input_length.cpu())] 78 | m_wer = np.mean(wers) 79 | return s2s_attn_mono, m_wer 80 | 81 | 82 | to_mel = torchaudio.transforms.MelSpectrogram( 83 | n_mels=128, n_fft=2048, win_length=2048, hop_length=640) 84 | mean, std = -4, 4 85 | 86 | 87 | def preprocess(wave): 88 | wave_tensor = torch.from_numpy(wave).float() 89 | mel_tensor = to_mel(wave_tensor) 90 | mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std 91 | return mel_tensor 92 | 93 | 94 | config_path = 'configs/asr.yml' 95 | ckpt_path = 'logs/asr/last.ckpt' 96 | dump_dir = 'dump' 97 | phoneme_path = f'{dump_dir}/phoneme.npy' 98 | train_path = f'{dump_dir}/train_files.list' 99 | val_path = f'{dump_dir}/val_files.list' 100 | config = yaml.safe_load(open(config_path)) 101 | model = build_model(model_params=config['model_params'] or {}) 102 | state_dict = torch.load(ckpt_path, map_location="cpu")['state_dict'] 103 | state_dict = {k.replace('model.', ''):v for k, v in state_dict.items()} 104 | model.load_state_dict(state_dict) 105 | device = 'cuda:0' 106 | model = model.to(device) 107 | model.eval() 108 | phoneme_data = np.load(phoneme_path, allow_pickle=True).item() 109 | 110 | all_files = [line.strip() for line in open(train_path)] 111 | import random 112 | 113 | random.shuffle(all_files) 114 | processed_cnt = 0 115 | all_files = [line.strip() for line in open(val_path)] + all_files 116 | with torch.no_grad(): 117 | for line in tqdm(all_files): 118 | wave_path = line.strip() 119 | try: 120 | phonemes = phoneme_data[wave_path] 121 | except: 122 | print('phoneme not exist ,skip:', wave_path) 123 | continue 124 | if not os.path.exists(wave_path): 125 | print('skip:', wave_path) 126 | continue 127 | wave, sr = librosa.load(wave_path, sr=None) 128 | if wave.shape[-1] == 2: 129 | wave = wave[:, 0].squeeze() 130 | assert sr == 32000 131 | phoneme = phonemes.split(' ') 132 | phoneme_ids = cleaned_text_to_sequence(phoneme) 133 | phoneme_ids = intersperse(phoneme_ids, 0) 134 | 135 | text = torch.LongTensor(phoneme_ids) 136 | mel_tensor = preprocess(wave).squeeze() 137 | 138 | ph_len = len(phoneme_ids) 139 | mel_len = mel_tensor.shape[-1] 140 | ps = mel_len/ph_len 141 | 142 | if ps < 1.2 or ps>10: 143 | print(ph_len, mel_len, ) 144 | print('skip:',wave_path) 145 | continue 146 | acoustic_feature = mel_tensor.squeeze() 147 | length_feature = acoustic_feature.size(1) 148 | # acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)] 149 | 150 | # print(acoustic_feature.size(), text.size()) 151 | 152 | text_input = text.unsqueeze(0).to(device) 153 | text_input_length = torch.LongTensor([len(phoneme_ids)]).to(device) 154 | mel_input = mel_tensor.unsqueeze(0).to(device) 155 | mel_input_length = torch.LongTensor([mel_input.size(2)]).to(device) 156 | 157 | s2s_attn_mono, m_wer = get_attention_mono(model, text_input, text_input_length, mel_input, mel_input_length) 158 | duration = s2s_attn_mono[0].long().sum(-1).detach().cpu().numpy().tolist() 159 | # duration = s2s_attn_mono[0].long().sum(-1).detach().cpu().numpy().tolist() 160 | # print(duration, len(duration), sum(duration)) 161 | duration = s2s_attn_mono[0].long().sum(-1).detach().cpu() 162 | 163 | save_path = wave_path.replace('.wav', '.dur.pt').replace('.mp3', '.dur.pt') 164 | torch.save(duration, save_path) 165 | processed_cnt += 1 166 | # text_input, text_input_length, mel_input, mel_input_length 167 | # print(s2s_attn_mono.shape, duration, len(duration), sum(duration), len(phoneme), acoustic_feature.shape) 168 | # break 169 | 170 | print(processed_cnt) -------------------------------------------------------------------------------- /asr/meldataset.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | import os 3 | import os.path as osp 4 | import time 5 | import random 6 | import numpy as np 7 | import random 8 | import soundfile as sf 9 | import librosa 10 | 11 | from text import cleaned_text_to_sequence 12 | import torch 13 | from torch import nn 14 | import torch.nn.functional as F 15 | import torchaudio 16 | from torch.utils.data import DataLoader 17 | from text.symbols import symbols 18 | import logging 19 | logger = logging.getLogger(__name__) 20 | logger.setLevel(logging.DEBUG) 21 | 22 | import pandas as pd 23 | 24 | 25 | np.random.seed(1) 26 | random.seed(1) 27 | 28 | to_mel = torchaudio.transforms.MelSpectrogram( 29 | n_mels=128, n_fft=2048, win_length=2048, hop_length=640) 30 | mean, std = -4, 4 31 | 32 | def preprocess(wave): 33 | wave_tensor = torch.from_numpy(wave).float() 34 | mel_tensor = to_mel(wave_tensor) 35 | mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std 36 | return mel_tensor 37 | def intersperse(lst, item): 38 | result = [item] * (len(lst) * 2 + 1) 39 | result[1::2] = lst 40 | return result 41 | 42 | class FilePathDataset(torch.utils.data.Dataset): 43 | def __init__(self, 44 | data_list, 45 | sr=32000, 46 | data_augmentation=False, 47 | validation=False, 48 | OOD_data="Data/OOD_texts.txt", 49 | min_length=50, 50 | ): 51 | 52 | phoneme_path = 'dump/phoneme.npy' 53 | _data_list = [l.strip() for l in data_list] 54 | self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item() 55 | 56 | self.data_list = [data for data in _data_list if os.path.exists(data) and data in self.phoneme_data] 57 | print(f"Data list length: {len(self.data_list)}") 58 | self.sr = sr 59 | assert sr == 32000 60 | 61 | 62 | self.mean, self.std = -4, 4 63 | self.data_augmentation = data_augmentation and (not validation) 64 | self.max_mel_length = 192 65 | 66 | self.min_length = min_length 67 | 68 | 69 | def __len__(self): 70 | return len(self.data_list) 71 | 72 | def __getitem__(self, idx): 73 | path = self.data_list[idx] 74 | 75 | wave, text_tensor, speaker_id = self._load_tensor(path) 76 | 77 | mel_tensor = preprocess(wave).squeeze() 78 | ph_len = len(text_tensor) 79 | mel_len = mel_tensor.shape[-1] 80 | ps = mel_len/ph_len 81 | if ps < 1.2 or ps>10: 82 | print(ph_len, mel_len, ) 83 | print('skip:',path) 84 | return self.__getitem__((idx+1)%len(self.data_list)) 85 | acoustic_feature = mel_tensor.squeeze() 86 | length_feature = acoustic_feature.size(1) 87 | acoustic_feature = acoustic_feature[:, :(length_feature - length_feature % 2)] 88 | 89 | return speaker_id, acoustic_feature, text_tensor, None, None, None, path, wave 90 | 91 | def _load_tensor(self, path): 92 | speaker_id = 0 93 | wave_path = path 94 | wave, sr = librosa.load(wave_path, sr=None) 95 | if wave.shape[-1] == 2: 96 | wave = wave[:, 0].squeeze() 97 | assert sr == 32000 98 | # wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0) 99 | 100 | text = self.phoneme_data[path] 101 | phoneme = text.split(' ') 102 | phoneme_ids = cleaned_text_to_sequence(phoneme) 103 | phoneme_ids = intersperse(phoneme_ids, 0) 104 | 105 | text = torch.LongTensor(phoneme_ids) 106 | 107 | return wave, text, speaker_id 108 | 109 | def _load_data(self, data): 110 | wave, text_tensor, speaker_id = self._load_tensor(data) 111 | mel_tensor = preprocess(wave).squeeze() 112 | 113 | mel_length = mel_tensor.size(1) 114 | if mel_length > self.max_mel_length: 115 | random_start = np.random.randint(0, mel_length - self.max_mel_length) 116 | mel_tensor = mel_tensor[:, random_start:random_start + self.max_mel_length] 117 | 118 | return mel_tensor, speaker_id 119 | 120 | 121 | 122 | class Collater(object): 123 | """ 124 | Args: 125 | adaptive_batch_size (bool): if true, decrease batch size when long data comes. 126 | """ 127 | 128 | def __init__(self, return_wave=False): 129 | self.text_pad_index = 0 130 | self.min_mel_length = 192 131 | self.max_mel_length = 192 132 | self.return_wave = return_wave 133 | 134 | 135 | def __call__(self, batch): 136 | # batch[0] = wave, mel, text, f0, speakerid 137 | batch_size = len(batch) 138 | 139 | # sort by mel length 140 | lengths = [b[1].shape[1] for b in batch] 141 | batch_indexes = np.argsort(lengths)[::-1] 142 | batch = [batch[bid] for bid in batch_indexes] 143 | 144 | nmels = batch[0][1].size(0) 145 | max_mel_length = max([b[1].shape[1] for b in batch]) 146 | max_text_length = max([b[2].shape[0] for b in batch]) 147 | # max_rtext_length = max([b[3].shape[0] for b in batch]) 148 | 149 | labels = torch.zeros((batch_size)).long() 150 | mels = torch.zeros((batch_size, nmels, max_mel_length)).float() 151 | texts = torch.zeros((batch_size, max_text_length)).long() 152 | # ref_texts = torch.zeros((batch_size, max_rtext_length)).long() 153 | 154 | input_lengths = torch.zeros(batch_size).long() 155 | # ref_lengths = torch.zeros(batch_size).long() 156 | output_lengths = torch.zeros(batch_size).long() 157 | # ref_mels = torch.zeros((batch_size, nmels, self.max_mel_length)).float() 158 | # ref_labels = torch.zeros((batch_size)).long() 159 | paths = ['' for _ in range(batch_size)] 160 | waves = [None for _ in range(batch_size)] 161 | 162 | for bid, (label, mel, text, ref_text, ref_mel, ref_label, path, wave) in enumerate(batch): 163 | mel_size = mel.size(1) 164 | text_size = text.size(0) 165 | labels[bid] = label 166 | mels[bid, :, :mel_size] = mel 167 | texts[bid, :text_size] = text 168 | input_lengths[bid] = text_size 169 | output_lengths[bid] = mel_size 170 | paths[bid] = path 171 | 172 | waves[bid] = wave 173 | 174 | return texts, input_lengths, mels, output_lengths 175 | 176 | 177 | 178 | def build_dataloader(path_list, 179 | validation=False, 180 | OOD_data="Data/OOD_texts.txt", 181 | min_length=50, 182 | batch_size=4, 183 | num_workers=1, 184 | collate_config={}, 185 | dataset_config={}): 186 | 187 | dataset = FilePathDataset(path_list, OOD_data=OOD_data, min_length=min_length, validation=validation, **dataset_config) 188 | collate_fn = Collater(**collate_config) 189 | data_loader = DataLoader(dataset, 190 | batch_size=batch_size, 191 | shuffle=(not validation), 192 | num_workers=num_workers, 193 | drop_last=(not validation), 194 | collate_fn=collate_fn) 195 | 196 | return data_loader 197 | 198 | -------------------------------------------------------------------------------- /asr/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from torch.nn import TransformerEncoder 5 | import torch.nn.functional as F 6 | from asr.layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock 7 | 8 | def build_model(model_params={}, model_type='asr'): 9 | model = ASRCNN(**model_params) 10 | return model 11 | 12 | 13 | class ASRCNN(nn.Module): 14 | def __init__(self, 15 | input_dim=80, 16 | hidden_dim=256, 17 | n_token=35, 18 | n_layers=6, 19 | token_embedding_dim=256, 20 | 21 | ): 22 | super().__init__() 23 | self.n_token = n_token 24 | self.n_down = 0 25 | self.to_mfcc = MFCC() 26 | self.init_cnn = ConvNorm(input_dim//2, hidden_dim, kernel_size=7, padding=3, stride=1) 27 | self.cnns = nn.Sequential( 28 | *[nn.Sequential( 29 | ConvBlock(hidden_dim), 30 | nn.GroupNorm(num_groups=1, num_channels=hidden_dim) 31 | ) for n in range(n_layers)]) 32 | self.projection = ConvNorm(hidden_dim, hidden_dim // 2) 33 | self.ctc_linear = nn.Sequential( 34 | LinearNorm(hidden_dim//2, hidden_dim), 35 | nn.ReLU(), 36 | LinearNorm(hidden_dim, n_token)) 37 | self.asr_s2s = ASRS2S( 38 | embedding_dim=token_embedding_dim, 39 | hidden_dim=hidden_dim//2, 40 | n_token=n_token) 41 | 42 | def forward(self, x, src_key_padding_mask=None, text_input=None): 43 | x = self.to_mfcc(x) 44 | x = self.init_cnn(x) 45 | x = self.cnns(x) 46 | 47 | x = self.projection(x) 48 | x = x.transpose(1, 2) 49 | ctc_logit = self.ctc_linear(x) 50 | if text_input is not None: 51 | _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input) 52 | return ctc_logit, s2s_logit, s2s_attn 53 | else: 54 | return ctc_logit 55 | 56 | def get_feature(self, x): 57 | x = self.to_mfcc(x) 58 | x = self.init_cnn(x) 59 | x = self.cnns(x) 60 | x = self.instance_norm(x) 61 | x = self.projection(x) 62 | return x 63 | 64 | def length_to_mask(self, lengths): 65 | mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) 66 | mask = torch.gt(mask+1, lengths.unsqueeze(1)).to(lengths.device) 67 | return mask 68 | 69 | def get_future_mask(self, out_length, unmask_future_steps=0): 70 | """ 71 | Args: 72 | out_length (int): returned mask shape is (out_length, out_length). 73 | unmask_futre_steps (int): unmasking future step size. 74 | Return: 75 | mask (torch.BoolTensor): mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False 76 | """ 77 | index_tensor = torch.arange(out_length).unsqueeze(0).expand(out_length, -1) 78 | mask = torch.gt(index_tensor, index_tensor.T + unmask_future_steps) 79 | return mask 80 | 81 | class ASRS2S(nn.Module): 82 | def __init__(self, 83 | embedding_dim=256, 84 | hidden_dim=512, 85 | n_location_filters=32, 86 | location_kernel_size=63, 87 | n_token=40): 88 | super(ASRS2S, self).__init__() 89 | self.embedding = nn.Embedding(n_token, embedding_dim) 90 | val_range = math.sqrt(6 / hidden_dim) 91 | self.embedding.weight.data.uniform_(-val_range, val_range) 92 | 93 | self.decoder_rnn_dim = hidden_dim 94 | self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token) 95 | self.attention_layer = Attention( 96 | self.decoder_rnn_dim, 97 | hidden_dim, 98 | hidden_dim, 99 | n_location_filters, 100 | location_kernel_size 101 | ) 102 | self.decoder_rnn = nn.LSTMCell(self.decoder_rnn_dim + embedding_dim, self.decoder_rnn_dim) 103 | self.project_to_hidden = nn.Sequential( 104 | LinearNorm(self.decoder_rnn_dim * 2, hidden_dim), 105 | nn.Tanh()) 106 | self.sos = 1 107 | self.eos = 2 108 | 109 | def initialize_decoder_states(self, memory, mask): 110 | """ 111 | moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim) 112 | """ 113 | B, L, H = memory.shape 114 | self.decoder_hidden = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory) 115 | self.decoder_cell = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory) 116 | self.attention_weights = torch.zeros((B, L)).type_as(memory) 117 | self.attention_weights_cum = torch.zeros((B, L)).type_as(memory) 118 | self.attention_context = torch.zeros((B, H)).type_as(memory) 119 | self.memory = memory 120 | self.processed_memory = self.attention_layer.memory_layer(memory) 121 | self.mask = mask 122 | self.unk_index = 3 123 | self.random_mask = 0.1 124 | 125 | def forward(self, memory, memory_mask, text_input): 126 | """ 127 | moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim) 128 | moemory_mask.shape = (B, L, ) 129 | texts_input.shape = (B, T) 130 | """ 131 | self.initialize_decoder_states(memory, memory_mask) 132 | # text random mask 133 | random_mask = (torch.rand(text_input.shape) < self.random_mask).to(text_input.device) 134 | _text_input = text_input.clone() 135 | _text_input.masked_fill_(random_mask, self.unk_index) 136 | decoder_inputs = self.embedding(_text_input).transpose(0, 1) # -> [T, B, channel] 137 | start_embedding = self.embedding( 138 | torch.LongTensor([self.sos]*decoder_inputs.size(1)).to(decoder_inputs.device)) 139 | decoder_inputs = torch.cat((start_embedding.unsqueeze(0), decoder_inputs), dim=0) 140 | 141 | hidden_outputs, logit_outputs, alignments = [], [], [] 142 | while len(hidden_outputs) < decoder_inputs.size(0): 143 | 144 | decoder_input = decoder_inputs[len(hidden_outputs)] 145 | hidden, logit, attention_weights = self.decode(decoder_input) 146 | hidden_outputs += [hidden] 147 | logit_outputs += [logit] 148 | alignments += [attention_weights] 149 | 150 | hidden_outputs, logit_outputs, alignments = \ 151 | self.parse_decoder_outputs( 152 | hidden_outputs, logit_outputs, alignments) 153 | 154 | return hidden_outputs, logit_outputs, alignments 155 | 156 | 157 | def decode(self, decoder_input): 158 | 159 | cell_input = torch.cat((decoder_input, self.attention_context), -1) 160 | self.decoder_hidden, self.decoder_cell = self.decoder_rnn( 161 | cell_input, 162 | (self.decoder_hidden, self.decoder_cell)) 163 | 164 | attention_weights_cat = torch.cat( 165 | (self.attention_weights.unsqueeze(1), 166 | self.attention_weights_cum.unsqueeze(1)),dim=1) 167 | 168 | self.attention_context, self.attention_weights = self.attention_layer( 169 | self.decoder_hidden, 170 | self.memory, 171 | self.processed_memory, 172 | attention_weights_cat, 173 | self.mask) 174 | 175 | self.attention_weights_cum += self.attention_weights 176 | 177 | hidden_and_context = torch.cat((self.decoder_hidden, self.attention_context), -1) 178 | hidden = self.project_to_hidden(hidden_and_context) 179 | 180 | # dropout to increasing g 181 | logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training)) 182 | 183 | return hidden, logit, self.attention_weights 184 | 185 | def parse_decoder_outputs(self, hidden, logit, alignments): 186 | 187 | # -> [B, T_out + 1, max_time] 188 | alignments = torch.stack(alignments).transpose(0,1) 189 | # [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols] 190 | logit = torch.stack(logit).transpose(0, 1).contiguous() 191 | hidden = torch.stack(hidden).transpose(0, 1).contiguous() 192 | 193 | return hidden, logit, alignments 194 | -------------------------------------------------------------------------------- /module/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | import numpy as np 5 | 6 | 7 | DEFAULT_MIN_BIN_WIDTH = 1e-3 8 | DEFAULT_MIN_BIN_HEIGHT = 1e-3 9 | DEFAULT_MIN_DERIVATIVE = 1e-3 10 | 11 | 12 | def piecewise_rational_quadratic_transform(inputs, 13 | unnormalized_widths, 14 | unnormalized_heights, 15 | unnormalized_derivatives, 16 | inverse=False, 17 | tails=None, 18 | tail_bound=1., 19 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 20 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 21 | min_derivative=DEFAULT_MIN_DERIVATIVE): 22 | 23 | if tails is None: 24 | spline_fn = rational_quadratic_spline 25 | spline_kwargs = {} 26 | else: 27 | spline_fn = unconstrained_rational_quadratic_spline 28 | spline_kwargs = { 29 | 'tails': tails, 30 | 'tail_bound': tail_bound 31 | } 32 | 33 | outputs, logabsdet = spline_fn( 34 | inputs=inputs, 35 | unnormalized_widths=unnormalized_widths, 36 | unnormalized_heights=unnormalized_heights, 37 | unnormalized_derivatives=unnormalized_derivatives, 38 | inverse=inverse, 39 | min_bin_width=min_bin_width, 40 | min_bin_height=min_bin_height, 41 | min_derivative=min_derivative, 42 | **spline_kwargs 43 | ) 44 | return outputs, logabsdet 45 | 46 | 47 | def searchsorted(bin_locations, inputs, eps=1e-6): 48 | bin_locations[..., -1] += eps 49 | return torch.sum( 50 | inputs[..., None] >= bin_locations, 51 | dim=-1 52 | ) - 1 53 | 54 | 55 | def unconstrained_rational_quadratic_spline(inputs, 56 | unnormalized_widths, 57 | unnormalized_heights, 58 | unnormalized_derivatives, 59 | inverse=False, 60 | tails='linear', 61 | tail_bound=1., 62 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 63 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 64 | min_derivative=DEFAULT_MIN_DERIVATIVE): 65 | inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) 66 | outside_interval_mask = ~inside_interval_mask 67 | 68 | outputs = torch.zeros_like(inputs) 69 | logabsdet = torch.zeros_like(inputs) 70 | 71 | if tails == 'linear': 72 | unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) 73 | constant = np.log(np.exp(1 - min_derivative) - 1) 74 | unnormalized_derivatives[..., 0] = constant 75 | unnormalized_derivatives[..., -1] = constant 76 | 77 | outputs[outside_interval_mask] = inputs[outside_interval_mask] 78 | logabsdet[outside_interval_mask] = 0 79 | else: 80 | raise RuntimeError('{} tails are not implemented.'.format(tails)) 81 | 82 | outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( 83 | inputs=inputs[inside_interval_mask], 84 | unnormalized_widths=unnormalized_widths[inside_interval_mask, :], 85 | unnormalized_heights=unnormalized_heights[inside_interval_mask, :], 86 | unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], 87 | inverse=inverse, 88 | left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, 89 | min_bin_width=min_bin_width, 90 | min_bin_height=min_bin_height, 91 | min_derivative=min_derivative 92 | ) 93 | 94 | return outputs, logabsdet 95 | 96 | def rational_quadratic_spline(inputs, 97 | unnormalized_widths, 98 | unnormalized_heights, 99 | unnormalized_derivatives, 100 | inverse=False, 101 | left=0., right=1., bottom=0., top=1., 102 | min_bin_width=DEFAULT_MIN_BIN_WIDTH, 103 | min_bin_height=DEFAULT_MIN_BIN_HEIGHT, 104 | min_derivative=DEFAULT_MIN_DERIVATIVE): 105 | if torch.min(inputs) < left or torch.max(inputs) > right: 106 | raise ValueError('Input to a transform is not within its domain') 107 | 108 | num_bins = unnormalized_widths.shape[-1] 109 | 110 | if min_bin_width * num_bins > 1.0: 111 | raise ValueError('Minimal bin width too large for the number of bins') 112 | if min_bin_height * num_bins > 1.0: 113 | raise ValueError('Minimal bin height too large for the number of bins') 114 | 115 | widths = F.softmax(unnormalized_widths, dim=-1) 116 | widths = min_bin_width + (1 - min_bin_width * num_bins) * widths 117 | cumwidths = torch.cumsum(widths, dim=-1) 118 | cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) 119 | cumwidths = (right - left) * cumwidths + left 120 | cumwidths[..., 0] = left 121 | cumwidths[..., -1] = right 122 | widths = cumwidths[..., 1:] - cumwidths[..., :-1] 123 | 124 | derivatives = min_derivative + F.softplus(unnormalized_derivatives) 125 | 126 | heights = F.softmax(unnormalized_heights, dim=-1) 127 | heights = min_bin_height + (1 - min_bin_height * num_bins) * heights 128 | cumheights = torch.cumsum(heights, dim=-1) 129 | cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) 130 | cumheights = (top - bottom) * cumheights + bottom 131 | cumheights[..., 0] = bottom 132 | cumheights[..., -1] = top 133 | heights = cumheights[..., 1:] - cumheights[..., :-1] 134 | 135 | if inverse: 136 | bin_idx = searchsorted(cumheights, inputs)[..., None] 137 | else: 138 | bin_idx = searchsorted(cumwidths, inputs)[..., None] 139 | 140 | input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] 141 | input_bin_widths = widths.gather(-1, bin_idx)[..., 0] 142 | 143 | input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] 144 | delta = heights / widths 145 | input_delta = delta.gather(-1, bin_idx)[..., 0] 146 | 147 | input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] 148 | input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] 149 | 150 | input_heights = heights.gather(-1, bin_idx)[..., 0] 151 | 152 | if inverse: 153 | a = (((inputs - input_cumheights) * (input_derivatives 154 | + input_derivatives_plus_one 155 | - 2 * input_delta) 156 | + input_heights * (input_delta - input_derivatives))) 157 | b = (input_heights * input_derivatives 158 | - (inputs - input_cumheights) * (input_derivatives 159 | + input_derivatives_plus_one 160 | - 2 * input_delta)) 161 | c = - input_delta * (inputs - input_cumheights) 162 | 163 | discriminant = b.pow(2) - 4 * a * c 164 | assert (discriminant >= 0).all() 165 | 166 | root = (2 * c) / (-b - torch.sqrt(discriminant)) 167 | outputs = root * input_bin_widths + input_cumwidths 168 | 169 | theta_one_minus_theta = root * (1 - root) 170 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 171 | * theta_one_minus_theta) 172 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) 173 | + 2 * input_delta * theta_one_minus_theta 174 | + input_derivatives * (1 - root).pow(2)) 175 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 176 | 177 | return outputs, -logabsdet 178 | else: 179 | theta = (inputs - input_cumwidths) / input_bin_widths 180 | theta_one_minus_theta = theta * (1 - theta) 181 | 182 | numerator = input_heights * (input_delta * theta.pow(2) 183 | + input_derivatives * theta_one_minus_theta) 184 | denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) 185 | * theta_one_minus_theta) 186 | outputs = input_cumheights + numerator / denominator 187 | 188 | derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) 189 | + 2 * input_delta * theta_one_minus_theta 190 | + input_derivatives * (1 - theta).pow(2)) 191 | logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) 192 | 193 | return outputs, logabsdet 194 | -------------------------------------------------------------------------------- /module/quantize.py: -------------------------------------------------------------------------------- 1 | # from descript audio codec 2 | from typing import Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from einops import rearrange 9 | from torch.nn.utils import weight_norm 10 | 11 | 12 | def WNConv1d(*args, **kwargs): 13 | return weight_norm(nn.Conv1d(*args, **kwargs)) 14 | 15 | 16 | class VectorQuantize(nn.Module): 17 | """ 18 | Implementation of VQ similar to Karpathy's repo: 19 | https://github.com/karpathy/deep-vector-quantization 20 | Additionally uses following tricks from Improved VQGAN 21 | (https://arxiv.org/pdf/2110.04627.pdf): 22 | 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space 23 | for improved codebook usage 24 | 2. l2-normalized codes: Converts euclidean distance to cosine similarity which 25 | improves training stability 26 | """ 27 | 28 | def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): 29 | super().__init__() 30 | self.codebook_size = codebook_size 31 | self.codebook_dim = codebook_dim 32 | 33 | self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) 34 | self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) 35 | self.codebook = nn.Embedding(codebook_size, codebook_dim) 36 | 37 | def forward(self, z): 38 | """Quantized the input tensor using a fixed codebook and returns 39 | the corresponding codebook vectors 40 | 41 | Parameters 42 | ---------- 43 | z : Tensor[B x D x T] 44 | 45 | Returns 46 | ------- 47 | Tensor[B x D x T] 48 | Quantized continuous representation of input 49 | Tensor[1] 50 | Commitment loss to train encoder to predict vectors closer to codebook 51 | entries 52 | Tensor[1] 53 | Codebook loss to update the codebook 54 | Tensor[B x T] 55 | Codebook indices (quantized discrete representation of input) 56 | Tensor[B x D x T] 57 | Projected latents (continuous representation of input before quantization) 58 | """ 59 | 60 | # Factorized codes (ViT-VQGAN) Project input into low-dimensional space 61 | z_e = self.in_proj(z) # z_e : (B x D x T) 62 | z_q, indices = self.decode_latents(z_e) 63 | 64 | commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) 65 | codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) 66 | 67 | z_q = ( 68 | z_e + (z_q - z_e).detach() 69 | ) # noop in forward pass, straight-through gradient estimator in backward pass 70 | 71 | z_q = self.out_proj(z_q) 72 | 73 | return z_q, commitment_loss, codebook_loss, indices, z_e 74 | 75 | def embed_code(self, embed_id): 76 | return F.embedding(embed_id, self.codebook.weight) 77 | 78 | def decode_code(self, embed_id): 79 | return self.embed_code(embed_id).transpose(1, 2) 80 | 81 | def decode_latents(self, latents): 82 | encodings = rearrange(latents, "b d t -> (b t) d") 83 | codebook = self.codebook.weight # codebook: (N x D) 84 | 85 | # L2 normalize encodings and codebook (ViT-VQGAN) 86 | encodings = F.normalize(encodings) 87 | codebook = F.normalize(codebook) 88 | 89 | # Compute euclidean distance with codebook 90 | dist = ( 91 | encodings.pow(2).sum(1, keepdim=True) 92 | - 2 * encodings @ codebook.t() 93 | + codebook.pow(2).sum(1, keepdim=True).t() 94 | ) 95 | indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) 96 | z_q = self.decode_code(indices) 97 | return z_q, indices 98 | def decode(self, indices): 99 | z_q = self.decode_code(indices) 100 | z_q = self.out_proj(z_q) 101 | return z_q 102 | 103 | class ResidualVectorQuantize(nn.Module): 104 | """ 105 | Introduced in SoundStream: An end2end neural audio codec 106 | https://arxiv.org/abs/2107.03312 107 | """ 108 | 109 | def __init__( 110 | self, 111 | input_dim: int = 512, 112 | n_codebooks: int = 9, 113 | codebook_size: int = 1024, 114 | codebook_dim: Union[int, list] = 8, 115 | quantizer_dropout: float = 0.0, 116 | ): 117 | super().__init__() 118 | if isinstance(codebook_dim, int): 119 | codebook_dim = [codebook_dim for _ in range(n_codebooks)] 120 | 121 | self.n_codebooks = n_codebooks 122 | self.codebook_dim = codebook_dim 123 | self.codebook_size = codebook_size 124 | 125 | self.quantizers = nn.ModuleList( 126 | [ 127 | VectorQuantize(input_dim, codebook_size, codebook_dim[i]) 128 | for i in range(n_codebooks) 129 | ] 130 | ) 131 | self.quantizer_dropout = quantizer_dropout 132 | 133 | def forward(self, z, n_quantizers: int = None): 134 | """Quantized the input tensor using a fixed set of `n` codebooks and returns 135 | the corresponding codebook vectors 136 | Parameters 137 | ---------- 138 | z : Tensor[B x D x T] 139 | n_quantizers : int, optional 140 | No. of quantizers to use 141 | (n_quantizers < self.n_codebooks ex: for quantizer dropout) 142 | Note: if `self.quantizer_dropout` is True, this argument is ignored 143 | when in training mode, and a random number of quantizers is used. 144 | Returns 145 | ------- 146 | dict 147 | A dictionary with the following keys: 148 | 149 | "z" : Tensor[B x D x T] 150 | Quantized continuous representation of input 151 | "codes" : Tensor[B x N x T] 152 | Codebook indices for each codebook 153 | (quantized discrete representation of input) 154 | "latents" : Tensor[B x N*D x T] 155 | Projected latents (continuous representation of input before quantization) 156 | "vq/commitment_loss" : Tensor[1] 157 | Commitment loss to train encoder to predict vectors closer to codebook 158 | entries 159 | "vq/codebook_loss" : Tensor[1] 160 | Codebook loss to update the codebook 161 | """ 162 | z_q = 0 163 | residual = z 164 | commitment_loss = 0 165 | codebook_loss = 0 166 | 167 | codebook_indices = [] 168 | latents = [] 169 | 170 | if n_quantizers is None: 171 | n_quantizers = self.n_codebooks 172 | if self.training: 173 | n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 174 | dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) 175 | n_dropout = int(z.shape[0] * self.quantizer_dropout) 176 | n_quantizers[:n_dropout] = dropout[:n_dropout] 177 | n_quantizers = n_quantizers.to(z.device) 178 | 179 | for i, quantizer in enumerate(self.quantizers): 180 | if self.training is False and i >= n_quantizers: 181 | break 182 | 183 | z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( 184 | residual 185 | ) 186 | 187 | # Create mask to apply quantizer dropout 188 | mask = ( 189 | torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers 190 | ) 191 | z_q = z_q + z_q_i * mask[:, None, None] 192 | residual = residual - z_q_i 193 | 194 | # Sum losses 195 | commitment_loss += (commitment_loss_i * mask).mean() 196 | codebook_loss += (codebook_loss_i * mask).mean() 197 | 198 | codebook_indices.append(indices_i) 199 | latents.append(z_e_i) 200 | 201 | codes = torch.stack(codebook_indices, dim=1) 202 | latents = torch.cat(latents, dim=1) 203 | 204 | return z_q, codes, latents, commitment_loss, codebook_loss 205 | 206 | def from_codes(self, codes: torch.Tensor): 207 | """Given the quantized codes, reconstruct the continuous representation 208 | Parameters 209 | ---------- 210 | codes : Tensor[B x N x T] 211 | Quantized discrete representation of input 212 | Returns 213 | ------- 214 | Tensor[B x D x T] 215 | Quantized continuous representation of input 216 | """ 217 | z_q = 0.0 218 | z_p = [] 219 | n_codebooks = codes.shape[1] 220 | for i in range(n_codebooks): 221 | z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) 222 | z_p.append(z_p_i) 223 | 224 | z_q_i = self.quantizers[i].out_proj(z_p_i) 225 | z_q = z_q + z_q_i 226 | return z_q, torch.cat(z_p, dim=1), codes 227 | 228 | def from_latents(self, latents: torch.Tensor): 229 | """Given the unquantized latents, reconstruct the 230 | continuous representation after quantization. 231 | 232 | Parameters 233 | ---------- 234 | latents : Tensor[B x N x T] 235 | Continuous representation of input after projection 236 | 237 | Returns 238 | ------- 239 | Tensor[B x D x T] 240 | Quantized representation of full-projected space 241 | Tensor[B x D x T] 242 | Quantized representation of latent space 243 | """ 244 | z_q = 0 245 | z_p = [] 246 | codes = [] 247 | dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) 248 | 249 | n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ 250 | 0 251 | ] 252 | for i in range(n_codebooks): 253 | j, k = dims[i], dims[i + 1] 254 | z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) 255 | z_p.append(z_p_i) 256 | codes.append(codes_i) 257 | 258 | z_q_i = self.quantizers[i].out_proj(z_p_i) 259 | z_q = z_q + z_q_i 260 | 261 | return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) 262 | 263 | 264 | if __name__ == "__main__": 265 | rvq = ResidualVectorQuantize(quantizer_dropout=True) 266 | x = torch.randn(16, 512, 80) 267 | y = rvq(x) 268 | print(y["latents"].shape) 269 | -------------------------------------------------------------------------------- /asr/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import pytorch_lightning.core.module as pl 6 | from monotonic_align import mask_from_lens 7 | from monotonic_align.core import maximum_path_c 8 | import torch.nn.functional as F 9 | import jiwer 10 | from PIL import Image 11 | 12 | def maximum_path(neg_cent, mask): 13 | """ Cython optimized version. 14 | neg_cent: [b, t_t, t_s] 15 | mask: [b, t_t, t_s] 16 | """ 17 | device = neg_cent.device 18 | dtype = neg_cent.dtype 19 | neg_cent = np.ascontiguousarray(neg_cent.data.cpu().numpy().astype(np.float32)) 20 | path = np.ascontiguousarray(np.zeros(neg_cent.shape, dtype=np.int32)) 21 | 22 | t_t_max = np.ascontiguousarray(mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)) 23 | t_s_max = np.ascontiguousarray(mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)) 24 | maximum_path_c(path, neg_cent, t_t_max, t_s_max) 25 | return torch.from_numpy(path).to(device=device, dtype=dtype) 26 | 27 | 28 | def drop_duplicated(chars): 29 | ret_chars = [chars[0]] 30 | for prev, curr in zip(chars[:-1], chars[1:]): 31 | if prev != curr: 32 | ret_chars.append(curr) 33 | return ret_chars 34 | def calc_wer(target, pred, ignore_indexes=[0]): 35 | target_chars = drop_duplicated(list(filter(lambda x: x not in ignore_indexes, map(str, list(target))))) 36 | pred_chars = drop_duplicated(list(filter(lambda x: x not in ignore_indexes, map(str, list(pred))))) 37 | target_str = ' '.join(target_chars) 38 | pred_str = ' '.join(pred_chars) 39 | error = jiwer.wer(target_str, pred_str) 40 | return error 41 | 42 | class ASRTrainer(pl.LightningModule): 43 | def __init__(self, model, criterion, mono_start_epoch, lr): 44 | super().__init__() 45 | self.model = model 46 | self.criterion = criterion 47 | self.mono_start_epoch = mono_start_epoch 48 | self.lr=lr 49 | 50 | 51 | def load_checkpoint(self, checkpoint_path): 52 | """Load checkpoint. 53 | 54 | Args: 55 | checkpoint_path (str): Checkpoint path to be loaded. 56 | """ 57 | state_dict = torch.load(checkpoint_path, map_location="cpu") 58 | self.load_state_dict(state_dict["state_dict"]) 59 | 60 | @staticmethod 61 | def get_gradient_norm(model): 62 | total_norm = 0 63 | for p in model.parameters(): 64 | param_norm = p.grad.data.norm(2) 65 | total_norm += param_norm.item() ** 2 66 | 67 | total_norm = np.sqrt(total_norm) 68 | return total_norm 69 | 70 | @staticmethod 71 | def length_to_mask(lengths): 72 | mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) 73 | mask = torch.gt(mask + 1, lengths.unsqueeze(1)) 74 | return mask 75 | 76 | def calc_mono_loss(self, s2s_attn,input_lengths, mel_input_length, text_mask, mel_mask, n_down): 77 | s2s_attn = s2s_attn.transpose(-1, -2) 78 | s2s_attn = s2s_attn[..., 1:] 79 | s2s_attn = s2s_attn.transpose(-1, -2) 80 | 81 | with torch.no_grad(): 82 | attn_mask = (~mel_mask).unsqueeze(-1).expand(mel_mask.shape[0], mel_mask.shape[1], 83 | text_mask.shape[-1]).float().transpose(-1, -2) 84 | attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], 85 | text_mask.shape[1], 86 | mel_mask.shape[-1]).float() 87 | attn_mask = (attn_mask < 1) 88 | 89 | s2s_attn.masked_fill_(attn_mask, 0.0) 90 | 91 | with torch.no_grad(): 92 | mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length) 93 | s2s_attn_mono = maximum_path(s2s_attn, mask_ST) 94 | loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10 95 | 96 | return loss_mono, s2s_attn_mono 97 | 98 | def get_attention_mono(self, text_input, text_input_length, mel_input, mel_input_length): 99 | mel_input_length = mel_input_length // (2 ** self.model.n_down) 100 | future_mask = self.model.get_future_mask( 101 | mel_input.size(2) // (2 ** self.model.n_down), unmask_future_steps=0).to(self.device) 102 | mel_mask = self.model.length_to_mask(mel_input_length) 103 | text_mask = self.model.length_to_mask(text_input_length) 104 | ppgs, s2s_pred, s2s_attn = self.model( 105 | mel_input, src_key_padding_mask=mel_mask, text_input=text_input) 106 | loss_mono, s2s_attn_mono = self.calc_mono_loss(s2s_attn, text_input_length, mel_input_length, text_mask, mel_mask, self.model.n_down) 107 | return s2s_attn_mono 108 | 109 | 110 | @staticmethod 111 | def get_image(arrs): 112 | pil_images = [] 113 | height = 0 114 | width = 0 115 | for arr in arrs: 116 | uint_arr = (((arr - arr.min()) / (arr.max() - arr.min())) * 255).astype(np.uint8) 117 | pil_image = Image.fromarray(uint_arr) 118 | pil_images.append(pil_image) 119 | height += uint_arr.shape[0] 120 | width = max(width, uint_arr.shape[1]) 121 | 122 | palette = Image.new('L', (width, height)) 123 | curr_heigth = 0 124 | for pil_image in pil_images: 125 | palette.paste(pil_image, (0, curr_heigth)) 126 | curr_heigth += pil_image.size[1] 127 | 128 | return palette 129 | 130 | def _load(self, states, model, force_load=True): 131 | model_states = model.state_dict() 132 | for key, val in states.items(): 133 | try: 134 | if key not in model_states: 135 | continue 136 | if isinstance(val, nn.Parameter): 137 | val = val.data 138 | 139 | if val.shape != model_states[key].shape: 140 | print("%s does not have same shape" % key) 141 | print(val.shape, model_states[key].shape) 142 | if not force_load: 143 | continue 144 | 145 | min_shape = np.minimum(np.array(val.shape), np.array(model_states[key].shape)) 146 | slices = [slice(0, min_index) for min_index in min_shape] 147 | model_states[key][slices].copy_(val[slices]) 148 | else: 149 | model_states[key].copy_(val) 150 | except: 151 | print("not exist :%s" % key) 152 | 153 | def configure_optimizers(self): 154 | optimizer = optim.AdamW([*self.parameters()], lr=self.lr, betas=(0.9, 0.95), weight_decay=0.1) 155 | return optimizer 156 | 157 | def training_step(self, batch, batch_idx): 158 | text_input, text_input_length, mel_input, mel_input_length = batch 159 | 160 | mel_input_length = mel_input_length // (2 ** self.model.n_down) 161 | future_mask = self.model.get_future_mask( 162 | mel_input.size(2) // (2 ** self.model.n_down), unmask_future_steps=0).to(self.device) 163 | mel_mask = self.model.length_to_mask(mel_input_length) 164 | text_mask = self.model.length_to_mask(text_input_length) 165 | ppgs, s2s_pred, s2s_attn = self.model( 166 | mel_input, src_key_padding_mask=mel_mask, text_input=text_input) 167 | loss_mono, s2s_attn_mono = self.calc_mono_loss(s2s_attn, text_input_length, mel_input_length, text_mask, 168 | mel_mask, self.model.n_down) 169 | loss_ctc = self.criterion['ctc'](ppgs.log_softmax(dim=2).transpose(0, 1), 170 | text_input, mel_input_length, text_input_length) 171 | 172 | loss_s2s = 0 173 | for _s2s_pred, _text_input, _text_length in zip(s2s_pred, text_input, text_input_length): 174 | loss_s2s += self.criterion['ce'](_s2s_pred[:_text_length], _text_input[:_text_length]) 175 | loss_s2s /= text_input.size(0) 176 | if self.current_epoch > self.mono_start_epoch: 177 | loss_ctc = loss_ctc * 0 178 | else: 179 | loss_mono = loss_mono * 0 180 | loss = loss_ctc + loss_s2s + loss_mono 181 | 182 | 183 | self.log("train/loss_ctc", loss_ctc, on_step=True, prog_bar=True) 184 | self.log("train/loss_s2s", loss_s2s, on_step=True, prog_bar=True) 185 | self.log("train/loss_mono", loss_mono, on_step=True, prog_bar=True) 186 | return loss 187 | 188 | def validation_step(self, batch, batch_idx): 189 | 190 | text_input, text_input_length, mel_input, mel_input_length = batch 191 | mel_input_length = mel_input_length // (2 ** self.model.n_down) 192 | future_mask = self.model.get_future_mask( 193 | mel_input.size(2) // (2 ** self.model.n_down), unmask_future_steps=0).to(self.device) 194 | mel_mask = self.model.length_to_mask(mel_input_length) 195 | text_mask = self.model.length_to_mask(text_input_length) 196 | ppgs, s2s_pred, s2s_attn = self.model( 197 | mel_input, src_key_padding_mask=mel_mask, text_input=text_input) 198 | loss_mono, s2s_attn_mono = self.calc_mono_loss(s2s_attn, text_input_length, mel_input_length, text_mask, 199 | mel_mask, self.model.n_down) 200 | loss_ctc = self.criterion['ctc'](ppgs.log_softmax(dim=2).transpose(0, 1), 201 | text_input, mel_input_length, text_input_length) 202 | loss_s2s = 0 203 | for _s2s_pred, _text_input, _text_length in zip(s2s_pred, text_input, text_input_length): 204 | loss_s2s += self.criterion['ce'](_s2s_pred[:_text_length], _text_input[:_text_length]) 205 | loss_s2s /= text_input.size(0) 206 | loss = loss_ctc + loss_s2s + loss_mono 207 | 208 | self.log("val/ctc", loss_ctc.item(), on_step=False, prog_bar=True) 209 | self.log("val/s2s", loss_s2s.item(), on_step=False, prog_bar=True) 210 | self.log("val/loss", loss.item(), on_step=False, prog_bar=True) 211 | self.log("val/mono", loss_mono.item(), on_step=False, prog_bar=True) 212 | 213 | _, amax_ppgs = torch.max(ppgs, dim=2) 214 | wers = [calc_wer(target[:text_length], 215 | pred[:mel_length], 216 | ignore_indexes=list(range(5))) \ 217 | for target, pred, text_length, mel_length in zip( 218 | text_input.cpu(), amax_ppgs.cpu(), text_input_length.cpu(), mel_input_length.cpu())] 219 | self.log("val/wers", np.mean(wers), on_step=False, prog_bar=True) 220 | 221 | _, amax_s2s = torch.max(s2s_pred, dim=2) 222 | acc = [torch.eq(target[:length], pred[:length]).float().mean().item() \ 223 | for target, pred, length in zip(text_input.cpu(), amax_s2s.cpu(), text_input_length.cpu())] 224 | 225 | self.log("val/acc", np.mean(acc), on_step=False, prog_bar=True) 226 | attn_img = self.get_image([s2s_attn[0].cpu().numpy()]) 227 | attn_mono_img = self.get_image([s2s_attn_mono[0].cpu().numpy()]) 228 | # self.logger.experiment.add_image("val/attn", attn_img, self.current_epoch) 229 | # self.logger.experiment.add_image("val/attn_mono", attn_mono_img, self.current_epoch) 230 | self.logger.log_image(key="attn", images=[attn_img, attn_mono_img], caption=["soft", "mono"]) 231 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import sys 4 | import argparse 5 | import logging 6 | import json 7 | import subprocess 8 | 9 | import librosa 10 | import numpy as np 11 | from scipy.io.wavfile import read 12 | import torch 13 | import logging 14 | logging.getLogger('numba').setLevel(logging.INFO) 15 | 16 | MATPLOTLIB_FLAG = False 17 | 18 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) 19 | logger = logging 20 | 21 | 22 | def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False): 23 | assert os.path.isfile(checkpoint_path) 24 | checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') 25 | iteration = checkpoint_dict['iteration'] 26 | learning_rate = checkpoint_dict['learning_rate'] 27 | if optimizer is not None and not skip_optimizer and checkpoint_dict['optimizer'] is not None: 28 | optimizer.load_state_dict(checkpoint_dict['optimizer']) 29 | saved_state_dict = checkpoint_dict['model'] 30 | copy_state_dict(model, saved_state_dict) 31 | print("load ") 32 | logger.info("Loaded checkpoint '{}' (iteration {})".format( 33 | checkpoint_path, iteration)) 34 | return model, optimizer, learning_rate, iteration 35 | 36 | 37 | def copy_state_dict(model, saved_state_dict): 38 | if hasattr(model, 'module'): 39 | state_dict = model.module.state_dict() 40 | else: 41 | state_dict = model.state_dict() 42 | new_state_dict = {} 43 | for k, v in state_dict.items(): 44 | try: 45 | # assert "quantizer" not in k 46 | # print("load", k) 47 | new_state_dict[k] = saved_state_dict[k] 48 | assert saved_state_dict[k].shape == v.shape, (saved_state_dict[k].shape, v.shape) 49 | except: 50 | print("error, %s is not in the checkpoint" % k) 51 | new_state_dict[k] = v 52 | if hasattr(model, 'module'): 53 | model.module.load_state_dict(new_state_dict) 54 | else: 55 | model.load_state_dict(new_state_dict) 56 | 57 | 58 | def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): 59 | logger.info("Saving model and optimizer state at iteration {} to {}".format( 60 | iteration, checkpoint_path)) 61 | if hasattr(model, 'module'): 62 | state_dict = model.module.state_dict() 63 | else: 64 | state_dict = model.state_dict() 65 | torch.save({'model': state_dict, 66 | 'iteration': iteration, 67 | 'optimizer': optimizer.state_dict(), 68 | 'learning_rate': learning_rate}, checkpoint_path) 69 | 70 | 71 | def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): 72 | for k, v in scalars.items(): 73 | writer.add_scalar(k, v, global_step) 74 | for k, v in histograms.items(): 75 | writer.add_histogram(k, v, global_step) 76 | for k, v in images.items(): 77 | writer.add_image(k, v, global_step, dataformats='HWC') 78 | for k, v in audios.items(): 79 | writer.add_audio(k, v, global_step, audio_sampling_rate) 80 | 81 | 82 | def latest_checkpoint_path(dir_path, regex="G_*.pth"): 83 | f_list = glob.glob(os.path.join(dir_path, regex)) 84 | f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) 85 | x = f_list[-1] 86 | print(x) 87 | return x 88 | 89 | 90 | def plot_spectrogram_to_numpy(spectrogram): 91 | global MATPLOTLIB_FLAG 92 | if not MATPLOTLIB_FLAG: 93 | import matplotlib 94 | matplotlib.use("Agg") 95 | MATPLOTLIB_FLAG = True 96 | mpl_logger = logging.getLogger('matplotlib') 97 | mpl_logger.setLevel(logging.WARNING) 98 | import matplotlib.pylab as plt 99 | import numpy as np 100 | 101 | fig, ax = plt.subplots(figsize=(10, 2)) 102 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", 103 | interpolation='none') 104 | plt.colorbar(im, ax=ax) 105 | plt.xlabel("Frames") 106 | plt.ylabel("Channels") 107 | plt.tight_layout() 108 | 109 | fig.canvas.draw() 110 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 111 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 112 | plt.close() 113 | return data 114 | 115 | 116 | def plot_alignment_to_numpy(alignment, info=None): 117 | global MATPLOTLIB_FLAG 118 | if not MATPLOTLIB_FLAG: 119 | import matplotlib 120 | matplotlib.use("Agg") 121 | MATPLOTLIB_FLAG = True 122 | mpl_logger = logging.getLogger('matplotlib') 123 | mpl_logger.setLevel(logging.WARNING) 124 | import matplotlib.pylab as plt 125 | import numpy as np 126 | 127 | fig, ax = plt.subplots(figsize=(6, 4)) 128 | im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', 129 | interpolation='none') 130 | fig.colorbar(im, ax=ax) 131 | xlabel = 'Decoder timestep' 132 | if info is not None: 133 | xlabel += '\n\n' + info 134 | plt.xlabel(xlabel) 135 | plt.ylabel('Encoder timestep') 136 | plt.tight_layout() 137 | 138 | fig.canvas.draw() 139 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 140 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 141 | plt.close() 142 | return data 143 | 144 | 145 | def load_wav_to_torch(full_path): 146 | data, sampling_rate = librosa.load(full_path, sr=None) 147 | return torch.FloatTensor(data), sampling_rate 148 | 149 | 150 | def load_filepaths_and_text(filename, split="|"): 151 | with open(filename, encoding='utf-8') as f: 152 | filepaths_and_text = [line.strip().split(split) for line in f] 153 | return filepaths_and_text 154 | 155 | 156 | def get_hparams(init=True, stage=1): 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument('-c', '--config', type=str, default="./configs/config.json", 159 | help='JSON file for configuration') 160 | parser.add_argument('-p', '--pretrain', type=str, required=False,default=None, 161 | help='pretrain dir') 162 | parser.add_argument('-rs', '--resume_step', type=int, required=False,default=None, 163 | help='resume step') 164 | 165 | args = parser.parse_args() 166 | 167 | config_path = args.config 168 | with open(config_path, "r") as f: 169 | data = f.read() 170 | config = json.loads(data) 171 | 172 | hparams = HParams(**config) 173 | hparams.pretrain = args.pretrain 174 | hparams.resume_step = args.resume_step 175 | if stage ==1: 176 | model_dir = hparams.s1_ckpt_dir 177 | else: 178 | model_dir = hparams.s2_ckpt_dir 179 | config_save_path = os.path.join(model_dir, "config.json") 180 | 181 | if not os.path.exists(model_dir): 182 | os.makedirs(model_dir) 183 | 184 | with open(config_save_path, "w") as f: 185 | f.write(data) 186 | return hparams 187 | 188 | 189 | 190 | def clean_checkpoints(path_to_models='logs/44k/', n_ckpts_to_keep=2, sort_by_time=True): 191 | """Freeing up space by deleting saved ckpts 192 | 193 | Arguments: 194 | path_to_models -- Path to the model directory 195 | n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth 196 | sort_by_time -- True -> chronologically delete ckpts 197 | False -> lexicographically delete ckpts 198 | """ 199 | import re 200 | ckpts_files = [f for f in os.listdir(path_to_models) if os.path.isfile(os.path.join(path_to_models, f))] 201 | name_key = (lambda _f: int(re.compile('._(\d+)\.pth').match(_f).group(1))) 202 | time_key = (lambda _f: os.path.getmtime(os.path.join(path_to_models, _f))) 203 | sort_key = time_key if sort_by_time else name_key 204 | x_sorted = lambda _x: sorted([f for f in ckpts_files if f.startswith(_x) and not f.endswith('_0.pth')], 205 | key=sort_key) 206 | to_del = [os.path.join(path_to_models, fn) for fn in 207 | (x_sorted('G')[:-n_ckpts_to_keep] + x_sorted('D')[:-n_ckpts_to_keep])] 208 | del_info = lambda fn: logger.info(f".. Free up space by deleting ckpt {fn}") 209 | del_routine = lambda x: [os.remove(x), del_info(x)] 210 | rs = [del_routine(fn) for fn in to_del] 211 | 212 | def get_hparams_from_dir(model_dir): 213 | config_save_path = os.path.join(model_dir, "config.json") 214 | with open(config_save_path, "r") as f: 215 | data = f.read() 216 | config = json.loads(data) 217 | 218 | hparams = HParams(**config) 219 | hparams.model_dir = model_dir 220 | return hparams 221 | 222 | 223 | def get_hparams_from_file(config_path): 224 | with open(config_path, "r") as f: 225 | data = f.read() 226 | config = json.loads(data) 227 | 228 | hparams = HParams(**config) 229 | return hparams 230 | 231 | def check_git_hash(model_dir): 232 | source_dir = os.path.dirname(os.path.realpath(__file__)) 233 | if not os.path.exists(os.path.join(source_dir, ".git")): 234 | logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format( 235 | source_dir 236 | )) 237 | return 238 | 239 | cur_hash = subprocess.getoutput("git rev-parse HEAD") 240 | 241 | path = os.path.join(model_dir, "githash") 242 | if os.path.exists(path): 243 | saved_hash = open(path).read() 244 | if saved_hash != cur_hash: 245 | logger.warn("git hash values are different. {}(saved) != {}(current)".format( 246 | saved_hash[:8], cur_hash[:8])) 247 | else: 248 | open(path, "w").write(cur_hash) 249 | 250 | 251 | def get_logger(model_dir, filename="train.log"): 252 | global logger 253 | logger = logging.getLogger(os.path.basename(model_dir)) 254 | logger.setLevel(logging.DEBUG) 255 | 256 | formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s") 257 | if not os.path.exists(model_dir): 258 | os.makedirs(model_dir) 259 | h = logging.FileHandler(os.path.join(model_dir, filename)) 260 | h.setLevel(logging.DEBUG) 261 | h.setFormatter(formatter) 262 | logger.addHandler(h) 263 | return logger 264 | 265 | 266 | class HParams(): 267 | def __init__(self, **kwargs): 268 | for k, v in kwargs.items(): 269 | if type(v) == dict: 270 | v = HParams(**v) 271 | self[k] = v 272 | 273 | def keys(self): 274 | return self.__dict__.keys() 275 | 276 | def items(self): 277 | return self.__dict__.items() 278 | 279 | def values(self): 280 | return self.__dict__.values() 281 | 282 | def __len__(self): 283 | return len(self.__dict__) 284 | 285 | def __getitem__(self, key): 286 | return getattr(self, key) 287 | 288 | def __setitem__(self, key, value): 289 | return setattr(self, key, value) 290 | 291 | def __contains__(self, key): 292 | return key in self.__dict__ 293 | 294 | def __repr__(self): 295 | return self.__dict__.__repr__() 296 | 297 | 298 | from matplotlib import pyplot as plt 299 | 300 | mpl_logger = logging.getLogger('matplotlib') 301 | mpl_logger.setLevel(logging.WARNING) 302 | def plot_alignment(data, titles=None, save_dir=None): 303 | fig, axes = plt.subplots(len(data), 1, figsize=[6,4],dpi=300) 304 | plt.subplots_adjust(top = 0.9, bottom = 0.1, right = 0.95, left = 0.05) 305 | if titles is None: 306 | titles = [None for i in range(len(data))] 307 | 308 | for i in range(len(data)): 309 | im = data[i] 310 | axes[i].imshow(im, origin='lower') 311 | # axes[i].set_xlabel('Audio') 312 | # axes[i].set_ylabel('Text') 313 | axes[i].set_ylim(0, im.shape[0]) 314 | axes[i].set_xlim(0, im.shape[1]) 315 | axes[i].set_title(titles[i], fontsize='medium') 316 | axes[i].tick_params(labelsize='x-small') 317 | axes[i].set_anchor('W') 318 | plt.tight_layout() 319 | 320 | fig.canvas.draw() 321 | data = save_figure_to_numpy(fig) 322 | if save_dir is not None: 323 | plt.savefig(save_dir) 324 | plt.close() 325 | return data 326 | 327 | def save_figure_to_numpy(fig): 328 | # save it to a numpy array. 329 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') 330 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 331 | return data 332 | -------------------------------------------------------------------------------- /transformer/transformer.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/transformer.py 2 | import copy 3 | import numbers 4 | from functools import partial 5 | from typing import Any 6 | from typing import Callable 7 | from typing import List 8 | from typing import Optional 9 | from typing import Tuple 10 | from typing import Union 11 | 12 | import torch 13 | from .activation import MultiheadAttention 14 | from .scaling import BalancedDoubleSwish 15 | from torch import nn 16 | from torch import Tensor 17 | from torch.nn import functional as F 18 | 19 | _shape_t = Union[int, List[int], torch.Size] 20 | 21 | 22 | class LayerNorm(nn.Module): 23 | __constants__ = ["normalized_shape", "eps", "elementwise_affine"] 24 | normalized_shape: Tuple[int, ...] 25 | eps: float 26 | elementwise_affine: bool 27 | 28 | def __init__( 29 | self, 30 | normalized_shape: _shape_t, 31 | eps: float=1e-5, 32 | elementwise_affine: bool=True, 33 | device=None, 34 | dtype=None, ) -> None: 35 | factory_kwargs = {"device": device, "dtype": dtype} 36 | super(LayerNorm, self).__init__() 37 | if isinstance(normalized_shape, numbers.Integral): 38 | # mypy error: incompatible types in assignment 39 | normalized_shape = (normalized_shape, ) # type: ignore[assignment] 40 | self.normalized_shape = tuple( 41 | normalized_shape) # type: ignore[arg-type] 42 | self.eps = eps 43 | self.elementwise_affine = elementwise_affine 44 | if self.elementwise_affine: 45 | self.weight = nn.Parameter( 46 | torch.empty(self.normalized_shape, **factory_kwargs)) 47 | self.bias = nn.Parameter( 48 | torch.empty(self.normalized_shape, **factory_kwargs)) 49 | else: 50 | self.register_parameter("weight", None) 51 | self.register_parameter("bias", None) 52 | 53 | self.reset_parameters() 54 | 55 | def reset_parameters(self) -> None: 56 | if self.elementwise_affine: 57 | nn.init.ones_(self.weight) 58 | nn.init.zeros_(self.bias) 59 | 60 | def forward(self, input: Tensor, embedding: Any=None) -> Tensor: 61 | if isinstance(input, tuple): 62 | input, embedding = input 63 | return (F.layer_norm( 64 | input, 65 | self.normalized_shape, 66 | self.weight, 67 | self.bias, 68 | self.eps, ), embedding, ) 69 | 70 | assert embedding is None 71 | return F.layer_norm(input, self.normalized_shape, self.weight, 72 | self.bias, self.eps) 73 | 74 | def extra_repr(self) -> str: 75 | return ( 76 | "{normalized_shape}, eps={eps}, " 77 | "elementwise_affine={elementwise_affine}".format(**self.__dict__)) 78 | 79 | 80 | class IdentityNorm(nn.Module): 81 | def __init__( 82 | self, 83 | d_model: int, 84 | eps: float=1e-5, 85 | device=None, 86 | dtype=None, ) -> None: 87 | super(IdentityNorm, self).__init__() 88 | 89 | def forward(self, input: Tensor, embedding: Any=None) -> Tensor: 90 | if isinstance(input, tuple): 91 | return input 92 | 93 | assert embedding is None 94 | return input 95 | 96 | 97 | class TransformerEncoder(nn.Module): 98 | r"""TransformerEncoder is a stack of N encoder layers. Users can build the 99 | BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. 100 | 101 | Args: 102 | encoder_layer: an instance of the TransformerEncoderLayer() class (required). 103 | num_layers: the number of sub-encoder-layers in the encoder (required). 104 | norm: the layer normalization component (optional). 105 | enable_nested_tensor: if True, input will automatically convert to nested tensor 106 | (and convert back on output). This will improve the overall performance of 107 | TransformerEncoder when padding rate is high. Default: ``True`` (enabled). 108 | 109 | Examples:: 110 | >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) 111 | >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) 112 | >>> src = torch.rand(10, 32, 512) 113 | >>> out = transformer_encoder(src) 114 | """ 115 | __constants__ = ["norm"] 116 | 117 | def __init__(self, encoder_layer, num_layers, norm=None): 118 | super(TransformerEncoder, self).__init__() 119 | self.layers = _get_clones(encoder_layer, num_layers) 120 | self.num_layers = num_layers 121 | self.norm = norm 122 | 123 | def forward( 124 | self, 125 | src: Tensor, 126 | mask: Optional[Tensor]=None, 127 | src_key_padding_mask: Optional[Tensor]=None, 128 | return_layer_states: bool=False, ) -> Tensor: 129 | r"""Pass the input through the encoder layers in turn. 130 | 131 | Args: 132 | src: the sequence to the encoder (required). 133 | mask: the mask for the src sequence (optional). 134 | src_key_padding_mask: the mask for the src keys per batch (optional). 135 | return_layer_states: return layers' state (optional). 136 | 137 | Shape: 138 | see the docs in Transformer class. 139 | """ 140 | if return_layer_states: 141 | layer_states = [] # layers' output 142 | output = src 143 | for mod in self.layers: 144 | output = mod( 145 | output, 146 | src_mask=mask, 147 | src_key_padding_mask=src_key_padding_mask, ) 148 | layer_states.append(output[0]) 149 | 150 | if self.norm is not None: 151 | output = self.norm(output) 152 | 153 | return layer_states, output 154 | 155 | output = src 156 | for mod in self.layers: 157 | output = mod(output, 158 | src_mask=mask, 159 | src_key_padding_mask=src_key_padding_mask) 160 | 161 | if self.norm is not None: 162 | output = self.norm(output) 163 | 164 | return output 165 | 166 | 167 | class TransformerEncoderLayer(nn.Module): 168 | __constants__ = ["batch_first", "norm_first"] 169 | 170 | def __init__( 171 | self, 172 | d_model: int, 173 | nhead: int, 174 | dim_feedforward: int=2048, 175 | dropout: float=0.1, 176 | activation: Union[str, Callable[[Tensor], Tensor]]=F.relu, 177 | batch_first: bool=False, 178 | norm_first: bool=False, 179 | device=None, 180 | dtype=None, 181 | linear1_self_attention_cls: nn.Module=nn.Linear, 182 | linear2_self_attention_cls: nn.Module=nn.Linear, 183 | linear1_feedforward_cls: nn.Module=nn.Linear, 184 | linear2_feedforward_cls: nn.Module=nn.Linear, 185 | layer_norm_cls: nn.Module=LayerNorm, 186 | layer_norm_eps: float=1e-5, 187 | adaptive_layer_norm=False, ) -> None: 188 | factory_kwargs = {"device": device, "dtype": dtype} 189 | super(TransformerEncoderLayer, self).__init__() 190 | self.self_attn = MultiheadAttention( 191 | d_model, 192 | nhead, 193 | dropout=dropout, 194 | batch_first=batch_first, 195 | linear1_cls=linear1_self_attention_cls, 196 | linear2_cls=linear2_self_attention_cls, 197 | **factory_kwargs, ) 198 | 199 | # Implementation of Feedforward model 200 | self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, 201 | **factory_kwargs) 202 | self.dropout = nn.Dropout(dropout) 203 | self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, 204 | **factory_kwargs) 205 | 206 | self.norm_first = norm_first 207 | self.dropout1 = nn.Dropout(dropout) 208 | self.dropout2 = nn.Dropout(dropout) 209 | 210 | # Legacy string support for activation function. 211 | if isinstance(activation, str): 212 | activation = _get_activation_fn(activation) 213 | elif isinstance(activation, partial): 214 | activation = activation(d_model) 215 | elif activation == BalancedDoubleSwish: 216 | activation = BalancedDoubleSwish(d_model) 217 | 218 | # # We can't test self.activation in forward() in TorchScript, 219 | # # so stash some information about it instead. 220 | # if activation is F.relu or isinstance(activation, torch.nn.ReLU): 221 | # self.activation_relu_or_gelu = 1 222 | # elif activation is F.gelu or isinstance(activation, torch.nn.GELU): 223 | # self.activation_relu_or_gelu = 2 224 | # else: 225 | # self.activation_relu_or_gelu = 0 226 | self.activation = activation 227 | 228 | norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) 229 | if layer_norm_cls == IdentityNorm: 230 | norm2 = BalancedBasicNorm( 231 | d_model, eps=layer_norm_eps, **factory_kwargs) 232 | else: 233 | norm2 = layer_norm_cls( 234 | d_model, eps=layer_norm_eps, **factory_kwargs) 235 | 236 | if adaptive_layer_norm: 237 | self.norm1 = AdaptiveLayerNorm(d_model, norm1) 238 | self.norm2 = AdaptiveLayerNorm(d_model, norm2) 239 | else: 240 | self.norm1 = norm1 241 | self.norm2 = norm2 242 | 243 | def __setstate__(self, state): 244 | super(TransformerEncoderLayer, self).__setstate__(state) 245 | if not hasattr(self, "activation"): 246 | self.activation = F.relu 247 | 248 | def forward( 249 | self, 250 | src: Tensor, 251 | src_mask: Optional[Tensor]=None, 252 | src_key_padding_mask: Optional[Tensor]=None, ) -> Tensor: 253 | r"""Pass the input through the encoder layer. 254 | 255 | Args: 256 | src: the sequence to the encoder layer (required). 257 | src_mask: the mask for the src sequence (optional). 258 | src_key_padding_mask: the mask for the src keys per batch (optional). 259 | 260 | Shape: 261 | see the docs in Transformer class. 262 | """ 263 | x, stage_embedding = src, None 264 | is_src_tuple = False 265 | if isinstance(src, tuple): 266 | x, stage_embedding = src 267 | is_src_tuple = True 268 | 269 | if src_key_padding_mask is not None: 270 | _skpm_dtype = src_key_padding_mask.dtype 271 | if _skpm_dtype != torch.bool and not torch.is_floating_point( 272 | src_key_padding_mask): 273 | raise AssertionError( 274 | "only bool and floating types of key_padding_mask are supported" 275 | ) 276 | 277 | if self.norm_first: 278 | x = x + self._sa_block( 279 | self.norm1(x, stage_embedding), 280 | src_mask, 281 | src_key_padding_mask, ) 282 | x = x + self._ff_block(self.norm2(x, stage_embedding)) 283 | else: 284 | x = self.norm1( 285 | x + self._sa_block(x, src_mask, src_key_padding_mask), 286 | stage_embedding, ) 287 | x = self.norm2(x + self._ff_block(x), stage_embedding) 288 | 289 | if is_src_tuple: 290 | return (x, stage_embedding) 291 | return x 292 | 293 | # self-attention block 294 | def _sa_block( 295 | self, 296 | x: Tensor, 297 | attn_mask: Optional[Tensor], 298 | key_padding_mask: Optional[Tensor], ) -> Tensor: 299 | x = self.self_attn( 300 | x, 301 | x, 302 | x, 303 | attn_mask=attn_mask, 304 | key_padding_mask=key_padding_mask, 305 | need_weights=False, )[0] 306 | return self.dropout1(x) 307 | 308 | # feed forward block 309 | def _ff_block(self, x: Tensor) -> Tensor: 310 | x = self.linear2(self.dropout(self.activation(self.linear1(x)))) 311 | return self.dropout2(x) 312 | 313 | 314 | class AdaptiveLayerNorm(nn.Module): 315 | r"""Adaptive Layer Normalization""" 316 | 317 | def __init__(self, d_model, norm) -> None: 318 | super(AdaptiveLayerNorm, self).__init__() 319 | self.project_layer = nn.Linear(d_model, 2 * d_model) 320 | self.norm = norm 321 | self.d_model = d_model 322 | self.eps = self.norm.eps 323 | 324 | def forward(self, input: Tensor, embedding: Tensor=None) -> Tensor: 325 | if isinstance(input, tuple): 326 | input, embedding = input 327 | weight, bias = torch.split( 328 | self.project_layer(embedding), 329 | split_size_or_sections=self.d_model, 330 | dim=-1, ) 331 | return (weight * self.norm(input) + bias, embedding) 332 | 333 | weight, bias = torch.split( 334 | self.project_layer(embedding), 335 | split_size_or_sections=self.d_model, 336 | dim=-1, ) 337 | return weight * self.norm(input) + bias 338 | 339 | def _get_clones(module, N): 340 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 341 | -------------------------------------------------------------------------------- /transformer/scaling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) 2 | # 3 | # See ../../../../LICENSE for clarification regarding multiple authors 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | import logging 17 | import math 18 | import random 19 | from typing import Optional 20 | from typing import Tuple 21 | from typing import Union 22 | 23 | import torch 24 | import torch.nn as nn 25 | from torch import Tensor 26 | 27 | 28 | class DoubleSwishFunction(torch.autograd.Function): 29 | """ 30 | double_swish(x) = x * torch.sigmoid(x-1) 31 | This is a definition, originally motivated by its close numerical 32 | similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). 33 | 34 | Memory-efficient derivative computation: 35 | double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) 36 | double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). 37 | Now, s'(x) = s(x) * (1-s(x)). 38 | double_swish'(x) = x * s'(x) + s(x). 39 | = x * s(x) * (1-s(x)) + s(x). 40 | = double_swish(x) * (1-s(x)) + s(x) 41 | ... so we just need to remember s(x) but not x itself. 42 | """ 43 | 44 | @staticmethod 45 | def forward(ctx, x: Tensor) -> Tensor: 46 | requires_grad = x.requires_grad 47 | x_dtype = x.dtype 48 | if x.dtype == torch.float16: 49 | x = x.to(torch.float32) 50 | 51 | s = torch.sigmoid(x - 1.0) 52 | y = x * s 53 | 54 | if requires_grad: 55 | deriv = y * (1 - s) + s 56 | # notes on derivative of x * sigmoid(x - 1): 57 | # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 58 | # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund 59 | # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. 60 | # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which 61 | # floors), should be expectation-preserving. 62 | floor = -0.043637 63 | ceil = 1.2 64 | d_scaled = (deriv - floor) * (255.0 / (ceil - floor) 65 | ) + torch.rand_like(deriv) 66 | if __name__ == "__main__": 67 | # for self-testing only. 68 | assert d_scaled.min() >= 0.0 69 | assert d_scaled.max() < 256.0 70 | d_int = d_scaled.to(torch.uint8) 71 | ctx.save_for_backward(d_int) 72 | if x.dtype == torch.float16 or torch.is_autocast_enabled(): 73 | y = y.to(torch.float16) 74 | return y 75 | 76 | @staticmethod 77 | def backward(ctx, y_grad: Tensor) -> Tensor: 78 | (d, ) = ctx.saved_tensors 79 | # the same constants as used in forward pass. 80 | floor = -0.043637 81 | ceil = 1.2 82 | d = d * ((ceil - floor) / 255.0) + floor 83 | return y_grad * d 84 | 85 | 86 | class DoubleSwish(torch.nn.Module): 87 | def forward(self, x: Tensor) -> Tensor: 88 | """Return double-swish activation function which is an approximation to Swish(Swish(x)), 89 | that we approximate closely with x * sigmoid(x-1). 90 | """ 91 | if torch.jit.is_scripting() or torch.jit.is_tracing(): 92 | return x * torch.sigmoid(x - 1.0) 93 | return DoubleSwishFunction.apply(x) 94 | 95 | 96 | class ActivationBalancerFunction(torch.autograd.Function): 97 | @staticmethod 98 | def forward( 99 | ctx, 100 | x: Tensor, 101 | scale_factor: Tensor, 102 | sign_factor: Optional[Tensor], 103 | channel_dim: int, ) -> Tensor: 104 | if channel_dim < 0: 105 | channel_dim += x.ndim 106 | ctx.channel_dim = channel_dim 107 | xgt0 = x > 0 108 | if sign_factor is None: 109 | ctx.save_for_backward(xgt0, scale_factor) 110 | else: 111 | ctx.save_for_backward(xgt0, scale_factor, sign_factor) 112 | return x 113 | 114 | @staticmethod 115 | def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: 116 | if len(ctx.saved_tensors) == 3: 117 | xgt0, scale_factor, sign_factor = ctx.saved_tensors 118 | for _ in range(ctx.channel_dim, x_grad.ndim - 1): 119 | scale_factor = scale_factor.unsqueeze(-1) 120 | sign_factor = sign_factor.unsqueeze(-1) 121 | factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) 122 | else: 123 | xgt0, scale_factor = ctx.saved_tensors 124 | for _ in range(ctx.channel_dim, x_grad.ndim - 1): 125 | scale_factor = scale_factor.unsqueeze(-1) 126 | factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) 127 | neg_delta_grad = x_grad.abs() * factor 128 | return (x_grad - neg_delta_grad, None, None, None, ) 129 | 130 | 131 | def _compute_scale_factor( 132 | x: Tensor, 133 | channel_dim: int, 134 | min_abs: float, 135 | max_abs: float, 136 | gain_factor: float, 137 | max_factor: float, ) -> Tensor: 138 | if channel_dim < 0: 139 | channel_dim += x.ndim 140 | sum_dims = [d for d in range(x.ndim) if d != channel_dim] 141 | x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) 142 | 143 | if min_abs == 0.0: 144 | below_threshold = 0.0 145 | else: 146 | # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if 147 | # x_abs)_mean , min_abs. 148 | below_threshold = ( 149 | (min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( 150 | min=0, max=max_factor) 151 | 152 | above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( 153 | min=0, max=max_factor) 154 | 155 | return below_threshold - above_threshold 156 | 157 | 158 | def _compute_sign_factor( 159 | x: Tensor, 160 | channel_dim: int, 161 | min_positive: float, 162 | max_positive: float, 163 | gain_factor: float, 164 | max_factor: float, ) -> Tensor: 165 | if channel_dim < 0: 166 | channel_dim += x.ndim 167 | sum_dims = [d for d in range(x.ndim) if d != channel_dim] 168 | proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) 169 | if min_positive == 0.0: 170 | factor1 = 0.0 171 | else: 172 | # 0 if proportion_positive >= min_positive, else can be 173 | # as large as max_factor. 174 | factor1 = ((min_positive - proportion_positive) * 175 | (gain_factor / min_positive)).clamp_( 176 | min=0, max=max_factor) 177 | 178 | if max_positive == 1.0: 179 | factor2 = 0.0 180 | else: 181 | # 0 if self.proportion_positive <= max_positive, else can be 182 | # as large as -max_factor. 183 | factor2 = ((proportion_positive - max_positive) * 184 | (gain_factor / (1.0 - max_positive))).clamp_( 185 | min=0, max=max_factor) 186 | sign_factor = factor1 - factor2 187 | # require min_positive != 0 or max_positive != 1: 188 | assert not isinstance(sign_factor, float) 189 | return sign_factor 190 | 191 | 192 | class ActivationBalancer(torch.nn.Module): 193 | """ 194 | Modifies the backpropped derivatives of a function to try to encourage, for 195 | each channel, that it is positive at least a proportion `threshold` of the 196 | time. It does this by multiplying negative derivative values by up to 197 | (1+max_factor), and positive derivative values by up to (1-max_factor), 198 | interpolated from 1 at the threshold to those extremal values when none 199 | of the inputs are positive. 200 | 201 | Args: 202 | num_channels: the number of channels 203 | channel_dim: the dimension/axis corresponding to the channel, e.g. 204 | -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. 205 | min_positive: the minimum, per channel, of the proportion of the time 206 | that (x > 0), below which we start to modify the derivatives. 207 | max_positive: the maximum, per channel, of the proportion of the time 208 | that (x > 0), above which we start to modify the derivatives. 209 | max_factor: the maximum factor by which we modify the derivatives for 210 | either the sign constraint or the magnitude constraint; 211 | e.g. with max_factor=0.02, the the derivatives would be multiplied by 212 | values in the range [0.98..1.02]. 213 | sign_gain_factor: determines the 'gain' with which we increase the 214 | change in gradient once the constraints on min_positive and max_positive 215 | are violated. 216 | scale_gain_factor: determines the 'gain' with which we increase the 217 | change in gradient once the constraints on min_abs and max_abs 218 | are violated. 219 | min_abs: the minimum average-absolute-value difference from the mean 220 | value per channel, which we allow, before we start to modify 221 | the derivatives to prevent this. 222 | max_abs: the maximum average-absolute-value difference from the mean 223 | value per channel, which we allow, before we start to modify 224 | the derivatives to prevent this. 225 | min_prob: determines the minimum probability with which we modify the 226 | gradients for the {min,max}_positive and {min,max}_abs constraints, 227 | on each forward(). This is done randomly to prevent all layers 228 | from doing it at the same time. Early in training we may use 229 | higher probabilities than this; it will decay to this value. 230 | """ 231 | 232 | def __init__( 233 | self, 234 | num_channels: int, 235 | channel_dim: int, 236 | min_positive: float=0.05, 237 | max_positive: float=0.95, 238 | max_factor: float=0.04, 239 | sign_gain_factor: float=0.01, 240 | scale_gain_factor: float=0.02, 241 | min_abs: float=0.2, 242 | max_abs: float=100.0, 243 | min_prob: float=0.1, ): 244 | super(ActivationBalancer, self).__init__() 245 | self.num_channels = num_channels 246 | self.channel_dim = channel_dim 247 | self.min_positive = min_positive 248 | self.max_positive = max_positive 249 | self.max_factor = max_factor 250 | self.min_abs = min_abs 251 | self.max_abs = max_abs 252 | self.min_prob = min_prob 253 | self.sign_gain_factor = sign_gain_factor 254 | self.scale_gain_factor = scale_gain_factor 255 | 256 | # count measures how many times the forward() function has been called. 257 | # We occasionally sync this to a tensor called `count`, that exists to 258 | # make sure it is synced to disk when we load and save the model. 259 | self.cpu_count = 0 260 | self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) 261 | 262 | def forward(self, x: Tensor) -> Tensor: 263 | if (torch.jit.is_scripting() or not x.requires_grad or 264 | torch.jit.is_tracing()): 265 | return _no_op(x) 266 | 267 | count = self.cpu_count 268 | self.cpu_count += 1 269 | 270 | if random.random() < 0.01: 271 | # Occasionally sync self.cpu_count with self.count. 272 | # count affects the decay of 'prob'. don't do this on every iter, 273 | # because syncing with the GPU is slow. 274 | self.cpu_count = max(self.cpu_count, self.count.item()) 275 | self.count.fill_(self.cpu_count) 276 | 277 | # the prob of doing some work exponentially decreases from 0.5 till it hits 278 | # a floor at min_prob (==0.1, by default) 279 | prob = max(self.min_prob, 0.5**(1 + (count / 4000.0))) 280 | 281 | if random.random() < prob: 282 | sign_gain_factor = 0.5 283 | if self.min_positive != 0.0 or self.max_positive != 1.0: 284 | sign_factor = _compute_sign_factor( 285 | x, 286 | self.channel_dim, 287 | self.min_positive, 288 | self.max_positive, 289 | gain_factor=self.sign_gain_factor / prob, 290 | max_factor=self.max_factor, ) 291 | else: 292 | sign_factor = None 293 | 294 | scale_factor = _compute_scale_factor( 295 | x.detach(), 296 | self.channel_dim, 297 | min_abs=self.min_abs, 298 | max_abs=self.max_abs, 299 | gain_factor=self.scale_gain_factor / prob, 300 | max_factor=self.max_factor, ) 301 | return ActivationBalancerFunction.apply( 302 | x, 303 | scale_factor, 304 | sign_factor, 305 | self.channel_dim, ) 306 | else: 307 | return _no_op(x) 308 | 309 | 310 | def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, 311 | min_prob=0.25) -> nn.Sequential: 312 | """ 313 | ActivationBalancer -> DoubleSwish 314 | """ 315 | balancer = ActivationBalancer( 316 | d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob) 317 | return nn.Sequential( 318 | balancer, 319 | DoubleSwish(), ) 320 | -------------------------------------------------------------------------------- /module/data_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import random 4 | import numpy as np 5 | import torch 6 | import torch.utils.data 7 | from tqdm import tqdm 8 | 9 | from module import commons 10 | from module.mel_processing import spectrogram_torch 11 | from text import cleaned_text_to_sequence 12 | from text.symbols import v 13 | from utils import load_wav_to_torch, load_filepaths_and_text 14 | import torch.nn.functional as F 15 | 16 | f0_bin = 64 17 | f0_max = 1100.0 18 | f0_min = 50.0 19 | f0_mel_min = 1127 * np.log(1 + f0_min / 700) 20 | f0_mel_max = 1127 * np.log(1 + f0_max / 700) 21 | 22 | 23 | def f0_to_coarse(f0): 24 | f0_mel = 1127 * (1 + f0 / 700).log() 25 | a = (f0_bin - 2) / (f0_mel_max - f0_mel_min) 26 | b = f0_mel_min * a - 1. 27 | f0_mel = torch.where(f0_mel > 0, f0_mel * a - b, f0_mel) 28 | # torch.clip_(f0_mel, min=1., max=float(f0_bin - 1)) 29 | f0_coarse = torch.round(f0_mel).long() 30 | f0_coarse = f0_coarse * (f0_coarse > 0) 31 | f0_coarse = f0_coarse + ((f0_coarse < 1) * 1) 32 | f0_coarse = f0_coarse * (f0_coarse < f0_bin) 33 | f0_coarse = f0_coarse + ((f0_coarse >= f0_bin) * (f0_bin - 1)) 34 | return f0_coarse 35 | 36 | 37 | """Multi speaker version""" 38 | 39 | 40 | class TextAudioSpeakerLoader(torch.utils.data.Dataset): 41 | """ 42 | 1) loads audio, speaker_id, text pairs 43 | 2) normalizes text and converts them to sequences of integers 44 | 3) computes spectrograms from audio files. 45 | """ 46 | 47 | def __init__(self, audiopaths_sid_text, hparams, get_path=False, meta=None, val=False, 48 | phoneme_path='dump/phoneme.npy'): 49 | self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text) 50 | self.max_wav_value = hparams.max_wav_value 51 | self.sampling_rate = hparams.sampling_rate 52 | self.filter_length = hparams.filter_length 53 | self.hop_length = hparams.hop_length 54 | self.win_length = hparams.win_length 55 | self.sampling_rate = hparams.sampling_rate 56 | self.val = val 57 | 58 | self.get_path = get_path 59 | self.meta = meta 60 | self.phoneme_data = np.load(phoneme_path, allow_pickle=True).item() 61 | 62 | random.seed(1234) 63 | random.shuffle(self.audiopaths_sid_text) 64 | self._filter() 65 | 66 | def _filter(self): 67 | """ 68 | Filter text & store spec lengths 69 | """ 70 | # Store spectrogram lengths for Bucketing 71 | # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) 72 | # spec_length = wav_length // hop_length 73 | print('only chinese and skip other languages!!!!!!!!') 74 | print('only chinese and skip other languages!!!!!!!!') 75 | print('only chinese and skip other languages!!!!!!!!') 76 | print('only chinese and skip other languages!!!!!!!!') 77 | 78 | print("phoneme_data_len:", len(self.phoneme_data.keys())) 79 | print("wav_data_len:", len(self.audiopaths_sid_text)) 80 | 81 | audiopaths_sid_text_new = [] 82 | lengths = [] 83 | skipped = 0 84 | for item in tqdm(self.audiopaths_sid_text): 85 | audiopath = item[0] 86 | if 'zh' not in audiopath: 87 | skipped += 1 88 | continue 89 | try: 90 | phoneme = self.phoneme_data[audiopath] 91 | phoneme = phoneme.split(' ') 92 | phoneme_ids = cleaned_text_to_sequence(phoneme) 93 | except Exception: 94 | skipped += 1 95 | continue 96 | 97 | bert_path = audiopath.replace('.wav', '.bert.pt').replace('.mp3', '.bert.pt') 98 | if not os.path.exists(bert_path): 99 | skipped += 1 100 | continue 101 | duration_path = audiopath.replace('.wav', '.dur.pt').replace('.mp3', '.dur.pt') 102 | if not os.path.exists(duration_path): 103 | skipped += 1 104 | continue 105 | sslpath = audiopath.replace('.wav', '.ssl.pt').replace('.mp3', '.ssl.pt') 106 | if os.path.exists(audiopath) and os.path.exists(sslpath) and ( 107 | os.path.getsize(audiopath) / self.sampling_rate / 2 > 0.6 or self.val): 108 | audiopaths_sid_text_new.append([audiopath, sslpath, bert_path, phoneme_ids, duration_path]) 109 | lengths.append(os.path.getsize(audiopath) // (2 * self.hop_length)) 110 | else: 111 | skipped += 1 112 | print("skipped: ", skipped, ", total: ", len(self.audiopaths_sid_text)) 113 | self.audiopaths_sid_text = audiopaths_sid_text_new 114 | self.lengths = lengths 115 | 116 | def get_audio_text_speaker_pair(self, audiopath_sid_text): 117 | # separate filename, speaker_id and text 118 | audiopath, sslpath, bert_path, phoneme_ids, duration_path = audiopath_sid_text 119 | phoneme_ids = commons.intersperse(phoneme_ids, 0) 120 | 121 | text = torch.LongTensor(phoneme_ids) 122 | try: 123 | spec, wav = self.get_audio(audiopath) 124 | except: 125 | spec = torch.zeros(1025, 100) 126 | wav = torch.zeros(1, 100 * self.hop_length) 127 | print("load audio error!!!!!!", audiopath) 128 | 129 | duration = torch.load(duration_path) 130 | assert duration.shape[0] == len(phoneme_ids), (duration.shape, phoneme_ids.shape) 131 | total_len = duration.sum().item() 132 | 133 | assert abs(total_len - spec.shape[-1]) < 3, (total_len, spec.shape[-1], audiopath) 134 | if spec.shape[-1] < total_len: 135 | spec = F.pad(spec, (0, total_len - spec.shape[-1])) 136 | wav = F.pad(wav, (0, total_len * self.hop_length - wav.shape[-1])) 137 | elif spec.shape[-1] > total_len: 138 | spec = spec[:, :total_len] 139 | wav = wav[:, :total_len * self.hop_length] 140 | 141 | ssl = torch.load(sslpath) 142 | ssl = F.interpolate(ssl, size=spec.shape[-1], mode="nearest") 143 | 144 | bert_feature = torch.load(bert_path) 145 | bert_feature = F.interpolate(bert_feature.unsqueeze(0), scale_factor=2, mode='nearest') 146 | bert_feature = F.pad(bert_feature, (0, 1), value=0).squeeze(0) 147 | 148 | assert bert_feature.shape[-1] == len(phoneme_ids) 149 | 150 | spk_emb = np.load(audiopath.replace('.wav', '.spk.npy').replace('.mp3', '.spk.npy')) 151 | spk_emb = torch.FloatTensor(spk_emb) 152 | 153 | return (ssl, spec, wav, text, bert_feature, spk_emb, duration) 154 | 155 | def get_audio(self, filename): 156 | audio, sampling_rate = load_wav_to_torch(filename) 157 | if sampling_rate != self.sampling_rate: 158 | raise ValueError("{} SR doesn't match target {} SR".format( 159 | sampling_rate, self.sampling_rate)) 160 | audio_norm = audio 161 | audio_norm = audio_norm.unsqueeze(0) 162 | spec = spectrogram_torch(audio_norm, self.filter_length, 163 | self.sampling_rate, self.hop_length, self.win_length, 164 | center=False) 165 | spec = torch.squeeze(spec, 0) 166 | return spec, audio_norm 167 | 168 | def get_sid(self, sid): 169 | sid = torch.LongTensor([int(sid)]) 170 | return sid 171 | 172 | def __getitem__(self, index): 173 | return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index]) 174 | 175 | def __len__(self): 176 | return len(self.audiopaths_sid_text) 177 | 178 | 179 | class TextAudioSpeakerCollate(): 180 | """ Zero-pads model inputs and targets 181 | """ 182 | 183 | def __init__(self, return_ids=False): 184 | self.return_ids = return_ids 185 | 186 | def __call__(self, batch): 187 | """Collate's training batch from normalized text, audio and speaker identities 188 | PARAMS 189 | ------ 190 | batch: [text_normalized, spec_normalized, wav_normalized, sid] 191 | """ 192 | # Right zero-pad all one-hot text sequences to max input length 193 | _, ids_sorted_decreasing = torch.sort( 194 | torch.LongTensor([x[1].size(1) for x in batch]), 195 | dim=0, descending=True) 196 | 197 | max_spec_len = max([x[1].size(1) for x in batch]) 198 | max_wav_len = max([x[2].size(1) for x in batch]) 199 | max_text_len = max([x[3].size(0) for x in batch]) 200 | 201 | spec_lengths = torch.LongTensor(len(batch)) 202 | wav_lengths = torch.LongTensor(len(batch)) 203 | text_lengths = torch.LongTensor(len(batch)) 204 | 205 | spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) 206 | wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) 207 | ssl_padded = torch.FloatTensor(len(batch), 768, max_spec_len) 208 | text_padded = torch.LongTensor(len(batch), max_text_len) 209 | bert_padded = torch.FloatTensor(len(batch), 1024, max_text_len) 210 | spk_emb_padded = torch.FloatTensor(len(batch), 256) 211 | duration_padded = torch.LongTensor(len(batch), max_text_len) 212 | 213 | spec_padded.zero_() 214 | wav_padded.zero_() 215 | ssl_padded.zero_() 216 | text_padded.zero_() 217 | bert_padded.zero_() 218 | spk_emb_padded.zero_() 219 | for i in range(len(ids_sorted_decreasing)): 220 | row = batch[ids_sorted_decreasing[i]] 221 | 222 | ssl = row[0] 223 | ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :] 224 | 225 | spec = row[1] 226 | spec_padded[i, :, :spec.size(1)] = spec 227 | spec_lengths[i] = spec.size(1) 228 | 229 | wav = row[2] 230 | wav_padded[i, :, :wav.size(1)] = wav 231 | wav_lengths[i] = wav.size(1) 232 | 233 | text = row[3] 234 | text_padded[i, :text.size(0)] = text 235 | text_lengths[i] = text.size(0) 236 | 237 | bert = row[4] 238 | bert_padded[i, :, :bert.size(-1)] = bert 239 | 240 | spk_emb = row[5] 241 | spk_emb_padded[i] = spk_emb 242 | 243 | duration = row[6] 244 | duration_padded[i, :duration.size(0)] = duration 245 | 246 | return ssl_padded, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths, bert_padded, spk_emb_padded, duration_padded 247 | 248 | 249 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): 250 | """ 251 | Maintain similar input lengths in a batch. 252 | Length groups are specified by boundaries. 253 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. 254 | 255 | It removes samples which are not included in the boundaries. 256 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. 257 | """ 258 | 259 | def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True): 260 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 261 | self.lengths = dataset.lengths 262 | self.batch_size = batch_size 263 | self.boundaries = boundaries 264 | 265 | self.buckets, self.num_samples_per_bucket = self._create_buckets() 266 | self.total_size = sum(self.num_samples_per_bucket) 267 | self.num_samples = self.total_size // self.num_replicas 268 | 269 | def _create_buckets(self): 270 | buckets = [[] for _ in range(len(self.boundaries) - 1)] 271 | for i in range(len(self.lengths)): 272 | length = self.lengths[i] 273 | idx_bucket = self._bisect(length) 274 | if idx_bucket != -1: 275 | buckets[idx_bucket].append(i) 276 | 277 | for i in range(len(buckets) - 1, 0, -1): 278 | if len(buckets[i]) == 0: 279 | buckets.pop(i) 280 | self.boundaries.pop(i + 1) 281 | 282 | num_samples_per_bucket = [] 283 | for i in range(len(buckets)): 284 | len_bucket = len(buckets[i]) 285 | total_batch_size = self.num_replicas * self.batch_size 286 | rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size 287 | num_samples_per_bucket.append(len_bucket + rem) 288 | return buckets, num_samples_per_bucket 289 | 290 | def __iter__(self): 291 | # deterministically shuffle based on epoch 292 | g = torch.Generator() 293 | g.manual_seed(self.epoch) 294 | 295 | indices = [] 296 | if self.shuffle: 297 | for bucket in self.buckets: 298 | indices.append(torch.randperm(len(bucket), generator=g).tolist()) 299 | else: 300 | for bucket in self.buckets: 301 | indices.append(list(range(len(bucket)))) 302 | 303 | batches = [] 304 | for i in range(len(self.buckets)): 305 | bucket = self.buckets[i] 306 | len_bucket = len(bucket) 307 | ids_bucket = indices[i] 308 | num_samples_bucket = self.num_samples_per_bucket[i] 309 | 310 | # add extra samples to make it evenly divisible 311 | rem = num_samples_bucket - len_bucket 312 | ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)] 313 | 314 | # subsample 315 | ids_bucket = ids_bucket[self.rank::self.num_replicas] 316 | 317 | # batching 318 | for j in range(len(ids_bucket) // self.batch_size): 319 | batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]] 320 | batches.append(batch) 321 | 322 | if self.shuffle: 323 | batch_ids = torch.randperm(len(batches), generator=g).tolist() 324 | batches = [batches[i] for i in batch_ids] 325 | self.batches = batches 326 | 327 | assert len(self.batches) * self.batch_size == self.num_samples 328 | return iter(self.batches) 329 | 330 | def _bisect(self, x, lo=0, hi=None): 331 | if hi is None: 332 | hi = len(self.boundaries) - 1 333 | 334 | if hi > lo: 335 | mid = (hi + lo) // 2 336 | if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: 337 | return mid 338 | elif x <= self.boundaries[mid]: 339 | return self._bisect(x, lo, mid) 340 | else: 341 | return self._bisect(x, mid + 1, hi) 342 | else: 343 | return -1 344 | 345 | def __len__(self): 346 | return self.num_samples // self.batch_size 347 | -------------------------------------------------------------------------------- /asr/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from typing import Optional, Any 5 | from torch import Tensor 6 | import torch.nn.functional as F 7 | import torchaudio 8 | import torchaudio.functional as audio_F 9 | 10 | import random 11 | random.seed(0) 12 | 13 | 14 | def _get_activation_fn(activ): 15 | if activ == 'relu': 16 | return nn.ReLU() 17 | elif activ == 'lrelu': 18 | return nn.LeakyReLU(0.2) 19 | elif activ == 'swish': 20 | return lambda x: x*torch.sigmoid(x) 21 | else: 22 | raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ) 23 | 24 | class LinearNorm(torch.nn.Module): 25 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 26 | super(LinearNorm, self).__init__() 27 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 28 | 29 | torch.nn.init.xavier_uniform_( 30 | self.linear_layer.weight, 31 | gain=torch.nn.init.calculate_gain(w_init_gain)) 32 | 33 | def forward(self, x): 34 | return self.linear_layer(x) 35 | 36 | 37 | class ConvNorm(torch.nn.Module): 38 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, 39 | padding=None, dilation=1, bias=True, w_init_gain='linear', param=None): 40 | super(ConvNorm, self).__init__() 41 | if padding is None: 42 | assert(kernel_size % 2 == 1) 43 | padding = int(dilation * (kernel_size - 1) / 2) 44 | 45 | self.conv = torch.nn.Conv1d(in_channels, out_channels, 46 | kernel_size=kernel_size, stride=stride, 47 | padding=padding, dilation=dilation, 48 | bias=bias) 49 | 50 | torch.nn.init.xavier_uniform_( 51 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) 52 | 53 | def forward(self, signal): 54 | conv_signal = self.conv(signal) 55 | return conv_signal 56 | 57 | class CausualConv(nn.Module): 58 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None): 59 | super(CausualConv, self).__init__() 60 | if padding is None: 61 | assert(kernel_size % 2 == 1) 62 | padding = int(dilation * (kernel_size - 1) / 2) * 2 63 | else: 64 | self.padding = padding * 2 65 | self.conv = nn.Conv1d(in_channels, out_channels, 66 | kernel_size=kernel_size, stride=stride, 67 | padding=self.padding, 68 | dilation=dilation, 69 | bias=bias) 70 | 71 | torch.nn.init.xavier_uniform_( 72 | self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) 73 | 74 | def forward(self, x): 75 | x = self.conv(x) 76 | x = x[:, :, :-self.padding] 77 | return x 78 | 79 | class CausualBlock(nn.Module): 80 | def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'): 81 | super(CausualBlock, self).__init__() 82 | self.blocks = nn.ModuleList([ 83 | self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p) 84 | for i in range(n_conv)]) 85 | 86 | def forward(self, x): 87 | for block in self.blocks: 88 | res = x 89 | x = block(x) 90 | x += res 91 | return x 92 | 93 | def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2): 94 | layers = [ 95 | CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), 96 | _get_activation_fn(activ), 97 | nn.BatchNorm1d(hidden_dim), 98 | nn.Dropout(p=dropout_p), 99 | CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), 100 | _get_activation_fn(activ), 101 | nn.Dropout(p=dropout_p) 102 | ] 103 | return nn.Sequential(*layers) 104 | 105 | class ConvBlock(nn.Module): 106 | def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'): 107 | super().__init__() 108 | self._n_groups = 8 109 | self.blocks = nn.ModuleList([ 110 | self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p) 111 | for i in range(n_conv)]) 112 | 113 | 114 | def forward(self, x): 115 | for block in self.blocks: 116 | res = x 117 | x = block(x) 118 | x += res 119 | return x 120 | 121 | def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2): 122 | layers = [ 123 | ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation), 124 | _get_activation_fn(activ), 125 | nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim), 126 | nn.Dropout(p=dropout_p), 127 | ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1), 128 | _get_activation_fn(activ), 129 | nn.Dropout(p=dropout_p) 130 | ] 131 | return nn.Sequential(*layers) 132 | 133 | class LocationLayer(nn.Module): 134 | def __init__(self, attention_n_filters, attention_kernel_size, 135 | attention_dim): 136 | super(LocationLayer, self).__init__() 137 | padding = int((attention_kernel_size - 1) / 2) 138 | self.location_conv = ConvNorm(2, attention_n_filters, 139 | kernel_size=attention_kernel_size, 140 | padding=padding, bias=False, stride=1, 141 | dilation=1) 142 | self.location_dense = LinearNorm(attention_n_filters, attention_dim, 143 | bias=False, w_init_gain='tanh') 144 | 145 | def forward(self, attention_weights_cat): 146 | processed_attention = self.location_conv(attention_weights_cat) 147 | processed_attention = processed_attention.transpose(1, 2) 148 | processed_attention = self.location_dense(processed_attention) 149 | return processed_attention 150 | 151 | 152 | class Attention(nn.Module): 153 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, 154 | attention_location_n_filters, attention_location_kernel_size): 155 | super(Attention, self).__init__() 156 | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, 157 | bias=False, w_init_gain='tanh') 158 | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, 159 | w_init_gain='tanh') 160 | self.v = LinearNorm(attention_dim, 1, bias=False) 161 | self.location_layer = LocationLayer(attention_location_n_filters, 162 | attention_location_kernel_size, 163 | attention_dim) 164 | self.score_mask_value = -float("inf") 165 | 166 | def get_alignment_energies(self, query, processed_memory, 167 | attention_weights_cat): 168 | """ 169 | PARAMS 170 | ------ 171 | query: decoder output (batch, n_mel_channels * n_frames_per_step) 172 | processed_memory: processed encoder outputs (B, T_in, attention_dim) 173 | attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) 174 | RETURNS 175 | ------- 176 | alignment (batch, max_time) 177 | """ 178 | 179 | processed_query = self.query_layer(query.unsqueeze(1)) 180 | processed_attention_weights = self.location_layer(attention_weights_cat) 181 | energies = self.v(torch.tanh( 182 | processed_query + processed_attention_weights + processed_memory)) 183 | 184 | energies = energies.squeeze(-1) 185 | return energies 186 | 187 | def forward(self, attention_hidden_state, memory, processed_memory, 188 | attention_weights_cat, mask): 189 | """ 190 | PARAMS 191 | ------ 192 | attention_hidden_state: attention rnn last output 193 | memory: encoder outputs 194 | processed_memory: processed encoder outputs 195 | attention_weights_cat: previous and cummulative attention weights 196 | mask: binary mask for padded data 197 | """ 198 | alignment = self.get_alignment_energies( 199 | attention_hidden_state, processed_memory, attention_weights_cat) 200 | 201 | if mask is not None: 202 | alignment.data.masked_fill_(mask, self.score_mask_value) 203 | 204 | attention_weights = F.softmax(alignment, dim=1) 205 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) 206 | attention_context = attention_context.squeeze(1) 207 | 208 | return attention_context, attention_weights 209 | 210 | 211 | class ForwardAttentionV2(nn.Module): 212 | def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, 213 | attention_location_n_filters, attention_location_kernel_size): 214 | super(ForwardAttentionV2, self).__init__() 215 | self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, 216 | bias=False, w_init_gain='tanh') 217 | self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, 218 | w_init_gain='tanh') 219 | self.v = LinearNorm(attention_dim, 1, bias=False) 220 | self.location_layer = LocationLayer(attention_location_n_filters, 221 | attention_location_kernel_size, 222 | attention_dim) 223 | self.score_mask_value = -float(1e20) 224 | 225 | def get_alignment_energies(self, query, processed_memory, 226 | attention_weights_cat): 227 | """ 228 | PARAMS 229 | ------ 230 | query: decoder output (batch, n_mel_channels * n_frames_per_step) 231 | processed_memory: processed encoder outputs (B, T_in, attention_dim) 232 | attention_weights_cat: prev. and cumulative att weights (B, 2, max_time) 233 | RETURNS 234 | ------- 235 | alignment (batch, max_time) 236 | """ 237 | 238 | processed_query = self.query_layer(query.unsqueeze(1)) 239 | processed_attention_weights = self.location_layer(attention_weights_cat) 240 | energies = self.v(torch.tanh( 241 | processed_query + processed_attention_weights + processed_memory)) 242 | 243 | energies = energies.squeeze(-1) 244 | return energies 245 | 246 | def forward(self, attention_hidden_state, memory, processed_memory, 247 | attention_weights_cat, mask, log_alpha): 248 | """ 249 | PARAMS 250 | ------ 251 | attention_hidden_state: attention rnn last output 252 | memory: encoder outputs 253 | processed_memory: processed encoder outputs 254 | attention_weights_cat: previous and cummulative attention weights 255 | mask: binary mask for padded data 256 | """ 257 | log_energy = self.get_alignment_energies( 258 | attention_hidden_state, processed_memory, attention_weights_cat) 259 | 260 | #log_energy = 261 | 262 | if mask is not None: 263 | log_energy.data.masked_fill_(mask, self.score_mask_value) 264 | 265 | #attention_weights = F.softmax(alignment, dim=1) 266 | 267 | #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME] 268 | #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1] 269 | 270 | #log_total_score = log_alpha + content_score 271 | 272 | #previous_attention_weights = attention_weights_cat[:,0,:] 273 | 274 | log_alpha_shift_padded = [] 275 | max_time = log_energy.size(1) 276 | for sft in range(2): 277 | shifted = log_alpha[:,:max_time-sft] 278 | shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value) 279 | log_alpha_shift_padded.append(shift_padded.unsqueeze(2)) 280 | 281 | biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2) 282 | 283 | log_alpha_new = biased + log_energy 284 | 285 | attention_weights = F.softmax(log_alpha_new, dim=1) 286 | 287 | attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) 288 | attention_context = attention_context.squeeze(1) 289 | 290 | return attention_context, attention_weights, log_alpha_new 291 | 292 | 293 | class PhaseShuffle2d(nn.Module): 294 | def __init__(self, n=2): 295 | super(PhaseShuffle2d, self).__init__() 296 | self.n = n 297 | self.random = random.Random(1) 298 | 299 | def forward(self, x, move=None): 300 | # x.size = (B, C, M, L) 301 | if move is None: 302 | move = self.random.randint(-self.n, self.n) 303 | 304 | if move == 0: 305 | return x 306 | else: 307 | left = x[:, :, :, :move] 308 | right = x[:, :, :, move:] 309 | shuffled = torch.cat([right, left], dim=3) 310 | return shuffled 311 | 312 | class PhaseShuffle1d(nn.Module): 313 | def __init__(self, n=2): 314 | super(PhaseShuffle1d, self).__init__() 315 | self.n = n 316 | self.random = random.Random(1) 317 | 318 | def forward(self, x, move=None): 319 | # x.size = (B, C, M, L) 320 | if move is None: 321 | move = self.random.randint(-self.n, self.n) 322 | 323 | if move == 0: 324 | return x 325 | else: 326 | left = x[:, :, :move] 327 | right = x[:, :, move:] 328 | shuffled = torch.cat([right, left], dim=2) 329 | 330 | return shuffled 331 | 332 | class MFCC(nn.Module): 333 | def __init__(self, n_mfcc=64, n_mels=128): 334 | super(MFCC, self).__init__() 335 | self.n_mfcc = n_mfcc 336 | self.n_mels = n_mels 337 | self.norm = 'ortho' 338 | dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm) 339 | self.register_buffer('dct_mat', dct_mat) 340 | 341 | def forward(self, mel_specgram): 342 | if len(mel_specgram.shape) == 2: 343 | mel_specgram = mel_specgram.unsqueeze(0) 344 | unsqueezed = True 345 | else: 346 | unsqueezed = False 347 | # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) 348 | # -> (channel, time, n_mfcc).tranpose(...) 349 | mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) 350 | 351 | # unpack batch 352 | if unsqueezed: 353 | mfcc = mfcc.squeeze(0) 354 | return mfcc 355 | -------------------------------------------------------------------------------- /module/core_vq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | # This implementation is inspired from 8 | # https://github.com/lucidrains/vector-quantize-pytorch 9 | # which is released under MIT License. Hereafter, the original license: 10 | # MIT License 11 | # 12 | # Copyright (c) 2020 Phil Wang 13 | # 14 | # Permission is hereby granted, free of charge, to any person obtaining a copy 15 | # of this software and associated documentation files (the "Software"), to deal 16 | # in the Software without restriction, including without limitation the rights 17 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 18 | # copies of the Software, and to permit persons to whom the Software is 19 | # furnished to do so, subject to the following conditions: 20 | # 21 | # The above copyright notice and this permission notice shall be included in all 22 | # copies or substantial portions of the Software. 23 | # 24 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 25 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 26 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 27 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 28 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 29 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 30 | # SOFTWARE. 31 | 32 | """Core vector quantization implementation.""" 33 | import typing as tp 34 | 35 | from einops import rearrange, repeat 36 | import torch 37 | from torch import nn 38 | import torch.nn.functional as F 39 | from tqdm import tqdm 40 | 41 | 42 | def default(val: tp.Any, d: tp.Any) -> tp.Any: 43 | return val if val is not None else d 44 | 45 | 46 | def ema_inplace(moving_avg, new, decay: float): 47 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) 48 | 49 | 50 | def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): 51 | return (x + epsilon) / (x.sum() + n_categories * epsilon) 52 | 53 | 54 | def uniform_init(*shape: int): 55 | t = torch.empty(shape) 56 | nn.init.kaiming_uniform_(t) 57 | return t 58 | 59 | 60 | def sample_vectors(samples, num: int): 61 | num_samples, device = samples.shape[0], samples.device 62 | 63 | if num_samples >= num: 64 | indices = torch.randperm(num_samples, device=device)[:num] 65 | else: 66 | indices = torch.randint(0, num_samples, (num,), device=device) 67 | 68 | return samples[indices] 69 | 70 | 71 | def kmeans(samples, num_clusters: int, num_iters: int = 10): 72 | dim, dtype = samples.shape[-1], samples.dtype 73 | max_kmeans_samples = 1400 74 | samples = samples[:max_kmeans_samples, :] 75 | means = sample_vectors(samples, num_clusters) 76 | 77 | print("kmeans start ... ") 78 | for _ in tqdm(range(num_iters)): 79 | diffs = rearrange(samples, "n d -> n () d") - rearrange( 80 | means, "c d -> () c d" 81 | ) 82 | dists = -(diffs ** 2).sum(dim=-1) 83 | 84 | buckets = dists.max(dim=-1).indices 85 | bins = torch.bincount(buckets, minlength=num_clusters) 86 | zero_mask = bins == 0 87 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 88 | 89 | new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) 90 | new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) 91 | new_means = new_means / bins_min_clamped[..., None] 92 | 93 | means = torch.where(zero_mask[..., None], means, new_means) 94 | 95 | return means, bins 96 | 97 | 98 | class EuclideanCodebook(nn.Module): 99 | """Codebook with Euclidean distance. 100 | Args: 101 | dim (int): Dimension. 102 | codebook_size (int): Codebook size. 103 | kmeans_init (bool): Whether to use k-means to initialize the codebooks. 104 | If set to true, run the k-means algorithm on the first training batch and use 105 | the learned centroids as initialization. 106 | kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. 107 | decay (float): Decay for exponential moving average over the codebooks. 108 | epsilon (float): Epsilon value for numerical stability. 109 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 110 | that have an exponential moving average cluster size less than the specified threshold with 111 | randomly selected vector from the current batch. 112 | """ 113 | def __init__( 114 | self, 115 | dim: int, 116 | codebook_size: int, 117 | kmeans_init: int = False, 118 | kmeans_iters: int = 10, 119 | decay: float = 0.99, 120 | epsilon: float = 1e-5, 121 | threshold_ema_dead_code: int = 2, 122 | ): 123 | super().__init__() 124 | self.decay = decay 125 | init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros 126 | embed = init_fn(codebook_size, dim) 127 | 128 | self.codebook_size = codebook_size 129 | 130 | self.kmeans_iters = kmeans_iters 131 | self.epsilon = epsilon 132 | self.threshold_ema_dead_code = threshold_ema_dead_code 133 | 134 | self.register_buffer("inited", torch.Tensor([not kmeans_init])) 135 | self.register_buffer("cluster_size", torch.zeros(codebook_size)) 136 | self.register_buffer("embed", embed) 137 | self.register_buffer("embed_avg", embed.clone()) 138 | 139 | @torch.jit.ignore 140 | def init_embed_(self, data): 141 | if self.inited: 142 | return 143 | 144 | embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) 145 | self.embed.data.copy_(embed) 146 | self.embed_avg.data.copy_(embed.clone()) 147 | self.cluster_size.data.copy_(cluster_size) 148 | self.inited.data.copy_(torch.Tensor([True])) 149 | # Make sure all buffers across workers are in sync after initialization 150 | #broadcast_tensors(self.buffers()) 151 | 152 | def replace_(self, samples, mask): 153 | modified_codebook = torch.where( 154 | mask[..., None], sample_vectors(samples, self.codebook_size), self.embed 155 | ) 156 | self.embed.data.copy_(modified_codebook) 157 | 158 | def expire_codes_(self, batch_samples): 159 | if self.threshold_ema_dead_code == 0: 160 | return 161 | 162 | expired_codes = self.cluster_size < self.threshold_ema_dead_code 163 | if not torch.any(expired_codes): 164 | return 165 | 166 | batch_samples = rearrange(batch_samples, "... d -> (...) d") 167 | self.replace_(batch_samples, mask=expired_codes) 168 | #broadcast_tensors(self.buffers()) 169 | 170 | def preprocess(self, x): 171 | x = rearrange(x, "... d -> (...) d") 172 | return x 173 | 174 | def quantize(self, x): 175 | embed = self.embed.t() 176 | dist = -( 177 | x.pow(2).sum(1, keepdim=True) 178 | - 2 * x @ embed 179 | + embed.pow(2).sum(0, keepdim=True) 180 | ) 181 | embed_ind = dist.max(dim=-1).indices 182 | return embed_ind 183 | 184 | def postprocess_emb(self, embed_ind, shape): 185 | return embed_ind.view(*shape[:-1]) 186 | 187 | def dequantize(self, embed_ind): 188 | quantize = F.embedding(embed_ind, self.embed) 189 | return quantize 190 | 191 | def encode(self, x): 192 | shape = x.shape 193 | # pre-process 194 | x = self.preprocess(x) 195 | # quantize 196 | embed_ind = self.quantize(x) 197 | # post-process 198 | embed_ind = self.postprocess_emb(embed_ind, shape) 199 | return embed_ind 200 | 201 | def decode(self, embed_ind): 202 | quantize = self.dequantize(embed_ind) 203 | return quantize 204 | 205 | def forward(self, x): 206 | shape, dtype = x.shape, x.dtype 207 | x = self.preprocess(x) 208 | 209 | self.init_embed_(x) 210 | 211 | embed_ind = self.quantize(x) 212 | embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) 213 | embed_ind = self.postprocess_emb(embed_ind, shape) 214 | quantize = self.dequantize(embed_ind) 215 | 216 | if self.training: 217 | # We do the expiry of code at that point as buffers are in sync 218 | # and all the workers will take the same decision. 219 | self.expire_codes_(x) 220 | ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) 221 | embed_sum = x.t() @ embed_onehot 222 | ema_inplace(self.embed_avg, embed_sum.t(), self.decay) 223 | cluster_size = ( 224 | laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) 225 | * self.cluster_size.sum() 226 | ) 227 | embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) 228 | self.embed.data.copy_(embed_normalized) 229 | 230 | return quantize, embed_ind 231 | 232 | 233 | class VectorQuantization(nn.Module): 234 | """Vector quantization implementation. 235 | Currently supports only euclidean distance. 236 | Args: 237 | dim (int): Dimension 238 | codebook_size (int): Codebook size 239 | codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. 240 | decay (float): Decay for exponential moving average over the codebooks. 241 | epsilon (float): Epsilon value for numerical stability. 242 | kmeans_init (bool): Whether to use kmeans to initialize the codebooks. 243 | kmeans_iters (int): Number of iterations used for kmeans initialization. 244 | threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes 245 | that have an exponential moving average cluster size less than the specified threshold with 246 | randomly selected vector from the current batch. 247 | commitment_weight (float): Weight for commitment loss. 248 | """ 249 | def __init__( 250 | self, 251 | dim: int, 252 | codebook_size: int, 253 | codebook_dim: tp.Optional[int] = None, 254 | decay: float = 0.99, 255 | epsilon: float = 1e-5, 256 | kmeans_init: bool = True, 257 | kmeans_iters: int = 50, 258 | threshold_ema_dead_code: int = 2, 259 | commitment_weight: float = 1., 260 | ): 261 | super().__init__() 262 | _codebook_dim: int = default(codebook_dim, dim) 263 | 264 | requires_projection = _codebook_dim != dim 265 | self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) 266 | self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) 267 | 268 | self.epsilon = epsilon 269 | self.commitment_weight = commitment_weight 270 | 271 | self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, 272 | kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, 273 | decay=decay, epsilon=epsilon, 274 | threshold_ema_dead_code=threshold_ema_dead_code) 275 | self.codebook_size = codebook_size 276 | 277 | @property 278 | def codebook(self): 279 | return self._codebook.embed 280 | 281 | def encode(self, x): 282 | x = rearrange(x, "b d n -> b n d") 283 | x = self.project_in(x) 284 | embed_in = self._codebook.encode(x) 285 | return embed_in 286 | 287 | def decode(self, embed_ind): 288 | quantize = self._codebook.decode(embed_ind) 289 | quantize = self.project_out(quantize) 290 | quantize = rearrange(quantize, "b n d -> b d n") 291 | return quantize 292 | 293 | def forward(self, x): 294 | device = x.device 295 | x = rearrange(x, "b d n -> b n d") 296 | x = self.project_in(x) 297 | 298 | quantize, embed_ind = self._codebook(x) 299 | 300 | if self.training: 301 | quantize = x + (quantize - x).detach() 302 | 303 | loss = torch.tensor([0.0], device=device, requires_grad=self.training) 304 | 305 | if self.training: 306 | if self.commitment_weight > 0: 307 | commit_loss = F.mse_loss(quantize.detach(), x) 308 | loss = loss + commit_loss * self.commitment_weight 309 | 310 | quantize = self.project_out(quantize) 311 | quantize = rearrange(quantize, "b n d -> b d n") 312 | return quantize, embed_ind, loss 313 | 314 | 315 | class ResidualVectorQuantization(nn.Module): 316 | """Residual vector quantization implementation. 317 | Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf 318 | """ 319 | def __init__(self, *, num_quantizers, **kwargs): 320 | super().__init__() 321 | self.layers = nn.ModuleList( 322 | [VectorQuantization(**kwargs) for _ in range(num_quantizers)] 323 | ) 324 | 325 | def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None): 326 | quantized_out = 0.0 327 | residual = x 328 | 329 | all_losses = [] 330 | all_indices = [] 331 | out_quantized = [] 332 | 333 | n_q = n_q or len(self.layers) 334 | 335 | for i, layer in enumerate(self.layers[:n_q]): 336 | quantized, indices, loss = layer(residual) 337 | residual = residual - quantized 338 | quantized_out = quantized_out + quantized 339 | 340 | all_indices.append(indices) 341 | all_losses.append(loss) 342 | if layers and i in layers: 343 | out_quantized.append(quantized) 344 | 345 | out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) 346 | return quantized_out, out_indices, out_losses, out_quantized 347 | 348 | def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int]= None) -> torch.Tensor: 349 | residual = x 350 | all_indices = [] 351 | n_q = n_q or len(self.layers) 352 | st = st or 0 353 | for layer in self.layers[st:n_q]: 354 | indices = layer.encode(residual) 355 | quantized = layer.decode(indices) 356 | residual = residual - quantized 357 | all_indices.append(indices) 358 | out_indices = torch.stack(all_indices) 359 | return out_indices 360 | 361 | def decode(self, q_indices: torch.Tensor, st: int=0) -> torch.Tensor: 362 | quantized_out = torch.tensor(0.0, device=q_indices.device) 363 | for i, indices in enumerate(q_indices): 364 | layer = self.layers[st + i] 365 | quantized = layer.decode(indices) 366 | quantized_out = quantized_out + quantized 367 | return quantized_out -------------------------------------------------------------------------------- /vits_train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.nn import functional as F 4 | from torch.utils.data import DataLoader 5 | from torch.utils.tensorboard import SummaryWriter 6 | import torch.multiprocessing as mp 7 | import torch.distributed as dist 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | from torch.cuda.amp import autocast, GradScaler 10 | from tqdm import tqdm 11 | import logging 12 | logging.getLogger("matplotlib").setLevel(logging.INFO) 13 | logging.getLogger("h5py").setLevel(logging.INFO) 14 | logging.getLogger("numba").setLevel(logging.INFO) 15 | 16 | from module import commons 17 | import utils 18 | from module.data_utils import ( 19 | TextAudioSpeakerLoader, 20 | TextAudioSpeakerCollate, 21 | DistributedBucketSampler 22 | ) 23 | from module.models import ( 24 | SynthesizerTrn, 25 | MultiPeriodDiscriminator, 26 | ) 27 | from module.losses import ( 28 | generator_loss, 29 | discriminator_loss, 30 | feature_loss, 31 | kl_loss 32 | ) 33 | from module.mel_processing import mel_spectrogram_torch, spec_to_mel_torch 34 | 35 | torch.backends.cudnn.benchmark = True 36 | global_step = 0 37 | 38 | 39 | def main(): 40 | """Assume Single Node Multi GPUs Training Only""" 41 | assert torch.cuda.is_available(), "CPU training is not allowed." 42 | 43 | n_gpus = torch.cuda.device_count() 44 | os.environ['MASTER_ADDR'] = 'localhost' 45 | os.environ['MASTER_PORT'] = '8000' 46 | 47 | hps = utils.get_hparams(stage=2) 48 | mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps,)) 49 | 50 | 51 | def run(rank, n_gpus, hps): 52 | global global_step 53 | if rank == 0: 54 | logger = utils.get_logger(hps.s2_ckpt_dir) 55 | logger.info(hps) 56 | utils.check_git_hash(hps.s2_ckpt_dir) 57 | writer = SummaryWriter(log_dir=hps.s2_ckpt_dir) 58 | writer_eval = SummaryWriter(log_dir=os.path.join(hps.s2_ckpt_dir, "eval")) 59 | 60 | dist.init_process_group(backend='gloo' if os.name == 'nt' else 'nccl', init_method='env://', world_size=n_gpus, 61 | rank=rank) 62 | torch.manual_seed(hps.train.seed) 63 | torch.cuda.set_device(rank) 64 | 65 | train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data) 66 | train_sampler = DistributedBucketSampler( 67 | train_dataset, 68 | hps.train.batch_size, 69 | [32, 300, 400, 500, 600, 700, 800, 900, 1000], 70 | num_replicas=n_gpus, 71 | rank=rank, 72 | shuffle=True) 73 | collate_fn = TextAudioSpeakerCollate() 74 | train_loader = DataLoader(train_dataset, num_workers=6, shuffle=False, pin_memory=True, 75 | collate_fn=collate_fn, batch_sampler=train_sampler, persistent_workers=True) 76 | if rank == 0: 77 | eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data, val=True) 78 | eval_loader = DataLoader(eval_dataset, num_workers=0, shuffle=False, 79 | batch_size=1, pin_memory=True, 80 | drop_last=False, collate_fn=collate_fn) 81 | 82 | net_g = SynthesizerTrn( 83 | hps.data.filter_length // 2 + 1, 84 | hps.train.segment_size // hps.data.hop_length, 85 | n_speakers=hps.data.n_speakers, 86 | **hps.model).cuda(rank) 87 | 88 | net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) 89 | for name, param in net_g.named_parameters(): 90 | if not param.requires_grad: 91 | print(name,"not requires_grad") 92 | optim_g = torch.optim.AdamW( 93 | net_g.parameters(), 94 | hps.train.learning_rate, 95 | betas=hps.train.betas, 96 | eps=hps.train.eps) 97 | optim_d = torch.optim.AdamW( 98 | net_d.parameters(), 99 | hps.train.learning_rate, 100 | betas=hps.train.betas, 101 | eps=hps.train.eps) 102 | net_g = DDP(net_g, device_ids=[rank]) 103 | net_d = DDP(net_d, device_ids=[rank]) 104 | 105 | pretrain_dir = hps.pretrain 106 | if pretrain_dir is None: 107 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "G_*.pth"), net_g, 108 | optim_g, False) 109 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.s2_ckpt_dir, "D_*.pth"), net_d, 110 | optim_d, False) 111 | epoch_str = max(epoch_str, 1) 112 | global_step = (epoch_str - 1) * len(train_loader) 113 | else: 114 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(pretrain_dir, "G_*.pth"), net_g, 115 | optim_g, True) 116 | _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(pretrain_dir, "D_*.pth"), net_d, 117 | optim_d, True) 118 | epoch_str = 1 119 | global_step = 0 120 | 121 | if hps.resume_step != None: 122 | global_step = hps.resume_step 123 | 124 | 125 | scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) 126 | scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) 127 | 128 | scaler = GradScaler(enabled=hps.train.fp16_run) 129 | 130 | for epoch in range(epoch_str, hps.train.epochs + 1): 131 | if rank == 0: 132 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, 133 | [train_loader, eval_loader], logger, [writer, writer_eval]) 134 | else: 135 | train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, 136 | [train_loader, None], None, None) 137 | scheduler_g.step() 138 | scheduler_d.step() 139 | 140 | 141 | def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): 142 | net_g, net_d = nets 143 | optim_g, optim_d = optims 144 | scheduler_g, scheduler_d = schedulers 145 | train_loader, eval_loader = loaders 146 | if writers is not None: 147 | writer, writer_eval = writers 148 | 149 | train_loader.batch_sampler.set_epoch(epoch) 150 | global global_step 151 | 152 | net_g.train() 153 | net_d.train() 154 | for batch_idx, (ssl, spec, spec_lengths, y, y_lengths, text, text_lengths, bert,spk_emb_padded,duration) in tqdm(enumerate(train_loader)): 155 | spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True) 156 | y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True) 157 | ssl = ssl.cuda(rank, non_blocking=True) 158 | bert = bert.cuda(rank, non_blocking=True) 159 | text, text_lengths = text.cuda(rank, non_blocking=True), text_lengths.cuda(rank, non_blocking=True) 160 | duration = duration.cuda(rank, non_blocking=True) 161 | spk_emb_padded = spk_emb_padded.cuda(rank, non_blocking=True) 162 | 163 | with autocast(enabled=hps.train.fp16_run): 164 | y_hat, attn, ids_slice, x_mask, z_mask, \ 165 | (z, z_p, m_p, logs_p, m_q, logs_q), z_q, quantize_loss, dur_loss, pred_dur, prosody_predict_loss = net_g(text, text_lengths, spec, spec_lengths, ssl, duration, bert, spk_emb_padded) 166 | 167 | mel = spec_to_mel_torch( 168 | spec, 169 | hps.data.filter_length, 170 | hps.data.n_mel_channels, 171 | hps.data.sampling_rate, 172 | hps.data.mel_fmin, 173 | hps.data.mel_fmax) 174 | y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) 175 | y_hat_mel = mel_spectrogram_torch( 176 | y_hat.squeeze(1), 177 | hps.data.filter_length, 178 | hps.data.n_mel_channels, 179 | hps.data.sampling_rate, 180 | hps.data.hop_length, 181 | hps.data.win_length, 182 | hps.data.mel_fmin, 183 | hps.data.mel_fmax 184 | ) 185 | 186 | y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice 187 | 188 | # Discriminator 189 | y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) 190 | with autocast(enabled=False): 191 | loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) 192 | loss_disc_all = loss_disc 193 | optim_d.zero_grad() 194 | scaler.scale(loss_disc_all).backward() 195 | scaler.unscale_(optim_d) 196 | grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None) 197 | scaler.step(optim_d) 198 | 199 | with autocast(enabled=hps.train.fp16_run): 200 | # Generator 201 | y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) 202 | with autocast(enabled=False): 203 | loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel 204 | loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl 205 | loss_dur = torch.sum(dur_loss.float()) 206 | loss_fm = feature_loss(fmap_r, fmap_g) 207 | loss_gen, losses_gen = generator_loss(y_d_hat_g) 208 | loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl + quantize_loss + loss_dur + prosody_predict_loss 209 | 210 | optim_g.zero_grad() 211 | scaler.scale(loss_gen_all).backward() 212 | scaler.unscale_(optim_g) 213 | grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None) 214 | scaler.step(optim_g) 215 | scaler.update() 216 | 217 | if rank == 0: 218 | if global_step % hps.train.log_interval == 0: 219 | lr = optim_g.param_groups[0]['lr'] 220 | losses = [loss_disc, loss_gen, loss_fm, loss_mel, quantize_loss, loss_kl] 221 | logger.info('Train Epoch: {} [{:.0f}%]'.format( 222 | epoch, 223 | 100. * batch_idx / len(train_loader))) 224 | logger.info([x.item() for x in losses] + [global_step, lr]) 225 | 226 | scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, 227 | "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} 228 | scalar_dict.update( 229 | {"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl, 230 | "loss/g/quantize_loss": quantize_loss,'dur_loss':loss_dur,'prosody_predict_loss':prosody_predict_loss}) 231 | 232 | # scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) 233 | # scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) 234 | # scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) 235 | image_dict = { 236 | "all/mel": utils.plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), 237 | "all/z_q": utils.plot_spectrogram_to_numpy(z_q[0].data.cpu().numpy()), 238 | "all/attn": utils.plot_alignment_to_numpy(attn[0, 0].data.cpu().numpy()), 239 | } 240 | utils.summarize( 241 | writer=writer, 242 | global_step=global_step, 243 | images=image_dict, 244 | scalars=scalar_dict) 245 | 246 | if global_step % hps.train.eval_interval == 0: 247 | evaluate(hps, net_g, eval_loader, writer_eval) 248 | utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, 249 | os.path.join(hps.s2_ckpt_dir, "G_{}.pth".format(global_step))) 250 | utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, 251 | os.path.join(hps.s2_ckpt_dir, "D_{}.pth".format(global_step))) 252 | keep_ckpts = getattr(hps.train, 'keep_ckpts', 3) 253 | if keep_ckpts > 0: 254 | utils.clean_checkpoints(path_to_models=hps.s2_ckpt_dir, n_ckpts_to_keep=keep_ckpts, sort_by_time=True) 255 | 256 | 257 | global_step += 1 258 | 259 | if rank == 0: 260 | logger.info('====> Epoch: {}'.format(epoch)) 261 | 262 | 263 | 264 | def evaluate(hps, generator, eval_loader, writer_eval): 265 | generator.eval() 266 | image_dict = {} 267 | audio_dict = {} 268 | print("Evaluating ...") 269 | with torch.no_grad(): 270 | for batch_idx, (ssl, spec, spec_lengths, y, y_lengths, text, text_lengths, bert,sid,duration) in enumerate(eval_loader): 271 | print(111) 272 | spec, spec_lengths = spec.cuda(), spec_lengths.cuda() 273 | y, y_lengths = y.cuda(), y_lengths.cuda() 274 | ssl = ssl.cuda() 275 | text, text_lengths = text.cuda(), text_lengths.cuda() 276 | duration = duration.cuda() 277 | bert = bert.cuda() 278 | sid = sid.cuda() 279 | 280 | y_hat, mask, *_ = generator.module.reconstruct(text, text_lengths, spec, spec_lengths, ssl,duration, bert,sid, noise_scale=0.5) 281 | y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length 282 | 283 | mel = spec_to_mel_torch( 284 | spec, 285 | hps.data.filter_length, 286 | hps.data.n_mel_channels, 287 | hps.data.sampling_rate, 288 | hps.data.mel_fmin, 289 | hps.data.mel_fmax) 290 | y_hat_mel = mel_spectrogram_torch( 291 | y_hat.squeeze(1).float(), 292 | hps.data.filter_length, 293 | hps.data.n_mel_channels, 294 | hps.data.sampling_rate, 295 | hps.data.hop_length, 296 | hps.data.win_length, 297 | hps.data.mel_fmin, 298 | hps.data.mel_fmax 299 | ) 300 | image_dict.update({ 301 | f"gen/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()) 302 | }) 303 | audio_dict.update({ 304 | f"gen/audio_{batch_idx}": y_hat[0, :, :] 305 | }) 306 | image_dict.update({f"gt/mel_{batch_idx}": utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy())}) 307 | audio_dict.update({f"gt/audio_{batch_idx}": y[0, :, :]}) 308 | 309 | # y_hat, mask, *_ = generator.module.infer(ssl, spec_lengths, speakers, y=None) 310 | # audio_dict.update({ 311 | # f"gen/audio_{batch_idx}_style_pred": y_hat[0, :, :] 312 | # }) 313 | 314 | utils.summarize( 315 | writer=writer_eval, 316 | global_step=global_step, 317 | images=image_dict, 318 | audios=audio_dict, 319 | audio_sampling_rate=hps.data.sampling_rate 320 | ) 321 | generator.train() 322 | 323 | if __name__ == "__main__": 324 | main() 325 | --------------------------------------------------------------------------------