├── LICENSE ├── README.md ├── product_vector_quantize.py ├── residual_vector_quantize.py ├── vector_quantize.py ├── vector_quantize_gumbel_softmax.py └── vector_quantize_moving_average.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 0x11DE784A 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Pytorch Vector Quantization 2 | A pytorch library for vector quantization methods. 3 | Vector quantization has been successfully used by high-quality image and audio generation, e.g., VQVAE, VQGAN. 4 | 5 | Implemented methods: 6 | - [x] Vector Quantization 7 | - [x] Vector Quantization based on momentum moving average 8 | - [x] Vector Quantization based on gumbel-softmax trick 9 | - [x] Product Quantization 10 | - [x] Residual Quantization 11 | 12 | ## Usage 13 | 14 | ```python 15 | import torch 16 | from vector_quantize import VectorQuantizer 17 | 18 | vq = VectorQuantizer( 19 | n_e = 1024, # codebook vocalbulary size 20 | e_dim = 256, # codebook vocalbulary dimension 21 | beta = 1.0, # the weight on the commitment loss 22 | ) 23 | 24 | x = torch.randn(1, 256, 16, 16) # size of NCHW 25 | quantized, commit_loss, indices = vq(x) # shape of (1, 256, 16, 16), (1), (1, 16, 16) 26 | ``` 27 | 28 | -------------------------------------------------------------------------------- /product_vector_quantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from vector_quantize import VectorQuantizer 5 | 6 | 7 | class ProductVectorQuantizer(nn.Module): 8 | def __init__(self, n_e, e_dim, num_quantizers): 9 | super().__init__() 10 | self.n_e = n_e 11 | self.e_dim = e_dim 12 | self.num_quantizers = num_quantizers 13 | assert self.e_dim % self.num_quantizers == 0 14 | self.vq_layers = nn.ModuleList([VectorQuantizer(n_e, e_dim//num_quantizers) for _ in range(num_quantizers)]) 15 | 16 | def get_codebook(self): 17 | all_codebook = [] 18 | for quantizer in self.vq_layers: 19 | codebook = quantizer.get_codebook() 20 | all_codebook.append(codebook) 21 | return torch.stack(all_codebook) 22 | 23 | def forward(self, z): 24 | all_z_q = [] 25 | all_losses = [] 26 | all_min_encoding_indices = [] 27 | 28 | z_chunk = torch.chunk(z, self.num_quantizers, dim=1) 29 | for idx, quantizer in enumerate(self.vq_layers): 30 | z_q, loss, min_encoding_indices = quantizer(z_chunk[idx]) 31 | 32 | all_z_q.append(z_q) 33 | all_losses.append(loss) 34 | all_min_encoding_indices.append(min_encoding_indices) 35 | 36 | all_z_q = torch.cat(all_z_q, dim=1) 37 | mean_losses = torch.stack(all_losses).mean() 38 | all_min_encoding_indices = torch.stack(all_min_encoding_indices, dim=1) 39 | 40 | return all_z_q, mean_losses, all_min_encoding_indices 41 | -------------------------------------------------------------------------------- /residual_vector_quantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from vector_quantize import VectorQuantizer 5 | 6 | 7 | class ResidualVectorQuantizer(nn.Module): 8 | """ References: 9 | SoundStream: An End-to-End Neural Audio Codec 10 | https://arxiv.org/pdf/2107.03312.pdf 11 | """ 12 | 13 | def __init__(self, n_e, e_dim, num_quantizers): 14 | super().__init__() 15 | self.n_e = n_e 16 | self.e_dim = e_dim 17 | self.num_quantizers = num_quantizers 18 | self.vq_layers = nn.ModuleList([VectorQuantizer(n_e, e_dim) for _ in range(num_quantizers)]) 19 | 20 | def get_codebook(self): 21 | all_codebook = [] 22 | for quantizer in self.vq_layers: 23 | codebook = quantizer.get_codebook() 24 | all_codebook.append(codebook) 25 | return torch.stack(all_codebook) 26 | 27 | def forward(self, z): 28 | all_losses = [] 29 | all_min_encoding_indices = [] 30 | 31 | z_q = 0 32 | residual = z 33 | for quantizer in self.vq_layers: 34 | z_res, loss, min_encoding_indices = quantizer(residual) 35 | residual = residual - z_res 36 | z_q = z_q + z_res 37 | 38 | all_losses.append(loss) 39 | all_min_encoding_indices.append(min_encoding_indices) 40 | 41 | mean_losses = torch.stack(all_losses).mean() 42 | all_min_encoding_indices = torch.stack(all_min_encoding_indices, dim=1) 43 | 44 | return z_q, mean_losses, all_min_encoding_indices 45 | -------------------------------------------------------------------------------- /vector_quantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class VectorQuantizer(nn.Module): 6 | """ 7 | Reference: 8 | Taming Transformers for High-Resolution Image Synthesis 9 | https://arxiv.org/pdf/2012.09841.pdf 10 | """ 11 | 12 | def __init__(self, n_e, e_dim, beta=1.0): 13 | super().__init__() 14 | self.n_e = n_e 15 | self.e_dim = e_dim 16 | self.beta = beta 17 | 18 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 19 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 20 | 21 | def get_codebook(self): 22 | return self.embedding.weight 23 | 24 | def get_codebook_entry(self, indices, shape=None): 25 | # get quantized latent vectors 26 | z_q = self.embedding(indices) 27 | if shape is not None: 28 | z_q = z_q.view(shape) 29 | # shape specifying (batch, height, width, channel) 30 | # reshape back to match original input shape 31 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 32 | return z_q 33 | 34 | def forward(self, z): 35 | # reshape z -> (batch, height, width, channel) and flatten 36 | z = z.permute(0, 2, 3, 1).contiguous() 37 | z_flattened = z.view(-1, self.e_dim) 38 | 39 | # distances from z to embeddings e (z - e)^2 = z^2 + e^2 - 2 e * z 40 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 41 | torch.sum(self.embedding.weight ** 2, dim=1, keepdim=True).t() - \ 42 | 2 * torch.matmul(z_flattened, self.embedding.weight.t()) 43 | 44 | min_encoding_indices = torch.argmin(d, dim=1) 45 | z_q = self.embedding(min_encoding_indices).view(z.shape) 46 | 47 | # compute loss for embedding 48 | loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \ 49 | torch.mean((z_q - z.detach()) ** 2) 50 | 51 | # preserve gradients 52 | z_q = z + (z_q - z).detach() 53 | 54 | # reshape back to match original input shape 55 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 56 | 57 | min_encoding_indices = min_encoding_indices.view(z.shape[:-1]) 58 | 59 | return z_q, loss, min_encoding_indices 60 | -------------------------------------------------------------------------------- /vector_quantize_gumbel_softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import numpy as np 6 | from torch import einsum 7 | 8 | 9 | class GumbelQuantize(nn.Module): 10 | """ 11 | Reference: 12 | Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 13 | https://arxiv.org/abs/1611.01144 14 | """ 15 | 16 | def __init__(self, hidden_channel, n_e, e_dim, kl_weight=1.0, 17 | temp_init=1.0, straight_through=True,): 18 | super().__init__() 19 | 20 | self.e_dim = e_dim 21 | self.n_e = n_e 22 | 23 | self.straight_through = straight_through 24 | self.temperature = temp_init 25 | self.kl_weight = kl_weight 26 | 27 | self.proj = nn.Conv2d(hidden_channel, n_e, kernel_size=1) 28 | self.embedding = nn.Embedding(n_e, e_dim) 29 | 30 | def get_codebook(self): 31 | return self.embedding.weight 32 | 33 | def get_codebook_entry(self, indices, shape=None): 34 | # get quantized latent vectors 35 | z_q = self.embedding(indices) 36 | if shape is not None: 37 | z_q = z_q.view(shape) 38 | # shape specifying (batch, height, width, channel) 39 | # reshape back to match original input shape 40 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 41 | return z_q 42 | 43 | def forward(self, z, temp=None): 44 | hard = self.straight_through if self.training else True 45 | temp = self.temperature if temp is None else temp 46 | 47 | logits = self.proj(z) 48 | soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) 49 | min_encoding_indices = soft_one_hot.argmax(dim=1) 50 | 51 | z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embedding.weight) 52 | 53 | # kl divergence loss w.r.t uniform distributions 54 | code_prob = F.softmax(logits, dim=1) 55 | loss = self.kl_weight * torch.sum(code_prob * torch.log(code_prob * self.n_e + 1e-10), dim=1).mean() 56 | 57 | return z_q, loss, min_encoding_indices -------------------------------------------------------------------------------- /vector_quantize_moving_average.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def moving_average(moving_avg, new, decay): 7 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) 8 | 9 | 10 | def laplace_smoothing(x, n_categories, eps=1e-5): 11 | return (x + eps) / (x.sum() + n_categories * eps) 12 | 13 | 14 | class MovingAverageVectorQuantizer(nn.Module): 15 | """ 16 | Reference: 17 | https://github.com/deepmind/sonnet 18 | """ 19 | 20 | def __init__(self, n_e, e_dim, decay=0.99, beta=1.0, eps=1e-5): 21 | super().__init__() 22 | 23 | self.n_e = n_e 24 | self.e_dim = e_dim 25 | self.decay = decay 26 | self.beta = beta 27 | self.eps = eps 28 | 29 | embedding = torch.randn(n_e, e_dim) 30 | self.register_buffer('embedding', embedding) 31 | self.register_buffer('embedding_avg', embedding.clone()) 32 | self.register_buffer('cluster_size', torch.zeros(n_e)) 33 | 34 | def get_codebook(self): 35 | return self.embedding 36 | 37 | def get_codebook_entry(self, indices, shape=None): 38 | # get quantized latent vectors 39 | z_q = F.embedding(indices, self.embedding) 40 | if shape is not None: 41 | z_q = z_q.view(shape) 42 | # shape specifying (batch, height, width, channel) 43 | # reshape back to match original input shape 44 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 45 | return z_q 46 | 47 | def forward(self, z): 48 | # reshape z -> (batch, height, width, channel) and flatten 49 | z = z.permute(0, 2, 3, 1).contiguous() 50 | z_flattened = z.view(-1, self.e_dim) 51 | 52 | 53 | # distances from z to embeddings e (z - e)^2 = z^2 + e^2 - 2 e * z 54 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 55 | torch.sum(self.embedding ** 2, dim=1, keepdim=True).t() - \ 56 | 2 * torch.matmul(z_flattened, self.embedding.t()) 57 | 58 | min_encoding_indices = torch.argmin(d, dim=1) 59 | z_q = F.embedding(min_encoding_indices, self.embedding).view(z.shape) 60 | 61 | # update codebook embedding by moving average 62 | if self.training: 63 | embedding_onehot = F.one_hot(min_encoding_indices, self.n_e).type(z_flattened.dtype) 64 | embedding_sum = embedding_onehot.t() @ z_flattened 65 | # TODO: all-reduce embedding_onehot and embedding_sum across gpus 66 | moving_average(self.cluster_size, embedding_onehot.sum(0), self.decay) 67 | moving_average(self.embedding_avg, embedding_sum, self.decay) 68 | n = self.cluster_size.sum() 69 | cluster_size = laplace_smoothing(self.cluster_size, self.n_e, self.eps) * n 70 | embedding_normalized = self.embedding_avg / cluster_size.unsqueeze(1) 71 | self.embedding.data.copy_(embedding_normalized) 72 | 73 | # compute loss for embedding 74 | loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \ 75 | torch.mean((z_q - z.detach()) ** 2) 76 | 77 | # preserve gradients 78 | z_q = z + (z_q - z).detach() 79 | 80 | # reshape back to match original input shape 81 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 82 | 83 | min_encoding_indices = min_encoding_indices.view(z.shape[:-1]) 84 | 85 | return z_q, loss, min_encoding_indices 86 | 87 | 88 | --------------------------------------------------------------------------------