├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── maskbit.png ├── maskbit_pytorch ├── __init__.py ├── maskbit.py └── trainer.py └── pyproject.toml /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## MaskBit - Pytorch (wip) 4 | 5 | Implementation of the proposed [MaskBit](https://arxiv.org/abs/2409.16211) from Bytedance AI 6 | 7 | This paper can be viewed as a modernized version of the architecture from [Taming Transformers](https://arxiv.org/abs/2012.09841) from Esser et al. 8 | 9 | They use the binary scalar quantization proposed in [MagVit2](https://arxiv.org/abs/2310.05737) in their autoencoder, and then non-autoregressive mask decoding, where the masking is setting the bit (`-1` or `+1`) to `0`, projected for the transformer without explicit embeddings for the trit 10 | 11 | ## Usage 12 | 13 | ```python 14 | import torch 15 | from maskbit_pytorch import BQVAE, MaskBit 16 | 17 | images = torch.randn(1, 3, 64, 64) 18 | 19 | # train vae 20 | 21 | vae = BQVAE( 22 | image_size = 64, 23 | dim = 512 24 | ) 25 | 26 | loss = vae(images, return_loss = True) 27 | loss.backward() 28 | 29 | # train maskbit 30 | 31 | maskbit = MaskBit( 32 | vae, 33 | dim = 512, 34 | bits_group_size = 512, 35 | depth = 2 36 | ) 37 | 38 | loss = maskbit(images) 39 | loss.backward() 40 | 41 | # after much training 42 | 43 | sampled_image = maskbit.sample() # (1, 3, 64, 64) 44 | ``` 45 | 46 | ## Citations 47 | 48 | ```bibtex 49 | @inproceedings{Weber2024MaskBitEI, 50 | title = {MaskBit: Embedding-free Image Generation via Bit Tokens}, 51 | author = {Mark Weber and Lijun Yu and Qihang Yu and Xueqing Deng and Xiaohui Shen and Daniel Cremers and Liang-Chieh Chen}, 52 | year = {2024}, 53 | url = {https://api.semanticscholar.org/CorpusID:272832013} 54 | } 55 | ``` 56 | -------------------------------------------------------------------------------- /maskbit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/maskbit-pytorch/eaca2ee7c5e27bcb5c16ef4ab148c9f26b3610fe/maskbit.png -------------------------------------------------------------------------------- /maskbit_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from maskbit_pytorch.maskbit import ( 2 | BQVAE, 3 | MaskBit 4 | ) 5 | 6 | from maskbit_pytorch.trainer import ( 7 | BQVAETrainer, 8 | MaskBitTrainer 9 | ) 10 | -------------------------------------------------------------------------------- /maskbit_pytorch/maskbit.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from math import ceil, prod, log2 4 | from functools import cache 5 | 6 | import torch 7 | from torch import nn, pi, tensor 8 | import torch.nn.functional as F 9 | import torch.distributed as dist 10 | from torch.nn import Module, ModuleList 11 | 12 | from vector_quantize_pytorch import ( 13 | LFQ 14 | ) 15 | 16 | from x_transformers import Encoder 17 | 18 | import einx 19 | from einops.layers.torch import Rearrange 20 | from einops import rearrange, repeat, pack, unpack 21 | 22 | from tqdm import tqdm 23 | 24 | # ein notation 25 | # b - batch 26 | # c - channels 27 | # h - height 28 | # w - width 29 | # n - raw bits sequence length 30 | # ng - sequence of bit groups 31 | 32 | # tensor typing 33 | 34 | import jaxtyping 35 | from jaxtyping import jaxtyped 36 | from beartype import beartype 37 | from beartype.door import is_bearable 38 | 39 | class TorchTyping: 40 | def __init__(self, abstract_dtype): 41 | self.abstract_dtype = abstract_dtype 42 | 43 | def __getitem__(self, shapes: str): 44 | return self.abstract_dtype[Tensor, shapes] 45 | 46 | Float = TorchTyping(jaxtyping.Float) 47 | Int = TorchTyping(jaxtyping.Int) 48 | Bool = TorchTyping(jaxtyping.Bool) 49 | 50 | # helper functions 51 | 52 | def exists(v): 53 | return v is not None 54 | 55 | def is_empty(t: Tensor): 56 | return t.numel() == 0 57 | 58 | def default(v, d): 59 | return v if exists(v) else d 60 | 61 | def divisible_by(num, den): 62 | return (num % den) == 0 63 | 64 | def pack_one(t, pattern): 65 | t, packed_shape = pack([t], pattern) 66 | 67 | def inverse(t, unpack_pattern = None): 68 | unpack_pattern = default(unpack_pattern, pattern) 69 | return unpack(t, packed_shape, unpack_pattern)[0] 70 | 71 | return t, inverse 72 | 73 | # distributed helpers 74 | 75 | @cache 76 | def is_distributed(): 77 | return dist.is_initialized() and dist.get_world_size() > 1 78 | 79 | def maybe_distributed_mean(t): 80 | if not is_distributed(): 81 | return t 82 | 83 | dist.all_reduce(t) 84 | t = t / dist.get_world_size() 85 | return t 86 | 87 | # tensor helpers 88 | 89 | def log(t, eps = 1e-20): 90 | return t.clamp(min = eps).log() 91 | 92 | def calc_entropy(logits): 93 | prob = logits.softmax(dim = -1) 94 | return (-prob * log(prob)).sum(dim = -1) 95 | 96 | def gumbel_noise(t): 97 | noise = torch.rand_like(t) 98 | return -log(-log(noise)) 99 | 100 | def gumbel_sample(t, temperature = 1., dim = -1): 101 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim) 102 | 103 | # adversarial related 104 | 105 | def hinge_discr_loss(fake, real): 106 | return (F.relu(1 + fake) + F.relu(1 - real)).mean() 107 | 108 | def hinge_gen_loss(fake): 109 | return -fake.mean() 110 | 111 | class ScalarEMA(Module): 112 | def __init__(self, decay: float): 113 | super().__init__() 114 | self.decay = decay 115 | 116 | self.register_buffer('initted', tensor(False)) 117 | self.register_buffer('ema', tensor(0.)) 118 | 119 | @torch.no_grad() 120 | def forward( 121 | self, 122 | values: Float['b'] 123 | ): 124 | if is_empty(values): 125 | return 126 | 127 | values = values.mean() 128 | values = maybe_distributed_mean(values) 129 | 130 | if not self.initted: 131 | self.ema.copy_(values) 132 | self.initted.copy_(tensor(True)) 133 | return 134 | 135 | self.ema.lerp_(values, 1. - self.decay) 136 | 137 | class ChanRMSNorm(Module): 138 | def __init__(self, dim): 139 | super().__init__() 140 | self.scale = dim ** 0.5 141 | self.gamma = nn.Parameter(torch.zeros(dim)) 142 | 143 | def forward(self, x): 144 | gamma = rearrange(self.gamma, 'c -> c 1 1') 145 | return F.normalize(x, dim = 1) * self.scale * (gamma + 1) 146 | 147 | class Discriminator(Module): 148 | def __init__( 149 | self, 150 | dims: tuple[int, ...], 151 | channels = 3, 152 | init_kernel_size = 5, 153 | ema_decay = 0.99, 154 | ): 155 | super().__init__() 156 | first_dim, *_, last_dim = dims 157 | dim_pairs = zip(dims[:-1], dims[1:]) 158 | 159 | self.layers = ModuleList([]) 160 | 161 | self.layers.append( 162 | nn.Sequential( 163 | nn.Conv2d(channels, first_dim, init_kernel_size, padding = init_kernel_size // 2), 164 | nn.SiLU() 165 | ) 166 | ) 167 | 168 | for dim_in, dim_out in dim_pairs: 169 | layer = nn.Sequential( 170 | nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), 171 | ChanRMSNorm(dim_out), 172 | nn.SiLU() 173 | ) 174 | 175 | self.layers.append(layer) 176 | 177 | dim = last_dim 178 | 179 | self.to_logits = nn.Sequential( 180 | nn.Conv2d(dim, dim, 1), 181 | nn.SiLU(), 182 | nn.Conv2d(dim, 1, 4) 183 | ) 184 | 185 | # for keeping track of the exponential moving averages of real and fake predictions 186 | # for the lecam divergence gan technique employed https://arxiv.org/abs/2104.03310 187 | 188 | self.ema_real = ScalarEMA(ema_decay) 189 | self.ema_fake = ScalarEMA(ema_decay) 190 | 191 | def forward( 192 | self, 193 | x: Float['b c h w'], 194 | is_real: bool | Bool['b'] | None = None 195 | ): 196 | batch, device = x.shape[0], x.device 197 | 198 | for layer in self.layers: 199 | x = layer(x) 200 | 201 | preds = self.to_logits(x) 202 | 203 | if not self.training or not exists(is_real): 204 | return preds 205 | 206 | if isinstance(is_real, bool): 207 | is_real = torch.full((batch,), is_real, dtype = torch.bool, device = device) 208 | 209 | is_fake = ~is_real 210 | 211 | preds_real = preds[is_real] 212 | preds_fake = preds[is_fake] 213 | 214 | self.ema_real(preds_real) 215 | self.ema_fake(preds_fake) 216 | 217 | reg_loss = 0. 218 | 219 | if not is_empty(preds_real) and self.ema_fake.initted: 220 | reg_loss = reg_loss + ((preds_real - self.ema_fake.ema) ** 2).mean() 221 | 222 | if not is_empty(preds_fake) and self.ema_real.initted: 223 | reg_loss = reg_loss + ((preds_fake - self.ema_real.ema) ** 2).mean() 224 | 225 | return preds, reg_loss 226 | 227 | # resnet block 228 | 229 | class Block(Module): 230 | def __init__( 231 | self, 232 | dim, 233 | dropout = 0. 234 | ): 235 | super().__init__() 236 | self.proj = nn.Conv2d(dim, dim, 3, padding = 1) 237 | self.act = nn.SiLU() 238 | self.dropout = nn.Dropout(dropout) 239 | 240 | def forward(self, x): 241 | x = self.proj(x) 242 | x = self.act(x) 243 | return self.dropout(x) 244 | 245 | class ResnetBlock(Module): 246 | def __init__( 247 | self, 248 | dim, 249 | *, 250 | dropout = 0. 251 | ): 252 | super().__init__() 253 | self.block1 = Block(dim, dropout = dropout) 254 | self.block2 = Block(dim) 255 | 256 | def forward(self, x): 257 | h = self.block1(x) 258 | h = self.block2(h) 259 | return h + x 260 | 261 | # down and upsample 262 | 263 | class Upsample(Module): 264 | def __init__( 265 | self, 266 | dim, 267 | dim_out = None 268 | ): 269 | super().__init__() 270 | dim_out = default(dim_out, dim) 271 | conv = nn.Conv2d(dim, dim_out * 4, 1) 272 | 273 | self.net = nn.Sequential( 274 | conv, 275 | nn.SiLU(), 276 | nn.PixelShuffle(2) 277 | ) 278 | 279 | self.init_conv_(conv) 280 | 281 | def init_conv_(self, conv): 282 | o, i, h, w = conv.weight.shape 283 | conv_weight = torch.empty(o // 4, i, h, w) 284 | nn.init.kaiming_uniform_(conv_weight) 285 | conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...') 286 | 287 | conv.weight.data.copy_(conv_weight) 288 | nn.init.zeros_(conv.bias.data) 289 | 290 | def forward(self, x): 291 | return self.net(x) 292 | 293 | def Downsample( 294 | dim, 295 | dim_out = None 296 | ): 297 | dim_out = default(dim_out, dim) 298 | 299 | return nn.Sequential( 300 | Rearrange('b c (h s1) (w s2) -> b (c s1 s2) h w', s1 = 2, s2 = 2), 301 | nn.Conv2d(dim * 4, dim_out, 1) 302 | ) 303 | 304 | # binary quantization vae 305 | 306 | class BQVAE(Module): 307 | 308 | @beartype 309 | def __init__( 310 | self, 311 | dim, 312 | *, 313 | image_size, 314 | channels = 3, 315 | depth = 2, 316 | proj_in_kernel_size = 7, 317 | entropy_loss_weight = 0.1, 318 | reg_loss_weight = 1e-2, 319 | gen_loss_weight = 1e-1, 320 | lfq_kwargs: dict = dict(), 321 | discr_kwargs: dict = dict() 322 | ): 323 | super().__init__() 324 | self.image_size = image_size 325 | self.channels = channels 326 | 327 | self.proj_in = nn.Conv2d(channels, dim, proj_in_kernel_size, padding = proj_in_kernel_size // 2) 328 | 329 | self.encoder = ModuleList([]) 330 | 331 | # encoder 332 | 333 | curr_dim = dim 334 | for _ in range(depth): 335 | self.encoder.append(ModuleList([ 336 | ResnetBlock(curr_dim), 337 | Downsample(curr_dim, curr_dim * 2) 338 | ])) 339 | 340 | curr_dim *= 2 341 | image_size //= 2 342 | 343 | # middle 344 | 345 | self.mid_block = ResnetBlock(curr_dim) 346 | 347 | # codebook 348 | 349 | self.codebook_input_shape = (curr_dim, image_size, image_size) 350 | 351 | # precompute how many bits a single sample is compressed to 352 | # so maskbit can take this value during sampling 353 | 354 | self.bits_per_image = prod(self.codebook_input_shape) 355 | 356 | self.lfq = LFQ( 357 | codebook_size = 2, # number of codes is not applicable, as they simply group all the bits and project into tokens for the transformer 358 | dim = curr_dim, 359 | **lfq_kwargs 360 | ) 361 | 362 | # decoder 363 | 364 | self.decoder = ModuleList([]) 365 | 366 | for _ in range(depth): 367 | self.decoder.append(ModuleList([ 368 | Upsample(curr_dim, curr_dim // 2), 369 | ResnetBlock(curr_dim // 2), 370 | ])) 371 | 372 | curr_dim //= 2 373 | 374 | self.proj_out = nn.Conv2d(curr_dim, channels, 3, padding = 1) 375 | 376 | # discriminator 377 | 378 | self.discr = Discriminator( 379 | dims = (dim,) * int(log2(image_size) - 2), 380 | channels = channels, 381 | **discr_kwargs 382 | ) 383 | 384 | # aux loss 385 | 386 | self.entropy_loss_weight = entropy_loss_weight 387 | 388 | self.reg_loss_weight = reg_loss_weight 389 | 390 | self.gen_loss_weight = gen_loss_weight 391 | 392 | # tensor typing related 393 | 394 | self._c = channels 395 | 396 | def decode_bits_to_images( 397 | self, 398 | bits: Float['b d h w'] | Float['b n'] | Bool['b d h w'] | Bool['b n'] 399 | ): 400 | 401 | if bits.dtype == torch.bool: 402 | bits = bits.float() * 2 - 1 403 | 404 | if bits.ndim == 2: 405 | fmap_height, fmap_width = self.codebook_input_shape[-2:] 406 | bits = rearrange(bits, 'b (d h w) -> b d h w', h = fmap_height, w = fmap_width) 407 | 408 | x = bits 409 | 410 | for upsample, resnet in self.decoder: 411 | x = upsample(x) 412 | x = resnet(x) 413 | 414 | recon = self.proj_out(x) 415 | 416 | return recon 417 | 418 | def forward( 419 | self, 420 | images: Float['b {self._c} h w'], 421 | *, 422 | return_loss = True, 423 | return_discr_loss = False, 424 | return_details = False, 425 | return_quantized_bits = False, 426 | return_bits_as_bool = False 427 | ): 428 | batch = images.shape[0] 429 | 430 | assert images.shape[-2:] == ((self.image_size,) * 2) 431 | assert not (return_loss and return_quantized_bits) 432 | 433 | x = self.proj_in(images) 434 | 435 | for resnet, downsample in self.encoder: 436 | x = resnet(x) 437 | x = downsample(x) 438 | 439 | x = self.mid_block(x) 440 | 441 | bits, _, entropy_aux_loss = self.lfq(x) 442 | 443 | if return_quantized_bits: 444 | if return_bits_as_bool: 445 | bits = bits > 0. 446 | 447 | return bits 448 | 449 | assert (bits.numel() // batch) == self.bits_per_image 450 | 451 | x = bits 452 | 453 | for upsample, resnet in self.decoder: 454 | x = upsample(x) 455 | x = resnet(x) 456 | 457 | recon = self.proj_out(x) 458 | 459 | if return_discr_loss: 460 | images = images.requires_grad_() 461 | recon = recon.detach() 462 | 463 | discr_real_logits, reg_loss_real = self.discr(images, is_real = True) 464 | discr_fake_logits, reg_loss_fake = self.discr(recon, is_real = False) 465 | 466 | discr_loss = hinge_discr_loss(discr_fake_logits, discr_real_logits) 467 | 468 | reg_loss = (reg_loss_real + reg_loss_fake) / 2 469 | 470 | loss = discr_loss + reg_loss * self.reg_loss_weight 471 | 472 | if not return_details: 473 | return loss 474 | 475 | return loss, recon, (discr_loss, reg_loss_real, reg_loss_fake) 476 | 477 | if not return_loss: 478 | return recon 479 | 480 | recon_loss = F.mse_loss(images, recon) 481 | 482 | discr_fake_logits = self.discr(recon) 483 | 484 | gen_loss = hinge_gen_loss(discr_fake_logits) 485 | 486 | total_loss = ( 487 | recon_loss + 488 | entropy_aux_loss * self.entropy_loss_weight + 489 | gen_loss * self.gen_loss_weight 490 | ) 491 | 492 | if not return_details: 493 | return total_loss 494 | 495 | return total_loss, recon, (recon_loss, entropy_aux_loss, gen_loss) 496 | 497 | # class 498 | 499 | class MaskBit(Module): 500 | 501 | @beartype 502 | def __init__( 503 | self, 504 | vae: BQVAE, 505 | *, 506 | bits_group_size, 507 | dim, 508 | depth, 509 | bits_groups = 2, 510 | dim_head = 64, 511 | heads = 8, 512 | encoder_kwargs: dict = dict(), 513 | loss_ignore_index = -1, 514 | train_frac_bits_flipped = 0.05 515 | ): 516 | super().__init__() 517 | 518 | vae.eval() 519 | self.vae = vae 520 | 521 | self.bits_groups = bits_groups 522 | # bits_group_size (bits per "token") / bit_groups consecutive bits are masked at a time 523 | 524 | assert divisible_by(bits_group_size, bits_groups) 525 | self.consecutive_bits_to_mask = bits_group_size // bits_groups 526 | 527 | self.demasking_transformer = nn.Sequential( 528 | Rearrange('b (n g) -> b n g', g = bits_group_size), 529 | nn.Linear(bits_group_size, dim), 530 | Encoder( 531 | dim = dim, 532 | depth = depth, 533 | attn_dim_head = dim_head, 534 | heads = heads, 535 | **encoder_kwargs 536 | ), 537 | nn.Linear(dim, bits_group_size * 2), 538 | Rearrange('b n (g bits) -> b (n g) bits', bits = 2) 539 | ) 540 | 541 | self.loss_ignore_index = loss_ignore_index 542 | 543 | self.train_frac_bits_flipped = train_frac_bits_flipped 544 | 545 | # tensor typing 546 | 547 | self._c = vae.channels 548 | 549 | def parameters(self): 550 | return self.demasking_transformer.parameters() 551 | 552 | @property 553 | def device(self): 554 | return next(self.parameters()).device 555 | 556 | @torch.no_grad() 557 | def sample( 558 | self, 559 | batch_size = 1, 560 | num_demasking_steps = 18, 561 | temperature = 1., 562 | return_bits = False, 563 | return_bits_as_bool = False, 564 | ): 565 | device = self.device 566 | 567 | seq_len = self.vae.bits_per_image 568 | 569 | bits = torch.zeros(batch_size, seq_len, device = device) # start off all masked, 0. 570 | 571 | # times go from 0. to 1. for `num_demasking_steps` 572 | 573 | times = torch.linspace(0., 1., num_demasking_steps, device = device) 574 | noise_levels = torch.cos(times * pi * 0.5) 575 | num_bits_to_mask = (noise_levels * seq_len).long().ceil().clamp(min = 1) 576 | 577 | # iteratively denoise with attention 578 | 579 | for ind, bits_to_mask in tqdm(enumerate(num_bits_to_mask)): 580 | is_first = ind == 0 581 | 582 | # if not the first step, mask by the previous step's bit predictions with highest entropy 583 | 584 | if not is_first: 585 | entropy = calc_entropy(logits) 586 | remask_indices = entropy.topk(bits_to_mask.item(), dim = -1).indices 587 | bits.scatter_(1, remask_indices, 0.) # recall they use 0. for masking 588 | 589 | # ask the attention network to predict the bits 590 | 591 | logits = self.demasking_transformer(bits) 592 | 593 | # sample the bits 594 | 595 | bits = gumbel_sample(logits, temperature = temperature) 596 | bits = (bits * 2 - 1.) # bits are -1. or +1 597 | 598 | images = self.vae.decode_bits_to_images(bits) 599 | 600 | if not return_bits: 601 | return images 602 | 603 | if return_bits_as_bool: 604 | bits = bits > 0. 605 | 606 | return images, bits 607 | 608 | def forward( 609 | self, 610 | images: Float['b {self._c} h w'] 611 | ): 612 | batch, device = images.shape[0], self.device 613 | 614 | with torch.no_grad(): 615 | self.vae.eval() 616 | 617 | bits = self.vae( 618 | images, 619 | return_loss = False, 620 | return_quantized_bits = True 621 | ) 622 | 623 | # pack the bits into one long sequence 624 | 625 | bits, _ = pack_one(bits, 'b *') 626 | 627 | num_bits, orig_bits = bits.shape[-1], bits 628 | 629 | # flip a few of the bits, so that the model learns to predict for tokens that are not masked 630 | 631 | if self.train_frac_bits_flipped > 0.: 632 | num_bits_to_flip = num_bits * self.train_frac_bits_flipped 633 | flip_mask = torch.rand_like(bits).argsort(dim = -1) < num_bits_to_flip 634 | 635 | bits = torch.where(flip_mask, bits * -1, bits) 636 | 637 | # get the masking fraction, which is a function of time and the noising schedule (we will go with the successful cosine schedule here from Nichol et al) 638 | 639 | times = torch.rand(batch, device = device) 640 | noise_level = torch.cos(times * pi * 0.5) 641 | 642 | # determine num bit groups and reshape 643 | 644 | assert divisible_by(num_bits, self.consecutive_bits_to_mask) 645 | 646 | bits = rearrange(bits, 'b (ng g) -> b ng g', g = self.consecutive_bits_to_mask) 647 | 648 | bit_group_seq_len = bits.shape[1] 649 | 650 | num_bit_group_mask = (bit_group_seq_len * noise_level).ceil().clamp(min = 1) 651 | 652 | # mask some fraction of the bits 653 | 654 | mask = torch.rand((batch, bit_group_seq_len), device = device).argsort(dim = -1) < num_bit_group_mask 655 | 656 | masked_bits = einx.where('b ng, , b ng g -> b (ng g)', mask, 0., bits) # main contribution of the paper is just this line of code where they mask bits to 0. 657 | 658 | # attention 659 | 660 | preds = self.demasking_transformer(masked_bits) 661 | 662 | # loss mask 663 | 664 | mask = repeat(mask, 'b ng -> b (ng g)', g = self.consecutive_bits_to_mask) 665 | 666 | loss_mask = mask | flip_mask 667 | 668 | # get loss 669 | 670 | labels = (orig_bits[loss_mask] > 0.).long() 671 | 672 | loss = F.cross_entropy( 673 | preds[loss_mask], 674 | labels, 675 | ignore_index = self.loss_ignore_index 676 | ) 677 | 678 | return loss 679 | -------------------------------------------------------------------------------- /maskbit_pytorch/trainer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from shutil import rmtree 3 | from functools import partial 4 | 5 | from beartype import beartype 6 | 7 | import torch 8 | from torch import nn, tensor 9 | from torch.nn import Module, ModuleList 10 | from torch.utils.data import Dataset, DataLoader, random_split 11 | 12 | from adam_atan2_pytorch import Adam 13 | 14 | import torchvision.transforms as T 15 | from torchvision.datasets import ImageFolder 16 | from torchvision.utils import make_grid, save_image 17 | 18 | from maskbit_pytorch.maskbit import BQVAE, MaskBit 19 | 20 | from einops import rearrange 21 | 22 | from accelerate import ( 23 | Accelerator, 24 | DistributedType, 25 | DistributedDataParallelKwargs 26 | ) 27 | 28 | from ema_pytorch import EMA 29 | 30 | from PIL import Image, ImageFile 31 | ImageFile.LOAD_TRUNCATED_IMAGES = True 32 | 33 | # helper functions 34 | 35 | def exists(val): 36 | return val is not None 37 | 38 | def identity(t, *args, **kwargs): 39 | return t 40 | 41 | def noop(*args, **kwargs): 42 | pass 43 | 44 | def find_index(arr, cond): 45 | for ind, el in enumerate(arr): 46 | if cond(el): 47 | return ind 48 | return None 49 | 50 | def find_and_pop(arr, cond, default = None): 51 | ind = find_index(arr, cond) 52 | 53 | if exists(ind): 54 | return arr.pop(ind) 55 | 56 | if callable(default): 57 | return default() 58 | 59 | return default 60 | 61 | def cycle(dl): 62 | while True: 63 | for data in dl: 64 | yield data 65 | 66 | def cast_tuple(t): 67 | return t if isinstance(t, (tuple, list)) else (t,) 68 | 69 | def yes_or_no(question): 70 | answer = input(f'{question} (y/n) ') 71 | return answer.lower() in ('yes', 'y') 72 | 73 | def accum_log(log, new_logs): 74 | for key, new_value in new_logs.items(): 75 | old_value = log.get(key, 0.) 76 | log[key] = old_value + new_value 77 | return log 78 | 79 | def pair(val): 80 | return val if isinstance(val, tuple) else (val, val) 81 | 82 | def convert_image_to_fn(img_type, image): 83 | if image.mode != img_type: 84 | return image.convert(img_type) 85 | return image 86 | 87 | # image related helpers fnuctions and dataset 88 | 89 | class ImageDataset(Dataset): 90 | def __init__( 91 | self, 92 | folder, 93 | image_size, 94 | exts = ['jpg', 'jpeg', 'png'] 95 | ): 96 | super().__init__() 97 | self.folder = folder 98 | self.image_size = image_size 99 | self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] 100 | 101 | print(f'{len(self.paths)} training samples found at {folder}') 102 | assert len(self) > 0 103 | 104 | self.transform = T.Compose([ 105 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 106 | T.Resize(image_size), 107 | T.RandomHorizontalFlip(), 108 | T.CenterCrop(image_size), 109 | T.ToTensor() 110 | ]) 111 | 112 | def __len__(self): 113 | return len(self.paths) 114 | 115 | def __getitem__(self, index): 116 | path = self.paths[index] 117 | img = Image.open(path) 118 | return self.transform(img) 119 | 120 | # vae trainer class 121 | 122 | class BQVAETrainer(Module): 123 | 124 | @beartype 125 | def __init__( 126 | self, 127 | vae: BQVAE, 128 | *, 129 | folder, 130 | num_train_steps, 131 | batch_size, 132 | image_size, 133 | lr = 3e-4, 134 | grad_accum_every = 1, 135 | max_grad_norm = None, 136 | discr_max_grad_norm = None, 137 | save_results_every = 100, 138 | save_model_every = 1000, 139 | results_folder = './results', 140 | valid_frac = 0.05, 141 | random_split_seed = 42, 142 | use_ema = True, 143 | ema_beta = 0.995, 144 | ema_update_after_step = 0, 145 | ema_update_every = 1, 146 | ema_kwargs: dict = dict(), 147 | accelerate_kwargs: dict = dict() 148 | ): 149 | super().__init__() 150 | 151 | # instantiate accelerator 152 | 153 | kwargs_handlers = accelerate_kwargs.get('kwargs_handlers', []) 154 | 155 | ddp_kwargs = find_and_pop( 156 | kwargs_handlers, 157 | lambda x: isinstance(x, DistributedDataParallelKwargs), 158 | partial(DistributedDataParallelKwargs, find_unused_parameters = True) 159 | ) 160 | 161 | ddp_kwargs.find_unused_parameters = True 162 | kwargs_handlers.append(ddp_kwargs) 163 | accelerate_kwargs.update(kwargs_handlers = kwargs_handlers) 164 | 165 | self.accelerator = Accelerator(**accelerate_kwargs) 166 | 167 | # vae 168 | 169 | self.vae = vae 170 | 171 | # training params 172 | 173 | self.register_buffer('steps', tensor(0)) 174 | 175 | self.num_train_steps = num_train_steps 176 | self.batch_size = batch_size 177 | self.grad_accum_every = grad_accum_every 178 | 179 | all_parameters = set(vae.parameters()) 180 | discr_parameters = set(vae.discr.parameters()) 181 | vae_parameters = all_parameters - discr_parameters 182 | 183 | self.vae_parameters = vae_parameters 184 | 185 | # optimizers 186 | 187 | self.optim = Adam(vae_parameters, lr = lr) 188 | self.discr_optim = Adam(discr_parameters, lr = lr) 189 | 190 | self.max_grad_norm = max_grad_norm 191 | self.discr_max_grad_norm = discr_max_grad_norm 192 | 193 | # create dataset 194 | 195 | self.ds = ImageDataset(folder, image_size) 196 | 197 | # split for validation 198 | 199 | if valid_frac > 0: 200 | train_size = int((1 - valid_frac) * len(self.ds)) 201 | valid_size = len(self.ds) - train_size 202 | self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed)) 203 | self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples') 204 | else: 205 | self.valid_ds = self.ds 206 | self.print(f'training with shared training and valid dataset of {len(self.ds)} samples') 207 | 208 | # dataloader 209 | 210 | self.dl = DataLoader( 211 | self.ds, 212 | batch_size = batch_size, 213 | shuffle = True 214 | ) 215 | 216 | self.valid_dl = DataLoader( 217 | self.valid_ds, 218 | batch_size = batch_size, 219 | shuffle = True 220 | ) 221 | 222 | # prepare with accelerator 223 | 224 | ( 225 | self.vae, 226 | self.optim, 227 | self.discr_optim, 228 | self.dl, 229 | self.valid_dl 230 | ) = self.accelerator.prepare( 231 | self.vae, 232 | self.optim, 233 | self.discr_optim, 234 | self.dl, 235 | self.valid_dl 236 | ) 237 | 238 | self.use_ema = use_ema 239 | 240 | if use_ema: 241 | self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every, **ema_kwargs) 242 | self.ema_vae = self.accelerator.prepare(self.ema_vae) 243 | 244 | self.dl_iter = cycle(self.dl) 245 | self.valid_dl_iter = cycle(self.valid_dl) 246 | 247 | self.save_model_every = save_model_every 248 | self.save_results_every = save_results_every 249 | 250 | self.results_folder = Path(results_folder) 251 | 252 | if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'): 253 | rmtree(str(self.results_folder)) 254 | 255 | self.results_folder.mkdir(parents = True, exist_ok = True) 256 | 257 | def save(self, path): 258 | if not self.accelerator.is_local_main_process: 259 | return 260 | 261 | pkg = dict( 262 | model = self.accelerator.get_state_dict(self.vae), 263 | optim = self.optim.state_dict(), 264 | discr_optim = self.discr_optim.state_dict() 265 | ) 266 | 267 | torch.save(pkg, path) 268 | 269 | def load(self, path): 270 | path = Path(path) 271 | assert path.exists() 272 | pkg = torch.load(path) 273 | 274 | vae = self.accelerator.unwrap_model(self.vae) 275 | vae.load_state_dict(pkg['model']) 276 | 277 | self.optim.load_state_dict(pkg['optim']) 278 | self.discr_optim.load_state_dict(pkg['discr_optim']) 279 | 280 | def print(self, msg): 281 | self.accelerator.print(msg) 282 | 283 | @property 284 | def device(self): 285 | return self.accelerator.device 286 | 287 | @property 288 | def is_distributed(self): 289 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) 290 | 291 | @property 292 | def is_main(self): 293 | return self.accelerator.is_main_process 294 | 295 | @property 296 | def is_local_main(self): 297 | return self.accelerator.is_local_main_process 298 | 299 | def train_step(self): 300 | acc = self.accelerator 301 | device = self.device 302 | 303 | steps = int(self.steps.item()) 304 | 305 | self.vae.train() 306 | discr = self.vae.module.discr if self.is_distributed else self.vae.discr 307 | 308 | if self.use_ema: 309 | ema_vae = self.ema_vae.module if self.is_distributed else self.ema_vae 310 | 311 | # logs 312 | 313 | logs = dict() 314 | 315 | # update vae (generator) 316 | 317 | for _ in range(self.grad_accum_every): 318 | img = next(self.dl_iter) 319 | img = img.to(device) 320 | 321 | with acc.autocast(): 322 | loss = self.vae( 323 | img, 324 | return_loss = True 325 | ) 326 | 327 | acc.backward(loss / self.grad_accum_every) 328 | 329 | accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) 330 | 331 | if exists(self.max_grad_norm): 332 | acc.clip_grad_norm_(self.vae.parameters(), self.max_grad_norm) 333 | 334 | self.optim.step() 335 | self.optim.zero_grad() 336 | 337 | # update discriminator 338 | 339 | if exists(discr): 340 | self.discr_optim.zero_grad() 341 | 342 | for _ in range(self.grad_accum_every): 343 | img = next(self.dl_iter) 344 | img = img.to(device) 345 | 346 | loss = self.vae( 347 | img, 348 | return_discr_loss = True 349 | ) 350 | 351 | acc.backward(loss / self.grad_accum_every) 352 | 353 | accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every}) 354 | 355 | if exists(self.discr_max_grad_norm): 356 | acc.clip_grad_norm_(discr.parameters(), self.discr_max_grad_norm) 357 | 358 | self.discr_optim.step() 359 | 360 | # log 361 | 362 | self.print(f"{steps}: vae loss: {logs['loss']:.3f} - discr loss: {logs['discr_loss']:.3f}") 363 | 364 | # update exponential moving averaged generator 365 | 366 | if self.use_ema: 367 | ema_vae.update() 368 | 369 | # sample results every so often 370 | 371 | if not (steps % self.save_results_every): 372 | vaes_to_evaluate = ((self.vae, str(steps)),) 373 | 374 | if self.use_ema: 375 | vaes_to_evaluate = ((ema_vae.ema_model, f'{steps}.ema'),) + vaes_to_evaluate 376 | 377 | for model, filename in vaes_to_evaluate: 378 | model.eval() 379 | 380 | valid_data = next(self.valid_dl_iter) 381 | valid_data = valid_data.to(device) 382 | 383 | _, recons, _ = model(valid_data, return_details = True) 384 | 385 | # else save a grid of images 386 | 387 | imgs_and_recons = rearrange([valid_data, recons], 'r b ... -> (b r) ...') 388 | 389 | imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.) 390 | grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1)) 391 | 392 | logs['reconstructions'] = grid 393 | 394 | save_image(grid, str(self.results_folder / f'{filename}.png')) 395 | 396 | self.print(f'{steps}: saving to {str(self.results_folder)}') 397 | 398 | # save model every so often 399 | 400 | acc.wait_for_everyone() 401 | 402 | if self.is_main and not (steps % self.save_model_every): 403 | state_dict = acc.unwrap_model(self.vae).state_dict() 404 | model_path = str(self.results_folder / f'vae.{steps}.pt') 405 | acc.save(state_dict, model_path) 406 | 407 | if self.use_ema: 408 | ema_state_dict = acc.unwrap_model(self.ema_vae).state_dict() 409 | model_path = str(self.results_folder / f'vae.{steps}.ema.pt') 410 | acc.save(ema_state_dict, model_path) 411 | 412 | self.print(f'{steps}: saving model to {str(self.results_folder)}') 413 | 414 | self.steps += 1 415 | return logs 416 | 417 | def forward(self): 418 | 419 | while self.steps < self.num_train_steps: 420 | logs = self.train_step() 421 | 422 | self.print('training complete') 423 | 424 | # maskbit trainer 425 | 426 | class MaskBitTrainer(Module): 427 | def __init__( 428 | self, 429 | maskbit: MaskBit, 430 | folder, 431 | num_train_steps, 432 | batch_size, 433 | image_size, 434 | lr = 3e-4, 435 | grad_accum_every = 1, 436 | max_grad_norm = None, 437 | save_results_every = 100, 438 | save_model_every = 1000, 439 | results_folder = './results', 440 | valid_frac = 0.05, 441 | random_split_seed = 42, 442 | accelerate_kwargs: dict = dict() 443 | ): 444 | super().__init__() 445 | 446 | # instantiate accelerator 447 | 448 | kwargs_handlers = accelerate_kwargs.get('kwargs_handlers', []) 449 | 450 | ddp_kwargs = find_and_pop( 451 | kwargs_handlers, 452 | lambda x: isinstance(x, DistributedDataParallelKwargs), 453 | partial(DistributedDataParallelKwargs, find_unused_parameters = True) 454 | ) 455 | 456 | ddp_kwargs.find_unused_parameters = True 457 | kwargs_handlers.append(ddp_kwargs) 458 | accelerate_kwargs.update(kwargs_handlers = kwargs_handlers) 459 | 460 | self.accelerator = Accelerator(**accelerate_kwargs) 461 | 462 | # training params 463 | 464 | self.register_buffer('steps', tensor(0)) 465 | 466 | self.num_train_steps = num_train_steps 467 | self.batch_size = batch_size 468 | self.grad_accum_every = grad_accum_every 469 | 470 | # model 471 | 472 | self.maskbit = maskbit 473 | 474 | # optimizers 475 | 476 | self.optim = Adam(maskbit.parameters(), lr = lr) 477 | 478 | self.max_grad_norm = max_grad_norm 479 | 480 | # create dataset 481 | 482 | self.ds = ImageDataset(folder, image_size) 483 | 484 | # split for validation 485 | 486 | if valid_frac > 0: 487 | train_size = int((1 - valid_frac) * len(self.ds)) 488 | valid_size = len(self.ds) - train_size 489 | self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed)) 490 | self.print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples') 491 | else: 492 | self.valid_ds = self.ds 493 | self.print(f'training with shared training and valid dataset of {len(self.ds)} samples') 494 | 495 | # dataloader 496 | 497 | self.dl = DataLoader( 498 | self.ds, 499 | batch_size = batch_size, 500 | shuffle = True 501 | ) 502 | 503 | self.valid_dl = DataLoader( 504 | self.valid_ds, 505 | batch_size = batch_size, 506 | shuffle = True 507 | ) 508 | 509 | # prepare with accelerator 510 | 511 | ( 512 | self.maskbit, 513 | self.optim, 514 | self.dl, 515 | self.valid_dl 516 | ) = self.accelerator.prepare( 517 | self.maskbit, 518 | self.optim, 519 | self.dl, 520 | self.valid_dl 521 | ) 522 | 523 | self.dl_iter = cycle(self.dl) 524 | self.valid_dl_iter = cycle(self.valid_dl) 525 | 526 | self.save_model_every = save_model_every 527 | self.save_results_every = save_results_every 528 | 529 | self.results_folder = Path(results_folder) 530 | 531 | if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'): 532 | rmtree(str(self.results_folder)) 533 | 534 | self.results_folder.mkdir(parents = True, exist_ok = True) 535 | 536 | def save(self, path): 537 | if not self.accelerator.is_local_main_process: 538 | return 539 | 540 | pkg = dict( 541 | model = self.accelerator.get_state_dict(self.maskbit), 542 | optim = self.optim.state_dict(), 543 | ) 544 | 545 | torch.save(pkg, path) 546 | 547 | def load(self, path): 548 | path = Path(path) 549 | assert path.exists() 550 | pkg = torch.load(path) 551 | 552 | maskbit = self.accelerator.unwrap_model(self.maskbit) 553 | maskbit.load_state_dict(pkg['model']) 554 | 555 | self.optim.load_state_dict(pkg['optim']) 556 | 557 | def print(self, msg): 558 | self.accelerator.print(msg) 559 | 560 | @property 561 | def device(self): 562 | return self.accelerator.device 563 | 564 | @property 565 | def is_distributed(self): 566 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) 567 | 568 | @property 569 | def is_main(self): 570 | return self.accelerator.is_main_process 571 | 572 | @property 573 | def is_local_main(self): 574 | return self.accelerator.is_local_main_process 575 | 576 | def train_step(self): 577 | acc = self.accelerator 578 | device = self.device 579 | 580 | steps = int(self.steps.item()) 581 | 582 | self.maskbit.train() 583 | 584 | # logs 585 | 586 | logs = dict() 587 | 588 | # update vae (generator) 589 | 590 | for _ in range(self.grad_accum_every): 591 | img = next(self.dl_iter) 592 | img = img.to(device) 593 | 594 | with acc.autocast(): 595 | loss = self.maskbit(img) 596 | 597 | acc.backward(loss / self.grad_accum_every) 598 | 599 | accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) 600 | 601 | if exists(self.max_grad_norm): 602 | acc.clip_grad_norm_(self.maskbit.parameters(), self.max_grad_norm) 603 | 604 | self.optim.step() 605 | self.optim.zero_grad() 606 | 607 | # log 608 | 609 | self.print(f"{steps}: maskbit loss: {logs['loss']:.3f}") 610 | 611 | # save model every so often 612 | 613 | acc.wait_for_everyone() 614 | 615 | if self.is_main and not (steps % self.save_model_every): 616 | state_dict = acc.unwrap_model(self.maskbit).state_dict() 617 | model_path = str(self.results_folder / f'maskbit.{steps}.pt') 618 | acc.save(state_dict, model_path) 619 | 620 | self.print(f'{steps}: saving model to {str(self.results_folder)}') 621 | 622 | self.steps += 1 623 | return logs 624 | 625 | def forward(self): 626 | 627 | while self.steps < self.num_train_steps: 628 | logs = self.train_step() 629 | 630 | self.print('training complete') 631 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "maskbit-pytorch" 3 | version = "0.0.2" 4 | description = "MaskBit" 5 | authors = [ 6 | { name = "Phil Wang", email = "lucidrains@gmail.com" } 7 | ] 8 | readme = "README.md" 9 | requires-python = ">= 3.9" 10 | license = { file = "LICENSE" } 11 | keywords = [ 12 | 'artificial intelligence', 13 | 'deep learning', 14 | 'image generation', 15 | 'scalar quantization', 16 | 'transformers', 17 | 'attention mechanism' 18 | ] 19 | 20 | classifiers=[ 21 | 'Development Status :: 4 - Beta', 22 | 'Intended Audience :: Developers', 23 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 24 | 'License :: OSI Approved :: MIT License', 25 | 'Programming Language :: Python :: 3.9', 26 | ] 27 | 28 | dependencies = [ 29 | "accelerate", 30 | "adam-atan2-pytorch>=0.1.1", 31 | "beartype", 32 | "einx>=0.3.0", 33 | "einops>=0.8.0", 34 | "ema-pytorch", 35 | "jaxtyping", 36 | "pillow", 37 | "torch>=2.2", 38 | "torchvision", 39 | "vector-quantize-pytorch>=1.17.8", 40 | "x-transformers>=1.37.9" 41 | ] 42 | 43 | [project.urls] 44 | Homepage = "https://pypi.org/project/maskbit-pytorch/" 45 | Repository = "https://github.com/lucidrains/maskbit-pytorch" 46 | 47 | [project.optional-dependencies] 48 | examples = [] 49 | test = [ 50 | "pytest" 51 | ] 52 | 53 | [tool.pytest.ini_options] 54 | pythonpath = [ 55 | "." 56 | ] 57 | 58 | [build-system] 59 | requires = ["hatchling"] 60 | build-backend = "hatchling.build" 61 | 62 | [tool.rye] 63 | managed = true 64 | dev-dependencies = [] 65 | 66 | [tool.hatch.metadata] 67 | allow-direct-references = true 68 | 69 | [tool.hatch.build.targets.wheel] 70 | packages = ["maskbit_pytorch"] 71 | --------------------------------------------------------------------------------