├── .gitignore
├── LICENSE
├── README.md
├── assets
├── vqtorch-logo.pdf
└── vqtorch-logo.png
├── examples
├── autoencoder.py
├── classification.py
├── experimental_group_affine.py
├── experimental_inplace_update.py
└── test.py
├── setup.py
└── vqtorch
├── __init__.py
├── dists.py
├── math_fns.py
├── nn
├── __init__.py
├── affine.py
├── gvq.py
├── pool.py
├── rvq.py
├── utils
│ ├── __init__.py
│ ├── init.py
│ └── replace.py
├── vq.py
└── vq_base.py
├── norms.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | **__pycache__/**
3 | **.egg**
4 | **.DS_Store**
5 | **.swp**
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 minyoung huh (jacob)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | ---
7 |
8 | VQTorch is a PyTorch library for vector quantization.
9 |
10 | The library was developed and used for.
11 | - `[1] Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks, Huh et al. ICML2023`
12 |
13 | ## Installation
14 | Development was done on Ubuntu with Python 3.9/3.10 using NVIDIA GPUs. Some requirements may need to be adjusted in order to run.
15 | Some features, such as half-precision cdist and cuda-based kmeans, are only supported on CUDA devices.
16 |
17 | First install the correct version of [cupy](https://github.com/cupy/cupy/). Make sure to install the correct version. The version refers to `CUDA Version` number when using the command `nvidia-smi`. `cupy` seem to now support ROCm drivers but this has not been tested.
18 | ```bash
19 | # recent 12.x cuda versions
20 | pip install cupy-cuda12x
21 |
22 | # 11.x versions (for even older see the repo above)
23 | pip install cupy-cuda11x
24 | ```
25 |
26 | Next, install `vqtorch`
27 | ```bash
28 | git clone https://github.com/minyoungg/vqtorch
29 | cd vqtorch
30 | pip install -e .
31 | ```
32 |
33 | ## Example usage
34 | For examples using `VectorQuant` for classification and auto-encoders check out [here](./examples/).
35 |
36 | ```python
37 | import torch
38 | from vqtorch.nn import VectorQuant
39 |
40 | print('Testing VectorQuant')
41 | # create VQ layer
42 | vq_layer = VectorQuant(
43 | feature_size=32, # feature dimension corresponding to the vectors
44 | num_codes=1024, # number of codebook vectors
45 | beta=0.98, # (default: 0.9) commitment trade-off
46 | kmeans_init=True, # (default: False) whether to use kmeans++ init
47 | norm=None, # (default: None) normalization for the input vectors
48 | cb_norm=None, # (default: None) normalization for codebook vectors
49 | affine_lr=10.0, # (default: 0.0) lr scale for affine parameters
50 | sync_nu=0.2, # (default: 0.0) codebook synchronization contribution
51 | replace_freq=20, # (default: None) frequency to replace dead codes
52 | dim=-1, # (default: -1) dimension to be quantized
53 | ).cuda()
54 |
55 | # when `kmeans_init=True` is recommended to warm up the codebook before training
56 | with torch.no_grad():
57 | z_e = torch.randn(128, 8, 8, 32).cuda()
58 | vq_layer(z_e)
59 |
60 | # standard forward pass
61 | z_e = torch.randn(128, 8, 8, 32).cuda()
62 | z_q, vq_dict = vq_layer(z_e)
63 |
64 | print(vq_dict.keys)
65 | >>> dict_keys(['z', 'z_q', 'd', 'q', 'loss', 'perplexity'])
66 | ```
67 |
68 | ## Supported features
69 | - `vqtorch.nn.GroupVectorQuant` - Vectors are quantized by first partitioning into `n` subvectors.
70 | - `vqtorch.nn.ResidualVectorQuant` - Vectors are first quantized and the residuals are repeatedly quantized.
71 | - `vqtorch.nn.MaxVecPool2d` - Pools along the vector dimension by selecting the vector with the maximum norm.
72 | - `vqtorch.nn.SoftMaxVecPool2d` - Pools along the vector dimension by the weighted average computed by softmax over the norm.
73 | - `vqtorch.no_vq` - Disables all vector quantization layers that inherit `vqtorch.nn._VQBaseLayer`
74 | ```python
75 | model = VQN(...)
76 | with vqtorch.no_vq():
77 | out = model(x)
78 | ```
79 |
80 | ## Experimental features
81 | - Group affine parameterization: divides the codebook into groups. The individual group is reparameterized with its own affine parameters. One can invoke it via
82 | ```python
83 | vq_layer = VectorQuant(..., affine_groups=8)
84 | ```
85 | - In-place alternated optimization: in-place codebook during the forward pass.
86 | ```python
87 | inplace_optimizer = lambda *args, **kwargs: torch.optim.SGD(*args, **kwargs, lr=50.0, momentum=0.9)
88 | vq_layer = VectorQuant(inplace_optimizer=inplace_optimizer)
89 | ```
90 |
91 | ## Planned features
92 | We aim to incorporate commonly used VQ methods, including probabilistic VQ variants.
93 |
94 |
95 | ## Citations
96 | If the features such as `affine parameterization`, `synchronized commitment loss` or `alternating optimization` was useful, please consider citing
97 |
98 | ```bibtex
99 | @inproceedings{huh2023improvedvqste,
100 | title={Straightening Out the Straight-Through Estimator: Overcoming Optimization Challenges in Vector Quantized Networks},
101 | author={Huh, Minyoung and Cheung, Brian and Agrawal, Pulkit and Isola, Phillip},
102 | booktitle={International Conference on Machine Learning},
103 | year={2023},
104 | organization={PMLR}
105 | }
106 | ```
107 |
108 | If you found the library useful please consider citing
109 | ```bibtex
110 | @misc{huh2023vqtorch,
111 | author = {Huh, Minyoung},
112 | title = {vqtorch: {P}y{T}orch Package for Vector Quantization},
113 | year = {2022},
114 | howpublished = {\url{https://github.com/minyoungg/vqtorch}},
115 | }
116 | ```
117 |
--------------------------------------------------------------------------------
/assets/vqtorch-logo.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/minyoungg/vqtorch/02e60a19bd742c17b0bf3e1925f23796d54cbeac/assets/vqtorch-logo.pdf
--------------------------------------------------------------------------------
/assets/vqtorch-logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/minyoungg/vqtorch/02e60a19bd742c17b0bf3e1925f23796d54cbeac/assets/vqtorch-logo.png
--------------------------------------------------------------------------------
/examples/autoencoder.py:
--------------------------------------------------------------------------------
1 | # mnist VQ experiment with various settings.
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torchvision import datasets
6 |
7 | from vqtorch.nn import VectorQuant
8 | from tqdm.auto import trange
9 | from torch.utils.data import DataLoader
10 | from torchvision import transforms
11 |
12 |
13 | lr = 3e-4
14 | train_iter = 1000
15 | num_codes = 256
16 | seed = 1234
17 |
18 | class SimpleVQAutoEncoder(nn.Module):
19 | def __init__(self, **vq_kwargs):
20 | super().__init__()
21 | self.layers = nn.ModuleList([
22 | nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
23 | nn.MaxPool2d(kernel_size=2, stride=2),
24 | nn.GELU(),
25 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
26 | nn.MaxPool2d(kernel_size=2, stride=2),
27 | VectorQuant(32, **vq_kwargs),
28 | nn.Upsample(scale_factor=2, mode='nearest'),
29 | nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
30 | nn.GELU(),
31 | nn.Upsample(scale_factor=2, mode='nearest'),
32 | nn.Conv2d(16, 1, kernel_size=3, stride=1, padding=1),
33 | ])
34 | return
35 |
36 | def forward(self, x):
37 | for layer in self.layers:
38 | if isinstance(layer, VectorQuant):
39 | x, vq_dict = layer(x)
40 | else:
41 | x = layer(x)
42 | return x.clamp(-1, 1), vq_dict
43 |
44 |
45 | def train(model, train_loader, train_iterations=1000, alpha=10):
46 | def iterate_dataset(data_loader):
47 | data_iter = iter(data_loader)
48 | while True:
49 | try:
50 | x, y = next(data_iter)
51 | except StopIteration:
52 | data_iter = iter(data_loader)
53 | x, y = next(data_iter)
54 | yield x.cuda(), y.cuda()
55 |
56 | for _ in (pbar := trange(train_iterations)):
57 | opt.zero_grad()
58 | x, _ = next(iterate_dataset(train_loader))
59 | out, vq_out = model(x)
60 | rec_loss = (out - x).abs().mean()
61 | cmt_loss = vq_out['loss']
62 | (rec_loss + alpha * cmt_loss).backward()
63 |
64 | opt.step()
65 | pbar.set_description(f'rec loss: {rec_loss.item():.3f} | ' + \
66 | f'cmt loss: {cmt_loss.item():.3f} | ' + \
67 | f'active %: {vq_out["q"].unique().numel() / num_codes * 100:.3f}')
68 | return
69 |
70 |
71 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
72 | train_dataset = DataLoader(datasets.MNIST(root='~/data/mnist', train=True, download=True, transform=transform), batch_size=256, shuffle=True)
73 |
74 | print('baseline')
75 | torch.random.manual_seed(seed)
76 | model = SimpleVQAutoEncoder(num_codes=num_codes).cuda()
77 | opt = torch.optim.AdamW(model.parameters(), lr=lr)
78 | train(model, train_dataset, train_iterations=train_iter)
79 |
80 |
81 | print('+ kmeans init')
82 | torch.random.manual_seed(seed)
83 | model = SimpleVQAutoEncoder(num_codes=num_codes, kmeans_init=True).cuda()
84 | opt = torch.optim.AdamW(model.parameters(), lr=lr)
85 | train(model, train_dataset, train_iterations=train_iter)
86 |
87 |
88 | print('+ synchronized update')
89 | torch.random.manual_seed(seed)
90 | model = SimpleVQAutoEncoder(num_codes=num_codes, kmeans_init=True, sync_nu=2.0).cuda()
91 | opt = torch.optim.AdamW(model.parameters(), lr=lr)
92 | train(model, train_dataset, train_iterations=train_iter)
93 |
94 |
95 | print('+ affine parameterization')
96 | torch.random.manual_seed(seed)
97 | model = SimpleVQAutoEncoder(num_codes=num_codes, kmeans_init=True, sync_nu=2.0, affine_lr=2.0).cuda()
98 | opt = torch.optim.AdamW(model.parameters(), lr=lr)
99 | train(model, train_dataset, train_iterations=train_iter)
--------------------------------------------------------------------------------
/examples/classification.py:
--------------------------------------------------------------------------------
1 | # mnist VQ experiment with various settings.
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torchvision import datasets
6 |
7 | from vqtorch.nn import VectorQuant
8 | from tqdm.auto import trange
9 | from torch.utils.data import DataLoader
10 | from torchvision import transforms
11 |
12 |
13 | lr = 3e-4
14 | train_iter = 1000
15 | num_codes = 256
16 | seed = 1234
17 |
18 | class SimpleVQClassifier(nn.Module):
19 | def __init__(self, **vq_kwargs):
20 | super().__init__()
21 | self.layers = nn.ModuleList([
22 | nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
23 | nn.MaxPool2d(kernel_size=2, stride=2),
24 | nn.GELU(),
25 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
26 | nn.MaxPool2d(kernel_size=2, stride=2),
27 | VectorQuant(32, **vq_kwargs),
28 | nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
29 | nn.AdaptiveMaxPool2d((1, 1)),
30 | nn.Flatten(),
31 | nn.Linear(64, 10),
32 | ])
33 | return
34 |
35 | def forward(self, x):
36 | for layer in self.layers:
37 | if isinstance(layer, VectorQuant):
38 | x, vq_dict = layer(x)
39 | else:
40 | x = layer(x)
41 | return x, vq_dict
42 |
43 |
44 | def train(model, train_loader, train_iterations=1000, alpha=10):
45 | def iterate_dataset(data_loader):
46 | data_iter = iter(data_loader)
47 | while True:
48 | try:
49 | x, y = next(data_iter)
50 | except StopIteration:
51 | data_iter = iter(data_loader)
52 | x, y = next(data_iter)
53 | yield x.cuda(), y.cuda()
54 |
55 | criterion = nn.CrossEntropyLoss()
56 |
57 | for _ in (pbar := trange(train_iterations)):
58 | opt.zero_grad()
59 | x, y = next(iterate_dataset(train_loader))
60 | out, vq_out = model(x)
61 | sce_loss = criterion(out, y)
62 | cmt_loss = vq_out['loss']
63 | acc = (out.argmax(dim=1) == y).float().mean()
64 | (sce_loss + alpha * cmt_loss).backward()
65 |
66 | opt.step()
67 | pbar.set_description(f'sce loss: {sce_loss.item():.3f} | ' + \
68 | f'cmt loss: {cmt_loss.item():.3f} | ' + \
69 | f'acc: {acc.item() * 100:.1f} | ' + \
70 | f'active %: {vq_out["q"].unique().numel() / num_codes * 100:.3f}')
71 | return
72 |
73 |
74 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
75 | train_dataset = DataLoader(datasets.MNIST(root='~/data/mnist', train=True, download=True, transform=transform), batch_size=256, shuffle=True)
76 |
77 |
78 | print('baseline')
79 | torch.random.manual_seed(seed)
80 | model = SimpleVQClassifier(num_codes=num_codes).cuda()
81 | opt = torch.optim.AdamW(model.parameters(), lr=lr)
82 | train(model, train_dataset, train_iterations=train_iter)
83 |
84 |
85 | print('+ kmeans init')
86 | torch.random.manual_seed(seed)
87 | model = SimpleVQClassifier(num_codes=num_codes, kmeans_init=True).cuda()
88 | opt = torch.optim.AdamW(model.parameters(), lr=lr)
89 | train(model, train_dataset, train_iterations=train_iter)
90 |
91 |
92 | print('+ synchronized update')
93 | torch.random.manual_seed(seed)
94 | model = SimpleVQClassifier(num_codes=num_codes, kmeans_init=True, sync_nu=1.0).cuda()
95 | opt = torch.optim.AdamW(model.parameters(), lr=lr)
96 | train(model, train_dataset, train_iterations=train_iter)
97 |
98 |
99 | print('+ affine parameterization')
100 | torch.random.manual_seed(seed)
101 | model = SimpleVQClassifier(num_codes=num_codes, kmeans_init=True, sync_nu=1.0, affine_lr=10).cuda()
102 | opt = torch.optim.AdamW(model.parameters(), lr=lr)
103 | train(model, train_dataset, train_iterations=train_iter)
--------------------------------------------------------------------------------
/examples/experimental_group_affine.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from vqtorch.nn import VectorQuant, GroupVectorQuant, ResidualVectorQuant
3 |
4 |
5 | print('Testing VectorQuant')
6 | # create VQ layer
7 | vq_layer = VectorQuant(
8 | feature_size=32, # feature dimension corresponding to the vectors
9 | num_codes=1024, # number of codebook vectors
10 | beta=0.98, # (default: 0.9) commitment trade-off
11 | kmeans_init=True, # (default: False) whether to use kmeans++ init
12 | norm=None, # (default: None) normalization for input vector
13 | cb_norm=None, # (default: None) normalization for codebook vectors
14 | affine_lr=10.0, # (default: 0.0) lr scale for affine parameters
15 | affine_groups=8, # *** NEW *** (default: 1) number of affine parameter groups
16 | sync_nu=0.2, # (default: 0.0) codebook syncronization contribution
17 | replace_freq=20, # (default: None) frequency to replace dead codes
18 | dim=-1, # (default: -1) dimension to be quantized
19 | ).cuda()
20 |
21 | # when using `kmeans_init`, we can warmup the codebook
22 | with torch.no_grad():
23 | z_e = torch.randn(128, 8, 8, 32).cuda()
24 | vq_layer(z_e)
25 |
26 | # standard forward pass
27 | z_e = torch.randn(128, 8, 8, 32).cuda()
28 | z_q, vq_dict = vq_layer(z_e) # equivalent to above
29 | assert z_e.shape == z_q.shape
30 | err = ((z_e - z_q) ** 2).mean().item()
31 | print(f'>>> quantization error: {err:.3f}')
--------------------------------------------------------------------------------
/examples/experimental_inplace_update.py:
--------------------------------------------------------------------------------
1 | # mnist VQ experiment with various settings.
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torchvision import datasets
6 |
7 | from vqtorch.nn import VectorQuant
8 | from tqdm.auto import trange
9 | from torch.utils.data import DataLoader
10 | from torchvision import transforms
11 |
12 |
13 | lr = 3e-4
14 | train_iter = 300
15 | num_codes = 256
16 | seed = 1234
17 |
18 | class SimpleVQClassifier(nn.Module):
19 | def __init__(self, **vq_kwargs):
20 | super().__init__()
21 | self.layers = nn.ModuleList([
22 | nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
23 | nn.MaxPool2d(kernel_size=2, stride=2),
24 | nn.GELU(),
25 | nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
26 | nn.MaxPool2d(kernel_size=2, stride=2),
27 | VectorQuant(32, **vq_kwargs),
28 | nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
29 | nn.AdaptiveMaxPool2d((1, 1)),
30 | nn.Flatten(),
31 | nn.Linear(64, 10),
32 | ])
33 | return
34 |
35 | def forward(self, x):
36 | for layer in self.layers:
37 | if isinstance(layer, VectorQuant):
38 | x, vq_dict = layer(x)
39 | else:
40 | x = layer(x)
41 | return x, vq_dict
42 |
43 |
44 | def train(model, train_loader, train_iterations=1000, alpha=10, ignore_commitment_loss=False):
45 | def iterate_dataset(data_loader):
46 | data_iter = iter(data_loader)
47 | while True:
48 | try:
49 | x, y = next(data_iter)
50 | except StopIteration:
51 | data_iter = iter(data_loader)
52 | x, y = next(data_iter)
53 | yield x.cuda(), y.cuda()
54 |
55 | criterion = nn.CrossEntropyLoss()
56 |
57 | for _ in (pbar := trange(train_iterations)):
58 | opt.zero_grad()
59 | x, y = next(iterate_dataset(train_loader))
60 | out, vq_out = model(x)
61 | sce_loss = criterion(out, y)
62 | cmt_loss = vq_out['loss']
63 |
64 | if ignore_commitment_loss:
65 | sce_loss.backward()
66 | else:
67 | (sce_loss + alpha * cmt_loss).backward()
68 |
69 | acc = (out.argmax(dim=1) == y).float().mean()
70 |
71 | opt.step()
72 | pbar.set_description(f'sce loss: {sce_loss.item():.3f} | ' + \
73 | f'cmt loss: {cmt_loss.item():.3f} | ' + \
74 | f'acc: {acc.item() * 100:.1f} | ' + \
75 | f'active %: {vq_out["q"].unique().numel() / num_codes * 100:.3f}')
76 | return
77 |
78 |
79 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
80 | train_dataset = DataLoader(datasets.MNIST(root='~/data/mnist', train=True, download=True, transform=transform), batch_size=256, shuffle=True)
81 |
82 |
83 | print('baseline + kmeans init')
84 | torch.random.manual_seed(seed)
85 | model = SimpleVQClassifier(num_codes=num_codes, kmeans_init=True).cuda()
86 | opt = torch.optim.AdamW(model.parameters(), lr=lr)
87 | train(model, train_dataset, train_iterations=train_iter, alpha=50)
88 |
89 |
90 | print('+ inplace alt update')
91 | inplace_optimizer = lambda *args, **kwargs: torch.optim.SGD(*args, **kwargs, lr=50.0, momentum=0.9)
92 | torch.random.manual_seed(seed)
93 | model = SimpleVQClassifier(num_codes=num_codes, kmeans_init=True, beta=1.0, inplace_optimizer=inplace_optimizer).cuda()
94 | opt = torch.optim.AdamW(model.parameters(), lr=lr)
95 | train(model, train_dataset, train_iterations=train_iter, ignore_commitment_loss=True)
96 |
--------------------------------------------------------------------------------
/examples/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from vqtorch.nn import VectorQuant, GroupVectorQuant, ResidualVectorQuant
3 |
4 |
5 | print('Testing VectorQuant')
6 | # create VQ layer
7 | vq_layer = VectorQuant(
8 | feature_size=32, # feature dimension corresponding to the vectors
9 | num_codes=1024, # number of codebook vectors
10 | beta=0.98, # (default: 0.9) commitment trade-off
11 | kmeans_init=True, # (default: False) whether to use kmeans++ init
12 | norm=None, # (default: None) normalization for input vector
13 | cb_norm=None, # (default: None) normalization for codebook vectors
14 | affine_lr=10.0, # (default: 0.0) lr scale for affine parameters
15 | sync_nu=0.2, # (default: 0.0) codebook syncronization contribution
16 | replace_freq=20, # (default: None) frequency to replace dead codes
17 | dim=-1, # (default: -1) dimension to be quantized
18 | ).cuda()
19 |
20 | # when using `kmeans_init`, we can warmup the codebook
21 | with torch.no_grad():
22 | z_e = torch.randn(128, 8, 8, 32).cuda()
23 | vq_layer(z_e)
24 |
25 | # standard forward pass
26 | z_e = torch.randn(128, 8, 8, 32).cuda()
27 | z_q, vq_dict = vq_layer(z_e) # equivalent to above
28 | assert z_e.shape == z_q.shape
29 | err = ((z_e - z_q) ** 2).mean().item()
30 | print(f'>>> quantization error: {err:.3f}')
31 |
32 |
33 |
34 | print('Testing GroupVectorQuant')
35 | # create VQ layer
36 | vq_layer = GroupVectorQuant(
37 | feature_size=32,
38 | num_codes=1024,
39 | beta=0.98,
40 | kmeans_init=True,
41 | norm=None,
42 | cb_norm=None,
43 | affine_lr=10.0,
44 | sync_nu=0.2,
45 | replace_freq=20,
46 | dim=-1,
47 | groups=4, # (default: 1) number of groups to divide the feature dimension
48 | share=False, # (default: True) when True, same codebook is used for each group
49 | ).cuda()
50 |
51 | # when using `kmeans_init`, we can warmup the codebook
52 | with torch.no_grad():
53 | z_e = torch.randn(128, 8, 8, 32).cuda()
54 | vq_layer(z_e)
55 |
56 | # standard forward pass
57 | z_e = torch.randn(128, 8, 8, 32).cuda()
58 | z_q, vq_dict = vq_layer(z_e) # equivalent to above
59 | assert z_e.shape == z_q.shape
60 | err = ((z_e - z_q) ** 2).mean().item()
61 | print(f'>>> quantization error: {err:.3f}')
62 |
63 |
64 |
65 |
66 | print('Testing ResidualVectorQuant')
67 | # create VQ layer
68 | vq_layer = ResidualVectorQuant(
69 | feature_size=32,
70 | num_codes=1024,
71 | beta=0.98,
72 | kmeans_init=True,
73 | norm=None,
74 | cb_norm=None,
75 | affine_lr=10.0,
76 | sync_nu=0.2,
77 | replace_freq=20,
78 | dim=-1,
79 | groups=4, # (default: 1) number of groups to divide the feature dimension
80 | share=True, # (default: True) when True, same codebook is used for each group
81 | ).cuda()
82 |
83 | # when using `kmeans_init`, we can warmup the codebook
84 | with torch.no_grad():
85 | z_e = torch.randn(128, 8, 8, 32).cuda()
86 | vq_layer(z_e)
87 |
88 | # standard forward pass
89 | z_e = torch.randn(128, 8, 8, 32).cuda()
90 | z_q, vq_dict = vq_layer(z_e) # equivalent to above
91 | assert z_e.shape == z_q.shape
92 | err = ((z_e - z_q) ** 2).mean().item()
93 | print(f'>>> quantization error: {err:.3f}')
94 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | setuptools.setup(
4 | name="vqtorch",
5 | packages=setuptools.find_packages(),
6 | version="0.1.0",
7 | author="Minyoung Huh",
8 | author_email="minhuh@mit.edu",
9 | description=f"vector-quantization for pytorch",
10 | url="git@github.com:minyoungg/vqtorch.git",
11 | classifiers=[
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: MIT License",
14 | "Operating System :: OS Independent",
15 | ],
16 | install_requires=[
17 | "torch>=1.13.0",
18 | "string-color==1.2.3",
19 | "torchpq==0.3.0.1",
20 | ],
21 | python_requires='>=3.6', # developed on 3.9 / 3.10
22 | )
23 |
--------------------------------------------------------------------------------
/vqtorch/__init__.py:
--------------------------------------------------------------------------------
1 | from vqtorch.nn import *
2 | from vqtorch.utils import *
3 | from vqtorch.math_fns import *
4 |
--------------------------------------------------------------------------------
/vqtorch/dists.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import warnings
4 | import numpy as np
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | def check_shape(tensor, codebook):
12 | if len(tensor.shape) != 3:
13 | raise RuntimeError(f'expected 3d tensor but found {tensor.size()}')
14 |
15 | if tensor.size(2) != codebook.size(1):
16 | raise RuntimeError(
17 | f'expected tensor and codebook to have the same feature ' + \
18 | f'dimensions but found: {tensor.size()} vs {codebook.size()}'
19 | )
20 | return
21 |
22 |
23 | def get_dist_fns(dist):
24 | if dist in ['euc', 'euclidean']:
25 | loss_fn = euclidean_distance
26 | dist_fn = euclidean_cdist_topk
27 | elif dist in ['cos', 'cosine']:
28 | loss_fn = cosine_distance
29 | dist_fn = cosine_cdist_topk
30 | else:
31 | raise ValueError(f'unknown distance method: {dist}')
32 | return loss_fn, dist_fn
33 |
34 |
35 | def cosine_distance(z, z_q):
36 | """
37 | Computes element wise euclidean of z and z_q
38 |
39 | NOTE: the euclidean distance is not a true euclidean distance.
40 | """
41 |
42 | z = F.normalize(z, p=2, dim=-1)
43 | z_q = F.normalize(z_q, p=2, dim=-1)
44 | return euclidean_distance(z, z_q)
45 |
46 |
47 | def euclidean_distance(z, z_q):
48 | """
49 | Computes element wise euclidean of z and z_q
50 |
51 | NOTE: uses spatial averaging and no square root is applied. hence this is
52 | not a true euclidean distance but makes no difference in practice.
53 | """
54 | if z.size() != z_q.size():
55 | raise RuntimeError(
56 | f'expected z and z_q to have the same shape but got ' + \
57 | f'{z.size()} vs {z_q.size()}'
58 | )
59 |
60 | z, z_q = z.reshape(z.size(0), -1), z_q.reshape(z_q.size(0), -1)
61 | return ((z_q - z) ** 2).mean(1) #.sqrt()
62 |
63 |
64 | def euclidean_cdist_topk(tensor, codebook, compute_chunk_size=1024, topk=1,
65 | half_precision=False):
66 | """
67 | Compute the euclidean distance between tensor and every element in the
68 | codebook.
69 |
70 | Args:
71 | tensor (Tensor): a 3D tensor of shape [batch x HWG x feats].
72 | codebook (Tensor): a 2D tensor of shape [num_codes x feats].
73 | compute_chunk_size (int): the chunk size to use when computing cdist.
74 | topk (int): stores `topk` distance minimizers. If -1, topk is the
75 | same length as the codebook
76 | half_precision (bool): if True, matrix multiplication is computed
77 | using half-precision to save memory.
78 | Returns:
79 | d (Tensor): distance matrix of shape [batch x HWG x topk].
80 | each element is the distance of tensor[i] to every codebook.
81 | q (Tensor): code matrix of the same dimension as `d`. The index of the
82 | corresponding topk distances.
83 |
84 | NOTE: Compute chunk only looks at tensor since optimal codebook size
85 | generally does not vary too much. In future versions, should consider
86 | computing chunk size while taking into consideration of codebook and
87 | feature dimension size.
88 | """
89 | check_shape(tensor, codebook)
90 |
91 | b, n, c = tensor.shape
92 | tensor_dtype = tensor.dtype
93 | tensor = tensor.reshape(-1, tensor.size(-1))
94 | tensor = tensor.split(compute_chunk_size)
95 | dq = []
96 |
97 | if topk == -1:
98 | topk = codebook.size(0)
99 |
100 | for i, tc in enumerate(tensor):
101 | cb = codebook
102 |
103 | if half_precision:
104 | tc = tc.half()
105 | cb = cb.half()
106 |
107 | d = torch.cdist(tc, cb)
108 | dq.append(torch.topk(d, k=topk, largest=False, dim=-1))
109 |
110 | d, q = torch.cat([_dq[0] for _dq in dq]), torch.cat([_dq[1] for _dq in dq])
111 |
112 | return_dict = {'d': d.to(tensor_dtype).reshape(b, n, -1),
113 | 'q': q.long().reshape(b, n, -1)}
114 | return return_dict
115 |
116 |
117 | def cosine_cdist_topk(tensor, codebook, chunks=4, topk=1, half_precision=False):
118 | """ Computes cosine distance instead. see `euclidean_cdist_topk` """
119 | check_shape(tensor, codebook, mask)
120 |
121 | tensor = F.normalize(tensor, p=2, dim=-1)
122 | codebook = F.normalize(codebook, p=2, dim=-1)
123 |
124 | d, q = euclidean_cdist_topk(tensor, codebook, chunks, topk, half_precision)
125 |
126 | d = 0.5 * (d ** 2)
127 | return d, q.long()
128 |
--------------------------------------------------------------------------------
/vqtorch/math_fns.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | def entropy(x, dim=-1, eps=1e-8, keepdim=False):
4 | assert x.min() >= 0., \
5 | f'function takes non-negative values but found x.min(): {x.min()}'
6 | is_tensor = True
7 |
8 | if len(x.shape) == 1:
9 | is_tensor = False
10 | x = x.unsqueeze(0)
11 |
12 | x = x.moveaxis(dim, -1)
13 | x_shape = x.shape
14 | x = x.view(-1, x.size(-1)) + eps
15 | p = x / x.sum(dim=1, keepdim=True)
16 | h = - (p * p.log()).sum(dim=1, keepdim=True)
17 | h = h.view(*x_shape[:-1], 1).moveaxis(dim, -1)
18 |
19 | if not keepdim:
20 | h = h.squeeze(dim)
21 |
22 | if not is_tensor:
23 | h = h.squeeze(0)
24 | return h
25 |
--------------------------------------------------------------------------------
/vqtorch/nn/__init__.py:
--------------------------------------------------------------------------------
1 | from .vq_base import _VQBaseLayer
2 |
3 | from .vq import VectorQuant
4 | from .gvq import GroupVectorQuant
5 | from .rvq import ResidualVectorQuant
6 |
7 | from .affine import AffineTransform
8 | from . import utils
9 |
--------------------------------------------------------------------------------
/vqtorch/nn/affine.py:
--------------------------------------------------------------------------------
1 | import math
2 | import time
3 | import warnings
4 | import numpy as np
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 |
12 | class AffineTransform(nn.Module):
13 | def __init__(
14 | self,
15 | feature_size,
16 | use_running_statistics=False,
17 | momentum=0.1,
18 | lr_scale=1,
19 | num_groups=1,
20 | ):
21 | super().__init__()
22 |
23 | self.use_running_statistics = use_running_statistics
24 | self.num_groups = num_groups
25 |
26 | if use_running_statistics:
27 | self.momentum = momentum
28 | self.register_buffer('running_statistics_initialized', torch.zeros(1))
29 | self.register_buffer('running_ze_mean', torch.zeros(num_groups, feature_size))
30 | self.register_buffer('running_ze_var', torch.ones(num_groups, feature_size))
31 |
32 | self.register_buffer('running_c_mean', torch.zeros(num_groups, feature_size))
33 | self.register_buffer('running_c_var', torch.ones(num_groups, feature_size))
34 | else:
35 | self.scale = nn.parameter.Parameter(torch.zeros(num_groups, feature_size))
36 | self.bias = nn.parameter.Parameter(torch.zeros(num_groups, feature_size))
37 | self.lr_scale = lr_scale
38 | return
39 |
40 | @torch.no_grad()
41 | def update_running_statistics(self, z_e, c):
42 | # we find it helpful to often to make an under-estimation on the
43 | # z_e embedding statistics. Empirically we observe a slight
44 | # over-estimation of the statistics, causing the straight-through
45 | # estimation to grow indefinitely. While this is not an issue
46 | # for most model architecture, some model architectures that don't
47 | # have normalized bottlenecks, can cause it to eventually explode.
48 | # placing the VQ layer in certain layers of ViT exhibits this behavior
49 |
50 |
51 | if self.training and self.use_running_statistics:
52 | unbiased = False
53 |
54 | ze_mean = z_e.mean([0, 1]).unsqueeze(0)
55 | ze_var = z_e.var([0, 1], unbiased=unbiased).unsqueeze(0)
56 |
57 | c_mean = c.mean([0]).unsqueeze(0)
58 | c_var = c.var([0], unbiased=unbiased).unsqueeze(0)
59 |
60 | if not self.running_statistics_initialized:
61 | self.running_ze_mean.data.copy_(ze_mean)
62 | self.running_ze_var.data.copy_(ze_var)
63 | self.running_c_mean.data.copy_(c_mean)
64 | self.running_c_var.data.copy_(c_var)
65 | self.running_statistics_initialized.fill_(1)
66 | else:
67 | self.running_ze_mean = (self.momentum * ze_mean) + (1 - self.momentum) * self.running_ze_mean
68 | self.running_ze_var = (self.momentum * ze_var) + (1 - self.momentum) * self.running_ze_var
69 | self.running_c_mean = (self.momentum * c_mean) + (1 - self.momentum) * self.running_c_mean
70 | self.running_c_var = (self.momentum * c_var) + (1 - self.momentum) * self.running_c_var
71 |
72 | # wd = 0.9998 # 0.995
73 | # self.running_ze_mean = wd * self.running_ze_mean
74 | # self.running_ze_var = wd * self.running_ze_var
75 | return
76 |
77 |
78 | def forward(self, codebook):
79 | scale, bias = self.get_affine_params()
80 | n, c = codebook.shape
81 | codebook = codebook.view(self.num_groups, -1, codebook.shape[-1])
82 | codebook = scale * codebook + bias
83 | return codebook.reshape(n, c)
84 |
85 |
86 | def get_affine_params(self):
87 | if self.use_running_statistics:
88 | scale = (self.running_ze_var / (self.running_c_var + 1e-8)).sqrt()
89 | bias = - scale * self.running_c_mean + self.running_ze_mean
90 | else:
91 | scale = (1. + self.lr_scale * self.scale)
92 | bias = self.lr_scale * self.bias
93 | return scale.unsqueeze(1), bias.unsqueeze(1)
94 |
--------------------------------------------------------------------------------
/vqtorch/nn/gvq.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from stringcolor import cs
5 | from vqtorch.norms import with_codebook_normalization
6 | from .vq import VectorQuant
7 |
8 |
9 | class GroupVectorQuant(VectorQuant):
10 | """
11 | Vector quantization layer.
12 |
13 | Args:
14 | groups (int): Number of groups for vector quantization. The vectors are divided
15 | into group chunks. When groups=1, it is the same as VectorQuant.
16 | share (bool): If True, codebook is shared for each sub-vector.
17 | *rest*: see VectorQuant()
18 | """
19 |
20 | def __init__(
21 | self,
22 | feature_size : int,
23 | num_codes : int,
24 | groups : int = 1,
25 | share : bool = True,
26 | **kwargs,
27 | ):
28 |
29 | if not share and not feature_size % groups == 0:
30 | e_msg = f'feature_size {self.feature_size} must be divisible by groups {groups}.'
31 | raise RuntimeError(str(cs(e_msg, 'red')))
32 |
33 | num_codebooks = 1 if share else groups
34 | in_dim = self.group_size = num_codes // num_codebooks
35 | out_dim = feature_size // groups
36 |
37 | super().__init__(feature_size, num_codes, code_vector_size=out_dim, **kwargs)
38 |
39 | self.groups = groups
40 | self.share = share
41 | self.codebook = nn.Embedding(in_dim * num_codebooks, out_dim)
42 | return
43 |
44 |
45 | def get_codebook_by_group(self, group):
46 | cb = self.codebook.weight
47 | offset = 0 if self.share else group * self.group_size
48 | return cb[offset : offset + self.group_size], offset
49 |
50 |
51 | @with_codebook_normalization
52 | def forward(self, z):
53 |
54 | ######
55 | ## (1) formatting data by groups and invariant to dim
56 | ######
57 |
58 | z = self.prepare_inputs(z, self.groups)
59 |
60 | if not self.enabled:
61 | z = self.to_original_format(z)
62 | return z, {}
63 |
64 | ######
65 | ## (2) quantize latent vector
66 | ######
67 |
68 | z_q = torch.zeros_like(z)
69 | d = torch.zeros(z_q.shape[:-1]).to(z_q.device)
70 | q = torch.zeros(z_q.shape[:-1], dtype=torch.long).to(z_q.device)
71 |
72 | for i in range(self.groups):
73 | # select group
74 | _z = z[..., i:i+1, :]
75 |
76 | # quantize
77 | cb, offset = self.get_codebook_by_group(i)
78 | _z_q, _d, _q = self.quantize(cb, _z)
79 |
80 | # assign to tensor
81 | z_q[..., i:i+1, :] = _z_q
82 | d[..., i:i+1] = _d
83 | q[..., i:i+1] = _q + offset
84 |
85 | to_return = {
86 | 'z' : z, # each group input z_e
87 | 'z_q' : z_q, # quantized output z_q
88 | 'd' : d, # distance function for each group
89 | 'q' : q, # codes using offsetted indices
90 | 'loss': self.compute_loss(z, z_q),
91 | 'perplexity': None,
92 | }
93 |
94 | z_q = self.straight_through_approximation(z, z_q)
95 | z_q = self.to_original_format(z_q)
96 | return z_q, to_return
97 |
--------------------------------------------------------------------------------
/vqtorch/nn/pool.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class _VecPool2d(nn.Module):
7 | def __init__(self, weighting_fn, kernel_size=3, stride=2, padding=2, dilation=1):
8 | super().__init__()
9 |
10 | self.k = kernel_size
11 | self.s = stride
12 | self.p = padding
13 | self.d = dilation
14 |
15 | self.weighting_fn = weighting_fn
16 |
17 | self.unfold = nn.Unfold(kernel_size=self.k, dilation=self.d, padding=self.p, stride=self.s)
18 | return
19 |
20 | def forward(self, x):
21 | b, c, h, w = x.shape
22 | out_h = (h - self.k + 2 * self.p) // self.s + 1
23 | out_w = (w - self.k + 2 * self.p) // self.s + 1
24 |
25 | with torch.no_grad():
26 | n = x.norm(dim=1, p=2, keepdim=True)
27 | n = self.unfold(n) # B x (K**2) x N
28 | n = self.weighting_fn(n, dim=1).unsqueeze(1)
29 |
30 | x = self.unfold(x) # B x C * (K**2) x N
31 | x = x.view(b, c, -1, x.size(-1))
32 |
33 | x = n * x
34 | x = x.sum(2).view(b, c, out_h, out_w)
35 | return x
36 |
37 |
38 | class MaxVecPool2d(_VecPool2d):
39 | def __init__(self, *args, **kwargs):
40 | super().__init__(MaxVecPool2d.max_onehot, *args, **kwargs)
41 |
42 | @staticmethod
43 | def max_onehot(x, dim):
44 | b, s, n = x.shape
45 |
46 | x = x.argmax(dim=1)
47 | x = x.view(-1)
48 | x = F.one_hot(x, s).view(b, n, s).swapaxes(-1, dim)
49 | return x
50 |
51 |
52 | class SoftMaxVecPool2d(_VecPool2d):
53 | def __init__(self, *args, **kwargs):
54 | super().__init__(torch.softmax, *args, **kwargs)
55 |
--------------------------------------------------------------------------------
/vqtorch/nn/rvq.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from stringcolor import cs
5 | from vqtorch.norms import with_codebook_normalization
6 | from .vq import VectorQuant
7 |
8 |
9 | class ResidualVectorQuant(VectorQuant):
10 | """
11 | Args
12 | groups (int): Number of residual VQ to apply. When num_residual=1,
13 | layer acts will be equivalent to VectorQuant.
14 | share (bool): If True, codebook is shared for every quantization.
15 | *rest*: see VectorQuant()
16 |
17 | NOTE: Don't use L2 normalization on the codebook. ResidualVQ is norm sensitive.
18 | For norm invariant, consider using cosine distance variant.
19 | """
20 |
21 | def __init__(
22 | self,
23 | feature_size : int,
24 | num_codes : int,
25 | groups : int = 1,
26 | share : bool = True,
27 | **kwargs,
28 | ):
29 |
30 | if not share and not feature_size % groups == 0:
31 | e_msg = f'feature_size {feature_size} must be divisible by residual groups {groups}.'
32 | raise RuntimeError(str(cs(e_msg, 'red')))
33 |
34 | self.groups = groups
35 | self.share = share
36 |
37 | num_codebooks = 1 if share else groups
38 | in_dim = self.group_size = num_codes // num_codebooks
39 | out_dim = feature_size
40 |
41 | super().__init__(feature_size, num_codes, code_vector_size=out_dim, **kwargs)
42 |
43 | self.groups = groups
44 | self.share = share
45 | self.codebook = nn.Embedding(in_dim * num_codebooks, out_dim)
46 |
47 | return
48 |
49 |
50 | def get_codebook_by_group(self, group):
51 | cb = self.codebook.weight
52 | offset = 0 if self.share else group * self.group_size
53 | return cb[offset : offset + self.group_size], offset
54 |
55 |
56 | @with_codebook_normalization
57 | def forward(self, z):
58 |
59 | ######
60 | ## (1) formatting data by groups and invariant to dim
61 | ######
62 |
63 | z = self.prepare_inputs(z, groups=1)
64 |
65 | if not self.enabled:
66 | z = self.to_original_format(z)
67 | return z, {}
68 |
69 | ######
70 | ## (2) quantize latent vector
71 | ######
72 |
73 | z_q = torch.zeros_like(z)
74 | z_res = torch.zeros(*z.shape[:-2], self.groups + 1, z.shape[-1]).to(z.device)
75 |
76 | d = torch.zeros(*z_q.shape[:-2], self.groups).to(z_q.device)
77 | q = torch.zeros(*z_q.shape[:-2], self.groups, dtype=torch.long).to(z_q.device)
78 |
79 | for i in range(self.groups):
80 | # select group
81 | _z = (z - z_q) # compute resiudal
82 | z_res[..., i:i+1, :] = _z
83 |
84 | # quantize
85 | cb, offset = self.get_codebook_by_group(i)
86 | _z_q, _d, _q = self.quantize(cb, _z)
87 |
88 | # update estimate
89 | z_q = z_q + _z_q
90 |
91 | # assign to tensor
92 | d[..., i:i+1] = _d
93 | q[..., i:i+1] = _q + offset
94 |
95 | z_res[..., -1:, :] = z - z_q
96 |
97 | to_return = {
98 | 'z' : z, # each group input z_e
99 | 'z_q' : z_q, # quantized output z_q
100 | 'd' : d, # distance function for each group
101 | 'q' : q, # codes using offsetted indices
102 | 'z_res': z_res,
103 | 'loss' : self.compute_loss(z, z_q),
104 | 'perplexity': None,
105 | }
106 |
107 | z_q = self.straight_through_approximation(z, z_q)
108 | z_q = self.to_original_format(z_q)
109 | return z_q, to_return
110 |
--------------------------------------------------------------------------------
/vqtorch/nn/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .replace import lru_replacement
2 |
--------------------------------------------------------------------------------
/vqtorch/nn/utils/init.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from stringcolor import cs
3 | import warnings
4 | import vqtorch
5 |
6 |
7 |
8 | @torch.no_grad()
9 | def data_dependent_init_forward_hook(self, inputs, outputs, use_kmeans=True, verbose=False):
10 | """ initializes codebook from data """
11 |
12 | if (not self.training) or (self.data_initialized.item() == 1):
13 | return
14 |
15 | if verbose:
16 | print(cs('initializing codebook with k-means++', 'y'))
17 |
18 | def sample_centroids(z_e, num_codes):
19 | """ replaces the data of the codebook with z_e randomly. """
20 |
21 | z_e = z_e.reshape(-1, z_e.size(-1))
22 |
23 | if num_codes >= z_e.size(0):
24 | e_msg = f'\ncodebook size > warmup samples: {num_codes} vs {z_e.size(0)}. ' + \
25 | 'recommended to decrease the codebook size or increase batch size.'
26 |
27 | warnings.warn(str(cs(e_msg, 'yellow')))
28 |
29 | # repeat until it fits and add noise
30 | repeat = num_codes // z_e.shape[0]
31 | new_codes = z_e.data.tile([repeat, 1])[:num_codes]
32 | new_codes += 1e-3 * torch.randn_like(new_codes.data)
33 |
34 | else:
35 | # you have more warmup samples than codebook. subsample data
36 | if use_kmeans:
37 | from torchpq.clustering import KMeans
38 | kmeans = KMeans(n_clusters=num_codes, distance='euclidean', init_mode="kmeans++")
39 | kmeans.fit(z_e.data.T.contiguous())
40 | new_codes = kmeans.centroids.T
41 | else:
42 | indices = torch.randint(low=0, high=num_codes, size=(num_codes,))
43 | indices = indices.to(z_e.device)
44 | new_codes = torch.index_select(z_e, 0, indices).to(z_e.device).data
45 |
46 | return new_codes
47 |
48 | _, misc = outputs
49 | z_e, z_q = misc['z'], misc['z_q']
50 |
51 | if type(self) is vqtorch.nn.VectorQuant:
52 | num_codes = self.codebook.weight.shape[0]
53 | new_codebook = sample_centroids(z_e, num_codes)
54 | self.codebook.weight.data = new_codebook
55 |
56 | elif type(self) is vqtorch.nn.GroupVectorQuant:
57 | if self.share:
58 | print(self.codebook.weight.shape)
59 | new_codebook = sample_centroids(z_e, self.group_size)
60 | self.codebook.weight.data = new_codebook
61 | else:
62 | for i in range(self.groups):
63 | offset = i * self.group_size
64 | new_codebook = sample_centroids(z_e[..., i, :], self.group_size)
65 | self.codebook.weight.data[offset:offset+self.group_size] = new_codebook
66 |
67 | elif type(self) is vqtorch.nn.ResidualVectorQuant:
68 | z_e = misc['z_res']
69 |
70 | if self.share:
71 | new_codebook = sample_centroids(z_e, self.group_size)
72 | self.codebook.weight.data = new_codebook
73 | else:
74 | for i in range(self.groups):
75 | offset = i * self.group_size
76 | new_codebook = sample_centroids(z_e[..., i, :], self.group_size)
77 | self.codebook.weight.data[offset:offset+self.group_size] = new_codebook
78 |
79 |
80 | self.data_initialized.fill_(1)
81 | return
82 |
--------------------------------------------------------------------------------
/vqtorch/nn/utils/replace.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 |
5 | class ReplaceLRU():
6 | """
7 | Attributes:
8 | rho (float): mutation noise
9 | timeout (int): number of batch it has seen
10 | """
11 | VALID_POLICIES = ['input_random', 'input_kmeans', 'self']
12 |
13 | def __init__(self, rho=1e-4, timeout=100):
14 | assert timeout > 1
15 | assert rho > 0.0
16 | self.rho = rho
17 | self.timeout = timeout
18 |
19 | self.policy = 'input_random'
20 | # self.policy = 'input_kmeans'
21 | # self.policy = 'self'
22 | self.tau = 2.0
23 |
24 | assert self.policy in self.VALID_POLICIES
25 | return
26 |
27 | @staticmethod
28 | def apply(module, rho=0., timeout=100):
29 | """ register forward hook """
30 | fn = ReplaceLRU(rho, timeout)
31 | device = next(module.parameters()).device
32 | module.register_forward_hook(fn)
33 | module.register_buffer('_counts', timeout * torch.ones(module.num_codes))
34 | module._counts = module._counts.to(device)
35 | return fn
36 |
37 | def __call__(self, module, inputs, outputs):
38 | """
39 | This function is triggered during forward pass
40 | recall: z_q, misc = vq(x)
41 |
42 | Args
43 | module (nn.VectorQuant)
44 | inputs (tuple): A tuple with 1 element
45 | x (Tensor)
46 | outputs (tuple): A tuple with 2 elements
47 | z_q (Tensor), misc (dict)
48 | """
49 | if not module.training:
50 | return
51 |
52 | # count down all code by 1 and if used, reset timer to timeout value
53 | module._counts -= 1
54 |
55 | # --- computes most recent codebook usage --- #
56 | unique, counts = torch.unique(outputs[1]['q'], return_counts=True)
57 | module._counts.index_fill_(0, unique, self.timeout)
58 |
59 | # --- find how many needs to be replaced --- #
60 | # num_active = self.check_and_replace_dead_codes(module, outputs)
61 | inactive_indices = torch.argwhere(module._counts == 0).squeeze(-1)
62 | num_inactive = inactive_indices.size(0)
63 |
64 | if num_inactive > 0:
65 |
66 | if self.policy == 'self':
67 | # exponential distance allows more recently used codes to be even more preferable
68 | p = torch.zeros_like(module._counts)
69 | p[unique] = counts.float()
70 | p = p / p.sum()
71 | p = torch.exp(self.tau * p) - 1 # the negative 1 is to drive p=0 to stay 0
72 |
73 | selected_indices = torch.multinomial(p, num_inactive, replacement=True)
74 | selected_values = module.codebook.weight.data[selected_indices].clone()
75 |
76 | elif self.policy == 'input_random':
77 | z_e = outputs[1]['z'].flatten(0, -2) # flatten to 2D
78 | z_e = z_e[torch.randperm(z_e.size(0))] # shuffle
79 | mult = num_inactive // z_e.size(0) + 1
80 | if mult > 1: # if theres not enough
81 | z_e = torch.cat(mult * [z_e])
82 | selected_values = z_e[:num_inactive]
83 |
84 | elif self.policy == 'input_kmeans':
85 | # can be extremely slow
86 | from torchpq.clustering import KMeans
87 | z_e = outputs[1]['z'].flatten(0, -2) # flatten to 2D
88 | z_e = z_e[torch.randperm(z_e.size(0))] # shuffle
89 | kmeans = KMeans(n_clusters=num_inactive, distance='euclidean', init_mode="kmeans++")
90 | kmeans.fit(z_e.data.T.contiguous())
91 | selected_values = kmeans.centroids.T
92 |
93 | if self.rho > 0:
94 | norm = selected_values.norm(p=2, dim=-1, keepdim=True)
95 | noise = torch.randn_like(selected_values)
96 | selected_values = selected_values + self.rho * norm * noise
97 |
98 | # --- update dead codes with new codes --- #
99 | module.codebook.weight.data[inactive_indices] = selected_values
100 | module._counts[inactive_indices] += self.timeout
101 |
102 | return outputs
103 |
104 |
105 |
106 | def lru_replacement(vq_module, rho=1e-4, timeout=100):
107 | """
108 | Example::
109 | >>> vq = VectorQuant(...)
110 | >>> vq = lru_replacement(vq)
111 | """
112 | ReplaceLRU.apply(vq_module, rho, timeout)
113 | return vq_module
114 |
--------------------------------------------------------------------------------
/vqtorch/nn/vq.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from vqtorch.dists import get_dist_fns
6 | import vqtorch
7 | from vqtorch.norms import with_codebook_normalization
8 | from .vq_base import _VQBaseLayer
9 | from .affine import AffineTransform
10 |
11 |
12 | class VectorQuant(_VQBaseLayer):
13 | """
14 | Vector quantization layer using straight-through estimation.
15 |
16 | Args:
17 | feature_size (int): feature dimension corresponding to the vectors
18 | num_codes (int): number of vectors in the codebook
19 | beta (float): commitment loss weighting
20 | sync_nu (float): sync loss weighting
21 | affine_lr (float): learning rate for affine transform
22 | affine_groups (int): number of affine parameter groups
23 | replace_freq (int): frequency to replace dead codes
24 | inplace_optimizer (Optimizer): optimizer for inplace codebook updates
25 | **kwargs: additional arguments for _VQBaseLayer
26 |
27 | Returns:
28 | Quantized vector z_q and return dict
29 | """
30 |
31 |
32 | def __init__(
33 | self,
34 | feature_size : int,
35 | num_codes : int,
36 | beta : float = 0.95,
37 | sync_nu : float = 0.0,
38 | affine_lr: float = 0.0,
39 | affine_groups: int = 1,
40 | replace_freq: int = 0,
41 | inplace_optimizer: torch.optim.Optimizer = None,
42 | **kwargs,
43 | ):
44 |
45 | super().__init__(feature_size, num_codes, **kwargs)
46 | self.loss_fn, self.dist_fn = get_dist_fns('euclidean')
47 |
48 | if beta < 0.0 or beta > 1.0:
49 | raise ValueError(f'beta must be in [0, 1] but got {beta}')
50 |
51 | self.beta = beta
52 | self.nu = sync_nu
53 | self.affine_lr = affine_lr
54 | self.codebook = nn.Embedding(self.num_codes, self.feature_size)
55 |
56 | if inplace_optimizer is not None:
57 | if beta != 1.0:
58 | raise ValueError('inplace_optimizer can only be used with beta=1.0')
59 | self.inplace_codebook_optimizer = inplace_optimizer(self.codebook.parameters())
60 |
61 | if affine_lr > 0:
62 | # defaults to using learnable affine parameters
63 | self.affine_transform = AffineTransform(
64 | self.code_vector_size,
65 | use_running_statistics=False,
66 | lr_scale=affine_lr,
67 | num_groups=affine_groups,
68 | )
69 | if replace_freq > 0:
70 | vqtorch.nn.utils.lru_replacement(self, rho=0.01, timeout=replace_freq)
71 | return
72 |
73 |
74 | def straight_through_approximation(self, z, z_q):
75 | """ passed gradient from z_q to z """
76 | if self.nu > 0:
77 | z_q = z + (z_q - z).detach() + (self.nu * z_q) + (-self.nu * z_q).detach()
78 | else:
79 | z_q = z + (z_q - z).detach()
80 | return z_q
81 |
82 |
83 | def compute_loss(self, z_e, z_q):
84 | """ computes loss between z and z_q """
85 | return ((1.0 - self.beta) * self.loss_fn(z_e, z_q.detach()) + \
86 | (self.beta) * self.loss_fn(z_e.detach(), z_q))
87 |
88 |
89 | def quantize(self, codebook, z):
90 | """
91 | Quantizes the latent codes z with the codebook
92 |
93 | Args:
94 | codebook (Tensor): B x F
95 | z (Tensor): B x ... x F
96 | """
97 |
98 | # reshape to (BHWG x F//G) and compute distance
99 | z_shape = z.shape[:-1]
100 | z_flat = z.view(z.size(0), -1, z.size(-1))
101 |
102 | if hasattr(self, 'affine_transform'):
103 | self.affine_transform.update_running_statistics(z_flat, codebook)
104 | codebook = self.affine_transform(codebook)
105 |
106 | with torch.no_grad():
107 | dist_out = self.dist_fn(
108 | tensor=z_flat,
109 | codebook=codebook,
110 | topk=self.topk,
111 | compute_chunk_size=self.cdist_chunk_size,
112 | half_precision=(z.is_cuda),
113 | )
114 |
115 | d = dist_out['d'].view(z_shape)
116 | q = dist_out['q'].view(z_shape).long()
117 |
118 | z_q = F.embedding(q, codebook)
119 |
120 | if self.training and hasattr(self, 'inplace_codebook_optimizer'):
121 | # update codebook inplace
122 | ((z_q - z.detach()) ** 2).mean().backward()
123 | self.inplace_codebook_optimizer.step()
124 | self.inplace_codebook_optimizer.zero_grad()
125 |
126 | # forward pass again with the update codebook
127 | z_q = F.embedding(q, codebook)
128 |
129 | # NOTE to save compute, we assumed Q did not change.
130 |
131 | return z_q, d, q
132 |
133 | @torch.no_grad()
134 | def get_codebook(self):
135 | cb = self.codebook.weight
136 | if hasattr(self, 'affine_transform'):
137 | cb = self.affine_transform(cb)
138 | return cb
139 |
140 | def get_codebook_affine_params(self):
141 | if hasattr(self, 'affine_transform'):
142 | return self.affine_transform.get_affine_params()
143 | return None
144 |
145 | @with_codebook_normalization
146 | def forward(self, z):
147 |
148 | ######
149 | ## (1) formatting data by groups and invariant to dim
150 | ######
151 |
152 | z = self.prepare_inputs(z, self.groups)
153 |
154 | if not self.enabled:
155 | z = self.to_original_format(z)
156 | return z, {}
157 |
158 | ######
159 | ## (2) quantize latent vector
160 | ######
161 |
162 | z_q, d, q = self.quantize(self.codebook.weight, z)
163 |
164 | # e_mean = F.one_hot(q, num_classes=self.num_codes).view(-1, self.num_codes).float().mean(0)
165 | # perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
166 | perplexity = None
167 |
168 | to_return = {
169 | 'z' : z, # each group input z_e
170 | 'z_q': z_q, # quantized output z_q
171 | 'd' : d, # distance function for each group
172 | 'q' : q, # codes
173 | 'loss': self.compute_loss(z, z_q).mean(),
174 | 'perplexity': perplexity,
175 | }
176 |
177 | z_q = self.straight_through_approximation(z, z_q)
178 | z_q = self.to_original_format(z_q)
179 |
180 | return z_q, to_return
181 |
--------------------------------------------------------------------------------
/vqtorch/nn/vq_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from vqtorch.norms import get_norm
5 | from vqtorch.nn.utils.init import data_dependent_init_forward_hook
6 |
7 |
8 |
9 | class _VQBaseLayer(nn.Module):
10 | """
11 | Base template code for vector quanitzation. All VQ layers will inherit
12 | from this class.
13 |
14 | Args:
15 | feature_size (int):
16 | The size of the feature. this is the length of each
17 | code vector. the dimensions must match the input feature size.
18 | num_codes (int):
19 | Number of codes to use in the codebook.
20 | dim (int): Dimension to quantize. by default quantization happens on
21 | the channel dimension. For example, given an image tensor
22 | (B x C x H x W) and dim=1, the channels are treated as features
23 | and the resulting codes `q` has the shape (B x H x W).
24 | For transformers (B x N x C), you should set dim=2 or -1.
25 | norm (str): Feature normalization.
26 | codebook_norm (str): Codebook normalization.
27 |
28 | Returns:
29 | Quantized vector z_q and return dict
30 |
31 | Attributes:
32 | cdist_chunk_size (int): chunk size for divide-and-conquer topk cdist.
33 | enabled (bool): If false, the model is not quantized and acts as an identity layer.
34 | """
35 |
36 | cdist_chunk_size = 1024
37 | enabled = True
38 |
39 | def __init__(
40 | self,
41 | feature_size : int,
42 | num_codes : int,
43 | dim : int = 1,
44 | norm : str = 'none',
45 | cb_norm : str = 'none',
46 | kmeans_init : bool = False,
47 | code_vector_size : int = None,
48 | ):
49 |
50 | super().__init__()
51 | self.feature_size = feature_size
52 | self.code_vector_size = feature_size if code_vector_size is None else code_vector_size
53 | self.num_codes = num_codes
54 | self.dim = dim
55 |
56 | self.groups = 1 # for group VQ
57 | self.topk = 1 # for probabilistic VQ
58 |
59 | self.norm = norm
60 | self.codebook_norm = cb_norm
61 | self.norm_layer, self.norm_before_grouping = get_norm(norm, feature_size)
62 |
63 | if kmeans_init:
64 | self.register_buffer('data_initialized', torch.zeros(1))
65 | self.register_forward_hook(data_dependent_init_forward_hook)
66 | return
67 |
68 | def quantize(self, codebook, z):
69 | """
70 | Quantizes the latent codes z with the codebook
71 |
72 | Args:
73 | codebook (Tensor): B x C
74 | z (Tensor): B x ... x C
75 | """
76 | raise NotImplementedError
77 |
78 |
79 | def compute_loss(self, z_e, z_q):
80 | """ computes error between z and z_q """
81 | raise NotImplementedError
82 |
83 |
84 | def to_canonical_group_format(self, z, groups):
85 | """
86 | Converts data into canonical group format
87 |
88 | The quantization dim is sent to the last dimension.
89 | The features are then resized such that C -> G x C'
90 |
91 | Args:
92 | x (Tensor): a tensor in group form [B x C x ... ]
93 | groups (int): number of groups
94 | Returns:
95 | x of shape [B x ... x G x C']
96 | """
97 |
98 | z = z.moveaxis(self.dim, -1).contiguous()
99 | z = z.unflatten(-1, (groups, -1))
100 | return z
101 |
102 |
103 | def to_original_format(self, x):
104 | """
105 | Merges group and permutes dimension back
106 |
107 | Args:
108 | x (Tensor): a tensor in group form [B x ... x G x C // G]
109 | Returns:
110 | merged `x` of shape [B x ... x C] (assuming dim=1)
111 | """
112 | return x.flatten(-2, -1).moveaxis(-1, self.dim)
113 |
114 |
115 | def prepare_inputs(self, z, groups):
116 | """
117 | Prepare input with normalization and group format
118 |
119 | Args:
120 | x (Tensor): a tensor in group form [B x C x ... ]
121 | groups (int): number of groups
122 | """
123 |
124 | if len(z.shape) <= 1:
125 | e_msg = f'expected a tensor of at least 2 dimensions but found {z.size()}'
126 | raise ValueError(e_msg)
127 |
128 | if self.norm_before_grouping:
129 | z = self.norm_layer(z)
130 |
131 | z = self.to_canonical_group_format(z, groups)
132 |
133 | if not self.norm_before_grouping:
134 | z = self.norm_layer(z)
135 |
136 | return z
137 |
138 |
139 | @property
140 | def requires_grad(self):
141 | return self.codebook[0].weight.requires_grad
142 |
143 |
144 | def set_requires_grad(self, requires_grad):
145 | for codebook in self.codebook:
146 | codebook.weight.requires_grad = requires_grad
147 | return
148 |
149 |
150 | def extra_repr(self):
151 | repr = "\n".join([
152 | f"num_codes: {self.num_codes}",
153 | f"groups: {self.groups}",
154 | f"enabled: {self.enabled}",
155 | ])
156 | return repr
157 |
--------------------------------------------------------------------------------
/vqtorch/norms.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | MAXNORM_CONSTRAINT_VALUE = 10
7 |
8 |
9 | class Normalize(nn.Module):
10 | """
11 | Simple vector normalization module. By default, vectors are normalizes
12 | along the channel dimesion. Each vector associated to the spatial
13 | location is normalized. Used along with cosine-distance VQ layer.
14 | """
15 | def __init__(self, p=2, dim=1, eps=1e-6):
16 | super().__init__()
17 | self.p = p
18 | self.dim = dim
19 | self.eps = eps
20 | return
21 |
22 | def forward(self, x):
23 | return F.normalize(x, p=self.p, dim=self.dim, eps=self.eps)
24 |
25 |
26 | def max_norm(w, p=2, dim=-1, max_norm=MAXNORM_CONSTRAINT_VALUE, eps=1e-8):
27 | norm = w.norm(p=p, dim=dim, keepdim=True)
28 | desired = torch.clamp(norm.data, max=max_norm)
29 | # desired = torch.clamp(norm, max=max_norm)
30 | return w * (desired / (norm + eps))
31 |
32 |
33 | class MaxNormConstraint(nn.Module):
34 | def __init__(self, max_norm=1, p=2, dim=-1, eps=1e-8):
35 | super().__init__()
36 | self.p = p
37 | self.eps = eps
38 | self.dim = dim
39 | self.max_norm = max_norm
40 |
41 | def forward(self, x):
42 | return max_norm(x, self.p, self.dim, max_norm=self.max_norm)
43 |
44 |
45 | @torch.no_grad()
46 | def with_codebook_normalization(func):
47 | def wrapper(*args):
48 | self = args[0]
49 | for n, m in self.named_modules():
50 | if isinstance(m, nn.Embedding):
51 | if self.codebook_norm == 'l2':
52 | m.weight.data = max_norm(m.weight.data, p=2, dim=1, eps=1e-8)
53 | elif self.codebook_norm == 'l2c':
54 | m.weight.data = F.normalize(m.weight.data, p=2, dim=1, eps=1e-8)
55 | return func(*args)
56 | return wrapper
57 |
58 |
59 | def get_norm(norm, num_channels=None):
60 | before_grouping = True
61 | if norm == 'l2':
62 | norm_layer = Normalize(p=2, dim=-1)
63 | before_grouping = False
64 | elif norm == 'l2c':
65 | norm_layer = MaxNormConstraint(p=2, dim=-1, max_norm=MAXNORM_CONSTRAINT_VALUE)
66 | before_grouping = False
67 | elif norm == 'bn':
68 | norm_layer = nn.BatchNorm2d(num_channels)
69 | elif norm == 'gn':
70 | norm_layer = GroupNorm(num_channels)
71 | elif norm in ['none', None]:
72 | norm_layer = nn.Identity()
73 | elif norm == 'in':
74 | norm_layer = nn.InstanceNorm2d(num_channels)
75 | else:
76 | raise ValueError(f'unknown norm {norm}')
77 | return norm_layer, before_grouping
78 |
79 |
80 | def GroupNorm(in_channels):
81 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
82 |
83 |
84 | def match_norm(x, y, dim=-1, eps=1e-8):
85 | """
86 | matches vector norm of x to that of y
87 | Args:
88 | x (Tensor): a tensor of any shape
89 | y (Tensor): a tensor of the same shape as `x`.
90 | dim (int): dimension to match the norm over
91 | eps (float): epsilon to mitigate division by zero.
92 | Returns:
93 | `x` with the same norm as `y` across `dim`
94 | """
95 | assert x.shape == y.shape, \
96 | f'expected `x` and `y` to have the same dim but found {x.shape} vs {y.shape}'
97 |
98 | # move chosen dim to last dim
99 | x = x.moveaxis(dim, -1).contiguous()
100 | y = y.moveaxis(dim, -1).contiguous()
101 | x_shape = x.shape
102 |
103 | # unravel everything such that [GBHW X C]
104 |
105 | # print(x.shape)
106 | x = x.view(-1, x.size(-1))
107 | y = y.view(-1, y.size(-1))
108 |
109 | # compute norm on C
110 | x_norm = torch.norm(x, p=2, dim=1, keepdim=True)
111 | y_norm = torch.norm(y, p=2, dim=1, keepdim=True)
112 |
113 | # clamp y_norm for division by 0
114 | x_norm = torch.clamp(x_norm, min=eps)
115 |
116 | # normalize (x now has same norm as y)
117 | x = y_norm * (x / x_norm)
118 | x = x.view(x_shape)
119 | x = x.moveaxis(-1, dim).contiguous()
120 | return x
121 |
--------------------------------------------------------------------------------
/vqtorch/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from vqtorch.nn import _VQBaseLayer
5 |
6 |
7 | def is_vq(m):
8 | """ checks if the module is a VQ layer """
9 | return issubclass(type(m), _VQBaseLayer)
10 |
11 | def is_vqn(net):
12 | """ checks if the network contains a VQ layer """
13 | return np.any([is_vq(layer) for layer in net.modules()])
14 |
15 | def get_vq_layers(model):
16 | """ returns vq layers from a network """
17 | return [m for m in model.modules() if is_vq(m)]
18 |
19 |
20 | class no_vq():
21 | """
22 | Function to turn off VQ by setting all the VQ layers to identity function.
23 |
24 | Examples::
25 | >>> with vqtorch.no_vq():
26 | ... out = model(x)
27 | """
28 |
29 | def __init__(self, modules):
30 | if type(modules) is not list:
31 | modules = [modules]
32 |
33 | for module in modules:
34 | if isinstance(module, torch.nn.DataParallel):
35 | module = module.module
36 |
37 | assert isinstance(module, torch.nn.Module), \
38 | f'expected input to be nn.Module or a list of nn.Module ' + \
39 | f'but found {type(module)}'
40 |
41 | self.enable_vq(module, enable=False)
42 |
43 | self.modules = modules
44 | return
45 |
46 | def __enter__(self):
47 | pass
48 |
49 | def __exit__(self, exception_type, exception_value, traceback):
50 | for module in self.modules:
51 | if isinstance(module, torch.nn.DataParallel):
52 | module = module.module
53 | self.enable_vq(module, enable=True)
54 | return
55 |
56 | def enable_vq(self, module, enable):
57 | for m in module.modules():
58 | if is_vq(m):
59 | m.enabled = enable
60 | return
61 |
--------------------------------------------------------------------------------