├── .gitignore ├── LICENSE ├── README.md ├── assets ├── vqtorch-logo.pdf └── vqtorch-logo.png ├── examples ├── autoencoder.py ├── classification.py ├── experimental_group_affine.py ├── experimental_inplace_update.py └── test.py ├── setup.py └── vqtorch ├── __init__.py ├── dists.py ├── math_fns.py ├── nn ├── __init__.py ├── affine.py ├── gvq.py ├── pool.py ├── rvq.py ├── utils │ ├── __init__.py │ ├── init.py │ └── replace.py ├── vq.py └── vq_base.py ├── norms.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | **__pycache__/** 3 | **.egg** 4 | **.DS_Store** 5 | **.swp** -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 minyoung huh (jacob) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |
3 | 4 | 5 | 6 | --- 7 | 8 | VQTorch is a PyTorch library for vector quantization. 9 | 10 | The library was developed and used for. 11 | - `[1] Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks, Huh et al. ICML2023` 12 | 13 | ## Installation 14 | Development was done on Ubuntu with Python 3.9/3.10 using NVIDIA GPUs. Some requirements may need to be adjusted in order to run. 15 | Some features, such as half-precision cdist and cuda-based kmeans, are only supported on CUDA devices. 16 | 17 | First install the correct version of [cupy](https://github.com/cupy/cupy/). Make sure to install the correct version. The version refers to `CUDA Version` number when using the command `nvidia-smi`. `cupy` seem to now support ROCm drivers but this has not been tested. 18 | ```bash 19 | # recent 12.x cuda versions 20 | pip install cupy-cuda12x 21 | 22 | # 11.x versions (for even older see the repo above) 23 | pip install cupy-cuda11x 24 | ``` 25 | 26 | Next, install `vqtorch` 27 | ```bash 28 | git clone https://github.com/minyoungg/vqtorch 29 | cd vqtorch 30 | pip install -e . 31 | ``` 32 | 33 | ## Example usage 34 | For examples using `VectorQuant` for classification and auto-encoders check out [here](./examples/). 35 | 36 | ```python 37 | import torch 38 | from vqtorch.nn import VectorQuant 39 | 40 | print('Testing VectorQuant') 41 | # create VQ layer 42 | vq_layer = VectorQuant( 43 | feature_size=32, # feature dimension corresponding to the vectors 44 | num_codes=1024, # number of codebook vectors 45 | beta=0.98, # (default: 0.9) commitment trade-off 46 | kmeans_init=True, # (default: False) whether to use kmeans++ init 47 | norm=None, # (default: None) normalization for the input vectors 48 | cb_norm=None, # (default: None) normalization for codebook vectors 49 | affine_lr=10.0, # (default: 0.0) lr scale for affine parameters 50 | sync_nu=0.2, # (default: 0.0) codebook synchronization contribution 51 | replace_freq=20, # (default: None) frequency to replace dead codes 52 | dim=-1, # (default: -1) dimension to be quantized 53 | ).cuda() 54 | 55 | # when `kmeans_init=True` is recommended to warm up the codebook before training 56 | with torch.no_grad(): 57 | z_e = torch.randn(128, 8, 8, 32).cuda() 58 | vq_layer(z_e) 59 | 60 | # standard forward pass 61 | z_e = torch.randn(128, 8, 8, 32).cuda() 62 | z_q, vq_dict = vq_layer(z_e) 63 | 64 | print(vq_dict.keys) 65 | >>> dict_keys(['z', 'z_q', 'd', 'q', 'loss', 'perplexity']) 66 | ``` 67 | 68 | ## Supported features 69 | - `vqtorch.nn.GroupVectorQuant` - Vectors are quantized by first partitioning into `n` subvectors. 70 | - `vqtorch.nn.ResidualVectorQuant` - Vectors are first quantized and the residuals are repeatedly quantized. 71 | - `vqtorch.nn.MaxVecPool2d` - Pools along the vector dimension by selecting the vector with the maximum norm. 72 | - `vqtorch.nn.SoftMaxVecPool2d` - Pools along the vector dimension by the weighted average computed by softmax over the norm. 73 | - `vqtorch.no_vq` - Disables all vector quantization layers that inherit `vqtorch.nn._VQBaseLayer` 74 | ```python 75 | model = VQN(...) 76 | with vqtorch.no_vq(): 77 | out = model(x) 78 | ``` 79 | 80 | ## Experimental features 81 | - Group affine parameterization: divides the codebook into groups. The individual group is reparameterized with its own affine parameters. One can invoke it via 82 | ```python 83 | vq_layer = VectorQuant(..., affine_groups=8) 84 | ``` 85 | - In-place alternated optimization: in-place codebook during the forward pass. 86 | ```python 87 | inplace_optimizer = lambda *args, **kwargs: torch.optim.SGD(*args, **kwargs, lr=50.0, momentum=0.9) 88 | vq_layer = VectorQuant(inplace_optimizer=inplace_optimizer) 89 | ``` 90 | 91 | ## Planned features 92 | We aim to incorporate commonly used VQ methods, including probabilistic VQ variants. 93 | 94 | 95 | ## Citations 96 | If the features such as `affine parameterization`, `synchronized commitment loss` or `alternating optimization` was useful, please consider citing 97 | 98 | ```bibtex 99 | @inproceedings{huh2023improvedvqste, 100 | title={Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks}, 101 | author={Huh, Minyoung and Cheung, Brian and Agrawal, Pulkit and Isola, Phillip}, 102 | booktitle={International Conference on Machine Learning}, 103 | year={2023}, 104 | organization={PMLR} 105 | } 106 | ``` 107 | 108 | If you found the library useful please consider citing 109 | ```bibtex 110 | @misc{huh2023vqtorch, 111 | author = {Huh, Minyoung}, 112 | title = {vqtorch: {P}y{T}orch Package for Vector Quantization}, 113 | year = {2022}, 114 | howpublished = {\url{https://github.com/minyoungg/vqtorch}}, 115 | } 116 | ``` 117 | -------------------------------------------------------------------------------- /assets/vqtorch-logo.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minyoungg/vqtorch/02e60a19bd742c17b0bf3e1925f23796d54cbeac/assets/vqtorch-logo.pdf -------------------------------------------------------------------------------- /assets/vqtorch-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minyoungg/vqtorch/02e60a19bd742c17b0bf3e1925f23796d54cbeac/assets/vqtorch-logo.png -------------------------------------------------------------------------------- /examples/autoencoder.py: -------------------------------------------------------------------------------- 1 | # mnist VQ experiment with various settings. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision import datasets 6 | 7 | from vqtorch.nn import VectorQuant 8 | from tqdm.auto import trange 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | 12 | 13 | lr = 3e-4 14 | train_iter = 1000 15 | num_codes = 256 16 | seed = 1234 17 | 18 | class SimpleVQAutoEncoder(nn.Module): 19 | def __init__(self, **vq_kwargs): 20 | super().__init__() 21 | self.layers = nn.ModuleList([ 22 | nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), 23 | nn.MaxPool2d(kernel_size=2, stride=2), 24 | nn.GELU(), 25 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), 26 | nn.MaxPool2d(kernel_size=2, stride=2), 27 | VectorQuant(32, **vq_kwargs), 28 | nn.Upsample(scale_factor=2, mode='nearest'), 29 | nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1), 30 | nn.GELU(), 31 | nn.Upsample(scale_factor=2, mode='nearest'), 32 | nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1), 33 | ]) 34 | return 35 | 36 | def forward(self, x): 37 | for layer in self.layers: 38 | if isinstance(layer, VectorQuant): 39 | x, vq_dict = layer(x) 40 | else: 41 | x = layer(x) 42 | return x.clamp(-1, 1), vq_dict 43 | 44 | 45 | def train(model, train_loader, train_iterations=1000, alpha=10): 46 | def iterate_dataset(data_loader): 47 | data_iter = iter(data_loader) 48 | while True: 49 | try: 50 | x, y = next(data_iter) 51 | except StopIteration: 52 | data_iter = iter(data_loader) 53 | x, y = next(data_iter) 54 | yield x.cuda(), y.cuda() 55 | 56 | for _ in (pbar := trange(train_iterations)): 57 | opt.zero_grad() 58 | x, _ = next(iterate_dataset(train_loader)) 59 | out, vq_out = model(x) 60 | rec_loss = (out - x).abs().mean() 61 | cmt_loss = vq_out['loss'] 62 | (rec_loss + alpha * cmt_loss).backward() 63 | 64 | opt.step() 65 | pbar.set_description(f'rec loss: {rec_loss.item():.3f} | ' + \ 66 | f'cmt loss: {cmt_loss.item():.3f} | ' + \ 67 | f'active %: {vq_out["q"].unique().numel() / num_codes * 100:.3f}') 68 | return 69 | 70 | 71 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) 72 | train_dataset = DataLoader(datasets.MNIST(root='~/data/mnist', train=True, download=True, transform=transform), batch_size=256, shuffle=True) 73 | 74 | print('baseline') 75 | torch.random.manual_seed(seed) 76 | model = SimpleVQAutoEncoder(num_codes=num_codes).cuda() 77 | opt = torch.optim.AdamW(model.parameters(), lr=lr) 78 | train(model, train_dataset, train_iterations=train_iter) 79 | 80 | 81 | print('+ kmeans init') 82 | torch.random.manual_seed(seed) 83 | model = SimpleVQAutoEncoder(num_codes=num_codes, kmeans_init=True).cuda() 84 | opt = torch.optim.AdamW(model.parameters(), lr=lr) 85 | train(model, train_dataset, train_iterations=train_iter) 86 | 87 | 88 | print('+ synchronized update') 89 | torch.random.manual_seed(seed) 90 | model = SimpleVQAutoEncoder(num_codes=num_codes, kmeans_init=True, sync_nu=2.0).cuda() 91 | opt = torch.optim.AdamW(model.parameters(), lr=lr) 92 | train(model, train_dataset, train_iterations=train_iter) 93 | 94 | 95 | print('+ affine parameterization') 96 | torch.random.manual_seed(seed) 97 | model = SimpleVQAutoEncoder(num_codes=num_codes, kmeans_init=True, sync_nu=2.0, affine_lr=2.0).cuda() 98 | opt = torch.optim.AdamW(model.parameters(), lr=lr) 99 | train(model, train_dataset, train_iterations=train_iter) -------------------------------------------------------------------------------- /examples/classification.py: -------------------------------------------------------------------------------- 1 | # mnist VQ experiment with various settings. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision import datasets 6 | 7 | from vqtorch.nn import VectorQuant 8 | from tqdm.auto import trange 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | 12 | 13 | lr = 3e-4 14 | train_iter = 1000 15 | num_codes = 256 16 | seed = 1234 17 | 18 | class SimpleVQClassifier(nn.Module): 19 | def __init__(self, **vq_kwargs): 20 | super().__init__() 21 | self.layers = nn.ModuleList([ 22 | nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), 23 | nn.MaxPool2d(kernel_size=2, stride=2), 24 | nn.GELU(), 25 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), 26 | nn.MaxPool2d(kernel_size=2, stride=2), 27 | VectorQuant(32, **vq_kwargs), 28 | nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), 29 | nn.AdaptiveMaxPool2d((1, 1)), 30 | nn.Flatten(), 31 | nn.Linear(64, 10), 32 | ]) 33 | return 34 | 35 | def forward(self, x): 36 | for layer in self.layers: 37 | if isinstance(layer, VectorQuant): 38 | x, vq_dict = layer(x) 39 | else: 40 | x = layer(x) 41 | return x, vq_dict 42 | 43 | 44 | def train(model, train_loader, train_iterations=1000, alpha=10): 45 | def iterate_dataset(data_loader): 46 | data_iter = iter(data_loader) 47 | while True: 48 | try: 49 | x, y = next(data_iter) 50 | except StopIteration: 51 | data_iter = iter(data_loader) 52 | x, y = next(data_iter) 53 | yield x.cuda(), y.cuda() 54 | 55 | criterion = nn.CrossEntropyLoss() 56 | 57 | for _ in (pbar := trange(train_iterations)): 58 | opt.zero_grad() 59 | x, y = next(iterate_dataset(train_loader)) 60 | out, vq_out = model(x) 61 | sce_loss = criterion(out, y) 62 | cmt_loss = vq_out['loss'] 63 | acc = (out.argmax(dim=1) == y).float().mean() 64 | (sce_loss + alpha * cmt_loss).backward() 65 | 66 | opt.step() 67 | pbar.set_description(f'sce loss: {sce_loss.item():.3f} | ' + \ 68 | f'cmt loss: {cmt_loss.item():.3f} | ' + \ 69 | f'acc: {acc.item() * 100:.1f} | ' + \ 70 | f'active %: {vq_out["q"].unique().numel() / num_codes * 100:.3f}') 71 | return 72 | 73 | 74 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) 75 | train_dataset = DataLoader(datasets.MNIST(root='~/data/mnist', train=True, download=True, transform=transform), batch_size=256, shuffle=True) 76 | 77 | 78 | print('baseline') 79 | torch.random.manual_seed(seed) 80 | model = SimpleVQClassifier(num_codes=num_codes).cuda() 81 | opt = torch.optim.AdamW(model.parameters(), lr=lr) 82 | train(model, train_dataset, train_iterations=train_iter) 83 | 84 | 85 | print('+ kmeans init') 86 | torch.random.manual_seed(seed) 87 | model = SimpleVQClassifier(num_codes=num_codes, kmeans_init=True).cuda() 88 | opt = torch.optim.AdamW(model.parameters(), lr=lr) 89 | train(model, train_dataset, train_iterations=train_iter) 90 | 91 | 92 | print('+ synchronized update') 93 | torch.random.manual_seed(seed) 94 | model = SimpleVQClassifier(num_codes=num_codes, kmeans_init=True, sync_nu=1.0).cuda() 95 | opt = torch.optim.AdamW(model.parameters(), lr=lr) 96 | train(model, train_dataset, train_iterations=train_iter) 97 | 98 | 99 | print('+ affine parameterization') 100 | torch.random.manual_seed(seed) 101 | model = SimpleVQClassifier(num_codes=num_codes, kmeans_init=True, sync_nu=1.0, affine_lr=10).cuda() 102 | opt = torch.optim.AdamW(model.parameters(), lr=lr) 103 | train(model, train_dataset, train_iterations=train_iter) -------------------------------------------------------------------------------- /examples/experimental_group_affine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from vqtorch.nn import VectorQuant, GroupVectorQuant, ResidualVectorQuant 3 | 4 | 5 | print('Testing VectorQuant') 6 | # create VQ layer 7 | vq_layer = VectorQuant( 8 | feature_size=32, # feature dimension corresponding to the vectors 9 | num_codes=1024, # number of codebook vectors 10 | beta=0.98, # (default: 0.9) commitment trade-off 11 | kmeans_init=True, # (default: False) whether to use kmeans++ init 12 | norm=None, # (default: None) normalization for input vector 13 | cb_norm=None, # (default: None) normalization for codebook vectors 14 | affine_lr=10.0, # (default: 0.0) lr scale for affine parameters 15 | affine_groups=8, # *** NEW *** (default: 1) number of affine parameter groups 16 | sync_nu=0.2, # (default: 0.0) codebook syncronization contribution 17 | replace_freq=20, # (default: None) frequency to replace dead codes 18 | dim=-1, # (default: -1) dimension to be quantized 19 | ).cuda() 20 | 21 | # when using `kmeans_init`, we can warmup the codebook 22 | with torch.no_grad(): 23 | z_e = torch.randn(128, 8, 8, 32).cuda() 24 | vq_layer(z_e) 25 | 26 | # standard forward pass 27 | z_e = torch.randn(128, 8, 8, 32).cuda() 28 | z_q, vq_dict = vq_layer(z_e) # equivalent to above 29 | assert z_e.shape == z_q.shape 30 | err = ((z_e - z_q) ** 2).mean().item() 31 | print(f'>>> quantization error: {err:.3f}') -------------------------------------------------------------------------------- /examples/experimental_inplace_update.py: -------------------------------------------------------------------------------- 1 | # mnist VQ experiment with various settings. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchvision import datasets 6 | 7 | from vqtorch.nn import VectorQuant 8 | from tqdm.auto import trange 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | 12 | 13 | lr = 3e-4 14 | train_iter = 300 15 | num_codes = 256 16 | seed = 1234 17 | 18 | class SimpleVQClassifier(nn.Module): 19 | def __init__(self, **vq_kwargs): 20 | super().__init__() 21 | self.layers = nn.ModuleList([ 22 | nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), 23 | nn.MaxPool2d(kernel_size=2, stride=2), 24 | nn.GELU(), 25 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), 26 | nn.MaxPool2d(kernel_size=2, stride=2), 27 | VectorQuant(32, **vq_kwargs), 28 | nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), 29 | nn.AdaptiveMaxPool2d((1, 1)), 30 | nn.Flatten(), 31 | nn.Linear(64, 10), 32 | ]) 33 | return 34 | 35 | def forward(self, x): 36 | for layer in self.layers: 37 | if isinstance(layer, VectorQuant): 38 | x, vq_dict = layer(x) 39 | else: 40 | x = layer(x) 41 | return x, vq_dict 42 | 43 | 44 | def train(model, train_loader, train_iterations=1000, alpha=10, ignore_commitment_loss=False): 45 | def iterate_dataset(data_loader): 46 | data_iter = iter(data_loader) 47 | while True: 48 | try: 49 | x, y = next(data_iter) 50 | except StopIteration: 51 | data_iter = iter(data_loader) 52 | x, y = next(data_iter) 53 | yield x.cuda(), y.cuda() 54 | 55 | criterion = nn.CrossEntropyLoss() 56 | 57 | for _ in (pbar := trange(train_iterations)): 58 | opt.zero_grad() 59 | x, y = next(iterate_dataset(train_loader)) 60 | out, vq_out = model(x) 61 | sce_loss = criterion(out, y) 62 | cmt_loss = vq_out['loss'] 63 | 64 | if ignore_commitment_loss: 65 | sce_loss.backward() 66 | else: 67 | (sce_loss + alpha * cmt_loss).backward() 68 | 69 | acc = (out.argmax(dim=1) == y).float().mean() 70 | 71 | opt.step() 72 | pbar.set_description(f'sce loss: {sce_loss.item():.3f} | ' + \ 73 | f'cmt loss: {cmt_loss.item():.3f} | ' + \ 74 | f'acc: {acc.item() * 100:.1f} | ' + \ 75 | f'active %: {vq_out["q"].unique().numel() / num_codes * 100:.3f}') 76 | return 77 | 78 | 79 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) 80 | train_dataset = DataLoader(datasets.MNIST(root='~/data/mnist', train=True, download=True, transform=transform), batch_size=256, shuffle=True) 81 | 82 | 83 | print('baseline + kmeans init') 84 | torch.random.manual_seed(seed) 85 | model = SimpleVQClassifier(num_codes=num_codes, kmeans_init=True).cuda() 86 | opt = torch.optim.AdamW(model.parameters(), lr=lr) 87 | train(model, train_dataset, train_iterations=train_iter, alpha=50) 88 | 89 | 90 | print('+ inplace alt update') 91 | inplace_optimizer = lambda *args, **kwargs: torch.optim.SGD(*args, **kwargs, lr=50.0, momentum=0.9) 92 | torch.random.manual_seed(seed) 93 | model = SimpleVQClassifier(num_codes=num_codes, kmeans_init=True, beta=1.0, inplace_optimizer=inplace_optimizer).cuda() 94 | opt = torch.optim.AdamW(model.parameters(), lr=lr) 95 | train(model, train_dataset, train_iterations=train_iter, ignore_commitment_loss=True) 96 | -------------------------------------------------------------------------------- /examples/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from vqtorch.nn import VectorQuant, GroupVectorQuant, ResidualVectorQuant 3 | 4 | 5 | print('Testing VectorQuant') 6 | # create VQ layer 7 | vq_layer = VectorQuant( 8 | feature_size=32, # feature dimension corresponding to the vectors 9 | num_codes=1024, # number of codebook vectors 10 | beta=0.98, # (default: 0.9) commitment trade-off 11 | kmeans_init=True, # (default: False) whether to use kmeans++ init 12 | norm=None, # (default: None) normalization for input vector 13 | cb_norm=None, # (default: None) normalization for codebook vectors 14 | affine_lr=10.0, # (default: 0.0) lr scale for affine parameters 15 | sync_nu=0.2, # (default: 0.0) codebook syncronization contribution 16 | replace_freq=20, # (default: None) frequency to replace dead codes 17 | dim=-1, # (default: -1) dimension to be quantized 18 | ).cuda() 19 | 20 | # when using `kmeans_init`, we can warmup the codebook 21 | with torch.no_grad(): 22 | z_e = torch.randn(128, 8, 8, 32).cuda() 23 | vq_layer(z_e) 24 | 25 | # standard forward pass 26 | z_e = torch.randn(128, 8, 8, 32).cuda() 27 | z_q, vq_dict = vq_layer(z_e) # equivalent to above 28 | assert z_e.shape == z_q.shape 29 | err = ((z_e - z_q) ** 2).mean().item() 30 | print(f'>>> quantization error: {err:.3f}') 31 | 32 | 33 | 34 | print('Testing GroupVectorQuant') 35 | # create VQ layer 36 | vq_layer = GroupVectorQuant( 37 | feature_size=32, 38 | num_codes=1024, 39 | beta=0.98, 40 | kmeans_init=True, 41 | norm=None, 42 | cb_norm=None, 43 | affine_lr=10.0, 44 | sync_nu=0.2, 45 | replace_freq=20, 46 | dim=-1, 47 | groups=4, # (default: 1) number of groups to divide the feature dimension 48 | share=False, # (default: True) when True, same codebook is used for each group 49 | ).cuda() 50 | 51 | # when using `kmeans_init`, we can warmup the codebook 52 | with torch.no_grad(): 53 | z_e = torch.randn(128, 8, 8, 32).cuda() 54 | vq_layer(z_e) 55 | 56 | # standard forward pass 57 | z_e = torch.randn(128, 8, 8, 32).cuda() 58 | z_q, vq_dict = vq_layer(z_e) # equivalent to above 59 | assert z_e.shape == z_q.shape 60 | err = ((z_e - z_q) ** 2).mean().item() 61 | print(f'>>> quantization error: {err:.3f}') 62 | 63 | 64 | 65 | 66 | print('Testing ResidualVectorQuant') 67 | # create VQ layer 68 | vq_layer = ResidualVectorQuant( 69 | feature_size=32, 70 | num_codes=1024, 71 | beta=0.98, 72 | kmeans_init=True, 73 | norm=None, 74 | cb_norm=None, 75 | affine_lr=10.0, 76 | sync_nu=0.2, 77 | replace_freq=20, 78 | dim=-1, 79 | groups=4, # (default: 1) number of groups to divide the feature dimension 80 | share=True, # (default: True) when True, same codebook is used for each group 81 | ).cuda() 82 | 83 | # when using `kmeans_init`, we can warmup the codebook 84 | with torch.no_grad(): 85 | z_e = torch.randn(128, 8, 8, 32).cuda() 86 | vq_layer(z_e) 87 | 88 | # standard forward pass 89 | z_e = torch.randn(128, 8, 8, 32).cuda() 90 | z_q, vq_dict = vq_layer(z_e) # equivalent to above 91 | assert z_e.shape == z_q.shape 92 | err = ((z_e - z_q) ** 2).mean().item() 93 | print(f'>>> quantization error: {err:.3f}') 94 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="vqtorch", 5 | packages=setuptools.find_packages(), 6 | version="0.1.0", 7 | author="Minyoung Huh", 8 | author_email="minhuh@mit.edu", 9 | description=f"vector-quantization for pytorch", 10 | url="git@github.com:minyoungg/vqtorch.git", 11 | classifiers=[ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: MIT License", 14 | "Operating System :: OS Independent", 15 | ], 16 | install_requires=[ 17 | "torch>=1.13.0", 18 | "string-color==1.2.3", 19 | "torchpq==0.3.0.1", 20 | ], 21 | python_requires='>=3.6', # developed on 3.9 / 3.10 22 | ) 23 | -------------------------------------------------------------------------------- /vqtorch/__init__.py: -------------------------------------------------------------------------------- 1 | from vqtorch.nn import * 2 | from vqtorch.utils import * 3 | from vqtorch.math_fns import * 4 | -------------------------------------------------------------------------------- /vqtorch/dists.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import warnings 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | def check_shape(tensor, codebook): 12 | if len(tensor.shape) != 3: 13 | raise RuntimeError(f'expected 3d tensor but found {tensor.size()}') 14 | 15 | if tensor.size(2) != codebook.size(1): 16 | raise RuntimeError( 17 | f'expected tensor and codebook to have the same feature ' + \ 18 | f'dimensions but found: {tensor.size()} vs {codebook.size()}' 19 | ) 20 | return 21 | 22 | 23 | def get_dist_fns(dist): 24 | if dist in ['euc', 'euclidean']: 25 | loss_fn = euclidean_distance 26 | dist_fn = euclidean_cdist_topk 27 | elif dist in ['cos', 'cosine']: 28 | loss_fn = cosine_distance 29 | dist_fn = cosine_cdist_topk 30 | else: 31 | raise ValueError(f'unknown distance method: {dist}') 32 | return loss_fn, dist_fn 33 | 34 | 35 | def cosine_distance(z, z_q): 36 | """ 37 | Computes element wise euclidean of z and z_q 38 | 39 | NOTE: the euclidean distance is not a true euclidean distance. 40 | """ 41 | 42 | z = F.normalize(z, p=2, dim=-1) 43 | z_q = F.normalize(z_q, p=2, dim=-1) 44 | return euclidean_distance(z, z_q) 45 | 46 | 47 | def euclidean_distance(z, z_q): 48 | """ 49 | Computes element wise euclidean of z and z_q 50 | 51 | NOTE: uses spatial averaging and no square root is applied. hence this is 52 | not a true euclidean distance but makes no difference in practice. 53 | """ 54 | if z.size() != z_q.size(): 55 | raise RuntimeError( 56 | f'expected z and z_q to have the same shape but got ' + \ 57 | f'{z.size()} vs {z_q.size()}' 58 | ) 59 | 60 | z, z_q = z.reshape(z.size(0), -1), z_q.reshape(z_q.size(0), -1) 61 | return ((z_q - z) ** 2).mean(1) #.sqrt() 62 | 63 | 64 | def euclidean_cdist_topk(tensor, codebook, compute_chunk_size=1024, topk=1, 65 | half_precision=False): 66 | """ 67 | Compute the euclidean distance between tensor and every element in the 68 | codebook. 69 | 70 | Args: 71 | tensor (Tensor): a 3D tensor of shape [batch x HWG x feats]. 72 | codebook (Tensor): a 2D tensor of shape [num_codes x feats]. 73 | compute_chunk_size (int): the chunk size to use when computing cdist. 74 | topk (int): stores `topk` distance minimizers. If -1, topk is the 75 | same length as the codebook 76 | half_precision (bool): if True, matrix multiplication is computed 77 | using half-precision to save memory. 78 | Returns: 79 | d (Tensor): distance matrix of shape [batch x HWG x topk]. 80 | each element is the distance of tensor[i] to every codebook. 81 | q (Tensor): code matrix of the same dimension as `d`. The index of the 82 | corresponding topk distances. 83 | 84 | NOTE: Compute chunk only looks at tensor since optimal codebook size 85 | generally does not vary too much. In future versions, should consider 86 | computing chunk size while taking into consideration of codebook and 87 | feature dimension size. 88 | """ 89 | check_shape(tensor, codebook) 90 | 91 | b, n, c = tensor.shape 92 | tensor_dtype = tensor.dtype 93 | tensor = tensor.reshape(-1, tensor.size(-1)) 94 | tensor = tensor.split(compute_chunk_size) 95 | dq = [] 96 | 97 | if topk == -1: 98 | topk = codebook.size(0) 99 | 100 | for i, tc in enumerate(tensor): 101 | cb = codebook 102 | 103 | if half_precision: 104 | tc = tc.half() 105 | cb = cb.half() 106 | 107 | d = torch.cdist(tc, cb) 108 | dq.append(torch.topk(d, k=topk, largest=False, dim=-1)) 109 | 110 | d, q = torch.cat([_dq[0] for _dq in dq]), torch.cat([_dq[1] for _dq in dq]) 111 | 112 | return_dict = {'d': d.to(tensor_dtype).reshape(b, n, -1), 113 | 'q': q.long().reshape(b, n, -1)} 114 | return return_dict 115 | 116 | 117 | def cosine_cdist_topk(tensor, codebook, chunks=4, topk=1, half_precision=False): 118 | """ Computes cosine distance instead. see `euclidean_cdist_topk` """ 119 | check_shape(tensor, codebook, mask) 120 | 121 | tensor = F.normalize(tensor, p=2, dim=-1) 122 | codebook = F.normalize(codebook, p=2, dim=-1) 123 | 124 | d, q = euclidean_cdist_topk(tensor, codebook, chunks, topk, half_precision) 125 | 126 | d = 0.5 * (d ** 2) 127 | return d, q.long() 128 | -------------------------------------------------------------------------------- /vqtorch/math_fns.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | def entropy(x, dim=-1, eps=1e-8, keepdim=False): 4 | assert x.min() >= 0., \ 5 | f'function takes non-negative values but found x.min(): {x.min()}' 6 | is_tensor = True 7 | 8 | if len(x.shape) == 1: 9 | is_tensor = False 10 | x = x.unsqueeze(0) 11 | 12 | x = x.moveaxis(dim, -1) 13 | x_shape = x.shape 14 | x = x.view(-1, x.size(-1)) + eps 15 | p = x / x.sum(dim=1, keepdim=True) 16 | h = - (p * p.log()).sum(dim=1, keepdim=True) 17 | h = h.view(*x_shape[:-1], 1).moveaxis(dim, -1) 18 | 19 | if not keepdim: 20 | h = h.squeeze(dim) 21 | 22 | if not is_tensor: 23 | h = h.squeeze(0) 24 | return h 25 | -------------------------------------------------------------------------------- /vqtorch/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .vq_base import _VQBaseLayer 2 | 3 | from .vq import VectorQuant 4 | from .gvq import GroupVectorQuant 5 | from .rvq import ResidualVectorQuant 6 | 7 | from .affine import AffineTransform 8 | from . import utils 9 | -------------------------------------------------------------------------------- /vqtorch/nn/affine.py: -------------------------------------------------------------------------------- 1 | import math 2 | import time 3 | import warnings 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | 12 | class AffineTransform(nn.Module): 13 | def __init__( 14 | self, 15 | feature_size, 16 | use_running_statistics=False, 17 | momentum=0.1, 18 | lr_scale=1, 19 | num_groups=1, 20 | ): 21 | super().__init__() 22 | 23 | self.use_running_statistics = use_running_statistics 24 | self.num_groups = num_groups 25 | 26 | if use_running_statistics: 27 | self.momentum = momentum 28 | self.register_buffer('running_statistics_initialized', torch.zeros(1)) 29 | self.register_buffer('running_ze_mean', torch.zeros(num_groups, feature_size)) 30 | self.register_buffer('running_ze_var', torch.ones(num_groups, feature_size)) 31 | 32 | self.register_buffer('running_c_mean', torch.zeros(num_groups, feature_size)) 33 | self.register_buffer('running_c_var', torch.ones(num_groups, feature_size)) 34 | else: 35 | self.scale = nn.parameter.Parameter(torch.zeros(num_groups, feature_size)) 36 | self.bias = nn.parameter.Parameter(torch.zeros(num_groups, feature_size)) 37 | self.lr_scale = lr_scale 38 | return 39 | 40 | @torch.no_grad() 41 | def update_running_statistics(self, z_e, c): 42 | # we find it helpful to often to make an under-estimation on the 43 | # z_e embedding statistics. Empirically we observe a slight 44 | # over-estimation of the statistics, causing the straight-through 45 | # estimation to grow indefinitely. While this is not an issue 46 | # for most model architecture, some model architectures that don't 47 | # have normalized bottlenecks, can cause it to eventually explode. 48 | # placing the VQ layer in certain layers of ViT exhibits this behavior 49 | 50 | 51 | if self.training and self.use_running_statistics: 52 | unbiased = False 53 | 54 | ze_mean = z_e.mean([0, 1]).unsqueeze(0) 55 | ze_var = z_e.var([0, 1], unbiased=unbiased).unsqueeze(0) 56 | 57 | c_mean = c.mean([0]).unsqueeze(0) 58 | c_var = c.var([0], unbiased=unbiased).unsqueeze(0) 59 | 60 | if not self.running_statistics_initialized: 61 | self.running_ze_mean.data.copy_(ze_mean) 62 | self.running_ze_var.data.copy_(ze_var) 63 | self.running_c_mean.data.copy_(c_mean) 64 | self.running_c_var.data.copy_(c_var) 65 | self.running_statistics_initialized.fill_(1) 66 | else: 67 | self.running_ze_mean = (self.momentum * ze_mean) + (1 - self.momentum) * self.running_ze_mean 68 | self.running_ze_var = (self.momentum * ze_var) + (1 - self.momentum) * self.running_ze_var 69 | self.running_c_mean = (self.momentum * c_mean) + (1 - self.momentum) * self.running_c_mean 70 | self.running_c_var = (self.momentum * c_var) + (1 - self.momentum) * self.running_c_var 71 | 72 | # wd = 0.9998 # 0.995 73 | # self.running_ze_mean = wd * self.running_ze_mean 74 | # self.running_ze_var = wd * self.running_ze_var 75 | return 76 | 77 | 78 | def forward(self, codebook): 79 | scale, bias = self.get_affine_params() 80 | n, c = codebook.shape 81 | codebook = codebook.view(self.num_groups, -1, codebook.shape[-1]) 82 | codebook = scale * codebook + bias 83 | return codebook.reshape(n, c) 84 | 85 | 86 | def get_affine_params(self): 87 | if self.use_running_statistics: 88 | scale = (self.running_ze_var / (self.running_c_var + 1e-8)).sqrt() 89 | bias = - scale * self.running_c_mean + self.running_ze_mean 90 | else: 91 | scale = (1. + self.lr_scale * self.scale) 92 | bias = self.lr_scale * self.bias 93 | return scale.unsqueeze(1), bias.unsqueeze(1) 94 | -------------------------------------------------------------------------------- /vqtorch/nn/gvq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from stringcolor import cs 5 | from vqtorch.norms import with_codebook_normalization 6 | from .vq import VectorQuant 7 | 8 | 9 | class GroupVectorQuant(VectorQuant): 10 | """ 11 | Vector quantization layer. 12 | 13 | Args: 14 | groups (int): Number of groups for vector quantization. The vectors are divided 15 | into group chunks. When groups=1, it is the same as VectorQuant. 16 | share (bool): If True, codebook is shared for each sub-vector. 17 | *rest*: see VectorQuant() 18 | """ 19 | 20 | def __init__( 21 | self, 22 | feature_size : int, 23 | num_codes : int, 24 | groups : int = 1, 25 | share : bool = True, 26 | **kwargs, 27 | ): 28 | 29 | if not share and not feature_size % groups == 0: 30 | e_msg = f'feature_size {self.feature_size} must be divisible by groups {groups}.' 31 | raise RuntimeError(str(cs(e_msg, 'red'))) 32 | 33 | num_codebooks = 1 if share else groups 34 | in_dim = self.group_size = num_codes // num_codebooks 35 | out_dim = feature_size // groups 36 | 37 | super().__init__(feature_size, num_codes, code_vector_size=out_dim, **kwargs) 38 | 39 | self.groups = groups 40 | self.share = share 41 | self.codebook = nn.Embedding(in_dim * num_codebooks, out_dim) 42 | return 43 | 44 | 45 | def get_codebook_by_group(self, group): 46 | cb = self.codebook.weight 47 | offset = 0 if self.share else group * self.group_size 48 | return cb[offset : offset + self.group_size], offset 49 | 50 | 51 | @with_codebook_normalization 52 | def forward(self, z): 53 | 54 | ###### 55 | ## (1) formatting data by groups and invariant to dim 56 | ###### 57 | 58 | z = self.prepare_inputs(z, self.groups) 59 | 60 | if not self.enabled: 61 | z = self.to_original_format(z) 62 | return z, {} 63 | 64 | ###### 65 | ## (2) quantize latent vector 66 | ###### 67 | 68 | z_q = torch.zeros_like(z) 69 | d = torch.zeros(z_q.shape[:-1]).to(z_q.device) 70 | q = torch.zeros(z_q.shape[:-1], dtype=torch.long).to(z_q.device) 71 | 72 | for i in range(self.groups): 73 | # select group 74 | _z = z[..., i:i+1, :] 75 | 76 | # quantize 77 | cb, offset = self.get_codebook_by_group(i) 78 | _z_q, _d, _q = self.quantize(cb, _z) 79 | 80 | # assign to tensor 81 | z_q[..., i:i+1, :] = _z_q 82 | d[..., i:i+1] = _d 83 | q[..., i:i+1] = _q + offset 84 | 85 | to_return = { 86 | 'z' : z, # each group input z_e 87 | 'z_q' : z_q, # quantized output z_q 88 | 'd' : d, # distance function for each group 89 | 'q' : q, # codes using offsetted indices 90 | 'loss': self.compute_loss(z, z_q), 91 | 'perplexity': None, 92 | } 93 | 94 | z_q = self.straight_through_approximation(z, z_q) 95 | z_q = self.to_original_format(z_q) 96 | return z_q, to_return 97 | -------------------------------------------------------------------------------- /vqtorch/nn/pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class _VecPool2d(nn.Module): 7 | def __init__(self, weighting_fn, kernel_size=3, stride=2, padding=2, dilation=1): 8 | super().__init__() 9 | 10 | self.k = kernel_size 11 | self.s = stride 12 | self.p = padding 13 | self.d = dilation 14 | 15 | self.weighting_fn = weighting_fn 16 | 17 | self.unfold = nn.Unfold(kernel_size=self.k, dilation=self.d, padding=self.p, stride=self.s) 18 | return 19 | 20 | def forward(self, x): 21 | b, c, h, w = x.shape 22 | out_h = (h - self.k + 2 * self.p) // self.s + 1 23 | out_w = (w - self.k + 2 * self.p) // self.s + 1 24 | 25 | with torch.no_grad(): 26 | n = x.norm(dim=1, p=2, keepdim=True) 27 | n = self.unfold(n) # B x (K**2) x N 28 | n = self.weighting_fn(n, dim=1).unsqueeze(1) 29 | 30 | x = self.unfold(x) # B x C * (K**2) x N 31 | x = x.view(b, c, -1, x.size(-1)) 32 | 33 | x = n * x 34 | x = x.sum(2).view(b, c, out_h, out_w) 35 | return x 36 | 37 | 38 | class MaxVecPool2d(_VecPool2d): 39 | def __init__(self, *args, **kwargs): 40 | super().__init__(MaxVecPool2d.max_onehot, *args, **kwargs) 41 | 42 | @staticmethod 43 | def max_onehot(x, dim): 44 | b, s, n = x.shape 45 | 46 | x = x.argmax(dim=1) 47 | x = x.view(-1) 48 | x = F.one_hot(x, s).view(b, n, s).swapaxes(-1, dim) 49 | return x 50 | 51 | 52 | class SoftMaxVecPool2d(_VecPool2d): 53 | def __init__(self, *args, **kwargs): 54 | super().__init__(torch.softmax, *args, **kwargs) 55 | -------------------------------------------------------------------------------- /vqtorch/nn/rvq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from stringcolor import cs 5 | from vqtorch.norms import with_codebook_normalization 6 | from .vq import VectorQuant 7 | 8 | 9 | class ResidualVectorQuant(VectorQuant): 10 | """ 11 | Args 12 | groups (int): Number of residual VQ to apply. When num_residual=1, 13 | layer acts will be equivalent to VectorQuant. 14 | share (bool): If True, codebook is shared for every quantization. 15 | *rest*: see VectorQuant() 16 | 17 | NOTE: Don't use L2 normalization on the codebook. ResidualVQ is norm sensitive. 18 | For norm invariant, consider using cosine distance variant. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | feature_size : int, 24 | num_codes : int, 25 | groups : int = 1, 26 | share : bool = True, 27 | **kwargs, 28 | ): 29 | 30 | if not share and not feature_size % groups == 0: 31 | e_msg = f'feature_size {feature_size} must be divisible by residual groups {groups}.' 32 | raise RuntimeError(str(cs(e_msg, 'red'))) 33 | 34 | self.groups = groups 35 | self.share = share 36 | 37 | num_codebooks = 1 if share else groups 38 | in_dim = self.group_size = num_codes // num_codebooks 39 | out_dim = feature_size 40 | 41 | super().__init__(feature_size, num_codes, code_vector_size=out_dim, **kwargs) 42 | 43 | self.groups = groups 44 | self.share = share 45 | self.codebook = nn.Embedding(in_dim * num_codebooks, out_dim) 46 | 47 | return 48 | 49 | 50 | def get_codebook_by_group(self, group): 51 | cb = self.codebook.weight 52 | offset = 0 if self.share else group * self.group_size 53 | return cb[offset : offset + self.group_size], offset 54 | 55 | 56 | @with_codebook_normalization 57 | def forward(self, z): 58 | 59 | ###### 60 | ## (1) formatting data by groups and invariant to dim 61 | ###### 62 | 63 | z = self.prepare_inputs(z, groups=1) 64 | 65 | if not self.enabled: 66 | z = self.to_original_format(z) 67 | return z, {} 68 | 69 | ###### 70 | ## (2) quantize latent vector 71 | ###### 72 | 73 | z_q = torch.zeros_like(z) 74 | z_res = torch.zeros(*z.shape[:-2], self.groups + 1, z.shape[-1]).to(z.device) 75 | 76 | d = torch.zeros(*z_q.shape[:-2], self.groups).to(z_q.device) 77 | q = torch.zeros(*z_q.shape[:-2], self.groups, dtype=torch.long).to(z_q.device) 78 | 79 | for i in range(self.groups): 80 | # select group 81 | _z = (z - z_q) # compute resiudal 82 | z_res[..., i:i+1, :] = _z 83 | 84 | # quantize 85 | cb, offset = self.get_codebook_by_group(i) 86 | _z_q, _d, _q = self.quantize(cb, _z) 87 | 88 | # update estimate 89 | z_q = z_q + _z_q 90 | 91 | # assign to tensor 92 | d[..., i:i+1] = _d 93 | q[..., i:i+1] = _q + offset 94 | 95 | z_res[..., -1:, :] = z - z_q 96 | 97 | to_return = { 98 | 'z' : z, # each group input z_e 99 | 'z_q' : z_q, # quantized output z_q 100 | 'd' : d, # distance function for each group 101 | 'q' : q, # codes using offsetted indices 102 | 'z_res': z_res, 103 | 'loss' : self.compute_loss(z, z_q), 104 | 'perplexity': None, 105 | } 106 | 107 | z_q = self.straight_through_approximation(z, z_q) 108 | z_q = self.to_original_format(z_q) 109 | return z_q, to_return 110 | -------------------------------------------------------------------------------- /vqtorch/nn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .replace import lru_replacement 2 | -------------------------------------------------------------------------------- /vqtorch/nn/utils/init.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from stringcolor import cs 3 | import warnings 4 | import vqtorch 5 | 6 | 7 | 8 | @torch.no_grad() 9 | def data_dependent_init_forward_hook(self, inputs, outputs, use_kmeans=True, verbose=False): 10 | """ initializes codebook from data """ 11 | 12 | if (not self.training) or (self.data_initialized.item() == 1): 13 | return 14 | 15 | if verbose: 16 | print(cs('initializing codebook with k-means++', 'y')) 17 | 18 | def sample_centroids(z_e, num_codes): 19 | """ replaces the data of the codebook with z_e randomly. """ 20 | 21 | z_e = z_e.reshape(-1, z_e.size(-1)) 22 | 23 | if num_codes >= z_e.size(0): 24 | e_msg = f'\ncodebook size > warmup samples: {num_codes} vs {z_e.size(0)}. ' + \ 25 | 'recommended to decrease the codebook size or increase batch size.' 26 | 27 | warnings.warn(str(cs(e_msg, 'yellow'))) 28 | 29 | # repeat until it fits and add noise 30 | repeat = num_codes // z_e.shape[0] 31 | new_codes = z_e.data.tile([repeat, 1])[:num_codes] 32 | new_codes += 1e-3 * torch.randn_like(new_codes.data) 33 | 34 | else: 35 | # you have more warmup samples than codebook. subsample data 36 | if use_kmeans: 37 | from torchpq.clustering import KMeans 38 | kmeans = KMeans(n_clusters=num_codes, distance='euclidean', init_mode="kmeans++") 39 | kmeans.fit(z_e.data.T.contiguous()) 40 | new_codes = kmeans.centroids.T 41 | else: 42 | indices = torch.randint(low=0, high=num_codes, size=(num_codes,)) 43 | indices = indices.to(z_e.device) 44 | new_codes = torch.index_select(z_e, 0, indices).to(z_e.device).data 45 | 46 | return new_codes 47 | 48 | _, misc = outputs 49 | z_e, z_q = misc['z'], misc['z_q'] 50 | 51 | if type(self) is vqtorch.nn.VectorQuant: 52 | num_codes = self.codebook.weight.shape[0] 53 | new_codebook = sample_centroids(z_e, num_codes) 54 | self.codebook.weight.data = new_codebook 55 | 56 | elif type(self) is vqtorch.nn.GroupVectorQuant: 57 | if self.share: 58 | print(self.codebook.weight.shape) 59 | new_codebook = sample_centroids(z_e, self.group_size) 60 | self.codebook.weight.data = new_codebook 61 | else: 62 | for i in range(self.groups): 63 | offset = i * self.group_size 64 | new_codebook = sample_centroids(z_e[..., i, :], self.group_size) 65 | self.codebook.weight.data[offset:offset+self.group_size] = new_codebook 66 | 67 | elif type(self) is vqtorch.nn.ResidualVectorQuant: 68 | z_e = misc['z_res'] 69 | 70 | if self.share: 71 | new_codebook = sample_centroids(z_e, self.group_size) 72 | self.codebook.weight.data = new_codebook 73 | else: 74 | for i in range(self.groups): 75 | offset = i * self.group_size 76 | new_codebook = sample_centroids(z_e[..., i, :], self.group_size) 77 | self.codebook.weight.data[offset:offset+self.group_size] = new_codebook 78 | 79 | 80 | self.data_initialized.fill_(1) 81 | return 82 | -------------------------------------------------------------------------------- /vqtorch/nn/utils/replace.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | 5 | class ReplaceLRU(): 6 | """ 7 | Attributes: 8 | rho (float): mutation noise 9 | timeout (int): number of batch it has seen 10 | """ 11 | VALID_POLICIES = ['input_random', 'input_kmeans', 'self'] 12 | 13 | def __init__(self, rho=1e-4, timeout=100): 14 | assert timeout > 1 15 | assert rho > 0.0 16 | self.rho = rho 17 | self.timeout = timeout 18 | 19 | self.policy = 'input_random' 20 | # self.policy = 'input_kmeans' 21 | # self.policy = 'self' 22 | self.tau = 2.0 23 | 24 | assert self.policy in self.VALID_POLICIES 25 | return 26 | 27 | @staticmethod 28 | def apply(module, rho=0., timeout=100): 29 | """ register forward hook """ 30 | fn = ReplaceLRU(rho, timeout) 31 | device = next(module.parameters()).device 32 | module.register_forward_hook(fn) 33 | module.register_buffer('_counts', timeout * torch.ones(module.num_codes)) 34 | module._counts = module._counts.to(device) 35 | return fn 36 | 37 | def __call__(self, module, inputs, outputs): 38 | """ 39 | This function is triggered during forward pass 40 | recall: z_q, misc = vq(x) 41 | 42 | Args 43 | module (nn.VectorQuant) 44 | inputs (tuple): A tuple with 1 element 45 | x (Tensor) 46 | outputs (tuple): A tuple with 2 elements 47 | z_q (Tensor), misc (dict) 48 | """ 49 | if not module.training: 50 | return 51 | 52 | # count down all code by 1 and if used, reset timer to timeout value 53 | module._counts -= 1 54 | 55 | # --- computes most recent codebook usage --- # 56 | unique, counts = torch.unique(outputs[1]['q'], return_counts=True) 57 | module._counts.index_fill_(0, unique, self.timeout) 58 | 59 | # --- find how many needs to be replaced --- # 60 | # num_active = self.check_and_replace_dead_codes(module, outputs) 61 | inactive_indices = torch.argwhere(module._counts == 0).squeeze(-1) 62 | num_inactive = inactive_indices.size(0) 63 | 64 | if num_inactive > 0: 65 | 66 | if self.policy == 'self': 67 | # exponential distance allows more recently used codes to be even more preferable 68 | p = torch.zeros_like(module._counts) 69 | p[unique] = counts.float() 70 | p = p / p.sum() 71 | p = torch.exp(self.tau * p) - 1 # the negative 1 is to drive p=0 to stay 0 72 | 73 | selected_indices = torch.multinomial(p, num_inactive, replacement=True) 74 | selected_values = module.codebook.weight.data[selected_indices].clone() 75 | 76 | elif self.policy == 'input_random': 77 | z_e = outputs[1]['z'].flatten(0, -2) # flatten to 2D 78 | z_e = z_e[torch.randperm(z_e.size(0))] # shuffle 79 | mult = num_inactive // z_e.size(0) + 1 80 | if mult > 1: # if theres not enough 81 | z_e = torch.cat(mult * [z_e]) 82 | selected_values = z_e[:num_inactive] 83 | 84 | elif self.policy == 'input_kmeans': 85 | # can be extremely slow 86 | from torchpq.clustering import KMeans 87 | z_e = outputs[1]['z'].flatten(0, -2) # flatten to 2D 88 | z_e = z_e[torch.randperm(z_e.size(0))] # shuffle 89 | kmeans = KMeans(n_clusters=num_inactive, distance='euclidean', init_mode="kmeans++") 90 | kmeans.fit(z_e.data.T.contiguous()) 91 | selected_values = kmeans.centroids.T 92 | 93 | if self.rho > 0: 94 | norm = selected_values.norm(p=2, dim=-1, keepdim=True) 95 | noise = torch.randn_like(selected_values) 96 | selected_values = selected_values + self.rho * norm * noise 97 | 98 | # --- update dead codes with new codes --- # 99 | module.codebook.weight.data[inactive_indices] = selected_values 100 | module._counts[inactive_indices] += self.timeout 101 | 102 | return outputs 103 | 104 | 105 | 106 | def lru_replacement(vq_module, rho=1e-4, timeout=100): 107 | """ 108 | Example:: 109 | >>> vq = VectorQuant(...) 110 | >>> vq = lru_replacement(vq) 111 | """ 112 | ReplaceLRU.apply(vq_module, rho, timeout) 113 | return vq_module 114 | -------------------------------------------------------------------------------- /vqtorch/nn/vq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from vqtorch.dists import get_dist_fns 6 | import vqtorch 7 | from vqtorch.norms import with_codebook_normalization 8 | from .vq_base import _VQBaseLayer 9 | from .affine import AffineTransform 10 | 11 | 12 | class VectorQuant(_VQBaseLayer): 13 | """ 14 | Vector quantization layer using straight-through estimation. 15 | 16 | Args: 17 | feature_size (int): feature dimension corresponding to the vectors 18 | num_codes (int): number of vectors in the codebook 19 | beta (float): commitment loss weighting 20 | sync_nu (float): sync loss weighting 21 | affine_lr (float): learning rate for affine transform 22 | affine_groups (int): number of affine parameter groups 23 | replace_freq (int): frequency to replace dead codes 24 | inplace_optimizer (Optimizer): optimizer for inplace codebook updates 25 | **kwargs: additional arguments for _VQBaseLayer 26 | 27 | Returns: 28 | Quantized vector z_q and return dict 29 | """ 30 | 31 | 32 | def __init__( 33 | self, 34 | feature_size : int, 35 | num_codes : int, 36 | beta : float = 0.95, 37 | sync_nu : float = 0.0, 38 | affine_lr: float = 0.0, 39 | affine_groups: int = 1, 40 | replace_freq: int = 0, 41 | inplace_optimizer: torch.optim.Optimizer = None, 42 | **kwargs, 43 | ): 44 | 45 | super().__init__(feature_size, num_codes, **kwargs) 46 | self.loss_fn, self.dist_fn = get_dist_fns('euclidean') 47 | 48 | if beta < 0.0 or beta > 1.0: 49 | raise ValueError(f'beta must be in [0, 1] but got {beta}') 50 | 51 | self.beta = beta 52 | self.nu = sync_nu 53 | self.affine_lr = affine_lr 54 | self.codebook = nn.Embedding(self.num_codes, self.feature_size) 55 | 56 | if inplace_optimizer is not None: 57 | if beta != 1.0: 58 | raise ValueError('inplace_optimizer can only be used with beta=1.0') 59 | self.inplace_codebook_optimizer = inplace_optimizer(self.codebook.parameters()) 60 | 61 | if affine_lr > 0: 62 | # defaults to using learnable affine parameters 63 | self.affine_transform = AffineTransform( 64 | self.code_vector_size, 65 | use_running_statistics=False, 66 | lr_scale=affine_lr, 67 | num_groups=affine_groups, 68 | ) 69 | if replace_freq > 0: 70 | vqtorch.nn.utils.lru_replacement(self, rho=0.01, timeout=replace_freq) 71 | return 72 | 73 | 74 | def straight_through_approximation(self, z, z_q): 75 | """ passed gradient from z_q to z """ 76 | if self.nu > 0: 77 | z_q = z + (z_q - z).detach() + (self.nu * z_q) + (-self.nu * z_q).detach() 78 | else: 79 | z_q = z + (z_q - z).detach() 80 | return z_q 81 | 82 | 83 | def compute_loss(self, z_e, z_q): 84 | """ computes loss between z and z_q """ 85 | return ((1.0 - self.beta) * self.loss_fn(z_e, z_q.detach()) + \ 86 | (self.beta) * self.loss_fn(z_e.detach(), z_q)) 87 | 88 | 89 | def quantize(self, codebook, z): 90 | """ 91 | Quantizes the latent codes z with the codebook 92 | 93 | Args: 94 | codebook (Tensor): B x F 95 | z (Tensor): B x ... x F 96 | """ 97 | 98 | # reshape to (BHWG x F//G) and compute distance 99 | z_shape = z.shape[:-1] 100 | z_flat = z.view(z.size(0), -1, z.size(-1)) 101 | 102 | if hasattr(self, 'affine_transform'): 103 | self.affine_transform.update_running_statistics(z_flat, codebook) 104 | codebook = self.affine_transform(codebook) 105 | 106 | with torch.no_grad(): 107 | dist_out = self.dist_fn( 108 | tensor=z_flat, 109 | codebook=codebook, 110 | topk=self.topk, 111 | compute_chunk_size=self.cdist_chunk_size, 112 | half_precision=(z.is_cuda), 113 | ) 114 | 115 | d = dist_out['d'].view(z_shape) 116 | q = dist_out['q'].view(z_shape).long() 117 | 118 | z_q = F.embedding(q, codebook) 119 | 120 | if self.training and hasattr(self, 'inplace_codebook_optimizer'): 121 | # update codebook inplace 122 | ((z_q - z.detach()) ** 2).mean().backward() 123 | self.inplace_codebook_optimizer.step() 124 | self.inplace_codebook_optimizer.zero_grad() 125 | 126 | # forward pass again with the update codebook 127 | z_q = F.embedding(q, codebook) 128 | 129 | # NOTE to save compute, we assumed Q did not change. 130 | 131 | return z_q, d, q 132 | 133 | @torch.no_grad() 134 | def get_codebook(self): 135 | cb = self.codebook.weight 136 | if hasattr(self, 'affine_transform'): 137 | cb = self.affine_transform(cb) 138 | return cb 139 | 140 | def get_codebook_affine_params(self): 141 | if hasattr(self, 'affine_transform'): 142 | return self.affine_transform.get_affine_params() 143 | return None 144 | 145 | @with_codebook_normalization 146 | def forward(self, z): 147 | 148 | ###### 149 | ## (1) formatting data by groups and invariant to dim 150 | ###### 151 | 152 | z = self.prepare_inputs(z, self.groups) 153 | 154 | if not self.enabled: 155 | z = self.to_original_format(z) 156 | return z, {} 157 | 158 | ###### 159 | ## (2) quantize latent vector 160 | ###### 161 | 162 | z_q, d, q = self.quantize(self.codebook.weight, z) 163 | 164 | # e_mean = F.one_hot(q, num_classes=self.num_codes).view(-1, self.num_codes).float().mean(0) 165 | # perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) 166 | perplexity = None 167 | 168 | to_return = { 169 | 'z' : z, # each group input z_e 170 | 'z_q': z_q, # quantized output z_q 171 | 'd' : d, # distance function for each group 172 | 'q' : q, # codes 173 | 'loss': self.compute_loss(z, z_q).mean(), 174 | 'perplexity': perplexity, 175 | } 176 | 177 | z_q = self.straight_through_approximation(z, z_q) 178 | z_q = self.to_original_format(z_q) 179 | 180 | return z_q, to_return 181 | -------------------------------------------------------------------------------- /vqtorch/nn/vq_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from vqtorch.norms import get_norm 5 | from vqtorch.nn.utils.init import data_dependent_init_forward_hook 6 | 7 | 8 | 9 | class _VQBaseLayer(nn.Module): 10 | """ 11 | Base template code for vector quanitzation. All VQ layers will inherit 12 | from this class. 13 | 14 | Args: 15 | feature_size (int): 16 | The size of the feature. this is the length of each 17 | code vector. the dimensions must match the input feature size. 18 | num_codes (int): 19 | Number of codes to use in the codebook. 20 | dim (int): Dimension to quantize. by default quantization happens on 21 | the channel dimension. For example, given an image tensor 22 | (B x C x H x W) and dim=1, the channels are treated as features 23 | and the resulting codes `q` has the shape (B x H x W). 24 | For transformers (B x N x C), you should set dim=2 or -1. 25 | norm (str): Feature normalization. 26 | codebook_norm (str): Codebook normalization. 27 | 28 | Returns: 29 | Quantized vector z_q and return dict 30 | 31 | Attributes: 32 | cdist_chunk_size (int): chunk size for divide-and-conquer topk cdist. 33 | enabled (bool): If false, the model is not quantized and acts as an identity layer. 34 | """ 35 | 36 | cdist_chunk_size = 1024 37 | enabled = True 38 | 39 | def __init__( 40 | self, 41 | feature_size : int, 42 | num_codes : int, 43 | dim : int = 1, 44 | norm : str = 'none', 45 | cb_norm : str = 'none', 46 | kmeans_init : bool = False, 47 | code_vector_size : int = None, 48 | ): 49 | 50 | super().__init__() 51 | self.feature_size = feature_size 52 | self.code_vector_size = feature_size if code_vector_size is None else code_vector_size 53 | self.num_codes = num_codes 54 | self.dim = dim 55 | 56 | self.groups = 1 # for group VQ 57 | self.topk = 1 # for probabilistic VQ 58 | 59 | self.norm = norm 60 | self.codebook_norm = cb_norm 61 | self.norm_layer, self.norm_before_grouping = get_norm(norm, feature_size) 62 | 63 | if kmeans_init: 64 | self.register_buffer('data_initialized', torch.zeros(1)) 65 | self.register_forward_hook(data_dependent_init_forward_hook) 66 | return 67 | 68 | def quantize(self, codebook, z): 69 | """ 70 | Quantizes the latent codes z with the codebook 71 | 72 | Args: 73 | codebook (Tensor): B x C 74 | z (Tensor): B x ... x C 75 | """ 76 | raise NotImplementedError 77 | 78 | 79 | def compute_loss(self, z_e, z_q): 80 | """ computes error between z and z_q """ 81 | raise NotImplementedError 82 | 83 | 84 | def to_canonical_group_format(self, z, groups): 85 | """ 86 | Converts data into canonical group format 87 | 88 | The quantization dim is sent to the last dimension. 89 | The features are then resized such that C -> G x C' 90 | 91 | Args: 92 | x (Tensor): a tensor in group form [B x C x ... ] 93 | groups (int): number of groups 94 | Returns: 95 | x of shape [B x ... x G x C'] 96 | """ 97 | 98 | z = z.moveaxis(self.dim, -1).contiguous() 99 | z = z.unflatten(-1, (groups, -1)) 100 | return z 101 | 102 | 103 | def to_original_format(self, x): 104 | """ 105 | Merges group and permutes dimension back 106 | 107 | Args: 108 | x (Tensor): a tensor in group form [B x ... x G x C // G] 109 | Returns: 110 | merged `x` of shape [B x ... x C] (assuming dim=1) 111 | """ 112 | return x.flatten(-2, -1).moveaxis(-1, self.dim) 113 | 114 | 115 | def prepare_inputs(self, z, groups): 116 | """ 117 | Prepare input with normalization and group format 118 | 119 | Args: 120 | x (Tensor): a tensor in group form [B x C x ... ] 121 | groups (int): number of groups 122 | """ 123 | 124 | if len(z.shape) <= 1: 125 | e_msg = f'expected a tensor of at least 2 dimensions but found {z.size()}' 126 | raise ValueError(e_msg) 127 | 128 | if self.norm_before_grouping: 129 | z = self.norm_layer(z) 130 | 131 | z = self.to_canonical_group_format(z, groups) 132 | 133 | if not self.norm_before_grouping: 134 | z = self.norm_layer(z) 135 | 136 | return z 137 | 138 | 139 | @property 140 | def requires_grad(self): 141 | return self.codebook[0].weight.requires_grad 142 | 143 | 144 | def set_requires_grad(self, requires_grad): 145 | for codebook in self.codebook: 146 | codebook.weight.requires_grad = requires_grad 147 | return 148 | 149 | 150 | def extra_repr(self): 151 | repr = "\n".join([ 152 | f"num_codes: {self.num_codes}", 153 | f"groups: {self.groups}", 154 | f"enabled: {self.enabled}", 155 | ]) 156 | return repr 157 | -------------------------------------------------------------------------------- /vqtorch/norms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | MAXNORM_CONSTRAINT_VALUE = 10 7 | 8 | 9 | class Normalize(nn.Module): 10 | """ 11 | Simple vector normalization module. By default, vectors are normalizes 12 | along the channel dimesion. Each vector associated to the spatial 13 | location is normalized. Used along with cosine-distance VQ layer. 14 | """ 15 | def __init__(self, p=2, dim=1, eps=1e-6): 16 | super().__init__() 17 | self.p = p 18 | self.dim = dim 19 | self.eps = eps 20 | return 21 | 22 | def forward(self, x): 23 | return F.normalize(x, p=self.p, dim=self.dim, eps=self.eps) 24 | 25 | 26 | def max_norm(w, p=2, dim=-1, max_norm=MAXNORM_CONSTRAINT_VALUE, eps=1e-8): 27 | norm = w.norm(p=p, dim=dim, keepdim=True) 28 | desired = torch.clamp(norm.data, max=max_norm) 29 | # desired = torch.clamp(norm, max=max_norm) 30 | return w * (desired / (norm + eps)) 31 | 32 | 33 | class MaxNormConstraint(nn.Module): 34 | def __init__(self, max_norm=1, p=2, dim=-1, eps=1e-8): 35 | super().__init__() 36 | self.p = p 37 | self.eps = eps 38 | self.dim = dim 39 | self.max_norm = max_norm 40 | 41 | def forward(self, x): 42 | return max_norm(x, self.p, self.dim, max_norm=self.max_norm) 43 | 44 | 45 | @torch.no_grad() 46 | def with_codebook_normalization(func): 47 | def wrapper(*args): 48 | self = args[0] 49 | for n, m in self.named_modules(): 50 | if isinstance(m, nn.Embedding): 51 | if self.codebook_norm == 'l2': 52 | m.weight.data = max_norm(m.weight.data, p=2, dim=1, eps=1e-8) 53 | elif self.codebook_norm == 'l2c': 54 | m.weight.data = F.normalize(m.weight.data, p=2, dim=1, eps=1e-8) 55 | return func(*args) 56 | return wrapper 57 | 58 | 59 | def get_norm(norm, num_channels=None): 60 | before_grouping = True 61 | if norm == 'l2': 62 | norm_layer = Normalize(p=2, dim=-1) 63 | before_grouping = False 64 | elif norm == 'l2c': 65 | norm_layer = MaxNormConstraint(p=2, dim=-1, max_norm=MAXNORM_CONSTRAINT_VALUE) 66 | before_grouping = False 67 | elif norm == 'bn': 68 | norm_layer = nn.BatchNorm2d(num_channels) 69 | elif norm == 'gn': 70 | norm_layer = GroupNorm(num_channels) 71 | elif norm in ['none', None]: 72 | norm_layer = nn.Identity() 73 | elif norm == 'in': 74 | norm_layer = nn.InstanceNorm2d(num_channels) 75 | else: 76 | raise ValueError(f'unknown norm {norm}') 77 | return norm_layer, before_grouping 78 | 79 | 80 | def GroupNorm(in_channels): 81 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 82 | 83 | 84 | def match_norm(x, y, dim=-1, eps=1e-8): 85 | """ 86 | matches vector norm of x to that of y 87 | Args: 88 | x (Tensor): a tensor of any shape 89 | y (Tensor): a tensor of the same shape as `x`. 90 | dim (int): dimension to match the norm over 91 | eps (float): epsilon to mitigate division by zero. 92 | Returns: 93 | `x` with the same norm as `y` across `dim` 94 | """ 95 | assert x.shape == y.shape, \ 96 | f'expected `x` and `y` to have the same dim but found {x.shape} vs {y.shape}' 97 | 98 | # move chosen dim to last dim 99 | x = x.moveaxis(dim, -1).contiguous() 100 | y = y.moveaxis(dim, -1).contiguous() 101 | x_shape = x.shape 102 | 103 | # unravel everything such that [GBHW X C] 104 | 105 | # print(x.shape) 106 | x = x.view(-1, x.size(-1)) 107 | y = y.view(-1, y.size(-1)) 108 | 109 | # compute norm on C 110 | x_norm = torch.norm(x, p=2, dim=1, keepdim=True) 111 | y_norm = torch.norm(y, p=2, dim=1, keepdim=True) 112 | 113 | # clamp y_norm for division by 0 114 | x_norm = torch.clamp(x_norm, min=eps) 115 | 116 | # normalize (x now has same norm as y) 117 | x = y_norm * (x / x_norm) 118 | x = x.view(x_shape) 119 | x = x.moveaxis(-1, dim).contiguous() 120 | return x 121 | -------------------------------------------------------------------------------- /vqtorch/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from vqtorch.nn import _VQBaseLayer 5 | 6 | 7 | def is_vq(m): 8 | """ checks if the module is a VQ layer """ 9 | return issubclass(type(m), _VQBaseLayer) 10 | 11 | def is_vqn(net): 12 | """ checks if the network contains a VQ layer """ 13 | return np.any([is_vq(layer) for layer in net.modules()]) 14 | 15 | def get_vq_layers(model): 16 | """ returns vq layers from a network """ 17 | return [m for m in model.modules() if is_vq(m)] 18 | 19 | 20 | class no_vq(): 21 | """ 22 | Function to turn off VQ by setting all the VQ layers to identity function. 23 | 24 | Examples:: 25 | >>> with vqtorch.no_vq(): 26 | ... out = model(x) 27 | """ 28 | 29 | def __init__(self, modules): 30 | if type(modules) is not list: 31 | modules = [modules] 32 | 33 | for module in modules: 34 | if isinstance(module, torch.nn.DataParallel): 35 | module = module.module 36 | 37 | assert isinstance(module, torch.nn.Module), \ 38 | f'expected input to be nn.Module or a list of nn.Module ' + \ 39 | f'but found {type(module)}' 40 | 41 | self.enable_vq(module, enable=False) 42 | 43 | self.modules = modules 44 | return 45 | 46 | def __enter__(self): 47 | pass 48 | 49 | def __exit__(self, exception_type, exception_value, traceback): 50 | for module in self.modules: 51 | if isinstance(module, torch.nn.DataParallel): 52 | module = module.module 53 | self.enable_vq(module, enable=True) 54 | return 55 | 56 | def enable_vq(self, module, enable): 57 | for m in module.modules(): 58 | if is_vq(m): 59 | m.enabled = enable 60 | return 61 | --------------------------------------------------------------------------------