├── .gitignore ├── LICENSE ├── MANIFEST.in ├── readme.md ├── setup.cfg ├── setup.py └── torchtools ├── __init__.py ├── lr_scheduler ├── __init__.py ├── delayed.py └── inverse_sqrt.py ├── nn ├── __init__.py ├── adain.py ├── alias_free_activation.py ├── equal_layers.py ├── evonorm2d.py ├── fourier_features.py ├── functional │ ├── __init__.py │ ├── gradient_penalty.py │ ├── magnitude_preserving.py │ ├── perceptual.py │ └── vq.py ├── gp_loss.py ├── haar_dwt.py ├── magnitude_preserving.py ├── mish.py ├── modulation.py ├── perceptual.py ├── pixel_normalzation.py ├── pos_embeddings.py ├── simple_self_attention.py ├── stylegan2 │ ├── __init__.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── transformers.py └── vq.py ├── optim ├── __init__.py ├── lamb.py ├── lookahead.py ├── novograd.py ├── over9000.py ├── radam.py ├── ralamb.py └── ranger.py ├── transforms ├── __init__.py ├── models │ ├── __init__.py │ └── saliency_model_v9.pt └── smart_crop.py └── utils ├── __init__.py ├── diffusion.py ├── diffusion2.py ├── gamma_parametrization.py └── weight_normalization.py /.gitignore: -------------------------------------------------------------------------------- 1 | */**/__pycache__ 2 | *.egg-info 3 | /dist 4 | *.pyc 5 | .idea -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Pablo Pernías 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | Footer 23 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include readme.md 3 | recursive-include torchtools/transforms/models * -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Pytorch Tools 2 | 3 | ## Install 4 | 5 | Requirements: 6 | ``` 7 | PyTorch >= 1.0.0 8 | Torchivision 9 | Numpy >= 1.0.0 10 | ``` 11 | 12 | ``` 13 | # In order to install the latest (beta) use 14 | pip install git+https://github.com/pabloppp/pytorch-tools -U 15 | 16 | # if you want to install a specific version to avoid breaking changes (for example, v0.3.5), use 17 | pip install git+https://github.com/pabloppp/pytorch-tools@0.3.5 -U 18 | ``` 19 | 20 | # Current available tools 21 | 22 | ## Optimizers 23 | 24 | Comparison table taken from https://github.com/mgrankin/over9000 25 | And the article explaining this recent improvements https://medium.com/@lessw/how-we-beat-the-fastai-leaderboard-score-by-19-77-a-cbb2338fab5c 26 | 27 | Dataset | LR Schedule| Imagenette size 128, 5 epoch | Imagewoof size 128, 5 epoch 28 | --- | -- | --- | --- 29 | Adam - baseline |OneCycle| 0.8493 | 0.6125 30 | RangerLars (RAdam + LARS + Lookahead) |Flat and anneal| 0.8732 | 0.6523 31 | Ralamb (RAdam + LARS) |Flat and anneal| 0.8675 | 0.6367 32 | Ranger (RAdam + Lookahead) |Flat and anneal| 0.8594 | 0.5946 33 | Novograd |Flat and anneal| 0.8711 | 0.6126 34 | Radam |Flat and anneal| 0.8444 | 0.537 35 | Lookahead |OneCycle| 0.8578 | 0.6106 36 | Lamb |OneCycle| 0.8400 | 0.5597 37 | 38 | ### Ranger 39 | Taken as is from https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 40 | Blog post: https://medium.com/@lessw/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d 41 | 42 | Example of use: 43 | ```python 44 | from torchtools.optim import Ranger 45 | 46 | optimizer = Ranger(model.parameters()) 47 | ``` 48 | 49 | ### RAdam 50 | Taken as is from https://github.com/LiyuanLucasLiu/RAdam 51 | Blog post: https://medium.com/@lessw/new-state-of-the-art-ai-optimizer-rectified-adam-radam-5d854730807b 52 | Original Paper: https://arxiv.org/abs/1908.03265 53 | 54 | Example of use: 55 | ```python 56 | from torchtools.optim import RAdam, PlainRAdam, AdamW 57 | 58 | optimizer = RAdam(model.parameters()) 59 | # optimizer = PlainRAdam(model.parameters()) 60 | # optimizer = AdamW(model.parameters()) 61 | ``` 62 | 63 | ### RangerLars (Over9000) 64 | Taken as is from https://github.com/mgrankin/over9000 65 | 66 | Example of use: 67 | ```python 68 | from torchtools.optim import RangerLars # Over9000 69 | 70 | optimizer = RangerLars(model.parameters()) 71 | ``` 72 | 73 | ### Novograd 74 | Taken as is from https://github.com/mgrankin/over9000 75 | 76 | Example of use: 77 | ```python 78 | from torchtools.optim import Novograd 79 | 80 | optimizer = Novograd(model.parameters()) 81 | ``` 82 | 83 | ### Ralamb 84 | Taken as is from https://github.com/mgrankin/over9000 85 | 86 | Example of use: 87 | ```python 88 | from torchtools.optim import Ralamb 89 | 90 | optimizer = Ralamb(model.parameters()) 91 | ``` 92 | 93 | ### Lookahead 94 | Taken as is from https://github.com/lonePatient/lookahead_pytorch 95 | Original Paper: https://arxiv.org/abs/1907.08610 96 | 97 | This lookahead can be used with any optimizer 98 | 99 | Example of use: 100 | ```python 101 | from torch import optim 102 | from torchtools.optim import Lookahead 103 | 104 | optimizer = optim.Adam(model.parameters(), lr=0.001) 105 | optimizer = Lookahead(base_optimizer=optimizer, k=5, alpha=0.5) 106 | 107 | # for a base Lookahead + Adam you can just do: 108 | # 109 | # from torchtools.optim import LookaheadAdam 110 | ``` 111 | 112 | ### Lamb 113 | Taken as is from https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py 114 | Original Paper: https://arxiv.org/abs/1904.00962 115 | 116 | Example of use: 117 | ```python 118 | from torchtools.optim import Lamb 119 | 120 | optimizer = Lamb(model.parameters()) 121 | ``` 122 | 123 | ## LR Schedulers 124 | 125 | ### Delayed LR 126 | Allows for a customizable number of initial steps where the learning rate remains fixed. 127 | After those steps the learning rate will be updated according to the supplied scheduler's policy 128 | 129 | Example of use: 130 | ```python 131 | from torch import optim, nn 132 | from torchtools.lr_scheduler import DelayerScheduler 133 | 134 | optimizer = optim.Adam(model.parameters(), lr=0.001) # define here your optimizer, the lr that you set will be the one used for the initial delay steps 135 | 136 | delay_epochs = 10 137 | total_epochs = 20 138 | base_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, delay_epochs) # delay the scheduler for 10 steps 139 | delayed_scheduler = DelayerScheduler(optimizer, total_epochs - delay_epochs, base_scheduler) 140 | 141 | for epoch in range(total_epochs): 142 | # train(...) 143 | delayed_scheduler.step() 144 | 145 | # The lr will be 0.001 for the first 10 epochs, then will use the policy fro the base_scheduler for the rest of the epochs 146 | 147 | 148 | # for a base DelayerScheduler + CosineAnnealingLR you can just do: 149 | # 150 | # from torchtools.lr_scheduler import DelayedCosineAnnealingLR 151 | # scheduler = DelayedCosineAnnealingLR(optimizer, delay_epochs, cosine_annealing_epochs) # the sum of both must be the total number of epochs 152 | ``` 153 | 154 | ## Activations 155 | 156 | ### Mish 157 | Original implementation: https://github.com/digantamisra98/Mish 158 | Original Paper: https://arxiv.org/abs/1908.08681v1 159 | Implementation taken as is from https://github.com/lessw2020/mish 160 | 161 | Example of use: 162 | ```python 163 | from torchtools.nn import Mish 164 | 165 | # Then you can just use Mish as a replacement for any activation function, like ReLU 166 | ``` 167 | 168 | ### AliasFreeActivation 169 | Implementation based on https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L225 by Rosinality 170 | I modularized this activation so it can be easily used inside of any model without having to deal with complex initialization. 171 | 172 | This activation actually takes a lot of responsibility, since it internally defines the channels and size of the input based on a set of parameters, instead of receiving them as a parameter, this means that the rest of the layers (convolutions, positional embedding, etc...) must adapt to it. 173 | 174 | Example of use: 175 | ```python 176 | from torchtools.nn.alias_free_activation import AliasFreeActivation 177 | from torchtools.nn import EqualLeakyReLU 178 | 179 | # We can use the static function to get the filter parameters for a specific level. 180 | # It can be specially usefull to obtain the initial size and channels. 181 | max_size, max_channels = 256, 512 182 | first_channels, first_size = AliasFreeActivation.alias_level_params( 183 | 0, max_levels=14, max_size=max_size, max_channels=max_channels 184 | )[-2:] 185 | 186 | class MyModel(nn.Module): 187 | def __init__(self, level, max_levels=14, max_size=256, max_channels=512, margin=10): 188 | ... 189 | # AdaIN will require the style vector to be 2*size 190 | leaky_relu = EqualLeakyReLU(negative_slope=0.2) 191 | self.activation = AliasFreeActivation( 192 | leaky_relu, level, max_levels=max_levels, max_size=max_size, max_channels=max_channels, margin=margin 193 | ) 194 | self.conv = nn.Conv2d(self.activation.channels_prev, self.activation.channels, kernel_size=3, padding=1) 195 | ... 196 | 197 | def forward(self, x): # x the channels and size of X are dependent on the level of this module. 198 | ... 199 | x = self.conv(x) 200 | x = self.activation(x) 201 | ... 202 | 203 | ``` 204 | 205 | ## Layers 206 | 207 | ### SimpleSelfAttention 208 | Implementation taken as is from https://github.com/sdoria/SimpleSelfAttention 209 | 210 | Example of use: 211 | ```python 212 | from torchtools.nn import SimpleSelfAttention 213 | 214 | # The input of the layer has to at least have 3 dimensions (B, C, N), 215 | # the attention will be performed in the 2nd dimension. 216 | # 217 | # For images, the input will be internally reshaped to 3 dimensions, 218 | # and reshaped back to the original shape before returning it 219 | ``` 220 | 221 | ### PixelNorm 222 | Inspired from https://github.com/github-pengge/PyTorch-progressive_growing_of_gans 223 | 224 | Example of use: 225 | ```python 226 | from torchtools.nn import PixelNorm 227 | 228 | model = nn.Linear( 229 | nn.Conv2d(...), 230 | PixelNorm(), 231 | nn.ReLU() 232 | ) 233 | 234 | # It doesn't require any parameter, it just performs a simple element-wise normalization 235 | # x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8) 236 | # 237 | # Just use it as a regular layer, generally after convolutions and before ReLU 238 | # (warning) since it performs a srtq root it's pretty slow if the layer sizes are big 239 | ``` 240 | 241 | ### Adaptive Instance Normalization - AdaIN 242 | Implementation based on https://github.com/SiskonEmilia/StyleGAN-PyTorch 243 | Original Paper https://arxiv.org/abs/1703.06868 244 | 245 | Example of use: 246 | ```python 247 | from torchtools.nn import AdaIN 248 | 249 | class MyModel(nn.Module): 250 | def __init__(self, n_channels): 251 | ... 252 | # AdaIN will require the style vector to be 2*size 253 | self.style = nn.Linear(input_size, output_size*2) 254 | self.adain = AdaIN(output_size) 255 | ... 256 | 257 | def forward(self, x, w): 258 | ... 259 | x = self.adain(x, self.style(w)) 260 | ... 261 | 262 | # AdaIN will "transfer" a style encoded in a latent vector w into any tensor x. 263 | # In order to do this it first needs to be passed through a linear layer that will return 2 tensors (actually, one tensor of twice the size required, that we'll then split in 2) 264 | # It will then perform an instance normalization to "whiten" the tensor, followed with a de-normalization but using the values generated by the linear layer, thus encoding the original vector w in the tensor. 265 | ``` 266 | 267 | ### EvoNorm 268 | Implementation taken as is from https://github.com/digantamisra98/EvoNorm all credit goes to digantamisra98 269 | Original Paper https://arxiv.org/abs/2004.02967 270 | 271 | Example of use: 272 | ```python 273 | from torchtools.nn import EvoNorm2D 274 | 275 | model = nn.Linear( 276 | nn.Conv2d(...), 277 | EvoNorm2D(c_hidden), # For S0 version 278 | # evoB0 = EvoNorm2D(input, affine = True, version = 'B0', training = True) # For B0 version 279 | nn.ReLU() 280 | ) 281 | ``` 282 | 283 | ### GPT Transformer Encoder Layer 284 | Implementation based on MinGPT https://github.com/karpathy/minGPT by Andrej Karpathy 285 | 286 | It can be used as a drop-in replacement for the `torch.nn.TransformerEncoderLayer` 287 | 288 | Example of use: 289 | ```python 290 | from torchtools.nn import GPTTransformerEncoderLayer 291 | 292 | class MyTransformer(nn.Module): 293 | def __init__(self, n_channels): 294 | ... 295 | encoder_layer = GPTTransformerEncoderLayer(d_model=512, nhead=8) 296 | self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) 297 | ... 298 | 299 | ``` 300 | 301 | ### Stylegan2 ModulatedConv2d 302 | Implementation based on https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L143 by Rosinality 303 | 304 | It extends from `torch.nn.Conv2d` so you can use it as a drop-in replacement, the only Conv2d parámeter that you cannot use is 'groups' since it will be overriden for this to work. 305 | 306 | It also includes a parameter `ema_decay` that will add the EMA normalization used in Alias-free GAN (defaults to 1, meaning that it's disabled) 307 | 308 | Example of use: 309 | ```python 310 | from torchtools.nn import ModulatedConv2d 311 | 312 | class MyModel(nn.Module): 313 | def __init__(self): 314 | ... 315 | self.conv = ModulatedConv2d(16, 32, kernel_size=3, padding=1) 316 | # SUGESTIONS: 317 | # set bias=False if you want to handle bias on your own 318 | # set demodulate=False for RGB output 319 | # set ema_decay=0.9989 to imitate the alias-free gan setup 320 | ... 321 | 322 | def forward(self, x, w): 323 | ... 324 | x = self.conv(x, w) # 'x' is a 4D tensor (B x C x W x H) and 'w' is a 2D tensor (B x C) 325 | ... 326 | ``` 327 | 328 | ### Equal Layers (EqualNorm, EqualLinear) 329 | Implementation based on https://github.com/rosinality/alias-free-gan-pytorch/blob/main/stylegan2/model.py#L94 330 | 331 | It extends the base classes (nn.Linear, nn.Conv2d, nn.LeakyReLU) so you can use this as a drop-in replacement, although it includes some optiona parameters. 332 | 333 | Example of use: 334 | ```python 335 | from torchtools.nn import EqualLinear, EqualLeakyReLU, EqualConv2d 336 | 337 | class MyModel(nn.Module): 338 | def __init__(self): 339 | ... 340 | self.linear = EqualLinear(16, 32, bias_init=1, lr_mul=0.01) # bias_init and lr_mul are extra optional params 341 | self.leaky_relu = EqualLeakyReLU(negative_slope=0.2) 342 | self.conv = EqualConv2d(16, 32, kernel_size=3, padding=1) 343 | # Since this classes extend from the base classes, you can use all parameters from the original classes. 344 | ... 345 | 346 | ``` 347 | 348 | ### FourierFeatures2d 349 | Implementation inspired on https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L88 350 | but improved using my own understanding of how this should work... 351 | 352 | It creates a 2d tensor of embeddings following a fourier series based on the parameters you provide, this features are dynamic, meaning that affine transformations can be applied to them in order to shift, rotate, and even scale (experimental). 353 | 354 | ```python 355 | from torchtools.nn import EqualLinear, EqualLeakyReLU, EqualConv2d 356 | 357 | class MyModel(nn.Module): 358 | def __init__(self, dim=256, margin=10, cutoff=2): 359 | ... 360 | self.feats = FourierFeatures2d(4+margin*2, dim, cutoff) # optionally enable scaling with allow_scaling=True 361 | # Also, you can randomize the frequencies if you plan on keeping them fixed, setting w_scale to any value > 0 362 | ... 363 | 364 | def forward(self, affine): 365 | ... 366 | embds = self.feats(affine) # 'affine' should be a Bx4 tensor, or Bx6 if scaling is enabled... 367 | # the default or initial affine values should be [1, 0, 0, 0, 1, 1] => ([1, 0]: rotation, [0, 0]: shift, [1, 1]: scale) 368 | ... 369 | 370 | ``` 371 | 372 | 373 | 374 | ## Criterions 375 | 376 | ### Gradient Penalty (for WGAN-GP) 377 | Implementation taken with minor changes from https://github.com/caogang/wgan-gp 378 | Original paper https://arxiv.org/pdf/1704.00028.pdf 379 | 380 | Example of use: 381 | ```python 382 | from torchtools.nn import GPLoss 383 | # This criterion defines the gradient penalty for WGAN GP 384 | # For an example of a training cycle refer to this repo https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py#L185 385 | 386 | discriminator = ... 387 | gpcriterion = GPLoss(discriminator) # l = 10 by default 388 | 389 | gradient_penalty = gpcriterion(real_data, fake_data) 390 | discriminator_loss = ... + gradient_penalty # add the gp component to the Wasserstein loss 391 | ``` 392 | 393 | ### Total Variation Loss 394 | Total Variation denoising https://www.wikiwand.com/en/Total_variation_denoising 395 | 396 | Example of use: 397 | ```python 398 | # This loss (or regularization) is usefull for removing artifacts and noise in generated images. 399 | # It's widely used in style transfer. 400 | from torchtools.nn import TVLoss 401 | 402 | tvcriterion = TVLoss() # reduction = 'sum' and alpha = 1e-4 by default 403 | 404 | G = ... # output image 405 | tv_loss = tvcriterion(G) 406 | loss = ... + tv_loss # add the tv loss component to your reconstruction loss 407 | ``` 408 | 409 | 410 | ## Vector Quantization 411 | ### VectorQuantize: Encodding based quantization [(source)](torchtools/vq.py#L5) 412 | This transforms any tensor to its quantized version using a codebook of embeddings. 413 | It uses a traight-forward approach for applying the gradients. 414 | Passing a tensor trough the **VectorQuantize** module will return a new tensor with the same dimension but changing each one of the tensors of the last dimension by the nearest neighbor from the codebook, which has a limited number of values, thus quantizing the tensor. 415 | 416 | For the quantization it relies in a differentiable function that you can see [here](torchtools/functional/vq.py#L4) 417 | 418 | The output of the model is a quantized tensor, as well as a Touple of the loss components of the codebook (needed for training), and the indices of the quantized vectors in the form: `qx, (vq_loss, commit_loss), indices` 419 | 420 | When **creating a new instance of the module**, it accepts the following parameters: 421 | - **embedding_size**: the size of the embeddings used in the codebook, should match the last dimension of the tensor you want to quantize 422 | - **k**: the size of the codebook, or number of embeddings. 423 | - **ema_decay** (default=0.99): the Exponentially Moving Average decay used (this only will be used if ema_loss is True) 424 | - **ema_loss** (default=False): Enables Exponentially Moving Average update of the codebook (instead of relying on gradient descent as EMA converges faster) 425 | 426 | When **calling the forward method** of the module, it accepts the following parameters: 427 | - **x**: this is the tensor you want to quantize, make sure the dimension that you want to quantize (by default is the last one) matches embedding_size defined when instantiating the module 428 | - **get_losses** (default=True): when False, the vq_loss and commit_loss components of the output will both be None, this should speed up a little bit the model when used for inference. 429 | - **dim** (default=-1): The dimension across which the input should be quantized. 430 | 431 | Example of use: 432 | ```python 433 | from torchtools.nn import VectorQuantize 434 | 435 | e = torch.randn(1, 16, 16, 8) # create a random tensor with 8 as its last dimension size 436 | vquantizer = VectorQuantize(8, k=32, ema_loss=True) # we create the module with embedding size of 8, a codebook of size 32 and make the codebook update using EMA 437 | qe, (vq_loss, commit_loss), indices = vquantizer.forward(e) # we quantize our tensor while also getting the loss components and the indices 438 | 439 | # NOTE While the model is in training mode, the codebook will always be updated when calling the forward method, in order to freeze the codebook for inference put it in evaluation mode with 'vquantizer.eval()' 440 | 441 | # NOTE 2 In order to update the module properly, add the loss components to the final model loss before calling backward(), if you set ema_loss to true you only need to add the commit_loss to the total loss, an it's usually multiplied by a value between 0.1 and 2, being 0.25 a good default value 442 | 443 | loss = ... # whatever loss you have for your final output 444 | loss += commit_loss * 0.25 445 | # loss += vq_loss # only if you didn't set the ema_loss to True 446 | 447 | ... 448 | loss.backward() 449 | optimizer.step() 450 | 451 | ``` 452 | 453 | --- 454 | 455 | ### Binarize: binarize the input tensor [(source)](torchtools/vq.py#L55) 456 | This transfors the values of a tensor into 0 and 1 depending if they're above or below a specified threshold. 457 | It uses a traight-forward approach for applying the gradients, so it's effectively differentiable. 458 | 459 | For the quantization it relies in a differentiable function that you can see [here](torchtools/functional/vq.py#L36) 460 | 461 | Example of use: 462 | ```python 463 | from torchtools.nn import Binarize 464 | 465 | e = torch.randn(8, 16) # create a random tensor with any dimension 466 | 467 | binarizer = Binarize(threshold=0.5) # you can set the threshold you want, for example if your output was passed through a tanh activation, 0 might be a better theshold since tanh outputs values between -1 and 1 468 | 469 | bq = binarizer(e) # will return a tensor with the same shape as e, but full of 0s and 1s 470 | ``` 471 | 472 | ## Embeddings 473 | 474 | ### RotaryEmbedding 475 | Implementation taken as is from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py#L161 476 | 477 | Example of use: 478 | ```python 479 | from torchtools.nn import RotaryEmbedding 480 | 481 | class MyModel(nn.Module): 482 | def __init__(self, dim): 483 | ... 484 | self.rotary_pos_embd = RotaryEmbedding(dim) 485 | ... 486 | 487 | def forward(self, x): 488 | x = self.rotary_pos_embd(x) 489 | ... 490 | 491 | 492 | ``` 493 | 494 | ## Diffusion 495 | 496 | ### Diffuzz 497 | Custom (non-cached) continuous forward/backward diffusion. 498 | It's not SUPER performant since it calculates all the required values on the fly instead of caching them (although I add a very simple cache where you can specify the number of steps that you want to cache), but in general I think this will not make an extremely big difference in terms of performance, simplifies a lot the code, and removes the concept of having a fixed number of timesteps for the forward diffusion (since I always found it weird to train a model assuming 1000 forward diffusion steps, and then using way less steps during inference) by using a continuous value between 0 and 1 to decide how much noise we'll be adding to the output (1 being pure gaussian noise). 499 | 500 | During sampling, the same applies, instead of having a fixed number of steps, the diffuzz module will accept a noised input, a couple of values t & t_prev (between 0 and 1) and a predicted noise, and it will try to remove such noise in a scale such as to go from step t to step t_prev, so if we want to denoise in 10 steps we'll tell it to go from 1.0 to 0.9, then to 0.8, etc... while if we want to denoise in 100 steps, we'll start at 1.0 and go to 0.99, then to 0.98, etc... 501 | 502 | Example of use during training: 503 | ```python 504 | from torchtools.utils import Diffuzz 505 | device = "cuda" 506 | 507 | diffuzz = Diffuzz(device=device) 508 | # diffuzz = Diffuzz(device=device, cache_steps=10000) # optionally you can pass a 'cache_steps' parameter to speed up the noising process 509 | custom_unet = CustomUnet().to(device) # Custom model whith output size = input size 510 | 511 | input_tensor = torch.randn(8, 3, 16, 16, device=device) # an image, audio signal, or whatever... 512 | 513 | t = torch.rand(input_tensor.size(0), device=device) # get a tensor with batch_size of values between 0 and 1 514 | noised_tensor, noise = diffuzz.diffuse(input_tensor, t) 515 | 516 | predicted_noise = custom_unet(noised_tensor, t) 517 | loss = nn.functional.mse_loss(predicted_noise, noise) 518 | 519 | # Optionally the diffuzz module provides loss gamma weighting (untested) but for this to work the loss 520 | # should not be averaged on the batch dimension before applying it. 521 | 522 | # loss = nn.functional.mse_loss(predicted_noise, noise, reduction='none').mean(dim=[1, 2, 3]) 523 | # loss = (loss * diffuzz.p2_weight(t)).mean() 524 | 525 | ``` 526 | 527 | Example of use for sampling: 528 | ```python 529 | from torchtools.utils import Diffuzz 530 | device = "cuda" 531 | 532 | sampled = diffuzz.sample( 533 | custom_unet, {'c': conditioning}, 534 | (conditioning.size(0), 3, 16, 16), 535 | timesteps=20, sampler='ddim' 536 | )[-1] 537 | ``` 538 | 539 | the `sample` method accepts a `sampler` parameter, currently only `ddpm` (default) and `ddim` are implemented, but I'm planning on adding more, very likely by borrowing (and appropriately citing) code from this repo https://github.com/ozanciga/diffusion-for-beginners 540 | 541 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = readme.md -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='torchtools', 5 | packages=find_packages(), 6 | description='PyTorch useful tools', 7 | version='0.3.5', 8 | url='https://github.com/pabloppp/pytorch-tools', 9 | author='Pablo Pernías', 10 | author_email='pablo@pernias.com', 11 | keywords=['pip', 'pytorch', 'tools', 'RAdam', 'Lookahead', 'RALamb', 'quantization'], 12 | zip_safe=False, 13 | install_requires=[ 14 | 'torch>=1.6', 15 | 'torchvision', 16 | 'numpy>=1.0', 17 | 'ninja>=1.0' 18 | ], 19 | package_data={ 20 | 'stylegan2.tools': ['torchtools/nn/stylegan2/*'], 21 | 'transforms.models': ['torchtools/transforms/models/*'] 22 | }, 23 | include_package_data=True, 24 | ) 25 | -------------------------------------------------------------------------------- /torchtools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pabloppp/pytorch-tools/6472bd5e223142bfbe105e2006432252f1ce3709/torchtools/__init__.py -------------------------------------------------------------------------------- /torchtools/lr_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from .delayed import DelayerScheduler, DelayedCosineAnnealingLR 2 | from .inverse_sqrt import InverseSqrtLR 3 | -------------------------------------------------------------------------------- /torchtools/lr_scheduler/delayed.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR 2 | 3 | class DelayerScheduler(_LRScheduler): 4 | """ Starts with a flat lr schedule until it reaches N epochs the applies a scheduler 5 | 6 | Args: 7 | optimizer (Optimizer): Wrapped optimizer. 8 | delay_epochs: number of epochs to keep the initial lr until starting aplying the scheduler 9 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 10 | """ 11 | 12 | def __init__(self, optimizer, delay_epochs, after_scheduler): 13 | self.delay_epochs = delay_epochs 14 | self.after_scheduler = after_scheduler 15 | self.finished = False 16 | super().__init__(optimizer) 17 | 18 | def get_lr(self): 19 | if self.last_epoch >= self.delay_epochs: 20 | if not self.finished: 21 | self.after_scheduler.base_lrs = self.base_lrs 22 | self.finished = True 23 | return self.after_scheduler.get_lr() 24 | 25 | return self.base_lrs 26 | 27 | def step(self, epoch=None): 28 | if self.finished: 29 | if epoch is None: 30 | self.after_scheduler.step(None) 31 | else: 32 | self.after_scheduler.step(epoch - self.delay_epochs) 33 | else: 34 | return super(DelayerScheduler, self).step(epoch) 35 | 36 | def DelayedCosineAnnealingLR(optimizer, delay_epochs, cosine_annealing_epochs): 37 | base_scheduler = CosineAnnealingLR(optimizer, cosine_annealing_epochs) 38 | return DelayerScheduler(optimizer, delay_epochs, base_scheduler) -------------------------------------------------------------------------------- /torchtools/lr_scheduler/inverse_sqrt.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from torch.optim.lr_scheduler import LRScheduler 4 | 5 | 6 | class InverseSqrtLR(LRScheduler): 7 | def __init__(self, optimizer, lr, warmup_steps, pre_warmup_lr=None, last_epoch=-1, verbose=False): 8 | warmup_steps = max(warmup_steps, 1) 9 | self.lr = lr * warmup_steps**0.5 10 | self.warmup_steps = warmup_steps 11 | self.pre_warmup_lr = pre_warmup_lr if pre_warmup_lr is not None else lr 12 | super().__init__(optimizer, last_epoch, verbose) 13 | 14 | def _process_lr(self, _): 15 | warmup_factor = min(self.last_epoch/self.warmup_steps, 1) # this grows linearly from 0 to 1 during the warmup 16 | base_lr = self.lr / max(self.last_epoch, self.warmup_steps)**0.5 17 | return warmup_factor * base_lr + (1-warmup_factor)*self.pre_warmup_lr 18 | 19 | def get_lr(self): 20 | if not self._get_lr_called_within_step: 21 | warnings.warn("To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning) 22 | 23 | lr = self._process_lr(self.lr) 24 | return [lr for _ in self.optimizer.param_groups] 25 | 26 | def _get_closed_form_lr(self): 27 | return [self._process_lr(base_lr) for base_lr in self.base_lrs] -------------------------------------------------------------------------------- /torchtools/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .mish import Mish 2 | from .simple_self_attention import SimpleSelfAttention 3 | from .vq import VectorQuantize, Binarize, FSQ 4 | from .gp_loss import GPLoss 5 | from .pixel_normalzation import PixelNorm 6 | from .perceptual import TVLoss 7 | from .adain import AdaIN 8 | from .transformers import GPTTransformerEncoderLayer 9 | from .evonorm2d import EvoNorm2D 10 | from .pos_embeddings import RotaryEmbedding 11 | from .modulation import ModulatedConv2d 12 | from .equal_layers import EqualConv2d, EqualLeakyReLU, EqualLinear 13 | from .fourier_features import FourierFeatures2d 14 | # from .alias_free_activation import AliasFreeActivation 15 | from .magnitude_preserving import MP_GELU, MP_SiLU, Gain 16 | from .haar_dwt import HaarForward, HaarInverse -------------------------------------------------------------------------------- /torchtools/nn/adain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class AdaIN(nn.Module): 5 | def __init__(self, n_channels): 6 | super(AdaIN, self).__init__() 7 | self.norm = nn.InstanceNorm2d(n_channels) 8 | 9 | def forward(self, image, style): 10 | factor, bias = style.view(style.size(0), style.size(1), 1, 1).chunk(2, dim=1) 11 | result = self.norm(image) * factor + bias 12 | return result -------------------------------------------------------------------------------- /torchtools/nn/alias_free_activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | from .stylegan2 import upfirdn2d 5 | 6 | #### 7 | # TOTALLY INSPIRED AND EVEN COPIED SOME CHUNKS FROM 8 | # https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L225 9 | # But I simplified it into a single (almots) self-contained module. 10 | # Probably I give this module too much reponsibility but meh... 11 | #### 12 | 13 | class AliasFreeActivation(nn.Module): 14 | def __init__(self, activation, level, max_levels, max_size, max_channels, margin, start_cutoff=2, critical_layers=2, window_size=6): 15 | super().__init__() 16 | self.activation = activation 17 | 18 | # Filter features 19 | self.cutoff, self.stopband, self.band_half, self.channels, self.size = self.alias_level_params( 20 | level, max_levels, max_size, max_channels, start_cutoff, critical_layers 21 | ) 22 | self.cutoff_prev, self.stopband_prev, self.band_half_prev, self.channels_prev, self.size_prev = self.alias_level_params( 23 | max(level-1, 0), max_levels, max_size, max_channels, start_cutoff, critical_layers 24 | ) 25 | 26 | # Filters 27 | self.scale_factor = 2 if self.size_prev < self.size else 1 28 | up_filter = self._lowpass_filter( 29 | window_size * self.scale_factor * 2, self.cutoff_prev, self.band_half_prev, self.size * self.scale_factor * 2 30 | ) 31 | self.register_buffer("up_filter", (up_filter / up_filter.sum()) * 2 * self.scale_factor) 32 | 33 | down_filter = self._lowpass_filter( 34 | window_size * self.scale_factor, self.cutoff, self.band_half, self.size * self.scale_factor * 2 35 | ) 36 | self.register_buffer("down_filter", down_filter / down_filter.sum()) 37 | 38 | p = self.up_filter.shape[0] - (2*self.scale_factor) 39 | self.up_pad = ((p + 1) // 2 + (2*self.scale_factor) - 1, p // 2) 40 | 41 | p = self.down_filter.shape[0] - 2 42 | self.down_pad = ((p + 1) // 2, p // 2) 43 | self.margin = margin 44 | 45 | @staticmethod 46 | def alias_level_params(level, max_levels, max_size, max_channels, start_cutoff=2, critical_layers=2, base_channels=2**14): 47 | end_cutoff = max_size//2 48 | cutoff = start_cutoff * (end_cutoff / start_cutoff) ** min(level / (max_levels - critical_layers), 1) 49 | 50 | start_stopband = start_cutoff ** 2.1 51 | end_stopband = end_cutoff * (2 ** 0.3) 52 | stopband = start_stopband * (end_stopband/start_stopband) ** min(level / (max_levels - critical_layers), 1) 53 | 54 | size = 2 ** math.ceil(math.log(min(2 * stopband, max_size), 2)) 55 | band_half = max(stopband, size / 2) - cutoff 56 | channels = min(round(base_channels / size), max_channels) 57 | 58 | return cutoff, stopband, band_half, channels, size 59 | 60 | def _lowpass_filter(self, n_taps, cutoff, band_half, sr): 61 | window = self._kaiser_window(n_taps, band_half, sr) 62 | ind = torch.arange(n_taps) - (n_taps - 1) / 2 63 | lowpass = 2 * cutoff / sr * torch.sinc(2 * cutoff / sr * ind) * window 64 | 65 | return lowpass 66 | 67 | def _kaiser_window(self, n_taps, f_h, sr): 68 | beta = self._kaiser_beta(n_taps, f_h, sr) 69 | ind = torch.arange(n_taps) - (n_taps - 1) / 2 70 | return torch.i0(beta * torch.sqrt(1 - ((2 * ind) / (n_taps - 1)) ** 2)) / torch.i0(torch.tensor(beta)) 71 | 72 | def _kaiser_attenuation(self, n_taps, f_h, sr): 73 | df = (2 * f_h) / (sr / 2) 74 | return 2.285 * (n_taps - 1) * math.pi * df + 7.95 75 | 76 | 77 | def _kaiser_beta(self, n_taps, f_h, sr): 78 | atten = self._kaiser_attenuation(n_taps, f_h, sr) 79 | if atten > 50: 80 | return 0.1102 * (atten - 8.7) 81 | 82 | elif 50 >= atten >= 21: 83 | return 0.5842 * (atten - 21) ** 0.4 + 0.07886 * (atten - 21) 84 | else: 85 | return 0.0 86 | 87 | def forward(self, x): 88 | x = self._upsample(x, self.up_filter, 2*self.scale_factor, pad=self.up_pad) 89 | x = self.activation(x) 90 | x = self._downsample(x, self.down_filter, 2, pad=self.down_pad) 91 | if self.scale_factor > 1 and self.margin > 0: 92 | m = self.scale_factor * self.margin // 2 93 | x = x[:, :, m:-m, m:-m] 94 | return x 95 | 96 | def _upsample(self, x, kernel, factor, pad=(0, 0)): 97 | x = upfirdn2d(x, kernel.unsqueeze(0), up=(factor, 1), pad=(*pad, 0, 0)) 98 | x = upfirdn2d(x, kernel.unsqueeze(1), up=(1, factor), pad=(0, 0, *pad)) 99 | return x 100 | 101 | def _downsample(self, x, kernel, factor, pad=(0, 0)): 102 | x = upfirdn2d(x, kernel.unsqueeze(0), down=(factor, 1), pad=(*pad, 0, 0)) 103 | x = upfirdn2d(x, kernel.unsqueeze(1), down=(1, factor), pad=(0, 0, *pad)) 104 | return x 105 | 106 | def extra_repr(self): 107 | info_string = f'cutoff={self.cutoff}, stopband={self.stopband}, band_half={self.band_half}, channels={self.channels}, size={self.size}' 108 | return info_string -------------------------------------------------------------------------------- /torchtools/nn/equal_layers.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | import math 5 | 6 | #### 7 | # TOTALLY INSPIRED AND EVEN COPIED SOME CHUNKS FROM 8 | # https://github.com/rosinality/alias-free-gan-pytorch/blob/main/stylegan2/model.py#L94 9 | # But made it extend from the base modules to avoid some boilerplate 10 | #### 11 | 12 | class EqualLinear(nn.Linear): 13 | def __init__(self, *args, bias_init=0, lr_mul=1, **kwargs): 14 | super().__init__(*args, **kwargs) 15 | 16 | self.scale = (1 / math.sqrt(self.in_features)) * lr_mul 17 | self.lr_mul = lr_mul 18 | 19 | nn.init.normal_(self.weight, std=1/lr_mul) 20 | if self.bias is not None: 21 | nn.init.constant_(self.bias, bias_init) 22 | 23 | def forward(self, x): 24 | return nn.functional.linear(x, self.weight * self.scale, self.bias * self.lr_mul) 25 | 26 | 27 | class EqualConv2d(nn.Conv2d): 28 | def __init__(self, *args, **kwargs): 29 | super().__init__(*args, **kwargs) 30 | 31 | fan_in = self.in_channels * self.kernel_size[0] ** 2 32 | self.scale = 1 / math.sqrt(fan_in) 33 | 34 | nn.init.normal_(self.weight) 35 | if self.bias is not None: 36 | nn.init.zeros_(self.bias) 37 | 38 | def forward(self, x): 39 | return self._conv_forward(x, self.weight * self.scale, self.bias) 40 | 41 | 42 | class EqualLeakyReLU(nn.LeakyReLU): 43 | def __init__(self, *args, scale=2**0.5, **kwargs): 44 | super().__init__(*args, **kwargs) 45 | self.scale = scale 46 | 47 | def forward(self, x): 48 | return super().forward(x) * self.scale -------------------------------------------------------------------------------- /torchtools/nn/evonorm2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | ## Taken as is from https://github.com/digantamisra98/EvoNorm all credit goes to digantamisra98 6 | class SwishImplementation(torch.autograd.Function): 7 | @staticmethod 8 | def forward(ctx, i): 9 | ctx.save_for_backward(i) 10 | return i * torch.sigmoid(i) 11 | 12 | @staticmethod 13 | def backward(ctx, grad_output): 14 | sigmoid_i = torch.sigmoid(ctx.saved_variables[0]) 15 | return grad_output * (sigmoid_i * (1 + ctx.saved_variables[0] * (1 - sigmoid_i))) 16 | 17 | 18 | class MemoryEfficientSwish(nn.Module): 19 | def forward(self, x): 20 | return SwishImplementation.apply(x) 21 | 22 | def instance_std(x, eps=1e-5): 23 | var = torch.var(x, dim = (2, 3), keepdim=True).expand_as(x) 24 | if torch.isnan(var).any(): 25 | var = torch.zeros(var.shape) 26 | return torch.sqrt(var + eps) 27 | 28 | def group_std(x, groups = 32, eps = 1e-5): 29 | N, C, H, W = x.size() 30 | x = torch.reshape(x, (N, groups, C // groups, H, W)) 31 | var = torch.var(x, dim = (2, 3, 4), keepdim = True).expand_as(x) 32 | return torch.reshape(torch.sqrt(var + eps), (N, C, H, W)) 33 | 34 | class EvoNorm2D(nn.Module): 35 | 36 | def __init__(self, input, non_linear = True, version = 'S0', efficient = False, affine = True, momentum = 0.9, eps = 1e-5, groups = 32, training = True): 37 | super(EvoNorm2D, self).__init__() 38 | self.non_linear = non_linear 39 | self.version = version 40 | self.training = training 41 | self.momentum = momentum 42 | self.efficient = efficient 43 | if self.version == 'S0': 44 | self.swish = MemoryEfficientSwish() 45 | self.groups = groups 46 | self.eps = eps 47 | if self.version not in ['B0', 'S0']: 48 | raise ValueError("Invalid EvoNorm version") 49 | self.insize = input 50 | self.affine = affine 51 | 52 | if self.affine: 53 | self.gamma = nn.Parameter(torch.ones(1, self.insize, 1, 1)) 54 | self.beta = nn.Parameter(torch.zeros(1, self.insize, 1, 1)) 55 | if self.non_linear: 56 | self.v = nn.Parameter(torch.ones(1,self.insize,1,1)) 57 | else: 58 | self.register_parameter('gamma', None) 59 | self.register_parameter('beta', None) 60 | self.register_buffer('v', None) 61 | self.register_buffer('running_var', torch.ones(1, self.insize, 1, 1)) 62 | 63 | self.reset_parameters() 64 | 65 | def reset_parameters(self): 66 | self.running_var.fill_(1) 67 | 68 | def _check_input_dim(self, x): 69 | if x.dim() != 4: 70 | raise ValueError('expected 4D input (got {}D input)' 71 | .format(x.dim())) 72 | 73 | def forward(self, x): 74 | self._check_input_dim(x) 75 | if self.version == 'S0': 76 | if self.non_linear: 77 | if not self.efficient: 78 | num = x * torch.sigmoid(self.v * x) # Original Swish Implementation, however memory intensive. 79 | else: 80 | num = self.swish(x) # Experimental Memory Efficient Variant of Swish 81 | return num / group_std(x, groups = self.groups, eps = self.eps) * self.gamma + self.beta 82 | else: 83 | return x * self.gamma + self.beta 84 | if self.version == 'B0': 85 | if self.training: 86 | var = torch.var(x, dim = (0, 2, 3), unbiased = False, keepdim = True) 87 | self.running_var.mul_(self.momentum) 88 | self.running_var.add_((1 - self.momentum) * var) 89 | else: 90 | var = self.running_var 91 | 92 | if self.non_linear: 93 | den = torch.max((var+self.eps).sqrt(), self.v * x + instance_std(x, eps = self.eps)) 94 | return x / den * self.gamma + self.beta 95 | else: 96 | return x * self.gamma + self.beta -------------------------------------------------------------------------------- /torchtools/nn/fourier_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | 5 | class FourierFeatures2d(nn.Module): 6 | def __init__(self, size, dim, cutoff, affine_eps=1e-8, freq_range=[-0.5, 0.5], w_scale=0, allow_scaling=False, op_order=['r', 't', 's']): 7 | super().__init__() 8 | self.size = size 9 | self.dim = dim 10 | self.cutoff = cutoff 11 | self.freq_range = freq_range 12 | self.affine_eps = affine_eps 13 | self.w_scale = w_scale 14 | coords = torch.linspace(freq_range[0], freq_range[1], size+1)[:-1] 15 | freqs = torch.linspace(0, cutoff, dim // 4) 16 | if w_scale > 0: 17 | freqs = freqs @ (torch.randn(dim // 4, dim // 4) * w_scale) 18 | coord_map = torch.outer(freqs, coords) 19 | coord_map = 2 * math.pi * coord_map 20 | self.register_buffer("coord_h", coord_map.view(freqs.shape[0], 1, size)) 21 | self.register_buffer("coord_w", self.coord_h.transpose(1, 2).detach()) 22 | self.register_buffer("lf", freqs.view(1, dim // 4, 1, 1) * 2*math.pi * 2/size) 23 | self.allow_scaling = allow_scaling 24 | for op in op_order: 25 | assert op in ['r', 't', 's'], f"Operation not valid: {op}" 26 | self.op_order = op_order 27 | 28 | def forward(self, affine): 29 | norm = ((affine[:, 0:1].pow(2) + affine[:, 1:2].pow(2)).sqrt() + self.affine_eps).expand(affine.size(0), 4) 30 | if self.allow_scaling: 31 | assert affine.size(-1) == 6, f"If scaling is enabled, 2 extra values must be passed for a total of 6, and not {affine.size(-1)}" 32 | norm = torch.cat([norm, norm.new_ones(affine.size(0), 2)], dim=1) 33 | else: 34 | assert affine.size(-1) == 4, f"If scaling is disabled, 4 affine values should be passed, and not {affine.size(-1)}" 35 | affine = affine / norm 36 | affine = affine[:, :, None, None, None] 37 | 38 | coord_h, coord_w = self.coord_h.unsqueeze(0), self.coord_w.unsqueeze(0) 39 | 40 | for op in reversed(self.op_order): 41 | if op == 's' and self.allow_scaling: 42 | coord_h = coord_h / nn.functional.threshold(affine[:, 5], 1.0, 1.0) # scale 43 | coord_w = coord_w / nn.functional.threshold(affine[:, 4], 1.0, 1.0) 44 | 45 | elif op == 't': 46 | coord_h = coord_h - (affine[:, 3] * self.lf) # shift 47 | coord_w = coord_w - (affine[:, 2] * self.lf) 48 | 49 | elif op == 'r': 50 | _coord_h = -coord_w * affine[:, 1] + coord_h * affine[:, 0] # rotation 51 | coord_w = coord_w * affine[:, 0] + coord_h * affine[:, 1] 52 | coord_h = _coord_h 53 | 54 | coord_h = torch.cat((torch.sin(coord_h), torch.cos(coord_h)), 1) 55 | coord_w = torch.cat((torch.sin(coord_w), torch.cos(coord_w)), 1) 56 | 57 | coords = torch.cat((coord_h, coord_w), 1) 58 | return coords 59 | 60 | def extra_repr(self): 61 | info_string = f'size={self.size}, dim={self.dim}, cutoff={self.cutoff}, freq_range={self.freq_range}' 62 | if self.w_scale > 0: 63 | info_string += f', w_scale={self.w_scale}' 64 | if self.allow_scaling: 65 | info_string += f', allow_scaling={self.allow_scaling}' 66 | return info_string 67 | -------------------------------------------------------------------------------- /torchtools/nn/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .vq import vector_quantize, binarize 2 | from .gradient_penalty import gradient_penalty 3 | from .perceptual import total_variation 4 | from .magnitude_preserving import mp_cat, mp_sum 5 | -------------------------------------------------------------------------------- /torchtools/nn/functional/gradient_penalty.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN WITH FEW MODIFICATIONS FROM https://github.com/caogang/wgan-gp 3 | # ORIGINAL PAPER https://arxiv.org/pdf/1704.00028.pdf 4 | #### 5 | 6 | import torch 7 | from torch import autograd 8 | 9 | def gradient_penalty(netD, real_data, fake_data, l=10): 10 | batch_size = real_data.size(0) 11 | alpha = real_data.new_empty((batch_size, 1, 1, 1)).uniform_(0, 1) 12 | alpha = alpha.expand_as(real_data) 13 | 14 | interpolates = alpha * real_data + ((1 - alpha) * fake_data) 15 | interpolates = autograd.Variable(interpolates, requires_grad=True) 16 | 17 | disc_interpolates = netD(interpolates) 18 | 19 | gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, 20 | grad_outputs=real_data.new_ones(disc_interpolates.size()), 21 | create_graph=True, retain_graph=True, only_inputs=True)[0] 22 | 23 | gradients = gradients.view(gradients.size(0), -1) 24 | gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12) 25 | gradient_penalty = ((gradients_norm - 1) ** 2).mean() * l 26 | 27 | return gradient_penalty -------------------------------------------------------------------------------- /torchtools/nn/functional/magnitude_preserving.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def mp_cat(*args, dim=1, t=0.5): 4 | if isinstance(t, float): 5 | t = [1-t, t] 6 | assert len(args) == len(t), "t must be a single scalar or a list of scalars of length len(args)" 7 | 8 | w = [m/a.size(dim)**0.5 for a, m in zip(args, t)] 9 | C = (sum([a.size(dim) for a in args]) / sum([m**2 for m in t]))**0.5 10 | 11 | return torch.cat([a*v for a, v in zip(args, w)], dim=dim) * C 12 | 13 | def mp_sum(*args, t=0.5): 14 | if isinstance(t, float): 15 | t = [1-t, t] 16 | 17 | assert len(args) == len(t), "t must be a single scalar or a list of scalars of length len(args)" 18 | assert abs(sum(t)-1) < 1e-3 , "the values of t should all add up to one" 19 | 20 | return sum([a*m for a, m in zip(args, t)]) / sum([m**2 for m in t])**0.5 21 | -------------------------------------------------------------------------------- /torchtools/nn/functional/perceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def total_variation(X, reduction='sum'): 4 | tv_h = torch.abs(X[:, :, :, 1:] - X[:, :, :, :-1]) 5 | tv_v = torch.abs(X[:, :, 1:] - X[:, :, :-1]) 6 | 7 | tv = torch.mean(tv_h) + torch.mean(tv_v) if reduction == 'mean' else torch.sum(tv_h) + torch.sum(tv_v) 8 | 9 | return tv -------------------------------------------------------------------------------- /torchtools/nn/functional/vq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | class vector_quantize(Function): 5 | @staticmethod 6 | def forward(ctx, x, codebook): 7 | with torch.no_grad(): 8 | codebook_sqr = torch.sum(codebook ** 2, dim=1) 9 | x_sqr = torch.sum(x ** 2, dim=1, keepdim=True) 10 | 11 | dist = torch.addmm(codebook_sqr + x_sqr, x, codebook.t(), alpha=-2.0, beta=1.0) 12 | _, indices = dist.min(dim=1) 13 | 14 | ctx.save_for_backward(indices, codebook) 15 | ctx.mark_non_differentiable(indices) 16 | 17 | nn = torch.index_select(codebook, 0, indices) 18 | return nn, indices 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output, grad_indices): 22 | grad_inputs, grad_codebook = None, None 23 | 24 | if ctx.needs_input_grad[0]: 25 | grad_inputs = grad_output.clone() 26 | if ctx.needs_input_grad[1]: 27 | # Gradient wrt. the codebook 28 | indices, codebook = ctx.saved_tensors 29 | 30 | grad_codebook = torch.zeros_like(codebook) 31 | grad_codebook.index_add_(0, indices, grad_output) 32 | 33 | return (grad_inputs, grad_codebook) 34 | 35 | 36 | class binarize(Function): 37 | @staticmethod 38 | def forward(ctx, x, threshold=0.5): 39 | with torch.no_grad(): 40 | binarized = (x > threshold).float() 41 | ctx.mark_non_differentiable(binarized) 42 | 43 | return binarized 44 | 45 | @staticmethod 46 | def backward(ctx, grad_output): 47 | grad_inputs = None 48 | 49 | if ctx.needs_input_grad[0]: 50 | grad_inputs = grad_output.clone() 51 | 52 | return grad_inputs -------------------------------------------------------------------------------- /torchtools/nn/gp_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .functional import gradient_penalty 4 | 5 | class GPLoss(nn.Module): 6 | def __init__(self, discriminator, l=10): 7 | super(GPLoss, self).__init__() 8 | self.discriminator = discriminator 9 | self.l = l 10 | 11 | def forward(self, real_data, fake_data): 12 | return gradient_penalty(self.discriminator, real_data, fake_data, self.l) -------------------------------------------------------------------------------- /torchtools/nn/haar_dwt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | # Taken almost as is from https://github.com/bes-dev/haar_pytorch 5 | class HaarForward(nn.Module): 6 | """ 7 | Performs a 2d DWT Forward decomposition of an image using Haar Wavelets 8 | set beta=1 for regular haard dwt, with beta=2 we make a magnitude preserving dwt 9 | """ 10 | def __init__(self, beta=2): 11 | super().__init__() 12 | self.alpha = 0.5 13 | self.beta = beta 14 | 15 | def forward(self, x: torch.Tensor) -> torch.Tensor: 16 | """ 17 | Performs a 2d DWT Forward decomposition of an image using Haar Wavelets 18 | 19 | Arguments: 20 | x (torch.Tensor): input tensor of shape [b, c, h, w] 21 | 22 | Returns: 23 | out (torch.Tensor): output tensor of shape [b, c * 4, h / 2, w / 2] 24 | """ 25 | 26 | ll = self.alpha/self.beta * (x[:,:,0::2,0::2] + x[:,:,0::2,1::2] + x[:,:,1::2,0::2] + x[:,:,1::2,1::2]) 27 | lh = self.alpha * (x[:,:,0::2,0::2] + x[:,:,0::2,1::2] - x[:,:,1::2,0::2] - x[:,:,1::2,1::2]) 28 | hl = self.alpha * (x[:,:,0::2,0::2] - x[:,:,0::2,1::2] + x[:,:,1::2,0::2] - x[:,:,1::2,1::2]) 29 | hh = self.alpha * (x[:,:,0::2,0::2] - x[:,:,0::2,1::2] - x[:,:,1::2,0::2] + x[:,:,1::2,1::2]) 30 | return torch.cat([ll,lh,hl,hh], axis=1) 31 | 32 | 33 | class HaarInverse(nn.Module): 34 | """ 35 | Performs a 2d DWT Inverse reconstruction of an image using Haar Wavelets 36 | set beta=1 for regular haard dwt, with beta=2 we make a magnitude preserving dwt 37 | """ 38 | def __init__(self, beta=2): 39 | super().__init__() 40 | self.alpha = 0.5 41 | self.beta = beta 42 | 43 | def forward(self, x: torch.Tensor) -> torch.Tensor: 44 | """ 45 | Performs a 2d DWT Inverse reconstruction of an image using Haar Wavelets 46 | 47 | Arguments: 48 | x (torch.Tensor): input tensor of shape [b, c, h, w] 49 | 50 | Returns: 51 | out (torch.Tensor): output tensor of shape [b, c / 4, h * 2, w * 2] 52 | """ 53 | assert x.size(1) % 4 == 0, "The number of channels must be divisible by 4." 54 | size = [x.shape[0], x.shape[1] // 4, x.shape[2] * 2, x.shape[3] * 2] 55 | f = lambda i: x[:, size[1] * i : size[1] * (i + 1)] 56 | out = torch.zeros(size, dtype=x.dtype, device=x.device) 57 | out[:,:,0::2,0::2] = self.alpha * (f(0)*self.beta + f(1) + f(2) + f(3)) 58 | out[:,:,0::2,1::2] = self.alpha * (f(0)*self.beta + f(1) - f(2) - f(3)) 59 | out[:,:,1::2,0::2] = self.alpha * (f(0)*self.beta - f(1) + f(2) - f(3)) 60 | out[:,:,1::2,1::2] = self.alpha * (f(0)*self.beta - f(1) - f(2) + f(3)) 61 | return out -------------------------------------------------------------------------------- /torchtools/nn/magnitude_preserving.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class MP_GELU(nn.GELU): 5 | def forward(self, x): 6 | return super().forward(x) / 0.652 # ¯\_(ツ)_/¯ 7 | 8 | class MP_SiLU(nn.SiLU): 9 | def forward(self, x): 10 | return super().forward(x) / 0.596 # ¯\_(ツ)_/¯ 11 | 12 | class Gain(nn.Module): 13 | def __init__(self, init_w=0.0): 14 | super().__init__() 15 | self.g = nn.Parameter(torch.tensor([init_w])) 16 | 17 | def forward(self, x): 18 | return x * self.g 19 | 20 | -------------------------------------------------------------------------------- /torchtools/nn/mish.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/lessw2020/mish 3 | # ORIGINAL PAPER https://arxiv.org/abs/1908.08681v1 4 | #### 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F #(uncomment if needed,but you likely already have it) 9 | 10 | #Mish - "Mish: A Self Regularized Non-Monotonic Neural Activation Function" 11 | #https://arxiv.org/abs/1908.08681v1 12 | #implemented for PyTorch / FastAI by lessw2020 13 | #github: https://github.com/lessw2020/mish 14 | 15 | class Mish(nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def forward(self, x): 20 | #inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!) 21 | return x *( torch.tanh(F.softplus(x))) -------------------------------------------------------------------------------- /torchtools/nn/modulation.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | import math 5 | 6 | #### 7 | # TOTALLY INSPIRED AND EVEN COPIED SOME CHUNKS FROM 8 | # https://github.com/rosinality/alias-free-gan-pytorch/blob/main/model.py#L143 9 | # But made it extend from the base Conv2d to avoid some boilerplate 10 | #### 11 | class ModulatedConv2d(nn.Conv2d): 12 | def __init__(self, *args, demodulate=True, ema_decay=1.0, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | 15 | fan_in = self.in_channels * self.kernel_size[0] ** 2 16 | self.scale = 1 / math.sqrt(fan_in) 17 | 18 | self.demodulate = demodulate 19 | self.ema_decay = ema_decay 20 | self.register_buffer("ema_var", torch.tensor(1.0)) 21 | nn.init.normal_(self.weight) 22 | if self.bias is not None: 23 | nn.init.zeros_(self.bias) 24 | 25 | def forward(self, x, w): 26 | batch, in_channels, height, width = x.shape 27 | 28 | style = w.view(batch, 1, in_channels, 1, 1) 29 | weight = self.scale * self.weight.unsqueeze(0) * style 30 | 31 | if self.demodulate: 32 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 33 | weight = weight * demod.view(batch, self.out_channels, 1, 1, 1) 34 | 35 | weight = weight.view( 36 | batch * self.out_channels, in_channels, self.kernel_size[0], self.kernel_size[1] 37 | ) 38 | 39 | if self.ema_decay < 1: 40 | if self.training: 41 | var = x.pow(2).mean((0, 1, 2, 3)) 42 | self.ema_var.mul_(self.ema_decay).add_(var.detach(), alpha=1 - self.ema_decay) 43 | 44 | weight = weight / (torch.sqrt(self.ema_var) + 1e-8) 45 | 46 | input = x.view(1, batch * in_channels, height, width) 47 | self.groups = batch 48 | out = self._conv_forward(input, weight, None) 49 | _, _, height, width = out.shape 50 | out = out.view(batch, self.out_channels, height, width) 51 | if self.bias is not None: 52 | out = out + self.bias.view(1, -1, 1, 1) 53 | return out 54 | -------------------------------------------------------------------------------- /torchtools/nn/perceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .functional import total_variation 4 | 5 | class TVLoss(nn.Module): 6 | def __init__(self, reduction='sum', alpha=1e-4): 7 | super(TVLoss, self).__init__() 8 | self.reduction = reduction 9 | self.alpha = alpha 10 | 11 | def forward(self, x): 12 | return total_variation(x, reduction=self.reduction) * self.alpha -------------------------------------------------------------------------------- /torchtools/nn/pixel_normalzation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class PixelNorm(nn.Module): 5 | def __init__(self, dim=1, eps=1e-4): 6 | super().__init__() 7 | self.dim = dim 8 | self.eps = eps 9 | 10 | def forward(self, x): 11 | return x / (torch.sqrt(torch.mean(x ** 2, dim=self.dim, keepdim=True)) + self.eps) -------------------------------------------------------------------------------- /torchtools/nn/pos_embeddings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | #### 5 | # CODE TAKEN FROM https://github.com/lucidrains/x-transformers 6 | #### 7 | 8 | class RotaryEmbedding(nn.Module): 9 | def __init__(self, dim, base=10000): 10 | super().__init__() 11 | inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) 12 | self.register_buffer('inv_freq', inv_freq) 13 | self.seq_len_cached = None 14 | self.cos_cached = None 15 | self.sin_cached = None 16 | 17 | def forward(self, x, seq_dim=1): 18 | seq_len = x.shape[seq_dim] 19 | if seq_len != self.seq_len_cached: 20 | self.seq_len_cached = seq_len 21 | t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) 22 | freqs = torch.einsum('i,j->ij', t, self.inv_freq) 23 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 24 | self.cos_cached = emb.cos()[:, None, None, :] 25 | self.sin_cached = emb.sin()[:, None, None, :] 26 | return self.cos_cached, self.sin_cached 27 | 28 | # rotary pos emb helpers: 29 | def rotate_half(x): 30 | x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:] 31 | return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0 32 | 33 | @torch.jit.script 34 | def apply_rotary_pos_emb(q, k, cos, sin): 35 | return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) 36 | -------------------------------------------------------------------------------- /torchtools/nn/simple_self_attention.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch, math, sys 3 | 4 | #### 5 | # CODE TAKEN FROM https://github.com/sdoria/SimpleSelfAttention 6 | #### 7 | 8 | #Unmodified from https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py 9 | def conv1d(ni, no, ks=1, stride=1, padding=0, bias=False): 10 | "Create and initialize a `nn.Conv1d` layer with spectral normalization." 11 | conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias) 12 | nn.init.kaiming_normal_(conv.weight) 13 | if bias: conv.bias.data.zero_() 14 | return nn.utils.spectral_norm(conv) 15 | 16 | # Adapted from SelfAttention layer at https://github.com/fastai/fastai/blob/5c51f9eabf76853a89a9bc5741804d2ed4407e49/fastai/layers.py 17 | # Inspired by https://arxiv.org/pdf/1805.08318.pdf 18 | class SimpleSelfAttention(nn.Module): 19 | 20 | def __init__(self, n_in, ks=1, sym=False): 21 | super().__init__() 22 | self.conv = conv1d(n_in, n_in, ks, padding=ks//2, bias=False) 23 | self.gamma = nn.Parameter(torch.Tensor([0.])) 24 | self.sym = sym 25 | self.n_in = n_in 26 | 27 | def forward(self, x): 28 | if self.sym: 29 | # symmetry hack by https://github.com/mgrankin 30 | c = self.conv.weight.view(self.n_in,self.n_in) 31 | c = (c + c.t())/2 32 | self.conv.weight = c.view(self.n_in,self.n_in,1) 33 | 34 | size = x.size() 35 | x = x.view(*size[:2],-1) # (C,N) 36 | 37 | # changed the order of mutiplication to avoid O(N^2) complexity 38 | # (x*xT)*(W*x) instead of (x*(xT*(W*x))) 39 | 40 | convx = self.conv(x) # (C,C) * (C,N) = (C,N) => O(NC^2) 41 | xxT = torch.bmm(x, x.permute(0,2,1).contiguous()) # (C,N) * (N,C) = (C,C) => O(NC^2) 42 | o = torch.bmm(xxT, convx) # (C,C) * (C,N) = (C,N) => O(NC^2) 43 | o = self.gamma * o + x 44 | 45 | return o.view(*size).contiguous() -------------------------------------------------------------------------------- /torchtools/nn/stylegan2/__init__.py: -------------------------------------------------------------------------------- 1 | from .upfirdn2d import upfirdn2d -------------------------------------------------------------------------------- /torchtools/nn/stylegan2/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 5 | const torch::Tensor &kernel, int up_x, int up_y, 6 | int down_x, int down_y, int pad_x0, int pad_x1, 7 | int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) \ 10 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) \ 12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) \ 14 | CHECK_CUDA(x); \ 15 | CHECK_CONTIGUOUS(x) 16 | 17 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel, 18 | int up_x, int up_y, int down_x, int down_y, int pad_x0, 19 | int pad_x1, int pad_y0, int pad_y1) { 20 | CHECK_INPUT(input); 21 | CHECK_INPUT(kernel); 22 | 23 | at::DeviceGuard guard(input.device()); 24 | 25 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, 26 | pad_y0, pad_y1); 27 | } 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 31 | } -------------------------------------------------------------------------------- /torchtools/nn/stylegan2/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | import os 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | upfirdn2d_op = load( 12 | "upfirdn2d", 13 | sources=[ 14 | os.path.join(module_path, "upfirdn2d.cpp"), 15 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class UpFirDn2dBackward(Function): 21 | @staticmethod 22 | def forward( 23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 24 | ): 25 | 26 | up_x, up_y = up 27 | down_x, down_y = down 28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 29 | 30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 31 | 32 | grad_input = upfirdn2d_op.upfirdn2d( 33 | grad_output, 34 | grad_kernel, 35 | down_x, 36 | down_y, 37 | up_x, 38 | up_y, 39 | g_pad_x0, 40 | g_pad_x1, 41 | g_pad_y0, 42 | g_pad_y1, 43 | ) 44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 45 | 46 | ctx.save_for_backward(kernel) 47 | 48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 49 | 50 | ctx.up_x = up_x 51 | ctx.up_y = up_y 52 | ctx.down_x = down_x 53 | ctx.down_y = down_y 54 | ctx.pad_x0 = pad_x0 55 | ctx.pad_x1 = pad_x1 56 | ctx.pad_y0 = pad_y0 57 | ctx.pad_y1 = pad_y1 58 | ctx.in_size = in_size 59 | ctx.out_size = out_size 60 | 61 | return grad_input 62 | 63 | @staticmethod 64 | def backward(ctx, gradgrad_input): 65 | kernel, = ctx.saved_tensors 66 | 67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 68 | 69 | gradgrad_out = upfirdn2d_op.upfirdn2d( 70 | gradgrad_input, 71 | kernel, 72 | ctx.up_x, 73 | ctx.up_y, 74 | ctx.down_x, 75 | ctx.down_y, 76 | ctx.pad_x0, 77 | ctx.pad_x1, 78 | ctx.pad_y0, 79 | ctx.pad_y1, 80 | ) 81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 82 | gradgrad_out = gradgrad_out.view( 83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 84 | ) 85 | 86 | return gradgrad_out, None, None, None, None, None, None, None, None 87 | 88 | 89 | class UpFirDn2d(Function): 90 | @staticmethod 91 | def forward(ctx, input, kernel, up, down, pad): 92 | up_x, up_y = up 93 | down_x, down_y = down 94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 95 | 96 | kernel_h, kernel_w = kernel.shape 97 | batch, channel, in_h, in_w = input.shape 98 | ctx.in_size = input.shape 99 | 100 | input = input.reshape(-1, in_h, in_w, 1) 101 | 102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 103 | 104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 106 | ctx.out_size = (out_h, out_w) 107 | 108 | ctx.up = (up_x, up_y) 109 | ctx.down = (down_x, down_y) 110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 111 | 112 | g_pad_x0 = kernel_w - pad_x0 - 1 113 | g_pad_y0 = kernel_h - pad_y0 - 1 114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 116 | 117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 118 | 119 | out = upfirdn2d_op.upfirdn2d( 120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 121 | ) 122 | # out = out.view(major, out_h, out_w, minor) 123 | out = out.view(-1, channel, out_h, out_w) 124 | 125 | return out 126 | 127 | @staticmethod 128 | def backward(ctx, grad_output): 129 | kernel, grad_kernel = ctx.saved_tensors 130 | 131 | grad_input = None 132 | 133 | if ctx.needs_input_grad[0]: 134 | grad_input = UpFirDn2dBackward.apply( 135 | grad_output, 136 | kernel, 137 | grad_kernel, 138 | ctx.up, 139 | ctx.down, 140 | ctx.pad, 141 | ctx.g_pad, 142 | ctx.in_size, 143 | ctx.out_size, 144 | ) 145 | 146 | return grad_input, None, None, None, None 147 | 148 | 149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 150 | if not isinstance(up, abc.Iterable): 151 | up = (up, up) 152 | 153 | if not isinstance(down, abc.Iterable): 154 | down = (down, down) 155 | 156 | if len(pad) == 2: 157 | pad = (pad[0], pad[1], pad[0], pad[1]) 158 | 159 | if input.device.type == "cpu": 160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad) 161 | 162 | else: 163 | out = UpFirDn2d.apply(input, kernel, up, down, pad) 164 | 165 | return out 166 | 167 | 168 | def upfirdn2d_native( 169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 170 | ): 171 | _, channel, in_h, in_w = input.shape 172 | input = input.reshape(-1, in_h, in_w, 1) 173 | 174 | _, in_h, in_w, minor = input.shape 175 | kernel_h, kernel_w = kernel.shape 176 | 177 | out = input.view(-1, in_h, 1, in_w, 1, minor) 178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 180 | 181 | out = F.pad( 182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 183 | ) 184 | out = out[ 185 | :, 186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 188 | :, 189 | ] 190 | 191 | out = out.permute(0, 3, 1, 2) 192 | out = out.reshape( 193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 194 | ) 195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 196 | out = F.conv2d(out, w) 197 | out = out.reshape( 198 | -1, 199 | minor, 200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 202 | ) 203 | out = out.permute(0, 2, 3, 1) 204 | out = out[:, ::down_y, ::down_x, :] 205 | 206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 208 | 209 | return out.view(-1, channel, out_h, out_w) 210 | -------------------------------------------------------------------------------- /torchtools/nn/stylegan2/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __device__ __forceinline__ int floor_div(int a, int b) { 18 | int t = 1 - a / b; 19 | return (a + t * b) / b - t; 20 | } 21 | 22 | struct UpFirDn2DKernelParams { 23 | int up_x; 24 | int up_y; 25 | int down_x; 26 | int down_y; 27 | int pad_x0; 28 | int pad_x1; 29 | int pad_y0; 30 | int pad_y1; 31 | 32 | int major_dim; 33 | int in_h; 34 | int in_w; 35 | int minor_dim; 36 | int kernel_h; 37 | int kernel_w; 38 | int out_h; 39 | int out_w; 40 | int loop_major; 41 | int loop_x; 42 | }; 43 | 44 | template 45 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 46 | const scalar_t *kernel, 47 | const UpFirDn2DKernelParams p) { 48 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 49 | int out_y = minor_idx / p.minor_dim; 50 | minor_idx -= out_y * p.minor_dim; 51 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 52 | int major_idx_base = blockIdx.z * p.loop_major; 53 | 54 | if (out_x_base >= p.out_w || out_y >= p.out_h || 55 | major_idx_base >= p.major_dim) { 56 | return; 57 | } 58 | 59 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 60 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 61 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 62 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 63 | 64 | for (int loop_major = 0, major_idx = major_idx_base; 65 | loop_major < p.loop_major && major_idx < p.major_dim; 66 | loop_major++, major_idx++) { 67 | for (int loop_x = 0, out_x = out_x_base; 68 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 69 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 70 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 71 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 72 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 73 | 74 | const scalar_t *x_p = 75 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 76 | minor_idx]; 77 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 78 | int x_px = p.minor_dim; 79 | int k_px = -p.up_x; 80 | int x_py = p.in_w * p.minor_dim; 81 | int k_py = -p.up_y * p.kernel_w; 82 | 83 | scalar_t v = 0.0f; 84 | 85 | for (int y = 0; y < h; y++) { 86 | for (int x = 0; x < w; x++) { 87 | v += static_cast(*x_p) * static_cast(*k_p); 88 | x_p += x_px; 89 | k_p += k_px; 90 | } 91 | 92 | x_p += x_py - w * x_px; 93 | k_p += k_py - w * k_px; 94 | } 95 | 96 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 97 | minor_idx] = v; 98 | } 99 | } 100 | } 101 | 102 | template 104 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 105 | const scalar_t *kernel, 106 | const UpFirDn2DKernelParams p) { 107 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 108 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 109 | 110 | __shared__ volatile float sk[kernel_h][kernel_w]; 111 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 112 | 113 | int minor_idx = blockIdx.x; 114 | int tile_out_y = minor_idx / p.minor_dim; 115 | minor_idx -= tile_out_y * p.minor_dim; 116 | tile_out_y *= tile_out_h; 117 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 118 | int major_idx_base = blockIdx.z * p.loop_major; 119 | 120 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 121 | major_idx_base >= p.major_dim) { 122 | return; 123 | } 124 | 125 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 126 | tap_idx += blockDim.x) { 127 | int ky = tap_idx / kernel_w; 128 | int kx = tap_idx - ky * kernel_w; 129 | scalar_t v = 0.0; 130 | 131 | if (kx < p.kernel_w & ky < p.kernel_h) { 132 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 133 | } 134 | 135 | sk[ky][kx] = v; 136 | } 137 | 138 | for (int loop_major = 0, major_idx = major_idx_base; 139 | loop_major < p.loop_major & major_idx < p.major_dim; 140 | loop_major++, major_idx++) { 141 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 142 | loop_x < p.loop_x & tile_out_x < p.out_w; 143 | loop_x++, tile_out_x += tile_out_w) { 144 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 145 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 146 | int tile_in_x = floor_div(tile_mid_x, up_x); 147 | int tile_in_y = floor_div(tile_mid_y, up_y); 148 | 149 | __syncthreads(); 150 | 151 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 152 | in_idx += blockDim.x) { 153 | int rel_in_y = in_idx / tile_in_w; 154 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 155 | int in_x = rel_in_x + tile_in_x; 156 | int in_y = rel_in_y + tile_in_y; 157 | 158 | scalar_t v = 0.0; 159 | 160 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 161 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 162 | p.minor_dim + 163 | minor_idx]; 164 | } 165 | 166 | sx[rel_in_y][rel_in_x] = v; 167 | } 168 | 169 | __syncthreads(); 170 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 171 | out_idx += blockDim.x) { 172 | int rel_out_y = out_idx / tile_out_w; 173 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 174 | int out_x = rel_out_x + tile_out_x; 175 | int out_y = rel_out_y + tile_out_y; 176 | 177 | int mid_x = tile_mid_x + rel_out_x * down_x; 178 | int mid_y = tile_mid_y + rel_out_y * down_y; 179 | int in_x = floor_div(mid_x, up_x); 180 | int in_y = floor_div(mid_y, up_y); 181 | int rel_in_x = in_x - tile_in_x; 182 | int rel_in_y = in_y - tile_in_y; 183 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 184 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 185 | 186 | if (out_x < p.out_w & out_y < p.out_h) { 187 | scalar_t v = 0.0; 188 | 189 | #pragma unroll 190 | for (int y = 0; y < kernel_h / up_y; y++) 191 | #pragma unroll 192 | for (int x = 0; x < kernel_w / up_x; x++) 193 | v += sx[rel_in_y + y][rel_in_x + x] * 194 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 195 | 196 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 197 | minor_idx] = v; 198 | } 199 | } 200 | } 201 | } 202 | } 203 | 204 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 205 | const torch::Tensor &kernel, int up_x, int up_y, 206 | int down_x, int down_y, int pad_x0, int pad_x1, 207 | int pad_y0, int pad_y1) { 208 | int curDevice = -1; 209 | cudaGetDevice(&curDevice); 210 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 211 | 212 | UpFirDn2DKernelParams p; 213 | 214 | auto x = input.contiguous(); 215 | auto k = kernel.contiguous(); 216 | 217 | p.major_dim = x.size(0); 218 | p.in_h = x.size(1); 219 | p.in_w = x.size(2); 220 | p.minor_dim = x.size(3); 221 | p.kernel_h = k.size(0); 222 | p.kernel_w = k.size(1); 223 | p.up_x = up_x; 224 | p.up_y = up_y; 225 | p.down_x = down_x; 226 | p.down_y = down_y; 227 | p.pad_x0 = pad_x0; 228 | p.pad_x1 = pad_x1; 229 | p.pad_y0 = pad_y0; 230 | p.pad_y1 = pad_y1; 231 | 232 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 233 | p.down_y; 234 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 235 | p.down_x; 236 | 237 | auto out = 238 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 239 | 240 | int mode = -1; 241 | 242 | int tile_out_h = -1; 243 | int tile_out_w = -1; 244 | 245 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 246 | void *cuda_kernel = (void *)upfirdn2d_kernel_large; 247 | 248 | if (p.up_x == 2 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 249 | p.kernel_h <= 1 && p.kernel_w <= 24) { 250 | cuda_kernel = 251 | (void *)upfirdn2d_kernel; 252 | tile_out_h = 8; 253 | tile_out_w = 128; 254 | } 255 | 256 | if (p.up_x == 2 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 257 | p.kernel_h <= 1 && p.kernel_w <= 12) { 258 | cuda_kernel = 259 | (void *)upfirdn2d_kernel; 260 | tile_out_h = 8; 261 | tile_out_w = 128; 262 | } 263 | 264 | if (p.up_x == 1 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 24 && p.kernel_w <= 1) { 266 | cuda_kernel = 267 | (void *)upfirdn2d_kernel; 268 | tile_out_h = 32; 269 | tile_out_w = 32; 270 | } 271 | 272 | if (p.up_x == 1 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 273 | p.kernel_h <= 12 && p.kernel_w <= 1) { 274 | cuda_kernel = 275 | (void *)upfirdn2d_kernel; 276 | tile_out_h = 32; 277 | tile_out_w = 32; 278 | } 279 | 280 | // 281 | 282 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 1 && 283 | p.kernel_h <= 1 && p.kernel_w <= 24) { 284 | cuda_kernel = 285 | (void *)upfirdn2d_kernel; 286 | tile_out_h = 8; 287 | tile_out_w = 64; 288 | } 289 | 290 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 1 && 291 | p.kernel_h <= 1 && p.kernel_w <= 12) { 292 | cuda_kernel = 293 | (void *)upfirdn2d_kernel; 294 | tile_out_h = 8; 295 | tile_out_w = 64; 296 | } 297 | 298 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 2 && 299 | p.kernel_h <= 24 && p.kernel_w <= 1) { 300 | cuda_kernel = 301 | (void *)upfirdn2d_kernel; 302 | tile_out_h = 16; 303 | tile_out_w = 32; 304 | } 305 | 306 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 2 && 307 | p.kernel_h <= 12 && p.kernel_w <= 1) { 308 | cuda_kernel = 309 | (void *)upfirdn2d_kernel; 310 | tile_out_h = 16; 311 | tile_out_w = 32; 312 | } 313 | 314 | // 315 | 316 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 317 | p.kernel_h <= 4 && p.kernel_w <= 4) { 318 | cuda_kernel = 319 | (void *)upfirdn2d_kernel; 320 | tile_out_h = 16; 321 | tile_out_w = 64; 322 | } 323 | 324 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 325 | p.kernel_h <= 3 && p.kernel_w <= 3) { 326 | cuda_kernel = 327 | (void *)upfirdn2d_kernel; 328 | tile_out_h = 16; 329 | tile_out_w = 64; 330 | } 331 | 332 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 333 | p.kernel_h <= 4 && p.kernel_w <= 4) { 334 | cuda_kernel = 335 | (void *)upfirdn2d_kernel; 336 | tile_out_h = 16; 337 | tile_out_w = 64; 338 | } 339 | 340 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 341 | p.kernel_h <= 2 && p.kernel_w <= 2) { 342 | cuda_kernel = 343 | (void *)upfirdn2d_kernel; 344 | tile_out_h = 16; 345 | tile_out_w = 64; 346 | } 347 | 348 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 349 | p.kernel_h <= 4 && p.kernel_w <= 4) { 350 | cuda_kernel = (void *)upfirdn2d_kernel; 351 | tile_out_h = 8; 352 | tile_out_w = 32; 353 | } 354 | 355 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 356 | p.kernel_h <= 2 && p.kernel_w <= 2) { 357 | cuda_kernel = (void *)upfirdn2d_kernel; 358 | tile_out_h = 8; 359 | tile_out_w = 32; 360 | } 361 | 362 | dim3 block_size; 363 | dim3 grid_size; 364 | 365 | if (tile_out_h > 0 && tile_out_w > 0) { 366 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 367 | p.loop_x = 1; 368 | block_size = dim3(32 * 8, 1, 1); 369 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 370 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 371 | (p.major_dim - 1) / p.loop_major + 1); 372 | } else { 373 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 374 | p.loop_x = 4; 375 | block_size = dim3(4, 32, 1); 376 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 377 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 378 | (p.major_dim - 1) / p.loop_major + 1); 379 | } 380 | 381 | scalar_t *out_p = out.data_ptr(); 382 | scalar_t *x_p = x.data_ptr(); 383 | scalar_t *k_p = k.data_ptr(); 384 | 385 | void *args[] = {&out_p, &x_p, &k_p, &p}; 386 | AT_CUDA_CHECK( 387 | cudaLaunchKernel(cuda_kernel, grid_size, block_size, args, 0, stream)); 388 | }); 389 | 390 | return out; 391 | } -------------------------------------------------------------------------------- /torchtools/nn/transformers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | # Based on the GPT2 implementatyion from MinGPT https://github.com/karpathy/minGPT by Andrej Karpathy 5 | class GPTTransformerEncoderLayer(nn.Module): 6 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0): 7 | super().__init__() 8 | self.ln1 = nn.LayerNorm(d_model) 9 | self.ln2 = nn.LayerNorm(d_model) 10 | self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 11 | self.mlp = nn.Sequential( 12 | nn.Linear(d_model, dim_feedforward), 13 | nn.GELU(), 14 | nn.Linear(dim_feedforward, d_model), 15 | nn.Dropout(dropout), 16 | ) 17 | 18 | def forward(self, x, src_mask=None, src_key_padding_mask=None): 19 | x = self.ln1(x) 20 | x = x + self.attn(x, x, x, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] 21 | x = x + self.mlp(self.ln2(x)) 22 | return x -------------------------------------------------------------------------------- /torchtools/nn/vq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .functional.vq import vector_quantize, binarize 4 | import numpy as np 5 | 6 | class VectorQuantize(nn.Module): 7 | def __init__(self, embedding_size, k, ema_decay=0.99, ema_loss=False): 8 | """ 9 | Takes an input of variable size (as long as the last dimension matches the embedding size). 10 | Returns one tensor containing the nearest neigbour embeddings to each of the inputs, 11 | with the same size as the input, vq and commitment components for the loss as a touple 12 | in the second output and the indices of the quantized vectors in the third: 13 | quantized, (vq_loss, commit_loss), indices 14 | """ 15 | super(VectorQuantize, self).__init__() 16 | 17 | self.codebook = nn.Embedding(k, embedding_size) 18 | self.codebook.weight.data.uniform_(-1./k, 1./k) 19 | self.vq = vector_quantize.apply 20 | 21 | self.ema_decay = ema_decay 22 | self.ema_loss = ema_loss 23 | if ema_loss: 24 | self.register_buffer('ema_element_count', torch.ones(k)) 25 | self.register_buffer('ema_weight_sum', torch.zeros_like(self.codebook.weight)) 26 | 27 | def _laplace_smoothing(self, x, epsilon): 28 | n = torch.sum(x) 29 | return ((x + epsilon) / (n + x.size(0) * epsilon) * n) 30 | 31 | def _updateEMA(self, z_e_x, indices): 32 | mask = nn.functional.one_hot(indices, self.ema_element_count.size(0)).float() 33 | elem_count = mask.sum(dim=0) 34 | weight_sum = torch.mm(mask.t(), z_e_x) 35 | 36 | self.ema_element_count = (self.ema_decay * self.ema_element_count) + ((1-self.ema_decay) * elem_count) 37 | self.ema_element_count = self._laplace_smoothing(self.ema_element_count, 1e-5) 38 | self.ema_weight_sum = (self.ema_decay * self.ema_weight_sum) + ((1-self.ema_decay) * weight_sum) 39 | 40 | self.codebook.weight.data = self.ema_weight_sum / self.ema_element_count.unsqueeze(-1) 41 | 42 | def idx2vq(self, idx, dim=-1): 43 | q_idx = self.codebook(idx) 44 | if dim != -1: 45 | q_idx = q_idx.movedim(-1, dim) 46 | return q_idx 47 | 48 | def forward(self, x, get_losses=True, dim=-1): 49 | if dim != -1: 50 | x = x.movedim(dim, -1) 51 | z_e_x = x.contiguous().view(-1, x.size(-1)) if len(x.shape) > 2 else x 52 | z_q_x, indices = self.vq(z_e_x, self.codebook.weight.detach()) 53 | vq_loss, commit_loss = None, None 54 | if self.ema_loss and self.training: 55 | self._updateEMA(z_e_x.detach(), indices.detach()) 56 | # pick the graded embeddings after updating the codebook in order to have a more accurate commitment loss 57 | z_q_x_grd = torch.index_select(self.codebook.weight, dim=0, index=indices) 58 | if get_losses: 59 | vq_loss = (z_q_x_grd - z_e_x.detach()).pow(2).mean() 60 | commit_loss = (z_e_x - z_q_x_grd.detach()).pow(2).mean() 61 | 62 | z_q_x = z_q_x.view(x.shape) 63 | if dim != -1: 64 | z_q_x = z_q_x.movedim(-1, dim) 65 | return z_q_x, (vq_loss, commit_loss), indices.view(x.shape[:-1]) 66 | 67 | class Binarize(nn.Module): 68 | def __init__(self, threshold=0.5): 69 | """ 70 | Takes an input of any size. 71 | Returns an output of the same size but with its values binarized (0 if input is below a threshold, 1 if its above) 72 | """ 73 | super(Binarize, self).__init__() 74 | 75 | self.bin = binarize.apply 76 | self.threshold = threshold 77 | 78 | def forward(self, x): 79 | return self.bin(x, self.threshold) 80 | 81 | # Finite Scalar Quantization: https://arxiv.org/abs/2309.15505 82 | class FSQ(nn.Module): 83 | def __init__(self, bins, dim=-1, eps=1e-1): 84 | super().__init__() 85 | self.dim = dim 86 | self.eps = eps 87 | self.register_buffer('bins', torch.tensor(bins)) 88 | self.register_buffer('bases', torch.tensor([1] + np.cumprod(bins[:-1]).tolist())) 89 | self.codebook_size = np.prod(bins) 90 | 91 | self.in_shift, self.out_shift = None, None 92 | 93 | def _round(self, x, quantize): 94 | x = x.sigmoid() * (1-1e-7) 95 | if quantize is True: 96 | x_rounded = x.sub(1/(self.bins*2)).mul(self.bins).round().div(self.bins).div(1-1/self.bins) 97 | x = x + (x_rounded - x).detach() 98 | x_sigmoid = x 99 | x = (x / (1-1e-7)).logit(eps=self.eps) 100 | return x, x_sigmoid 101 | 102 | def vq_to_idx(self, x, is_sigmoid=False): 103 | if not is_sigmoid: 104 | x = x.sigmoid() * (1-1e-7) 105 | x = x.sub(1/(self.bins*2)).mul(self.bins).round().div(self.bins).div(1-1/self.bins) 106 | x = x.mul(self.bins-1).long() 107 | x = (x * self.bases).sum(dim=-1).long() 108 | return x 109 | 110 | def idx_to_vq(self, x): 111 | x = x.unsqueeze(-1) // self.bases % self.bins 112 | x = x.div(self.bins-1) 113 | x = (x / (1-1e-7)).logit(eps=self.eps) 114 | if self.dim != -1: 115 | x = x.movedim(-1, self.dim) 116 | return x 117 | 118 | def forward(self, x, quantize=True): 119 | if self.dim != -1: 120 | x = x.movedim(self.dim, -1) 121 | 122 | x, x_sigmoid = self._round(x, quantize=quantize) 123 | idx = self.vq_to_idx(x_sigmoid, is_sigmoid=True) 124 | 125 | if self.dim != -1: 126 | x = x.movedim(-1, self.dim) 127 | return x, idx 128 | -------------------------------------------------------------------------------- /torchtools/optim/__init__.py: -------------------------------------------------------------------------------- 1 | from .radam import RAdam, PlainRAdam, AdamW 2 | from .ranger import Ranger 3 | from .lookahead import Lookahead, LookaheadAdam 4 | from .over9000 import Over9000, RangerLars 5 | from .ralamb import Ralamb 6 | from .novograd import Novograd 7 | from .lamb import Lamb 8 | -------------------------------------------------------------------------------- /torchtools/optim/lamb.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/mgrankin/over9000 3 | #### 4 | 5 | import collections 6 | import math 7 | 8 | import torch 9 | from torch.optim import Optimizer 10 | 11 | try: 12 | from tensorboardX import SummaryWriter 13 | 14 | def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): 15 | """Log a histogram of trust ratio scalars in across layers.""" 16 | results = collections.defaultdict(list) 17 | for group in optimizer.param_groups: 18 | for p in group['params']: 19 | state = optimizer.state[p] 20 | for i in ('weight_norm', 'adam_norm', 'trust_ratio'): 21 | if i in state: 22 | results[i].append(state[i]) 23 | 24 | for k, v in results.items(): 25 | event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) 26 | except ModuleNotFoundError as e: 27 | print("To use this log_lamb_rs, please run 'pip install tensorboardx'. Also you must have Tensorboard running to see results") 28 | 29 | class Lamb(Optimizer): 30 | r"""Implements Lamb algorithm. 31 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 32 | Arguments: 33 | params (iterable): iterable of parameters to optimize or dicts defining 34 | parameter groups 35 | lr (float, optional): learning rate (default: 1e-3) 36 | betas (Tuple[float, float], optional): coefficients used for computing 37 | running averages of gradient and its square (default: (0.9, 0.999)) 38 | eps (float, optional): term added to the denominator to improve 39 | numerical stability (default: 1e-8) 40 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 41 | adam (bool, optional): always use trust ratio = 1, which turns this into 42 | Adam. Useful for comparison purposes. 43 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 44 | https://arxiv.org/abs/1904.00962 45 | """ 46 | 47 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 48 | weight_decay=0, adam=False): 49 | if not 0.0 <= lr: 50 | raise ValueError("Invalid learning rate: {}".format(lr)) 51 | if not 0.0 <= eps: 52 | raise ValueError("Invalid epsilon value: {}".format(eps)) 53 | if not 0.0 <= betas[0] < 1.0: 54 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 55 | if not 0.0 <= betas[1] < 1.0: 56 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 57 | defaults = dict(lr=lr, betas=betas, eps=eps, 58 | weight_decay=weight_decay) 59 | self.adam = adam 60 | super(Lamb, self).__init__(params, defaults) 61 | 62 | def step(self, closure=None): 63 | """Performs a single optimization step. 64 | Arguments: 65 | closure (callable, optional): A closure that reevaluates the model 66 | and returns the loss. 67 | """ 68 | loss = None 69 | if closure is not None: 70 | loss = closure() 71 | 72 | for group in self.param_groups: 73 | for p in group['params']: 74 | if p.grad is None: 75 | continue 76 | grad = p.grad.data 77 | if grad.is_sparse: 78 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 79 | 80 | state = self.state[p] 81 | 82 | # State initialization 83 | if len(state) == 0: 84 | state['step'] = 0 85 | # Exponential moving average of gradient values 86 | state['exp_avg'] = torch.zeros_like(p.data) 87 | # Exponential moving average of squared gradient values 88 | state['exp_avg_sq'] = torch.zeros_like(p.data) 89 | 90 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 91 | beta1, beta2 = group['betas'] 92 | 93 | state['step'] += 1 94 | 95 | # Decay the first and second moment running average coefficient 96 | # m_t 97 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 98 | # v_t 99 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 100 | 101 | # Paper v3 does not use debiasing. 102 | # bias_correction1 = 1 - beta1 ** state['step'] 103 | # bias_correction2 = 1 - beta2 ** state['step'] 104 | # Apply bias to lr to avoid broadcast. 105 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 106 | 107 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 108 | 109 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 110 | if group['weight_decay'] != 0: 111 | adam_step.add_(group['weight_decay'], p.data) 112 | 113 | adam_norm = adam_step.pow(2).sum().sqrt() 114 | if weight_norm == 0 or adam_norm == 0: 115 | trust_ratio = 1 116 | else: 117 | trust_ratio = weight_norm / adam_norm 118 | state['weight_norm'] = weight_norm 119 | state['adam_norm'] = adam_norm 120 | state['trust_ratio'] = trust_ratio 121 | if self.adam: 122 | trust_ratio = 1 123 | 124 | p.data.add_(-step_size * trust_ratio, adam_step) 125 | 126 | return loss -------------------------------------------------------------------------------- /torchtools/optim/lookahead.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/lonePatient/lookahead_pytorch 3 | # Original paper: https://arxiv.org/abs/1907.08610 4 | #### 5 | # Lookahead implementation from https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lookahead.py 6 | 7 | """ Lookahead Optimizer Wrapper. 8 | Implementation modified from: https://github.com/alphadl/lookahead.pytorch 9 | Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610 10 | """ 11 | import torch 12 | from torch.optim.optimizer import Optimizer 13 | from collections import defaultdict 14 | 15 | 16 | class Lookahead(Optimizer): 17 | def __init__(self, base_optimizer, alpha=0.5, k=6): 18 | if not 0.0 <= alpha <= 1.0: 19 | raise ValueError(f'Invalid slow update rate: {alpha}') 20 | if not 1 <= k: 21 | raise ValueError(f'Invalid lookahead steps: {k}') 22 | defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0) 23 | self.base_optimizer = base_optimizer 24 | self.param_groups = self.base_optimizer.param_groups 25 | self.defaults = base_optimizer.defaults 26 | self.defaults.update(defaults) 27 | self.state = defaultdict(dict) 28 | # manually add our defaults to the param groups 29 | for name, default in defaults.items(): 30 | for group in self.param_groups: 31 | group.setdefault(name, default) 32 | 33 | def update_slow(self, group): 34 | for fast_p in group["params"]: 35 | if fast_p.grad is None: 36 | continue 37 | param_state = self.state[fast_p] 38 | if 'slow_buffer' not in param_state: 39 | param_state['slow_buffer'] = torch.empty_like(fast_p.data) 40 | param_state['slow_buffer'].copy_(fast_p.data) 41 | slow = param_state['slow_buffer'] 42 | slow.add_(group['lookahead_alpha'], fast_p.data - slow) 43 | fast_p.data.copy_(slow) 44 | 45 | def sync_lookahead(self): 46 | for group in self.param_groups: 47 | self.update_slow(group) 48 | 49 | def step(self, closure=None): 50 | # print(self.k) 51 | # assert id(self.param_groups) == id(self.base_optimizer.param_groups) 52 | loss = self.base_optimizer.step(closure) 53 | for group in self.param_groups: 54 | group['lookahead_step'] += 1 55 | if group['lookahead_step'] % group['lookahead_k'] == 0: 56 | self.update_slow(group) 57 | return loss 58 | 59 | def state_dict(self): 60 | fast_state_dict = self.base_optimizer.state_dict() 61 | slow_state = { 62 | (id(k) if isinstance(k, torch.Tensor) else k): v 63 | for k, v in self.state.items() 64 | } 65 | fast_state = fast_state_dict['state'] 66 | param_groups = fast_state_dict['param_groups'] 67 | return { 68 | 'state': fast_state, 69 | 'slow_state': slow_state, 70 | 'param_groups': param_groups, 71 | } 72 | 73 | def load_state_dict(self, state_dict): 74 | fast_state_dict = { 75 | 'state': state_dict['state'], 76 | 'param_groups': state_dict['param_groups'], 77 | } 78 | self.base_optimizer.load_state_dict(fast_state_dict) 79 | 80 | # We want to restore the slow state, but share param_groups reference 81 | # with base_optimizer. This is a bit redundant but least code 82 | slow_state_new = False 83 | if 'slow_state' not in state_dict: 84 | print('Loading state_dict from optimizer without Lookahead applied.') 85 | state_dict['slow_state'] = defaultdict(dict) 86 | slow_state_new = True 87 | slow_state_dict = { 88 | 'state': state_dict['slow_state'], 89 | 'param_groups': state_dict['param_groups'], # this is pointless but saves code 90 | } 91 | super(Lookahead, self).load_state_dict(slow_state_dict) 92 | self.param_groups = self.base_optimizer.param_groups # make both ref same container 93 | if slow_state_new: 94 | # reapply defaults to catch missing lookahead specific ones 95 | for name, default in self.defaults.items(): 96 | for group in self.param_groups: 97 | group.setdefault(name, default) 98 | 99 | 100 | def LookaheadAdam(params, alpha=0.5, k=6, *args, **kwargs): 101 | adam = Adam(params, *args, **kwargs) 102 | return Lookahead(adam, alpha, k) 103 | -------------------------------------------------------------------------------- /torchtools/optim/novograd.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/mgrankin/over9000 3 | #### 4 | 5 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # http://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | import torch 20 | from torch.optim import Optimizer 21 | import math 22 | 23 | 24 | class AdamW(Optimizer): 25 | """Implements AdamW algorithm. 26 | 27 | It has been proposed in `Adam: A Method for Stochastic Optimization`_. 28 | 29 | Arguments: 30 | params (iterable): iterable of parameters to optimize or dicts defining 31 | parameter groups 32 | lr (float, optional): learning rate (default: 1e-3) 33 | betas (Tuple[float, float], optional): coefficients used for computing 34 | running averages of gradient and its square (default: (0.9, 0.999)) 35 | eps (float, optional): term added to the denominator to improve 36 | numerical stability (default: 1e-8) 37 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 38 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 39 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 40 | 41 | Adam: A Method for Stochastic Optimization: 42 | https://arxiv.org/abs/1412.6980 43 | On the Convergence of Adam and Beyond: 44 | https://openreview.net/forum?id=ryQu7f-RZ 45 | """ 46 | 47 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, 48 | weight_decay=0, amsgrad=False): 49 | if not 0.0 <= lr: 50 | raise ValueError("Invalid learning rate: {}".format(lr)) 51 | if not 0.0 <= eps: 52 | raise ValueError("Invalid epsilon value: {}".format(eps)) 53 | if not 0.0 <= betas[0] < 1.0: 54 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 55 | if not 0.0 <= betas[1] < 1.0: 56 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 57 | defaults = dict(lr=lr, betas=betas, eps=eps, 58 | weight_decay=weight_decay, amsgrad=amsgrad) 59 | super(AdamW, self).__init__(params, defaults) 60 | 61 | def __setstate__(self, state): 62 | super(AdamW, self).__setstate__(state) 63 | for group in self.param_groups: 64 | group.setdefault('amsgrad', False) 65 | 66 | def step(self, closure=None): 67 | """Performs a single optimization step. 68 | 69 | Arguments: 70 | closure (callable, optional): A closure that reevaluates the model 71 | and returns the loss. 72 | """ 73 | loss = None 74 | if closure is not None: 75 | loss = closure() 76 | 77 | for group in self.param_groups: 78 | for p in group['params']: 79 | if p.grad is None: 80 | continue 81 | grad = p.grad.data 82 | if grad.is_sparse: 83 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 84 | amsgrad = group['amsgrad'] 85 | 86 | state = self.state[p] 87 | 88 | # State initialization 89 | if len(state) == 0: 90 | state['step'] = 0 91 | # Exponential moving average of gradient values 92 | state['exp_avg'] = torch.zeros_like(p.data) 93 | # Exponential moving average of squared gradient values 94 | state['exp_avg_sq'] = torch.zeros_like(p.data) 95 | if amsgrad: 96 | # Maintains max of all exp. moving avg. of sq. grad. values 97 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 98 | 99 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 100 | if amsgrad: 101 | max_exp_avg_sq = state['max_exp_avg_sq'] 102 | beta1, beta2 = group['betas'] 103 | 104 | state['step'] += 1 105 | # Decay the first and second moment running average coefficient 106 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 107 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 108 | if amsgrad: 109 | # Maintains the maximum of all 2nd moment running avg. till now 110 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 111 | # Use the max. for normalizing running avg. of gradient 112 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 113 | else: 114 | denom = exp_avg_sq.sqrt().add_(group['eps']) 115 | 116 | bias_correction1 = 1 - beta1 ** state['step'] 117 | bias_correction2 = 1 - beta2 ** state['step'] 118 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 119 | p.data.add_(-step_size, torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom)) 120 | 121 | return loss 122 | 123 | 124 | class Novograd(Optimizer): 125 | """ 126 | Implements Novograd algorithm. 127 | 128 | Args: 129 | params (iterable): iterable of parameters to optimize or dicts defining 130 | parameter groups 131 | lr (float, optional): learning rate (default: 1e-3) 132 | betas (Tuple[float, float], optional): coefficients used for computing 133 | running averages of gradient and its square (default: (0.95, 0)) 134 | eps (float, optional): term added to the denominator to improve 135 | numerical stability (default: 1e-8) 136 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 137 | grad_averaging: gradient averaging 138 | amsgrad (boolean, optional): whether to use the AMSGrad variant of this 139 | algorithm from the paper `On the Convergence of Adam and Beyond`_ 140 | (default: False) 141 | """ 142 | 143 | def __init__(self, params, lr=1e-3, betas=(0.95, 0), eps=1e-8, 144 | weight_decay=0, grad_averaging=False, amsgrad=False): 145 | if not 0.0 <= lr: 146 | raise ValueError("Invalid learning rate: {}".format(lr)) 147 | if not 0.0 <= eps: 148 | raise ValueError("Invalid epsilon value: {}".format(eps)) 149 | if not 0.0 <= betas[0] < 1.0: 150 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 151 | if not 0.0 <= betas[1] < 1.0: 152 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 153 | defaults = dict(lr=lr, betas=betas, eps=eps, 154 | weight_decay=weight_decay, 155 | grad_averaging=grad_averaging, 156 | amsgrad=amsgrad) 157 | 158 | super(Novograd, self).__init__(params, defaults) 159 | 160 | def __setstate__(self, state): 161 | super(Novograd, self).__setstate__(state) 162 | for group in self.param_groups: 163 | group.setdefault('amsgrad', False) 164 | 165 | def step(self, closure=None): 166 | """Performs a single optimization step. 167 | 168 | Arguments: 169 | closure (callable, optional): A closure that reevaluates the model 170 | and returns the loss. 171 | """ 172 | loss = None 173 | if closure is not None: 174 | loss = closure() 175 | 176 | for group in self.param_groups: 177 | for p in group['params']: 178 | if p.grad is None: 179 | continue 180 | grad = p.grad.data 181 | if grad.is_sparse: 182 | raise RuntimeError('Sparse gradients are not supported.') 183 | amsgrad = group['amsgrad'] 184 | 185 | state = self.state[p] 186 | 187 | # State initialization 188 | if len(state) == 0: 189 | state['step'] = 0 190 | # Exponential moving average of gradient values 191 | state['exp_avg'] = torch.zeros_like(p.data) 192 | # Exponential moving average of squared gradient values 193 | state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 194 | if amsgrad: 195 | # Maintains max of all exp. moving avg. of sq. grad. values 196 | state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device) 197 | 198 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 199 | if amsgrad: 200 | max_exp_avg_sq = state['max_exp_avg_sq'] 201 | beta1, beta2 = group['betas'] 202 | 203 | state['step'] += 1 204 | 205 | norm = torch.sum(torch.pow(grad, 2)) 206 | 207 | if exp_avg_sq == 0: 208 | exp_avg_sq.copy_(norm) 209 | else: 210 | exp_avg_sq.mul_(beta2).add_(1 - beta2, norm) 211 | 212 | if amsgrad: 213 | # Maintains the maximum of all 2nd moment running avg. till now 214 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 215 | # Use the max. for normalizing running avg. of gradient 216 | denom = max_exp_avg_sq.sqrt().add_(group['eps']) 217 | else: 218 | denom = exp_avg_sq.sqrt().add_(group['eps']) 219 | 220 | grad.div_(denom) 221 | if group['weight_decay'] != 0: 222 | grad.add_(group['weight_decay'], p.data) 223 | if group['grad_averaging']: 224 | grad.mul_(1 - beta1) 225 | exp_avg.mul_(beta1).add_(grad) 226 | 227 | p.data.add_(-group['lr'], exp_avg) 228 | 229 | return loss 230 | -------------------------------------------------------------------------------- /torchtools/optim/over9000.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/mgrankin/over9000 3 | #### 4 | 5 | import torch, math 6 | from torch.optim.optimizer import Optimizer 7 | import itertools as it 8 | from .lookahead import Lookahead 9 | from .ralamb import Ralamb 10 | 11 | 12 | # RAdam + LARS + LookAHead 13 | 14 | # Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py 15 | # RAdam + LARS implementation from https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20 16 | 17 | def Over9000(params, alpha=0.5, k=6, *args, **kwargs): 18 | ralamb = Ralamb(params, *args, **kwargs) 19 | return Lookahead(ralamb, alpha, k) 20 | 21 | 22 | RangerLars = Over9000 23 | -------------------------------------------------------------------------------- /torchtools/optim/radam.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/LiyuanLucasLiu/RAdam 3 | # Paper: https://arxiv.org/abs/1908.03265 4 | #### 5 | 6 | import math 7 | import torch 8 | from torch.optim.optimizer import Optimizer, required 9 | 10 | class RAdam(Optimizer): 11 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 12 | if not 0.0 <= lr: 13 | raise ValueError("Invalid learning rate: {}".format(lr)) 14 | if not 0.0 <= eps: 15 | raise ValueError("Invalid epsilon value: {}".format(eps)) 16 | if not 0.0 <= betas[0] < 1.0: 17 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 18 | if not 0.0 <= betas[1] < 1.0: 19 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 20 | 21 | self.degenerated_to_sgd = degenerated_to_sgd 22 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 23 | for param in params: 24 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 25 | param['buffer'] = [[None, None, None] for _ in range(10)] 26 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) 27 | super(RAdam, self).__init__(params, defaults) 28 | 29 | def __setstate__(self, state): 30 | super(RAdam, self).__setstate__(state) 31 | 32 | def step(self, closure=None): 33 | 34 | loss = None 35 | if closure is not None: 36 | loss = closure() 37 | 38 | for group in self.param_groups: 39 | 40 | for p in group['params']: 41 | if p.grad is None: 42 | continue 43 | grad = p.grad.data.float() 44 | if grad.is_sparse: 45 | raise RuntimeError('RAdam does not support sparse gradients') 46 | 47 | p_data_fp32 = p.data.float() 48 | 49 | state = self.state[p] 50 | 51 | if len(state) == 0: 52 | state['step'] = 0 53 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 54 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 55 | else: 56 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 57 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 58 | 59 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 60 | beta1, beta2 = group['betas'] 61 | 62 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 63 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 64 | 65 | state['step'] += 1 66 | buffered = group['buffer'][int(state['step'] % 10)] 67 | if state['step'] == buffered[0]: 68 | N_sma, step_size = buffered[1], buffered[2] 69 | else: 70 | buffered[0] = state['step'] 71 | beta2_t = beta2 ** state['step'] 72 | N_sma_max = 2 / (1 - beta2) - 1 73 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 74 | buffered[1] = N_sma 75 | 76 | # more conservative since it's an approximated value 77 | if N_sma >= 5: 78 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 79 | elif self.degenerated_to_sgd: 80 | step_size = 1.0 / (1 - beta1 ** state['step']) 81 | else: 82 | step_size = -1 83 | buffered[2] = step_size 84 | 85 | # more conservative since it's an approximated value 86 | if N_sma >= 5: 87 | if group['weight_decay'] != 0: 88 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 89 | denom = exp_avg_sq.sqrt().add_(group['eps']) 90 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 91 | p.data.copy_(p_data_fp32) 92 | elif step_size > 0: 93 | if group['weight_decay'] != 0: 94 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 95 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 96 | p.data.copy_(p_data_fp32) 97 | 98 | return loss 99 | 100 | class PlainRAdam(Optimizer): 101 | 102 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 103 | if not 0.0 <= lr: 104 | raise ValueError("Invalid learning rate: {}".format(lr)) 105 | if not 0.0 <= eps: 106 | raise ValueError("Invalid epsilon value: {}".format(eps)) 107 | if not 0.0 <= betas[0] < 1.0: 108 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 109 | if not 0.0 <= betas[1] < 1.0: 110 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 111 | 112 | self.degenerated_to_sgd = degenerated_to_sgd 113 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 114 | 115 | super(PlainRAdam, self).__init__(params, defaults) 116 | 117 | def __setstate__(self, state): 118 | super(PlainRAdam, self).__setstate__(state) 119 | 120 | def step(self, closure=None): 121 | 122 | loss = None 123 | if closure is not None: 124 | loss = closure() 125 | 126 | for group in self.param_groups: 127 | 128 | for p in group['params']: 129 | if p.grad is None: 130 | continue 131 | grad = p.grad.data.float() 132 | if grad.is_sparse: 133 | raise RuntimeError('RAdam does not support sparse gradients') 134 | 135 | p_data_fp32 = p.data.float() 136 | 137 | state = self.state[p] 138 | 139 | if len(state) == 0: 140 | state['step'] = 0 141 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 142 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 143 | else: 144 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 145 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 146 | 147 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 148 | beta1, beta2 = group['betas'] 149 | 150 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 151 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 152 | 153 | state['step'] += 1 154 | beta2_t = beta2 ** state['step'] 155 | N_sma_max = 2 / (1 - beta2) - 1 156 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 157 | 158 | 159 | # more conservative since it's an approximated value 160 | if N_sma >= 5: 161 | if group['weight_decay'] != 0: 162 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 163 | step_size = group['lr'] * math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 164 | denom = exp_avg_sq.sqrt().add_(group['eps']) 165 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 166 | p.data.copy_(p_data_fp32) 167 | elif self.degenerated_to_sgd: 168 | if group['weight_decay'] != 0: 169 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 170 | step_size = group['lr'] / (1 - beta1 ** state['step']) 171 | p_data_fp32.add_(-step_size, exp_avg) 172 | p.data.copy_(p_data_fp32) 173 | 174 | return loss 175 | 176 | 177 | class AdamW(Optimizer): 178 | 179 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup = 0): 180 | if not 0.0 <= lr: 181 | raise ValueError("Invalid learning rate: {}".format(lr)) 182 | if not 0.0 <= eps: 183 | raise ValueError("Invalid epsilon value: {}".format(eps)) 184 | if not 0.0 <= betas[0] < 1.0: 185 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 186 | if not 0.0 <= betas[1] < 1.0: 187 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 188 | 189 | defaults = dict(lr=lr, betas=betas, eps=eps, 190 | weight_decay=weight_decay, warmup = warmup) 191 | super(AdamW, self).__init__(params, defaults) 192 | 193 | def __setstate__(self, state): 194 | super(AdamW, self).__setstate__(state) 195 | 196 | def step(self, closure=None): 197 | loss = None 198 | if closure is not None: 199 | loss = closure() 200 | 201 | for group in self.param_groups: 202 | 203 | for p in group['params']: 204 | if p.grad is None: 205 | continue 206 | grad = p.grad.data.float() 207 | if grad.is_sparse: 208 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 209 | 210 | p_data_fp32 = p.data.float() 211 | 212 | state = self.state[p] 213 | 214 | if len(state) == 0: 215 | state['step'] = 0 216 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 217 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 218 | else: 219 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 220 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 221 | 222 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 223 | beta1, beta2 = group['betas'] 224 | 225 | state['step'] += 1 226 | 227 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 228 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 229 | 230 | denom = exp_avg_sq.sqrt().add_(group['eps']) 231 | bias_correction1 = 1 - beta1 ** state['step'] 232 | bias_correction2 = 1 - beta2 ** state['step'] 233 | 234 | if group['warmup'] > state['step']: 235 | scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup'] 236 | else: 237 | scheduled_lr = group['lr'] 238 | 239 | step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1 240 | 241 | if group['weight_decay'] != 0: 242 | p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32) 243 | 244 | p_data_fp32.addcdiv_(-step_size, exp_avg, denom) 245 | 246 | p.data.copy_(p_data_fp32) 247 | 248 | return loss -------------------------------------------------------------------------------- /torchtools/optim/ralamb.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/mgrankin/over9000 3 | #### 4 | 5 | import torch, math 6 | from torch.optim.optimizer import Optimizer 7 | 8 | # RAdam + LARS 9 | class Ralamb(Optimizer): 10 | 11 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): 12 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 13 | self.buffer = [[None, None, None] for ind in range(10)] 14 | super(Ralamb, self).__init__(params, defaults) 15 | 16 | def __setstate__(self, state): 17 | super(Ralamb, self).__setstate__(state) 18 | 19 | def step(self, closure=None): 20 | 21 | loss = None 22 | if closure is not None: 23 | loss = closure() 24 | 25 | for group in self.param_groups: 26 | 27 | for p in group['params']: 28 | if p.grad is None: 29 | continue 30 | grad = p.grad.data.float() 31 | if grad.is_sparse: 32 | raise RuntimeError('Ralamb does not support sparse gradients') 33 | 34 | p_data_fp32 = p.data.float() 35 | 36 | state = self.state[p] 37 | 38 | if len(state) == 0: 39 | state['step'] = 0 40 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 41 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 42 | else: 43 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 44 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 45 | 46 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 47 | beta1, beta2 = group['betas'] 48 | 49 | # Decay the first and second moment running average coefficient 50 | # m_t 51 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 52 | # v_t 53 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 54 | 55 | state['step'] += 1 56 | buffered = self.buffer[int(state['step'] % 10)] 57 | 58 | if state['step'] == buffered[0]: 59 | N_sma, radam_step_size = buffered[1], buffered[2] 60 | else: 61 | buffered[0] = state['step'] 62 | beta2_t = beta2 ** state['step'] 63 | N_sma_max = 2 / (1 - beta2) - 1 64 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 65 | buffered[1] = N_sma 66 | 67 | # more conservative since it's an approximated value 68 | if N_sma >= 5: 69 | radam_step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 70 | else: 71 | radam_step_size = 1.0 / (1 - beta1 ** state['step']) 72 | buffered[2] = radam_step_size 73 | 74 | if group['weight_decay'] != 0: 75 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 76 | 77 | # more conservative since it's an approximated value 78 | radam_step = p_data_fp32.clone() 79 | if N_sma >= 5: 80 | denom = exp_avg_sq.sqrt().add_(group['eps']) 81 | radam_step.addcdiv_(-radam_step_size * group['lr'], exp_avg, denom) 82 | else: 83 | radam_step.add_(-radam_step_size * group['lr'], exp_avg) 84 | 85 | radam_norm = radam_step.pow(2).sum().sqrt() 86 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 87 | if weight_norm == 0 or radam_norm == 0: 88 | trust_ratio = 1 89 | else: 90 | trust_ratio = weight_norm / radam_norm 91 | 92 | state['weight_norm'] = weight_norm 93 | state['adam_norm'] = radam_norm 94 | state['trust_ratio'] = trust_ratio 95 | 96 | if N_sma >= 5: 97 | p_data_fp32.addcdiv_(-radam_step_size * group['lr'] * trust_ratio, exp_avg, denom) 98 | else: 99 | p_data_fp32.add_(-radam_step_size * group['lr'] * trust_ratio, exp_avg) 100 | 101 | p.data.copy_(p_data_fp32) 102 | 103 | return loss -------------------------------------------------------------------------------- /torchtools/optim/ranger.py: -------------------------------------------------------------------------------- 1 | #### 2 | # CODE TAKEN FROM https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 3 | # Blog post: https://medium.com/@lessw/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d 4 | #### 5 | 6 | import math 7 | import torch 8 | from torch.optim.optimizer import Optimizer, required 9 | import itertools as it 10 | from .lookahead import Lookahead 11 | from .radam import RAdam 12 | 13 | 14 | def Ranger(params, alpha=0.5, k=6, betas=(.95, 0.999), *args, **kwargs): 15 | radam = RAdam(params, betas=betas, *args, **kwargs) 16 | return Lookahead(radam, alpha, k) 17 | -------------------------------------------------------------------------------- /torchtools/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .smart_crop import SmartCrop -------------------------------------------------------------------------------- /torchtools/transforms/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pabloppp/pytorch-tools/6472bd5e223142bfbe105e2006432252f1ce3709/torchtools/transforms/models/__init__.py -------------------------------------------------------------------------------- /torchtools/transforms/models/saliency_model_v9.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pabloppp/pytorch-tools/6472bd5e223142bfbe105e2006432252f1ce3709/torchtools/transforms/models/saliency_model_v9.pt -------------------------------------------------------------------------------- /torchtools/transforms/smart_crop.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | import numpy as np 5 | import os 6 | 7 | # MICRO RESNET 8 | class ResBlock(nn.Module): 9 | def __init__(self, channels): 10 | super(ResBlock, self).__init__() 11 | 12 | self.resblock = nn.Sequential( 13 | nn.ReflectionPad2d(1), 14 | nn.Conv2d(channels, channels, kernel_size=3), 15 | nn.InstanceNorm2d(channels, affine=True), 16 | nn.ReLU(), 17 | nn.ReflectionPad2d(1), 18 | nn.Conv2d(channels, channels, kernel_size=3), 19 | nn.InstanceNorm2d(channels, affine=True), 20 | ) 21 | 22 | def forward(self, x): 23 | out = self.resblock(x) 24 | return out + x 25 | 26 | class Upsample2d(nn.Module): 27 | def __init__(self, scale_factor): 28 | super(Upsample2d, self).__init__() 29 | 30 | self.interp = nn.functional.interpolate 31 | self.scale_factor = scale_factor 32 | 33 | def forward(self, x): 34 | x = self.interp(x, scale_factor=self.scale_factor, mode='nearest') 35 | return x 36 | 37 | class MicroResNet(nn.Module): 38 | def __init__(self): 39 | super(MicroResNet, self).__init__() 40 | 41 | self.downsampler = nn.Sequential( 42 | nn.ReflectionPad2d(4), 43 | nn.Conv2d(3, 8, kernel_size=9, stride=4), 44 | nn.InstanceNorm2d(8, affine=True), 45 | nn.ReLU(), 46 | nn.ReflectionPad2d(1), 47 | nn.Conv2d(8, 16, kernel_size=3, stride=2), 48 | nn.InstanceNorm2d(16, affine=True), 49 | nn.ReLU(), 50 | nn.ReflectionPad2d(1), 51 | nn.Conv2d(16, 32, kernel_size=3, stride=2), 52 | nn.InstanceNorm2d(32, affine=True), 53 | nn.ReLU(), 54 | ) 55 | 56 | self.residual = nn.Sequential( 57 | ResBlock(32), 58 | nn.Conv2d(32, 64, kernel_size=1, bias=False, groups=32), 59 | ResBlock(64), 60 | ) 61 | 62 | self.segmentator = nn.Sequential( 63 | nn.ReflectionPad2d(1), 64 | nn.Conv2d(64, 16, kernel_size=3), 65 | nn.InstanceNorm2d(16, affine=True), 66 | nn.ReLU(), 67 | Upsample2d(scale_factor=2), 68 | nn.ReflectionPad2d(4), 69 | nn.Conv2d(16, 1, kernel_size=9), 70 | nn.Sigmoid() 71 | ) 72 | 73 | def forward(self, x): 74 | out = self.downsampler(x) 75 | out = self.residual(out) 76 | out = self.segmentator(out) 77 | return out 78 | 79 | # SmartCrop module 80 | class SmartCrop(nn.Module): 81 | def __init__(self, output_size, randomize_p=0.0, randomize_q=0.1, temperature=0.03): 82 | super().__init__() 83 | self.output_size = output_size 84 | self.randomize_p, self.randomize_q = randomize_p, randomize_q 85 | self.temperature = temperature 86 | if isinstance(self.output_size, int): 87 | self.output_size = (self.output_size, self.output_size) 88 | self.saliency_model = MicroResNet().eval().requires_grad_(False) 89 | checkpoint = torch.load(os.path.dirname(__file__) + "/models/saliency_model_v9.pt", map_location="cpu") 90 | self.saliency_model.load_state_dict(checkpoint) 91 | 92 | def forward(self, image): 93 | is_batch = len(image.shape) == 4 94 | if not is_batch: 95 | image = image.unsqueeze(0) 96 | with torch.no_grad(): 97 | resized_image = torchvision.transforms.functional.resize(image, 240, antialias=True) 98 | saliency_map = self.saliency_model(resized_image) 99 | tempered_heatmap = saliency_map.view(saliency_map.size(0), -1).div(self.temperature).softmax(-1) 100 | tempered_heatmap = tempered_heatmap / tempered_heatmap.sum(dim=1) 101 | tempered_heatmap = (tempered_heatmap > tempered_heatmap.max(dim=-1)[0]*0.75).float() 102 | saliency_map = tempered_heatmap.view(*saliency_map.shape) 103 | 104 | # GET CENTROID 105 | coord_space = torch.cat([ 106 | torch.linspace(0, 1, saliency_map.size(-2))[None, None, :, None].expand(-1, -1, -1, saliency_map.size(-1)), 107 | torch.linspace(0, 1, saliency_map.size(-1))[None, None, None, :].expand(-1, -1, saliency_map.size(-2), -1), 108 | ], dim=1) 109 | centroid = (coord_space * saliency_map).sum(dim=[-1, -2]) / saliency_map.sum(dim=[-1, -2]) 110 | # CROP 111 | crops = [] 112 | for i in range(image.size(0)): 113 | if np.random.rand() < self.randomize_p: 114 | centroid[i, 0] += np.random.uniform(-self.randomize_q, self.randomize_q) 115 | centroid[i, 1] += np.random.uniform(-self.randomize_q, self.randomize_q) 116 | top = (centroid[i, 0]*image.size(-2)-self.output_size[-2]/2).clamp(min=0, max=max(0, image.size(-2)-self.output_size[-2])).int() 117 | left = (centroid[i, 1]*image.size(-1)-self.output_size[-1]/2).clamp(min=0, max=max(0, image.size(-1)-self.output_size[-1])).int() 118 | bottom, right = top + self.output_size[-2], left + self.output_size[-1] 119 | crop = image[i, :, top:bottom, left:right] 120 | if crop.size(-2) < self.output_size[-2] or crop.size(-1) < self.output_size[-1]: 121 | crop = torchvision.transforms.functional.center_crop(crop, self.output_size) 122 | crops.append(crop) 123 | if is_batch: 124 | crops = torch.stack(crops, dim=0) 125 | else: 126 | crops = crops[0] 127 | return crops -------------------------------------------------------------------------------- /torchtools/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion import Diffuzz 2 | from .diffusion2 import Diffuzz2 3 | from .gamma_parametrization import apply_gamma_reparam, gamma_reparam_model, remove_gamma_reparam 4 | from .weight_normalization import apply_weight_norm, weight_norm_model, remove_weight_norm -------------------------------------------------------------------------------- /torchtools/utils/diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # Samplers -------------------------------------------------------------------- 4 | class SimpleSampler(): 5 | def __init__(self, diffuzz): 6 | self.current_step = -1 7 | self.diffuzz = diffuzz 8 | 9 | def __call__(self, *args, **kwargs): 10 | self.current_step += 1 11 | return self.step(*args, **kwargs) 12 | 13 | def init_x(self, shape): 14 | return torch.randn(*shape, device=self.diffuzz.device) 15 | 16 | def step(self, x, t, t_prev, noise): 17 | raise NotImplementedError("You should override the 'apply' function.") 18 | 19 | class DDPMSampler(SimpleSampler): 20 | def step(self, x, t, t_prev, noise): 21 | alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) 22 | alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]]) 23 | alpha = (alpha_cumprod / alpha_cumprod_prev) 24 | 25 | mu = (1.0 / alpha).sqrt() * (x - (1-alpha) * noise / (1-alpha_cumprod).sqrt()) 26 | std = ((1-alpha) * (1. - alpha_cumprod_prev) / (1. - alpha_cumprod)).sqrt() * torch.randn_like(mu) 27 | return mu + std * (t_prev != 0).float().view(t_prev.size(0), *[1 for _ in x.shape[1:]]) 28 | 29 | class DDIMSampler(SimpleSampler): 30 | def step(self, x, t, t_prev, noise): 31 | alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) 32 | alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]]) 33 | 34 | x0 = (x - (1 - alpha_cumprod).sqrt() * noise) / (alpha_cumprod).sqrt() 35 | dp_xt = (1 - alpha_cumprod_prev).sqrt() 36 | return (alpha_cumprod_prev).sqrt() * x0 + dp_xt * noise 37 | 38 | class DPMSolverPlusPlusSampler(SimpleSampler): # FIXME: CURRENTLY NOT WORKING 39 | def __init__(self, *args, **kwargs): 40 | super().__init__(*args, **kwargs) 41 | self.q_ts = {} 42 | 43 | def _get_coef(self, alpha_cumprod): 44 | log_alpha_t = alpha_cumprod.log() 45 | alpha_t = log_alpha_t.exp() 46 | sigma_t = (1-alpha_t ** 2).sqrt() 47 | lambda_t = log_alpha_t - sigma_t.log() 48 | return alpha_t, sigma_t, lambda_t 49 | 50 | def init_x(self, shape): 51 | alpha_cumprod = self.diffuzz._alpha_cumprod(torch.ones(shape[0], device=self.diffuzz.device)).view(-1, *[1 for _ in shape[1:]]) 52 | return torch.randn(*shape, device=self.diffuzz.device) * self._get_coef(alpha_cumprod)[1] 53 | 54 | def step(self, x, t, t_prev, noise): 55 | alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) 56 | stride = (t_prev - t) 57 | if self.current_step == 0: 58 | alpha_t, sigma_t, _ = self._get_coef(alpha_cumprod) 59 | elif self.current_step == 1: 60 | alpha_cumprod_next = self.diffuzz._alpha_cumprod(t+stride).view(t.size(0), *[1 for _ in x.shape[1:]]) 61 | alpha_t, sigma_t, lambda_t = self._get_coef(alpha_cumprod) 62 | _, sigma_t_next, lambda_t_next = self._get_coef(alpha_cumprod_next) 63 | h = lambda_t - lambda_t_next 64 | x = sigma_t / sigma_t_next * x - alpha_t * torch.expm1(-h) * self.q_ts[self.current_step-1] 65 | else: 66 | alpha_cumprod_next = self.diffuzz._alpha_cumprod(t+stride).view(t.size(0), *[1 for _ in x.shape[1:]]) 67 | alpha_cumprod_next_next = self.diffuzz._alpha_cumprod(t+stride*2).view(t.size(0), *[1 for _ in x.shape[1:]]) 68 | 69 | alpha_t, sigma_t, lambda_t = self._get_coef(alpha_cumprod) 70 | _, sigma_t_next, lambda_t_next = self._get_coef(alpha_cumprod_next) 71 | _, _, lambda_t_next_next = self._get_coef(alpha_cumprod_next_next) 72 | 73 | h = lambda_t - lambda_t_next 74 | h_next = lambda_t_next - lambda_t_next_next 75 | 76 | r = h_next / h 77 | D = (1 + 1 / (2 * r)) * self.q_ts[self.current_step-1] - 1 / (2 * r) * self.q_ts[self.current_step-2] 78 | x = sigma_t / sigma_t_next * x - alpha_t * torch.expm1(-h) * D 79 | self.q_ts[self.current_step] = (x - sigma_t * noise) / alpha_t 80 | return x 81 | 82 | sampler_dict = { 83 | 'ddpm': DDPMSampler, 84 | 'ddim': DDIMSampler, 85 | 'dpmsolver++': DPMSolverPlusPlusSampler, 86 | } 87 | 88 | # Custom simplified foward/backward diffusion (cosine schedule) 89 | class Diffuzz(): 90 | def __init__(self, s=0.008, device="cpu", cache_steps=None, scaler=1, clamp_range=(0.0001, 0.9999)): 91 | self.device = device 92 | self.s = torch.tensor([s]).to(device) 93 | self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 94 | self.scaler = scaler 95 | self.cached_steps = None 96 | self.clamp_range = clamp_range 97 | if cache_steps is not None: 98 | self.cached_steps = self._alpha_cumprod(torch.linspace(0, 1, cache_steps, device=device)) 99 | 100 | def _alpha_cumprod(self, t): 101 | if self.cached_steps is None: 102 | if self.scaler > 1: 103 | t = 1 - (1-t) ** self.scaler 104 | elif self.scaler < 1: 105 | t = t ** self.scaler 106 | alpha_cumprod = torch.cos((t + self.s) / (1 + self.s) * torch.pi * 0.5).clamp(0, 1) ** 2 / self._init_alpha_cumprod 107 | return alpha_cumprod.clamp(self.clamp_range[0], self.clamp_range[1]) 108 | else: 109 | return self.cached_steps[t.mul(len(self.cached_steps)-1).long()] 110 | 111 | def diffuse(self, x, t, noise=None): # t -> [0, 1] 112 | if noise is None: 113 | noise = torch.randn_like(x) 114 | alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) 115 | return alpha_cumprod.sqrt() * x + (1-alpha_cumprod).sqrt() * noise, noise 116 | 117 | def undiffuse(self, x, t, t_prev, noise, sampler=None): 118 | if sampler is None: 119 | sampler = DDPMSampler(self) 120 | return sampler(x, t, t_prev, noise) 121 | 122 | def sample(self, model, model_inputs, shape, mask=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, unconditional_inputs=None, sampler='ddpm', half=False): 123 | r_range = torch.linspace(t_start, t_end, timesteps+1)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(self.device) 124 | if isinstance(sampler, str): 125 | if sampler in sampler_dict: 126 | sampler = sampler_dict[sampler](self) 127 | else: 128 | raise ValueError(f"If sampler is a string it must be one of the supported samplers: {list(sampler_dict.keys())}") 129 | elif issubclass(sampler, SimpleSampler): 130 | sampler = sampler(self) 131 | else: 132 | raise ValueError("Sampler should be either a string or a SimpleSampler object.") 133 | preds = [] 134 | x = sampler.init_x(shape) if x_init is None or mask is not None else x_init.clone() 135 | if half: 136 | r_range = r_range.half() 137 | x = x.half() 138 | for i in range(0, timesteps): 139 | if mask is not None and x_init is not None: 140 | x_renoised, _ = self.diffuse(x_init, r_range[i]) 141 | x = x * mask + x_renoised * (1-mask) 142 | pred_noise = model(x, r_range[i], **model_inputs) 143 | if cfg is not None: 144 | if unconditional_inputs is None: 145 | unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} 146 | pred_noise_unconditional = model(x, r_range[i], **unconditional_inputs) 147 | pred_noise = torch.lerp(pred_noise_unconditional, pred_noise, cfg) 148 | x = self.undiffuse(x, r_range[i], r_range[i+1], pred_noise, sampler=sampler) 149 | preds.append(x) 150 | return preds 151 | 152 | def p2_weight(self, t, k=1.0, gamma=1.0): 153 | alpha_cumprod = self._alpha_cumprod(t) 154 | return (k + alpha_cumprod / (1 - alpha_cumprod)) ** -gamma -------------------------------------------------------------------------------- /torchtools/utils/diffusion2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | # Samplers -------------------------------------------------------------------- 5 | class SimpleSampler(): 6 | def __init__(self, diffuzz, mode="v"): 7 | self.current_step = -1 8 | self.diffuzz = diffuzz 9 | if mode not in ['v', 'e', 'x']: 10 | raise Exception("Mode should be either 'v', 'e' or 'x'") 11 | self.mode = mode 12 | 13 | def __call__(self, *args, **kwargs): 14 | self.current_step += 1 15 | return self.step(*args, **kwargs) 16 | 17 | def init_x(self, shape): 18 | return torch.randn(*shape, device=self.diffuzz.device) 19 | 20 | def step(self, x, t, t_prev, noise): 21 | raise NotImplementedError("You should override the 'apply' function.") 22 | 23 | # https://github.com/ozanciga/diffusion-for-beginners/blob/main/samplers/ddim.py 24 | class DDIMSampler(SimpleSampler): 25 | def step(self, x, t, t_prev, pred, eta=0): 26 | alpha_cumprod = self.diffuzz._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) 27 | alpha_cumprod_prev = self.diffuzz._alpha_cumprod(t_prev).view(t_prev.size(0), *[1 for _ in x.shape[1:]]) 28 | 29 | sigma_tau = eta * ((1 - alpha_cumprod_prev) / (1 - alpha_cumprod)).sqrt() * (1 - alpha_cumprod / alpha_cumprod_prev).sqrt() if eta > 0 else 0 30 | if self.mode == 'v': 31 | x0 = alpha_cumprod.sqrt() * x - (1-alpha_cumprod).sqrt() * pred 32 | noise = (1-alpha_cumprod).sqrt() * x + alpha_cumprod.sqrt() * pred 33 | elif self.mode == 'x': 34 | x0 = pred 35 | noise = (x - x0 * alpha_cumprod.sqrt()) / (1 - alpha_cumprod).sqrt() 36 | else: 37 | noise = pred 38 | x0 = (x - (1 - alpha_cumprod).sqrt() * noise) / alpha_cumprod.sqrt() 39 | renoised = alpha_cumprod_prev.sqrt() * x0 + (1 - alpha_cumprod_prev - sigma_tau ** 2).sqrt() * noise + sigma_tau * torch.randn_like(x) 40 | return x0, renoised, pred 41 | 42 | class DDPMSampler(DDIMSampler): 43 | def step(self, x, t, t_prev, pred, eta=1): 44 | return super().step(x, t, t_prev, pred, eta) 45 | 46 | sampler_dict = { 47 | 'ddpm': DDPMSampler, 48 | 'ddim': DDIMSampler, 49 | } 50 | 51 | # Custom simplified foward/backward diffusion (cosine schedule) 52 | class Diffuzz2(): 53 | def __init__(self, s=0.008, device="cpu", cache_steps=None, scaler=1, clamp_range=(1e-7, 1-1e-7)): 54 | self.device = device 55 | self.s = torch.tensor([s]).to(device) 56 | self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2 57 | self.scaler = 2 * np.log(1/scaler) 58 | self.cached_steps = None 59 | self.clamp_range = clamp_range 60 | if cache_steps is not None: 61 | self.cached_steps = self._alpha_cumprod(torch.linspace(0, 1, cache_steps, device=device)) 62 | 63 | def _alpha_cumprod(self, t): 64 | if self.cached_steps is None: 65 | alpha_cumprod = torch.cos((t + self.s) / (1 + self.s) * torch.pi * 0.5).clamp(0, 1) ** 2 / self._init_alpha_cumprod 66 | alpha_cumprod = alpha_cumprod.clamp(self.clamp_range[0], self.clamp_range[1]) 67 | if self.scaler != 1: 68 | alpha_cumprod = (alpha_cumprod/(1-alpha_cumprod)).log().add(self.scaler).sigmoid().clamp(self.clamp_range[0], self.clamp_range[1]) 69 | return alpha_cumprod 70 | else: 71 | return self.cached_steps[t.mul(len(self.cached_steps)-1).long()] 72 | 73 | def scale_t(self, t, scaler): 74 | scaler = 2 * np.log(1/scaler) 75 | alpha_cumprod = torch.cos((t + self.s) / (1 + self.s) * torch.pi * 0.5).clamp(0, 1) ** 2 / self._init_alpha_cumprod 76 | alpha_cumprod = alpha_cumprod.clamp(self.clamp_range[0], self.clamp_range[1]) 77 | if scaler != 1: 78 | alpha_cumprod = (alpha_cumprod/(1-alpha_cumprod)).log().add(scaler).sigmoid().clamp(self.clamp_range[0], self.clamp_range[1]) 79 | return (((alpha_cumprod * self._init_alpha_cumprod) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + self.s) - self.s 80 | 81 | def diffuse(self, x, t, noise=None): # t -> [0, 1] 82 | if noise is None: 83 | noise = torch.randn_like(x) 84 | alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) 85 | return alpha_cumprod.sqrt() * x + (1-alpha_cumprod).sqrt() * noise, noise 86 | 87 | def get_v(self, x, t, noise): 88 | alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in x.shape[1:]]) 89 | # x0 = alpha_cumprod * noised − (1-alpha_cumprod).sqrt() * pred_v 90 | # noise = (1-alpha_cumprod).sqrt() * noised + alpha_cumprod * pred_v 91 | return alpha_cumprod.sqrt() * noise - (1-alpha_cumprod).sqrt() * x 92 | 93 | def x0_from_v(self, noised, pred_v, t): 94 | alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in noised.shape[1:]]) 95 | return alpha_cumprod.sqrt() * noised - (1-alpha_cumprod).sqrt() * pred_v 96 | 97 | def noise_from_v(self, noised, pred_v, t): 98 | alpha_cumprod = self._alpha_cumprod(t).view(t.size(0), *[1 for _ in noised.shape[1:]]) 99 | return (1-alpha_cumprod).sqrt() * noised + alpha_cumprod.sqrt() * pred_v 100 | 101 | def undiffuse(self, x, t, t_prev, pred, sampler=None, **kwargs): 102 | if sampler is None: 103 | sampler = DDPMSampler(self) 104 | return sampler(x, t, t_prev, pred, **kwargs) 105 | 106 | def sample(self, model, model_inputs, shape, mask=None, t_start=1.0, t_end=0.0, timesteps=20, x_init=None, cfg=3.0, cfg_rho=0.7, unconditional_inputs=None, sampler='ddpm', dtype=None, sample_mode='v', sampler_params={}, t_scaler=1): 107 | r_range = torch.linspace(t_start, t_end, timesteps+1)[:, None].expand(-1, shape[0] if x_init is None else x_init.size(0)).to(self.device) 108 | if t_scaler != 1: 109 | r_range = self.scale_t(r_range, t_scaler) 110 | if isinstance(sampler, str): 111 | if sampler in sampler_dict: 112 | sampler = sampler_dict[sampler](self, sample_mode) 113 | else: 114 | raise ValueError(f"If sampler is a string it must be one of the supported samplers: {list(sampler_dict.keys())}") 115 | elif issubclass(sampler, SimpleSampler): 116 | sampler = sampler(self, sample_mode) 117 | else: 118 | raise ValueError("Sampler should be either a string or a SimpleSampler object.") 119 | 120 | x = sampler.init_x(shape) if x_init is None or mask is not None else x_init.clone() 121 | if dtype is not None: 122 | r_range = r_range.to(dtype) 123 | x = x.to(dtype) 124 | if cfg is not None: 125 | if unconditional_inputs is None: 126 | unconditional_inputs = {k: torch.zeros_like(v) for k, v in model_inputs.items()} 127 | model_inputs = {k:torch.cat([v, v_u]) if isinstance(v, torch.Tensor) else None for (k, v), (k_u, v_u) in zip(model_inputs.items(), unconditional_inputs.items())} 128 | for i in range(0, timesteps): 129 | if mask is not None and x_init is not None: 130 | x_renoised, _ = self.diffuse(x_init, r_range[i]) 131 | x = x * mask + x_renoised * (1-mask) 132 | if cfg is not None: 133 | pred, pred_unconditional = model(torch.cat([x] * 2), torch.cat([r_range[i]] * 2), **model_inputs).chunk(2) 134 | pred_cfg = torch.lerp(pred_unconditional, pred, cfg) 135 | if cfg_rho > 0: 136 | std_pos, std_cfg = pred.std(), pred_cfg.std() 137 | pred = cfg_rho * (pred_cfg * std_pos/(std_cfg+1e-9)) + pred_cfg * (1-cfg_rho) 138 | else: 139 | pred = pred_cfg 140 | else: 141 | pred = model(x, r_range[i], **model_inputs) 142 | 143 | diff_out = self.undiffuse(x, r_range[i], r_range[i+1], pred, sampler=sampler, **sampler_params) 144 | x = diff_out[1] 145 | altered_vars = yield diff_out 146 | 147 | # Update some running variables if the user wants 148 | if altered_vars is not None: 149 | cfg = altered_vars.get('cfg', cfg) 150 | cfg_rho = altered_vars.get('cfg_rho', cfg_rho) 151 | sampler = altered_vars.get('sampler', sampler) 152 | unconditional_inputs = altered_vars.get('unconditional_inputs', unconditional_inputs) 153 | model_inputs = altered_vars.get('model_inputs', model_inputs) 154 | x = altered_vars.get('x', x) 155 | mask = altered_vars.get('mask', mask) 156 | x_init = altered_vars.get('x_init', x_init) 157 | 158 | def p2_weight(self, t, k=1.0, gamma=1.0): 159 | alpha_cumprod = self._alpha_cumprod(t) 160 | return (k + alpha_cumprod / (1 - alpha_cumprod)) ** -gamma 161 | 162 | def truncated_snr_weight(self, t, min=1.0, max=None): 163 | alpha_cumprod = self._alpha_cumprod(t) 164 | srn = (alpha_cumprod / (1 - alpha_cumprod)) 165 | if min != None or max != None: 166 | srn = srn.clamp(min=min, max=max) 167 | return srn 168 | -------------------------------------------------------------------------------- /torchtools/utils/gamma_parametrization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class _GammaScaling(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.gamma = nn.Parameter(torch.ones(1)) 9 | 10 | def forward(self, w): 11 | return w * self.gamma 12 | 13 | def apply_gamma_reparam(module, name="weight"): # this reparametrizes the parameters of a single module 14 | nn.utils.parametrizations.spectral_norm(module, name) 15 | nn.utils.parametrize.register_parametrization(module, name, _GammaScaling()) 16 | return module 17 | 18 | def gamma_reparam_model(model): 19 | for module in model.modules(): # this reparametrizes all linear layers of the model 20 | if isinstance(module, nn.Linear) and not torch.nn.utils.parametrize.is_parametrized(module, "weight"): 21 | apply_gamma_reparam(module, "weight") 22 | elif isinstance(module, nn.MultiheadAttention) and not torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"): 23 | apply_gamma_reparam(module, "in_proj_weight") 24 | return model 25 | 26 | def remove_gamma_reparam(model): 27 | for module in model.modules(): 28 | if torch.nn.utils.parametrize.is_parametrized(module, "weight"): 29 | nn.utils.parametrize.remove_parametrizations(module, "weight") 30 | elif torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"): 31 | nn.utils.parametrize.remove_parametrizations(module, "in_proj_weight") 32 | -------------------------------------------------------------------------------- /torchtools/utils/weight_normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class _WeigthNorm(nn.Module): 5 | def __init__(self, eps=1e-4): 6 | super().__init__() 7 | self.eps = eps 8 | 9 | def _normalize(self, w): 10 | norm_dims = list(range(1, len(w.shape))) 11 | w_norm = torch.linalg.vector_norm(w, dim=norm_dims, keepdim=True) 12 | # w_norm = torch.norm_except_dim(w, 2, 0).clone() 13 | return w / (w_norm + self.eps) 14 | 15 | def forward(self, w): 16 | if self.training: 17 | with torch.no_grad(): 18 | fan_in = w[0].numel()**0.5 19 | w.data = self._normalize(w.data.clone()) * fan_in 20 | # w.copy_(self._normalize(w) * fan_in) 21 | return self._normalize(w) 22 | 23 | def apply_weight_norm(module, name="weight", init_weight=True): # this reparametrizes the parameters of a single module 24 | if init_weight: 25 | torch.nn.init.normal(getattr(module, name)) 26 | nn.utils.parametrize.register_parametrization(module, name, _WeigthNorm(), unsafe=True) 27 | return module 28 | 29 | def weight_norm_model(model, whitelist=None, init_weight=True): 30 | whitelist = whitelist or [] 31 | 32 | def check_parameter(module, name): 33 | return hasattr(module, name) and not torch.nn.utils.parametrize.is_parametrized(module, name) and isinstance(getattr(module, name), nn.Parameter) 34 | 35 | for name, module in model.named_modules(): # this reparametrizes all layers of the model that have a "weight" parameter 36 | if not any([w in name for w in whitelist]): 37 | if check_parameter(module, "weight"): 38 | apply_weight_norm(module, init_weight=init_weight) 39 | elif check_parameter(module, "in_proj_weight"): 40 | apply_weight_norm(module, 'in_proj_weight', init_weight=init_weight) 41 | return model 42 | 43 | def remove_weight_norm(model): 44 | for module in model.modules(): 45 | if torch.nn.utils.parametrize.is_parametrized(module, "weight"): 46 | nn.utils.parametrize.remove_parametrizations(module, "weight") 47 | elif torch.nn.utils.parametrize.is_parametrized(module, "in_proj_weight"): 48 | nn.utils.parametrize.remove_parametrizations(module, "in_proj_weight") --------------------------------------------------------------------------------