├── .gitignore ├── README.md ├── __init__.py ├── datasets.yaml ├── datasets_prepare.py ├── eval_variant_2.ipynb ├── eval_variant_3.ipynb ├── requirements.txt ├── supervoice_hybrid ├── __init__.py ├── audio.py ├── config.py ├── misc.py ├── models.py ├── tensors.py ├── tokenizers.py ├── transformer.py └── vocoders.py ├── tokenizer_text.model ├── tokenizer_text.vocab ├── train ├── __init__.py ├── dataset.py └── misc.py ├── train_variant_1.py ├── train_variant_1.sh ├── train_variant_2.py ├── train_variant_2.sh ├── train_variant_3.py ├── train_variant_3.sh └── welcome.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .ipynb_checkpoints 3 | /external_datasets 4 | /processed_datasets 5 | wandb 6 | output -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Supervoice Hybrid 2 | 3 | A combination of my experiments with VALL-E, Voicebox, Speech Flow, Tortoise TTS, SeamlessM4T and E2 TTS. 4 | 5 | # License 6 | 7 | MIT -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ex3ndr/supervoice-hybrid/b664b91f720180ea9d33dfe1b63fa9b33223f2e4/__init__.py -------------------------------------------------------------------------------- /datasets.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | - librilight 3 | - librilight-aligned 4 | - librilight-aligned@medium 5 | - librilight-aligned@large 6 | - librilight-encodec 7 | - librilight-encodec@medium 8 | - librilight-encodec@large -------------------------------------------------------------------------------- /datasets_prepare.py: -------------------------------------------------------------------------------- 1 | # Ignore warnings 2 | import warnings 3 | warnings.filterwarnings("ignore") 4 | 5 | import os 6 | import multiprocessing 7 | import glob 8 | import torch 9 | import torchaudio 10 | import csv 11 | import gzip 12 | import json 13 | import math 14 | from pathlib import Path 15 | from tqdm import tqdm 16 | from supervoice_hybrid.audio import load_mono_audio, spectogram 17 | from supervoice_hybrid.config import config 18 | 19 | # 20 | # Parameters 21 | # 22 | 23 | PARAM_WORKERS = max(torch.cuda.device_count() * 2, 4) 24 | 25 | # 26 | # Execution 27 | # 28 | 29 | 30 | def clean_text(s: str) -> str: 31 | table = str.maketrans("’‘,。;?!():-《》、“”【】", "'',.;?!(): <>/\"\"[]") 32 | s = s.translate(table) 33 | return s.strip() 34 | 35 | def execute_parallel(args): 36 | process_id = multiprocessing.current_process()._identity[0] 37 | files, output_dir, index = args 38 | file = files[index]['path'] 39 | cuts = files[index]['cuts'] 40 | device = "cuda:" + str(process_id % torch.cuda.device_count()) 41 | 42 | # Load audio 43 | source = load_mono_audio(file, config.audio.sample_rate, device=device) 44 | 45 | # Process cuts 46 | for cut in cuts: 47 | id, start, duration, text = cut 48 | wav = source[int(start * config.audio.sample_rate):int((start + duration) * config.audio.sample_rate)] 49 | 50 | # Encode 51 | spec = spectogram(wav, config.audio.n_fft, config.audio.n_mels, config.audio.hop_size, config.audio.win_size, config.audio.mel_norm, config.audio.mel_scale, config.audio.sample_rate) 52 | 53 | # Save codecs 54 | output_file = Path(output_dir) / Path(id + ".pt") 55 | output_file.parent.mkdir(parents=True, exist_ok=True) 56 | if output_file.exists(): 57 | print("File exists", output_file) 58 | torch.save(spec.cpu(), output_file) 59 | 60 | # Save text 61 | output_file = Path(output_dir) / Path(id + ".txt") 62 | if output_file.exists(): 63 | print("File exists", output_file) 64 | with open(output_file, "w") as f: 65 | f.write(text) 66 | 67 | def execute_run(): 68 | torch.multiprocessing.set_start_method('spawn') 69 | 70 | # Collections 71 | collections = ( 72 | # ("small", "./external_datasets/libriheavy/libriheavy_cuts_small.jsonl.gz", "./external_datasets/librilight/", "./processed_datasets/librilight/"), 73 | # ("medium", "./external_datasets/libriheavy/libriheavy_cuts_medium.jsonl.gz", "./external_datasets/librilight-medium/", "./processed_datasets/librilight-medium/"), 74 | ("large", "./external_datasets/libriheavy/libriheavy_cuts_large.jsonl.gz", "./external_datasets/librilight-large/", "./processed_datasets/librilight-large/"), 75 | ) 76 | 77 | for collection in collections: 78 | name, index_path, files_path, output_path = collection 79 | 80 | # Load index 81 | print("Loading index for collection: " + name) 82 | files = [] 83 | files_map = {} 84 | with gzip.open(index_path, "r") as f: 85 | for line in f: 86 | cut = json.loads(line) 87 | start = math.floor(1000 * cut["start"]) / 1000 88 | duration = math.floor(1000 * cut["duration"]) / 1000 89 | 90 | # Load audio 91 | wav_id = cut["recording"]["id"] 92 | id = cut["supervisions"][0]["id"] 93 | if wav_id.startswith("small/"): 94 | wav_id = wav_id[len("small/"):] 95 | if wav_id.startswith("medium/"): 96 | wav_id = wav_id[len("medium/"):] 97 | if wav_id.startswith("large/"): 98 | wav_id = wav_id[len("large/"):] 99 | if id.startswith("small/"): 100 | id = id[len("small/"):] 101 | if id.startswith("medium/"): 102 | id = id[len("medium/"):] 103 | if id.startswith("large/"): 104 | id = id[len("large/"):] 105 | 106 | # Check if exists 107 | if (Path(output_path) / Path(id + ".pt")).exists(): 108 | continue 109 | 110 | # Load text 111 | text = cut["supervisions"][0]["custom"]["texts"][0] 112 | text = clean_text(text) 113 | 114 | # Find index 115 | if wav_id not in files_map: 116 | files_map[wav_id] = len(files) 117 | files.append({ "path": files_path + wav_id + ".flac", "cuts": []}) 118 | index = files_map[wav_id] 119 | 120 | # Append 121 | files[index]['cuts'].append((id, start, duration, text)) 122 | 123 | # Process files 124 | print("Processing files for collection: " + name) 125 | with multiprocessing.Manager() as manager: 126 | files = manager.list(files) 127 | args_list = [(files, output_path, i) for i in range(len(files))] 128 | with multiprocessing.Pool(processes=PARAM_WORKERS) as pool: 129 | for result in tqdm(pool.imap_unordered(execute_parallel, args_list, chunksize=32), total=len(files)): 130 | pass 131 | 132 | # End 133 | print("Done") 134 | 135 | if __name__ == "__main__": 136 | execute_run() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fairseq2 2 | git+https://github.com/facebookresearch/seamless_communication.git -------------------------------------------------------------------------------- /supervoice_hybrid/__init__.py: -------------------------------------------------------------------------------- 1 | from .tokenizers import * 2 | from .models import * -------------------------------------------------------------------------------- /supervoice_hybrid/audio.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchaudio 3 | import torchaudio.transforms as T 4 | import torchaudio.functional as F 5 | 6 | # 7 | # Cached Hann Window 8 | # 9 | 10 | hann_window_cache = {} 11 | def hann_window(size, device): 12 | global hann_window_cache 13 | key = str(device) + "_" + str(size) 14 | if key in hann_window_cache: 15 | return hann_window_cache[key] 16 | else: 17 | res = torch.hann_window(size).to(device) 18 | hann_window_cache[key] = res 19 | return res 20 | 21 | # 22 | # Mel Log Bank 23 | # 24 | 25 | melscale_fbank_cache = {} 26 | def melscale_fbanks(n_mels, n_fft, f_min, f_max, sample_rate, mel_norm, mel_scale, device): 27 | global melscale_fbank_cache 28 | key = str(n_mels) + "_" + str(n_fft) + "_" + str(f_min) + "_" + str(f_max) + "_" + str(sample_rate) + "_" + str(mel_norm) + "_" + str(mel_scale) + "_"+ str(device) 29 | if key in melscale_fbank_cache: 30 | return melscale_fbank_cache[key] 31 | else: 32 | res = F.melscale_fbanks( 33 | n_freqs=int(n_fft // 2 + 1), 34 | sample_rate=sample_rate, 35 | f_min=f_min, 36 | f_max=f_max, 37 | n_mels=n_mels, 38 | norm=mel_norm, 39 | mel_scale=mel_scale 40 | ).transpose(-1, -2).to(device) 41 | melscale_fbank_cache[key] = res 42 | return res 43 | 44 | # 45 | # Resampler 46 | # 47 | 48 | resampler_cache = {} 49 | def resampler(from_sample_rate, to_sample_rate, device=None): 50 | global resampler_cache 51 | if device is None: 52 | device = "cpu" 53 | key = str(from_sample_rate) + "_" + str(to_sample_rate) + "_" + str(device) 54 | if key in resampler_cache: 55 | return resampler_cache[key] 56 | else: 57 | res = T.Resample( 58 | from_sample_rate, 59 | to_sample_rate, 60 | lowpass_filter_width=64, 61 | rolloff=0.9475937167399596, 62 | resampling_method="sinc_interp_kaiser", 63 | beta=14.769656459379492 64 | ).to(device) 65 | resampler_cache[key] = res 66 | return res 67 | 68 | # 69 | # Spectogram caclulcation 70 | # 71 | 72 | def spectogram(audio, n_fft, n_mels, n_hop, n_window, mel_norm, mel_scale, sample_rate): 73 | 74 | # Hann Window 75 | window = hann_window(n_window, audio.device) 76 | 77 | # STFT 78 | stft = torch.stft(audio, 79 | 80 | # STFT Parameters 81 | n_fft = n_fft, 82 | hop_length = n_hop, 83 | win_length = n_window, 84 | window = window, 85 | center = True, 86 | 87 | onesided = True, # Default to true to real input, but we enforce it just in case 88 | return_complex = True 89 | ) 90 | 91 | # Compute magnitudes (|a + ib| = sqrt(a^2 + b^2)) instead of power spectrum (|a + ib|^2 = a^2 + b^2) 92 | # because magnitude and phase is linear to the input, while power spectrum is quadratic to the input 93 | # and the magnitude is easier to learn for vocoder 94 | # magnitudes = stft[..., :-1].abs() ** 2 # Power 95 | magnitudes = stft[..., :-1].abs() # Amplitude 96 | 97 | # Mel Log Bank 98 | mel_filters = melscale_fbanks(n_mels, n_fft, 0, sample_rate / 2, sample_rate, mel_norm, mel_scale, audio.device) 99 | mel_spec = (mel_filters @ magnitudes) 100 | 101 | # Log 102 | log_spec = torch.clamp(mel_spec, min=1e-5).log() 103 | 104 | return log_spec 105 | 106 | # 107 | # Load Mono Audio 108 | # 109 | 110 | def load_mono_audio(src, sample_rate, device=None): 111 | 112 | # Load audio 113 | audio, sr = torchaudio.load(src) 114 | 115 | # Move to device 116 | if device is not None: 117 | audio = audio.to(device) 118 | 119 | # Resample 120 | if sr != sample_rate: 121 | audio = resampler(sr, sample_rate, device)(audio) 122 | sr = sample_rate 123 | 124 | # Convert to mono 125 | if audio.shape[0] > 1: 126 | audio = audio.mean(dim=0, keepdim=True) 127 | 128 | # Convert to single dimension 129 | audio = audio[0] 130 | 131 | return audio -------------------------------------------------------------------------------- /supervoice_hybrid/config.py: -------------------------------------------------------------------------------- 1 | from .misc import dict_to_object 2 | 3 | config = dict_to_object({ 4 | "audio": { 5 | "sample_rate": 24000, 6 | "n_mels": 100, 7 | "n_fft": 1024, 8 | "hop_size": 256, 9 | "win_size": 256 * 4, 10 | "mel_norm": "slaney", 11 | "mel_scale": "slaney", 12 | "norm_std": 2.2615, 13 | "norm_mean": -5.8843 14 | } 15 | }) -------------------------------------------------------------------------------- /supervoice_hybrid/misc.py: -------------------------------------------------------------------------------- 1 | def dict_to_object(src): 2 | class DictToObject: 3 | def __init__(self, dictionary): 4 | for key, value in dictionary.items(): 5 | if isinstance(value, dict): 6 | value = DictToObject(value) 7 | self.__dict__[key] = value 8 | 9 | def __repr__(self): 10 | return f"{self.__dict__}" 11 | return DictToObject(src) -------------------------------------------------------------------------------- /supervoice_hybrid/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .transformer import Transformer 3 | from .config import config 4 | from .tensors import LearnedSinusoidalPosEmb 5 | from xformers.ops import fmha 6 | from torchdiffeq import odeint 7 | 8 | class SupervoceVariant1(torch.nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | # Parameters 13 | self.n_dim = 1024 14 | self.max_seq_len = 8 * 1024 15 | 16 | # Positional embeddings 17 | self.positional_embedding_text = torch.nn.Embedding(self.max_seq_len, self.n_dim) 18 | torch.nn.init.normal_(self.positional_embedding_text.weight, mean=0.0, std=0.02) 19 | self.positional_embedding_audio = torch.nn.Embedding(self.max_seq_len, self.n_dim) 20 | torch.nn.init.normal_(self.positional_embedding_audio.weight, mean=0.0, std=0.02) 21 | 22 | # Text Condition 23 | self.text_embedding = torch.nn.Embedding(8 * 1024, self.n_dim) 24 | torch.nn.init.normal_(self.text_embedding.weight, mean=0.0, std=0.02) 25 | 26 | # Audio embedding 27 | self.audio_embedding = torch.nn.Embedding(1024 + 1, self.n_dim) 28 | 29 | # Input projection 30 | self.input_projection = torch.nn.Linear(self.n_dim * 2, self.n_dim, bias=False) 31 | 32 | # Transformer 33 | self.transformer = Transformer( 34 | n_heads = 16, 35 | n_layers = 12, 36 | n_dim = self.n_dim, 37 | n_dim_head = 16, # n_dim // n_heads 38 | n_dim_ffn = 4096, 39 | att_dropout = 0, 40 | ffn_dropout = 0.1 41 | ) 42 | 43 | # Prediction 44 | self.prediction = torch.nn.Linear(self.n_dim, 1024) 45 | torch.nn.init.normal_(self.prediction.weight, mean=0.0, std=0.02) 46 | torch.nn.init.zeros_(self.prediction.bias) 47 | 48 | 49 | def forward(self, *, condition_text, condition_audio, duration, target = None): 50 | device = condition_text[0].device 51 | 52 | # 53 | # Check shapes 54 | # 55 | 56 | B = len(condition_text) 57 | assert len(condition_audio) == B 58 | assert len(duration) == B 59 | if target is not None: 60 | assert len(target) == B 61 | assert all(t.shape[0] + c.shape[0] == d for t, d, c in zip(target, duration, condition_audio)) 62 | for i in range(B): 63 | assert len(condition_text[i].shape) == 1, condition_text[i].shape 64 | assert len(condition_audio[i].shape) == 1, condition_audio[i].shape 65 | assert condition_text[i].shape[0] <= duration[i] 66 | assert condition_audio[i].shape[0] <= duration[i] 67 | if target is not None: 68 | assert target[i].shape[0] + condition_audio[i].shape[0] == duration[i] 69 | 70 | # 71 | # Prepare inputs 72 | # 73 | 74 | # Pad inputs 75 | inputs_text = [] 76 | inputs_audio = [] 77 | inputs_positional = [] 78 | for i in range(B): 79 | d = duration[i] 80 | inputs_text.append(torch.nn.functional.pad(condition_text[i], (0, d - condition_text[i].shape[0]), "constant", 0)) 81 | inputs_audio.append(torch.nn.functional.pad(condition_audio[i] + 1, (0, d - condition_audio[i].shape[0]), "constant", 0)) 82 | inputs_positional.append(torch.arange(d).to(device, non_blocking=True)) 83 | 84 | # Cat everything 85 | inputs_text = torch.cat(inputs_text) 86 | inputs_audio = torch.cat(inputs_audio) 87 | 88 | # Embeddings 89 | inputs_text = self.text_embedding(inputs_text) 90 | inputs_audio = self.audio_embedding(inputs_audio) 91 | 92 | # Positional embeddings 93 | inputs_positional = torch.cat(inputs_positional) 94 | inputs_text += self.positional_embedding_text(inputs_positional) 95 | inputs_audio += self.positional_embedding_audio(inputs_positional) 96 | 97 | # Input projection 98 | inputs = torch.cat([inputs_text, inputs_audio], dim=-1) 99 | inputs = self.input_projection(inputs) 100 | 101 | # 102 | # Run transformer 103 | # 104 | mask = fmha.BlockDiagonalMask.from_seqlens(duration) 105 | x = self.transformer(inputs.unsqueeze(0), mask = mask).squeeze(0) 106 | 107 | # 108 | # Predict 109 | # 110 | 111 | x = self.prediction(x) 112 | 113 | # 114 | # Split predictions 115 | # 116 | 117 | predicted = [] 118 | offset = 0 119 | for i in range(B): 120 | predicted.append(x[offset + (duration[i] - target[i].shape[0]): offset + duration[i]]) 121 | offset += duration[i] 122 | 123 | # 124 | # Loss 125 | # 126 | 127 | if target is not None: 128 | loss = torch.nn.functional.cross_entropy(torch.cat(predicted), torch.cat(target)) 129 | return predicted, loss 130 | else: 131 | return predicted 132 | 133 | 134 | class SupervoiceVariant2(torch.nn.Module): 135 | def __init__(self): 136 | super().__init__() 137 | 138 | # Parameters 139 | self.n_dim = 1024 140 | self.n_vocab = 8 * 1024 141 | self.max_seq_len = 8 * 1024 142 | 143 | # Positional embeddings 144 | self.positional_embedding_text = torch.nn.Embedding(self.max_seq_len, self.n_dim) 145 | torch.nn.init.normal_(self.positional_embedding_text.weight, mean=0.0, std=0.02) 146 | self.positional_embedding_audio = torch.nn.Embedding(self.max_seq_len, self.n_dim) 147 | torch.nn.init.normal_(self.positional_embedding_audio.weight, mean=0.0, std=0.02) 148 | 149 | # Text Embedding 150 | self.text_embedding = torch.nn.Embedding(self.n_vocab, self.n_dim) 151 | torch.nn.init.normal_(self.text_embedding.weight, mean=0.0, std=0.02) 152 | 153 | # Audio embedding 154 | self.audio_embedding = torch.nn.Linear(config.audio.n_mels, self.n_dim, bias=False) 155 | torch.nn.init.normal_(self.audio_embedding.weight, mean=0.0, std=0.02) 156 | 157 | # Noise embedding 158 | self.noise_embedding = torch.nn.Linear(config.audio.n_mels, self.n_dim, bias=False) 159 | torch.nn.init.normal_(self.noise_embedding.weight, mean=0.0, std=0.02) 160 | 161 | # Transformer input 162 | self.input_projection = torch.nn.Linear(3 * self.n_dim, self.n_dim, bias=False) 163 | 164 | # Sinusoidal positional embedding for time 165 | self.time_embedding = LearnedSinusoidalPosEmb(self.n_dim) 166 | 167 | # Transformer 168 | self.transformer = Transformer( 169 | n_heads = 16, 170 | n_layers = 12, 171 | n_dim = self.n_dim, 172 | n_dim_head = 16, # n_dim // n_heads 173 | n_dim_ffn = self.n_dim * 4, 174 | att_dropout = 0, 175 | ffn_dropout = 0.1, 176 | enable_skip_connections = True 177 | ) 178 | 179 | # Prediction 180 | self.prediction = torch.nn.Linear(self.n_dim, config.audio.n_mels) 181 | 182 | def sample(self, *, tokens, audio, interval, steps, alpha = None): 183 | 184 | # 185 | # Prepare 186 | # 187 | 188 | # Mask out audio 189 | masked_audio = audio.clone() 190 | masked_audio[interval[0]: interval[1]] = 0 191 | 192 | # Create noise 193 | noise = torch.randn_like(audio) 194 | 195 | # Create time interpolation 196 | times = torch.linspace(0, 1, steps, device = audio.device) 197 | 198 | # 199 | # Solver 200 | # 201 | 202 | def solver(t, z): 203 | 204 | # If alpha is not provided 205 | if alpha is None: 206 | return self.forward( 207 | condition_text = [tokens], 208 | condition_audio = [masked_audio], 209 | noisy_audio = [z], 210 | times = [t], 211 | )[0] 212 | 213 | # If alpha is provided - zero out tokens and audio and mix together 214 | tokens_empty = torch.zeros_like(tokens) 215 | audio_empty = torch.zeros_like(audio) 216 | 217 | # Inference 218 | predicted_mix = self.forward( 219 | condition_text = [torch.zeros_like(tokens), tokens], 220 | condition_audio = [torch.zeros_like(audio), masked_audio], 221 | noisy_audio = [z, z], 222 | times = [t, t] 223 | ) 224 | predicted_conditioned = predicted_mix[1] 225 | predicted_unconditioned = predicted_mix[0] 226 | 227 | # CFG prediction 228 | 229 | # There are different ways to do CFG, this is my very naive version, which worked for me: 230 | # prediction = (1 + alpha) * predicted_conditioned - alpha * predicted_unconditioned 231 | 232 | # Original paper uses a different one, but i found that it simply creates overexposed values 233 | # prediction = predicted_unconditioned + (predicted_conditioned - predicted_unconditioned) * alpha 234 | 235 | # This is from the latest paper that rescales original formula (https://arxiv.org/abs/2305.08891): 236 | prediction = predicted_conditioned + (predicted_conditioned - predicted_unconditioned) * alpha 237 | prediction_rescaled = predicted_conditioned.std() * (prediction / prediction.std()) 238 | 239 | return prediction 240 | 241 | 242 | trajectory = odeint(solver, noise, times, atol = 1e-5, rtol = 1e-5, method = 'midpoint') 243 | 244 | # 245 | # Output predicted interval 246 | # 247 | 248 | return trajectory[-1][interval[0]: interval[1]] 249 | 250 | def forward(self, *, condition_text, condition_audio, noisy_audio, times, intervals = None, target = None): 251 | device = condition_text[0].device 252 | 253 | # Check shapes 254 | B = len(condition_text) 255 | assert len(condition_audio) == B 256 | assert len(noisy_audio) == B 257 | assert len(times) == B 258 | if target is not None: 259 | assert intervals is not None 260 | assert len(target) == B 261 | 262 | # Calculate durations 263 | durations = [c.shape[0] for c in condition_audio] 264 | 265 | # Check inner shapes 266 | for i in range(B): 267 | assert len(condition_text[i].shape) == 1, condition_text[i].shape 268 | assert len(condition_audio[i].shape) == 2, condition_audio[i].shape 269 | assert len(noisy_audio[i].shape) == 2, condition_audio[i].shape 270 | assert condition_text[i].shape[0] <= durations[i], condition_text[i].shape[0] 271 | assert condition_audio[i].shape[0] == durations[i] 272 | assert condition_audio[i].shape[1] == config.audio.n_mels 273 | assert noisy_audio[i].shape[0] == durations[i] 274 | assert noisy_audio[i].shape[1] == config.audio.n_mels 275 | if target is not None: 276 | assert len(intervals[i]) == 2 277 | assert intervals[i][0] >= 0 278 | assert intervals[i][1] <= durations[i] 279 | assert intervals[i][0] <= intervals[i][1] 280 | assert target[i].shape[0] == intervals[i][1] - intervals[i][0] 281 | assert target[i].shape[1] == config.audio.n_mels 282 | 283 | # Prepare inputs 284 | inputs_text = [] 285 | inputs_positional = [] 286 | for i in range(B): 287 | d = durations[i] 288 | inputs_text.append(torch.nn.functional.pad(condition_text[i], (0, d - condition_text[i].shape[0]), "constant", 0)) 289 | inputs_positional.append(torch.arange(d).to(device, non_blocking=True)) 290 | 291 | # Cat everything 292 | inputs_positional = torch.cat(inputs_positional) 293 | inputs_text = torch.cat(inputs_text) 294 | inputs_audio = torch.cat(condition_audio) 295 | inputs_noisy = torch.cat(noisy_audio) 296 | 297 | # Text 298 | inputs_text = self.text_embedding(inputs_text) 299 | inputs_text += self.positional_embedding_text(inputs_positional) 300 | 301 | # Audio 302 | inputs_audio = self.audio_embedding(inputs_audio) 303 | inputs_audio += self.positional_embedding_audio(inputs_positional) 304 | inputs_noisy = self.noise_embedding(inputs_noisy) 305 | inputs_noisy += self.positional_embedding_audio(inputs_positional) 306 | 307 | # Input projection 308 | inputs = torch.cat([inputs_text, inputs_audio, inputs_noisy], dim=-1) 309 | inputs = self.input_projection(inputs) 310 | 311 | # Time embeddings 312 | times = self.time_embedding(times) 313 | 314 | # Merge time embeddings 315 | inputs_timed = [] 316 | offset = 0 317 | for i in range(B): 318 | d = durations[i] 319 | inputs_timed.append(torch.cat([inputs[offset: offset + d], times[i].unsqueeze(0)], dim=0)) 320 | offset += d 321 | inputs = torch.cat(inputs_timed) 322 | 323 | # Transformer 324 | attention_mask = None 325 | if len(durations) > 1: # Disable mask for speed if not batched 326 | attention_mask = fmha.BlockDiagonalMask.from_seqlens([i + 1 for i in durations]) 327 | x = self.transformer(inputs.unsqueeze(0), mask = attention_mask).squeeze(0) 328 | 329 | # Predict 330 | x = self.prediction(x) 331 | 332 | # Split predictions 333 | outputs = [] 334 | offset = 0 335 | for i in range(B): 336 | outputs.append(x[offset: offset + durations[i]]) 337 | offset += durations[i] + 1 # +1 for time embedding 338 | 339 | # Compute loss 340 | if target is not None: 341 | 342 | # Compute target intervals 343 | preds = [] 344 | offset = 0 345 | for i in range(B): 346 | preds.append(x[offset + intervals[i][0]: offset + intervals[i][1]]) 347 | offset += durations[i] + 1 # +1 for time embedding 348 | 349 | # Compute loss 350 | target = torch.cat(target, dim = 0) 351 | predd_cat = torch.cat(preds, dim = 0) 352 | loss = torch.nn.functional.mse_loss(predd_cat, target) 353 | 354 | # Normalize by number of frames 355 | loss = loss / target.shape[0] 356 | 357 | return outputs, loss 358 | 359 | return outputs 360 | 361 | 362 | class SupervoiceVariant3(torch.nn.Module): 363 | def __init__(self, flow): 364 | super().__init__() 365 | self.flow = flow 366 | self.flow.transformer.cache_alibi = False 367 | self.text_embedding = torch.nn.Embedding(8 * 1024 + 1, 100) 368 | torch.nn.init.normal_(self.text_embedding.weight, mean=0.0, std=0.02) 369 | 370 | def sample(self, *, tokens, audio, interval, steps, alpha = None): 371 | 372 | # 373 | # Prepare 374 | # 375 | 376 | # Mask out audio 377 | masked_audio = audio.clone() 378 | masked_audio[interval[0]: interval[1]] = 0 379 | 380 | # Create noise 381 | noise = torch.randn_like(audio) 382 | 383 | # Create time interpolation 384 | times = torch.linspace(0, 1, steps, device = audio.device) 385 | 386 | # 387 | # Solver 388 | # 389 | 390 | def solver(t, z): 391 | 392 | # If alpha is not provided 393 | if alpha is None: 394 | return self.forward( 395 | condition_text = [tokens], 396 | condition_audio = [masked_audio], 397 | noisy_audio = [z], 398 | times = [t], 399 | )[0] 400 | 401 | # If alpha is provided - zero out tokens and audio and mix together 402 | tokens_empty = torch.zeros_like(tokens) 403 | audio_empty = torch.zeros_like(audio) 404 | 405 | # Inference 406 | predicted_mix = self.forward( 407 | condition_text = [torch.zeros_like(tokens), tokens], 408 | condition_audio = [torch.zeros_like(audio), masked_audio], 409 | noisy_audio = [z, z], 410 | times = [t, t] 411 | ) 412 | predicted_conditioned = predicted_mix[1] 413 | predicted_unconditioned = predicted_mix[0] 414 | 415 | # CFG prediction 416 | 417 | # There are different ways to do CFG, this is my very naive version, which worked for me: 418 | # prediction = (1 + alpha) * predicted_conditioned - alpha * predicted_unconditioned 419 | 420 | # Original paper uses a different one, but i found that it simply creates overexposed values 421 | # prediction = predicted_unconditioned + (predicted_conditioned - predicted_unconditioned) * alpha 422 | 423 | # This is from the latest paper that rescales original formula (https://arxiv.org/abs/2305.08891): 424 | prediction = predicted_conditioned + (predicted_conditioned - predicted_unconditioned) * alpha 425 | prediction_rescaled = predicted_conditioned.std() * (prediction / prediction.std()) 426 | 427 | return prediction 428 | 429 | 430 | trajectory = odeint(solver, noise, times, atol = 1e-5, rtol = 1e-5, method = 'midpoint') 431 | 432 | # 433 | # Output predicted interval 434 | # 435 | 436 | return trajectory[-1][interval[0]: interval[1]] 437 | 438 | def forward(self, *, condition_text, condition_audio, noisy_audio, times, intervals = None, target = None): 439 | device = condition_text[0].device 440 | 441 | # Check shapes 442 | B = len(condition_text) 443 | assert len(condition_audio) == B 444 | assert len(noisy_audio) == B 445 | assert len(times) == B 446 | if target is not None: 447 | assert intervals is not None 448 | assert len(target) == B 449 | 450 | # Calculate durations 451 | durations = [c.shape[0] for c in condition_audio] 452 | 453 | # Check inner shapes 454 | for i in range(B): 455 | assert len(condition_text[i].shape) == 1, condition_text[i].shape 456 | assert len(condition_audio[i].shape) == 2, condition_audio[i].shape 457 | assert len(noisy_audio[i].shape) == 2, condition_audio[i].shape 458 | assert condition_text[i].shape[0] <= durations[i], condition_text[i].shape[0] 459 | assert condition_audio[i].shape[0] == durations[i] 460 | assert condition_audio[i].shape[1] == config.audio.n_mels 461 | assert noisy_audio[i].shape[0] == durations[i] 462 | assert noisy_audio[i].shape[1] == config.audio.n_mels 463 | if target is not None: 464 | assert len(intervals[i]) == 2 465 | assert intervals[i][0] >= 0 466 | assert intervals[i][1] <= durations[i] 467 | assert intervals[i][0] <= intervals[i][1] 468 | assert target[i].shape[0] == intervals[i][1] - intervals[i][0] 469 | assert target[i].shape[1] == config.audio.n_mels 470 | 471 | # Find max duration 472 | max_duration = max(durations) 473 | 474 | # Prepare inputs 475 | inputs_text = [] 476 | inputs_audio = [] 477 | inputs_noisy = [] 478 | input_mask = None 479 | targets = None 480 | if target is not None: 481 | input_mask = [] 482 | targets = [] 483 | for i in range(B): 484 | 485 | # Pad text and use 0 for padding and 1 for filler 486 | t = condition_text[i] + 1 487 | t = torch.nn.functional.pad(t, (1, durations[i] - t.shape[0]), "constant", 0) # Filer 488 | t = torch.nn.functional.pad(t, (0, max_duration - t.shape[0]), "constant", 0) # Pad 489 | inputs_text.append(t) 490 | 491 | # Pad audio and noise with simple zeros 492 | inputs_audio.append(torch.nn.functional.pad(condition_audio[i], (0, 0, 0, max_duration - condition_audio[i].shape[0]), "constant", 0)) 493 | inputs_noisy.append(torch.nn.functional.pad(noisy_audio[i], (0, 0, 0, max_duration - noisy_audio[i].shape[0]), "constant", 0)) 494 | 495 | # For loss 496 | if target is not None: 497 | 498 | # Create target 499 | targets.append(torch.nn.functional.pad(target[i], (0, 0, 0, max_duration - target[i].shape[0]), "constant", 0)) 500 | 501 | # Create loss mask 502 | mask = torch.zeros(max_duration, device = device, dtype = torch.bool) 503 | mask[intervals[i][0]: intervals[i][1]] = 1 504 | input_mask.append(mask) 505 | 506 | # Stack everything 507 | inputs_text = torch.stack(inputs_text) 508 | inputs_audio = torch.stack(inputs_audio) 509 | inputs_noisy = torch.stack(inputs_noisy) 510 | if target is not None: 511 | input_mask = torch.stack(input_mask) 512 | targets = torch.stack(targets) 513 | 514 | # Cacluate condition 515 | inputs_condition = inputs_audio + self.text_embedding(inputs_text) 516 | 517 | # Run flow 518 | return self.flow( 519 | audio = inputs_condition, 520 | noise = inputs_noisy, 521 | times = times, 522 | mask = input_mask, 523 | target = targets, 524 | mask_loss = True 525 | ) 526 | -------------------------------------------------------------------------------- /supervoice_hybrid/tensors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange 4 | import random 5 | import numpy as np 6 | import math 7 | from torch.nn.utils.rnn import pad_sequence 8 | 9 | class RMSNorm(torch.nn.Module): 10 | def __init__(self, dim): 11 | super().__init__() 12 | self.scale = dim ** 0.5 13 | self.gamma = torch.nn.Parameter(torch.ones(dim)) 14 | 15 | def forward(self, x): 16 | return F.normalize(x, dim = -1, eps = 1e-05) * self.scale * self.gamma 17 | 18 | class AdaptiveRMSNorm(torch.nn.Module): 19 | def __init__( 20 | self, 21 | dim 22 | ): 23 | super().__init__() 24 | self.scale = dim ** 0.5 25 | self.to_gamma = torch.nn.Linear(dim, dim) 26 | self.to_beta = torch.nn.Linear(dim, dim) 27 | 28 | # Identity initialization 29 | torch.nn.init.zeros_(self.to_gamma.weight) 30 | torch.nn.init.ones_(self.to_gamma.bias) 31 | torch.nn.init.zeros_(self.to_beta.weight) 32 | torch.nn.init.zeros_(self.to_beta.bias) 33 | 34 | def forward(self, x, *, condition): 35 | normed = F.normalize(x, dim = -1, eps = 1e-05) * self.scale 36 | gamma, beta = self.to_gamma(condition), self.to_beta(condition) 37 | gamma, beta = map(lambda t: rearrange(t, 'b d -> b 1 d'), (gamma, beta)) 38 | 39 | return normed * gamma + beta 40 | 41 | class ConvPositionEmbed(torch.nn.Module): 42 | def __init__(self, n_dim, kernel_size): 43 | super().__init__() 44 | self.dw_conv1d = torch.nn.Sequential(nn.Conv1d(n_dim, n_dim, kernel_size, groups = n_dim, padding = kernel_size // 2), nn.GELU()) 45 | 46 | def forward(self, x, mask = None): 47 | 48 | if mask is not None: 49 | mask = mask[..., None] 50 | x = x.masked_fill(~mask, 0.) 51 | 52 | x = rearrange(x, 'b n c -> b c n') 53 | x = self.dw_conv1d(x) 54 | out = rearrange(x, 'b c n -> b n c') 55 | 56 | if mask is not None: 57 | out = out.masked_fill(~mask, 0.) 58 | 59 | return out 60 | 61 | class LearnedSinusoidalPosEmb(torch.nn.Module): 62 | def __init__(self, dim): 63 | super().__init__() 64 | half_dim = dim // 2 65 | self.weights = torch.nn.Parameter(torch.randn(half_dim)) 66 | 67 | def forward(self, x): 68 | x = rearrange(x, 'b -> b 1') 69 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi 70 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) 71 | return fouriered 72 | 73 | def probability_binary_mask(shape, true_prob, device): 74 | return torch.zeros(shape, device = device).float().uniform_(0, 1) < true_prob 75 | 76 | 77 | def debug_if_invalid(x): 78 | if torch.isnan(x).any() or torch.isinf(x).any(): 79 | print('Invalid tensor') 80 | 81 | def count_parameters(model): 82 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 83 | 84 | def drop_using_mask(source, replacement, mask): 85 | while mask.dim() < source.dim(): 86 | mask = mask.unsqueeze(-1) 87 | return torch.where(mask, torch.full(source.shape, replacement, dtype = source.dtype, device = source.device), source) 88 | 89 | def merge_mask(source, replacement, mask): 90 | while mask.dim() < source.dim(): 91 | mask = mask.unsqueeze(-1) 92 | return torch.where(mask, replacement, source) 93 | 94 | def random_interval_masking(batch_size, length, *, min_size, min_count, max_count, device): 95 | tensor = torch.full((batch_size, length), False, device=device, dtype=torch.bool) 96 | for i in range(batch_size): 97 | 98 | # Expected sum of all intervals 99 | expected_length = random.randint(min_count, max_count) 100 | 101 | # Number of intervals 102 | num_intervals = random.randint(1, expected_length // min_size) 103 | 104 | # Generate interval lengths 105 | lengths = [min_size] * num_intervals 106 | for _ in range(expected_length - num_intervals * min_size): 107 | lengths[random.randint(0, num_intervals - 1)] += 1 108 | 109 | # Generate start points 110 | placements = [] 111 | offset = 0 112 | remaining = expected_length 113 | for l in lengths: 114 | start_position = random.uniform(offset, remaining - l) 115 | placements.append(start_position) 116 | offset = start_position + l 117 | remaining -= l 118 | 119 | # Write to tensor 120 | for l, p in zip(lengths, placements): 121 | tensor[i, int(p):int(p + l)] = True 122 | 123 | return tensor 124 | 125 | def sinusoids(length, channels, max_timescale=10000): 126 | assert channels % 2 == 0 127 | log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) 128 | inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) 129 | scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] 130 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) 131 | 132 | def list_to_tensors(tensors): 133 | 134 | # Calculate lengths 135 | l = list(map(len, tensors)) 136 | 137 | # Padded tensors 138 | padded = pad_sequence(tensors, batch_first=True, padding_value=0) 139 | 140 | # Mask 141 | mask = torch.zeros((padded.shape[0], padded.shape[1]), device=padded.device) 142 | for i in range(len(l)): 143 | mask[i, l[i]:] = -10000.0 144 | 145 | return padded, mask 146 | 147 | def join_tensors(tensors, sep = None): 148 | 149 | # No separator 150 | if sep is None: 151 | return torch.cat(tensors) 152 | 153 | # Initial 154 | res = tensors[0] 155 | 156 | # Join rest 157 | for t in tensors[1:]: 158 | res = torch.cat([res, sep, t]) 159 | 160 | return res 161 | 162 | def get_slopes_power_of_2(n_heads): 163 | start = (2**(-2**-(math.log2(n_heads)-3))) 164 | ratio = start 165 | return torch.tensor([start*ratio**i for i in range(n_heads)], dtype=torch.float32) -------------------------------------------------------------------------------- /supervoice_hybrid/tokenizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sentencepiece as spm 3 | # from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer 4 | 5 | class SentencePieceTextTokenizer: 6 | def __init__(self, path): 7 | self.sp = spm.SentencePieceProcessor() 8 | self.sp.load(path) 9 | 10 | def encode(self, text): 11 | return torch.tensor(self.sp.encode(text), dtype=torch.long).squeeze(0).squeeze(0) 12 | 13 | def encode_sample(self, text): 14 | return torch.tensor(self.sp.encode(text, enable_sampling=True, alpha=0.1, nbest_size=-1), dtype=torch.long).squeeze(0).squeeze(0) 15 | 16 | 17 | # class UnitTextTokenizer: 18 | # def __init__(self): 19 | # text_tokenizer = load_unity_char_tokenizer("nar_t2u_aligner") 20 | # self.tokenizer = text_tokenizer.create_raw_encoder() 21 | # self.vocab_info = text_tokenizer.vocab_info 22 | # self.bos_idx = self.vocab_info.bos_idx 23 | # self.eos_idx = self.vocab_info.eos_idx 24 | # self.pad_idx = self.vocab_info.pad_idx 25 | 26 | # def encode(self, text: str) -> torch.Tensor: 27 | # return self.tokenizer(text) 28 | 29 | # def encode_sample(self, text: str) -> torch.Tensor: 30 | # return self.tokenizer(text) -------------------------------------------------------------------------------- /supervoice_hybrid/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | import torch.nn.functional as F 5 | from einops import rearrange, repeat, reduce, pack, unpack 6 | from torch.cuda.amp import autocast 7 | from .tensors import RMSNorm, AdaptiveRMSNorm 8 | import xformers.ops as xops 9 | from torch.profiler import record_function 10 | 11 | class Transformer(nn.Module): 12 | def __init__(self, 13 | n_heads, 14 | n_layers, 15 | n_dim, 16 | n_dim_head, 17 | n_dim_ffn, 18 | att_dropout, 19 | ffn_dropout, 20 | enable_skip_connections = False, 21 | adaptive = False 22 | ): 23 | super(Transformer, self).__init__() 24 | self.n_layers = n_layers 25 | self.n_heads = n_heads 26 | self.enable_skip_connections = enable_skip_connections 27 | self.adaptive = adaptive 28 | 29 | # Attention blocks 30 | self.layers = torch.nn.ModuleList([]) 31 | for i in range(n_layers): 32 | self.layers.append(AttentionBlock( 33 | n_heads = n_heads, 34 | n_dim = n_dim, 35 | n_dim_head = n_dim_head, 36 | n_dim_ffn = n_dim_ffn, 37 | att_dropout = att_dropout, 38 | ffn_dropout = ffn_dropout, 39 | adaptive = adaptive 40 | )) 41 | 42 | # Skip connections 43 | self.skip_combiners = torch.nn.ModuleList([]) 44 | if enable_skip_connections: 45 | for i in range(n_layers//2): 46 | self.skip_combiners.append(torch.nn.Linear(n_dim * 2, n_dim)) 47 | 48 | # Output normalization 49 | self.output_norm = RMSNorm(n_dim) if not adaptive else AdaptiveRMSNorm(n_dim) 50 | 51 | def forward(self, x, condition = None, mask = None): 52 | 53 | # Run through attention blocks 54 | connections = [] 55 | for i in range(self.n_layers): 56 | 57 | # Skip connection 58 | if self.n_layers - (self.n_layers // 2) <= i and self.enable_skip_connections: 59 | s = connections.pop() 60 | x = torch.cat([x, s], dim = -1) 61 | x = self.skip_combiners[i - (self.n_layers // 2)](x) 62 | 63 | # Attention 64 | with record_function("attention"): 65 | x = self.layers[i](x, condition = condition, mask = mask) 66 | 67 | # Skip connection 68 | if i <= self.n_layers // 2: 69 | connections.append(x) 70 | 71 | # Output normalization 72 | x = self.output_norm(x) 73 | 74 | # Result 75 | return x 76 | 77 | 78 | class AttentionBlock(torch.nn.Module): 79 | def __init__(self, n_heads, n_dim, n_dim_head, n_dim_ffn, att_dropout, ffn_dropout, adaptive): 80 | super(AttentionBlock, self).__init__() 81 | 82 | self.n_heads = n_heads 83 | self.n_dim_head = n_dim_head 84 | self.att_dropout = att_dropout 85 | self.adaptive = adaptive 86 | 87 | # Attention input layer norm 88 | self.attention_ln = RMSNorm(n_dim) if not adaptive else AdaptiveRMSNorm(n_dim) 89 | 90 | # Input -> Query/Key/Value for each head in single tensor for speedup 91 | self.attention = nn.Linear(n_dim, 3 * n_dim_head * n_heads, bias=False) 92 | torch.nn.init.normal_(self.attention.weight, mean=0.0, std=0.02) 93 | 94 | # Attention dropout 95 | # self.attention_dropout = nn.Dropout(att_dropout) 96 | 97 | # Output flatten multiple heads into single tensor 98 | # self.attention_output = nn.Linear(n_dim_head * n_heads, n_dim, bias=False) 99 | self.attention_output = nn.Linear(n_dim_head * n_heads, n_dim) 100 | torch.nn.init.normal_(self.attention_output.weight, mean=0.0, std=0.02) 101 | torch.nn.init.zeros_(self.attention_output.bias) 102 | 103 | # MLP part 104 | self.mlp_ln = RMSNorm(n_dim) if not adaptive else AdaptiveRMSNorm(n_dim) 105 | 106 | self.mlp_input = nn.Linear(n_dim, n_dim_ffn) 107 | torch.nn.init.normal_(self.mlp_input.weight, mean=0.0, std=0.02) 108 | torch.nn.init.zeros_(self.mlp_input.bias) 109 | 110 | self.mlp_output = nn.Linear(n_dim_ffn, n_dim) 111 | torch.nn.init.normal_(self.mlp_output.weight, mean=0.0, std=0.02) 112 | torch.nn.init.zeros_(self.mlp_output.bias) 113 | 114 | self.mlp_output_dropout = nn.Dropout(ffn_dropout) 115 | 116 | def forward(self, x, condition = None, mask = None): 117 | 118 | with record_function("attention:pre"): 119 | # B, T, C = x.size() # batch size, sequence length, context width 120 | 121 | # Residual 122 | residual = x 123 | 124 | # Input normalization 125 | y = self.attention_ln(x) if not self.adaptive else self.attention_ln(x, condition = condition) 126 | 127 | # Calculation Q/K/V for each head 128 | q, k, v = self.attention(y).chunk(3, dim = -1) 129 | 130 | # 131 | # XFormers Implementation 132 | # 133 | 134 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h = self.n_heads), (q, k, v)) 135 | y = xops.memory_efficient_attention(q, k, v, p = self.att_dropout if self.training else 0.0, attn_bias = mask) 136 | y = rearrange(y, 'b n h d -> b n (h d)') 137 | 138 | # 139 | # SDPA implementation 140 | # 141 | 142 | # Reshape for head-first attention 143 | # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.n_heads), (q, k, v)) 144 | # Run through attention 145 | # y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.att_dropout if self.training else 0.0, attn_mask = mask, is_causal = casual) 146 | # Reshape back 147 | # y = rearrange(y, 'b h n d -> b n (h d)') 148 | 149 | # Output 150 | y = self.attention_output(y) 151 | 152 | # Residual 153 | y = residual + y 154 | residual = y 155 | 156 | with record_function("attention:post-post"): 157 | # MLP 158 | y = self.mlp_ln(y) if not self.adaptive else self.mlp_ln(y, condition = condition) 159 | y = self.mlp_input(y) 160 | y = F.gelu(y) 161 | y = self.mlp_output_dropout(y) 162 | y = self.mlp_output(y) 163 | y = residual + y 164 | 165 | return y -------------------------------------------------------------------------------- /supervoice_hybrid/vocoders.py: -------------------------------------------------------------------------------- 1 | from encodec import EncodecModel 2 | 3 | def load_encodec_encoder(): 4 | encodec_model = EncodecModel.encodec_model_24khz() 5 | encodec_model.set_target_bandwidth(6.0) 6 | return encodec_model 7 | 8 | def load_encodec_decoder_direct(): 9 | encodec_model = EncodecModel.encodec_model_24khz() 10 | encodec_model.set_target_bandwidth(6.0) 11 | return encodec_model -------------------------------------------------------------------------------- /tokenizer_text.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ex3ndr/supervoice-hybrid/b664b91f720180ea9d33dfe1b63fa9b33223f2e4/tokenizer_text.model -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ex3ndr/supervoice-hybrid/b664b91f720180ea9d33dfe1b63fa9b33223f2e4/train/__init__.py -------------------------------------------------------------------------------- /train/dataset.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import json 3 | import math 4 | import random 5 | import torch 6 | 7 | def load_encodec_sampler(index, dir, batch_size, tokenizer = None): 8 | 9 | # Load ids 10 | ids = [] 11 | with gzip.open(index, "r") as f: 12 | for line in f: 13 | cut = json.loads(line) 14 | id = cut["supervisions"][0]["id"] 15 | if id.startswith("small/"): 16 | id = id[len("small/"):] 17 | if id.startswith("medium/"): 18 | id = id[len("medium/"):] 19 | if id.startswith("large/"): 20 | id = id[len("large/"):] 21 | ids.append(id) 22 | 23 | def sample(): 24 | loaded = 0 25 | res_encoded = [] 26 | res_text = [] 27 | while loaded < batch_size: 28 | # Pick ID 29 | id = random.choice(ids) 30 | 31 | try: 32 | 33 | # Load text 34 | with open(dir + id + ".txt", 'r') as file: 35 | text = file.read() 36 | if tokenizer is not None: 37 | text = tokenizer.encode(text) if random.random() < 0.3 else tokenizer.encode_sample(text) # 30% chance of sampling optimal 38 | if text.shape[0] == 0: 39 | raise Exception("Empty file") 40 | 41 | # Load encoded 42 | encoded = torch.load(dir + id + ".pt") 43 | 44 | # Append 45 | res_text.append(text) 46 | res_encoded.append(encoded) 47 | loaded += 1 48 | except: 49 | print("Invalid file: " + id) 50 | pass 51 | 52 | return res_encoded, res_text 53 | 54 | return sample 55 | 56 | def load_spec_sampler(index, dir, batch_size, tokenizer = None): 57 | 58 | # Load ids 59 | ids = [] 60 | with gzip.open(index, "r") as f: 61 | for line in f: 62 | cut = json.loads(line) 63 | id = cut["supervisions"][0]["id"] 64 | if id.startswith("small/"): 65 | id = id[len("small/"):] 66 | if id.startswith("medium/"): 67 | id = id[len("medium/"):] 68 | if id.startswith("large/"): 69 | id = id[len("large/"):] 70 | ids.append(id) 71 | 72 | def sample(): 73 | loaded = 0 74 | res_encoded = [] 75 | res_text = [] 76 | while loaded < batch_size: 77 | # Pick ID 78 | id = random.choice(ids) 79 | 80 | try: 81 | 82 | # Load text 83 | with open(dir + id + ".txt", 'r') as file: 84 | text = file.read() 85 | if tokenizer is not None: 86 | text = tokenizer.encode(text) if random.random() < 0.3 else tokenizer.encode_sample(text) # 30% chance of sampling optimal 87 | if text.shape[0] == 0: 88 | raise Exception("Empty file") 89 | 90 | # Load encoded 91 | encoded = torch.load(dir + id + ".pt", map_location = "cpu") 92 | 93 | # Append 94 | res_text.append(text) 95 | res_encoded.append(encoded) 96 | loaded += 1 97 | except Exception as e: 98 | print(e) 99 | print("Invalid file: " + id) 100 | pass 101 | 102 | return res_encoded, res_text 103 | 104 | return sample 105 | 106 | def create_async_loader(sampler, num_workers = 1): 107 | 108 | # Dataset 109 | class AsyncDataset(torch.utils.data.IterableDataset): 110 | def __init__(self, sampler): 111 | self.sampler = sampler 112 | def generate(self): 113 | while True: 114 | yield self.sampler() 115 | def __iter__(self): 116 | return iter(self.generate()) 117 | dataset = AsyncDataset(sampler) 118 | 119 | # Load loader 120 | loader = torch.utils.data.DataLoader(dataset, batch_size = 1, num_workers = num_workers, pin_memory = True, shuffle=False) 121 | 122 | return loader -------------------------------------------------------------------------------- /train/misc.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import matplotlib 4 | import torch 5 | from torch.nn.utils import weight_norm 6 | import matplotlib.pyplot as plt 7 | from IPython.display import Audio, display 8 | import numpy as np 9 | 10 | # 11 | # Plotting 12 | # 13 | 14 | def plot_waveform(waveform, sample_rate=16000, title="Waveform", xlim=(0,5)): 15 | waveform = waveform.numpy() 16 | 17 | num_channels, num_frames = waveform.shape 18 | time_axis = torch.arange(0, num_frames) / sample_rate 19 | 20 | figure, axes = plt.subplots(num_channels, 1) 21 | if num_channels == 1: 22 | axes = [axes] 23 | for c in range(num_channels): 24 | axes[c].plot(time_axis, waveform[c], linewidth=1) 25 | axes[c].grid(True) 26 | if num_channels > 1: 27 | axes[c].set_ylabel(f"Channel {c+1}") 28 | if xlim: 29 | axes[c].set_xlim(xlim) 30 | figure.suptitle(title) 31 | 32 | def plot_specgram(spectrogram, title="Spectrogram"): 33 | _, axis = plt.subplots(1, 1) 34 | axis.imshow(spectrogram, cmap="viridis", vmin=-10, vmax=0, origin="lower", aspect="auto") 35 | axis.set_title(title) 36 | plt.tight_layout() 37 | 38 | # 39 | # Utilities 40 | # 41 | 42 | def exists(val): 43 | return val is not None -------------------------------------------------------------------------------- /train_variant_1.py: -------------------------------------------------------------------------------- 1 | # Ignore warnings 2 | import warnings 3 | warnings.filterwarnings("ignore") 4 | 5 | # Base 6 | import itertools 7 | from glob import glob 8 | from tqdm import tqdm 9 | import time 10 | from contextlib import nullcontext 11 | from pathlib import Path 12 | import shutil 13 | import math 14 | import random 15 | 16 | # ML 17 | import torch 18 | import wandb 19 | from accelerate import Accelerator, DistributedDataParallelKwargs 20 | from accelerate.utils import set_seed 21 | from torch.profiler import profile, record_function, ProfilerActivity 22 | 23 | # Local 24 | from supervoice_hybrid import SupervoceVariant1, SentencePieceTextTokenizer 25 | from train.dataset import load_encodec_sampler, create_async_loader 26 | 27 | # Experiment 28 | train_experiment = "var1-2" 29 | train_project="hybrid-var1" 30 | train_auto_resume = True 31 | 32 | # Training schedule and parameters 33 | train_target_batch_size = 16 34 | train_batch_size = 12 35 | train_mixed_precision = "fp16" # "bf16" or "fp16" or None 36 | train_clip_grad_norm = 1 # Common reproductions are using 100 or 1 37 | train_lr_start = 1e-12 38 | train_lr_max = 5e-4 39 | train_steps = 600000 40 | train_warmup_steps = 32000 # I am using faster warmup - it is more natural for me after working on voicebox 41 | 42 | # Utilities 43 | train_loader_workers = 32 44 | train_log_every = 1 45 | train_save_every = 1000 46 | train_watch_every = 1000 47 | 48 | # 49 | # Factory 50 | # 51 | 52 | def create_sampler(): 53 | tokenizer = SentencePieceTextTokenizer("./tokenizer_text.model") 54 | # train_sampler = load_encodec_sampler("./external_datasets/libriheavy/libriheavy_cuts_small.jsonl.gz", "./external_datasets/libriheavy-encodec/", train_batch_size, tokenizer) 55 | # train_sampler = load_encodec_sampler("./external_datasets/libriheavy/libriheavy_cuts_medium.jsonl.gz", "./external_datasets/libriheavy-medium-encodec/", train_batch_size, tokenizer) 56 | train_sampler = load_encodec_sampler("./external_datasets/libriheavy/libriheavy_cuts_large.jsonl.gz", "./external_datasets/libriheavy-large-encodec/", train_batch_size, tokenizer) 57 | return train_sampler 58 | 59 | def create_model(): 60 | return SupervoceVariant1() 61 | 62 | def do_train(accelerator, model, inputs): 63 | device = accelerator.device 64 | audio, text = inputs 65 | 66 | # Reshape inputs 67 | condition_text = [] 68 | condition_audio = [] 69 | targets = [] 70 | durations = [] 71 | for B in range(len(audio)): 72 | a = audio[B].squeeze(0)[0] 73 | t = text[B].squeeze(0) 74 | 75 | # Calculate split 76 | min_duration = 0 77 | max_duration = a.shape[0] // 3 78 | audio_split = random.randint(min_duration, max_duration) 79 | 80 | # Append 81 | condition_text.append(t.to(device, non_blocking=True)) 82 | condition_audio.append(a[:audio_split].to(device, non_blocking=True)) 83 | targets.append(a[audio_split:].to(device, non_blocking=True)) 84 | durations.append(a.shape[0]) 85 | 86 | # Forward 87 | _, loss = model( 88 | condition_text = condition_text, 89 | condition_audio = condition_audio, 90 | duration = durations, 91 | target = targets 92 | ) 93 | 94 | return loss 95 | 96 | # 97 | # Train 98 | # 99 | 100 | def main(): 101 | 102 | # Calculate gradient accumulation 103 | train_grad_accum_every = train_target_batch_size 104 | if torch.cuda.is_available() and torch.cuda.device_count() > 1: 105 | train_grad_accum_every = math.ceil(train_target_batch_size / torch.cuda.device_count()) 106 | print(f"Running with gradient accumulation every {train_grad_accum_every}") 107 | 108 | # Prepare accelerator 109 | ddp_kwargs = DistributedDataParallelKwargs() 110 | accelerator = Accelerator(log_with="wandb", kwargs_handlers=[ddp_kwargs], gradient_accumulation_steps = train_grad_accum_every, mixed_precision=train_mixed_precision) 111 | device = accelerator.device 112 | output_dir = Path("./output") 113 | output_dir.mkdir(parents=True, exist_ok=True) 114 | dtype = torch.float16 if train_mixed_precision == "fp16" else (torch.bfloat16 if train_mixed_precision == "bf16" else torch.float32) 115 | torch.backends.cuda.matmul.allow_tf32 = True 116 | torch.backends.cudnn.allow_tf32 = True 117 | lr_start = train_lr_start * accelerator.num_processes 118 | lr_max = train_lr_max * accelerator.num_processes 119 | random_suffix = ''.join(random.choices('0123456789abcdef', k=6)) 120 | run_id = f"{train_experiment}-{random_suffix}" 121 | 122 | # Prepare dataset 123 | accelerator.print("Loading sampler...") 124 | train_sampler = create_sampler() 125 | train_loader = create_async_loader(train_sampler, num_workers = train_loader_workers) 126 | train_cycle = cycle(train_loader) 127 | 128 | # Model 129 | accelerator.print("Loading model...") 130 | step = 1 131 | model = create_model() 132 | raw_model = model 133 | wd_params, no_wd_params = [], [] 134 | for param in model.parameters(): 135 | param_list = no_wd_params if param.ndim < 2 else wd_params 136 | param_list.append(param) 137 | optim = torch.optim.AdamW([{'params': wd_params}, {'params': no_wd_params, 'weight_decay': 0}], train_lr_start, betas=[0.9, 0.95],weight_decay=0.01, eps=1e-6) 138 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max = train_steps) 139 | 140 | # Checkpoint 141 | checkpoint = None 142 | if train_auto_resume and (output_dir / f"{train_experiment}.pt").exists(): 143 | checkpoint = torch.load(str(output_dir / f"{train_experiment}.pt"), map_location="cpu") 144 | step = checkpoint['step'] 145 | run_id = checkpoint['run_id'] 146 | 147 | # Accelerate 148 | model, optim = accelerator.prepare(model, optim) 149 | hps = { 150 | "train_lr_start": train_lr_start, 151 | "train_lr_max": train_lr_max, 152 | "grad_accum_every": train_grad_accum_every, 153 | "steps": train_steps, 154 | "warmup_steps": train_warmup_steps, 155 | "mixed_precision": train_mixed_precision, 156 | "clip_grad_norm": train_clip_grad_norm, 157 | } 158 | accelerator.init_trackers(train_project, config=hps, init_kwargs={"wandb":{"name":run_id, "id": run_id, "resume": "allow"}}) 159 | if accelerator.is_main_process: 160 | wandb.watch(model, log="all", log_freq=train_watch_every * train_grad_accum_every) 161 | 162 | # Save 163 | def save(): 164 | # Save step checkpoint 165 | fname = str(output_dir / f"{train_experiment}.pt") 166 | fname_step = str(output_dir / f"{train_experiment}.{step}.pt") 167 | torch.save({ 168 | 169 | # Model 170 | 'model': raw_model.state_dict(), 171 | 172 | # Optimizer 173 | 'optimizer': optim.state_dict(), 174 | 'scheduler': scheduler.state_dict(), 175 | 'scaler': accelerator.scaler.state_dict(), 176 | 'step': step, 177 | 'run_id': run_id, 178 | 179 | }, fname_step) 180 | 181 | # Overwrite main checkpoint 182 | shutil.copyfile(fname_step, fname) 183 | 184 | # Load 185 | if checkpoint is not None: 186 | raw_model.load_state_dict(checkpoint['model']) 187 | optim.load_state_dict(checkpoint['optimizer']) 188 | scheduler.load_state_dict(checkpoint['scheduler']) 189 | accelerator.scaler.load_state_dict(checkpoint['scaler']) 190 | accelerator. print(f'Loaded at #{step}') 191 | 192 | # Train step 193 | def train_step(): 194 | model.train() 195 | 196 | # Update LR 197 | if step < train_warmup_steps: 198 | lr = (lr_start + ((lr_max - lr_start) * step) / train_warmup_steps) 199 | for param_group in optim.param_groups: 200 | param_group['lr'] = lr 201 | lr = lr / accelerator.num_processes 202 | else: 203 | scheduler.step() 204 | lr = scheduler.get_last_lr()[0] / accelerator.num_processes 205 | 206 | # Load batch 207 | for _ in range(train_grad_accum_every): 208 | with accelerator.accumulate(model): 209 | 210 | # Load batch 211 | inputs = next(train_cycle) 212 | 213 | # Do train 214 | with record_function("forward"): 215 | with accelerator.autocast(): 216 | loss = do_train(accelerator, model, inputs) 217 | loss = loss / train_grad_accum_every # Rescale loss 218 | 219 | # Backprop 220 | with record_function("backward"): 221 | optim.zero_grad() 222 | accelerator.backward(loss) 223 | if accelerator.sync_gradients: 224 | accelerator.clip_grad_norm_(model.parameters(), train_clip_grad_norm) 225 | optim.step() 226 | 227 | # Log skipping step 228 | if optim.step_was_skipped: 229 | accelerator.print("Step was skipped") 230 | if torch.isnan(loss): 231 | raise ValueError("Loss is NaN") 232 | 233 | return loss * train_grad_accum_every, lr 234 | 235 | # 236 | # Start Training 237 | # 238 | 239 | accelerator.print("Training started at step", step) 240 | while step < train_steps: 241 | 242 | # Step 243 | start = time.time() 244 | loss, lr = train_step() 245 | end = time.time() 246 | 247 | # Advance 248 | step = step + 1 249 | 250 | # Summary 251 | if step % train_log_every == 0: 252 | accelerator.log({ 253 | "learning_rate": lr, 254 | "loss": loss, 255 | "scale": accelerator.scaler.get_scale() if accelerator.scaler is not None else 1.0 256 | }, step=step) 257 | accelerator.print(f'Step {step} | Loss: {loss} | LR: {lr} | Time: {end - start}') 258 | 259 | # Save 260 | if step % train_save_every == 0: 261 | save() 262 | 263 | # End training 264 | if accelerator.is_main_process: 265 | accelerator.print("Finishing training...") 266 | save() 267 | accelerator.end_training() 268 | accelerator.print('✨ Training complete!') 269 | 270 | # 271 | # Utility 272 | # 273 | 274 | def cycle(dl): 275 | while True: 276 | for data in dl: 277 | yield data 278 | 279 | if __name__ == "__main__": 280 | main() 281 | -------------------------------------------------------------------------------- /train_variant_1.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512' 3 | accelerate launch ./train_variant_1.py 4 | # while true; do 5 | # accelerate launch ./train_variant_1.py || true 6 | # done -------------------------------------------------------------------------------- /train_variant_2.py: -------------------------------------------------------------------------------- 1 | # Ignore warnings 2 | import warnings 3 | warnings.filterwarnings("ignore") 4 | 5 | # Base 6 | import itertools 7 | from glob import glob 8 | from tqdm import tqdm 9 | import time 10 | from contextlib import nullcontext 11 | from pathlib import Path 12 | import shutil 13 | import math 14 | import random 15 | 16 | # ML 17 | import torch 18 | import wandb 19 | from accelerate import Accelerator, DistributedDataParallelKwargs 20 | from accelerate.utils import set_seed 21 | from torch.profiler import profile, record_function, ProfilerActivity 22 | 23 | # Local 24 | from supervoice_hybrid import SupervoiceVariant2, SentencePieceTextTokenizer 25 | from supervoice_hybrid.config import config 26 | from train.dataset import load_spec_sampler, create_async_loader 27 | 28 | # Experiment 29 | train_experiment = "var2-1" 30 | train_project="hybrid-var2" 31 | train_auto_resume = True 32 | 33 | # Training schedule and parameters 34 | train_target_batch_size = 16 35 | train_batch_size = 12 36 | train_mixed_precision = "fp16" # "bf16" or "fp16" or None 37 | train_clip_grad_norm = 1 # Common reproductions are using 100 or 1 38 | train_lr_start = 1e-12 39 | train_lr_max = 5e-4 40 | train_steps = 600000 41 | train_warmup_steps = 6000 # I am using faster warmup - it is more natural for me after working on voicebox 42 | train_sigma = 1e-5 43 | 44 | # Utilities 45 | train_loader_workers = 32 46 | train_log_every = 1 47 | train_save_every = 1000 48 | train_watch_every = 1000 49 | 50 | # 51 | # Factory 52 | # 53 | 54 | def create_sampler(): 55 | # tokenizer = UnitTextTokenizer() 56 | tokenizer = SentencePieceTextTokenizer("./tokenizer_text.model") 57 | # train_sampler = load_spec_sampler("./external_datasets/libriheavy/libriheavy_cuts_small.jsonl.gz", "./processed_datasets/librilight/", train_batch_size, tokenizer) 58 | # train_sampler = load_spec_sampler("./external_datasets/libriheavy/libriheavy_cuts_medium.jsonl.gz", "./processed_datasets/librilight-medium/", train_batch_size, tokenizer) 59 | train_sampler = load_spec_sampler("./external_datasets/libriheavy/libriheavy_cuts_large.jsonl.gz", "./processed_datasets/librilight-large/", train_batch_size, tokenizer) 60 | return train_sampler 61 | 62 | def create_model(): 63 | return SupervoiceVariant2() 64 | 65 | def do_train(accelerator, model, inputs): 66 | device = accelerator.device 67 | audio_r, text_r = inputs 68 | 69 | # Preprocessing 70 | condition_text = [] 71 | condition_audio = [] 72 | noisy_audio = [] 73 | intervals = [] 74 | times = [] 75 | target = [] 76 | for i in range(train_batch_size): 77 | audio = audio_r[i].squeeze(0).T 78 | text = text_r[i].squeeze(0) 79 | 80 | # Normalize audio 81 | audio = (audio - config.audio.norm_mean) / config.audio.norm_std 82 | 83 | # Prepare time and noisy data 84 | time = random.uniform(0, 1) 85 | noise = torch.randn_like(audio) 86 | noisy = (1 - (1 - train_sigma) * time) * noise + time * audio 87 | target_flow = audio - (1 - train_sigma) * noise 88 | 89 | # Calculate interval 90 | interval_start = random.randint(0, math.floor(audio.shape[0] * 0.3)) 91 | interval_end = random.randint(interval_start + math.floor(audio.shape[0] * 0.7), audio.shape[0]) 92 | 93 | # 20% chance of non-conditional 94 | if random.random() < 0.2: 95 | interval_start = 0 96 | interval_end = audio.shape[0] 97 | text = torch.zeros(1).long() 98 | 99 | # Apply mask 100 | audio[interval_start:interval_end,:] = 0 101 | 102 | # Append 103 | condition_text.append(text.to(device, non_blocking=True)) 104 | condition_audio.append(audio.to(device, non_blocking=True)) 105 | noisy_audio.append(noisy.to(device, non_blocking=True)) 106 | intervals.append([interval_start, interval_end]) 107 | times.append(torch.tensor(time).to(device, non_blocking=True)) 108 | target.append(target_flow[interval_start:interval_end,:].to(device, non_blocking=True)) 109 | 110 | # Forward 111 | _, loss = model( 112 | condition_text = condition_text, 113 | condition_audio = condition_audio, 114 | noisy_audio = noisy_audio, 115 | intervals = intervals, 116 | times = times, 117 | target = target 118 | ) 119 | 120 | return loss 121 | 122 | # 123 | # Train 124 | # 125 | 126 | def main(): 127 | 128 | # Calculate gradient accumulation 129 | train_grad_accum_every = train_target_batch_size 130 | if torch.cuda.is_available() and torch.cuda.device_count() > 1: 131 | train_grad_accum_every = math.ceil(train_target_batch_size / torch.cuda.device_count()) 132 | print(f"Running with gradient accumulation every {train_grad_accum_every}") 133 | 134 | # Prepare accelerator 135 | ddp_kwargs = DistributedDataParallelKwargs() 136 | accelerator = Accelerator(log_with="wandb", kwargs_handlers=[ddp_kwargs], gradient_accumulation_steps = train_grad_accum_every, mixed_precision=train_mixed_precision) 137 | device = accelerator.device 138 | output_dir = Path("./output") 139 | output_dir.mkdir(parents=True, exist_ok=True) 140 | dtype = torch.float16 if train_mixed_precision == "fp16" else (torch.bfloat16 if train_mixed_precision == "bf16" else torch.float32) 141 | torch.backends.cuda.matmul.allow_tf32 = True 142 | torch.backends.cudnn.allow_tf32 = True 143 | lr_start = train_lr_start * accelerator.num_processes 144 | lr_max = train_lr_max * accelerator.num_processes 145 | random_suffix = ''.join(random.choices('0123456789abcdef', k=6)) 146 | run_id = f"{train_experiment}-{random_suffix}" 147 | 148 | # Prepare dataset 149 | accelerator.print("Loading sampler...") 150 | train_sampler = create_sampler() 151 | train_loader = create_async_loader(train_sampler, num_workers = train_loader_workers) 152 | train_cycle = cycle(train_loader) 153 | 154 | # Model 155 | accelerator.print("Loading model...") 156 | step = 1 157 | model = create_model() 158 | raw_model = model 159 | wd_params, no_wd_params = [], [] 160 | for param in model.parameters(): 161 | param_list = no_wd_params if param.ndim < 2 else wd_params 162 | param_list.append(param) 163 | optim = torch.optim.AdamW([{'params': wd_params}, {'params': no_wd_params, 'weight_decay': 0}], train_lr_start, betas=[0.9, 0.95],weight_decay=0.01, eps=1e-6) 164 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max = train_steps) 165 | 166 | # Checkpoint 167 | checkpoint = None 168 | if train_auto_resume and (output_dir / f"{train_experiment}.pt").exists(): 169 | checkpoint = torch.load(str(output_dir / f"{train_experiment}.pt"), map_location="cpu") 170 | step = checkpoint['step'] 171 | run_id = checkpoint['run_id'] 172 | 173 | # Accelerate 174 | model, optim = accelerator.prepare(model, optim) 175 | hps = { 176 | "train_lr_start": train_lr_start, 177 | "train_lr_max": train_lr_max, 178 | "grad_accum_every": train_grad_accum_every, 179 | "steps": train_steps, 180 | "warmup_steps": train_warmup_steps, 181 | "mixed_precision": train_mixed_precision, 182 | "clip_grad_norm": train_clip_grad_norm, 183 | } 184 | accelerator.init_trackers(train_project, config=hps, init_kwargs={"wandb":{"name":run_id, "id": run_id, "resume": "allow"}}) 185 | if accelerator.is_main_process: 186 | wandb.watch(model, log="all", log_freq=train_watch_every * train_grad_accum_every) 187 | 188 | # Save 189 | def save(): 190 | # Save step checkpoint 191 | fname = str(output_dir / f"{train_experiment}.pt") 192 | fname_step = str(output_dir / f"{train_experiment}.{step}.pt") 193 | torch.save({ 194 | 195 | # Model 196 | 'model': raw_model.state_dict(), 197 | 198 | # Optimizer 199 | 'optimizer': optim.state_dict(), 200 | 'scheduler': scheduler.state_dict(), 201 | 'scaler': accelerator.scaler.state_dict(), 202 | 'step': step, 203 | 'run_id': run_id, 204 | 205 | }, fname_step) 206 | 207 | # Overwrite main checkpoint 208 | shutil.copyfile(fname_step, fname) 209 | 210 | # Load 211 | if checkpoint is not None: 212 | raw_model.load_state_dict(checkpoint['model']) 213 | optim.load_state_dict(checkpoint['optimizer']) 214 | scheduler.load_state_dict(checkpoint['scheduler']) 215 | accelerator.scaler.load_state_dict(checkpoint['scaler']) 216 | accelerator. print(f'Loaded at #{step}') 217 | 218 | # Train step 219 | def train_step(): 220 | model.train() 221 | 222 | # Update LR 223 | if step < train_warmup_steps: 224 | lr = (lr_start + ((lr_max - lr_start) * step) / train_warmup_steps) 225 | for param_group in optim.param_groups: 226 | param_group['lr'] = lr 227 | lr = lr / accelerator.num_processes 228 | else: 229 | scheduler.step() 230 | lr = scheduler.get_last_lr()[0] / accelerator.num_processes 231 | 232 | # Load batch 233 | for _ in range(train_grad_accum_every): 234 | with accelerator.accumulate(model): 235 | 236 | # Load batch 237 | inputs = next(train_cycle) 238 | 239 | # Do train 240 | with record_function("forward"): 241 | with accelerator.autocast(): 242 | loss = do_train(accelerator, model, inputs) 243 | loss = loss / train_grad_accum_every # Rescale loss 244 | 245 | # Backprop 246 | with record_function("backward"): 247 | optim.zero_grad() 248 | accelerator.backward(loss) 249 | if accelerator.sync_gradients: 250 | accelerator.clip_grad_norm_(model.parameters(), train_clip_grad_norm) 251 | optim.step() 252 | 253 | # Log skipping step 254 | if optim.step_was_skipped: 255 | accelerator.print("Step was skipped") 256 | if torch.isnan(loss): 257 | raise ValueError("Loss is NaN") 258 | 259 | return loss * train_grad_accum_every, lr 260 | 261 | # 262 | # Start Training 263 | # 264 | 265 | accelerator.print("Training started at step", step) 266 | while step < train_steps: 267 | 268 | # Step 269 | start = time.time() 270 | loss, lr = train_step() 271 | end = time.time() 272 | 273 | # Advance 274 | step = step + 1 275 | 276 | # Summary 277 | if step % train_log_every == 0: 278 | accelerator.log({ 279 | "learning_rate": lr, 280 | "loss": loss, 281 | "scale": accelerator.scaler.get_scale() if accelerator.scaler is not None else 1.0 282 | }, step=step) 283 | accelerator.print(f'Step {step} | Loss: {loss} | LR: {lr} | Time: {end - start}') 284 | 285 | # Save 286 | if step % train_save_every == 0: 287 | save() 288 | 289 | # End training 290 | if accelerator.is_main_process: 291 | accelerator.print("Finishing training...") 292 | save() 293 | accelerator.end_training() 294 | accelerator.print('✨ Training complete!') 295 | 296 | # 297 | # Utility 298 | # 299 | 300 | def cycle(dl): 301 | while True: 302 | for data in dl: 303 | yield data 304 | 305 | if __name__ == "__main__": 306 | main() 307 | -------------------------------------------------------------------------------- /train_variant_2.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512' 3 | accelerate launch ./train_variant_2.py 4 | # while true; do 5 | # accelerate launch ./train_variant_2.py || true 6 | # done -------------------------------------------------------------------------------- /train_variant_3.py: -------------------------------------------------------------------------------- 1 | # Ignore warnings 2 | import warnings 3 | warnings.filterwarnings("ignore") 4 | 5 | # Base 6 | import itertools 7 | from glob import glob 8 | from tqdm import tqdm 9 | import time 10 | from contextlib import nullcontext 11 | from pathlib import Path 12 | import shutil 13 | import math 14 | import random 15 | 16 | # ML 17 | import torch 18 | import wandb 19 | from accelerate import Accelerator, DistributedDataParallelKwargs 20 | from accelerate.utils import set_seed 21 | from torch.profiler import profile, record_function, ProfilerActivity 22 | 23 | # Local 24 | from supervoice_hybrid import SupervoiceVariant3, SentencePieceTextTokenizer 25 | from supervoice_hybrid.config import config 26 | from train.dataset import load_spec_sampler, create_async_loader 27 | 28 | # Experiment 29 | train_experiment = "var3-1" 30 | train_project="hybrid-var3" 31 | train_auto_resume = True 32 | 33 | # Training schedule and parameters 34 | train_target_batch_size = 8 35 | train_batch_size = 2 36 | train_mixed_precision = "fp16" # "bf16" or "fp16" or None 37 | train_clip_grad_norm = 1 # Common reproductions are using 100 or 1 38 | train_lr_start = 1e-12 39 | train_lr_max = 5e-5 40 | train_steps = 60000 41 | train_warmup_steps = 1000 # I am using faster warmup - it is more natural for me after working on voicebox 42 | train_sigma = 1e-5 43 | 44 | # Utilities 45 | train_loader_workers = 32 46 | train_log_every = 1 47 | train_save_every = 1000 48 | train_watch_every = 1000 49 | 50 | # 51 | # Factory 52 | # 53 | 54 | def create_sampler(): 55 | # tokenizer = UnitTextTokenizer() 56 | tokenizer = SentencePieceTextTokenizer("./tokenizer_text.model") 57 | train_sampler = load_spec_sampler("./external_datasets/libriheavy/libriheavy_cuts_small.jsonl.gz", "./processed_datasets/librilight/", train_batch_size, tokenizer) 58 | # train_sampler = load_spec_sampler("./external_datasets/libriheavy/libriheavy_cuts_medium.jsonl.gz", "./processed_datasets/librilight-medium/", train_batch_size, tokenizer) 59 | # train_sampler = load_spec_sampler("./external_datasets/libriheavy/libriheavy_cuts_large.jsonl.gz", "./processed_datasets/librilight-large/", train_batch_size, tokenizer) 60 | return train_sampler 61 | 62 | def create_model(): 63 | flow = torch.hub.load(repo_or_dir='ex3ndr/supervoice-flow', model='flow') 64 | return SupervoiceVariant3(flow) 65 | 66 | def do_train(accelerator, model, inputs): 67 | device = accelerator.device 68 | audio_r, text_r = inputs 69 | 70 | # Preprocessing 71 | condition_text = [] 72 | condition_audio = [] 73 | noisy_audio = [] 74 | intervals = [] 75 | times = [] 76 | target = [] 77 | for i in range(train_batch_size): 78 | audio = audio_r[i].squeeze(0).T 79 | text = text_r[i].squeeze(0) 80 | 81 | # Normalize audio 82 | audio = (audio - config.audio.norm_mean) / config.audio.norm_std 83 | 84 | # Prepare time and noisy data 85 | time = random.uniform(0, 1) 86 | noise = torch.randn_like(audio) 87 | noisy = (1 - (1 - train_sigma) * time) * noise + time * audio 88 | target_flow = audio - (1 - train_sigma) * noise 89 | 90 | # Calculate interval 91 | interval_start = random.randint(0, math.floor(audio.shape[0] * 0.3)) 92 | interval_end = random.randint(interval_start + math.floor(audio.shape[0] * 0.7), audio.shape[0]) 93 | 94 | # 20% chance of non-conditional 95 | if random.random() < 0.2: 96 | interval_start = 0 97 | interval_end = audio.shape[0] 98 | text = torch.zeros(1).long() 99 | 100 | # Apply mask 101 | audio[interval_start:interval_end,:] = 0 102 | 103 | # Append 104 | condition_text.append(text.to(device, non_blocking=True)) 105 | condition_audio.append(audio.to(torch.float16).to(device, non_blocking=True)) 106 | noisy_audio.append(noisy.to(torch.float16).to(device, non_blocking=True)) 107 | intervals.append([interval_start, interval_end]) 108 | times.append(torch.tensor(time).to(device, non_blocking=True)) 109 | target.append(target_flow[interval_start:interval_end,:].to(torch.float16).to(device, non_blocking=True)) 110 | 111 | # Forward 112 | _, loss = model( 113 | condition_text = condition_text, 114 | condition_audio = condition_audio, 115 | noisy_audio = noisy_audio, 116 | intervals = intervals, 117 | times = times, 118 | target = target 119 | ) 120 | 121 | return loss 122 | 123 | # 124 | # Train 125 | # 126 | 127 | def main(): 128 | 129 | # Calculate gradient accumulation 130 | train_grad_accum_every = train_target_batch_size 131 | if torch.cuda.is_available() and torch.cuda.device_count() > 1: 132 | train_grad_accum_every = math.ceil(train_target_batch_size / torch.cuda.device_count()) 133 | print(f"Running with gradient accumulation every {train_grad_accum_every}") 134 | 135 | # Prepare accelerator 136 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 137 | accelerator = Accelerator(log_with="wandb", kwargs_handlers=[ddp_kwargs], gradient_accumulation_steps = train_grad_accum_every, mixed_precision=train_mixed_precision) 138 | device = accelerator.device 139 | output_dir = Path("./output") 140 | output_dir.mkdir(parents=True, exist_ok=True) 141 | dtype = torch.float16 if train_mixed_precision == "fp16" else (torch.bfloat16 if train_mixed_precision == "bf16" else torch.float32) 142 | torch.backends.cuda.matmul.allow_tf32 = True 143 | torch.backends.cudnn.allow_tf32 = True 144 | lr_start = train_lr_start * accelerator.num_processes 145 | lr_max = train_lr_max * accelerator.num_processes 146 | random_suffix = ''.join(random.choices('0123456789abcdef', k=6)) 147 | run_id = f"{train_experiment}-{random_suffix}" 148 | 149 | # Prepare dataset 150 | accelerator.print("Loading sampler...") 151 | train_sampler = create_sampler() 152 | train_loader = create_async_loader(train_sampler, num_workers = train_loader_workers) 153 | train_cycle = cycle(train_loader) 154 | 155 | # Model 156 | accelerator.print("Loading model...") 157 | step = 1 158 | model = create_model() 159 | raw_model = model 160 | wd_params, no_wd_params = [], [] 161 | for param in model.parameters(): 162 | param_list = no_wd_params if param.ndim < 2 else wd_params 163 | param_list.append(param) 164 | optim = torch.optim.AdamW([{'params': wd_params}, {'params': no_wd_params, 'weight_decay': 0}], train_lr_start, betas=[0.9, 0.95],weight_decay=0.01, eps=1e-6) 165 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max = train_steps) 166 | 167 | # Checkpoint 168 | checkpoint = None 169 | if train_auto_resume and (output_dir / f"{train_experiment}.pt").exists(): 170 | checkpoint = torch.load(str(output_dir / f"{train_experiment}.pt"), map_location="cpu") 171 | step = checkpoint['step'] 172 | run_id = checkpoint['run_id'] 173 | 174 | # Accelerate 175 | model, optim = accelerator.prepare(model, optim) 176 | hps = { 177 | "train_lr_start": train_lr_start, 178 | "train_lr_max": train_lr_max, 179 | "grad_accum_every": train_grad_accum_every, 180 | "steps": train_steps, 181 | "warmup_steps": train_warmup_steps, 182 | "mixed_precision": train_mixed_precision, 183 | "clip_grad_norm": train_clip_grad_norm, 184 | } 185 | accelerator.init_trackers(train_project, config=hps, init_kwargs={"wandb":{"name":run_id, "id": run_id, "resume": "allow"}}) 186 | if accelerator.is_main_process: 187 | wandb.watch(model, log="all", log_freq=train_watch_every * train_grad_accum_every) 188 | 189 | # Save 190 | def save(): 191 | # Save step checkpoint 192 | fname = str(output_dir / f"{train_experiment}.pt") 193 | fname_step = str(output_dir / f"{train_experiment}.{step}.pt") 194 | torch.save({ 195 | 196 | # Model 197 | 'model': raw_model.state_dict(), 198 | 199 | # Optimizer 200 | 'optimizer': optim.state_dict(), 201 | 'scheduler': scheduler.state_dict(), 202 | 'scaler': accelerator.scaler.state_dict(), 203 | 'step': step, 204 | 'run_id': run_id, 205 | 206 | }, fname_step) 207 | 208 | # Overwrite main checkpoint 209 | shutil.copyfile(fname_step, fname) 210 | 211 | # Load 212 | if checkpoint is not None: 213 | raw_model.load_state_dict(checkpoint['model']) 214 | optim.load_state_dict(checkpoint['optimizer']) 215 | scheduler.load_state_dict(checkpoint['scheduler']) 216 | accelerator.scaler.load_state_dict(checkpoint['scaler']) 217 | accelerator. print(f'Loaded at #{step}') 218 | 219 | # Train step 220 | def train_step(): 221 | model.train() 222 | 223 | # Update LR 224 | if step < train_warmup_steps: 225 | lr = (lr_start + ((lr_max - lr_start) * step) / train_warmup_steps) 226 | for param_group in optim.param_groups: 227 | param_group['lr'] = lr 228 | lr = lr / accelerator.num_processes 229 | else: 230 | scheduler.step() 231 | lr = scheduler.get_last_lr()[0] / accelerator.num_processes 232 | 233 | # Load batch 234 | last_loss = 0 235 | for _ in range(train_grad_accum_every): 236 | with accelerator.accumulate(model): 237 | 238 | # Load batch 239 | inputs = next(train_cycle) 240 | 241 | # Do train 242 | with record_function("forward"): 243 | with accelerator.autocast(): 244 | loss = do_train(accelerator, model, inputs) 245 | loss = loss / train_grad_accum_every # Rescale loss 246 | 247 | # Backprop 248 | with record_function("backward"): 249 | optim.zero_grad() 250 | accelerator.backward(loss) 251 | if accelerator.sync_gradients: 252 | accelerator.clip_grad_norm_(model.parameters(), train_clip_grad_norm) 253 | optim.step() 254 | 255 | # Log skipping step 256 | if optim.step_was_skipped: 257 | accelerator.print("Step was skipped") 258 | if torch.isnan(loss): 259 | raise ValueError("Loss is NaN") 260 | 261 | # Cleanup 262 | last_loss = loss.detach().cpu().item() 263 | del loss 264 | 265 | return last_loss * train_grad_accum_every, lr 266 | 267 | # 268 | # Start Training 269 | # 270 | 271 | accelerator.print("Training started at step", step) 272 | while step < train_steps: 273 | 274 | # Step 275 | start = time.time() 276 | loss, lr = train_step() 277 | end = time.time() 278 | 279 | # Advance 280 | step = step + 1 281 | 282 | # Summary 283 | if step % train_log_every == 0: 284 | accelerator.log({ 285 | "learning_rate": lr, 286 | "loss": loss, 287 | "scale": accelerator.scaler.get_scale() if accelerator.scaler is not None else 1.0 288 | }, step=step) 289 | accelerator.print(f'Step {step} | Loss: {loss} | LR: {lr} | Time: {end - start}') 290 | 291 | # Save 292 | if step % train_save_every == 0: 293 | save() 294 | 295 | # End training 296 | if accelerator.is_main_process: 297 | accelerator.print("Finishing training...") 298 | save() 299 | accelerator.end_training() 300 | accelerator.print('✨ Training complete!') 301 | 302 | # 303 | # Utility 304 | # 305 | 306 | def cycle(dl): 307 | while True: 308 | for data in dl: 309 | yield data 310 | 311 | if __name__ == "__main__": 312 | main() 313 | -------------------------------------------------------------------------------- /train_variant_3.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512' 3 | accelerate launch ./train_variant_3.py 4 | # while true; do 5 | # accelerate launch ./train_variant_3.py || true 6 | # done --------------------------------------------------------------------------------