├── .gitignore ├── README.md ├── examples └── generate.py ├── pytorch_pretrained_gans ├── .gitignore ├── BigBiGAN │ ├── __init__.py │ ├── model │ │ ├── BigGAN.py │ │ ├── __init__.py │ │ ├── layers.py │ │ └── sync_batchnorm │ │ │ ├── __init__.py │ │ │ ├── batchnorm.py │ │ │ ├── batchnorm_reimpl.py │ │ │ ├── comm.py │ │ │ ├── replicate.py │ │ │ └── unittest.py │ └── weights │ │ └── download.sh ├── BigGAN │ ├── __init__.py │ ├── config.py │ ├── convert_tf_to_pytorch.py │ ├── file_utils.py │ ├── model.py │ └── utils.py ├── CIPS │ ├── GeneratorsCIPS.py │ ├── __init__.py │ ├── blocks.py │ └── op │ │ ├── __init__.py │ │ ├── fused_act.py │ │ ├── fused_bias_act.cpp │ │ ├── fused_bias_act_kernel.cu │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.py │ │ └── upfirdn2d_kernel.cu ├── StudioGAN │ ├── __init__.py │ ├── configs │ │ ├── ImageNet │ │ │ ├── BigGAN2048 │ │ │ │ └── BigGAN2048.json │ │ │ ├── BigGAN256 │ │ │ │ └── BigGAN256.json │ │ │ ├── ContraGAN2048 │ │ │ │ └── ContraGAN2048.json │ │ │ ├── ContraGAN256 │ │ │ │ └── ContraGAN256.json │ │ │ ├── SAGAN │ │ │ │ └── SAGAN.json │ │ │ ├── SNGAN │ │ │ │ └── SNGAN.json │ │ │ └── download.sh │ │ └── TinyImageNet │ │ │ ├── ACGAN │ │ │ └── ACGAN.json │ │ │ ├── BigGAN │ │ │ └── BigGAN.json │ │ │ ├── ContraGAN │ │ │ └── DiffAugGAN(C).json │ │ │ ├── GGAN │ │ │ └── GGAN.json │ │ │ ├── LSGAN │ │ │ └── LSGAN.json │ │ │ ├── ProjGAN │ │ │ └── ProjGAN.json │ │ │ ├── SAGAN │ │ │ └── SAGAN.json │ │ │ ├── SNGAN │ │ │ └── SNGAN.json │ │ │ ├── WGAN-WC │ │ │ └── WGAN-WC.json │ │ │ └── download.sh │ ├── loader.py │ ├── main.py │ ├── models │ │ ├── __init__.py │ │ ├── big_resnet.py │ │ ├── big_resnet_deep.py │ │ └── resnet.py │ ├── sync_batchnorm │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── batchnorm_reimpl.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py │ ├── utils │ │ ├── __init__.py │ │ ├── ada.py │ │ ├── ada_op │ │ │ ├── __init__.py │ │ │ ├── fused_act.py │ │ │ ├── fused_bias_act.cpp │ │ │ ├── fused_bias_act_kernel.cu │ │ │ ├── upfirdn2d.cpp │ │ │ ├── upfirdn2d.py │ │ │ └── upfirdn2d_kernel.cu │ │ ├── biggan_utils.py │ │ ├── cr_diff_aug.py │ │ ├── diff_aug.py │ │ ├── load_checkpoint.py │ │ ├── log.py │ │ ├── losses.py │ │ ├── make_hdf5.py │ │ ├── misc.py │ │ ├── model_ops.py │ │ └── sample.py │ └── worker.py ├── __init__.py ├── self_conditioned │ ├── __init__.py │ └── gan_training │ │ ├── __init__.py │ │ ├── checkpoints.py │ │ ├── config.py │ │ ├── distributions.py │ │ ├── eval.py │ │ ├── inputs.py │ │ ├── logger.py │ │ ├── metrics │ │ ├── __init__.py │ │ ├── clustering_metrics.py │ │ ├── fid.py │ │ ├── inception_score.py │ │ └── tf_is │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ └── inception_score.py │ │ ├── models │ │ ├── __init__.py │ │ ├── blocks.py │ │ ├── dcgan_deep.py │ │ ├── dcgan_shallow.py │ │ ├── resnet2.py │ │ ├── resnet2s.py │ │ └── resnet3.py │ │ ├── train.py │ │ └── utils.py └── stylegan2_ada_pytorch │ ├── __init__.py │ ├── dnnlib │ ├── __init__.py │ └── util.py │ └── torch_utils │ ├── __init__.py │ ├── custom_ops.py │ ├── misc.py │ ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py │ ├── persistence.py │ └── training_stats.py ├── setup.py └── tests └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Custom 2 | .vscode 3 | wandb 4 | outputs 5 | tmp* 6 | slurm-logs 7 | inversion/gans/BigGAN/weights 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | .github 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # Lightning /research 38 | test_tube_exp/ 39 | tests/tests_tt_dir/ 40 | tests/save_dir 41 | default/ 42 | data/ 43 | test_tube_logs/ 44 | test_tube_data/ 45 | datasets/ 46 | model_weights/ 47 | tests/save_dir 48 | tests/tests_tt_dir/ 49 | processed/ 50 | raw/ 51 | 52 | # PyInstaller 53 | # Usually these files are written by a python script from a template 54 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 55 | *.manifest 56 | *.spec 57 | 58 | # Installer logs 59 | pip-log.txt 60 | pip-delete-this-directory.txt 61 | 62 | # Unit test / coverage reports 63 | htmlcov/ 64 | .tox/ 65 | .coverage 66 | .coverage.* 67 | .cache 68 | nosetests.xml 69 | coverage.xml 70 | *.cover 71 | .hypothesis/ 72 | .pytest_cache/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | 93 | # PyBuilder 94 | target/ 95 | 96 | # Jupyter Notebook 97 | .ipynb_checkpoints 98 | 99 | # pyenv 100 | .python-version 101 | 102 | # celery beat schedule file 103 | celerybeat-schedule 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | 130 | # IDEs 131 | .idea 132 | .vscode 133 | 134 | # seed project 135 | lightning_logs/ 136 | MNIST 137 | .DS_Store 138 | -------------------------------------------------------------------------------- /examples/generate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Example usage: 3 | python generate.py 4 | """ 5 | import torch 6 | from pytorch_pretrained_gans import make_gan 7 | 8 | # BigGAN (unconditional) 9 | G = make_gan(gan_type='biggan', model_name='biggan-deep-256') # -> nn.Module 10 | y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000]) 11 | z = G.sample_latent(batch_size=1) # -> torch.Size([1, 128]) 12 | x = G(z=z, y=y) # -> torch.Size([1, 3, 256, 256]) 13 | assert z.shape == torch.Size([1, 128]) 14 | assert x.shape == torch.Size([1, 3, 256, 256]) 15 | 16 | # BigBiGAN (unconditional) 17 | G = make_gan(gan_type='bigbigan') # -> nn.Module 18 | z = G.sample_latent(batch_size=1) # -> torch.Size([1, 120]) 19 | x = G(z=z) # -> torch.Size([1, 3, 128, 128]) 20 | assert z.shape == torch.Size([1, 120]) 21 | assert x.shape == torch.Size([1, 3, 128, 128]) 22 | 23 | # Self-Conditioned GAN (unconditional) 24 | G = make_gan(gan_type='selfconditionedgan', model_name='self_conditioned') # -> nn.Module 25 | y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000]) 26 | z = G.sample_latent(batch_size=1) # -> torch.Size([1, 256]) 27 | x = G(z=z, y=y) # -> torch.Size([1, 3, 128, 128]) 28 | assert z.shape == torch.Size([1, 256]) 29 | assert x.shape == torch.Size([1, 3, 128, 128]) 30 | 31 | # StyleGAN2 (unconditional) 32 | G = make_gan(gan_type='stylegan2').to('cuda') # -> nn.Module 33 | z = G.sample_latent(batch_size=1, device='cuda') # -> torch.Size([1, 18, 512]) 34 | x = G(z=z) # -> torch.Size([1, 3, 1024, 1024]) 35 | assert z.shape == torch.Size([1, 18, 512]) 36 | assert x.shape == torch.Size([1, 3, 1024, 1024]) 37 | 38 | try: 39 | # StudioGAN (unconditional) 40 | G = make_gan(gan_type='studiogan', model_name='SAGAN') # -> nn.Module 41 | y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000]) 42 | z = G.sample_latent(batch_size=1) # -> torch.Size([1, 128]) 43 | x = G(z=z, y=y) # -> torch.Size([1, 3, 128, 128]) 44 | assert z.shape == torch.Size([1, 128]) 45 | assert x.shape == torch.Size([1, 3, 128, 128]) 46 | except: 47 | print('Please download StudioGAN models as specified in the repo') -------------------------------------------------------------------------------- /pytorch_pretrained_gans/.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | *.pth 3 | *.pkl 4 | *.npy 5 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/BigBiGAN/__init__.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | from torch import nn 4 | from torch.utils import model_zoo 5 | from .model import BigGAN 6 | 7 | 8 | _WEIGHTS_URL = "https://github.com/greeneggsandyaml/tmp/releases/download/0.0.1/BigBiGAN_x1.pth" 9 | 10 | 11 | class GeneratorWrapper(nn.Module): 12 | """ A wrapper to put the GAN in a standard format -- here, a modified 13 | version of the old UnconditionalBigGAN class """ 14 | 15 | def __init__(self, big_gan): 16 | super().__init__() 17 | self.big_gan = big_gan 18 | self.dim_z = self.big_gan.dim_z 19 | self.conditional = False 20 | 21 | def forward(self, z): 22 | classes = torch.zeros(z.shape[0], dtype=torch.int64, device=z.device) 23 | return self.big_gan(z, self.big_gan.shared(classes)) 24 | 25 | def sample_latent(self, batch_size, device='cpu'): 26 | z = torch.randn((batch_size, self.dim_z), device=device) 27 | return z 28 | 29 | 30 | def make_biggan_config(resolution): 31 | attn_dict = {128: '64', 256: '128', 512: '64'} 32 | dim_z_dict = {128: 120, 256: 140, 512: 128} 33 | config = { 34 | 'G_param': 'SN', 'D_param': 'SN', 35 | 'G_ch': 96, 'D_ch': 96, 36 | 'D_wide': True, 'G_shared': True, 37 | 'shared_dim': 128, 'dim_z': dim_z_dict[resolution], 38 | 'hier': True, 'cross_replica': False, 39 | 'mybn': False, 'G_activation': nn.ReLU(inplace=True), 40 | 'G_attn': attn_dict[resolution], 41 | 'norm_style': 'bn', 42 | 'G_init': 'ortho', 'skip_init': True, 'no_optim': True, 43 | 'G_fp16': False, 'G_mixed_precision': False, 44 | 'accumulate_stats': False, 'num_standing_accumulations': 16, 45 | 'G_eval_mode': True, 46 | 'BN_eps': 1e-04, 'SN_eps': 1e-04, 47 | 'num_G_SVs': 1, 'num_G_SV_itrs': 1, 'resolution': resolution, 48 | 'n_classes': 1000} 49 | return config 50 | 51 | 52 | def make_bigbigan(model_name='bigbigan-128'): 53 | assert model_name == 'bigbigan-128' 54 | config = make_biggan_config(resolution=128) 55 | G = BigGAN.Generator(**config) 56 | checkpoint = model_zoo.load_url(_WEIGHTS_URL, map_location='cpu') 57 | G.load_state_dict(checkpoint) # , strict=False) 58 | G = GeneratorWrapper(G) 59 | return G 60 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/BigBiGAN/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/pytorch-pretrained-gans/2982fdab4e683165e45bc2f4a64c2942a7a3a1b7/pytorch_pretrained_gans/BigBiGAN/model/__init__.py -------------------------------------------------------------------------------- /pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNormReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/BigBiGAN/model/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/BigBiGAN/weights/download.sh: -------------------------------------------------------------------------------- 1 | wget https://www.dropbox.com/s/9w2i45h455k3b4p/BigBiGAN_x1.pth 2 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/BigGAN/__init__.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import os 3 | import torch 4 | import numpy as np 5 | 6 | from .model import BigGAN 7 | 8 | 9 | def truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None): 10 | """ Create a truncated noise vector. 11 | Params: 12 | batch_size: batch size. 13 | dim_z: dimension of z 14 | truncation: truncation value to use 15 | seed: seed for the random generator 16 | Output: 17 | array of shape (batch_size, dim_z) 18 | """ 19 | from scipy.stats import truncnorm 20 | state = None if seed is None else np.random.RandomState(seed) 21 | values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32) 22 | return truncation * values 23 | 24 | 25 | class GeneratorWrapper(torch.nn.Module): 26 | """ A wrapper to put the GAN in a standard format """ 27 | 28 | def __init__(self, G, truncation=0.4): 29 | super().__init__() 30 | self.G = G 31 | self.dim_z = G.config.z_dim 32 | self.conditional = True 33 | self.num_classes = 1000 34 | 35 | self.truncation = truncation 36 | 37 | def forward(self, z, y=None, return_y=False): 38 | """ In original code, z -> noise_vector, y -> class_vector """ 39 | if y is None: 40 | y = self.sample_class(batch_size=z.shape[0], device=z.device) 41 | elif y.dtype == torch.long: 42 | y = torch.eye(self.num_classes, dtype=torch.float, device=y.device)[y] 43 | else: 44 | y = y.to(z.device) 45 | x = self.G(z, y, truncation=self.truncation) 46 | x = torch.clamp(x, min=-1, max=1) # this shouldn't really be necessary 47 | return (x, y) if return_y else x 48 | 49 | def sample_latent(self, batch_size, device='cpu'): 50 | z = truncated_noise_sample(truncation=self.truncation, batch_size=batch_size) 51 | z = torch.from_numpy(z).to(device) 52 | return z 53 | 54 | def sample_class(self, batch_size, device='cpu'): 55 | y = torch.randint(low=0, high=self.num_classes, size=(batch_size,), device=device) 56 | y = torch.eye(self.num_classes, dtype=torch.float, device=device)[y] 57 | return y 58 | 59 | 60 | def make_biggan(model_name='biggan-deep-256') -> torch.nn.Module: 61 | G = BigGAN.from_pretrained(model_name).eval() 62 | G = GeneratorWrapper(G) 63 | return G.eval() 64 | 65 | 66 | if __name__ == '__main__': 67 | # Testing 68 | device = torch.device('cuda') 69 | G = make_pretrained_biggan('biggan-deep-512') 70 | G.to(device).eval() 71 | print('Created G') 72 | print(f'Params: {sum(p.numel() for p in G.parameters()):_}') 73 | z = torch.randn([1, G.dim_z]).to(device) 74 | print(f'z.shape: {z.shape}') 75 | x = G(z) 76 | print(f'x.shape: {x.shape}') 77 | print(f'x.max(): {x.max()}') 78 | print(f'x.min(): {x.min()}') 79 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/BigGAN/config.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | """ 3 | BigGAN config. 4 | """ 5 | from __future__ import (absolute_import, division, print_function, unicode_literals) 6 | 7 | import copy 8 | import json 9 | 10 | class BigGANConfig(object): 11 | """ Configuration class to store the configuration of a `BigGAN`. 12 | Defaults are for the 128x128 model. 13 | layers tuple are (up-sample in the layer ?, input channels, output channels) 14 | """ 15 | def __init__(self, 16 | output_dim=128, 17 | z_dim=128, 18 | class_embed_dim=128, 19 | channel_width=128, 20 | num_classes=1000, 21 | layers=[(False, 16, 16), 22 | (True, 16, 16), 23 | (False, 16, 16), 24 | (True, 16, 8), 25 | (False, 8, 8), 26 | (True, 8, 4), 27 | (False, 4, 4), 28 | (True, 4, 2), 29 | (False, 2, 2), 30 | (True, 2, 1)], 31 | attention_layer_position=8, 32 | eps=1e-4, 33 | n_stats=51): 34 | """Constructs BigGANConfig. """ 35 | self.output_dim = output_dim 36 | self.z_dim = z_dim 37 | self.class_embed_dim = class_embed_dim 38 | self.channel_width = channel_width 39 | self.num_classes = num_classes 40 | self.layers = layers 41 | self.attention_layer_position = attention_layer_position 42 | self.eps = eps 43 | self.n_stats = n_stats 44 | 45 | @classmethod 46 | def from_dict(cls, json_object): 47 | """Constructs a `BigGANConfig` from a Python dictionary of parameters.""" 48 | config = BigGANConfig() 49 | for key, value in json_object.items(): 50 | config.__dict__[key] = value 51 | return config 52 | 53 | @classmethod 54 | def from_json_file(cls, json_file): 55 | """Constructs a `BigGANConfig` from a json file of parameters.""" 56 | with open(json_file, "r", encoding='utf-8') as reader: 57 | text = reader.read() 58 | return cls.from_dict(json.loads(text)) 59 | 60 | def __repr__(self): 61 | return str(self.to_json_string()) 62 | 63 | def to_dict(self): 64 | """Serializes this instance to a Python dictionary.""" 65 | output = copy.deepcopy(self.__dict__) 66 | return output 67 | 68 | def to_json_string(self): 69 | """Serializes this instance to a JSON string.""" 70 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 71 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/CIPS/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import NamedTuple 3 | 4 | try: 5 | from .GeneratorsCIPS import CIPSskip 6 | model_available = True 7 | except: 8 | CIPSskip = None 9 | model_available = False 10 | 11 | 12 | class Churches256Arguments(NamedTuple): 13 | """CIPSskip for LSUN-Churches-256""" 14 | Generator = CIPSskip 15 | size = 256 16 | coords_size = 256 17 | fc_dim = 512 18 | latent = 512 19 | style_dim = 512 20 | n_mlp = 8 21 | activation = None 22 | channel_multiplier = 2 23 | coords_integer_values = False 24 | 25 | 26 | MODELS = { 27 | # Download from https://github.com/saic-mdal/CIPS#pretrained-checkpoints 28 | 'churches': ('/home/luke/projects/experiments/gan-seg/src/segmentation/gans/CIPS/churches_g_ema.pt', Churches256Arguments), 29 | } 30 | 31 | 32 | class GeneratorWrapper(torch.nn.Module): 33 | """ A wrapper to put the GAN in a standard format """ 34 | 35 | def __init__(self, g_ema, args, truncation=0.7, device='cpu'): 36 | super().__init__() 37 | self.G = g_ema.to(device) 38 | self.dim_z = g_ema.style_dim 39 | self.conditional = False 40 | 41 | self.truncation = truncation 42 | self.truncation_latent = get_latent_mean(g_ema, args, device) 43 | self.x_channel, self.y_channel = convert_to_coord_format_unbatched( 44 | args.coords_size, args.coords_size, device, 45 | integer_values=args.coords_integer_values) 46 | self.coords_size = args.coords_size 47 | 48 | def forward(self, z): 49 | x_channel = self.x_channel.repeat(z.size(0), 1, self.coords_size, 1).to(z.device) 50 | y_channel = self.y_channel.repeat(z.size(0), 1, 1, self.coords_size).to(z.device) 51 | converted_full = torch.cat((x_channel, y_channel), dim=1) 52 | sample, _ = self.G( 53 | coords=converted_full, 54 | latent=[z], 55 | return_latents=False, 56 | truncation=self.truncation, 57 | truncation_latent=self.truncation_latent, 58 | input_is_latent=True) 59 | sample = torch.clamp(sample, min=-1, max=1) # I don't know if this is needed, I think it is though 60 | return sample 61 | 62 | 63 | def convert_to_coord_format_unbatched(h, w, device='cpu', integer_values=False): 64 | if integer_values: 65 | x_channel = torch.arange(w, dtype=torch.float, device=device).view(1, 1, 1, -1) 66 | y_channel = torch.arange(h, dtype=torch.float, device=device).view(1, 1, -1, 1) 67 | else: 68 | x_channel = torch.linspace(-1, 1, w, device=device).view(1, 1, 1, -1) 69 | y_channel = torch.linspace(-1, 1, h, device=device).view(1, 1, -1, 1) 70 | return (x_channel, y_channel) 71 | 72 | 73 | def make_cips(model_name='churches', **kwargs) -> torch.nn.Module: 74 | if not model_available: 75 | raise Exception('Could not load model. Do you have CUDA?') 76 | 77 | checkpoint_path, args = MODELS[model_name] 78 | g_ema = args.Generator( 79 | size=args.size, 80 | hidden_size=args.fc_dim, 81 | style_dim=args.latent, 82 | n_mlp=args.n_mlp, 83 | activation=args.activation, 84 | channel_multiplier=args.channel_multiplier) 85 | ckpt = torch.load(checkpoint_path, map_location='cpu') 86 | g_ema.load_state_dict(ckpt) 87 | G = GeneratorWrapper(g_ema, args, **kwargs) 88 | return G.eval() 89 | 90 | 91 | @torch.no_grad() 92 | def get_latent_mean(g_ema, args, device): 93 | 94 | # Get sample input 95 | n_sample = 1 96 | sample_z = torch.randn(n_sample, args.latent, device=device) 97 | x_channel, y_channel = convert_to_coord_format_unbatched(args.coords_size, args.coords_size, device, 98 | integer_values=args.coords_integer_values) 99 | x_channel = x_channel.repeat(sample_z.size(0), 1, args.coords_size, 1).to(device) 100 | y_channel = y_channel.repeat(sample_z.size(0), 1, 1, args.coords_size).to(device) 101 | converted_full = torch.cat((x_channel, y_channel), dim=1) 102 | 103 | # Generate a bunch of times and 104 | latents = [] 105 | samples = [] 106 | for _ in range(100): 107 | sample_z = torch.randn(n_sample, args.latent, device=device) 108 | sample, latent = g_ema(converted_full, [sample_z], return_latents=True) 109 | latents.append(latent.cpu()) 110 | samples.append(sample.cpu()) 111 | samples = torch.cat(samples, 0) 112 | latents = torch.cat(latents, 0) 113 | truncation_latent = latents.mean(0).cuda() 114 | assert len(truncation_latent.shape) == 1 and truncation_latent.size(0) == 512, 'smt wrong' 115 | return truncation_latent 116 | 117 | 118 | if __name__ == '__main__': 119 | # Testing 120 | device = torch.device('cuda') 121 | G = make_cips(device=device) 122 | print('Created G') 123 | print(f'Params: {sum(p.numel() for p in G.parameters()):_}') 124 | z = torch.randn([1, G.dim_z]).to(device) 125 | print(f'z.shape: {z.shape}') 126 | x = G(z) 127 | print(f'x.shape: {x.shape}') 128 | import pdb 129 | pdb.set_trace() 130 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/CIPS/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/CIPS/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | fused = load( 11 | 'fused', 12 | sources=[ 13 | os.path.join(module_path, 'fused_bias_act.cpp'), 14 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 15 | ], 16 | ) 17 | 18 | 19 | class FusedLeakyReLUFunctionBackward(Function): 20 | @staticmethod 21 | def forward(ctx, grad_output, out, negative_slope, scale): 22 | ctx.save_for_backward(out) 23 | ctx.negative_slope = negative_slope 24 | ctx.scale = scale 25 | 26 | empty = grad_output.new_empty(0) 27 | 28 | grad_input = fused.fused_bias_act( 29 | grad_output, empty, out, 3, 1, negative_slope, scale 30 | ) 31 | 32 | dim = [0] 33 | 34 | if grad_input.ndim > 2: 35 | dim += list(range(2, grad_input.ndim)) 36 | 37 | grad_bias = grad_input.sum(dim).detach() 38 | 39 | return grad_input, grad_bias 40 | 41 | @staticmethod 42 | def backward(ctx, gradgrad_input, gradgrad_bias): 43 | out, = ctx.saved_tensors 44 | gradgrad_out = fused.fused_bias_act( 45 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 46 | ) 47 | 48 | return gradgrad_out, None, None, None 49 | 50 | 51 | class FusedLeakyReLUFunction(Function): 52 | @staticmethod 53 | def forward(ctx, input, bias, negative_slope, scale): 54 | empty = input.new_empty(0) 55 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 56 | ctx.save_for_backward(out) 57 | ctx.negative_slope = negative_slope 58 | ctx.scale = scale 59 | 60 | return out 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output): 64 | out, = ctx.saved_tensors 65 | 66 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 67 | grad_output, out, ctx.negative_slope, ctx.scale 68 | ) 69 | 70 | return grad_input, grad_bias, None, None 71 | 72 | 73 | class FusedLeakyReLU(nn.Module): 74 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 75 | super().__init__() 76 | 77 | self.bias = nn.Parameter(torch.zeros(channel)) 78 | self.negative_slope = negative_slope 79 | self.scale = scale 80 | 81 | def forward(self, input): 82 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 83 | 84 | 85 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 86 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 87 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/CIPS/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /pytorch_pretrained_gans/CIPS/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /pytorch_pretrained_gans/CIPS/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import join 3 | from pathlib import Path 4 | import numpy as np 5 | import torch 6 | import json 7 | from collections.abc import MutableMapping 8 | 9 | from .models import resnet 10 | from .models import big_resnet 11 | from .models import big_resnet_deep 12 | 13 | 14 | # Download here: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN#imagenet-3x128x128 15 | ROOT = Path(__file__).parent 16 | ARCHS = { 17 | 'resnet': resnet, 18 | 'big_resnet': big_resnet, 19 | 'big_resnet_deep': big_resnet_deep, 20 | } 21 | 22 | 23 | class Config(object): 24 | def __init__(self, dict_): 25 | self.__dict__.update(dict_) 26 | 27 | 28 | def flatten(d): 29 | items = [] 30 | for k, v in d.items(): 31 | if isinstance(v, MutableMapping): 32 | items.extend(flatten(v).items()) 33 | else: 34 | items.append((k, v)) 35 | return dict(items) 36 | 37 | 38 | def truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None): 39 | """ Create a truncated noise vector. 40 | Params: 41 | batch_size: batch size. 42 | dim_z: dimension of z 43 | truncation: truncation value to use 44 | seed: seed for the random generator 45 | Output: 46 | array of shape (batch_size, dim_z) 47 | """ 48 | from scipy.stats import truncnorm 49 | state = None if seed is None else np.random.RandomState(seed) 50 | values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32) 51 | return truncation * values 52 | 53 | 54 | class GeneratorWrapper(torch.nn.Module): 55 | """ A wrapper to put the GAN in a standard format """ 56 | 57 | def __init__(self, Gen, cfgs): 58 | super().__init__() 59 | self.G = Gen 60 | self.dim_z = Gen.z_dim 61 | self.conditional = True 62 | self.num_classes = cfgs.num_classes 63 | 64 | self.truncation = 1.0 65 | 66 | def forward(self, z, y=None, return_y=False): 67 | if y is not None: 68 | # the model is conditional and the user gives us a class 69 | y = y.to(z.device) 70 | elif self.num_classes is not None: 71 | # the model is conditional but the user does not give us a class 72 | y = self.sample_class(batch_size=z.shape[0], device=z.device) 73 | else: 74 | # the model is unconditional 75 | y = None 76 | x = self.G(z, label=y, evaluation=True) 77 | x = torch.clamp(x, min=-1, max=1) # this shouldn't really be necessary 78 | return (x, y) if return_y else x 79 | 80 | def sample_latent(self, batch_size, device='cpu'): 81 | z = torch.randn((batch_size, self.dim_z), device=device) 82 | return z 83 | 84 | def sample_class(self, batch_size, device='cpu'): 85 | y = torch.randint(low=0, high=self.num_classes, size=(batch_size,), device=device) 86 | return y 87 | 88 | 89 | def get_config_and_checkpoint(root): 90 | paths = list(map(str, root.iterdir())) 91 | checkpoint_path = [p for p in paths if '.pth' in p] 92 | config_path = [p for p in paths if '.json' in p] 93 | assert len(checkpoint_path) == 1, f'no checkpoint found in {root}' 94 | assert len(config_path) == 1, f'no config found in {root}' 95 | checkpoint_path = checkpoint_path[0] 96 | config_path = config_path[0] 97 | with open(config_path) as f: 98 | cfgs = json.load(f) 99 | cfgs = Config(flatten(cfgs)) 100 | cfgs.mixed_precision = False 101 | return cfgs, checkpoint_path 102 | 103 | 104 | def make_studiogan(model_name='SAGAN', dataset='ImageNet') -> torch.nn.Module: 105 | 106 | # Get configs and model checkpoint path 107 | cfgs, checkpoint_path = get_config_and_checkpoint(ROOT / 'configs' / dataset / model_name) 108 | 109 | # From: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN/blob/master/src/loader.py#L90 110 | Generator = ARCHS[cfgs.architecture].Generator 111 | Gen = Generator( 112 | cfgs.z_dim, cfgs.shared_dim, cfgs.img_size, cfgs.g_conv_dim, cfgs.g_spectral_norm, cfgs.attention, 113 | cfgs.attention_after_nth_gen_block, cfgs.activation_fn, cfgs.conditional_strategy, cfgs.num_classes, 114 | cfgs.g_init, cfgs.G_depth, cfgs.mixed_precision) 115 | 116 | # Checkpoint 117 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 118 | Gen.load_state_dict(checkpoint['state_dict']) 119 | 120 | # Wrap 121 | G = GeneratorWrapper(Gen, cfgs) 122 | return G.eval() 123 | 124 | 125 | if __name__ == '__main__': 126 | # Testing 127 | device = 'cuda' 128 | G = make_studiogan('BigGAN2048').to(device) 129 | print('Created G') 130 | print(f'Params: {sum(p.numel() for p in G.parameters()):_}') 131 | z = torch.randn([1, G.dim_z]).to(device) 132 | print(f'z.shape: {z.shape}') 133 | x = G(z) 134 | print(f'x.shape: {x.shape}') 135 | print(x.max(), x.min()) 136 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/ImageNet/BigGAN2048/BigGAN2048.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "imagenet", 4 | "data_path": "./data/ILSVRC2012", 5 | "img_size": 128, 6 | "num_classes": 1000, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "big_resnet", 15 | "conditional_strategy": "ProjGAN", 16 | "pos_collected_numerator":false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": true, 21 | "d_spectral_norm": true, 22 | "activation_fn": "ReLU", 23 | "attention": true, 24 | "attention_after_nth_gen_block": 4, 25 | "attention_after_nth_dis_block": 1, 26 | "z_dim": 120, 27 | "shared_dim": 128, 28 | "g_conv_dim": 96, 29 | "d_conv_dim": 96, 30 | "G_depth": "N/A", 31 | "D_depth": "N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 8, 38 | "d_lr": 0.0002, 39 | "g_lr": 0.00005, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 2, 47 | "total_step": 200000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda": "N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penalty_lambda": "N/A", 65 | 66 | "deep_regret_analysis_for_dis": false, 67 | "regret_penalty_lambda": "N/A", 68 | 69 | "cr": false, 70 | "cr_lambda":"N/A", 71 | 72 | "bcr": false, 73 | "real_lambda": "N/A", 74 | "fake_lambda": "N/A", 75 | 76 | "zcr": false, 77 | "gen_lambda": "N/A", 78 | "dis_lambda": "N/A", 79 | "sigma_noise": "N/A" 80 | }, 81 | 82 | "initialization":{ 83 | "g_init": "ortho", 84 | "d_init": "ortho" 85 | }, 86 | 87 | "training_and_sampling_setting":{ 88 | "random_flip_preprocessing": true, 89 | "diff_aug":false, 90 | 91 | "ada": false, 92 | "ada_target": "N/A", 93 | "ada_length": "N/A", 94 | 95 | "prior": "gaussian", 96 | "truncated_factor": 1, 97 | 98 | "latent_op": false, 99 | "latent_op_rate":"N/A", 100 | "latent_op_step":"N/A", 101 | "latent_op_step4eval":"N/A", 102 | "latent_op_alpha":"N/A", 103 | "latent_op_beta":"N/A", 104 | "latent_norm_reg_weight":"N/A", 105 | 106 | "ema": true, 107 | "ema_decay": 0.9999, 108 | "ema_start": 20000 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/ImageNet/BigGAN256/BigGAN256.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "imagenet", 4 | "data_path": "./data/ILSVRC2012", 5 | "img_size": 128, 6 | "num_classes": 1000, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "big_resnet", 15 | "conditional_strategy": "ProjGAN", 16 | "pos_collected_numerator":false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": true, 21 | "d_spectral_norm": true, 22 | "activation_fn": "ReLU", 23 | "attention": true, 24 | "attention_after_nth_gen_block": 4, 25 | "attention_after_nth_dis_block": 1, 26 | "z_dim": 120, 27 | "shared_dim": 128, 28 | "g_conv_dim": 96, 29 | "d_conv_dim": 96, 30 | "G_depth": "N/A", 31 | "D_depth": "N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0002, 39 | "g_lr": 0.00005, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 2, 47 | "total_step": 200000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda": "N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penalty_lambda": "N/A", 65 | 66 | "deep_regret_analysis_for_dis": false, 67 | "regret_penalty_lambda": "N/A", 68 | 69 | "cr": false, 70 | "cr_lambda":"N/A", 71 | 72 | "bcr": false, 73 | "real_lambda": "N/A", 74 | "fake_lambda": "N/A", 75 | 76 | "zcr": false, 77 | "gen_lambda": "N/A", 78 | "dis_lambda": "N/A", 79 | "sigma_noise": "N/A" 80 | }, 81 | 82 | "initialization":{ 83 | "g_init": "ortho", 84 | "d_init": "ortho" 85 | }, 86 | 87 | "training_and_sampling_setting":{ 88 | "random_flip_preprocessing": true, 89 | "diff_aug":false, 90 | 91 | "ada": false, 92 | "ada_target": "N/A", 93 | "ada_length": "N/A", 94 | 95 | "prior": "gaussian", 96 | "truncated_factor": 1, 97 | 98 | "latent_op": false, 99 | "latent_op_rate":"N/A", 100 | "latent_op_step":"N/A", 101 | "latent_op_step4eval":"N/A", 102 | "latent_op_alpha":"N/A", 103 | "latent_op_beta":"N/A", 104 | "latent_norm_reg_weight":"N/A", 105 | 106 | "ema": true, 107 | "ema_decay": 0.9999, 108 | "ema_start": 20000 109 | } 110 | } 111 | } 112 | 113 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/ImageNet/ContraGAN2048/ContraGAN2048.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "imagenet", 4 | "data_path": "./data/ILSVRC2012", 5 | "img_size": 128, 6 | "num_classes": 1000, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "big_resnet", 15 | "conditional_strategy": "ContraGAN", 16 | "pos_collected_numerator": false, 17 | "hypersphere_dim": 1536, 18 | "nonlinear_embed": false, 19 | "normalize_embed": true, 20 | "g_spectral_norm": true, 21 | "d_spectral_norm": true, 22 | "activation_fn": "ReLU", 23 | "attention": true, 24 | "attention_after_nth_gen_block": 4, 25 | "attention_after_nth_dis_block": 1, 26 | "z_dim": 120, 27 | "shared_dim": 128, 28 | "g_conv_dim": 96, 29 | "d_conv_dim": 96, 30 | "G_depth":"N/A", 31 | "D_depth":"N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 8, 38 | "d_lr": 0.0002, 39 | "g_lr": 0.00005, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 2, 47 | "total_step": 200000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda": 1.0, 54 | "margin": 0.0, 55 | "tempering_type": "constant", 56 | "tempering_step": "N/A", 57 | "start_temperature": 1.0, 58 | "end_temperature": 1.0, 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penalty_lambda": "N/A", 65 | 66 | "deep_regret_analysis_for_dis": false, 67 | "regret_penalty_lambda": "N/A", 68 | 69 | "cr": false, 70 | "cr_lambda":"N/A", 71 | 72 | "bcr": false, 73 | "real_lambda": "N/A", 74 | "fake_lambda": "N/A", 75 | 76 | "zcr": false, 77 | "gen_lambda": "N/A", 78 | "dis_lambda": "N/A", 79 | "sigma_noise": "N/A" 80 | }, 81 | 82 | "initialization":{ 83 | "g_init": "ortho", 84 | "d_init": "ortho" 85 | }, 86 | 87 | "training_and_sampling_setting":{ 88 | "random_flip_preprocessing": true, 89 | "diff_aug": false, 90 | 91 | "ada": false, 92 | "ada_target": "N/A", 93 | "ada_length": "N/A", 94 | 95 | "prior": "gaussian", 96 | "truncated_factor": 1, 97 | 98 | "latent_op": false, 99 | "latent_op_rate":"N/A", 100 | "latent_op_step":"N/A", 101 | "latent_op_step4eval":"N/A", 102 | "latent_op_alpha":"N/A", 103 | "latent_op_beta":"N/A", 104 | "latent_norm_reg_weight":"N/A", 105 | 106 | "ema": true, 107 | "ema_decay": 0.9999, 108 | "ema_start": 20000 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/ImageNet/ContraGAN256/ContraGAN256.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "data_processing":{ 4 | "dataset_name": "imagenet", 5 | "data_path": "./data/ILSVRC2012", 6 | "img_size": 128, 7 | "num_classes": 1000, 8 | "batch_size4prcsing": 256, 9 | "chunk_size": 500, 10 | "compression": false 11 | }, 12 | 13 | "train": { 14 | "model": { 15 | "architecture": "big_resnet", 16 | "conditional_strategy": "ContraGAN", 17 | "pos_collected_numerator": false, 18 | "hypersphere_dim": 1536, 19 | "nonlinear_embed": false, 20 | "normalize_embed": true, 21 | "g_spectral_norm": true, 22 | "d_spectral_norm": true, 23 | "activation_fn": "ReLU", 24 | "attention": true, 25 | "attention_after_nth_gen_block": 4, 26 | "attention_after_nth_dis_block": 1, 27 | "z_dim": 120, 28 | "shared_dim": 128, 29 | "g_conv_dim": 96, 30 | "d_conv_dim": 96, 31 | "G_depth":"N/A", 32 | "D_depth":"N/A" 33 | }, 34 | 35 | "optimization": { 36 | "optimizer": "Adam", 37 | "batch_size": 256, 38 | "accumulation_steps": 1, 39 | "d_lr": 0.0002, 40 | "g_lr": 0.00005, 41 | "momentum": "N/A", 42 | "nesterov": "N/A", 43 | "alpha": "N/A", 44 | "beta1": 0.0, 45 | "beta2": 0.999, 46 | "g_steps_per_iter": 1, 47 | "d_steps_per_iter": 2, 48 | "total_step": 200000 49 | }, 50 | 51 | "loss_function": { 52 | "adv_loss": "hinge", 53 | 54 | "contrastive_lambda": 1.0, 55 | "margin": 0.0, 56 | "tempering_type": "constant", 57 | "tempering_step": "N/A", 58 | "start_temperature": 1.0, 59 | "end_temperature": 1.0, 60 | 61 | "weight_clipping_for_dis": false, 62 | "weight_clipping_bound": "N/A", 63 | 64 | "gradient_penalty_for_dis": false, 65 | "gradient_penalty_lambda": "N/A", 66 | 67 | "deep_regret_analysis_for_dis": false, 68 | "regret_penalty_lambda": "N/A", 69 | 70 | "cr": false, 71 | "cr_lambda":"N/A", 72 | 73 | "bcr": false, 74 | "real_lambda": "N/A", 75 | "fake_lambda": "N/A", 76 | 77 | "zcr": false, 78 | "gen_lambda": "N/A", 79 | "dis_lambda": "N/A", 80 | "sigma_noise": "N/A" 81 | }, 82 | 83 | "initialization":{ 84 | "g_init": "ortho", 85 | "d_init": "ortho" 86 | }, 87 | 88 | "training_and_sampling_setting":{ 89 | "random_flip_preprocessing": true, 90 | "diff_aug": false, 91 | 92 | "ada": false, 93 | "ada_target": "N/A", 94 | "ada_length": "N/A", 95 | 96 | "prior": "gaussian", 97 | "truncated_factor": 1, 98 | 99 | "latent_op": false, 100 | "latent_op_rate":"N/A", 101 | "latent_op_step":"N/A", 102 | "latent_op_step4eval":"N/A", 103 | "latent_op_alpha":"N/A", 104 | "latent_op_beta":"N/A", 105 | "latent_norm_reg_weight":"N/A", 106 | 107 | "ema": true, 108 | "ema_decay": 0.9999, 109 | "ema_start": 20000 110 | } 111 | } 112 | } 113 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/ImageNet/SAGAN/SAGAN.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "imagenet", 4 | "data_path": "./data/ILSVRC2012", 5 | "img_size": 128, 6 | "num_classes": 1000, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "resnet", 15 | "conditional_strategy": "ProjGAN", 16 | "pos_collected_numerator": false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": true, 21 | "d_spectral_norm": true, 22 | "activation_fn": "ReLU", 23 | "attention": true, 24 | "attention_after_nth_gen_block": 4, 25 | "attention_after_nth_dis_block": 1, 26 | "z_dim": 128, 27 | "shared_dim": "N/A", 28 | "g_conv_dim": 64, 29 | "d_conv_dim": 64, 30 | "G_depth": "N/A", 31 | "D_depth": "N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0004, 39 | "g_lr": 0.0001, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 1000000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda": "N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penalty_lambda": "N/A", 65 | 66 | "deep_regret_analysis_for_dis": false, 67 | "regret_penalty_lambda": "N/A", 68 | 69 | "cr": false, 70 | "cr_lambda": "N/A", 71 | 72 | "bcr": false, 73 | "real_lambda": "N/A", 74 | "fake_lambda": "N/A", 75 | 76 | "zcr": false, 77 | "gen_lambda": "N/A", 78 | "dis_lambda": "N/A", 79 | "sigma_noise": "N/A" 80 | }, 81 | 82 | "initialization":{ 83 | "g_init": "ortho", 84 | "d_init": "ortho" 85 | }, 86 | 87 | "training_and_sampling_setting":{ 88 | "random_flip_preprocessing": true, 89 | "diff_aug": false, 90 | 91 | "ada": false, 92 | "ada_target": "N/A", 93 | "ada_length": "N/A", 94 | 95 | "prior": "gaussian", 96 | "truncated_factor": 1, 97 | 98 | "latent_op": false, 99 | "latent_op_rate":"N/A", 100 | "latent_op_step":"N/A", 101 | "latent_op_step4eval":"N/A", 102 | "latent_op_alpha":"N/A", 103 | "latent_op_beta":"N/A", 104 | "latent_norm_reg_weight":"N/A", 105 | 106 | "ema": false, 107 | "ema_decay": "N/A", 108 | "ema_start": "N/A" 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/ImageNet/SNGAN/SNGAN.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "imagenet", 4 | "data_path": "./data/ILSVRC2012", 5 | "img_size": 128, 6 | "num_classes": 1000, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "resnet", 15 | "conditional_strategy": "ProjGAN", 16 | "pos_collected_numerator": false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": false, 21 | "d_spectral_norm": true, 22 | "activation_fn": "ReLU", 23 | "attention": false, 24 | "attention_after_nth_gen_block": "N/A", 25 | "attention_after_nth_dis_block": "N/A", 26 | "z_dim": 128, 27 | "shared_dim": "N/A", 28 | "g_conv_dim": 64, 29 | "d_conv_dim": 64, 30 | "G_depth": "N/A", 31 | "D_depth": "N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0002, 39 | "g_lr": 0.00005, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 2, 47 | "total_step": 500000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda": "N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penalty_lambda": "N/A", 65 | 66 | "deep_regret_analysis_for_dis": false, 67 | "regret_penalty_lambda": "N/A", 68 | 69 | "cr": false, 70 | "cr_lambda": "N/A", 71 | 72 | "bcr": false, 73 | "real_lambda": "N/A", 74 | "fake_lambda": "N/A", 75 | 76 | "zcr": false, 77 | "gen_lambda": "N/A", 78 | "dis_lambda": "N/A", 79 | "sigma_noise": "N/A" 80 | }, 81 | 82 | "initialization":{ 83 | "g_init": "ortho", 84 | "d_init": "ortho" 85 | }, 86 | 87 | "training_and_sampling_setting":{ 88 | "random_flip_preprocessing": true, 89 | "diff_aug": false, 90 | 91 | "ada": false, 92 | "ada_target": "N/A", 93 | "ada_length": "N/A", 94 | 95 | "prior": "gaussian", 96 | "truncated_factor": 1, 97 | 98 | "latent_op": false, 99 | "latent_op_rate":"N/A", 100 | "latent_op_step":"N/A", 101 | "latent_op_step4eval":"N/A", 102 | "latent_op_alpha":"N/A", 103 | "latent_op_beta":"N/A", 104 | "latent_norm_reg_weight":"N/A", 105 | 106 | "ema": false, 107 | "ema_decay": "N/A", 108 | "ema_start": "N/A" 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/ImageNet/download.sh: -------------------------------------------------------------------------------- 1 | # SAGAN 2 | # https://drive.google.com/file/d/1exrZloM2bHYJyU5_v9XUMlNcy5TA6gMN/view?usp=sharing 3 | cd SAGAN 4 | gdown --id 1exrZloM2bHYJyU5_v9XUMlNcy5TA6gMN 5 | cd .. 6 | 7 | # SNGAN 8 | # https://drive.google.com/file/d/1L4Jk9v_vRojdj9ZpLBak8OxoahdsX5yn/view?usp=sharing 9 | cd SNGAN 10 | gdown --id 1L4Jk9v_vRojdj9ZpLBak8OxoahdsX5yn 11 | cd .. 12 | 13 | # BigGAN2048 14 | # https://drive.google.com/file/d/14VIJUsYcItZrfNk_PcjglNXH_sPyd504/view?usp=sharing 15 | cd BigGAN2048 16 | gdown --id 16tZIHrXFYFM6mXmEF-4YA7vO1D-s7meq 17 | cd .. 18 | 19 | # ContraGAN256 20 | # https://drive.google.com/file/d/15ipVwbQpncc678tGdT7VsDcFCstUpP1n/view?usp=sharing 21 | cd ContraGAN256 22 | gdown --id 15ipVwbQpncc678tGdT7VsDcFCstUpP1n 23 | cd .. 24 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/ACGAN/ACGAN.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "tiny_imagenet", 4 | "data_path": "./data/TINY_ILSVRC2012", 5 | "img_size": 64, 6 | "num_classes": 200, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "resnet", 15 | "conditional_strategy": "ACGAN", 16 | "pos_collected_numerator": false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": false, 21 | "d_spectral_norm": false, 22 | "activation_fn": "ReLU", 23 | "attention": false, 24 | "attention_after_nth_gen_block": "N/A", 25 | "attention_after_nth_dis_block": "N/A", 26 | "z_dim": 128, 27 | "shared_dim": "N/A", 28 | "g_conv_dim": 64, 29 | "d_conv_dim": 64, 30 | "G_depth": "N/A", 31 | "D_depth": "N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0004, 39 | "g_lr": 0.0001, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 100000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda": "N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penalty_lambda": "N/A", 65 | 66 | "deep_regret_analysis_for_dis": false, 67 | "regret_penalty_lambda": "N/A", 68 | 69 | "cr": false, 70 | "cr_lambda": "N/A", 71 | 72 | "bcr": false, 73 | "real_lambda": "N/A", 74 | "fake_lambda": "N/A", 75 | 76 | "zcr": false, 77 | "gen_lambda": "N/A", 78 | "dis_lambda": "N/A", 79 | "sigma_noise": "N/A" 80 | }, 81 | 82 | "initialization":{ 83 | "g_init": "ortho", 84 | "d_init": "ortho" 85 | }, 86 | 87 | "training_and_sampling_setting":{ 88 | "random_flip_preprocessing": true, 89 | "diff_aug": false, 90 | 91 | "ada": false, 92 | "ada_target": "N/A", 93 | "ada_length": "N/A", 94 | 95 | "prior": "gaussian", 96 | "truncated_factor": 1, 97 | 98 | "latent_op": false, 99 | "latent_op_rate":"N/A", 100 | "latent_op_step":"N/A", 101 | "latent_op_step4eval":"N/A", 102 | "latent_op_alpha":"N/A", 103 | "latent_op_beta":"N/A", 104 | "latent_norm_reg_weight":"N/A", 105 | 106 | "ema": false, 107 | "ema_decay": "N/A", 108 | "ema_start": "N/A" 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/BigGAN/BigGAN.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "tiny_imagenet", 4 | "data_path": "./data/TINY_ILSVRC2012", 5 | "img_size": 64, 6 | "num_classes": 200, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "big_resnet", 15 | "conditional_strategy": "ProjGAN", 16 | "pos_collected_numerator": false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": true, 21 | "d_spectral_norm": true, 22 | "activation_fn": "ReLU", 23 | "attention": true, 24 | "attention_after_nth_gen_block": 3, 25 | "attention_after_nth_dis_block": 1, 26 | "z_dim": 100, 27 | "shared_dim": 128, 28 | "g_conv_dim": 80, 29 | "d_conv_dim": 80, 30 | "G_depth":"N/A", 31 | "D_depth":"N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 1024, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0004, 39 | "g_lr": 0.0001, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 100000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda":"N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penalty_lambda": "N/A", 65 | 66 | "deep_regret_analysis_for_dis": false, 67 | "regret_penalty_lambda": "N/A", 68 | 69 | "cr": false, 70 | "cr_lambda":"N/A", 71 | 72 | "bcr": false, 73 | "real_lambda": "N/A", 74 | "fake_lambda": "N/A", 75 | 76 | "zcr": false, 77 | "gen_lambda": "N/A", 78 | "dis_lambda": "N/A", 79 | "sigma_noise": "N/A" 80 | }, 81 | 82 | "initialization":{ 83 | "g_init": "ortho", 84 | "d_init": "ortho" 85 | }, 86 | 87 | "training_and_sampling_setting":{ 88 | "random_flip_preprocessing": true, 89 | "diff_aug": false, 90 | 91 | "ada": false, 92 | "ada_target": "N/A", 93 | "ada_length": "N/A", 94 | 95 | "prior": "gaussian", 96 | "truncated_factor": 1, 97 | 98 | "latent_op": false, 99 | "latent_op_rate":"N/A", 100 | "latent_op_step":"N/A", 101 | "latent_op_step4eval":"N/A", 102 | "latent_op_alpha":"N/A", 103 | "latent_op_beta":"N/A", 104 | "latent_norm_reg_weight":"N/A", 105 | 106 | "ema": true, 107 | "ema_decay": 0.9999, 108 | "ema_start": 20000 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/ContraGAN/DiffAugGAN(C).json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "tiny_imagenet", 4 | "data_path": "./data/TINY_ILSVRC2012", 5 | "img_size": 64, 6 | "num_classes": 200, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "big_resnet", 15 | "conditional_strategy": "ContraGAN", 16 | "pos_collected_numerator": true, 17 | "hypersphere_dim": 768, 18 | "nonlinear_embed": false, 19 | "normalize_embed": true, 20 | "g_spectral_norm": true, 21 | "d_spectral_norm": true, 22 | "activation_fn": "ReLU", 23 | "attention": true, 24 | "attention_after_nth_gen_block": 3, 25 | "attention_after_nth_dis_block": 1, 26 | "z_dim": 100, 27 | "shared_dim": 128, 28 | "g_conv_dim": 80, 29 | "d_conv_dim": 80, 30 | "G_depth":"N/A", 31 | "D_depth":"N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 1024, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0004, 39 | "g_lr": 0.0001, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 200000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda": 1.0, 54 | "margin": 0.0, 55 | "tempering_type": "constant", 56 | "tempering_step": "N/A", 57 | "start_temperature": 1.0, 58 | "end_temperature": 1.0, 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penalty_lambda": "N/A", 65 | 66 | "deep_regret_analysis_for_dis": false, 67 | "regret_penalty_lambda": "N/A", 68 | 69 | "cr": false, 70 | "cr_lambda": "N/A", 71 | 72 | "bcr": false, 73 | "real_lambda": "N/A", 74 | "fake_lambda": "N/A", 75 | 76 | "zcr": false, 77 | "gen_lambda": "N/A", 78 | "dis_lambda": "N/A", 79 | "sigma_noise": "N/A" 80 | }, 81 | 82 | "initialization":{ 83 | "g_init": "ortho", 84 | "d_init": "ortho" 85 | }, 86 | 87 | "training_and_sampling_setting":{ 88 | "random_flip_preprocessing": true, 89 | "diff_aug": true, 90 | 91 | "ada": false, 92 | "ada_target": "N/A", 93 | "ada_length": "N/A", 94 | 95 | "prior": "gaussian", 96 | "truncated_factor": 1, 97 | 98 | "latent_op": false, 99 | "latent_op_rate": "N/A", 100 | "latent_op_step": "N/A", 101 | "latent_op_step4eval": "N/A", 102 | "latent_op_alpha": "N/A", 103 | "latent_op_beta": "N/A", 104 | "latent_norm_reg_weight": "N/A", 105 | 106 | "ema": true, 107 | "ema_decay": 0.9999, 108 | "ema_start": 20000 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/GGAN/GGAN.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "tiny_imagenet", 4 | "data_path": "./data/TINY_ILSVRC2012", 5 | "img_size": 64, 6 | "num_classes": 200, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "resnet", 15 | "conditional_strategy": "no", 16 | "pos_collected_numerator": false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": false, 21 | "d_spectral_norm": false, 22 | "activation_fn": "ReLU", 23 | "attention": false, 24 | "attention_after_nth_gen_block": "N/A", 25 | "attention_after_nth_dis_block": "N/A", 26 | "z_dim": 128, 27 | "shared_dim": "N/A", 28 | "g_conv_dim": 64, 29 | "d_conv_dim": 64, 30 | "G_depth": "N/A", 31 | "D_depth": "N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0004, 39 | "g_lr": 0.0001, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 100000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda": "N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penalty_lambda": "N/A", 65 | 66 | "deep_regret_analysis_for_dis": false, 67 | "regret_penalty_lambda": "N/A", 68 | 69 | "cr": false, 70 | "cr_lambda": "N/A", 71 | 72 | "bcr": false, 73 | "real_lambda": "N/A", 74 | "fake_lambda": "N/A", 75 | 76 | "zcr": false, 77 | "gen_lambda": "N/A", 78 | "dis_lambda": "N/A", 79 | "sigma_noise": "N/A" 80 | }, 81 | 82 | "initialization":{ 83 | "g_init": "ortho", 84 | "d_init": "ortho" 85 | }, 86 | 87 | "training_and_sampling_setting":{ 88 | "random_flip_preprocessing": true, 89 | "diff_aug": false, 90 | 91 | "ada": false, 92 | "ada_target": "N/A", 93 | "ada_length": "N/A", 94 | 95 | "prior": "gaussian", 96 | "truncated_factor": 1, 97 | 98 | "latent_op": false, 99 | "latent_op_rate":"N/A", 100 | "latent_op_step":"N/A", 101 | "latent_op_step4eval":"N/A", 102 | "latent_op_alpha":"N/A", 103 | "latent_op_beta":"N/A", 104 | "latent_norm_reg_weight":"N/A", 105 | 106 | "ema": false, 107 | "ema_decay": "N/A", 108 | "ema_start": "N/A" 109 | } 110 | } 111 | } -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/LSGAN/LSGAN.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "tiny_imagenet", 4 | "data_path": "./data/TINY_ILSVRC2012", 5 | "img_size": 64, 6 | "num_classes": 200, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "resnet", 15 | "conditional_strategy": "no", 16 | "pos_collected_numerator": false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": false, 21 | "d_spectral_norm": false, 22 | "activation_fn": "ReLU", 23 | "attention": false, 24 | "attention_after_nth_gen_block": "N/A", 25 | "attention_after_nth_dis_block": "N/A", 26 | "z_dim": 128, 27 | "shared_dim": "N/A", 28 | "g_conv_dim": 64, 29 | "d_conv_dim": 64, 30 | "G_depth": "N/A", 31 | "D_depth": "N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0004, 39 | "g_lr": 0.0001, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 100000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "least_square", 52 | 53 | "contrastive_lambda": "N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penalty_lambda": "N/A", 65 | 66 | "deep_regret_analysis_for_dis": false, 67 | "regret_penalty_lambda": "N/A", 68 | 69 | "cr": false, 70 | "cr_lambda": "N/A", 71 | 72 | "bcr": false, 73 | "real_lambda": "N/A", 74 | "fake_lambda": "N/A", 75 | 76 | "zcr": false, 77 | "gen_lambda": "N/A", 78 | "dis_lambda": "N/A", 79 | "sigma_noise": "N/A" 80 | }, 81 | 82 | "initialization":{ 83 | "g_init": "ortho", 84 | "d_init": "ortho" 85 | }, 86 | 87 | "training_and_sampling_setting":{ 88 | "random_flip_preprocessing": true, 89 | "diff_aug": false, 90 | 91 | "ada": false, 92 | "ada_target": "N/A", 93 | "ada_length": "N/A", 94 | 95 | "prior": "gaussian", 96 | "truncated_factor": 1, 97 | 98 | "latent_op": false, 99 | "latent_op_rate":"N/A", 100 | "latent_op_step":"N/A", 101 | "latent_op_step4eval":"N/A", 102 | "latent_op_alpha":"N/A", 103 | "latent_op_beta":"N/A", 104 | "latent_norm_reg_weight":"N/A", 105 | 106 | "ema": false, 107 | "ema_decay": "N/A", 108 | "ema_start": "N/A" 109 | } 110 | } 111 | } -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/ProjGAN/ProjGAN.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "tiny_imagenet", 4 | "data_path": "./data/TINY_ILSVRC2012", 5 | "img_size": 64, 6 | "num_classes": 200, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "resnet", 15 | "conditional_strategy": "ProjGAN", 16 | "pos_collected_numerator": false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": false, 21 | "d_spectral_norm": false, 22 | "activation_fn": "ReLU", 23 | "attention": false, 24 | "attention_after_nth_gen_block": "N/A", 25 | "attention_after_nth_dis_block": "N/A", 26 | "z_dim": 128, 27 | "shared_dim": "N/A", 28 | "g_conv_dim": 64, 29 | "d_conv_dim": 64, 30 | "G_depth": "N/A", 31 | "D_depth": "N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0004, 39 | "g_lr": 0.0001, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 100000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda": "N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penalty_lambda": "N/A", 65 | 66 | "deep_regret_analysis_for_dis": false, 67 | "regret_penalty_lambda": "N/A", 68 | 69 | "cr": false, 70 | "cr_lambda": "N/A", 71 | 72 | "bcr": false, 73 | "real_lambda": "N/A", 74 | "fake_lambda": "N/A", 75 | 76 | "zcr": false, 77 | "gen_lambda": "N/A", 78 | "dis_lambda": "N/A", 79 | "sigma_noise": "N/A" 80 | }, 81 | 82 | "initialization":{ 83 | "g_init": "ortho", 84 | "d_init": "ortho" 85 | }, 86 | 87 | "training_and_sampling_setting":{ 88 | "random_flip_preprocessing": true, 89 | "diff_aug": false, 90 | 91 | "ada": false, 92 | "ada_target": "N/A", 93 | "ada_length": "N/A", 94 | 95 | "prior": "gaussian", 96 | "truncated_factor": 1, 97 | 98 | "latent_op": false, 99 | "latent_op_rate": "N/A", 100 | "latent_op_step": "N/A", 101 | "latent_op_step4eval": "N/A", 102 | "latent_op_alpha": "N/A", 103 | "latent_op_beta": "N/A", 104 | "latent_norm_reg_weight": "N/A", 105 | 106 | "ema": false, 107 | "ema_decay": "N/A", 108 | "ema_start": "N/A" 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/SAGAN/SAGAN.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "tiny_imagenet", 4 | "data_path": "./data/TINY_ILSVRC2012", 5 | "img_size": 64, 6 | "num_classes": 200, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "resnet", 15 | "conditional_strategy": "ProjGAN", 16 | "pos_collected_numerator": false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": true, 21 | "d_spectral_norm": true, 22 | "activation_fn": "ReLU", 23 | "attention": true, 24 | "attention_after_nth_gen_block": 3, 25 | "attention_after_nth_dis_block": 1, 26 | "z_dim": 128, 27 | "shared_dim": "N/A", 28 | "g_conv_dim": 64, 29 | "d_conv_dim": 64, 30 | "G_depth": "N/A", 31 | "D_depth": "N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0004, 39 | "g_lr": 0.0001, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 100000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda": "N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penalty_lambda": "N/A", 65 | 66 | "deep_regret_analysis_for_dis": false, 67 | "regret_penalty_lambda": "N/A", 68 | 69 | "cr": false, 70 | "cr_lambda": "N/A", 71 | 72 | "bcr": false, 73 | "real_lambda": "N/A", 74 | "fake_lambda": "N/A", 75 | 76 | "zcr": false, 77 | "gen_lambda": "N/A", 78 | "dis_lambda": "N/A", 79 | "sigma_noise": "N/A" 80 | }, 81 | 82 | "initialization":{ 83 | "g_init": "ortho", 84 | "d_init": "ortho" 85 | }, 86 | 87 | "training_and_sampling_setting":{ 88 | "random_flip_preprocessing": true, 89 | "diff_aug": false, 90 | 91 | "ada": false, 92 | "ada_target": "N/A", 93 | "ada_length": "N/A", 94 | 95 | "prior": "gaussian", 96 | "truncated_factor": 1, 97 | 98 | "latent_op": false, 99 | "latent_op_rate":"N/A", 100 | "latent_op_step":"N/A", 101 | "latent_op_step4eval":"N/A", 102 | "latent_op_alpha":"N/A", 103 | "latent_op_beta":"N/A", 104 | "latent_norm_reg_weight":"N/A", 105 | 106 | "ema": false, 107 | "ema_decay": "N/A", 108 | "ema_start": "N/A" 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/SNGAN/SNGAN.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "tiny_imagenet", 4 | "data_path": "./data/TINY_ILSVRC2012", 5 | "img_size": 64, 6 | "num_classes": 200, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "resnet", 15 | "conditional_strategy": "ProjGAN", 16 | "pos_collected_numerator": false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": false, 21 | "d_spectral_norm": true, 22 | "activation_fn": "ReLU", 23 | "attention": false, 24 | "attention_after_nth_gen_block": "N/A", 25 | "attention_after_nth_dis_block": "N/A", 26 | "z_dim": 128, 27 | "shared_dim": "N/A", 28 | "g_conv_dim": 64, 29 | "d_conv_dim": 64, 30 | "G_depth": "N/A", 31 | "D_depth": "N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0004, 39 | "g_lr": 0.0001, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 100000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "hinge", 52 | 53 | "contrastive_lambda": "N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": false, 64 | "gradient_penalty_lambda": "N/A", 65 | 66 | "deep_regret_analysis_for_dis": false, 67 | "regret_penalty_lambda": "N/A", 68 | 69 | "cr": false, 70 | "cr_lambda": "N/A", 71 | 72 | "bcr": false, 73 | "real_lambda": "N/A", 74 | "fake_lambda": "N/A", 75 | 76 | "zcr": false, 77 | "gen_lambda": "N/A", 78 | "dis_lambda": "N/A", 79 | "sigma_noise": "N/A" 80 | }, 81 | 82 | "initialization":{ 83 | "g_init": "ortho", 84 | "d_init": "ortho" 85 | }, 86 | 87 | "training_and_sampling_setting":{ 88 | "random_flip_preprocessing": true, 89 | "diff_aug": false, 90 | 91 | "ada": false, 92 | "ada_target": "N/A", 93 | "ada_length": "N/A", 94 | 95 | "prior": "gaussian", 96 | "truncated_factor": 1, 97 | 98 | "latent_op": false, 99 | "latent_op_rate":"N/A", 100 | "latent_op_step":"N/A", 101 | "latent_op_step4eval":"N/A", 102 | "latent_op_alpha":"N/A", 103 | "latent_op_beta":"N/A", 104 | "latent_norm_reg_weight":"N/A", 105 | 106 | "ema": false, 107 | "ema_decay": "N/A", 108 | "ema_start": "N/A" 109 | } 110 | } 111 | } 112 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/WGAN-WC/WGAN-WC.json: -------------------------------------------------------------------------------- 1 | { 2 | "data_processing":{ 3 | "dataset_name": "tiny_imagenet", 4 | "data_path": "./data/TINY_ILSVRC2012", 5 | "img_size": 64, 6 | "num_classes": 200, 7 | "batch_size4prcsing": 256, 8 | "chunk_size": 500, 9 | "compression": false 10 | }, 11 | 12 | "train": { 13 | "model": { 14 | "architecture": "resnet", 15 | "conditional_strategy": "no", 16 | "pos_collected_numerator": false, 17 | "hypersphere_dim": "N/A", 18 | "nonlinear_embed": false, 19 | "normalize_embed": false, 20 | "g_spectral_norm": false, 21 | "d_spectral_norm": false, 22 | "activation_fn": "ReLU", 23 | "attention": false, 24 | "attention_after_nth_gen_block": "N/A", 25 | "attention_after_nth_dis_block": "N/A", 26 | "z_dim": 128, 27 | "shared_dim": "N/A", 28 | "g_conv_dim": 64, 29 | "d_conv_dim": 64, 30 | "G_depth": "N/A", 31 | "D_depth": "N/A" 32 | }, 33 | 34 | "optimization": { 35 | "optimizer": "Adam", 36 | "batch_size": 256, 37 | "accumulation_steps": 1, 38 | "d_lr": 0.0004, 39 | "g_lr": 0.0001, 40 | "momentum": "N/A", 41 | "nesterov": "N/A", 42 | "alpha": "N/A", 43 | "beta1": 0.0, 44 | "beta2": 0.999, 45 | "g_steps_per_iter": 1, 46 | "d_steps_per_iter": 1, 47 | "total_step": 100000 48 | }, 49 | 50 | "loss_function": { 51 | "adv_loss": "wasserstein", 52 | 53 | "contrastive_lambda": "N/A", 54 | "margin": "N/A", 55 | "tempering_type": "N/A", 56 | "tempering_step": "N/A", 57 | "start_temperature": "N/A", 58 | "end_temperature": "N/A", 59 | 60 | "weight_clipping_for_dis": false, 61 | "weight_clipping_bound": "N/A", 62 | 63 | "gradient_penalty_for_dis": true, 64 | "gradient_penalty_lambda": 10.0, 65 | 66 | "deep_regret_analysis_for_dis": false, 67 | "regret_penalty_lambda": "N/A", 68 | 69 | "cr": false, 70 | "cr_lambda": "N/A", 71 | 72 | "bcr": false, 73 | "real_lambda": "N/A", 74 | "fake_lambda": "N/A", 75 | 76 | "zcr": false, 77 | "gen_lambda": "N/A", 78 | "dis_lambda": "N/A", 79 | "sigma_noise": "N/A" 80 | }, 81 | 82 | "initialization":{ 83 | "g_init": "ortho", 84 | "d_init": "ortho" 85 | }, 86 | 87 | "training_and_sampling_setting":{ 88 | "random_flip_preprocessing": true, 89 | "diff_aug": false, 90 | 91 | "ada": false, 92 | "ada_target": "N/A", 93 | "ada_length": "N/A", 94 | 95 | "prior": "gaussian", 96 | "truncated_factor": 1, 97 | 98 | "latent_op": false, 99 | "latent_op_rate":"N/A", 100 | "latent_op_step":"N/A", 101 | "latent_op_step4eval":"N/A", 102 | "latent_op_alpha":"N/A", 103 | "latent_op_beta":"N/A", 104 | "latent_norm_reg_weight":"N/A", 105 | 106 | "ema": false, 107 | "ema_decay": "N/A", 108 | "ema_start": "N/A" 109 | } 110 | } 111 | } -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/configs/TinyImageNet/download.sh: -------------------------------------------------------------------------------- 1 | 2 | # LSGAN 3 | # https://drive.google.com/file/d/1Wa5CrUAoxgW730Z1MXg7QJ-O2Y_XiIOC/view?usp=sharing 4 | mkdir LSGAN 5 | cd LSGAN 6 | # gdown --id 1Wa5CrUAoxgW730Z1MXg7QJ-O2Y_XiIOC 7 | wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/LSGAN.json 8 | cd .. 9 | 10 | # GGAN 11 | # https://drive.google.com/file/d/1U5644ZhZUdoJDUPLELoQHtPA9qOdZIXS/view?usp=sharing 12 | mkdir GGAN 13 | cd GGAN 14 | # gdown --id 1U5644ZhZUdoJDUPLELoQHtPA9qOdZIXS 15 | wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/GGAN.json 16 | cd .. 17 | 18 | # WGAN-WC 19 | # https://drive.google.com/file/d/1TbWjWx8PhSHKmh-gv3WTYybSKWpoj8_u/view?usp=sharing 20 | mkdir WGAN-WC 21 | cd WGAN-WC 22 | # gdown --id 1TbWjWx8PhSHKmh-gv3WTYybSKWpoj8_u 23 | wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/WGAN-WC.json 24 | cd .. 25 | 26 | # ACGAN 27 | # https://drive.google.com/file/d/14JkiZLONLXAP1JCfixlSPkbKxK5mG1cF/view?usp=sharing 28 | mkdir ACGAN 29 | cd ACGAN 30 | # gdown --id 14JkiZLONLXAP1JCfixlSPkbKxK5mG1cF 31 | wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/ACGAN.json 32 | cd .. 33 | 34 | # ProjGAN 35 | # https://drive.google.com/file/d/1mRtit-GFIHjD--YLG-PzhThKkyOW7zoI/view?usp=sharing 36 | mkdir ProjGAN 37 | cd ProjGAN 38 | # gdown --id 1mRtit-GFIHjD--YLG-PzhThKkyOW7zoI 39 | wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/ProjGAN.json 40 | cd .. 41 | 42 | # SNGAN 43 | # https://drive.google.com/file/d/1xHrk4bt0Xbatvt3hs4RoMM3E6BCBUAmw/view?usp=sharing 44 | mkdir SNGAN 45 | cd SNGAN 46 | # gdown --id 1xHrk4bt0Xbatvt3hs4RoMM3E6BCBUAmw 47 | wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/SNGAN.json 48 | cd .. 49 | 50 | # SAGAN 51 | # https://drive.google.com/file/d/1vaEwUqUF_qC5uUBRNW413vt_8QMYfuoN/view?usp=sharing 52 | mkdir SAGAN 53 | cd SAGAN 54 | # gdown --id 1vaEwUqUF_qC5uUBRNW413vt_8QMYfuoN 55 | wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/SAGAN.json 56 | cd .. 57 | 58 | # BigGAN 59 | # https://drive.google.com/file/d/16FqpBcB318De2HM7XS6UNFT7zs-3XD6e/view?usp=sharing 60 | mkdir BigGAN 61 | cd BigGAN 62 | # gdown --id 16FqpBcB318De2HM7XS6UNFT7zs-3XD6e 63 | wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/BigGAN.json 64 | cd .. 65 | 66 | # ContraGAN 67 | # https://drive.google.com/file/d/1NKcNjtg51rfmFvTSMmZlkweMAlQJhYUA/view?usp=sharing 68 | mkdir ContraGAN 69 | cd ContraGAN 70 | # gdown --id 1NKcNjtg51rfmFvTSMmZlkweMAlQJhYUA 71 | wget https://raw.githubusercontent.com/POSTECH-CVLab/PyTorch-StudioGAN/master/src/configs/TINY_ILSVRC2012/DiffAugGAN(C).json 72 | cd .. 73 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/pytorch-pretrained-gans/2982fdab4e683165e45bc2f4a64c2942a7a3a1b7/pytorch_pretrained_gans/StudioGAN/models/__init__.py -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/pytorch-pretrained-gans/2982fdab4e683165e45bc2f4a64c2942a7a3a1b7/pytorch_pretrained_gans/StudioGAN/sync_batchnorm/__init__.py -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | """ 2 | -*- coding: utf-8 -*- 3 | File : batchnorm_reimpl.py 4 | Author : Jiayuan Mao 5 | Email : maojiayuan@gmail.com 6 | Date : 27/01/2018 7 | 8 | This file is part of Synchronized-BatchNorm-PyTorch. 9 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 10 | Distributed under MIT License. 11 | 12 | MIT License 13 | 14 | Copyright (c) 2018 Jiayuan MAO 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | """ 34 | 35 | 36 | import torch 37 | import torch.nn as nn 38 | import torch.nn.init as init 39 | 40 | __all__ = ['BatchNorm2dReimpl'] 41 | 42 | 43 | class BatchNorm2dReimpl(nn.Module): 44 | """ 45 | A re-implementation of batch normalization, used for testing the numerical 46 | stability. 47 | 48 | Author: acgtyrant 49 | See also: 50 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 51 | """ 52 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 53 | super().__init__() 54 | 55 | self.num_features = num_features 56 | self.eps = eps 57 | self.momentum = momentum 58 | self.weight = nn.Parameter(torch.empty(num_features)) 59 | self.bias = nn.Parameter(torch.empty(num_features)) 60 | self.register_buffer('running_mean', torch.zeros(num_features)) 61 | self.register_buffer('running_var', torch.ones(num_features)) 62 | self.reset_parameters() 63 | 64 | def reset_running_stats(self): 65 | self.running_mean.zero_() 66 | self.running_var.fill_(1) 67 | 68 | def reset_parameters(self): 69 | self.reset_running_stats() 70 | init.uniform_(self.weight) 71 | init.zeros_(self.bias) 72 | 73 | def forward(self, input_): 74 | batchsize, channels, height, width = input_.size() 75 | numel = batchsize * height * width 76 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 77 | sum_ = input_.sum(1) 78 | sum_of_square = input_.pow(2).sum(1) 79 | mean = sum_ / numel 80 | sumvar = sum_of_square - sum_ * mean 81 | 82 | self.running_mean = ( 83 | (1 - self.momentum) * self.running_mean 84 | + self.momentum * mean.detach() 85 | ) 86 | unbias_var = sumvar / (numel - 1) 87 | self.running_var = ( 88 | (1 - self.momentum) * self.running_var 89 | + self.momentum * unbias_var.detach() 90 | ) 91 | 92 | bias_var = sumvar / numel 93 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 94 | output = ( 95 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 96 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 97 | 98 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 99 | 100 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | """ 2 | -*- coding: utf-8 -*- 3 | File : replicate.py 4 | Author : Jiayuan Mao 5 | Email : maojiayuan@gmail.com 6 | Date : 27/01/2018 7 | 8 | This file is part of Synchronized-BatchNorm-PyTorch. 9 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 10 | Distributed under MIT License. 11 | 12 | MIT License 13 | 14 | Copyright (c) 2018 Jiayuan MAO 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | """ 34 | 35 | 36 | import functools 37 | 38 | from torch.nn.parallel.data_parallel import DataParallel 39 | 40 | __all__ = [ 41 | 'CallbackContext', 42 | 'execute_replication_callbacks', 43 | 'DataParallelWithCallback', 44 | 'patch_replication_callback' 45 | ] 46 | 47 | 48 | class CallbackContext(object): 49 | pass 50 | 51 | 52 | def execute_replication_callbacks(modules): 53 | """ 54 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 55 | 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Note that, as all modules are isomorphism, we assign each sub-module with a context 59 | (shared among multiple copies of this module on different devices). 60 | Through this context, different copies can share some information. 61 | 62 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 63 | of any slave copies. 64 | """ 65 | master_copy = modules[0] 66 | nr_modules = len(list(master_copy.modules())) 67 | ctxs = [CallbackContext() for _ in range(nr_modules)] 68 | 69 | for i, module in enumerate(modules): 70 | for j, m in enumerate(module.modules()): 71 | if hasattr(m, '__data_parallel_replicate__'): 72 | m.__data_parallel_replicate__(ctxs[j], i) 73 | 74 | 75 | class DataParallelWithCallback(DataParallel): 76 | """ 77 | Data Parallel with a replication callback. 78 | 79 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 80 | original `replicate` function. 81 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 82 | 83 | Examples: 84 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 85 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 86 | # sync_bn.__data_parallel_replicate__ will be invoked. 87 | """ 88 | 89 | def replicate(self, module, device_ids): 90 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | 95 | def patch_replication_callback(data_parallel): 96 | """ 97 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 98 | Useful when you have customized `DataParallel` implementation. 99 | 100 | Examples: 101 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 102 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 103 | > patch_replication_callback(sync_bn) 104 | # this is equivalent to 105 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 106 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 107 | """ 108 | 109 | assert isinstance(data_parallel, DataParallel) 110 | 111 | old_replicate = data_parallel.replicate 112 | 113 | @functools.wraps(old_replicate) 114 | def new_replicate(module, device_ids): 115 | modules = old_replicate(module, device_ids) 116 | execute_replication_callbacks(modules) 117 | return modules 118 | 119 | data_parallel.replicate = new_replicate 120 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | """ 2 | -*- coding: utf-8 -*- 3 | File : unittest.py 4 | Author : Jiayuan Mao 5 | Email : maojiayuan@gmail.com 6 | Date : 27/01/2018 7 | 8 | This file is part of Synchronized-BatchNorm-PyTorch. 9 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 10 | Distributed under MIT License. 11 | 12 | MIT License 13 | 14 | Copyright (c) 2018 Jiayuan MAO 15 | 16 | Permission is hereby granted, free of charge, to any person obtaining a copy 17 | of this software and associated documentation files (the "Software"), to deal 18 | in the Software without restriction, including without limitation the rights 19 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 20 | copies of the Software, and to permit persons to whom the Software is 21 | furnished to do so, subject to the following conditions: 22 | 23 | The above copyright notice and this permission notice shall be included in all 24 | copies or substantial portions of the Software. 25 | 26 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 27 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 28 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 29 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 30 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 31 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 32 | SOFTWARE. 33 | """ 34 | 35 | 36 | import unittest 37 | import torch 38 | 39 | 40 | class TorchTestCase(unittest.TestCase): 41 | def assertTensorClose(self, x, y): 42 | adiff = float((x - y).abs().max()) 43 | if (y == 0).all(): 44 | rdiff = 'NaN' 45 | else: 46 | rdiff = float((adiff / y).abs().max()) 47 | 48 | message = ( 49 | 'Tensor close check failed\n' 50 | 'adiff={}\n' 51 | 'rdiff={}\n' 52 | ).format(adiff, rdiff) 53 | self.assertTrue(torch.allclose(x, y), message) 54 | 55 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/pytorch-pretrained-gans/2982fdab4e683165e45bc2f4a64c2942a7a3a1b7/pytorch_pretrained_gans/StudioGAN/utils/__init__.py -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/utils/ada_op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/utils/ada_op/fused_act.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2019 Kim Seonghyeon 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | """ 24 | 25 | 26 | import os 27 | 28 | import torch 29 | from torch import nn 30 | from torch.nn import functional as F 31 | from torch.autograd import Function 32 | from torch.utils.cpp_extension import load 33 | 34 | 35 | module_path = os.path.dirname(__file__) 36 | fused = load( 37 | "fused", 38 | sources=[ 39 | os.path.join(module_path, "fused_bias_act.cpp"), 40 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 41 | ], 42 | ) 43 | 44 | 45 | class FusedLeakyReLUFunctionBackward(Function): 46 | @staticmethod 47 | def forward(ctx, grad_output, out, negative_slope, scale): 48 | ctx.save_for_backward(out) 49 | ctx.negative_slope = negative_slope 50 | ctx.scale = scale 51 | 52 | empty = grad_output.new_empty(0) 53 | 54 | grad_input = fused.fused_bias_act( 55 | grad_output, empty, out, 3, 1, negative_slope, scale 56 | ) 57 | 58 | dim = [0] 59 | 60 | if grad_input.ndim > 2: 61 | dim += list(range(2, grad_input.ndim)) 62 | 63 | grad_bias = grad_input.sum(dim).detach() 64 | 65 | return grad_input, grad_bias 66 | 67 | @staticmethod 68 | def backward(ctx, gradgrad_input, gradgrad_bias): 69 | out, = ctx.saved_tensors 70 | gradgrad_out = fused.fused_bias_act( 71 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 72 | ) 73 | 74 | return gradgrad_out, None, None, None 75 | 76 | 77 | class FusedLeakyReLUFunction(Function): 78 | @staticmethod 79 | def forward(ctx, input, bias, negative_slope, scale): 80 | empty = input.new_empty(0) 81 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 82 | ctx.save_for_backward(out) 83 | ctx.negative_slope = negative_slope 84 | ctx.scale = scale 85 | 86 | return out 87 | 88 | @staticmethod 89 | def backward(ctx, grad_output): 90 | out, = ctx.saved_tensors 91 | 92 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 93 | grad_output, out, ctx.negative_slope, ctx.scale 94 | ) 95 | 96 | return grad_input, grad_bias, None, None 97 | 98 | 99 | class FusedLeakyReLU(nn.Module): 100 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 101 | super().__init__() 102 | 103 | self.bias = nn.Parameter(torch.zeros(channel)) 104 | self.negative_slope = negative_slope 105 | self.scale = scale 106 | 107 | def forward(self, input): 108 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 109 | 110 | 111 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 112 | if input.device.type == "cpu": 113 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 114 | return ( 115 | F.leaky_relu( 116 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 117 | ) 118 | * scale 119 | ) 120 | 121 | else: 122 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 123 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/utils/ada_op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | MIT License 3 | 4 | Copyright (c) 2019 Kim Seonghyeon 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | */ 24 | 25 | 26 | #include 27 | 28 | 29 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 30 | int act, int grad, float alpha, float scale); 31 | 32 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 33 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 34 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 35 | 36 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 37 | int act, int grad, float alpha, float scale) { 38 | CHECK_CUDA(input); 39 | CHECK_CUDA(bias); 40 | 41 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 42 | } 43 | 44 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 45 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 46 | } 47 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/utils/ada_op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } 100 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/utils/ada_op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | MIT License 3 | 4 | Copyright (c) 2019 Kim Seonghyeon 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | */ 24 | 25 | 26 | #include 27 | 28 | 29 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 30 | int up_x, int up_y, int down_x, int down_y, 31 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 32 | 33 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 34 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 35 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 36 | 37 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 38 | int up_x, int up_y, int down_x, int down_y, 39 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 40 | CHECK_CUDA(input); 41 | CHECK_CUDA(kernel); 42 | 43 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 44 | } 45 | 46 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 47 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 48 | } 49 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/utils/biggan_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch 3 | 4 | MIT License 5 | 6 | Copyright (c) 2019 Andy Brock 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | """ 26 | 27 | 28 | import random 29 | 30 | from .sample import sample_latents 31 | 32 | import torch 33 | 34 | 35 | class ema(object): 36 | def __init__(self, source, target, decay=0.9999, start_itr=0): 37 | self.source = source 38 | self.target = target 39 | self.decay = decay 40 | # Optional parameter indicating what iteration to start the decay at 41 | self.start_itr = start_itr 42 | # Initialize target's params to be source's 43 | self.source_dict = self.source.state_dict() 44 | self.target_dict = self.target.state_dict() 45 | print('Initializing EMA parameters to be source parameters...') 46 | with torch.no_grad(): 47 | for key in self.source_dict: 48 | self.target_dict[key].data.copy_(self.source_dict[key].data) 49 | self.target_dict[key].requires_grad = False 50 | 51 | def update(self, itr=None): 52 | # If an iteration counter is provided and itr is less than the start itr, 53 | # peg the ema weights to the underlying weights. 54 | if itr >= 0 and itr < self.start_itr: 55 | decay = 0.0 56 | else: 57 | decay = self.decay 58 | with torch.no_grad(): 59 | for key in self.source_dict: 60 | self.target_dict[key].data.copy_(self.target_dict[key].data * decay + 61 | self.source_dict[key].data * (1 - decay)) 62 | 63 | 64 | class ema_DP_SyncBN(object): 65 | def __init__(self, source, target, decay=0.9999, start_itr=0): 66 | self.source = source 67 | self.target = target 68 | self.decay = decay 69 | self.start_itr = start_itr 70 | # Initialize target's params to be source's 71 | print('Initializing EMA parameters to be source parameters...') 72 | for key in self.source.state_dict(): 73 | self.target.state_dict()[key].data.copy_(self.source.state_dict()[key].data) 74 | self.target.state_dict()[key].requires_grad = False 75 | 76 | def update(self, itr=None): 77 | # If an iteration counter is provided and itr is less than the start itr, 78 | # peg the ema weights to the underlying weights. 79 | if itr >= 0 and itr < self.start_itr: 80 | decay = 0.0 81 | else: 82 | decay = self.decay 83 | 84 | for key in self.source.state_dict(): 85 | data = self.target.state_dict()[key].data * decay + \ 86 | self.source.state_dict()[key].detach().data * (1. - decay) 87 | self.target.state_dict()[key].data.copy_(data) 88 | 89 | 90 | def ortho(model, strength=1e-4, blacklist=[]): 91 | with torch.no_grad(): 92 | for param in model.parameters(): 93 | # Only apply this to parameters with at least 2 axes, and not in the blacklist 94 | if len(param.shape) < 2 or any([param is item for item in blacklist]): 95 | continue 96 | w = param.view(param.shape[0], -1) 97 | grad = (2 * torch.mm(torch.mm(w, w.t()) 98 | * (1. - torch.eye(w.shape[0], device=w.device)), w)) 99 | param.grad.data += strength * grad.view(param.shape) 100 | 101 | 102 | # Interp function; expects x0 and x1 to be of shape (shape0, 1, rest_of_shape..) 103 | def interp(x0, x1, num_midpoints): 104 | lerp = torch.linspace(0, 1.0, num_midpoints + 2, device='cuda').to(x0.dtype) 105 | return ((x0 * (1 - lerp.view(1, -1, 1))) + (x1 * lerp.view(1, -1, 1))) 106 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/utils/cr_diff_aug.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # src/utils/cr_diff_aug.py 6 | 7 | 8 | import random 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | 13 | 14 | 15 | def CR_DiffAug(x, flip=True, translation=True): 16 | if flip: 17 | x = random_flip(x, 0.5) 18 | if translation: 19 | x = random_translation(x, 1/8) 20 | if flip or translation: 21 | x = x.contiguous() 22 | return x 23 | 24 | 25 | def random_flip(x, p): 26 | x_out = x.clone() 27 | n, c, h, w = x.shape[0], x.shape[1], x.shape[2], x.shape[3] 28 | flip_prob = torch.FloatTensor(n, 1).uniform_(0.0, 1.0) 29 | flip_mask = flip_prob < p 30 | flip_mask = flip_mask.type(torch.bool).view(n, 1, 1, 1).repeat(1, c, h, w).to(x.device) 31 | x_out[flip_mask] = torch.flip(x[flip_mask].view(-1, c, h, w), [3]).view(-1) 32 | return x_out 33 | 34 | 35 | def random_translation(x, ratio): 36 | max_t_x, max_t_y = int(x.shape[2]*ratio), int(x.shape[3]*ratio) 37 | t_x = torch.randint(-max_t_x, max_t_x + 1, size = [x.shape[0], 1, 1], device=x.device) 38 | t_y = torch.randint(-max_t_y, max_t_y + 1, size = [x.shape[0], 1, 1], device=x.device) 39 | 40 | grid_batch, grid_x, grid_y = torch.meshgrid( 41 | torch.arange(x.shape[0], dtype=torch.long, device=x.device), 42 | torch.arange(x.shape[2], dtype=torch.long, device=x.device), 43 | torch.arange(x.shape[3], dtype=torch.long, device=x.device), 44 | ) 45 | 46 | grid_x = (grid_x + t_x) + max_t_x 47 | grid_y = (grid_y + t_y) + max_t_y 48 | x_pad = F.pad(input=x, pad=[max_t_y, max_t_y, max_t_x, max_t_x], mode='reflect') 49 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 50 | return x 51 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/utils/diff_aug.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2020, Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | * Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | * Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 16 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 19 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 21 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 22 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 23 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 24 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 25 | """ 26 | 27 | 28 | import torch 29 | import torch.nn.functional as F 30 | 31 | 32 | 33 | ### Differentiable Augmentation for Data-Efficient GAN Training (https://arxiv.org/abs/2006.10738) 34 | ### Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han 35 | ### https://github.com/mit-han-lab/data-efficient-gans 36 | 37 | 38 | def DiffAugment(x, policy='', channels_first=True): 39 | if policy: 40 | if not channels_first: 41 | x = x.permute(0, 3, 1, 2) 42 | for p in policy.split(','): 43 | for f in AUGMENT_FNS[p]: 44 | x = f(x) 45 | if not channels_first: 46 | x = x.permute(0, 2, 3, 1) 47 | x = x.contiguous() 48 | return x 49 | 50 | 51 | def rand_brightness(x): 52 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5) 53 | return x 54 | 55 | 56 | def rand_saturation(x): 57 | x_mean = x.mean(dim=1, keepdim=True) 58 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean 59 | return x 60 | 61 | 62 | def rand_contrast(x): 63 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True) 64 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean 65 | return x 66 | 67 | 68 | def rand_translation(x, ratio=0.125): 69 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 70 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device) 71 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device) 72 | grid_batch, grid_x, grid_y = torch.meshgrid( 73 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 74 | torch.arange(x.size(2), dtype=torch.long, device=x.device), 75 | torch.arange(x.size(3), dtype=torch.long, device=x.device), 76 | ) 77 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1) 78 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1) 79 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0]) 80 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2) 81 | return x 82 | 83 | 84 | def rand_cutout(x, ratio=0.5): 85 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5) 86 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device) 87 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device) 88 | grid_batch, grid_x, grid_y = torch.meshgrid( 89 | torch.arange(x.size(0), dtype=torch.long, device=x.device), 90 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device), 91 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device), 92 | ) 93 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1) 94 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1) 95 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device) 96 | mask[grid_batch, grid_x, grid_y] = 0 97 | x = x * mask.unsqueeze(1) 98 | return x 99 | 100 | 101 | AUGMENT_FNS = { 102 | 'color': [rand_brightness, rand_saturation, rand_contrast], 103 | 'translation': [rand_translation], 104 | 'cutout': [rand_cutout], 105 | } 106 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/utils/load_checkpoint.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # src/utils/load_checkpoint.py 6 | 7 | 8 | import os 9 | 10 | import torch 11 | 12 | 13 | 14 | def load_checkpoint(model, optimizer, filename, metric=False, ema=False): 15 | start_step = 0 16 | if ema: 17 | checkpoint = torch.load(filename) 18 | model.load_state_dict(checkpoint['state_dict']) 19 | return model 20 | else: 21 | checkpoint = torch.load(filename) 22 | seed = checkpoint['seed'] 23 | run_name = checkpoint['run_name'] 24 | start_step = checkpoint['step'] 25 | model.load_state_dict(checkpoint['state_dict']) 26 | optimizer.load_state_dict(checkpoint['optimizer']) 27 | ada_p = checkpoint['ada_p'] 28 | for state in optimizer.state.values(): 29 | for k, v in state.items(): 30 | if isinstance(v, torch.Tensor): 31 | state[k] = v.cuda() 32 | 33 | if metric: 34 | best_step = checkpoint['best_step'] 35 | best_fid = checkpoint['best_fid'] 36 | best_fid_checkpoint_path = checkpoint['best_fid_checkpoint_path'] 37 | return model, optimizer, seed, run_name, start_step, ada_p, best_step, best_fid, best_fid_checkpoint_path 38 | return model, optimizer, seed, run_name, start_step, ada_p 39 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/utils/log.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # src/utils/log.py 6 | 7 | 8 | import json 9 | import os 10 | import logging 11 | from os.path import dirname, abspath, exists, join 12 | from datetime import datetime 13 | 14 | 15 | 16 | def make_run_name(format, framework, phase): 17 | return format.format( 18 | framework=framework, 19 | phase=phase, 20 | timestamp=datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 21 | ) 22 | 23 | 24 | def make_logger(run_name, log_output): 25 | if log_output is not None: 26 | run_name = log_output.split('/')[-1].split('.')[0] 27 | logger = logging.getLogger(run_name) 28 | logger.propagate = False 29 | log_filepath = log_output if log_output is not None else join('logs', f'{run_name}.log') 30 | 31 | log_dir = dirname(abspath(log_filepath)) 32 | if not exists(log_dir): 33 | os.makedirs(log_dir) 34 | 35 | if not logger.handlers: # execute only if logger doesn't already exist 36 | file_handler = logging.FileHandler(log_filepath, 'a', 'utf-8') 37 | stream_handler = logging.StreamHandler(os.sys.stdout) 38 | 39 | formatter = logging.Formatter('[%(levelname)s] %(asctime)s > %(message)s', datefmt='%Y-%m-%d %H:%M:%S') 40 | 41 | file_handler.setFormatter(formatter) 42 | stream_handler.setFormatter(formatter) 43 | 44 | logger.addHandler(file_handler) 45 | logger.addHandler(stream_handler) 46 | logger.setLevel(logging.INFO) 47 | return logger 48 | 49 | 50 | def make_checkpoint_dir(checkpoint_dir, run_name): 51 | checkpoint_dir = checkpoint_dir if checkpoint_dir is not None else join('checkpoints', run_name) 52 | if not exists(abspath(checkpoint_dir)): 53 | os.makedirs(checkpoint_dir) 54 | return checkpoint_dir 55 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/utils/make_hdf5.py: -------------------------------------------------------------------------------- 1 | """ 2 | this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch 3 | 4 | MIT License 5 | 6 | Copyright (c) 2019 Andy Brock 7 | Permission is hereby granted, free of charge, to any person obtaining a copy 8 | of this software and associated documentation files (the "Software"), to deal 9 | in the Software without restriction, including without limitation the rights 10 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | copies of the Software, and to permit persons to whom the Software is 12 | furnished to do so, subject to the following conditions: 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | """ 23 | 24 | 25 | import os 26 | import sys 27 | import h5py as h5 28 | import numpy as np 29 | import PIL 30 | from argparse import ArgumentParser 31 | from tqdm import tqdm, trange 32 | 33 | from data_utils.load_dataset import LoadDataset 34 | 35 | import torch 36 | import torchvision.transforms as transforms 37 | from torch.utils.data import DataLoader 38 | 39 | 40 | 41 | def make_hdf5(model_config, train_config, mode): 42 | if 'hdf5' in model_config['dataset_name']: 43 | raise ValueError('Reading from an HDF5 file which you will probably be ' 44 | 'about to overwrite! Override this error only if you know ' 45 | 'what you''re doing!') 46 | 47 | file_name = '{dataset_name}_{size}_{mode}.hdf5'.format(dataset_name=model_config['dataset_name'], size=model_config['img_size'], mode=mode) 48 | file_path = os.path.join(model_config['data_path'], file_name) 49 | train = True if mode == "train" else False 50 | 51 | if os.path.isfile(file_path): 52 | print("{file_name} exist!\nThe file are located in the {file_path}".format(file_name=file_name, file_path=file_path)) 53 | else: 54 | dataset = LoadDataset(model_config['dataset_name'], model_config['data_path'], train=train, download=True, resize_size=model_config['img_size'], 55 | hdf5_path=None, random_flip=False) 56 | 57 | loader = DataLoader(dataset, 58 | batch_size=model_config['batch_size4prcsing'], 59 | shuffle=False, 60 | pin_memory=False, 61 | num_workers=train_config['num_workers'], 62 | drop_last=False) 63 | 64 | print('Starting to load %s into an HDF5 file with chunk size %i and compression %s...' % (model_config['dataset_name'], 65 | model_config['chunk_size'], 66 | model_config['compression'])) 67 | # Loop over loader 68 | for i,(x,y) in enumerate(tqdm(loader)): 69 | # Numpyify x, y 70 | x = (255 * ((x + 1) / 2.0)).byte().numpy() 71 | y = y.numpy() 72 | # If we're on the first batch, prepare the hdf5 73 | if i==0: 74 | with h5.File(file_path, 'w') as f: 75 | print('Producing dataset of len %d' % len(loader.dataset)) 76 | imgs_dset = f.create_dataset('imgs', x.shape, dtype='uint8', maxshape=(len(loader.dataset), 3, 77 | model_config['img_size'], model_config['img_size']), 78 | chunks=(model_config['chunk_size'], 3, model_config['img_size'], model_config['img_size']), compression=model_config['compression']) 79 | print('Image chunks chosen as ' + str(imgs_dset.chunks)) 80 | imgs_dset[...] = x 81 | 82 | labels_dset = f.create_dataset('labels', y.shape, dtype='int64', maxshape=(len(loader.dataset),), 83 | chunks=(model_config['chunk_size'],), compression=model_config['compression']) 84 | print('Label chunks chosen as ' + str(labels_dset.chunks)) 85 | labels_dset[...] = y 86 | # Else append to the hdf5 87 | else: 88 | with h5.File(file_path, 'a') as f: 89 | f['imgs'].resize(f['imgs'].shape[0] + x.shape[0], axis=0) 90 | f['imgs'][-x.shape[0]:] = x 91 | f['labels'].resize(f['labels'].shape[0] + y.shape[0], axis=0) 92 | f['labels'][-y.shape[0]:] = y 93 | return file_path 94 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/StudioGAN/utils/sample.py: -------------------------------------------------------------------------------- 1 | # PyTorch StudioGAN: https://github.com/POSTECH-CVLab/PyTorch-StudioGAN 2 | # The MIT License (MIT) 3 | # See license file or visit https://github.com/POSTECH-CVLab/PyTorch-StudioGAN for details 4 | 5 | # src/utils/sample.py 6 | 7 | 8 | import numpy as np 9 | import random 10 | from numpy import linalg 11 | from math import sin, cos, sqrt 12 | 13 | from .losses import latent_optimise 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | from torch.nn import DataParallel 18 | 19 | 20 | def sample_latents(dist, batch_size, dim, truncated_factor=1, num_classes=None, perturb=None, device=torch.device("cpu"), sampler="default"): 21 | if num_classes: 22 | if sampler == "default": 23 | y_fake = torch.randint(low=0, high=num_classes, size=(batch_size,), dtype=torch.long, device=device) 24 | elif sampler == "class_order_some": 25 | assert batch_size % 8 == 0, "The size of the batches should be a multiple of 8." 26 | num_classes_plot = batch_size // 8 27 | indices = np.random.permutation(num_classes)[:num_classes_plot] 28 | elif sampler == "class_order_all": 29 | batch_size = num_classes * 8 30 | indices = [c for c in range(num_classes)] 31 | elif isinstance(sampler, int): 32 | y_fake = torch.tensor([sampler] * batch_size, dtype=torch.long).to(device) 33 | else: 34 | raise NotImplementedError 35 | 36 | if sampler in ["class_order_some", "class_order_all"]: 37 | y_fake = [] 38 | for idx in indices: 39 | y_fake += [idx] * 8 40 | y_fake = torch.tensor(y_fake, dtype=torch.long).to(device) 41 | else: 42 | y_fake = None 43 | 44 | if isinstance(perturb, float) and perturb > 0.0: 45 | if dist == "gaussian": 46 | latents = torch.randn(batch_size, dim, device=device) / truncated_factor 47 | eps = perturb * torch.randn(batch_size, dim, device=device) 48 | latents_eps = latents + eps 49 | elif dist == "uniform": 50 | latents = torch.FloatTensor(batch_size, dim).uniform_(-1.0, 1.0).to(device) 51 | eps = perturb * torch.FloatTensor(batch_size, dim).uniform_(-1.0, 1.0).to(device) 52 | latents_eps = latents + eps 53 | elif dist == "hyper_sphere": 54 | latents, latents_eps = random_ball(batch_size, dim, perturb=perturb) 55 | latents, latents_eps = torch.FloatTensor(latents).to(device), torch.FloatTensor(latents_eps).to(device) 56 | return latents, y_fake, latents_eps 57 | else: 58 | if dist == "gaussian": 59 | latents = torch.randn(batch_size, dim, device=device) / truncated_factor 60 | elif dist == "uniform": 61 | latents = torch.FloatTensor(batch_size, dim).uniform_(-1.0, 1.0).to(device) 62 | elif dist == "hyper_sphere": 63 | latents = random_ball(batch_size, dim, perturb=perturb).to(device) 64 | return latents, y_fake 65 | 66 | 67 | def random_ball(batch_size, z_dim, perturb=False): 68 | if perturb: 69 | normal = np.random.normal(size=(z_dim, batch_size)) 70 | random_directions = normal / linalg.norm(normal, axis=0) 71 | random_radii = random.random(batch_size) ** (1 / z_dim) 72 | zs = 1.0 * (random_directions * random_radii).T 73 | 74 | normal_perturb = normal + 0.05 * np.random.normal(size=(z_dim, batch_size)) 75 | perturb_random_directions = normal_perturb / linalg.norm(normal_perturb, axis=0) 76 | perturb_random_radii = random.random(batch_size) ** (1 / z_dim) 77 | zs_perturb = 1.0 * (perturb_random_directions * perturb_random_radii).T 78 | return zs, zs_perturb 79 | else: 80 | normal = np.random.normal(size=(z_dim, batch_size)) 81 | random_directions = normal / linalg.norm(normal, axis=0) 82 | random_radii = random.random(batch_size) ** (1 / z_dim) 83 | zs = 1.0 * (random_directions * random_radii).T 84 | return zs 85 | 86 | 87 | # Convenience function to sample an index, not actually a 1-hot 88 | def sample_1hot(batch_size, num_classes, device='cuda'): 89 | return torch.randint(low=0, high=num_classes, size=(batch_size,), 90 | device=device, dtype=torch.int64, requires_grad=False) 91 | 92 | 93 | def make_mask(labels, n_cls, device): 94 | labels = labels.detach().cpu().numpy() 95 | n_samples = labels.shape[0] 96 | mask_multi = np.zeros([n_cls, n_samples]) 97 | for c in range(n_cls): 98 | c_indices = np.where(labels == c) 99 | mask_multi[c, c_indices] = +1 100 | 101 | mask_multi = torch.tensor(mask_multi).type(torch.long) 102 | return mask_multi.to(device) 103 | 104 | 105 | def target_class_sampler(dataset, target_class): 106 | try: 107 | targets = dataset.data.targets 108 | except: 109 | targets = dataset.labels 110 | weights = [True if target == target_class else False for target in targets] 111 | num_samples = sum(weights) 112 | weights = torch.DoubleTensor(weights) 113 | sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights), replacement=False) 114 | return num_samples, sampler 115 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/__init__.py: -------------------------------------------------------------------------------- 1 | from .BigGAN import make_biggan 2 | from .BigBiGAN import make_bigbigan 3 | from .self_conditioned import make_selfcond_gan 4 | from .stylegan2_ada_pytorch import make_stylegan2 5 | from .StudioGAN import make_studiogan 6 | from .CIPS import make_cips 7 | 8 | 9 | def make_gan(*, gan_type, **kwargs): 10 | t = gan_type.lower() 11 | if t == 'bigbigan': 12 | G = make_bigbigan(**kwargs) 13 | elif t == 'selfconditionedgan': 14 | G = make_selfcond_gan(**kwargs) 15 | elif t == 'studiogan': 16 | G = make_studiogan(**kwargs) 17 | elif t == 'stylegan2': 18 | G = make_stylegan2(**kwargs) 19 | elif t == 'cips': 20 | G = make_cips(**kwargs) 21 | elif t == 'biggan': 22 | G = make_biggan(**kwargs) 23 | else: 24 | raise NotImplementedError(f'Unrecognized GAN type: {gan_type}') 25 | return G 26 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/self_conditioned/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import model_zoo 3 | from .gan_training.models import generator_dict 4 | 5 | # Config, adapted from: 6 | # - https://github.com/stevliu/self-conditioned-gan/blob/master/configs/imagenet/default.yaml 7 | # - https://github.com/stevliu/self-conditioned-gan/blob/master/configs/imagenet/unconditional.yaml 8 | # - https://github.com/stevliu/self-conditioned-gan/blob/master/configs/imagenet/selfcondgan.yaml 9 | configs = { 10 | 'unconditional': { 11 | 'generator': { 12 | 'name': 'resnet2', 13 | 'kwargs': {}, 14 | # Unconditional 15 | 'nlabels': 1, 16 | 'conditioning': 'unconditional', 17 | }, 18 | 'z_dist': { 19 | 'dim': 256 20 | }, 21 | 'data': { 22 | 'img_size': 128 23 | }, 24 | 'pretrained': { 25 | 'model': 'http://selfcondgan.csail.mit.edu/weights/uncondgan_i_model.pt' 26 | } 27 | }, 28 | 'self_conditioned': { 29 | 'generator': { 30 | 'name': 'resnet2', 31 | 'kwargs': {}, 32 | # Self-conditional 33 | 'nlabels': 100, 34 | 'conditioning': 'embedding', 35 | }, 36 | 'z_dist': { 37 | 'dim': 256 38 | }, 39 | 'data': { 40 | 'img_size': 128 41 | }, 42 | 'pretrained': { 43 | 'model': 'http://selfcondgan.csail.mit.edu/weights/selfcondgan_i_model.pt' 44 | } 45 | } 46 | } 47 | 48 | 49 | class GeneratorWrapper(torch.nn.Module): 50 | """ A wrapper to put the GAN in a standard format and add metadata (dim_z) """ 51 | 52 | def __init__(self, generator, dim_z, nlabels): 53 | super().__init__() 54 | self.G = generator 55 | self.dim_z = dim_z 56 | self.conditional = True 57 | self.num_classes = nlabels 58 | 59 | def forward(self, z, y=None, return_y=False): 60 | if y is None: 61 | y = self.sample_class(batch_size=z.shape[0], device=z.device) 62 | else: 63 | y = y.to(z.device) 64 | x = self.G(z, y) 65 | return (x, y) if return_y else x 66 | 67 | def sample_latent(self, batch_size, device='cpu'): 68 | z = torch.randn(size=(batch_size, self.dim_z), device=device) 69 | return z 70 | 71 | def sample_class(self, batch_size=None, device='cpu'): 72 | y = torch.randint(low=0, high=self.num_classes, size=(batch_size,), device=device) 73 | return y 74 | 75 | 76 | def make_selfcond_gan(model_name='self_conditioned'): 77 | """ A helper function for loading a (pretrained) GAN """ 78 | 79 | # Get generator configuration 80 | assert model_name in {'self_conditioned', 'unconditional'} 81 | config = configs[model_name] 82 | 83 | # Create GAN 84 | Generator = generator_dict[config['generator']['name']] 85 | generator = Generator( 86 | z_dim=config['z_dist']['dim'], 87 | nlabels=config['generator']['nlabels'], 88 | size=config['data']['img_size'], 89 | conditioning=config['generator']['conditioning'], 90 | **config['generator']['kwargs'] 91 | ) 92 | 93 | # Load checkpoint 94 | checkpoint = model_zoo.load_url(config['pretrained']['model'], map_location='cpu') 95 | generator.load_state_dict(checkpoint['generator']) 96 | print(f"Loaded pretrained GAN weights (iteration: {checkpoint['it']})") 97 | 98 | # Wrap GAN 99 | G = GeneratorWrapper( 100 | generator=generator, 101 | dim_z=config['z_dist']['dim'], 102 | nlabels=config['generator']['nlabels'] 103 | ).eval() 104 | 105 | return G 106 | 107 | 108 | if __name__ == "__main__": 109 | 110 | # Load model 111 | G = make_selfcond_gan('self-conditioned') 112 | print(f'Parameters: {sum(p.numel() for p in G.parameters()) / 10**6} million') 113 | 114 | # Example usage 115 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 116 | G.to(device) 117 | with torch.no_grad(): 118 | z = torch.randn(7, G.dim_z, requires_grad=False, device=device) 119 | x = G(z) 120 | print(f'Input shape: {z.shape}') 121 | print(f'Output shape: {x.shape}') 122 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/self_conditioned/gan_training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lukemelas/pytorch-pretrained-gans/2982fdab4e683165e45bc2f4a64c2942a7a3a1b7/pytorch_pretrained_gans/self_conditioned/gan_training/__init__.py -------------------------------------------------------------------------------- /pytorch_pretrained_gans/self_conditioned/gan_training/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from torch import optim 3 | from os import path 4 | from gan_training.models import generator_dict, discriminator_dict 5 | from gan_training.train import toggle_grad 6 | from clusterers import clusterer_dict 7 | 8 | 9 | # General config 10 | def load_config(path, default_path): 11 | ''' Loads config file. 12 | 13 | Args: 14 | path (str): path to config file 15 | default_path (bool): whether to use default path 16 | ''' 17 | # Load configuration from file itself 18 | with open(path, 'r') as f: 19 | cfg_special = yaml.load(f) 20 | 21 | # Check if we should inherit from a config 22 | inherit_from = cfg_special.get('inherit_from') 23 | 24 | # If yes, load this config first as default 25 | # If no, use the default_path 26 | if inherit_from is not None: 27 | cfg = load_config(inherit_from, default_path) 28 | elif default_path is not None: 29 | with open(default_path, 'r') as f: 30 | cfg = yaml.load(f) 31 | else: 32 | cfg = dict() 33 | 34 | # Include main configuration 35 | update_recursive(cfg, cfg_special) 36 | 37 | return cfg 38 | 39 | 40 | def update_recursive(dict1, dict2): 41 | ''' Update two config dictionaries recursively. 42 | 43 | Args: 44 | dict1 (dict): first dictionary to be updated 45 | dict2 (dict): second dictionary which entries should be used 46 | 47 | ''' 48 | for k, v in dict2.items(): 49 | # Add item if not yet in dict1 50 | if k not in dict1: 51 | dict1[k] = None 52 | # Update 53 | if isinstance(dict1[k], dict): 54 | update_recursive(dict1[k], v) 55 | else: 56 | dict1[k] = v 57 | 58 | 59 | def get_clusterer(config): 60 | return clusterer_dict[config['clusterer']['name']] 61 | 62 | 63 | def build_models(config): 64 | # Get classes 65 | Generator = generator_dict[config['generator']['name']] 66 | Discriminator = discriminator_dict[config['discriminator']['name']] 67 | 68 | # Build models 69 | generator = Generator(z_dim=config['z_dist']['dim'], 70 | nlabels=config['generator']['nlabels'], 71 | size=config['data']['img_size'], 72 | conditioning=config['generator']['conditioning'], 73 | **config['generator']['kwargs']) 74 | discriminator = Discriminator( 75 | nlabels=config['discriminator']['nlabels'], 76 | conditioning=config['discriminator']['conditioning'], 77 | size=config['data']['img_size'], 78 | **config['discriminator']['kwargs']) 79 | 80 | return generator, discriminator 81 | 82 | 83 | def build_optimizers(generator, discriminator, config): 84 | optimizer = config['training']['optimizer'] 85 | lr_g = config['training']['lr_g'] 86 | lr_d = config['training']['lr_d'] 87 | 88 | 89 | toggle_grad(generator, True) 90 | toggle_grad(discriminator, True) 91 | 92 | g_params = generator.parameters() 93 | d_params = discriminator.parameters() 94 | 95 | if optimizer == 'rmsprop': 96 | g_optimizer = optim.RMSprop(g_params, lr=lr_g, alpha=0.99, eps=1e-8) 97 | d_optimizer = optim.RMSprop(d_params, lr=lr_d, alpha=0.99, eps=1e-8) 98 | elif optimizer == 'adam': 99 | beta1 = config['training']['beta1'] 100 | beta2 = config['training']['beta2'] 101 | g_optimizer = optim.Adam(g_params, lr=lr_g, betas=(beta1, beta2), eps=1e-8) 102 | d_optimizer = optim.Adam(d_params, lr=lr_d, betas=(beta1, beta2), eps=1e-8) 103 | elif optimizer == 'sgd': 104 | g_optimizer = optim.SGD(g_params, lr=lr_g, momentum=0.) 105 | d_optimizer = optim.SGD(d_params, lr=lr_d, momentum=0.) 106 | 107 | return g_optimizer, d_optimizer 108 | 109 | 110 | # Some utility functions 111 | def get_parameter_groups(parameters, gradient_scales, base_lr): 112 | param_groups = [] 113 | for p in parameters: 114 | c = gradient_scales.get(p, 1.) 115 | param_groups.append({'params': [p], 'lr': c * base_lr}) 116 | return param_groups 117 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/self_conditioned/gan_training/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import distributions 3 | 4 | 5 | def get_zdist(dist_name, dim, device=None): 6 | # Get distribution 7 | if dist_name == 'uniform': 8 | low = -torch.ones(dim, device=device) 9 | high = torch.ones(dim, device=device) 10 | zdist = distributions.Uniform(low, high) 11 | elif dist_name == 'gauss': 12 | mu = torch.zeros(dim, device=device) 13 | scale = torch.ones(dim, device=device) 14 | zdist = distributions.Normal(mu, scale) 15 | else: 16 | raise NotImplementedError 17 | 18 | # Add dim attribute 19 | zdist.dim = dim 20 | 21 | return zdist 22 | 23 | 24 | def get_ydist(nlabels, device=None): 25 | logits = torch.zeros(nlabels, device=device) 26 | ydist = distributions.categorical.Categorical(logits=logits) 27 | 28 | # Add nlabels attribute 29 | ydist.nlabels = nlabels 30 | 31 | return ydist 32 | 33 | 34 | def interpolate_sphere(z1, z2, t): 35 | p = (z1 * z2).sum(dim=-1, keepdim=True) 36 | p = p / z1.pow(2).sum(dim=-1, keepdim=True).sqrt() 37 | p = p / z2.pow(2).sum(dim=-1, keepdim=True).sqrt() 38 | omega = torch.acos(p) 39 | s1 = torch.sin((1-t)*omega)/torch.sin(omega) 40 | s2 = torch.sin(t*omega)/torch.sin(omega) 41 | z = s1 * z1 + s2 * z2 42 | 43 | return z 44 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/self_conditioned/gan_training/eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | from gan_training.metrics import inception_score 6 | 7 | class Evaluator(object): 8 | def __init__(self, 9 | generator, 10 | zdist, 11 | ydist, 12 | train_loader, 13 | clusterer, 14 | batch_size=64, 15 | inception_nsamples=10000, 16 | device=None): 17 | self.generator = generator 18 | self.clusterer = clusterer 19 | self.train_loader = train_loader 20 | self.zdist = zdist 21 | self.ydist = ydist 22 | self.inception_nsamples = inception_nsamples 23 | self.batch_size = batch_size 24 | self.device = device 25 | 26 | def sample_z(self, batch_size): 27 | return self.zdist.sample((batch_size, )).to(self.device) 28 | 29 | def get_y(self, x, y): 30 | return self.clusterer.get_labels(x, y).to(self.device) 31 | 32 | def get_fake_real_samples(self, N): 33 | ''' returns N fake images and N real images in pytorch form''' 34 | with torch.no_grad(): 35 | self.generator.eval() 36 | fake_imgs = [] 37 | real_imgs = [] 38 | while len(fake_imgs) < N: 39 | for x_real, y_gt in self.train_loader: 40 | x_real = x_real.cuda() 41 | z = self.sample_z(x_real.size(0)) 42 | y = self.get_y(x_real, y_gt) 43 | samples = self.generator(z, y) 44 | samples = [s.data.cpu() for s in samples] 45 | fake_imgs.extend(samples) 46 | real_batch = [img.data.cpu() for img in x_real] 47 | real_imgs.extend(real_batch) 48 | assert (len(real_imgs) == len(fake_imgs)) 49 | if len(fake_imgs) >= N: 50 | fake_imgs = fake_imgs[:N] 51 | real_imgs = real_imgs[:N] 52 | return fake_imgs, real_imgs 53 | 54 | def compute_inception_score(self): 55 | imgs, _ = self.get_fake_real_samples(self.inception_nsamples) 56 | imgs = [img.numpy() for img in imgs] 57 | score, score_std = inception_score(imgs, 58 | device=self.device, 59 | resize=True, 60 | splits=1) 61 | 62 | return score, score_std 63 | 64 | def create_samples(self, z, y=None): 65 | self.generator.eval() 66 | batch_size = z.size(0) 67 | # Parse y 68 | if y is None: 69 | raise NotImplementedError() 70 | elif isinstance(y, int): 71 | y = torch.full((batch_size, ), 72 | y, 73 | device=self.device, 74 | dtype=torch.int64) 75 | # Sample x 76 | with torch.no_grad(): 77 | x = self.generator(z, y) 78 | return x 79 | 80 | 81 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/self_conditioned/gan_training/logger.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import torchvision 4 | import copy 5 | 6 | 7 | class Logger(object): 8 | def __init__(self, 9 | log_dir='./logs', 10 | img_dir='./imgs', 11 | monitoring=None, 12 | monitoring_dir=None): 13 | self.stats = dict() 14 | self.log_dir = log_dir 15 | self.img_dir = img_dir 16 | 17 | if not os.path.exists(log_dir): 18 | os.makedirs(log_dir) 19 | 20 | if not os.path.exists(img_dir): 21 | os.makedirs(img_dir) 22 | 23 | if not (monitoring is None or monitoring == 'none'): 24 | self.setup_monitoring(monitoring, monitoring_dir) 25 | else: 26 | self.monitoring = None 27 | self.monitoring_dir = None 28 | 29 | def setup_monitoring(self, monitoring, monitoring_dir=None): 30 | self.monitoring = monitoring 31 | self.monitoring_dir = monitoring_dir 32 | 33 | if monitoring == 'telemetry': 34 | import telemetry 35 | self.tm = telemetry.ApplicationTelemetry() 36 | if self.tm.get_status() == 0: 37 | print('Telemetry successfully connected.') 38 | elif monitoring == 'tensorboard': 39 | import tensorboardX 40 | self.tb = tensorboardX.SummaryWriter(monitoring_dir) 41 | else: 42 | raise NotImplementedError('Monitoring tool "%s" not supported!' % 43 | monitoring) 44 | 45 | def add(self, category, k, v, it): 46 | if category not in self.stats: 47 | self.stats[category] = {} 48 | 49 | if k not in self.stats[category]: 50 | self.stats[category][k] = [] 51 | 52 | self.stats[category][k].append((it, v)) 53 | 54 | k_name = '%s/%s' % (category, k) 55 | if self.monitoring == 'telemetry': 56 | self.tm.metric_push_async({'metric': k_name, 'value': v, 'it': it}) 57 | elif self.monitoring == 'tensorboard': 58 | self.tb.add_scalar(k_name, v, it) 59 | 60 | def add_imgs(self, imgs, class_name, it): 61 | outdir = os.path.join(self.img_dir, class_name) 62 | if not os.path.exists(outdir): 63 | os.makedirs(outdir) 64 | outfile = os.path.join(outdir, '%08d.png' % it) 65 | 66 | imgs = imgs / 2 + 0.5 67 | imgs = torchvision.utils.make_grid(imgs) 68 | torchvision.utils.save_image(copy.deepcopy(imgs), outfile, nrow=8) 69 | 70 | if self.monitoring == 'tensorboard': 71 | self.tb.add_image(class_name, copy.deepcopy(imgs), it) 72 | 73 | def get_last(self, category, k, default=0.): 74 | if category not in self.stats: 75 | return default 76 | elif k not in self.stats[category]: 77 | return default 78 | else: 79 | return self.stats[category][k][-1][1] 80 | 81 | def save_stats(self, filename): 82 | filename = os.path.join(self.log_dir, filename) 83 | with open(filename, 'wb') as f: 84 | pickle.dump(self.stats, f) 85 | 86 | def load_stats(self, filename): 87 | filename = os.path.join(self.log_dir, filename) 88 | if not os.path.exists(filename): 89 | print('Warning: file "%s" does not exist!' % filename) 90 | return 91 | 92 | try: 93 | with open(filename, 'rb') as f: 94 | self.stats = pickle.load(f) 95 | except EOFError: 96 | print('Warning: log file corrupted!') 97 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/self_conditioned/gan_training/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from gan_training.metrics.inception_score import inception_score 2 | 3 | __all__ = [ 4 | inception_score 5 | ] 6 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/self_conditioned/gan_training/metrics/clustering_metrics.py: -------------------------------------------------------------------------------- 1 | def warn(*args, **kwargs): 2 | pass 3 | 4 | 5 | import warnings 6 | warnings.warn = warn 7 | 8 | from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_rand_score, homogeneity_score 9 | from sklearn import metrics 10 | 11 | import numpy as np 12 | 13 | 14 | def nmi(inferred, gt): 15 | return normalized_mutual_info_score(inferred, gt) 16 | 17 | 18 | def acc(inferred, gt): 19 | gt = gt.astype(np.int64) 20 | assert inferred.size == gt.size 21 | D = max(inferred.max(), gt.max()) + 1 22 | w = np.zeros((D, D), dtype=np.int64) 23 | for i in range(inferred.size): 24 | w[inferred[i], gt[i]] += 1 25 | from sklearn.utils.linear_assignment_ import linear_assignment 26 | ind = linear_assignment(w.max() - w) 27 | return sum([w[i, j] for i, j in ind]) * 1.0 / inferred.size 28 | 29 | 30 | def purity_score(y_true, y_pred): 31 | contingency_matrix = metrics.cluster.contingency_matrix(y_true, y_pred) 32 | return np.sum(np.amax(contingency_matrix, 33 | axis=0)) / np.sum(contingency_matrix) 34 | 35 | 36 | def ari(inferred, gt): 37 | return adjusted_rand_score(gt, inferred) 38 | 39 | 40 | def homogeneity(inferred, gt): 41 | return homogeneity_score(gt, inferred) 42 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/self_conditioned/gan_training/metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import torch.utils.data 5 | 6 | from torchvision.models.inception import inception_v3 7 | 8 | import numpy as np 9 | from scipy.stats import entropy 10 | 11 | 12 | def inception_score(imgs, device=None, batch_size=32, resize=False, splits=1): 13 | """Computes the inception score of the generated images imgs 14 | 15 | Args: 16 | imgs: Torch dataset of (3xHxW) numpy images normalized in the 17 | range [-1, 1] 18 | cuda: whether or not to run on GPU 19 | batch_size: batch size for feeding into Inception v3 20 | splits: number of splits 21 | """ 22 | N = len(imgs) 23 | 24 | assert batch_size > 0 25 | assert N > batch_size 26 | 27 | # Set up dataloader 28 | dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size) 29 | 30 | # Load inception model 31 | inception_model = inception_v3(pretrained=True, transform_input=False) 32 | inception_model = inception_model.to(device) 33 | inception_model.eval() 34 | up = nn.Upsample(size=(299, 299), mode='bilinear').to(device) 35 | 36 | def get_pred(x): 37 | with torch.no_grad(): 38 | if resize: 39 | x = up(x) 40 | x = inception_model(x) 41 | out = F.softmax(x, dim=-1) 42 | out = out.cpu().numpy() 43 | return out 44 | 45 | # Get predictions 46 | preds = np.zeros((N, 1000)) 47 | 48 | for i, batch in enumerate(dataloader, 0): 49 | batchv = batch.to(device) 50 | batch_size_i = batch.size()[0] 51 | 52 | preds[i * batch_size:i * batch_size + batch_size_i] = get_pred(batchv) 53 | 54 | # Now compute the mean kl-div 55 | split_scores = [] 56 | 57 | for k in range(splits): 58 | part = preds[k * (N // splits):(k + 1) * (N // splits), :] 59 | py = np.mean(part, axis=0) 60 | scores = [] 61 | for i in range(part.shape[0]): 62 | pyx = part[i, :] 63 | scores.append(entropy(pyx, py)) 64 | split_scores.append(np.exp(np.mean(scores))) 65 | 66 | return np.mean(split_scores), np.std(split_scores) 67 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/self_conditioned/gan_training/metrics/tf_is/README.md: -------------------------------------------------------------------------------- 1 | Inception Score 2 | ===================================== 3 | 4 | A new Tensorflow implementation of the "Inception Score" (IS) for the evaluation of generative models, with a bug raised in [https://github.com/openai/improved-gan/issues/29](https://github.com/openai/improved-gan/issues/29) fixed. 5 | 6 | ## Major Dependency 7 | - `tensorflow >= 1.14` 8 | 9 | ## Features 10 | - Fast, easy-to-use and memory-efficient, written in a way that is similar to the original implementation 11 | - No prior knowledge about Tensorflow is necessary if your are using CPU or GPU 12 | - Makes use of [TF-GAN](https://github.com/tensorflow/gan) 13 | - Downloads InceptionV1 automatically 14 | - Compatible with both Python 2 and Python 3 15 | 16 | ## Usage 17 | - If you are working with GPU, use `inception_score.py`; if you are working with TPU, use `inception_score_tpu.py` and pass a Tensorflow Session and a [TPUStrategy](https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/TPUStrategy) as additional arguments. 18 | - Call `get_inception_score(images, splits=10)`, where `images` is a numpy array with values ranging from 0 to 255 and shape in the form `[N, 3, HEIGHT, WIDTH]` where `N`, `HEIGHT` and `WIDTH` can be arbitrary. `dtype` of the images is recommended to be `np.uint8` to save CPU memory. 19 | - A smaller `BATCH_SIZE` reduces GPU/TPU memory usage, but at the cost of a slight slowdown. 20 | - If you want to compute a general "Classifier Score" with probabilities `preds` from another classifier, call `preds2score(preds, splits=10)`. `preds` can be a numpy array of arbitrary shape `[N, num_classes]`. 21 | ## Links 22 | - The Inception Score was proposed in the paper [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498) 23 | - Code for the [Fréchet Inception Distance](https://github.com/tsc2017/Frechet-Inception-Distance) 24 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/self_conditioned/gan_training/metrics/tf_is/inception_score.py: -------------------------------------------------------------------------------- 1 | ''' 2 | From https://github.com/tsc2017/Inception-Score 3 | Code derived from https://github.com/openai/improved-gan/blob/master/inception_score/model.py and https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py 4 | 5 | Usage: 6 | Call get_inception_score(images, splits=10) 7 | Args: 8 | images: A numpy array with values ranging from 0 to 255 and shape in the form [N, 3, HEIGHT, WIDTH] where N, HEIGHT and WIDTH can be arbitrary. A dtype of np.uint8 is recommended to save CPU memory. 9 | splits: The number of splits of the images, default is 10. 10 | Returns: 11 | Mean and standard deviation of the Inception Score across the splits. 12 | ''' 13 | 14 | import os 15 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 16 | import tensorflow as tf 17 | import functools 18 | import numpy as np 19 | import time 20 | from tqdm import tqdm 21 | from tensorflow.python.ops import array_ops 22 | tfgan = tf.contrib.gan 23 | 24 | session=tf.compat.v1.InteractiveSession() 25 | 26 | # A smaller BATCH_SIZE reduces GPU memory usage, but at the cost of a slight slowdown 27 | BATCH_SIZE = 64 28 | INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz' 29 | INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb' 30 | 31 | # Run images through Inception. 32 | inception_images = tf.compat.v1.placeholder(tf.float32, [None, 3, None, None]) 33 | def inception_logits(images = inception_images, num_splits = 1): 34 | images = tf.transpose(images, [0, 2, 3, 1]) 35 | size = 299 36 | images = tf.compat.v1.image.resize_bilinear(images, [size, size]) 37 | generated_images_list = array_ops.split(images, num_or_size_splits = num_splits) 38 | logits = tf.map_fn( 39 | fn = functools.partial( 40 | tfgan.eval.run_inception, 41 | default_graph_def_fn = functools.partial( 42 | tfgan.eval.get_graph_def_from_url_tarball, 43 | INCEPTION_URL, 44 | INCEPTION_FROZEN_GRAPH, 45 | os.path.basename(INCEPTION_URL)), 46 | output_tensor = 'logits:0'), 47 | elems = array_ops.stack(generated_images_list), 48 | parallel_iterations = 8, 49 | back_prop = False, 50 | swap_memory = True, 51 | name = 'RunClassifier') 52 | logits = array_ops.concat(array_ops.unstack(logits), 0) 53 | return logits 54 | 55 | logits=inception_logits() 56 | 57 | def get_inception_probs(inps): 58 | n_batches = int(np.ceil(float(inps.shape[0]) / BATCH_SIZE)) 59 | preds = np.zeros([inps.shape[0], 1000], dtype = np.float32) 60 | for i in tqdm(range(n_batches)): 61 | inp = inps[i * BATCH_SIZE:(i + 1) * BATCH_SIZE] / 255. * 2 - 1 62 | preds[i * BATCH_SIZE : i * BATCH_SIZE + min(BATCH_SIZE, inp.shape[0])] = session.run(logits,{inception_images: inp})[:, :1000] 63 | preds = np.exp(preds) / np.sum(np.exp(preds), 1, keepdims=True) 64 | return preds 65 | 66 | def preds2score(preds, splits=10): 67 | scores = [] 68 | for i in range(splits): 69 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 70 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 71 | kl = np.mean(np.sum(kl, 1)) 72 | scores.append(np.exp(kl)) 73 | return np.mean(scores), np.std(scores) 74 | 75 | def get_inception_score(images, splits=10): 76 | assert(type(images) == np.ndarray) 77 | assert(len(images.shape) == 4) 78 | assert(images.shape[1] == 3) 79 | assert(np.min(images[0]) >= 0 and np.max(images[0]) > 10), 'Image values should be in the range [0, 255]' 80 | print('Calculating Inception Score with %i images in %i splits' % (images.shape[0], splits)) 81 | start_time=time.time() 82 | preds = get_inception_probs(images) 83 | mean, std = preds2score(preds, splits) 84 | print('Inception Score calculation time: %f s' % (time.time() - start_time)) 85 | return mean, std # Reference values: 11.38 for 50000 CIFAR-10 training set images, or mean=11.31, std=0.10 if in 10 splits. 86 | 87 | def compute_is_from_npz(path): 88 | with np.load(path) as data: 89 | fake_imgs = data['fake'] 90 | fake_imgs = fake_imgs.transpose(0, 3, 1, 2) 91 | print(fake_imgs.shape) 92 | return get_inception_score(fake_imgs) 93 | 94 | 95 | if __name__ == '__main__': 96 | import argparse 97 | import json 98 | 99 | parser = argparse.ArgumentParser('compute TF IS') 100 | parser.add_argument('--samples', help='path to samples') 101 | parser.add_argument('--it', type=str, help='path to samples') 102 | parser.add_argument('--results_dir', help='path to results_dir') 103 | args = parser.parse_args() 104 | 105 | it = args.it 106 | results_dir = args.results_dir 107 | mean, std = compute_is_from_npz(args.samples) 108 | 109 | with open(os.path.join(args.results_dir, 'is_results.json')) as f: 110 | is_results = json.load(f) 111 | 112 | is_results[it] = float(mean) 113 | print(f'{results_dir} iteration {it} IS: {mean}') 114 | 115 | with open(os.path.join(args.results_dir, 'is_results.json'), 'w') as f: 116 | f.write(json.dumps(is_results)) -------------------------------------------------------------------------------- /pytorch_pretrained_gans/self_conditioned/gan_training/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import (dcgan_deep, dcgan_shallow, resnet2) 2 | 3 | generator_dict = { 4 | 'resnet2': resnet2.Generator, 5 | 'dcgan_deep': dcgan_deep.Generator, 6 | 'dcgan_shallow': dcgan_shallow.Generator 7 | } 8 | 9 | discriminator_dict = { 10 | 'resnet2': resnet2.Discriminator, 11 | 'dcgan_deep': dcgan_deep.Discriminator, 12 | 'dcgan_shallow': dcgan_shallow.Discriminator 13 | } 14 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/self_conditioned/gan_training/models/dcgan_deep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import torch.utils.data 5 | import torch.utils.data.distributed 6 | from . import blocks 7 | 8 | 9 | class Generator(nn.Module): 10 | def __init__(self, 11 | nlabels, 12 | conditioning, 13 | z_dim=128, 14 | nc=3, 15 | ngf=64, 16 | embed_dim=256, 17 | **kwargs): 18 | super(Generator, self).__init__() 19 | 20 | assert conditioning != 'unconditional' or nlabels == 1 21 | 22 | if conditioning == 'embedding': 23 | self.get_latent = blocks.LatentEmbeddingConcat(nlabels, embed_dim) 24 | self.fc = nn.Linear(z_dim + embed_dim, 4 * 4 * ngf * 8) 25 | elif conditioning == 'unconditional': 26 | self.get_latent = blocks.Identity() 27 | self.fc = nn.Linear(z_dim, 4 * 4 * ngf * 8) 28 | else: 29 | raise NotImplementedError( 30 | f"{conditioning} not implemented for generator") 31 | 32 | bn = blocks.BatchNorm2d 33 | 34 | self.nlabels = nlabels 35 | 36 | self.conv1 = nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1) 37 | self.bn1 = bn(ngf * 4, nlabels) 38 | 39 | self.conv2 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1) 40 | self.bn2 = bn(ngf * 2, nlabels) 41 | 42 | self.conv3 = nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1) 43 | self.bn3 = bn(ngf, nlabels) 44 | 45 | self.conv_out = nn.Sequential(nn.Conv2d(ngf, nc, 3, 1, 1), nn.Tanh()) 46 | 47 | def forward(self, input, y): 48 | y = y.clamp(None, self.nlabels - 1) 49 | out = self.get_latent(input, y) 50 | 51 | out = self.fc(out) 52 | out = out.view(out.size(0), -1, 4, 4) 53 | out = F.relu(self.bn1(self.conv1(out), y)) 54 | out = F.relu(self.bn2(self.conv2(out), y)) 55 | out = F.relu(self.bn3(self.conv3(out), y)) 56 | return self.conv_out(out) 57 | 58 | 59 | class Discriminator(nn.Module): 60 | def __init__(self, 61 | nlabels, 62 | conditioning, 63 | nc=3, 64 | ndf=64, 65 | pack_size=1, 66 | features='penultimate', 67 | **kwargs): 68 | 69 | super(Discriminator, self).__init__() 70 | 71 | assert conditioning != 'unconditional' or nlabels == 1 72 | 73 | self.nlabels = nlabels 74 | 75 | self.conv1 = nn.Sequential(nn.Conv2d(nc * pack_size, ndf, 3, 1, 1), nn.LeakyReLU(0.1)) 76 | self.conv2 = nn.Sequential(nn.Conv2d(ndf, ndf, 4, 2, 1), nn.LeakyReLU(0.1)) 77 | self.conv3 = nn.Sequential(nn.Conv2d(ndf, ndf * 2, 3, 1, 1), nn.LeakyReLU(0.1)) 78 | self.conv4 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 2, 4, 2, 1), nn.LeakyReLU(0.1)) 79 | self.conv5 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 4, 3, 1, 1), nn.LeakyReLU(0.1)) 80 | self.conv6 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 4, 4, 2, 1), nn.LeakyReLU(0.1)) 81 | self.conv7 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 8, 3, 1, 1), nn.LeakyReLU(0.1)) 82 | 83 | if conditioning == 'mask': 84 | self.fc_out = blocks.LinearConditionalMaskLogits( 85 | ndf * 8 * 4 * 4, nlabels) 86 | elif conditioning == 'unconditional': 87 | self.fc_out = blocks.LinearUnconditionalLogits( 88 | ndf * 8 * 4 * 4) 89 | else: 90 | raise NotImplementedError( 91 | f"{conditioning} not implemented for discriminator") 92 | 93 | self.features = features 94 | self.pack_size = pack_size 95 | print(f'Getting features from {self.features}') 96 | 97 | def stack(self, x): 98 | # pacgan 99 | nc = self.pack_size 100 | assert (x.size(0) % nc == 0) 101 | if nc == 1: 102 | return x 103 | x_new = [] 104 | for i in range(x.size(0) // nc): 105 | imgs_to_stack = x[i * nc:(i + 1) * nc] 106 | x_new.append(torch.cat([t for t in imgs_to_stack], dim=0)) 107 | return torch.stack(x_new) 108 | 109 | def forward(self, input, y=None, get_features=False): 110 | input = self.stack(input) 111 | out = self.conv1(input) 112 | out = self.conv2(out) 113 | out = self.conv3(out) 114 | out = self.conv4(out) 115 | out = self.conv5(out) 116 | out = self.conv6(out) 117 | out = self.conv7(out) 118 | 119 | if get_features and self.features == "penultimate": 120 | return out.view(out.size(0), -1) 121 | if get_features and self.features == "summed": 122 | return out.view(out.size(0), out.size(1), -1).sum(dim=2) 123 | 124 | out = out.view(out.size(0), -1) 125 | y = y.clamp(None, self.nlabels - 1) 126 | result = self.fc_out(out, y) 127 | assert (len(result.shape) == 1) 128 | return result 129 | 130 | 131 | if __name__ == '__main__': 132 | z = torch.zeros((1, 128)) 133 | g = Generator() 134 | x = torch.zeros((1, 3, 32, 32)) 135 | d = Discriminator() 136 | 137 | g(z) 138 | d(g(z)) 139 | d(x) 140 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/self_conditioned/gan_training/models/dcgan_shallow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import torch.utils.data 5 | import torch.utils.data.distributed 6 | from . import blocks 7 | 8 | 9 | class Generator(nn.Module): 10 | def __init__(self, 11 | nlabels, 12 | conditioning, 13 | z_dim=128, 14 | nc=3, 15 | ngf=64, 16 | embed_dim=256, 17 | **kwargs): 18 | super(Generator, self).__init__() 19 | 20 | assert conditioning != 'unconditional' or nlabels == 1 21 | 22 | if conditioning == 'embedding': 23 | self.get_latent = blocks.LatentEmbeddingConcat(nlabels, embed_dim) 24 | self.fc = nn.Linear(z_dim + embed_dim, 4 * 4 * ngf * 8) 25 | elif conditioning == 'unconditional': 26 | self.get_latent = blocks.Identity() 27 | self.fc = nn.Linear(z_dim, 4 * 4 * ngf * 8) 28 | else: 29 | raise NotImplementedError( 30 | f"{conditioning} not implemented for generator") 31 | 32 | bn = blocks.BatchNorm2d 33 | 34 | self.nlabels = nlabels 35 | 36 | self.conv1 = nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1) 37 | self.bn1 = bn(ngf * 4, nlabels) 38 | 39 | self.conv2 = nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1) 40 | self.bn2 = bn(ngf * 2, nlabels) 41 | 42 | self.conv3 = nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1) 43 | self.bn3 = bn(ngf, nlabels) 44 | 45 | self.conv_out = nn.Sequential(nn.Conv2d(ngf, nc, 3, 1, 1), nn.Tanh()) 46 | 47 | def forward(self, input, y): 48 | y = y.clamp(None, self.nlabels - 1) 49 | 50 | out = self.get_latent(input, y) 51 | out = self.fc(out) 52 | 53 | out = out.view(out.size(0), -1, 4, 4) 54 | out = F.relu(self.bn1(self.conv1(out), y)) 55 | out = F.relu(self.bn2(self.conv2(out), y)) 56 | out = F.relu(self.bn3(self.conv3(out), y)) 57 | return self.conv_out(out) 58 | 59 | 60 | class Discriminator(nn.Module): 61 | def __init__(self, 62 | nlabels, 63 | conditioning, 64 | features='penultimate', 65 | pack_size=1, 66 | nc=3, 67 | ndf=64, 68 | **kwargs): 69 | super(Discriminator, self).__init__() 70 | 71 | assert conditioning != 'unconditional' or nlabels == 1 72 | 73 | self.nlabels = nlabels 74 | 75 | self.conv1 = nn.Sequential(nn.Conv2d(nc * pack_size, ndf, 4, 2, 1), 76 | nn.BatchNorm2d(ndf), 77 | nn.LeakyReLU(0.2, inplace=True)) 78 | self.conv2 = nn.Sequential(nn.Conv2d(ndf, ndf * 2, 4, 2, 1), 79 | nn.BatchNorm2d(ndf * 2), 80 | nn.LeakyReLU(0.2, inplace=True)) 81 | self.conv3 = nn.Sequential(nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1), 82 | nn.BatchNorm2d(ndf * 4), 83 | nn.LeakyReLU(0.2, inplace=True)) 84 | self.conv4 = nn.Sequential(nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1), 85 | nn.BatchNorm2d(ndf * 8), 86 | nn.LeakyReLU(0.2, inplace=True)) 87 | 88 | if conditioning == 'mask': 89 | self.fc_out = blocks.LinearConditionalMaskLogits(ndf * 8 * 4, nlabels) 90 | elif conditioning == 'unconditional': 91 | self.fc_out = blocks.LinearUnconditionalLogits(ndf * 8 * 4) 92 | else: 93 | raise NotImplementedError( 94 | f"{conditioning} not implemented for discriminator") 95 | 96 | self.pack_size = pack_size 97 | self.features = features 98 | print(f'Getting features from {self.features}') 99 | 100 | def stack(self, x): 101 | # pacgan 102 | nc = self.pack_size 103 | if nc == 1: 104 | return x 105 | x_new = [] 106 | for i in range(x.size(0) // nc): 107 | imgs_to_stack = x[i * nc:(i + 1) * nc] 108 | x_new.append(torch.cat([t for t in imgs_to_stack], dim=0)) 109 | return torch.stack(x_new) 110 | 111 | def forward(self, input, y=None, get_features=False): 112 | input = self.stack(input) 113 | out = self.conv1(input) 114 | out = self.conv2(out) 115 | out = self.conv3(out) 116 | out = self.conv4(out) 117 | out = out.view(out.size(0), -1) 118 | if get_features: 119 | return out.view(out.size(0), -1) 120 | y = y.clamp(None, self.nlabels - 1) 121 | result = self.fc_out(out, y) 122 | assert (len(result.shape) == 1) 123 | return result 124 | 125 | 126 | if __name__ == '__main__': 127 | z = torch.zeros((1, 128)) 128 | g = Generator() 129 | x = torch.zeros((1, 3, 32, 32)) 130 | d = Discriminator() 131 | 132 | g(z) 133 | d(g(z)) 134 | d(x) 135 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/self_conditioned/gan_training/train.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | import torch 3 | from torch.nn import functional as F 4 | import torch.utils.data 5 | import torch.utils.data.distributed 6 | from torch import autograd 7 | import numpy as np 8 | 9 | 10 | class Trainer(object): 11 | def __init__(self, 12 | generator, 13 | discriminator, 14 | g_optimizer, 15 | d_optimizer, 16 | gan_type, 17 | reg_type, 18 | reg_param): 19 | 20 | self.generator = generator 21 | self.discriminator = discriminator 22 | self.g_optimizer = g_optimizer 23 | self.d_optimizer = d_optimizer 24 | self.gan_type = gan_type 25 | self.reg_type = reg_type 26 | self.reg_param = reg_param 27 | 28 | print(f'D reg gamma: {self.reg_param}') 29 | 30 | def generator_trainstep(self, y, z): 31 | assert (y.size(0) == z.size(0)) 32 | toggle_grad(self.generator, True) 33 | toggle_grad(self.discriminator, False) 34 | 35 | self.generator.train() 36 | self.discriminator.train() 37 | self.g_optimizer.zero_grad() 38 | 39 | x_fake = self.generator(z, y) 40 | d_fake = self.discriminator(x_fake, y) 41 | gloss = self.compute_loss(d_fake, 1) 42 | gloss.backward() 43 | 44 | self.g_optimizer.step() 45 | 46 | return gloss.item() 47 | 48 | def discriminator_trainstep(self, x_real, y, z): 49 | toggle_grad(self.generator, False) 50 | toggle_grad(self.discriminator, True) 51 | self.generator.train() 52 | self.discriminator.train() 53 | self.d_optimizer.zero_grad() 54 | 55 | # On real data 56 | x_real.requires_grad_() 57 | 58 | d_real = self.discriminator(x_real, y) 59 | dloss_real = self.compute_loss(d_real, 1) 60 | 61 | if self.reg_type == 'real' or self.reg_type == 'real_fake': 62 | dloss_real.backward(retain_graph=True) 63 | reg = self.reg_param * compute_grad2(d_real, x_real).mean() 64 | reg.backward() 65 | else: 66 | dloss_real.backward() 67 | 68 | # On fake data 69 | with torch.no_grad(): 70 | x_fake = self.generator(z, y) 71 | 72 | x_fake.requires_grad_() 73 | d_fake = self.discriminator(x_fake, y) 74 | dloss_fake = self.compute_loss(d_fake, 0) 75 | 76 | if self.reg_type == 'fake' or self.reg_type == 'real_fake': 77 | dloss_fake.backward(retain_graph=True) 78 | reg = self.reg_param * compute_grad2(d_fake, x_fake).mean() 79 | reg.backward() 80 | else: 81 | dloss_fake.backward() 82 | 83 | if self.reg_type == 'wgangp': 84 | reg = self.reg_param * self.wgan_gp_reg(x_real, x_fake, y) 85 | reg.backward() 86 | elif self.reg_type == 'wgangp0': 87 | reg = self.reg_param * self.wgan_gp_reg( 88 | x_real, x_fake, y, center=0.) 89 | reg.backward() 90 | 91 | self.d_optimizer.step() 92 | 93 | dloss = (dloss_real + dloss_fake) 94 | if self.reg_type == 'none': 95 | reg = torch.tensor(0.) 96 | 97 | return dloss.item(), reg.item() 98 | 99 | def compute_loss(self, d_out, target): 100 | targets = d_out.new_full(size=d_out.size(), fill_value=target) 101 | 102 | if self.gan_type == 'standard': 103 | loss = F.binary_cross_entropy_with_logits(d_out, targets) 104 | elif self.gan_type == 'wgan': 105 | loss = (2 * target - 1) * d_out.mean() 106 | else: 107 | raise NotImplementedError 108 | 109 | return loss 110 | 111 | def wgan_gp_reg(self, x_real, x_fake, y, center=1.): 112 | batch_size = y.size(0) 113 | eps = torch.rand(batch_size, device=y.device).view(batch_size, 1, 1, 1) 114 | x_interp = (1 - eps) * x_real + eps * x_fake 115 | x_interp = x_interp.detach() 116 | x_interp.requires_grad_() 117 | d_out = self.discriminator(x_interp, y) 118 | 119 | reg = (compute_grad2(d_out, x_interp).sqrt() - center).pow(2).mean() 120 | 121 | return reg 122 | 123 | 124 | # Utility functions 125 | def toggle_grad(model, requires_grad): 126 | for p in model.parameters(): 127 | p.requires_grad_(requires_grad) 128 | 129 | 130 | def compute_grad2(d_out, x_in): 131 | batch_size = x_in.size(0) 132 | grad_dout = autograd.grad(outputs=d_out.sum(), 133 | inputs=x_in, 134 | create_graph=True, 135 | retain_graph=True, 136 | only_inputs=True)[0] 137 | grad_dout2 = grad_dout.pow(2) 138 | assert (grad_dout2.size() == x_in.size()) 139 | reg = grad_dout2.view(batch_size, -1).sum(1) 140 | return reg 141 | 142 | 143 | def update_average(model_tgt, model_src, beta): 144 | toggle_grad(model_src, False) 145 | toggle_grad(model_tgt, False) 146 | 147 | param_dict_src = dict(model_src.named_parameters()) 148 | 149 | for p_name, p_tgt in model_tgt.named_parameters(): 150 | p_src = param_dict_src[p_name] 151 | assert (p_src is not p_tgt) 152 | p_tgt.copy_(beta * p_tgt + (1. - beta) * p_src) 153 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/self_conditioned/gan_training/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torch.utils.data.distributed 4 | import torchvision 5 | 6 | import os 7 | 8 | 9 | def save_images(imgs, outfile, nrow=8): 10 | imgs = imgs / 2 + 0.5 # unnormalize 11 | torchvision.utils.save_image(imgs, outfile, nrow=nrow) 12 | 13 | 14 | def get_nsamples(data_loader, N): 15 | x = [] 16 | y = [] 17 | n = 0 18 | for x_next, y_next in data_loader: 19 | x.append(x_next) 20 | y.append(y_next) 21 | n += x_next.size(0) 22 | if n > N: 23 | break 24 | x = torch.cat(x, dim=0)[:N] 25 | y = torch.cat(y, dim=0)[:N] 26 | return x, y 27 | 28 | 29 | def update_average(model_tgt, model_src, beta): 30 | param_dict_src = dict(model_src.named_parameters()) 31 | 32 | for p_name, p_tgt in model_tgt.named_parameters(): 33 | p_src = param_dict_src[p_name] 34 | assert (p_src is not p_tgt) 35 | p_tgt.copy_(beta * p_tgt + (1. - beta) * p_src) 36 | 37 | 38 | def get_most_recent(d, ext): 39 | if not os.path.exists(d): 40 | print(f'Directory {d} does not exist') 41 | return -1 42 | its = [] 43 | for f in os.listdir(d): 44 | try: 45 | it = int(f.split(ext + "_")[1].split('.pt')[0]) 46 | its.append(it) 47 | except Exception as e: 48 | pass 49 | if len(its) == 0: 50 | print('Found no files with extension \"%s\" under %s' % (ext, d)) 51 | return -1 52 | return max(its) 53 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/stylegan2_ada_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | from torch.hub import urlparse, get_dir, download_url_to_file 5 | import pickle 6 | 7 | 8 | MODELS = { 9 | 'ffhq': ('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl', None), 10 | 'afhqwild': ('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqwild.pkl', None), 11 | } 12 | 13 | 14 | def download_url(url, download_dir=None, filename=None): 15 | parts = urlparse(url) 16 | if download_dir is None: 17 | hub_dir = get_dir() 18 | download_dir = os.path.join(hub_dir, 'checkpoints') 19 | if filename is None: 20 | filename = os.path.basename(parts.path) 21 | cached_file = os.path.join(download_dir, filename) 22 | if not os.path.exists(cached_file): 23 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 24 | download_url_to_file(url, cached_file) 25 | return cached_file 26 | 27 | 28 | class GeneratorWrapper(torch.nn.Module): 29 | """ A wrapper to put the GAN in a standard format. This wrapper takes 30 | w as input, rather than (z, c) """ 31 | 32 | def __init__(self, G, num_classes=None): 33 | super().__init__() 34 | self.G = G # NOTE! This takes in w, rather than z 35 | self.dim_z = G.synthesis.w_dim 36 | self.conditional = (num_classes is not None) 37 | self.num_classes = num_classes 38 | 39 | self.num_ws = G.synthesis.num_ws 40 | self.truncation_psi = 0.5 41 | self.truncation_cutoff = 8 42 | 43 | def forward(self, z): 44 | r"""The input `z` is expected to be `w`, not `z`, in the notation 45 | of the original StyleGAN 2 paper""" 46 | if len(z.shape) == 2: # expand to 18 layers 47 | z = z.unsqueeze(1).repeat(1, self.num_ws, 1) 48 | return self.G.synthesis(z) 49 | 50 | def sample_latent(self, batch_size, device='cpu'): 51 | z = torch.randn([batch_size, self.dim_z], device=device) 52 | c = None if self.conditional else None # not implemented for conditional models 53 | w = self.G.mapping(z, c, truncation_psi=self.truncation_psi, truncation_cutoff=self.truncation_cutoff) 54 | return w 55 | 56 | 57 | def add_utils_to_path(): 58 | import sys 59 | from pathlib import Path 60 | util_path = str(Path(__file__).parent) 61 | if util_path not in sys.path: 62 | sys.path.append(util_path) 63 | print(f'Added {util_path} to path') 64 | 65 | 66 | def make_stylegan2(model_name='ffhq') -> torch.nn.Module: 67 | """G takes as input an image in NCHW format with dtype float32, normalized 68 | to the range [-1, +1]. Some models also take a conditioning class label, 69 | which is passed as img = G(z, c)""" 70 | add_utils_to_path() # we need dnnlib and torch_utils in the path 71 | url, num_classes = MODELS[model_name] 72 | cached_file = download_url(url) 73 | assert cached_file.endswith('.pkl') 74 | with open(cached_file, 'rb') as f: 75 | G = pickle.load(f)['G_ema'] 76 | G = GeneratorWrapper(G, num_classes=num_classes) 77 | return G.eval() 78 | 79 | 80 | if __name__ == '__main__': 81 | # Testing 82 | G = make_stylegan2().cuda() 83 | print('Created G') 84 | print(f'Params: {sum(p.numel() for p in G.parameters()):_}') 85 | z = torch.randn([1, G.dim_z]).cuda() 86 | print(f'z.shape: {z.shape}') 87 | x = G(z) 88 | print(f'x.shape: {x.shape}') 89 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/stylegan2_ada_pytorch/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import glob 11 | import torch 12 | import torch.utils.cpp_extension 13 | import importlib 14 | import hashlib 15 | import shutil 16 | from pathlib import Path 17 | 18 | from torch.utils.file_baton import FileBaton 19 | 20 | #---------------------------------------------------------------------------- 21 | # Global options. 22 | 23 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 24 | 25 | #---------------------------------------------------------------------------- 26 | # Internal helper funcs. 27 | 28 | def _find_compiler_bindir(): 29 | patterns = [ 30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 34 | ] 35 | for pattern in patterns: 36 | matches = sorted(glob.glob(pattern)) 37 | if len(matches): 38 | return matches[-1] 39 | return None 40 | 41 | #---------------------------------------------------------------------------- 42 | # Main entry point for compiling and loading C++/CUDA plugins. 43 | 44 | _cached_plugins = dict() 45 | 46 | def get_plugin(module_name, sources, **build_kwargs): 47 | assert verbosity in ['none', 'brief', 'full'] 48 | 49 | # Already cached? 50 | if module_name in _cached_plugins: 51 | return _cached_plugins[module_name] 52 | 53 | # Print status. 54 | if verbosity == 'full': 55 | print(f'Setting up PyTorch plugin "{module_name}"...') 56 | elif verbosity == 'brief': 57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 58 | 59 | try: # pylint: disable=too-many-nested-blocks 60 | # Make sure we can find the necessary compiler binaries. 61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 62 | compiler_bindir = _find_compiler_bindir() 63 | if compiler_bindir is None: 64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 65 | os.environ['PATH'] += ';' + compiler_bindir 66 | 67 | # Compile and load. 68 | verbose_build = (verbosity == 'full') 69 | 70 | # Incremental build md5sum trickery. Copies all the input source files 71 | # into a cached build directory under a combined md5 digest of the input 72 | # source files. Copying is done only if the combined digest has changed. 73 | # This keeps input file timestamps and filenames the same as in previous 74 | # extension builds, allowing for fast incremental rebuilds. 75 | # 76 | # This optimization is done only in case all the source files reside in 77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 78 | # environment variable is set (we take this as a signal that the user 79 | # actually cares about this.) 80 | source_dirs_set = set(os.path.dirname(source) for source in sources) 81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): 82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) 83 | 84 | # Compute a combined hash digest for all source files in the same 85 | # custom op directory (usually .cu, .cpp, .py and .h files). 86 | hash_md5 = hashlib.md5() 87 | for src in all_source_files: 88 | with open(src, 'rb') as f: 89 | hash_md5.update(f.read()) 90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 92 | 93 | if not os.path.isdir(digest_build_dir): 94 | os.makedirs(digest_build_dir, exist_ok=True) 95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock')) 96 | if baton.try_acquire(): 97 | try: 98 | for src in all_source_files: 99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) 100 | finally: 101 | baton.release() 102 | else: 103 | # Someone else is copying source files under the digest dir, 104 | # wait until done and continue. 105 | baton.wait() 106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] 107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, 108 | verbose=verbose_build, sources=digest_sources, **build_kwargs) 109 | else: 110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 111 | module = importlib.import_module(module_name) 112 | 113 | except: 114 | if verbosity == 'brief': 115 | print('Failed!') 116 | raise 117 | 118 | # Print status and add to cache. 119 | if verbosity == 'full': 120 | print(f'Done setting up PyTorch plugin "{module_name}".') 121 | elif verbosity == 'brief': 122 | print('Done.') 123 | _cached_plugins[module_name] = module 124 | return module 125 | 126 | #---------------------------------------------------------------------------- 127 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import warnings 15 | import torch 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | def grid_sample(input, grid): 28 | if _should_use_custom_op(): 29 | return _GridSample2dForward.apply(input, grid) 30 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def _should_use_custom_op(): 35 | if not enabled: 36 | return False 37 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 38 | return True 39 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') 40 | return False 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | class _GridSample2dForward(torch.autograd.Function): 45 | @staticmethod 46 | def forward(ctx, input, grid): 47 | assert input.ndim == 4 48 | assert grid.ndim == 4 49 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 50 | ctx.save_for_backward(input, grid) 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | input, grid = ctx.saved_tensors 56 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 57 | return grad_input, grad_grid 58 | 59 | #---------------------------------------------------------------------------- 60 | 61 | class _GridSample2dBackward(torch.autograd.Function): 62 | @staticmethod 63 | def forward(ctx, grad_output, input, grid): 64 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /pytorch_pretrained_gans/stylegan2_ada_pytorch/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup, find_packages 4 | 5 | setup( 6 | name='pytorch_pretrained_gans', 7 | version='0.0.1', 8 | description='Project', 9 | author='Luke Melas-Kyriazi', 10 | author_email='', 11 | url='https://github.com/lukemelas/', 12 | install_requires=[], 13 | packages=find_packages(), 14 | package_data={'pytorch_pretrained_gans': 15 | [ 16 | 'BigBiGAN/model/weights', 17 | 'StudioGAN/configs', 18 | ] 19 | }, 20 | ) 21 | -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | pytest tests/test.py --disable-pytest-warnings -s 3 | """ 4 | import pytest 5 | import torch 6 | 7 | @pytest.mark.skip(reason="disabled") 8 | def test_load_models(): 9 | 10 | from pytorch_pretrained_gans import make_gan 11 | 12 | # # BigGAN (conditional) 13 | # G = make_gan(gan_type='biggan', model_name='biggan-deep-128') 14 | # y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000]) 15 | # z = G.sample_latent(batch_size=1) # -> torch.Size([1, 128]) 16 | # x = G(z=z, y=y) # -> torch.Size([1, 3, 128, 128]) 17 | # assert list(x.shape) == [1, 3, 128, 128] 18 | 19 | # SelfCondGAN (conditional) 20 | G = make_gan(gan_type='selfconditionedgan', model_name='self_conditioned') 21 | y = G.sample_class(batch_size=1) # -> torch.Size([1, 1000]) 22 | z = G.sample_latent(batch_size=1) # -> torch.Size([1, 128]) 23 | x = G(z=z, y=y) # -> torch.Size([1, 3, 256, 256]) 24 | assert list(x.shape) == [1, 3, 256, 256] 25 | 26 | 27 | --------------------------------------------------------------------------------