├── LICENSE ├── README.md └── SRVQ.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yi Luo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spherical residual vector quantization (SRVQ) 2 | 3 | This repository contains a Pytorch-based minimalist implementation of the spherical residual vector quantization (SRVQ) module used in our [Gull neural audio codec](https://arxiv.org/abs/2404.04947) framework. Find the demo page [here](https://yluo42.github.io/Gull/). 4 | 5 | SRVQ is a modification to the standard RVQ to better quantize unit-norm inputs. The general idea is to use unit-norm codebooks with standard VQ-VAE selection and update scheme at the first hierarchy (R=1), while use **rotation matrices** defined by Householder transformations (treat learnable reflection matrices as learnable rotation matrices) for other hierarchies (R>1). 6 | 7 | # Reference 8 | If you use SRVQ in your project, please consider citing the following paper: 9 | 10 | > @article{luo2024gull, 11 | > title={Gull: A Generative Multifunctional Audio Codec}, 12 | > author={Luo, Yi and Yu, Jianwei and Chen, Hangting and Gu Rongzhi and Weng, Chao}, 13 | > journal={arXiv preprint arXiv:2404.04947}, 14 | > year={2024} 15 | > } 16 | -------------------------------------------------------------------------------- /SRVQ.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # modified from https://github.com/bshall/VectorQuantizedVAE/blob/master/model.py 6 | class SVQ(nn.Module): 7 | """ 8 | Spherical VQ via standard EMA-based VQ-VAE scheme. 9 | """ 10 | def __init__(self, num_code, code_dim, decay=0.99, stale_tolerance=100): 11 | super().__init__() 12 | 13 | self.num_code = num_code 14 | self.code_dim = code_dim 15 | self.decay = decay 16 | self.stale_tolerance = stale_tolerance 17 | self.eps = torch.finfo(torch.float32).eps 18 | 19 | # unit-norm codebooks 20 | embedding = torch.empty(num_code, code_dim).normal_() 21 | embedding = embedding / (embedding.pow(2).sum(-1) + self.eps).sqrt().unsqueeze(-1) 22 | self.register_buffer("embedding", embedding) 23 | self.register_buffer("ema_weight", self.embedding.clone()) 24 | self.register_buffer("ema_count", torch.zeros(self.num_code)) 25 | self.register_buffer("stale_counter", torch.zeros(self.num_code)) 26 | 27 | def forward(self, input): 28 | 29 | B, N, T = input.shape 30 | assert N == self.code_dim 31 | 32 | input_detach = input.detach().mT.reshape(B*T, self.code_dim) # B*T, dim 33 | 34 | # distance 35 | eu_dis = 2 - 2 * input_detach.mm(self.embedding.T) # B*T, num_code 36 | 37 | # best codes 38 | indices = torch.argmin(eu_dis, dim=-1) # B*T 39 | quantized = torch.gather(self.embedding, 0, indices.unsqueeze(-1).expand(-1, self.code_dim)) # B*T, dim 40 | quantized = quantized.reshape(B, T, N).mT # B, N, T 41 | 42 | if self.training: 43 | # EMA update for codebook 44 | encodings = F.one_hot(indices, self.num_code).float() # B*T, num_code 45 | self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0) # num_code 46 | 47 | update_direction = encodings.T.mm(input_detach) # num_code, dim 48 | self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * update_direction # num_code, dim 49 | 50 | # Laplace smoothing on the counters 51 | # make sure the denominator will never be zero 52 | n = torch.sum(self.ema_count) 53 | self.ema_count = (self.ema_count + self.eps) / (n + self.num_code * self.eps) * n # num_code 54 | 55 | self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1) 56 | 57 | # calculate code usage 58 | stale_codes = (encodings.sum(0) == 0).float() # num_code 59 | self.stale_counter = self.stale_counter * stale_codes + stale_codes 60 | 61 | # random replace codes that haven't been used for a while 62 | replace_code = (self.stale_counter == self.stale_tolerance).float() # num_code 63 | if replace_code.max() > 0: 64 | random_input_idx = torch.randperm(input_detach.shape[0]) 65 | random_input = input_detach[random_input_idx].reshape(input_detach.shape) 66 | if random_input.shape[0] < self.num_code: 67 | random_input = torch.cat([random_input]*(self.num_code // random_input.shape[0] + 1), 0) 68 | random_input = random_input[:self.num_code] # num_code, dim 69 | 70 | self.embedding = self.embedding * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) 71 | self.ema_weight = self.ema_weight * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) 72 | self.ema_count = self.ema_count * (1 - replace_code) 73 | self.stale_counter = self.stale_counter * (1 - replace_code) 74 | 75 | # unit-norm codebooks 76 | self.embedding = self.embedding / (self.embedding.pow(2).sum(-1) + self.eps).sqrt().unsqueeze(-1) 77 | self.ema_weight = self.ema_weight / (self.ema_weight.pow(2).sum(-1) + self.eps).sqrt().unsqueeze(-1) 78 | 79 | return quantized 80 | 81 | class RotVQ(nn.Module): 82 | """ 83 | Rotary VQ via Householder transform. 84 | """ 85 | def __init__(self, num_code, code_dim): 86 | super().__init__() 87 | 88 | self.num_code = num_code 89 | self.code_dim = code_dim 90 | self.eps = torch.finfo(torch.float32).eps 91 | 92 | self.rot_emb = nn.Parameter(torch.randn(num_code, code_dim)) 93 | 94 | def forward(self, prev_input, target): 95 | 96 | B, N, T = prev_input.shape 97 | assert N == self.code_dim 98 | 99 | prev_input = prev_input.mT.reshape(B*T, N) # B*T, dim 100 | target = target.mT.reshape(B*T, N) # B*T, dim 101 | 102 | # rotation matrices 103 | # a more efficient implementation without explicitly calculating the rotation matrices 104 | rot_emb = self.rot_emb / (self.rot_emb.pow(2).sum(-1) + self.eps).sqrt().unsqueeze(-1) 105 | # always contain an identity rotation matrix 106 | rot_emb = torch.cat([rot_emb[:1] * 0., rot_emb[1:]], 0) 107 | eu_dis = 2 - 2 * (target * prev_input).sum(-1).unsqueeze(-1) # B*T, 1 108 | eu_dis = eu_dis + 4 * target.mm(rot_emb.T) * prev_input.mm(rot_emb.T) # B*T, num_code 109 | 110 | # best codes 111 | indices = torch.argmin(eu_dis, dim=-1) # B*T 112 | encodings = F.one_hot(indices, self.num_code).float() # B*T, num_code 113 | rot_emb = encodings.mm(rot_emb) # B*T, dim 114 | quantized = prev_input - 2 * (prev_input * rot_emb).sum(-1).unsqueeze(-1) * rot_emb 115 | quantized = quantized.reshape(B, T, N).mT # B, N, T 116 | 117 | return quantized 118 | 119 | class Quantizer(nn.Module): 120 | def __init__(self, code_dim, decay=0.99, stale_tolerance=100, bit=[12]): 121 | super().__init__() 122 | 123 | self.code_dim = code_dim 124 | self.eps = torch.finfo(torch.float32).eps 125 | 126 | self.RVQ = nn.ModuleList([]) 127 | for i in range(len(bit)): 128 | if i == 0: 129 | self.RVQ.append(SVQ(2**bit[i], code_dim, decay, stale_tolerance)) 130 | else: 131 | self.RVQ.append(RotVQ(2**bit[i], code_dim)) 132 | 133 | def forward(self, input): 134 | 135 | quantized = [] 136 | for i in range(len(self.RVQ)): 137 | if i == 0: 138 | this_quantized = self.RVQ[i](input) 139 | # straight-through estimator 140 | this_quantized = (this_quantized - input).detach() + input 141 | else: 142 | this_quantized = self.RVQ[i](quantized[-1], input) 143 | quantized.append(this_quantized) 144 | 145 | latent_loss = [] 146 | for i in range(len(self.RVQ)): 147 | if i == 0: 148 | latent_loss.append(F.mse_loss(input, quantized[i].detach())) 149 | else: 150 | latent_loss.append((F.mse_loss(input, quantized[i].detach()) + F.mse_loss(input.detach(), quantized[i])) / 2.) 151 | 152 | quantized = torch.stack(quantized, -1) 153 | latent_loss = torch.stack(latent_loss, -1) 154 | 155 | return quantized, latent_loss 156 | 157 | if __name__ == '__main__': 158 | Q = Quantizer(code_dim=64, bit=[10]*5) 159 | input = torch.rand(2, 64, 100) 160 | input = input / input.pow(2).sum(1).sqrt().unsqueeze(1) 161 | 162 | quantized, latent_loss = Q(input) # no need to apply straight-through estimator between quantized and input again 163 | print(quantized.shape) 164 | print(latent_loss) # non-increasing 165 | --------------------------------------------------------------------------------