├── .gitignore ├── LICENSE ├── README.md ├── SoundStorm.py ├── arch.png ├── core ├── __init__.py ├── bidirectional_transformer.py ├── codebook.py ├── conformer.py ├── conformer_layers.py ├── modules.py ├── transformer.py ├── vq_f16.py └── vq_modules.py ├── dataset.py ├── helper.py ├── infer.py ├── lr_schedule.py ├── requirnements.txt ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | data/ 162 | checkpoints/ 163 | samples/ 164 | output/ 165 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Rishikesh (ऋषिकेश) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SoundStorm: Efficient Parallel Audio Generation 2 | 3 | **Work In Progress ...** 4 | 5 | SoundStorm is a model for efficient, non-autoregressive audio generation. SoundStorm receives as input the semantic tokens of 6 | AudioLM, and relies on bidirectional attention and confidence-based parallel decoding to generate the tokens of a neural audio codec. 7 | 8 | ![](arch.png) 9 | 10 | ## Pre-processing and Training Scripts: 11 | 12 | ### DataSet : 13 | 14 | Pre-processing and Data format follows this: https://huggingface.co/datasets/collabora/whisperspeech 15 | 16 | 17 | 18 | ### Start Training: 19 | ``` 20 | python train.py 21 | ``` 22 | **Semantic token path:** `./data/whisperspeech/whisperspeech/librilight/stoks/` 23 | 24 | **Acoustic token path:** `./data/whisperspeech/whisperspeech/librilight/encodec-6kbps/` 25 | 26 | 27 | 28 | 29 | 30 | 31 | ## References : 32 | 33 | * MaskGIT code : https://github.com/dome272/MaskGIT-pytorch 34 | * SoundStorm : https://github.com/feng-yufei/shared_debugging_code 35 | 36 | -------------------------------------------------------------------------------- /SoundStorm.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import joblib 5 | import numpy as np 6 | import torch 7 | from einops import rearrange, reduce 8 | from einops.layers.torch import Rearrange, EinMix 9 | from encodec import EncodecModel 10 | from torch import nn 11 | import torch.nn.functional as F 12 | from core.conformer import Conformer 13 | 14 | _CONFIDENCE_OF_KNOWN_TOKENS = torch.Tensor([torch.inf]).to("cuda") 15 | 16 | def uniform(shape, min = 0, max = 1, device = None): 17 | return torch.zeros(shape, device = device).float().uniform_(0, 1) 18 | 19 | def cosine_schedule(t): 20 | return torch.cos(t * math.pi * 0.5) 21 | 22 | def gamma_func(t): 23 | return np.cos(t * np.pi / 2) 24 | 25 | 26 | def weights_init(m): 27 | classname = m.__class__.__name__ 28 | if "Linear" in classname or "Embedding" == classname: 29 | #print(f"Initializing Module {classname}.") 30 | nn.init.trunc_normal_(m.weight.data, 0.0, 0.02) 31 | # elif "Parameter" in classname: 32 | # return nn.init.trunc_normal_(m, 0.0, 0.02) 33 | 34 | def top_k(logits, thres=0.9): 35 | k = math.ceil((1 - thres) * logits.shape[-1]) 36 | val, ind = logits.topk(k, dim=-1) 37 | probs = torch.full_like(logits, float('-inf')) 38 | probs.scatter_(2, ind, val) 39 | return probs 40 | 41 | 42 | def weights_init(m): 43 | classname = m.__class__.__name__ 44 | if "Linear" in classname or "Embedding" == classname: 45 | #print(f"Initializing Module {classname}.") 46 | nn.init.trunc_normal_(m.weight.data, 0.0, 0.02) 47 | 48 | 49 | def log(t, eps=1e-10): 50 | return torch.log(t + eps) 51 | 52 | 53 | def gumbel_noise(t): 54 | noise = torch.zeros_like(t).uniform_(0, 1) 55 | return -log(-log(noise)) 56 | 57 | 58 | def gumbel_sample(t, temperature=1., dim=-1): 59 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim=dim) 60 | 61 | 62 | class SoundStorm(nn.Module): 63 | 64 | def __init__(self, dim=1024, heads=16, linear_units=4096, num_blocks=12, semantic_codebook_size=1024, 65 | semantic_num_quantizers=1, acoustic_codebook_size=1024, acoustic_num_quantizers=8, 66 | positionwise_conv_kernel_size=5, encodec=None, hubert_kmean_path=None): 67 | 68 | super().__init__() 69 | num_codes_with_mask = acoustic_codebook_size + 1 70 | sos_token = 0 71 | self.steps = [8, 1, 1, 1, 1, 1, 1, 1] 72 | self.filter_threshold = 0.7 73 | self.temperature = 0.5 74 | 75 | self.ignore_index = 1025 76 | self.n_q = acoustic_num_quantizers 77 | 78 | self.semantic_embeds = nn.Embedding((semantic_codebook_size + 2) * semantic_num_quantizers, dim) 79 | 80 | self.code_embeds = nn.ModuleList( 81 | [ 82 | nn.Embedding(num_codes_with_mask + 2, dim) 83 | for _ in range(acoustic_num_quantizers) 84 | ] 85 | ) 86 | 87 | 88 | self.mask_token_id = num_codes_with_mask 89 | self.mask_upper_level = num_codes_with_mask 90 | 91 | self.sos_tokens = sos_token 92 | 93 | self.lm = Conformer( 94 | attention_dim=dim, 95 | attention_heads=heads, 96 | linear_units=linear_units, 97 | num_blocks=num_blocks, 98 | positionwise_conv_kernel_size=positionwise_conv_kernel_size 99 | ) 100 | 101 | self.heads = nn.Sequential( 102 | nn.Linear(dim, dim * acoustic_num_quantizers), 103 | Rearrange('b n (h d) -> b (n h) d', h=acoustic_num_quantizers), 104 | nn.GELU(), 105 | nn.LayerNorm(dim, eps=1e-12), 106 | Rearrange('b (n q) d -> b n q d', q=acoustic_num_quantizers) 107 | ) 108 | 109 | 110 | self.bias = nn.ParameterList([ 111 | nn.Parameter(torch.zeros(num_codes_with_mask + 2)) 112 | for _ in range(acoustic_num_quantizers) 113 | 114 | ] 115 | ) 116 | 117 | self.to_logits = nn.Sequential( 118 | nn.LayerNorm(dim), 119 | EinMix( 120 | 'b n q d -> b n q l', 121 | weight_shape='q d l', 122 | bias_shape='q l', 123 | q=acoustic_num_quantizers, 124 | l=num_codes_with_mask + 2, 125 | d=dim 126 | ) 127 | ) 128 | 129 | # project the dimension of semantic tokens to model dimension 130 | self.sem_cond_proj = nn.Linear(dim, dim) 131 | 132 | self.loss = nn.CrossEntropyLoss(reduction='mean', ignore_index=self.ignore_index) 133 | self.apply(weights_init) 134 | 135 | if encodec is not None: 136 | self._read_embedding_from_encodec(encodec) 137 | 138 | if hubert_kmean_path is not None: 139 | self._read_embedding_from_hubert_kmeans(hubert_kmean_path) 140 | def _read_embedding_from_encodec(self, encodec: EncodecModel): 141 | for i, layer in enumerate(encodec.quantizer.vq.layers[:self.n_q]): 142 | layer_weight = layer.codebook 143 | layer_dim = layer_weight.size(1) 144 | code_per_layer = layer_weight.size(0) 145 | assert code_per_layer == 1024 146 | self.code_embeds[i].weight.data[:code_per_layer, :layer_dim] = layer_weight.clone().data 147 | 148 | def _read_embedding_from_hubert_kmeans(self, km_path: str): 149 | km_model = joblib.load(km_path) 150 | centers = km_model.cluster_centers_.transpose() 151 | centers = torch.tensor(centers, dtype=torch.float32).transpose(0, 1) 152 | self.semantic_embeds.weight.data[:centers.size(0), :centers.size(1)] = centers.clone() 153 | 154 | def level_mask(self, code, seq_len, b, t, device): 155 | rand_times = torch.empty(b, device=device).uniform_(0, 1) 156 | batched_randperm = torch.rand((b, seq_len - t), device=device).argsort(dim=-1).float() 157 | 158 | rand_probs = cosine_schedule(rand_times) 159 | 160 | num_tokens_mask = (rand_probs * (seq_len - t)).clamp(min=1.) 161 | 162 | mask = batched_randperm < rearrange(num_tokens_mask, 'b -> b 1') 163 | prompt_mask = torch.ones((b, t), device=device).eq(0) 164 | mask = torch.cat([prompt_mask, mask], dim=1) 165 | labels = torch.where(mask, code, self.ignore_index) 166 | 167 | code = torch.where(mask, self.mask_token_id, code) 168 | 169 | return code, labels 170 | 171 | def fine_mask(self, code, t): 172 | code[:, t:] = self.mask_upper_level 173 | return code 174 | 175 | 176 | def masking(self, codes, q=None, t=None): 177 | seq_len = codes.shape[1] 178 | batch = codes.shape[0] 179 | codes = rearrange(codes, 'b n q -> q b n') 180 | 181 | masked_codes = [] 182 | 183 | for i, code in enumerate(codes): 184 | if q == i: 185 | c, label = self.level_mask(code, seq_len, batch, t, codes.device) 186 | masked_codes.append(c) 187 | elif i > q: 188 | masked_codes.append(self.fine_mask(code, t)) 189 | else: 190 | masked_codes.append(code) 191 | 192 | return masked_codes, label 193 | 194 | 195 | 196 | 197 | 198 | def forward(self, cond, codes, return_loss=True): 199 | """ 200 | cond: [B, Len] 201 | codes: [B, N_q, Len] 202 | """ 203 | 204 | b, q, n = codes.shape 205 | 206 | codes = rearrange(codes, 'b q n -> b n q', q=q) 207 | 208 | q = random.randint(0, self.n_q - 1) 209 | t = random.randint(0, codes.shape[1] - 1) 210 | 211 | masked_codes, labels = self.masking(codes, q, t) 212 | 213 | masked_codes = torch.stack(masked_codes, dim=0) 214 | masked_codes = rearrange(masked_codes, 'q b n -> b n q') 215 | 216 | 217 | emb = None 218 | 219 | 220 | for i, layer in enumerate(self.code_embeds): 221 | if emb is None: 222 | emb = layer(masked_codes[:, :, i].unsqueeze(-1)).squeeze(-2) 223 | else: 224 | emb = emb + layer(masked_codes[:, :, i].unsqueeze(-1)).squeeze(-2) 225 | 226 | 227 | semb = self.semantic_embeds(cond) # [B, n, d] 228 | 229 | semb = self.sem_cond_proj(semb) 230 | 231 | # emb = reduce(emb, 'b n q d -> b n d', 'sum') # [B, n, d] 232 | 233 | emb = emb + semb 234 | 235 | out, _ = self.lm(emb, None) # [B, n, d] 236 | 237 | out = self.heads(out) # [B, q*n, d] 238 | 239 | logits = self.to_logits(out) # [B, n, q, d] 240 | 241 | #logits = torch.matmul(out[:, :, q], self.code_embeds[q].weight.T) + self.bias[q] 242 | 243 | if return_loss: 244 | logits = logits[:, :, q] # [B, n, d] 245 | 246 | loss = F.cross_entropy( 247 | rearrange(logits, 'b n c -> b c n'), 248 | labels, 249 | ignore_index = self.ignore_index 250 | ) 251 | 252 | return loss, logits, labels 253 | return logits, out 254 | 255 | 256 | 257 | 258 | def tokens_to_logits(self, semb, input_codes): 259 | # sum the embedding of all (unmasked / masked quantizer layers) [B, n, q] 260 | emb = semb 261 | for i, layer in enumerate(self.code_embeds): 262 | emb = emb + layer(input_codes[:, :, i]) 263 | 264 | out, _ = self.lm(emb, None) # [B, n, d] 265 | out = self.heads(out) # [B, q*n, d] 266 | logits = self.to_logits(out) # [B, n, q, d] 267 | 268 | return logits 269 | 270 | def mask_by_random_topk(self, mask_len, probs, temperature=1.0): 271 | confidence = torch.log(probs) + temperature * torch.distributions.gumbel.Gumbel(0, 1).sample(probs.shape).to("cuda") 272 | sorted_confidence, _ = torch.sort(confidence, dim=-1) 273 | # Obtains cut off threshold given the mask lengths. 274 | cut_off = torch.take_along_dim(sorted_confidence, mask_len.to(torch.long), dim=-1) 275 | # Masks tokens with lower confidence. 276 | masking = (confidence < cut_off) 277 | return masking 278 | 279 | @torch.no_grad() 280 | def generate(self, conds, codes): 281 | # conds : B, Len 282 | # codes : B, N_q, Len 283 | # clip the first 3 sec of ground truth as prompt, remove rest 284 | # if sample too short, use first half 285 | # currently we assume we know the ground-truth length to generate, needs to be replaced in the future 286 | 287 | num_latents_input = int(conds.size(1)) # Scale by 1.5 because HuBERT is 50Hz, Encodec is 75Hz 288 | num_prompt = min(int(num_latents_input * 0.5), 225) # Default is 3 seconds (3*75Hz = 225 frames) 289 | 290 | b, q, n = codes.shape 291 | 292 | codes = rearrange(codes, 'b q n -> b n q', q=q) 293 | 294 | prompt = codes[:, :num_prompt, :] 295 | device = next(self.lm.parameters()).device 296 | num_latents_to_generate = num_latents_input - num_prompt 297 | batch_size = 1 298 | 299 | acoustic_len = num_latents_input 300 | semantic_len = conds.size(1) 301 | 302 | # upsample sem tokens 303 | semb = self.semantic_embeds(conds) # [B, n, d] 304 | # fetch_idx = torch.arange(0, acoustic_len).to(semb.device) * 2 / 3 305 | # fetch_idx_int = fetch_idx.to(torch.int64).clamp(0, semantic_len - 1) 306 | # fetch_idx_res = fetch_idx - fetch_idx_int 307 | # sem_cond_upscale = semb[:, fetch_idx_int] * (1 - fetch_idx_res).unsqueeze(0).unsqueeze(2) \ 308 | # + semb[:, (fetch_idx_int + 1).clamp(0, semantic_len - 1)] * fetch_idx_res.unsqueeze( 309 | # 0).unsqueeze(2) 310 | semb = self.sem_cond_proj(semb) 311 | 312 | # sequence starts off as all masked 313 | seq_len = num_latents_to_generate 314 | shape = (batch_size, seq_len, 8) 315 | seq = torch.full(shape, self.mask_token_id, device=device) 316 | mask = torch.full(shape, True, device=device) 317 | 318 | # from lucidrain's inference code 319 | for rvq_layer in range(8): 320 | 321 | # Calculate number of tokens to have masked at each time step 322 | iter_steps = self.steps[rvq_layer] 323 | times = torch.linspace(0., 1., iter_steps + 1) 324 | all_mask_num_tokens = (cosine_schedule(times[1:]) * seq_len).long() 325 | 326 | for mask_num_tokens, steps_until_x0 in zip(all_mask_num_tokens.tolist(), reversed(range(iter_steps))): 327 | 328 | logits = self.tokens_to_logits(semb, torch.cat([prompt, seq], dim=1)) 329 | logits = logits.view(batch_size, num_latents_to_generate + num_prompt, 8, 1025) 330 | logits = logits[:, num_prompt:, rvq_layer, 331 | :] # Get the logits we want to consider (post-prompt and on given RVQ layer) 332 | 333 | # Top codebook vector index for each of the timestamps 334 | logits = top_k(logits, 335 | self.filter_threshold) # Remove logits below a certain threshold (convert to -inf) 336 | sampled_ids = gumbel_sample(logits, temperature=max(self.temperature, 1e-3)) 337 | 338 | # Temporarily replace all tokens where mask is still True with sample tokens, will be undone below after mask is recomputed 339 | # Only tokens that are unmasked in the update will be kept 340 | seq[:, :, rvq_layer] = torch.where(mask[:, :, rvq_layer], sampled_ids, seq[:, :, rvq_layer]) 341 | 342 | scores = 1 - logits.softmax(dim=-1) 343 | scores = scores.gather(2, rearrange(sampled_ids, 'b n -> b n 1')) # gather the logits that it sampled 344 | scores = rearrange(scores, 'b n 1 -> b n') 345 | 346 | # No more tokens left to unmask, move to next RVQ layer 347 | if mask_num_tokens == 0: 348 | continue 349 | 350 | # Remove scores corresponding to positions that have already been unmasked 351 | scores = scores.masked_fill(~mask[:, :, rvq_layer], -torch.finfo(scores.dtype).max) 352 | 353 | # High score = low probability logit value so select the highest `mask_num_tokens` to remain masked after this step 354 | mask_indices = scores.topk(mask_num_tokens, dim=-1).indices 355 | mask[:, :, rvq_layer] = torch.zeros_like(scores, dtype=torch.bool).scatter(1, mask_indices, True) 356 | # Update seq with the newly calculated mask 357 | seq[:, :, rvq_layer] = seq[:, :, rvq_layer].masked_fill(mask[:, :, rvq_layer], self.mask_token_id) 358 | 359 | out = torch.cat([prompt, seq], dim=1) 360 | return out 361 | 362 | 363 | 364 | 365 | 366 | def num_params(model, print_out=True): 367 | parameters = filter(lambda p: p.requires_grad, model.parameters()) 368 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 369 | if print_out: 370 | print("Trainable Parameters: %.3fM" % parameters) 371 | 372 | 373 | if __name__ == '__main__': 374 | 375 | cond = torch.randint(1, 1024, (2, 20)).long() 376 | codes = torch.randint(1, 1024, (2, 8, 20)).long() 377 | 378 | model = SoundStorm() 379 | 380 | num_params(model) 381 | 382 | 383 | logits, out, mask = model(cond, codes) 384 | 385 | 386 | -------------------------------------------------------------------------------- /arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/SoundStorm-pytorch/2fefa0c3805ad6d043ae9ab39af4a71a4e2c2e33/arch.png -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rishikksh20/SoundStorm-pytorch/2fefa0c3805ad6d043ae9ab39af4a71a4e2c2e33/core/__init__.py -------------------------------------------------------------------------------- /core/bidirectional_transformer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def weights_init(m): 7 | classname = m.__class__.__name__ 8 | if "Linear" in classname or "Embedding" == classname: 9 | print(f"Initializing Module {classname}.") 10 | nn.init.trunc_normal_(m.weight.data, 0.0, 0.02) 11 | # elif "Parameter" in classname: 12 | # return nn.init.trunc_normal_(m, 0.0, 0.02) 13 | 14 | 15 | class Attention(nn.Module): 16 | """ 17 | Simple Self-Attention algorithm. Potential for optimization using a non-quadratic attention mechanism in complexity. 18 | -> Linformer, Reformer etc. 19 | """ 20 | def __init__(self, dim=768, heads=8): 21 | super(Attention, self).__init__() 22 | d = dim // heads 23 | self.q, self.k, self.v = nn.Linear(dim, d), nn.Linear(dim, d), nn.Linear(dim, d) 24 | self.norm = d ** 0.5 25 | self.dropout = nn.Dropout(p=0.1) 26 | 27 | def forward(self, x): 28 | q, k, v = self.q(x), self.k(x), self.v(x) 29 | qk = torch.softmax(q @ torch.transpose(k, 1, 2) / self.norm, dim=1) 30 | qk = self.dropout(qk) 31 | attn = torch.matmul(qk, v) 32 | return attn 33 | 34 | 35 | class MultiHeadAttention(nn.Module): 36 | """ 37 | Implementation of MultiHeadAttention, splitting it up to multiple Self-Attention layers and concatenating 38 | the results and subsequently running it through one linear layer of same dimension. 39 | """ 40 | def __init__(self, dim=768, heads=8): 41 | super(MultiHeadAttention, self).__init__() 42 | self.self_attention_heads = nn.ModuleList([Attention(dim, heads) for _ in range(heads)]) 43 | self.projector = nn.Linear(dim, dim) 44 | 45 | def forward(self, x): 46 | for i, sa_head in enumerate(self.self_attention_heads): 47 | if i == 0: 48 | out = sa_head(x) 49 | else: 50 | out = torch.cat((out, sa_head(x)), axis=-1) 51 | out = self.projector(out) 52 | return out 53 | 54 | 55 | class PositionalEmbedding(nn.Module): 56 | def __init__(self, d_model, max_len=512): 57 | super().__init__() 58 | 59 | # Compute the positional encodings once in log space. 60 | pe = torch.zeros(max_len, d_model).float() 61 | pe.require_grad = False 62 | 63 | position = torch.arange(0, max_len).float().unsqueeze(1) 64 | div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() 65 | 66 | pe[:, 0::2] = torch.sin(position * div_term) 67 | pe[:, 1::2] = torch.cos(position * div_term) 68 | 69 | pe = pe.unsqueeze(0) 70 | self.register_buffer('pe', pe) 71 | 72 | def forward(self, x): 73 | return self.pe[:, :x.size(1)] 74 | 75 | 76 | class Encoder(nn.Module): 77 | """ 78 | Transformer encoder using MultiHeadAttention and MLP along with skip connections and LayerNorm 79 | """ 80 | def __init__(self, dim=768, hidden_dim=3072): 81 | super(Encoder, self).__init__() 82 | # self.MultiHeadAttention = MultiHeadAttention(dim) 83 | self.MultiHeadAttention = nn.MultiheadAttention(dim, num_heads=8, batch_first=True, dropout=0.1) 84 | self.LayerNorm1 = nn.LayerNorm(dim, eps=1e-12) 85 | self.LayerNorm2 = nn.LayerNorm(dim, eps=1e-12) 86 | self.MLP = nn.Sequential(*[ 87 | nn.Linear(dim, hidden_dim), 88 | nn.GELU(), 89 | nn.Linear(hidden_dim, dim), 90 | nn.Dropout(p=0.1) 91 | ]) 92 | self.dropout = nn.Dropout(p=0.1) 93 | 94 | def forward(self, x): 95 | # attn = self.MultiHeadAttention(x) 96 | attn, _ = self.MultiHeadAttention(x, x, x, need_weights=False) 97 | attn = self.dropout(attn) 98 | x = x.add(attn) 99 | x = self.LayerNorm1(x) 100 | mlp = self.MLP(x) 101 | x = x.add(mlp) 102 | x = self.LayerNorm2(x) 103 | return x 104 | 105 | 106 | class BidirectionalTransformer(nn.Module): 107 | def __init__(self, args): 108 | super(BidirectionalTransformer, self).__init__() 109 | self.num_image_tokens = args.num_image_tokens 110 | self.tok_emb = nn.Embedding(args.num_codebook_vectors + 2, args.dim) 111 | # self.pos_emb = PositionalEmbedding(args.dim, self.num_image_tokens + 1) 112 | self.pos_emb = nn.init.trunc_normal_(nn.Parameter(torch.zeros(self.num_image_tokens + 1, args.dim)), 0., 0.02) 113 | # self.register_buffer("pos_emb", nn.init.trunc_normal_(nn.Parameter(torch.zeros(1024, args.dim)), 0., 0.02)) 114 | self.blocks = nn.Sequential(*[Encoder(args.dim, args.hidden_dim) for _ in range(args.n_layers)]) 115 | self.Token_Prediction = nn.Sequential(*[ 116 | nn.Linear(in_features=args.dim, out_features=args.dim), 117 | nn.GELU(), 118 | nn.LayerNorm(args.dim, eps=1e-12) 119 | ]) 120 | self.bias = nn.Parameter(torch.zeros(self.num_image_tokens+1, args.num_codebook_vectors + 2)) 121 | self.ln = nn.LayerNorm(args.dim, eps=1e-12) 122 | self.drop = nn.Dropout(p=0.1) 123 | self.apply(weights_init) 124 | 125 | def forward(self, x): 126 | token_embeddings = self.tok_emb(x) 127 | t = token_embeddings.shape[1] 128 | position_embeddings = self.pos_emb[:t, :] 129 | # position_embeddings = self.pos_emb(x) 130 | embed = self.drop(self.ln(token_embeddings + position_embeddings)) 131 | embed = self.blocks(embed) 132 | embed = self.Token_Prediction(embed) 133 | logits = torch.matmul(embed, self.tok_emb.weight.T) + self.bias 134 | 135 | return logits 136 | -------------------------------------------------------------------------------- /core/codebook.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Codebook(nn.Module): 6 | """ 7 | Codebook mapping: takes in an encoded image and maps each vector onto its closest codebook vector. 8 | Metric: mean squared error = (z_e - z_q)**2 = (z_e**2) - (2*z_e*z_q) + (z_q**2) 9 | """ 10 | 11 | def __init__(self, args): 12 | super().__init__() 13 | self.num_codebook_vectors = args.num_codebook_vectors 14 | self.latent_dim = args.latent_dim 15 | self.beta = args.beta 16 | 17 | self.embedding = nn.Embedding(self.num_codebook_vectors, self.latent_dim) 18 | self.embedding.weight.data.uniform_(-1.0 / self.num_codebook_vectors, 1.0 / self.num_codebook_vectors) 19 | 20 | def forward(self, z): 21 | z = z.permute(0, 2, 3, 1).contiguous() 22 | z_flattened = z.view(-1, self.latent_dim) 23 | 24 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 25 | torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ 26 | torch.matmul(z_flattened, self.embedding.weight.t()) 27 | 28 | min_encoding_indices = torch.argmin(d, dim=1) 29 | z_q = self.embedding(min_encoding_indices).view(z.shape) 30 | 31 | loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean((z_q - z.detach()) ** 2) 32 | 33 | # preserve gradients 34 | z_q = z + (z_q - z).detach() # moving average instead of hard codebook remapping 35 | 36 | z_q = z_q.permute(0, 3, 1, 2) 37 | 38 | return z_q, min_encoding_indices, loss -------------------------------------------------------------------------------- /core/conformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from core.modules import ConvolutionModule 3 | from core.conformer_layers import EncoderLayer, MultiHeadedAttention 4 | from core.modules import LayerNorm 5 | from core.modules import MultiLayeredConv1d 6 | from core.modules import RotaryEmbedding 7 | from core.modules import Swish 8 | from core.modules import repeat 9 | 10 | 11 | 12 | class Conformer(torch.nn.Module): 13 | """ 14 | Conformer encoder module. 15 | Args: 16 | idim (int): Input dimension. 17 | attention_dim (int): Dimension of attention. 18 | attention_heads (int): The number of heads of multi head attention. 19 | linear_units (int): The number of units of position-wise feed forward. 20 | num_blocks (int): The number of decoder blocks. 21 | dropout_rate (float): Dropout rate. 22 | positional_dropout_rate (float): Dropout rate after adding positional encoding. 23 | attention_dropout_rate (float): Dropout rate in attention. 24 | input_layer (Union[str, torch.nn.Module]): Input layer type. 25 | normalize_before (bool): Whether to use layer_norm before the first block. 26 | concat_after (bool): Whether to concat attention layer's input and output. 27 | if True, additional linear will be applied. 28 | i.e. x -> x + linear(concat(x, att(x))) 29 | if False, no additional linear will be applied. i.e. x -> x + att(x) 30 | positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear". 31 | positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer. 32 | macaron_style (bool): Whether to use macaron style for positionwise layer. 33 | pos_enc_layer_type (str): Conformer positional encoding layer type. 34 | selfattention_layer_type (str): Conformer attention layer type. 35 | activation_type (str): Conformer activation function type. 36 | use_cnn_module (bool): Whether to use convolution module. 37 | cnn_module_kernel (int): Kernerl size of convolution module. 38 | padding_idx (int): Padding idx for input_layer=embed. 39 | """ 40 | 41 | def __init__(self, attention_dim=256, attention_heads=4, linear_units=2048, num_blocks=6, dropout_rate=0.1, attention_dropout_rate=0.0, 42 | normalize_before=True, concat_after=False, positionwise_conv_kernel_size=1, macaron_style=False, 43 | use_cnn_module=False, cnn_module_kernel=31): 44 | super(Conformer, self).__init__() 45 | 46 | activation = Swish() 47 | self.conv_subsampling_factor = 1 48 | 49 | self.rotary_emb = RotaryEmbedding(attention_dim//attention_heads) 50 | 51 | self.normalize_before = normalize_before 52 | 53 | 54 | 55 | 56 | # self-attention module definition 57 | encoder_selfattn_layer = MultiHeadedAttention 58 | encoder_selfattn_layer_args = (attention_heads, attention_dim, attention_dropout_rate) 59 | 60 | # feed-forward module definition 61 | positionwise_layer = MultiLayeredConv1d 62 | positionwise_layer_args = (attention_dim, linear_units, positionwise_conv_kernel_size, dropout_rate,) 63 | 64 | # convolution module definition 65 | convolution_layer = ConvolutionModule 66 | convolution_layer_args = (attention_dim, cnn_module_kernel, activation) 67 | 68 | self.encoders = repeat(num_blocks, lambda lnum: EncoderLayer(attention_dim, encoder_selfattn_layer(*encoder_selfattn_layer_args), 69 | positionwise_layer(*positionwise_layer_args), 70 | positionwise_layer(*positionwise_layer_args) if macaron_style else None, 71 | convolution_layer(*convolution_layer_args) if use_cnn_module else None, dropout_rate, 72 | normalize_before, concat_after)) 73 | if self.normalize_before: 74 | self.after_norm = LayerNorm(attention_dim) 75 | 76 | 77 | def forward(self, xs, masks, embeds = None): 78 | """ 79 | Encode input sequence. 80 | Args: 81 | utterance_embedding: embedding containing lots of conditioning signals 82 | step: indicator for when to start updating the embedding function 83 | xs (torch.Tensor): Input tensor (#batch, time, idim). 84 | masks (torch.Tensor): Mask tensor (#batch, time). 85 | Returns: 86 | torch.Tensor: Output tensor (#batch, time, attention_dim). 87 | torch.Tensor: Mask tensor (#batch, time). 88 | """ 89 | 90 | rotary_emb = self.rotary_emb(xs.shape[-2]) 91 | 92 | 93 | xs, masks, _, _, _ = self.encoders(xs, masks, None, embeds, rotary_emb) 94 | # if isinstance(xs, tuple): 95 | # xs = xs[0] 96 | 97 | 98 | if self.normalize_before: 99 | xs = self.after_norm(xs) 100 | 101 | return xs, masks -------------------------------------------------------------------------------- /core/conformer_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from core.modules import LayerNorm, apply_rotary_pos_emb 4 | import math 5 | 6 | import numpy 7 | import torch 8 | from torch import nn 9 | 10 | 11 | class MultiHeadedAttention(nn.Module): 12 | """ 13 | Multi-Head Attention layer. 14 | Args: 15 | n_head (int): The number of heads. 16 | n_feat (int): The number of features. 17 | dropout_rate (float): Dropout rate. 18 | """ 19 | 20 | def __init__(self, n_head, n_feat, dropout_rate): 21 | """ 22 | Construct an MultiHeadedAttention object. 23 | """ 24 | super(MultiHeadedAttention, self).__init__() 25 | assert n_feat % n_head == 0 26 | # We assume d_v always equals d_k 27 | self.d_k = n_feat // n_head 28 | self.h = n_head 29 | self.linear_q = nn.Linear(n_feat, n_feat) 30 | self.linear_k = nn.Linear(n_feat, n_feat) 31 | self.linear_v = nn.Linear(n_feat, n_feat) 32 | self.linear_out = nn.Linear(n_feat, n_feat) 33 | self.attn = None 34 | self.dropout = nn.Dropout(p=dropout_rate) 35 | 36 | def forward_qkv(self, query, key, value): 37 | """ 38 | Transform query, key and value. 39 | Args: 40 | query (torch.Tensor): Query tensor (#batch, time1, size). 41 | key (torch.Tensor): Key tensor (#batch, time2, size). 42 | value (torch.Tensor): Value tensor (#batch, time2, size). 43 | Returns: 44 | torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). 45 | torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). 46 | torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). 47 | """ 48 | n_batch = query.size(0) 49 | q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) 50 | k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) 51 | v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) 52 | q = q.transpose(1, 2) # (batch, head, time1, d_k) 53 | k = k.transpose(1, 2) # (batch, head, time2, d_k) 54 | v = v.transpose(1, 2) # (batch, head, time2, d_k) 55 | 56 | return q, k, v 57 | 58 | def forward_attention(self, value, scores, mask): 59 | """ 60 | Compute attention context vector. 61 | Args: 62 | value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). 63 | scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). 64 | mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). 65 | Returns: 66 | torch.Tensor: Transformed value (#batch, time1, d_model) 67 | weighted by the attention score (#batch, time1, time2). 68 | """ 69 | n_batch = value.size(0) 70 | if mask is not None: 71 | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) 72 | min_value = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min) 73 | scores = scores.masked_fill(mask, min_value) 74 | self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) 75 | else: 76 | self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) 77 | 78 | p_attn = self.dropout(self.attn) 79 | x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) 80 | x = (x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)) # (batch, time1, d_model) 81 | 82 | return self.linear_out(x) # (batch, time1, d_model) 83 | 84 | def forward(self, query, key, value, mask, rotary_emb=None): 85 | """ 86 | Compute scaled dot product attention. 87 | Args: 88 | query (torch.Tensor): Query tensor (#batch, time1, size). 89 | key (torch.Tensor): Key tensor (#batch, time2, size). 90 | value (torch.Tensor): Value tensor (#batch, time2, size). 91 | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or 92 | (#batch, time1, time2). 93 | Returns: 94 | torch.Tensor: Output tensor (#batch, time1, d_model). 95 | """ 96 | q, k, v = self.forward_qkv(query, key, value) 97 | 98 | if rotary_emb is not None: 99 | q = apply_rotary_pos_emb(rotary_emb, q) 100 | k = apply_rotary_pos_emb(rotary_emb, k) 101 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) 102 | return self.forward_attention(v, scores, mask) 103 | 104 | 105 | class EncoderLayer(nn.Module): 106 | """ 107 | Encoder layer module. 108 | Args: 109 | size (int): Input dimension. 110 | self_attn (torch.nn.Module): Self-attention module instance. 111 | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance 112 | can be used as the argument. 113 | feed_forward (torch.nn.Module): Feed-forward module instance. 114 | `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance 115 | can be used as the argument. 116 | feed_forward_macaron (torch.nn.Module): Additional feed-forward module instance. 117 | `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance 118 | can be used as the argument. 119 | conv_module (torch.nn.Module): Convolution module instance. 120 | `ConvlutionModule` instance can be used as the argument. 121 | dropout_rate (float): Dropout rate. 122 | normalize_before (bool): Whether to use layer_norm before the first block. 123 | concat_after (bool): Whether to concat attention layer's input and output. 124 | if True, additional linear will be applied. 125 | i.e. x -> x + linear(concat(x, att(x))) 126 | if False, no additional linear will be applied. i.e. x -> x + att(x) 127 | """ 128 | 129 | def __init__(self, size, self_attn, feed_forward, feed_forward_macaron, conv_module, dropout_rate, normalize_before=True, concat_after=False, 130 | ): 131 | super(EncoderLayer, self).__init__() 132 | self.self_attn = self_attn 133 | self.feed_forward = feed_forward 134 | self.feed_forward_macaron = feed_forward_macaron 135 | self.conv_module = conv_module 136 | 137 | 138 | self.norm_ff = LayerNorm(size) # for the FNN module 139 | self.norm_mha = LayerNorm(size) # for the MHA module 140 | if feed_forward_macaron is not None: 141 | self.norm_ff_macaron = LayerNorm(size) 142 | self.ff_scale = 0.5 143 | else: 144 | self.ff_scale = 1.0 145 | if self.conv_module is not None: 146 | self.norm_conv = LayerNorm(size) # for the CNN module 147 | self.norm_final = LayerNorm(size) # for the final output of the block 148 | self.dropout = nn.Dropout(dropout_rate) 149 | self.size = size 150 | self.normalize_before = normalize_before 151 | self.concat_after = concat_after 152 | if self.concat_after: 153 | self.concat_linear = nn.Linear(size + size, size) 154 | 155 | def forward(self, x_input, mask, cache=None, embeds = None, rotary_emb=None): 156 | """ 157 | Compute encoded features. 158 | Args: 159 | x_input (Union[Tuple, torch.Tensor]): Input tensor w/ or w/o pos emb. 160 | - w/ pos emb: Tuple of tensors [(#batch, time, size), (1, time, size)]. 161 | - w/o pos emb: Tensor (#batch, time, size). 162 | mask (torch.Tensor): Mask tensor for the input (#batch, time). 163 | cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). 164 | Returns: 165 | torch.Tensor: Output tensor (#batch, time, size). 166 | torch.Tensor: Mask tensor (#batch, time). 167 | """ 168 | x = x_input 169 | 170 | # whether to use macaron style 171 | if self.feed_forward_macaron is not None: 172 | residual = x 173 | if self.normalize_before: 174 | if self.cond_layer_norm: 175 | x = self.norm_ff_macaron(x, embeds) 176 | else: 177 | x = self.norm_ff_macaron(x) 178 | x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) 179 | if not self.normalize_before: 180 | if self.cond_layer_norm: 181 | x = self.norm_ff_macaron(x, embeds) 182 | else: 183 | x = self.norm_ff_macaron(x) 184 | 185 | # multi-headed self-attention module 186 | residual = x 187 | if self.normalize_before: 188 | x = self.norm_mha(x) 189 | 190 | if cache is None: 191 | x_q = x 192 | else: 193 | assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) 194 | x_q = x[:, -1:, :] 195 | residual = residual[:, -1:, :] 196 | mask = None if mask is None else mask[:, -1:, :] 197 | 198 | x_att = self.self_attn(x_q, x, x, mask, rotary_emb=rotary_emb) 199 | 200 | if self.concat_after: 201 | x_concat = torch.cat((x, x_att), dim=-1) 202 | x = residual + self.concat_linear(x_concat) 203 | else: 204 | x = residual + self.dropout(x_att) 205 | if not self.normalize_before: 206 | if self.cond_layer_norm: 207 | x = self.norm_mha(x, embeds) 208 | else: 209 | x = self.norm_mha(x) 210 | 211 | # convolution module 212 | if self.conv_module is not None: 213 | residual = x 214 | if self.normalize_before: 215 | x = self.norm_conv(x) 216 | x = residual + self.dropout(self.conv_module(x)) 217 | if not self.normalize_before: 218 | x = self.norm_conv(x) 219 | 220 | # feed forward module 221 | residual = x 222 | if self.normalize_before: 223 | x = self.norm_ff(x) 224 | x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) 225 | if not self.normalize_before: 226 | x = self.norm_ff(x) 227 | 228 | if self.conv_module is not None: 229 | x = self.norm_final(x) 230 | 231 | if cache is not None: 232 | x = torch.cat([cache, x], dim=1) 233 | 234 | return x, mask, None, embeds, rotary_emb -------------------------------------------------------------------------------- /core/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | # rotary embedding 5 | 6 | class RotaryEmbedding(nn.Module): 7 | def __init__(self, dim, theta = 10000): 8 | super().__init__() 9 | inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) 10 | self.register_buffer("inv_freq", inv_freq, persistent = False) 11 | 12 | @property 13 | def device(self): 14 | return next(self.buffers()).device 15 | 16 | def forward(self, seq_len): 17 | t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq) 18 | freqs = torch.einsum('i , j -> i j', t, self.inv_freq) 19 | freqs = torch.cat((freqs, freqs), dim = -1) 20 | return freqs 21 | 22 | def rotate_half(x): 23 | x1, x2 = x.chunk(2, dim=-1) 24 | return torch.cat((-x2, x1), dim=-1) 25 | 26 | def apply_rotary_pos_emb(pos, t): 27 | return (t * pos.cos()) + (rotate_half(t) * pos.sin()) 28 | 29 | # conformer 30 | 31 | 32 | class MultiSequential(torch.nn.Sequential): 33 | """Multi-input multi-output torch.nn.Sequential""" 34 | 35 | def forward(self, *args): 36 | for m in self: 37 | args = m(*args) 38 | return args 39 | 40 | 41 | def repeat(N, fn): 42 | """repeat module N times 43 | 44 | :param int N: repeat time 45 | :param function fn: function to generate module 46 | :return: repeated loss 47 | :rtype: MultiSequential 48 | """ 49 | return MultiSequential(*[fn(n) for n in range(N)]) 50 | 51 | class LayerNorm(torch.nn.LayerNorm): 52 | """ 53 | Layer normalization core. 54 | Args: 55 | nout (int): Output dim size. 56 | dim (int): Dimension to be normalized. 57 | """ 58 | 59 | def __init__(self, nout, dim=-1, elementwise_affine=True): 60 | """ 61 | Construct an LayerNorm object. 62 | """ 63 | super(LayerNorm, self).__init__(nout, eps=1e-12, elementwise_affine=elementwise_affine) 64 | self.dim = dim 65 | 66 | def forward(self, x): 67 | """ 68 | Apply layer normalization. 69 | Args: 70 | x (torch.Tensor): Input tensor. 71 | Returns: 72 | torch.Tensor: Normalized tensor. 73 | """ 74 | if self.dim == -1: 75 | return super(LayerNorm, self).forward(x) 76 | return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) 77 | 78 | class LinearNorm(torch.nn.Module): 79 | def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): 80 | super(LinearNorm, self).__init__() 81 | self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) 82 | 83 | torch.nn.init.xavier_uniform_( 84 | self.linear_layer.weight, 85 | gain=torch.nn.init.calculate_gain(w_init_gain)) 86 | 87 | def forward(self, x): 88 | return self.linear_layer(x) 89 | 90 | class MultiLayeredConv1d(torch.nn.Module): 91 | """Multi-layered conv1d for Transformer block. 92 | 93 | This is a module of multi-leyered conv1d designed to replace positionwise feed-forward network 94 | in Transforner block, which is introduced in `FastSpeech: Fast, Robust and Controllable Text to Speech`_. 95 | 96 | Args: 97 | in_chans (int): Number of input channels. 98 | hidden_chans (int): Number of hidden channels. 99 | kernel_size (int): Kernel size of conv1d. 100 | dropout_rate (float): Dropout rate. 101 | 102 | .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: 103 | https://arxiv.org/pdf/1905.09263.pdf 104 | 105 | """ 106 | 107 | def __init__( 108 | self, in_chans: int, hidden_chans: int, kernel_size: int, dropout_rate: float 109 | ): 110 | super(MultiLayeredConv1d, self).__init__() 111 | self.w_1 = torch.nn.Conv1d( 112 | in_chans, 113 | hidden_chans, 114 | kernel_size, 115 | stride=1, 116 | padding=(kernel_size - 1) // 2, 117 | ) 118 | self.w_2 = torch.nn.Conv1d( 119 | hidden_chans, in_chans, 1, stride=1, padding=(1 - 1) // 2 120 | ) 121 | self.dropout = torch.nn.Dropout(dropout_rate) 122 | 123 | def forward(self, x: torch.Tensor) -> torch.Tensor: 124 | """Calculate forward propagation. 125 | 126 | Args: 127 | x (Tensor): Batch of input tensors (B, *, in_chans). 128 | 129 | Returns: 130 | Tensor: Batch of output tensors (B, *, hidden_chans) 131 | 132 | """ 133 | x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) 134 | return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1) 135 | 136 | 137 | class Swish(torch.nn.Module): 138 | """ 139 | Construct an Swish activation function for Conformer. 140 | """ 141 | 142 | def forward(self, x): 143 | """ 144 | Return Swish activation function. 145 | """ 146 | return x * torch.sigmoid(x) 147 | 148 | class ConvolutionModule(nn.Module): 149 | """ 150 | ConvolutionModule in Conformer model. 151 | 152 | Args: 153 | channels (int): The number of channels of conv layers. 154 | kernel_size (int): Kernel size of conv layers. 155 | 156 | """ 157 | 158 | def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True): 159 | super(ConvolutionModule, self).__init__() 160 | # kernel_size should be an odd number for 'SAME' padding 161 | assert (kernel_size - 1) % 2 == 0 162 | 163 | self.pointwise_conv1 = nn.Conv1d(channels, 2 * channels, kernel_size=1, stride=1, padding=0, bias=bias, ) 164 | self.depthwise_conv = nn.Conv1d(channels, channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2, groups=channels, bias=bias, ) 165 | self.norm = nn.GroupNorm(num_groups=32, num_channels=channels) 166 | self.pointwise_conv2 = nn.Conv1d(channels, channels, kernel_size=1, stride=1, padding=0, bias=bias, ) 167 | self.activation = activation 168 | 169 | def forward(self, x): 170 | """ 171 | Compute convolution module. 172 | 173 | Args: 174 | x (torch.Tensor): Input tensor (#batch, time, channels). 175 | 176 | Returns: 177 | torch.Tensor: Output tensor (#batch, time, channels). 178 | 179 | """ 180 | # exchange the temporal dimension and the feature dimension 181 | x = x.transpose(1, 2) 182 | 183 | # GLU mechanism 184 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 185 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 186 | 187 | # 1D Depthwise Conv 188 | x = self.depthwise_conv(x) 189 | x = self.activation(self.norm(x)) 190 | 191 | x = self.pointwise_conv2(x) 192 | 193 | return x.transpose(1, 2) -------------------------------------------------------------------------------- /core/transformer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import math 7 | import random 8 | from einops import rearrange 9 | from core.bidirectional_transformer import BidirectionalTransformer 10 | _CONFIDENCE_OF_KNOWN_TOKENS = torch.Tensor([torch.inf]).to("cuda") 11 | 12 | 13 | class VQGANTransformer(nn.Module): 14 | def __init__(self, args): 15 | super().__init__() 16 | self.num_image_tokens = args.num_image_tokens 17 | self.sos_token = args.num_codebook_vectors + 1 18 | self.mask_token_id = args.num_codebook_vectors 19 | self.choice_temperature = 4.5 20 | 21 | self.gamma = self.gamma_func("cosine") 22 | 23 | # self.transformer = BidirectionalTransformer( 24 | # patch_size=8, embed_dim=args.dim, depth=args.n_layers, num_heads=12, mlp_ratio=4, qkv_bias=True, 25 | # norm_layer=partial(nn.LayerNorm, eps=1e-6), vocab_size=8192+1) 26 | self.transformer = BidirectionalTransformer(args) 27 | self.vqgan = self.load_vqgan(args) 28 | print(f"Transformer parameters: {sum([p.numel() for p in self.transformer.parameters()])}") 29 | 30 | def load_checkpoint(self, epoch): 31 | self.load_state_dict(torch.load(os.path.join("checkpoints", f"transformer_epoch_{epoch}.pt"))) 32 | print("Check!") 33 | 34 | def masking(self, codes, q=None, t=None): 35 | 36 | codes = rearrange(codes, 'b q n -> q b n') 37 | q = random.randint(0, codes.size(0) - 1) if q is None else q 38 | t = random.randint(0, codes.shape[-1] - 1) if t is None else t 39 | t_mask = torch.ones(codes.shape) 40 | t_mask[:, :, t:] = 0 41 | t_mask[0:q] = 1 42 | masked_indices = self.mask_token_id * torch.ones_like(codes, device=codes.device) 43 | codes = t_mask * codes + (1 - t_mask) * masked_indices 44 | 45 | indices = codes[q - 1] 46 | 47 | gamma = self.gamma_func() 48 | gammas = gamma(np.random.uniform()) 49 | r = math.floor(gammas * indices.shape[1]) 50 | sample = torch.rand(indices.shape, device=indices.device).topk(r, dim=1).indices 51 | mask = torch.zeros(indices.shape, dtype=torch.bool, device=indices.device) 52 | mask.scatter_(dim=1, index=sample, value=True) 53 | mask[:, :t] = True 54 | masked_indices = self.mask_token_id * torch.ones_like(indices, device=indices.device) 55 | codes[q - 1] = mask * indices + (~mask) * masked_indices 56 | 57 | codes = rearrange(codes, 'q b n -> b q n') 58 | 59 | return codes, mask # [B, Q, N+1] 60 | 61 | 62 | 63 | 64 | def forward(self, codes, cond): 65 | # codes : [Batch, N_quantizer, Len] 66 | # cond : [Batch, len] 67 | 68 | 69 | b, Q, n = codes.shape 70 | sos_tokens = torch.ones([b, Q+1], dtype=torch.long, device=codes.device) * self.sos_token 71 | a_indices, masks = self.masking(codes) 72 | a_indices = torch.cat((cond.unsqueeze(1), a_indices), dim=1) 73 | 74 | a_indices = torch.cat((sos_tokens.unsqueeze(-1), a_indices), dim=-1) 75 | 76 | logits = self.transformer(a_indices) 77 | 78 | # 3. Loss is only calculated on q level with masked token only (~mask) 79 | 80 | return logits, masks 81 | 82 | def top_k_logits(self, logits, k): 83 | v, ix = torch.topk(logits, k) 84 | out = logits.clone() 85 | if k == 0: 86 | out[:, :] = self.sos_token 87 | else: 88 | out[out < v[..., [-1]]] = self.sos_token 89 | return out 90 | 91 | def gamma_func(self, mode="cosine"): 92 | if mode == "linear": 93 | return lambda r: 1 - r 94 | elif mode == "cosine": 95 | return lambda r: np.cos(r * np.pi / 2) 96 | elif mode == "square": 97 | return lambda r: 1 - r ** 2 98 | elif mode == "cubic": 99 | return lambda r: 1 - r ** 3 100 | else: 101 | raise NotImplementedError 102 | 103 | def create_input_tokens_normal(self, num, label=None): 104 | # label_tokens = label * torch.ones([num, 1]) 105 | # Shift the label by codebook_size 106 | # label_tokens = label_tokens + self.vqgan.codebook.num_codebook_vectors 107 | # Create blank masked tokens 108 | blank_tokens = torch.ones((num, self.num_image_tokens), device="cuda") 109 | masked_tokens = self.mask_token_id * blank_tokens 110 | # Concatenate the two as input_tokens 111 | # input_tokens = torch.concat([label_tokens, masked_tokens], dim=-1) 112 | # return input_tokens.to(torch.int32) 113 | return masked_tokens.to(torch.int64) 114 | 115 | def tokens_to_logits(self, seq): 116 | logits = self.transformer(seq) 117 | # logits = logits[..., :self.vqgan.codebook.num_codebook_vectors] # why is maskgit returning [8, 257, 2025]? 118 | return logits 119 | 120 | def mask_by_random_topk(self, mask_len, probs, temperature=1.0): 121 | confidence = torch.log(probs) + temperature * torch.distributions.gumbel.Gumbel(0, 1).sample(probs.shape).to("cuda") 122 | sorted_confidence, _ = torch.sort(confidence, dim=-1) 123 | # Obtains cut off threshold given the mask lengths. 124 | cut_off = torch.take_along_dim(sorted_confidence, mask_len.to(torch.long), dim=-1) 125 | # Masks tokens with lower confidence. 126 | masking = (confidence < cut_off) 127 | return masking 128 | 129 | @torch.no_grad() 130 | def sample_good(self, inputs=None, num=1, T=11, mode="cosine"): 131 | # self.transformer.eval() 132 | N = self.num_image_tokens 133 | if inputs is None: 134 | inputs = self.create_input_tokens_normal(num) 135 | else: 136 | inputs = torch.hstack( 137 | (inputs, torch.zeros((inputs.shape[0], N - inputs.shape[1]), device="cuda", dtype=torch.int).fill_(self.mask_token_id))) 138 | 139 | sos_tokens = torch.ones(inputs.shape[0], 1, dtype=torch.long, device=inputs.device) * self.sos_token 140 | inputs = torch.cat((sos_tokens, inputs), dim=1) 141 | 142 | unknown_number_in_the_beginning = torch.sum(inputs == self.mask_token_id, dim=-1) 143 | gamma = self.gamma_func(mode) 144 | cur_ids = inputs # [8, 257] 145 | for t in range(T): 146 | logits = self.tokens_to_logits(cur_ids) # call transformer to get predictions [8, 257, 1024] 147 | sampled_ids = torch.distributions.categorical.Categorical(logits=logits).sample() 148 | 149 | unknown_map = (cur_ids == self.mask_token_id) # which tokens need to be sampled -> bool [8, 257] 150 | sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids) # replace all -1 with their samples and leave the others untouched [8, 257] 151 | 152 | ratio = 1. * (t + 1) / T # just a percentage e.g. 1 / 12 153 | mask_ratio = gamma(ratio) 154 | 155 | probs = F.softmax(logits, dim=-1) # convert logits into probs [8, 257, 1024] 156 | selected_probs = torch.squeeze(torch.take_along_dim(probs, torch.unsqueeze(sampled_ids, -1), -1), -1) # get probability for selected tokens in categorical call, also for already sampled ones [8, 257] 157 | 158 | selected_probs = torch.where(unknown_map, selected_probs, _CONFIDENCE_OF_KNOWN_TOKENS) # ignore tokens which are already sampled [8, 257] 159 | 160 | mask_len = torch.unsqueeze(torch.floor(unknown_number_in_the_beginning * mask_ratio), 1) # floor(256 * 0.99) = 254 --> [254, 254, 254, 254, ....] 161 | mask_len = torch.maximum(torch.zeros_like(mask_len), torch.minimum(torch.sum(unknown_map, dim=-1, keepdim=True)-1, mask_len)) # add -1 later when conditioning and also ones_like. Zeroes just because we have no cond token 162 | # max(1, min(how many unknown tokens, how many tokens we want to sample)) 163 | 164 | # Adds noise for randomness 165 | masking = self.mask_by_random_topk(mask_len, selected_probs, temperature=self.choice_temperature * (1. - ratio)) 166 | # Masks tokens with lower confidence. 167 | cur_ids = torch.where(masking, self.mask_token_id, sampled_ids) 168 | # print((cur_ids == 8192).count_nonzero()) 169 | 170 | # self.transformer.train() 171 | return cur_ids[:, 1:] 172 | 173 | @torch.no_grad() 174 | def log_images(self, x, mode="cosine"): 175 | log = dict() 176 | 177 | _, z_indices = self.encode_to_z(x) 178 | 179 | # create new sample 180 | index_sample = self.sample_good(mode=mode) 181 | x_new = self.indices_to_image(index_sample) 182 | 183 | # create a "half" sample 184 | z_start_indices = z_indices[:, :z_indices.shape[1] // 2] 185 | half_index_sample = self.sample_good(z_start_indices, mode=mode) 186 | x_sample = self.indices_to_image(half_index_sample) 187 | 188 | # create reconstruction 189 | x_rec = self.indices_to_image(z_indices) 190 | 191 | log["input"] = x 192 | log["rec"] = x_rec 193 | log["half_sample"] = x_sample 194 | log["new_sample"] = x_new 195 | return log, torch.concat((x, x_rec, x_sample, x_new)) 196 | 197 | def indices_to_image(self, indices, p1=32, p2=32): 198 | ix_to_vectors = self.vqgan.codebook.embedding(indices).reshape(indices.shape[0], p1, p2, 32) 199 | # ix_to_vectors = self.vqgan.quantize.embedding(indices).reshape(indices.shape[0], 16, 16, 256) 200 | ix_to_vectors = ix_to_vectors.permute(0, 3, 1, 2) 201 | image = self.vqgan.decode(ix_to_vectors) 202 | return image 203 | 204 | @staticmethod 205 | def create_masked_image(image: torch.Tensor, x_start: int = 100, y_start: int = 100, size: int = 50): 206 | mask = torch.ones_like(image, dtype=torch.int) 207 | mask[:, :, x_start:x_start + size, y_start:y_start + size] = 0 208 | return image * mask, mask 209 | 210 | def inpainting(self, image: torch.Tensor, x_start: int = 100, y_start: int = 100, size: int = 50): 211 | # Note: this function probably doesnt work yet lol 212 | # apply mask on image 213 | masked_image, mask = self.create_masked_image(image, x_start, y_start, size) 214 | 215 | # encode masked image 216 | # _, indices = self.encode_to_z(masked_image) 217 | indices = torch.randint(1024, (1, 256), dtype=torch.int) 218 | mask = mask[:, 0, :, :] 219 | 220 | # set masked patches to be 0 -> so that the sampling part only samples indices for these patches 221 | # 1. idea: just calculate the ratio between 256x256 image and 16x16 latent image and set the area 222 | # which was masked in the original image to 0 in the encoded image 223 | # 2. idea: check if patches which were masked in the original image are always the same in the latent space 224 | # If so: set these to 0 225 | p = 16 226 | patched_mask = mask.unfold(2, p, p).unfold(1, p, p) 227 | patched_mask = torch.transpose(patched_mask, 3, 4) 228 | patched_mask = patched_mask.permute(1, 2, 0, 3, 4) 229 | patched_mask = patched_mask.contiguous().view(patched_mask.size(0) * patched_mask.size(1), 230 | -1) # 256 x 256 i.e. 16x16 x 256 231 | 232 | indices_mask, _ = torch.min(patched_mask, dim=-1) 233 | indices = indices_mask * indices 234 | 235 | # inpaint the image by using the sample method and provide the masked image indices and condition 236 | sampled_indices = self.sample(indices) 237 | 238 | # reconstruct inpainted image 239 | inpainted_image = self.indices_to_image(sampled_indices) 240 | 241 | # linearly blend the input image and inpainted image at border of mask (to avoid sharp edges at border of mask) 242 | indices_mask = indices_mask.reshape(1, 1, 16, 16).type(torch.float) 243 | upsampled_indices_mask = F.interpolate(indices_mask, scale_factor=16).squeeze(0) 244 | intra = torch.where(mask != upsampled_indices_mask, 1, 0) 245 | 246 | # define mask for blending 247 | n = 128 248 | base = torch.arange(n).view(1, -1).max(torch.arange(n).view(-1, 1)) 249 | right = torch.stack((torch.rot90(base, 1, [0, 1]), base)).reshape(n * 2, n) 250 | left = torch.stack((torch.rot90(base, 2, [0, 1]), torch.rot90(base, 3, [0, 1]))).reshape(n * 2, n) 251 | full = torch.cat((left, right), 1) 252 | 253 | # construct opacity matrix for intra region 254 | min_blend = torch.min(torch.where(intra == 1, full, 1000000)) 255 | max_blend = torch.max(torch.where(intra == 1, full, -1000000)) 256 | mask_blend = torch.where(intra == 1, (full - min_blend) / max_blend, torch.ones_like(intra, dtype=torch.float)) 257 | 258 | mask_real = torch.where(mask == 0, mask.type(torch.float), mask_blend) 259 | mask_fake = torch.where(mask == 0, (1 - mask).type(torch.float), mask_blend) 260 | 261 | blended_image = mask_real * image + mask_fake * inpainted_image 262 | 263 | return blended_image, inpainted_image 264 | 265 | -------------------------------------------------------------------------------- /core/vq_f16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from vq_modules import Encoder, Decoder 4 | from vq_modules import VectorQuantizer2 as VectorQuantizer 5 | 6 | 7 | class VQModel(nn.Module): 8 | def __init__(self, ckpt_path=None): 9 | super().__init__() 10 | ddconfig = {'double_z': False, 'z_channels': 256, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 1, 2, 2, 4], 'num_res_blocks': 2, 'attn_resolutions': [16], 'dropout': 0.0} 11 | embed_dim = 256 12 | n_embed = 1024 13 | self.encoder = Encoder(**ddconfig) 14 | self.decoder = Decoder(**ddconfig) 15 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25) 16 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) 17 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 18 | if ckpt_path is not None: 19 | self.init_from_ckpt(ckpt_path) 20 | 21 | def init_from_ckpt(self, path): 22 | sd = torch.load(path, map_location="cpu") 23 | self.load_state_dict(sd, strict=False) 24 | print(f"Restored from {path}") 25 | 26 | def encode(self, x): 27 | h = self.encoder(x) 28 | h = self.quant_conv(h) 29 | quant, emb_loss, info = self.quantize(h) 30 | return quant, emb_loss, info 31 | 32 | def decode(self, quant): 33 | quant = self.post_quant_conv(quant) 34 | dec = self.decoder(quant) 35 | return dec 36 | 37 | def decode_code(self, code_b): 38 | quant_b = self.quantize.embed_code(code_b) 39 | dec = self.decode(quant_b) 40 | return dec 41 | 42 | def forward(self, input): 43 | quant, diff, _ = self.encode(input) 44 | dec = self.decode(quant) 45 | return dec, diff 46 | 47 | 48 | # RuDalle image pos embeddings 49 | 50 | def get_image_pos_embeddings(self, image_input_ids, past_length=0): 51 | input_shape = image_input_ids.size() 52 | row_ids = torch.arange(past_length, input_shape[-1] + past_length, 53 | dtype=torch.long, device=self.device) // self.image_tokens_per_dim 54 | row_ids = row_ids.unsqueeze(0).view(-1, input_shape[-1]) 55 | col_ids = torch.arange(past_length, input_shape[-1] + past_length, 56 | dtype=torch.long, device=self.device) % self.image_tokens_per_dim 57 | col_ids = col_ids.unsqueeze(0).view(-1, input_shape[-1]) 58 | return self.image_row_embeddings(row_ids) + self.image_col_embeddings(col_ids) -------------------------------------------------------------------------------- /core/vq_modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from einops import rearrange 7 | 8 | 9 | def get_timestep_embedding(timesteps, embedding_dim): 10 | """ 11 | This matches the implementation in Denoising Diffusion Probabilistic Models: 12 | From Fairseq. 13 | Build sinusoidal embeddings. 14 | This matches the implementation in tensor2tensor, but differs slightly 15 | from the description in Section 3.5 of "Attention Is All You Need". 16 | """ 17 | assert len(timesteps.shape) == 1 18 | 19 | half_dim = embedding_dim // 2 20 | emb = math.log(10000) / (half_dim - 1) 21 | emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) 22 | emb = emb.to(device=timesteps.device) 23 | emb = timesteps.float()[:, None] * emb[None, :] 24 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 25 | if embedding_dim % 2 == 1: # zero pad 26 | emb = torch.nn.functional.pad(emb, (0,1,0,0)) 27 | return emb 28 | 29 | 30 | def nonlinearity(x): 31 | # swish 32 | return x*torch.sigmoid(x) 33 | 34 | 35 | def Normalize(in_channels): 36 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 37 | 38 | 39 | class Upsample(nn.Module): 40 | def __init__(self, in_channels, with_conv): 41 | super().__init__() 42 | self.with_conv = with_conv 43 | if self.with_conv: 44 | self.conv = torch.nn.Conv2d(in_channels, 45 | in_channels, 46 | kernel_size=3, 47 | stride=1, 48 | padding=1) 49 | 50 | def forward(self, x): 51 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 52 | if self.with_conv: 53 | x = self.conv(x) 54 | return x 55 | 56 | 57 | class Downsample(nn.Module): 58 | def __init__(self, in_channels, with_conv): 59 | super().__init__() 60 | self.with_conv = with_conv 61 | if self.with_conv: 62 | # no asymmetric padding in torch conv, must do it ourselves 63 | self.conv = torch.nn.Conv2d(in_channels, 64 | in_channels, 65 | kernel_size=3, 66 | stride=2, 67 | padding=0) 68 | 69 | def forward(self, x): 70 | if self.with_conv: 71 | pad = (0,1,0,1) 72 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 73 | x = self.conv(x) 74 | else: 75 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 76 | return x 77 | 78 | 79 | class ResnetBlock(nn.Module): 80 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 81 | dropout, temb_channels=512): 82 | super().__init__() 83 | self.in_channels = in_channels 84 | out_channels = in_channels if out_channels is None else out_channels 85 | self.out_channels = out_channels 86 | self.use_conv_shortcut = conv_shortcut 87 | 88 | self.norm1 = Normalize(in_channels) 89 | self.conv1 = torch.nn.Conv2d(in_channels, 90 | out_channels, 91 | kernel_size=3, 92 | stride=1, 93 | padding=1) 94 | if temb_channels > 0: 95 | self.temb_proj = torch.nn.Linear(temb_channels, 96 | out_channels) 97 | self.norm2 = Normalize(out_channels) 98 | self.dropout = torch.nn.Dropout(dropout) 99 | self.conv2 = torch.nn.Conv2d(out_channels, 100 | out_channels, 101 | kernel_size=3, 102 | stride=1, 103 | padding=1) 104 | if self.in_channels != self.out_channels: 105 | if self.use_conv_shortcut: 106 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 107 | out_channels, 108 | kernel_size=3, 109 | stride=1, 110 | padding=1) 111 | else: 112 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 113 | out_channels, 114 | kernel_size=1, 115 | stride=1, 116 | padding=0) 117 | 118 | def forward(self, x, temb): 119 | h = x 120 | h = self.norm1(h) 121 | h = nonlinearity(h) 122 | h = self.conv1(h) 123 | 124 | if temb is not None: 125 | h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] 126 | 127 | h = self.norm2(h) 128 | h = nonlinearity(h) 129 | h = self.dropout(h) 130 | h = self.conv2(h) 131 | 132 | if self.in_channels != self.out_channels: 133 | if self.use_conv_shortcut: 134 | x = self.conv_shortcut(x) 135 | else: 136 | x = self.nin_shortcut(x) 137 | 138 | return x+h 139 | 140 | 141 | class AttnBlock(nn.Module): 142 | def __init__(self, in_channels): 143 | super().__init__() 144 | self.in_channels = in_channels 145 | 146 | self.norm = Normalize(in_channels) 147 | self.q = torch.nn.Conv2d(in_channels, 148 | in_channels, 149 | kernel_size=1, 150 | stride=1, 151 | padding=0) 152 | self.k = torch.nn.Conv2d(in_channels, 153 | in_channels, 154 | kernel_size=1, 155 | stride=1, 156 | padding=0) 157 | self.v = torch.nn.Conv2d(in_channels, 158 | in_channels, 159 | kernel_size=1, 160 | stride=1, 161 | padding=0) 162 | self.proj_out = torch.nn.Conv2d(in_channels, 163 | in_channels, 164 | kernel_size=1, 165 | stride=1, 166 | padding=0) 167 | 168 | 169 | def forward(self, x): 170 | h_ = x 171 | h_ = self.norm(h_) 172 | q = self.q(h_) 173 | k = self.k(h_) 174 | v = self.v(h_) 175 | 176 | # compute attention 177 | b,c,h,w = q.shape 178 | q = q.reshape(b,c,h*w) 179 | q = q.permute(0,2,1) # b,hw,c 180 | k = k.reshape(b,c,h*w) # b,c,hw 181 | w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 182 | w_ = w_ * (int(c)**(-0.5)) 183 | w_ = torch.nn.functional.softmax(w_, dim=2) 184 | 185 | # attend to values 186 | v = v.reshape(b,c,h*w) 187 | w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) 188 | h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 189 | h_ = h_.reshape(b,c,h,w) 190 | 191 | h_ = self.proj_out(h_) 192 | 193 | return x+h_ 194 | 195 | 196 | class Encoder(nn.Module): 197 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, 198 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 199 | resolution, z_channels, double_z=True, **ignore_kwargs): 200 | super().__init__() 201 | self.ch = ch 202 | self.temb_ch = 0 203 | self.num_resolutions = len(ch_mult) 204 | self.num_res_blocks = num_res_blocks 205 | self.resolution = resolution 206 | self.in_channels = in_channels 207 | 208 | # downsampling 209 | self.conv_in = torch.nn.Conv2d(in_channels, 210 | self.ch, 211 | kernel_size=3, 212 | stride=1, 213 | padding=1) 214 | 215 | curr_res = resolution 216 | in_ch_mult = (1,)+tuple(ch_mult) 217 | self.down = nn.ModuleList() 218 | for i_level in range(self.num_resolutions): 219 | block = nn.ModuleList() 220 | attn = nn.ModuleList() 221 | block_in = ch*in_ch_mult[i_level] 222 | block_out = ch*ch_mult[i_level] 223 | for i_block in range(self.num_res_blocks): 224 | block.append(ResnetBlock(in_channels=block_in, 225 | out_channels=block_out, 226 | temb_channels=self.temb_ch, 227 | dropout=dropout)) 228 | block_in = block_out 229 | if curr_res in attn_resolutions: 230 | attn.append(AttnBlock(block_in)) 231 | down = nn.Module() 232 | down.block = block 233 | down.attn = attn 234 | if i_level != self.num_resolutions-1: 235 | down.downsample = Downsample(block_in, resamp_with_conv) 236 | curr_res = curr_res // 2 237 | self.down.append(down) 238 | 239 | # middle 240 | self.mid = nn.Module() 241 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 242 | out_channels=block_in, 243 | temb_channels=self.temb_ch, 244 | dropout=dropout) 245 | self.mid.attn_1 = AttnBlock(block_in) 246 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 247 | out_channels=block_in, 248 | temb_channels=self.temb_ch, 249 | dropout=dropout) 250 | 251 | # end 252 | self.norm_out = Normalize(block_in) 253 | self.conv_out = torch.nn.Conv2d(block_in, 254 | 2*z_channels if double_z else z_channels, 255 | kernel_size=3, 256 | stride=1, 257 | padding=1) 258 | 259 | 260 | def forward(self, x): 261 | #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) 262 | 263 | # timestep embedding 264 | temb = None 265 | 266 | # downsampling 267 | hs = [self.conv_in(x)] 268 | for i_level in range(self.num_resolutions): 269 | for i_block in range(self.num_res_blocks): 270 | h = self.down[i_level].block[i_block](hs[-1], temb) 271 | if len(self.down[i_level].attn) > 0: 272 | h = self.down[i_level].attn[i_block](h) 273 | hs.append(h) 274 | if i_level != self.num_resolutions-1: 275 | hs.append(self.down[i_level].downsample(hs[-1])) 276 | 277 | # middle 278 | h = hs[-1] 279 | h = self.mid.block_1(h, temb) 280 | h = self.mid.attn_1(h) 281 | h = self.mid.block_2(h, temb) 282 | 283 | # end 284 | h = self.norm_out(h) 285 | h = nonlinearity(h) 286 | h = self.conv_out(h) 287 | return h 288 | 289 | 290 | class Decoder(nn.Module): 291 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, 292 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 293 | resolution, z_channels, give_pre_end=False, **ignorekwargs): 294 | super().__init__() 295 | self.ch = ch 296 | self.temb_ch = 0 297 | self.num_resolutions = len(ch_mult) 298 | self.num_res_blocks = num_res_blocks 299 | self.resolution = resolution 300 | self.in_channels = in_channels 301 | self.give_pre_end = give_pre_end 302 | 303 | # compute in_ch_mult, block_in and curr_res at lowest res 304 | in_ch_mult = (1,)+tuple(ch_mult) 305 | block_in = ch*ch_mult[self.num_resolutions-1] 306 | curr_res = resolution // 2**(self.num_resolutions-1) 307 | self.z_shape = (1,z_channels,curr_res,curr_res) 308 | #print("Working with z of shape {} = {} dimensions.".format( 309 | # self.z_shape, np.prod(self.z_shape))) 310 | 311 | # z to block_in 312 | self.conv_in = torch.nn.Conv2d(z_channels, 313 | block_in, 314 | kernel_size=3, 315 | stride=1, 316 | padding=1) 317 | 318 | # middle 319 | self.mid = nn.Module() 320 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 321 | out_channels=block_in, 322 | temb_channels=self.temb_ch, 323 | dropout=dropout) 324 | self.mid.attn_1 = AttnBlock(block_in) 325 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 326 | out_channels=block_in, 327 | temb_channels=self.temb_ch, 328 | dropout=dropout) 329 | 330 | # upsampling 331 | self.up = nn.ModuleList() 332 | for i_level in reversed(range(self.num_resolutions)): 333 | block = nn.ModuleList() 334 | attn = nn.ModuleList() 335 | block_out = ch*ch_mult[i_level] 336 | for i_block in range(self.num_res_blocks+1): 337 | block.append(ResnetBlock(in_channels=block_in, 338 | out_channels=block_out, 339 | temb_channels=self.temb_ch, 340 | dropout=dropout)) 341 | block_in = block_out 342 | if curr_res in attn_resolutions: 343 | attn.append(AttnBlock(block_in)) 344 | up = nn.Module() 345 | up.block = block 346 | up.attn = attn 347 | if i_level != 0: 348 | up.upsample = Upsample(block_in, resamp_with_conv) 349 | curr_res = curr_res * 2 350 | self.up.insert(0, up) # prepend to get consistent order 351 | 352 | # end 353 | self.norm_out = Normalize(block_in) 354 | self.conv_out = torch.nn.Conv2d(block_in, 355 | out_ch, 356 | kernel_size=3, 357 | stride=1, 358 | padding=1) 359 | 360 | def forward(self, z): 361 | #assert z.shape[1:] == self.z_shape[1:] 362 | self.last_z_shape = z.shape 363 | 364 | # timestep embedding 365 | temb = None 366 | 367 | # z to block_in 368 | h = self.conv_in(z) 369 | 370 | # middle 371 | h = self.mid.block_1(h, temb) 372 | h = self.mid.attn_1(h) 373 | h = self.mid.block_2(h, temb) 374 | 375 | # upsampling 376 | for i_level in reversed(range(self.num_resolutions)): 377 | for i_block in range(self.num_res_blocks+1): 378 | h = self.up[i_level].block[i_block](h, temb) 379 | if len(self.up[i_level].attn) > 0: 380 | h = self.up[i_level].attn[i_block](h) 381 | if i_level != 0: 382 | h = self.up[i_level].upsample(h) 383 | 384 | # end 385 | if self.give_pre_end: 386 | return h 387 | 388 | h = self.norm_out(h) 389 | h = nonlinearity(h) 390 | h = self.conv_out(h) 391 | return h 392 | 393 | 394 | class VectorQuantizer2(nn.Module): 395 | """ 396 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 397 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 398 | """ 399 | # NOTE: due to a bug the beta term was applied to the wrong term. for 400 | # backwards compatibility we use the buggy version by default, but you can 401 | # specify legacy=False to fix it. 402 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", 403 | sane_index_shape=False, legacy=True): 404 | super().__init__() 405 | self.n_e = n_e 406 | self.e_dim = e_dim 407 | self.beta = beta 408 | self.legacy = legacy 409 | 410 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 411 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 412 | 413 | self.remap = remap 414 | if self.remap is not None: 415 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 416 | self.re_embed = self.used.shape[0] 417 | self.unknown_index = unknown_index # "random" or "extra" or integer 418 | if self.unknown_index == "extra": 419 | self.unknown_index = self.re_embed 420 | self.re_embed = self.re_embed+1 421 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " 422 | f"Using {self.unknown_index} for unknown indices.") 423 | else: 424 | self.re_embed = n_e 425 | 426 | self.sane_index_shape = sane_index_shape 427 | 428 | def remap_to_used(self, inds): 429 | ishape = inds.shape 430 | assert len(ishape)>1 431 | inds = inds.reshape(ishape[0],-1) 432 | used = self.used.to(inds) 433 | match = (inds[:,:,None]==used[None,None,...]).long() 434 | new = match.argmax(-1) 435 | unknown = match.sum(2)<1 436 | if self.unknown_index == "random": 437 | new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) 438 | else: 439 | new[unknown] = self.unknown_index 440 | return new.reshape(ishape) 441 | 442 | def unmap_to_all(self, inds): 443 | ishape = inds.shape 444 | assert len(ishape)>1 445 | inds = inds.reshape(ishape[0],-1) 446 | used = self.used.to(inds) 447 | if self.re_embed > self.used.shape[0]: # extra token 448 | inds[inds>=self.used.shape[0]] = 0 # simply set to zero 449 | back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) 450 | return back.reshape(ishape) 451 | 452 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False): 453 | assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" 454 | assert rescale_logits==False, "Only for interface compatible with Gumbel" 455 | assert return_logits==False, "Only for interface compatible with Gumbel" 456 | # reshape z -> (batch, height, width, channel) and flatten 457 | z = rearrange(z, 'b c h w -> b h w c').contiguous() 458 | z_flattened = z.view(-1, self.e_dim) 459 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 460 | 461 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 462 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 463 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 464 | 465 | min_encoding_indices = torch.argmin(d, dim=1) 466 | z_q = self.embedding(min_encoding_indices).view(z.shape) 467 | perplexity = None 468 | min_encodings = None 469 | 470 | # compute loss for embedding 471 | if not self.legacy: 472 | loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ 473 | torch.mean((z_q - z.detach()) ** 2) 474 | else: 475 | loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ 476 | torch.mean((z_q - z.detach()) ** 2) 477 | 478 | # preserve gradients 479 | z_q = z + (z_q - z).detach() 480 | 481 | # reshape back to match original input shape 482 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 483 | 484 | if self.remap is not None: 485 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis 486 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 487 | min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten 488 | 489 | if self.sane_index_shape: 490 | min_encoding_indices = min_encoding_indices.reshape( 491 | z_q.shape[0], z_q.shape[2], z_q.shape[3]) 492 | 493 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 494 | 495 | def get_codebook_entry(self, indices, shape): 496 | # shape specifying (batch, height, width, channel) 497 | if self.remap is not None: 498 | indices = indices.reshape(shape[0],-1) # add batch axis 499 | indices = self.unmap_to_all(indices) 500 | indices = indices.reshape(-1) # flatten again 501 | 502 | # get quantized latent vectors 503 | z_q = self.embedding(indices) 504 | 505 | if shape is not None: 506 | z_q = z_q.view(shape) 507 | # reshape back to match original input shape 508 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 509 | 510 | return z_q -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | import pandas as pd 4 | import random 5 | from pathlib import Path 6 | 7 | from torch.utils.data import DataLoader 8 | 9 | 10 | def load_librilight_data(encodec_path, stoks_path, speaker=None): 11 | speakers = [] 12 | atoks = [] 13 | stoks = [] 14 | for path in Path(encodec_path).rglob('*.encodec'): 15 | speakers.append(path.parents[1].name) 16 | atoks.append(path) 17 | stoks.append(Path(stoks_path) / path.relative_to(encodec_path).with_suffix('.stoks')) 18 | data = pd.DataFrame(dict(atoks=atoks, stoks=stoks, speaker=speakers)) 19 | if speaker: data = data[data['speaker'] == speaker] 20 | return data 21 | 22 | 23 | class SADataset(torch.utils.data.Dataset): 24 | def __init__(self, data, speakers): 25 | self.data = data 26 | self.samples = [(i, j) for i, name in enumerate(data['stoks']) for j in 27 | range(torch.load(name, map_location='cpu').shape[0])] 28 | self.speakers = speakers 29 | 30 | def __len__(self): 31 | return len(self.samples) 32 | 33 | def S_tokens(self): 34 | return len(self) * 1500 35 | 36 | def hours(self): 37 | return len(self) * 30 / 3600 38 | 39 | def __repr__(self): 40 | return f"Dataset: {len(self)} samples ({len(self.data['speaker'].unique())} speakers), {self.S_tokens()} Stokens, {self.hours():.1f} hours)" 41 | 42 | def __getitem__(self, idx): 43 | i, j = self.samples[idx] 44 | row = self.data.iloc[i] 45 | jA = j * 2250 46 | Stoks = torch.load(row['stoks'], map_location='cpu')[j] 47 | Atoks = torch.load(row['atoks'], map_location='cpu')[0, :, jA:jA + 2250] 48 | return Stoks, F.pad(Atoks, (0, 2250 - Atoks.shape[-1]), value=1026), torch.tensor(self.speakers[row['speaker']]) 49 | 50 | 51 | import math 52 | 53 | 54 | def load_datasets( 55 | stoks_path: Path, # semantic tokens path 56 | encodec_path: Path, # encodec tokens path 57 | subsample: float = 1, # use a fraction of the files 58 | val_split: float = 0.001, # how much data to use for validation 59 | speaker: str = None # only load a single speaker id 60 | ): 61 | data = load_librilight_data(encodec_path, stoks_path, speaker=speaker) 62 | 63 | speakers = {id: i for i, id in enumerate(data['speaker'].unique())} 64 | 65 | # select at most 4 frequent speakers from the dataset for the validation set 66 | # this way, even when subsampling, we avoid the problem of having someone 67 | # in the validation set that's not in the training set 68 | val_speakers = data.groupby('speaker').size().sort_values()[-4:].index 69 | Nval = math.ceil(val_split * len(data) / len(val_speakers)) 70 | val_idxs = [] 71 | for idx in val_speakers: 72 | val_idxs += list(data[data['speaker'] == idx][:Nval].index) 73 | 74 | train_idxs = list(set(data.index) - set(val_idxs)) 75 | 76 | random.seed(0) 77 | random.shuffle(train_idxs) 78 | Ntrain = int(len(train_idxs) * subsample) 79 | 80 | val_data, train_data = data.loc[val_idxs], data.loc[train_idxs[:Ntrain]] 81 | 82 | return SADataset(train_data, speakers), SADataset(val_data, speakers) 83 | 84 | def get_tts_dataset(semb_path, encodec_path, batch_size): 85 | s_train_ds, val_ds = load_datasets(semb_path, encodec_path, 86 | subsample=1.0, speaker='6454') 87 | 88 | pin_mem = True 89 | num_workers = 4 90 | shuffle = True 91 | 92 | 93 | train_set = DataLoader( 94 | s_train_ds, 95 | batch_size=batch_size, 96 | num_workers=num_workers, 97 | shuffle=shuffle, 98 | pin_memory=pin_mem, 99 | ) 100 | 101 | valid_set = DataLoader( 102 | val_ds, 103 | batch_size=1, 104 | num_workers=0, 105 | shuffle=False, 106 | pin_memory=False, 107 | ) 108 | return train_set, valid_set -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_channels, out_channels): 8 | super(ResidualBlock, self).__init__() 9 | self.in_channels = in_channels 10 | self.out_channels = out_channels 11 | self.block = nn.Sequential( 12 | GroupNorm(in_channels), 13 | Swish(), 14 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 15 | GroupNorm(out_channels), 16 | Swish(), 17 | nn.Conv2d(out_channels, out_channels, 3, 1, 1) 18 | ) 19 | if in_channels != out_channels: 20 | self.channel_up = nn.Conv2d(in_channels, out_channels, 1, 1, 0) 21 | 22 | def forward(self, x): 23 | if self.in_channels != self.out_channels: 24 | return self.block(x) + self.channel_up(x) 25 | else: 26 | return x + self.block(x) 27 | 28 | 29 | class UpSampleBlock(nn.Module): 30 | def __init__(self, channels): 31 | super(UpSampleBlock, self).__init__() 32 | self.conv = nn.Conv2d(channels, channels, 3, 1, 1) 33 | 34 | def forward(self, x): 35 | x = F.interpolate(x, scale_factor=2.) 36 | return self.conv(x) 37 | 38 | 39 | class DownSampleBlock(nn.Module): 40 | def __init__(self, channels): 41 | super(DownSampleBlock, self).__init__() 42 | self.conv = nn.Conv2d(channels, channels, 3, 2, 0) 43 | 44 | def forward(self, x): 45 | pad = (0, 1, 0, 1) 46 | x = F.pad(x, pad, mode="constant", value=0) 47 | return self.conv(x) 48 | 49 | 50 | class NonLocalBlock(nn.Module): 51 | def __init__(self, in_channels): 52 | super().__init__() 53 | self.in_channels = in_channels 54 | 55 | self.norm = GroupNorm(in_channels) 56 | self.q = torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0) 57 | self.k = torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0) 58 | self.v = torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0) 59 | self.proj_out = torch.nn.Conv2d(in_channels, in_channels, 1, 1, 0) 60 | 61 | def forward(self, x): 62 | h_ = self.norm(x) 63 | q = self.q(h_) 64 | k = self.k(h_) 65 | v = self.v(h_) 66 | 67 | b, c, h, w = q.shape 68 | 69 | q = q.reshape(b, c, h * w) 70 | q = q.permute(0, 2, 1) 71 | k = k.reshape(b, c, h * w) 72 | v = v.reshape(b, c, h * w) 73 | 74 | attn = torch.bmm(q, k) 75 | attn = attn * (int(c) ** (-0.5)) 76 | attn = F.softmax(attn, dim=2) 77 | 78 | attn = attn.permute(0, 2, 1) 79 | A = torch.bmm(v, attn) 80 | A = A.reshape(b, c, h, w) 81 | 82 | A = self.proj_out(A) 83 | 84 | return x + A 85 | 86 | 87 | class GroupNorm(nn.Module): 88 | def __init__(self, in_channels): 89 | super(GroupNorm, self).__init__() 90 | self.gn = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 91 | 92 | def forward(self, x): 93 | return self.gn(x) 94 | 95 | 96 | class Swish(nn.Module): 97 | def forward(self, x): 98 | return x * torch.sigmoid(x) 99 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from einops import rearrange 4 | from tqdm import tqdm 5 | import argparse 6 | import torch 7 | import torch.nn.functional as F 8 | from SoundStorm import SoundStorm 9 | from encodec import EncodecModel 10 | from scipy.io.wavfile import write 11 | 12 | DEVICE = "cuda:0" 13 | 14 | CHKPT = "./checkpoints/transformer_step_40000.pt" 15 | 16 | 17 | 18 | 19 | encodec_ = EncodecModel.encodec_model_24khz() 20 | encodec_.normalize = False 21 | encodec_.set_target_bandwidth(6.0) 22 | encodec_ = encodec_.cuda() 23 | target_sample_hz = 24000 24 | torch_args = True 25 | 26 | path = None 27 | if __name__ == '__main__': 28 | 29 | #cond = "./data/semantic_code/26_495_000007_000003.npy" 30 | #codes = "./data/codec_code/26_495_000007_000003.npy" 31 | cond = "../../../data/whisperspeech/whisperspeech/librilight/stoks/large/6454/a_christmas_miscellany_2018_1807_librivox_64kb_mp3/christmasmiscellany2018_02_various_64kb.stoks" 32 | codes = "../../../data/whisperspeech/whisperspeech/librilight/encodec-6kbps/large/6454/a_christmas_miscellany_2018_1807_librivox_64kb_mp3/christmasmiscellany2018_02_various_64kb.encodec" 33 | 34 | 35 | j = 2 36 | if torch_args: 37 | jA = j * 2250 38 | Stoks = torch.load(cond, map_location=DEVICE)[j] 39 | Atoks = torch.load(codes, map_location=DEVICE)[0, :, jA:jA + 2250] 40 | codes = F.pad(Atoks, (0, 2250 - Atoks.shape[-1]), value=1026) 41 | 42 | 43 | 44 | prompt = codes[:, :750].clone().detach().unsqueeze(0).cuda() 45 | semb = Stoks.unsqueeze(0) 46 | b, n = semb.shape 47 | semb = semb.reshape(b, n // 2, 2) 48 | semb = semb.repeat_interleave(2, -1)[:, :, :3] 49 | semb[:, :, 1] = 1025 50 | semb = semb.reshape(b, n // 2 * 3) 51 | assert semb.shape[-1] == codes.shape[-1] 52 | print("shape of semb", semb.shape) 53 | print("shape of codes", codes.shape) 54 | 55 | # codec = rearrange(codes.unsqueeze(0), 'b q n -> q b n') 56 | # emb = encodec_.quantizer.decode(codec) 57 | # 58 | # audio = encodec_.decoder(emb).squeeze() 59 | # print("shape of audio", audio.shape) 60 | # write("org.wav", target_sample_hz, audio.detach().cpu().numpy()) 61 | else: 62 | semb = np.load(cond) # [L/2] 63 | codes = np.load(codes) # [L, number_of_quantizers] 64 | 65 | semb = np.repeat(semb, 2, axis=0) 66 | mel_len = min(semb.shape[0], codes.shape[0]) 67 | 68 | if semb.shape[0] > mel_len: 69 | semb = semb[:mel_len] 70 | else: 71 | codes = codes[:mel_len] 72 | 73 | assert semb.shape[0] == codes.shape[0] 74 | prompt = torch.from_numpy(codes[:600, :]).T.unsqueeze(0).to(DEVICE) 75 | semb = torch.from_numpy(semb).unsqueeze(0).to(DEVICE) 76 | model = SoundStorm(dim=768, heads=8, linear_units=3072, num_blocks=8).to(DEVICE) 77 | 78 | chkpt = torch.load(CHKPT, map_location=DEVICE) 79 | model.load_state_dict(chkpt['model']) 80 | model.eval() 81 | 82 | codec = model.generate(semb, prompt) # [B, q, n] 83 | 84 | np.save("out.npy", codec.cpu().detach().numpy()) 85 | 86 | codes = rearrange(codec, 'b q n -> q b n') 87 | emb = encodec_.quantizer.decode(codes) 88 | 89 | audio = encodec_.decoder(emb) 90 | 91 | write("out_conf2.wav", target_sample_hz, audio.detach().cpu().numpy()) -------------------------------------------------------------------------------- /lr_schedule.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.optim import Adam 4 | import math 5 | 6 | 7 | class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler): 8 | """ 9 | Implements Warmup learning rate schedule until 'warmup_steps', going from 'init_lr' to 'peak_lr' for multiple optimizers. 10 | """ 11 | def __init__(self, optimizer, init_lr, peak_lr, end_lr, warmup_steps=10000, total_steps=400000, current_step=0): 12 | self.init_lr = init_lr 13 | self.peak_lr = peak_lr 14 | self.end_lr = end_lr 15 | self.optimizer = optimizer 16 | self._warmup_rate = (peak_lr - init_lr) / warmup_steps 17 | self._decay_rate = (end_lr - peak_lr) / (total_steps - warmup_steps) 18 | self._current_step = current_step 19 | self.lr = init_lr 20 | self.warmup_steps = warmup_steps 21 | self.total_steps = total_steps 22 | self._last_lr = [self.lr] 23 | 24 | def set_lr(self, lr): 25 | self._last_lr = [g['lr'] for g in self.optimizer.param_groups] 26 | for g in self.optimizer.param_groups: 27 | g['lr'] = lr 28 | 29 | def step(self): 30 | if self._current_step < self.warmup_steps: 31 | lr = self.init_lr + self._warmup_rate * self._current_step 32 | 33 | elif self._current_step > self.total_steps: 34 | lr = self.end_lr 35 | 36 | else: 37 | decay_ratio = (self._current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps) 38 | if decay_ratio < 0.0 or decay_ratio > 1.0: 39 | raise RuntimeError( 40 | "Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings.") 41 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) 42 | lr = self.end_lr + coeff * (self.peak_lr - self.end_lr) 43 | 44 | self.set_lr(lr) 45 | self.lr = lr 46 | self._current_step += 1 47 | return self.lr 48 | 49 | 50 | if __name__ == '__main__': 51 | m = nn.Linear(10, 10) 52 | opt = Adam(m.parameters(), lr=1e-4) 53 | s = WarmupCosineLRSchedule(opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000,total_steps=20000, current_step=0) 54 | lrs = [] 55 | for i in range(25000): 56 | s.step() 57 | lrs.append(s.lr) 58 | print(s.lr) 59 | 60 | 61 | # plt.plot(lrs) 62 | # plt.plot(range(0, 25000), lrs) 63 | # plt.show() -------------------------------------------------------------------------------- /requirnements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers 3 | tqdm 4 | einops 5 | pandas 6 | matplotlib 7 | scipy 8 | encodec 9 | tensorboard -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | from tqdm import tqdm 5 | import argparse 6 | import torch 7 | import torch.nn.functional as F 8 | from SoundStorm import SoundStorm 9 | from dataset import get_tts_dataset 10 | from lr_schedule import WarmupCosineLRSchedule 11 | from torch.utils.tensorboard import SummaryWriter 12 | from encodec import EncodecModel 13 | 14 | def topk_accuracy(output, target, topk=(1,)): 15 | """Computes the precision@k for the specified values of k""" 16 | with torch.no_grad(): 17 | maxk = max(topk) 18 | batch_size = target.size(0) 19 | 20 | _, pred = output.topk(maxk, 1, True, True) 21 | pred = pred.t() 22 | correct = (pred == target.unsqueeze(dim=0)).expand_as(pred) 23 | 24 | res = [] 25 | for k in topk: 26 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 27 | res.append(correct_k.mul_(100.0 / batch_size)) 28 | 29 | 30 | return res 31 | 32 | class TrainTransformer: 33 | def __init__(self, args): 34 | encodec_ = EncodecModel.encodec_model_24khz() 35 | encodec_.normalize = False 36 | encodec_.set_target_bandwidth(6.0) 37 | encodec_ = encodec_.to(device=args.device) 38 | 39 | self.model = SoundStorm(encodec=encodec_).to(device=args.device) 40 | self.optim = torch.optim.AdamW(self.model.parameters(), lr=2e-4, betas=(0.8, 0.99), eps=0.000000001) 41 | self.lr_schedule = WarmupCosineLRSchedule(self.optim, 42 | init_lr=0.000001, 43 | peak_lr=0.0002, 44 | end_lr=0.00001, 45 | warmup_steps=2000, 46 | total_steps=40000 47 | ) 48 | 49 | 50 | if args.run_name: 51 | self.logger = SummaryWriter(f"./runs/{args.run_name}") 52 | else: 53 | self.logger = SummaryWriter() 54 | self.train(args) 55 | 56 | def train(self, args): 57 | 58 | train_dataset, valid_dataset = get_tts_dataset(args.spath, args.epath, args.batch_size) 59 | len_train_dataset = len(train_dataset) 60 | step = 0 61 | start_from_epoch = 0 62 | if args.chkpt is not None: 63 | checkpoint = torch.load(args.chkpt) 64 | start_from_epoch = checkpoint["epoch"] 65 | 66 | self.lr_schedule = torch.optim.lr_scheduler.ExponentialLR(self.optim, gamma=0.999875, 67 | last_epoch=start_from_epoch - 1) 68 | if args.chkpt is not None: 69 | print("Loading checkpoints") 70 | self.model.load_state_dict(checkpoint["model"]) 71 | self.optim.load_state_dict(checkpoint["optim"]) 72 | self.lr_schedule.load_state_dict(checkpoint["schedular"]) 73 | step = checkpoint["step"] 74 | args.start_from_epoch = checkpoint["epoch"] 75 | 76 | self.model.train() 77 | for epoch in range(args.start_from_epoch + 1, args.epochs + 1): 78 | print(f"Epoch {epoch}:") 79 | with tqdm(range(len(train_dataset))) as pbar: 80 | 81 | for i, (cond, codes, ids) in zip(pbar, train_dataset): 82 | start = random.randint(0, 999) 83 | codes = codes.cuda() # [B, 8, 2250] 84 | cond = cond.cuda() # [B, 1500] 85 | 86 | b, n = cond.shape 87 | cond = cond.reshape(b, n // 2, 2) 88 | cond = cond.repeat_interleave(2, -1)[:, :, :3] 89 | cond[:, :, 1] = 1025 90 | cond = cond.reshape(b, n // 2 * 3) 91 | assert cond.shape[-1] == codes.shape[-1] 92 | 93 | loss, logit, target = self.model(cond.long()[:, start:start + 375], 94 | codes.long()[:, :, start:start + 375]) 95 | # loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1)) 96 | loss.backward() 97 | 98 | ### Calculate accuracy: 99 | mask = target.eq(1025) 100 | maske_target = target[~mask] 101 | masked_logits = logit[~mask] 102 | 103 | masked_token_prediction = torch.argmax(masked_logits, dim=-1) 104 | token_correct = (masked_token_prediction == maske_target).sum() 105 | token_total = maske_target.shape[0] 106 | token_accuracy = token_correct / token_total 107 | 108 | topk = topk_accuracy(masked_logits, maske_target, topk=(1, 10)) 109 | 110 | if step % args.accum_grad == 0: 111 | self.optim.step() 112 | self.lr_schedule.step() 113 | self.optim.zero_grad() 114 | step += 1 115 | pbar.set_postfix(Top10k=topk[1].cpu().detach().numpy().item(), 116 | Top1k=topk[0].cpu().detach().numpy().item(), 117 | Transformer_Loss=np.round(loss.cpu().detach().numpy().item(), 4)) 118 | pbar.update(0) 119 | self.logger.add_scalar("Cross Entropy Loss", np.round(loss.cpu().detach().numpy().item(), 4), step) 120 | self.logger.add_scalar("Accuracy", token_accuracy.cpu().detach().numpy().item(), step) 121 | self.logger.add_scalar("Top10K", topk[1].cpu().detach().numpy().item(), step) 122 | self.logger.add_scalar("Top1K", topk[0].cpu().detach().numpy().item(), step) 123 | if step % args.ckpt_interval == 0: 124 | torch.save({ 125 | "model": self.model.state_dict(), 126 | "optim": self.optim.state_dict(), 127 | "schedular": self.lr_schedule.state_dict(), 128 | "step": step, 129 | "epoch": epoch + 1, 130 | }, 131 | os.path.join("checkpoints", f"transformer_step_{step}.pt") 132 | ) 133 | 134 | if step % args.validation_step == 0: 135 | self.validation_step(valid_dataset, step) 136 | 137 | def validation_step(self, valid_set, steps): 138 | self.model.eval() 139 | accuracy = 0 140 | avg_loss = 0 141 | top10k = 0 142 | with tqdm(range(len(valid_set))) as pbar: 143 | for i, (cond, codes, ids) in zip(pbar, valid_set): 144 | codes = codes.cuda()[:, :, :750] # [B, 8, 2250] 145 | cond = cond.cuda()[:, :500] # [B, 1500] 146 | 147 | b, n = cond.shape 148 | cond = cond.reshape(b, n // 2, 2) 149 | cond = cond.repeat_interleave(2, -1)[:, :, :3] 150 | cond[:, :, 1] = 1025 151 | cond = cond.reshape(b, n // 2 * 3) 152 | assert cond.shape[-1] == codes.shape[-1] 153 | with torch.no_grad(): 154 | loss, logit, target = self.model(cond.long()[:, :750], codes.long()[:, :, :750]) 155 | # loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), target.reshape(-1)) 156 | avg_loss = avg_loss + loss 157 | ### Calculate accuracy: 158 | mask = target.eq(1025) 159 | maske_target = target[~mask] 160 | masked_logits = logit[~mask] 161 | 162 | masked_token_prediction = torch.argmax(masked_logits, dim=-1) 163 | token_correct = (masked_token_prediction == maske_target).sum() 164 | token_total = maske_target.shape[0] 165 | token_accuracy = token_correct / token_total 166 | 167 | topk = topk_accuracy(masked_logits, maske_target, topk=(1, 10)) 168 | top10k = top10k + topk[-1] 169 | accuracy = accuracy + token_accuracy 170 | pbar.set_postfix(Accuracy=token_accuracy.cpu().detach().numpy().item(), 171 | Transformer_Loss=np.round(loss.cpu().detach().numpy().item(), 4)) 172 | pbar.update(0) 173 | 174 | self.logger.add_scalar("Infer loss:", np.round((avg_loss.cpu().detach().numpy().item()) / len(valid_set), 4), 175 | steps) 176 | self.logger.add_scalar("Infer Accuracy", accuracy.cpu().detach().numpy().item() / len(valid_set), 177 | steps) 178 | self.logger.add_scalar("Infer Top10k", top10k.cpu().detach().numpy().item() / len(valid_set), 179 | steps) 180 | self.model.train() 181 | 182 | 183 | 184 | if __name__ == '__main__': 185 | parser = argparse.ArgumentParser(description="VQGAN") 186 | parser.add_argument('--run-name', type=str, default=None) 187 | parser.add_argument('--nq', type=int, default=8, help='Number of quantizer.') 188 | parser.add_argument('--spath ', type=str, default='./data/whisperspeech/whisperspeech/librilight/stoks/', 189 | help='Path to data.') 190 | parser.add_argument('--epath ', type=str, default='./data/whisperspeech/whisperspeech/librilight/encodec-6kbps/', 191 | help='Path to data.') 192 | parser.add_argument('--device', type=str, default="cuda", help='Which device the training is on.') 193 | parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training.') 194 | parser.add_argument('--accum-grad', type=int, default=10, help='Number for gradient accumulation.') 195 | parser.add_argument('--epochs', type=int, default=300, help='Number of epochs to train.') 196 | parser.add_argument('--start-from-epoch', type=int, default=0, help='Number of epochs to train.') 197 | parser.add_argument('--ckpt-interval', type=int, default=5000, help='Number of epochs to train.') 198 | parser.add_argument('--validation_step', type=int, default=1000, help='Number of epochs to train.') 199 | parser.add_argument('--learning-rate', type=float, default=1e-4, help='Learning rate.') 200 | parser.add_argument('--chkpt', type=str, default=None, help='checkpoint path to load') 201 | 202 | parser.add_argument('--n-layers', type=int, default=12, help='Number of layers of transformer.') 203 | parser.add_argument('--dim', type=int, default=768, help='Dimension of transformer.') 204 | parser.add_argument('--hidden-dim', type=int, default=3072, help='Dimension of transformer.') 205 | 206 | args = parser.parse_args() 207 | args.run_name = "tests3" 208 | args.checkpoint_path = r".\checkpoints" 209 | args.n_layers = 24 210 | args.dim = 768 211 | args.hidden_dim = 3072 212 | args.batch_size = 16 213 | args.accum_grad = 4 214 | args.epochs = 1000 215 | 216 | args.start_from_epoch = 0 217 | 218 | #args.spath = "../../../data/whisperspeech/whisperspeech/librilight/stoks/" 219 | #args.epath = "../../../data/whisperspeech/whisperspeech/librilight/encodec-6kbps/" 220 | 221 | train_transformer = TrainTransformer(args) 222 | 223 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import albumentations 3 | import numpy as np 4 | import torch.nn as nn 5 | from PIL import Image 6 | from torch.utils.data import Dataset, DataLoader 7 | import matplotlib.pyplot as plt 8 | 9 | 10 | import torch 11 | from typing import List 12 | from typing import List 13 | import torch.nn.functional as F 14 | 15 | 16 | def get_mask_from_lengths(lengths, max_len=None): 17 | batch_size = lengths.shape[0] 18 | if max_len is None: 19 | max_len = torch.max(lengths).item() 20 | 21 | ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device) 22 | mask = ids >= lengths.unsqueeze(1).expand(-1, max_len) 23 | 24 | return mask 25 | 26 | 27 | @torch.jit.script 28 | def pad_2d_tensor(xs: List[torch.Tensor], pad_value: float = 0.0): 29 | max_len = max([xs[i].size(0) for i in range(len(xs))]) 30 | 31 | out_list = [] 32 | 33 | for i, batch in enumerate(xs): 34 | one_batch_padded = F.pad( 35 | batch, (0, 0, 0, max_len - batch.size(0)), "constant", pad_value 36 | ) 37 | out_list.append(one_batch_padded) 38 | 39 | out_padded = torch.stack(out_list) 40 | return out_padded 41 | 42 | 43 | def pad_list(xs, pad_value): 44 | """Perform padding for the list of tensors. 45 | 46 | Args: 47 | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. 48 | pad_value (float): Value for padding. 49 | 50 | Returns: 51 | Tensor: Padded tensor (B, Tmax, `*`). 52 | 53 | Examples: 54 | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] 55 | >>> x 56 | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] 57 | >>> pad_list(x, 0) 58 | tensor([[1., 1., 1., 1.], 59 | [1., 1., 0., 0.], 60 | [1., 0., 0., 0.]]) 61 | 62 | """ 63 | n_batch = len(xs) 64 | max_len = max(x.size(0) for x in xs) 65 | pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) 66 | 67 | for i in range(n_batch): 68 | pad[i, : xs[i].size(0)] = xs[i] 69 | 70 | return pad 71 | 72 | def length_to_mask(lengths): 73 | mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) 74 | mask = torch.gt(mask+1, lengths.unsqueeze(1)) 75 | return mask 76 | 77 | 78 | 79 | def make_pad_mask(lengths: List[int], xs: torch.Tensor = None, length_dim: int = -1): 80 | """Make mask tensor containing indices of padded part. 81 | 82 | Args: 83 | lengths (LongTensor or List): Batch of lengths (B,). 84 | xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor. 85 | length_dim (int, optional): Dimension indicator of the above tensor. See the example. 86 | 87 | Returns: 88 | Tensor: Mask tensor containing indices of padded part. 89 | 90 | Examples: 91 | With only lengths. 92 | 93 | >>> lengths = [5, 3, 2] 94 | >>> make_non_pad_mask(lengths) 95 | masks = [[0, 0, 0, 0 ,0], 96 | [0, 0, 0, 1, 1], 97 | [0, 0, 1, 1, 1]] 98 | 99 | With the reference tensor. 100 | 101 | >>> xs = torch.zeros((3, 2, 4)) 102 | >>> make_pad_mask(lengths, xs) 103 | tensor([[[0, 0, 0, 0], 104 | [0, 0, 0, 0]], 105 | [[0, 0, 0, 1], 106 | [0, 0, 0, 1]], 107 | [[0, 0, 1, 1], 108 | [0, 0, 1, 1]]], dtype=torch.uint8) 109 | >>> xs = torch.zeros((3, 2, 6)) 110 | >>> make_pad_mask(lengths, xs) 111 | tensor([[[0, 0, 0, 0, 0, 1], 112 | [0, 0, 0, 0, 0, 1]], 113 | [[0, 0, 0, 1, 1, 1], 114 | [0, 0, 0, 1, 1, 1]], 115 | [[0, 0, 1, 1, 1, 1], 116 | [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) 117 | 118 | With the reference tensor and dimension indicator. 119 | 120 | >>> xs = torch.zeros((3, 6, 6)) 121 | >>> make_pad_mask(lengths, xs, 1) 122 | tensor([[[0, 0, 0, 0, 0, 0], 123 | [0, 0, 0, 0, 0, 0], 124 | [0, 0, 0, 0, 0, 0], 125 | [0, 0, 0, 0, 0, 0], 126 | [0, 0, 0, 0, 0, 0], 127 | [1, 1, 1, 1, 1, 1]], 128 | [[0, 0, 0, 0, 0, 0], 129 | [0, 0, 0, 0, 0, 0], 130 | [0, 0, 0, 0, 0, 0], 131 | [1, 1, 1, 1, 1, 1], 132 | [1, 1, 1, 1, 1, 1], 133 | [1, 1, 1, 1, 1, 1]], 134 | [[0, 0, 0, 0, 0, 0], 135 | [0, 0, 0, 0, 0, 0], 136 | [1, 1, 1, 1, 1, 1], 137 | [1, 1, 1, 1, 1, 1], 138 | [1, 1, 1, 1, 1, 1], 139 | [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) 140 | >>> make_pad_mask(lengths, xs, 2) 141 | tensor([[[0, 0, 0, 0, 0, 1], 142 | [0, 0, 0, 0, 0, 1], 143 | [0, 0, 0, 0, 0, 1], 144 | [0, 0, 0, 0, 0, 1], 145 | [0, 0, 0, 0, 0, 1], 146 | [0, 0, 0, 0, 0, 1]], 147 | [[0, 0, 0, 1, 1, 1], 148 | [0, 0, 0, 1, 1, 1], 149 | [0, 0, 0, 1, 1, 1], 150 | [0, 0, 0, 1, 1, 1], 151 | [0, 0, 0, 1, 1, 1], 152 | [0, 0, 0, 1, 1, 1]], 153 | [[0, 0, 1, 1, 1, 1], 154 | [0, 0, 1, 1, 1, 1], 155 | [0, 0, 1, 1, 1, 1], 156 | [0, 0, 1, 1, 1, 1], 157 | [0, 0, 1, 1, 1, 1], 158 | [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) 159 | 160 | """ 161 | if length_dim == 0: 162 | raise ValueError("length_dim cannot be 0: {}".format(length_dim)) 163 | 164 | if not isinstance(lengths, list): 165 | lengths = lengths.tolist() 166 | bs = int(len(lengths)) 167 | if xs is None: 168 | maxlen = int(max(lengths)) 169 | else: 170 | maxlen = xs.size(length_dim) 171 | 172 | seq_range = torch.arange(0, maxlen, dtype=torch.int64) 173 | seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) 174 | seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) 175 | mask = seq_range_expand >= seq_length_expand 176 | 177 | if xs is not None: 178 | assert xs.size(0) == bs, (xs.size(0), bs) 179 | 180 | if length_dim < 0: 181 | length_dim = xs.dim() + length_dim 182 | # ind = (:, None, ..., None, :, , None, ..., None) 183 | ind = tuple( 184 | slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) 185 | ) 186 | mask = mask[ind].expand_as(xs).to(xs.device) 187 | return mask 188 | 189 | 190 | 191 | def make_non_pad_mask(lengths, xs=None, length_dim=-1): 192 | """Make mask tensor containing indices of non-padded part. 193 | 194 | Args: 195 | lengths (LongTensor or List): Batch of lengths (B,). 196 | xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor. 197 | length_dim (int, optional): Dimension indicator of the above tensor. See the example. 198 | 199 | Returns: 200 | ByteTensor: mask tensor containing indices of padded part. 201 | 202 | Examples: 203 | With only lengths. 204 | 205 | >>> lengths = [5, 3, 2] 206 | >>> make_non_pad_mask(lengths) 207 | masks = [[1, 1, 1, 1 ,1], 208 | [1, 1, 1, 0, 0], 209 | [1, 1, 0, 0, 0]] 210 | 211 | With the reference tensor. 212 | 213 | >>> xs = torch.zeros((3, 2, 4)) 214 | >>> make_non_pad_mask(lengths, xs) 215 | tensor([[[1, 1, 1, 1], 216 | [1, 1, 1, 1]], 217 | [[1, 1, 1, 0], 218 | [1, 1, 1, 0]], 219 | [[1, 1, 0, 0], 220 | [1, 1, 0, 0]]], dtype=torch.uint8) 221 | >>> xs = torch.zeros((3, 2, 6)) 222 | >>> make_non_pad_mask(lengths, xs) 223 | tensor([[[1, 1, 1, 1, 1, 0], 224 | [1, 1, 1, 1, 1, 0]], 225 | [[1, 1, 1, 0, 0, 0], 226 | [1, 1, 1, 0, 0, 0]], 227 | [[1, 1, 0, 0, 0, 0], 228 | [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) 229 | 230 | With the reference tensor and dimension indicator. 231 | 232 | >>> xs = torch.zeros((3, 6, 6)) 233 | >>> make_non_pad_mask(lengths, xs, 1) 234 | tensor([[[1, 1, 1, 1, 1, 1], 235 | [1, 1, 1, 1, 1, 1], 236 | [1, 1, 1, 1, 1, 1], 237 | [1, 1, 1, 1, 1, 1], 238 | [1, 1, 1, 1, 1, 1], 239 | [0, 0, 0, 0, 0, 0]], 240 | [[1, 1, 1, 1, 1, 1], 241 | [1, 1, 1, 1, 1, 1], 242 | [1, 1, 1, 1, 1, 1], 243 | [0, 0, 0, 0, 0, 0], 244 | [0, 0, 0, 0, 0, 0], 245 | [0, 0, 0, 0, 0, 0]], 246 | [[1, 1, 1, 1, 1, 1], 247 | [1, 1, 1, 1, 1, 1], 248 | [0, 0, 0, 0, 0, 0], 249 | [0, 0, 0, 0, 0, 0], 250 | [0, 0, 0, 0, 0, 0], 251 | [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) 252 | >>> make_non_pad_mask(lengths, xs, 2) 253 | tensor([[[1, 1, 1, 1, 1, 0], 254 | [1, 1, 1, 1, 1, 0], 255 | [1, 1, 1, 1, 1, 0], 256 | [1, 1, 1, 1, 1, 0], 257 | [1, 1, 1, 1, 1, 0], 258 | [1, 1, 1, 1, 1, 0]], 259 | [[1, 1, 1, 0, 0, 0], 260 | [1, 1, 1, 0, 0, 0], 261 | [1, 1, 1, 0, 0, 0], 262 | [1, 1, 1, 0, 0, 0], 263 | [1, 1, 1, 0, 0, 0], 264 | [1, 1, 1, 0, 0, 0]], 265 | [[1, 1, 0, 0, 0, 0], 266 | [1, 1, 0, 0, 0, 0], 267 | [1, 1, 0, 0, 0, 0], 268 | [1, 1, 0, 0, 0, 0], 269 | [1, 1, 0, 0, 0, 0], 270 | [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) 271 | 272 | """ 273 | return ~make_pad_mask(lengths, xs, length_dim) 274 | 275 | 276 | 277 | # --------------------------------------------- # 278 | # Data Utils 279 | # --------------------------------------------- # 280 | 281 | class ImagePaths(Dataset): 282 | def __init__(self, path, size=None): 283 | self.size = size 284 | 285 | self.images = [os.path.join(path, file) for file in os.listdir(path)] 286 | self._length = len(self.images) 287 | 288 | self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) 289 | self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) 290 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 291 | 292 | def __len__(self): 293 | return self._length 294 | 295 | def preprocess_image(self, image_path): 296 | image = Image.open(image_path) 297 | if not image.mode == "RGB": 298 | image = image.convert("RGB") 299 | image = np.array(image).astype(np.uint8) 300 | image = self.preprocessor(image=image)["image"] 301 | image = (image / 127.5 - 1.0).astype(np.float32) 302 | image = image.transpose(2, 0, 1) 303 | return image 304 | 305 | def __getitem__(self, i): 306 | example = self.preprocess_image(self.images[i]) 307 | return example 308 | 309 | 310 | def load_data(args): 311 | train_data = ImagePaths(args.dataset_path, size=256) 312 | train_loader = DataLoader(train_data, batch_size=args.batch_size, shuffle=False) 313 | return train_loader 314 | 315 | 316 | # --------------------------------------------- # 317 | # Module Utils 318 | # for Encoder, Decoder etc. 319 | # --------------------------------------------- # 320 | 321 | def weights_init(m): 322 | classname = m.__class__.__name__ 323 | if classname.find('Conv') != -1: 324 | nn.init.normal_(m.weight.data, 0.0, 0.02) 325 | elif classname.find('BatchNorm') != -1: 326 | nn.init.normal_(m.weight.data, 1.0, 0.02) 327 | nn.init.constant_(m.bias.data, 0) 328 | 329 | 330 | def plot_images(images: dict): 331 | x = images["input"] 332 | reconstruction = images["rec"] 333 | half_sample = images["half_sample"] 334 | new_sample = images["new_sample"] 335 | 336 | fig, axarr = plt.subplots(1, 4) 337 | axarr[0].imshow(x.cpu().detach().numpy()[0].transpose(1, 2, 0)) 338 | axarr[1].imshow(reconstruction.cpu().detach().numpy()[0].transpose(1, 2, 0)) 339 | axarr[2].imshow(half_sample.cpu().detach().numpy()[0].transpose(1, 2, 0)) 340 | axarr[3].imshow(new_sample.cpu().detach().numpy()[0].transpose(1, 2, 0)) 341 | plt.show() 342 | --------------------------------------------------------------------------------