├── .gitignore ├── LICENSE ├── README.md ├── StyleGAN2.py ├── cuda ├── __init__.py ├── conv2d_gradfix.py ├── fused_act.py ├── fused_bias_act.cpp ├── fused_bias_act_kernel.cu ├── upfirdn2d.cpp ├── upfirdn2d.py └── upfirdn2d_kernel.cu ├── main.py ├── metric └── cal_fid.py ├── networks.py ├── ops.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Junho Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # stylegan2-pytorch 2 | Pytorch implementation of StyleGAN2 in my style 3 | 4 | ## Usage 5 | ``` 6 | > python main.py --dataset FFHQ --img_size 256 --batch_size 8 7 | ``` 8 | 9 | ## Reference 10 | * [rosinality-pytorch](https://github.com/rosinality/stylegan2-pytorch) 11 | 12 | ## Author 13 | * [Junho Kim](http://bit.ly/jhkim_resume) 14 | -------------------------------------------------------------------------------- /StyleGAN2.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | 3 | from utils import * 4 | import time 5 | from networks import * 6 | from copy import deepcopy 7 | from torch.utils.tensorboard import SummaryWriter 8 | import numpy as np 9 | import torchvision 10 | from functools import partial 11 | from metric.cal_fid import InceptionV3, calculate_fid 12 | 13 | print = partial(print, flush=True) 14 | 15 | def run_fn(rank, args, world_size): 16 | device = torch.device('cuda', rank) 17 | torch.backends.cudnn.benchmark = True 18 | 19 | model = StyleGAN2(args, world_size) 20 | model.build_model(rank, device) 21 | model.train_model(rank, device) 22 | 23 | class StyleGAN2(): 24 | def __init__(self, args, NUM_GPUS): 25 | super(StyleGAN2, self).__init__() 26 | """ Model """ 27 | self.model_name = 'StyleGAN2' 28 | self.phase = args['phase'] 29 | self.checkpoint_dir = args['checkpoint_dir'] 30 | self.result_dir = args['result_dir'] 31 | self.log_dir = args['log_dir'] 32 | self.sample_dir = args['sample_dir'] 33 | self.dataset_name = args['dataset'] 34 | self.NUM_GPUS = NUM_GPUS 35 | 36 | 37 | """ Training parameters """ 38 | self.img_size = args['img_size'] 39 | self.batch_size = args['batch_size'] 40 | self.global_batch_size = self.batch_size * self.NUM_GPUS 41 | self.n_total_image = args['n_total_image'] * 1000 42 | self.iteration = self.n_total_image // self.global_batch_size 43 | 44 | self.g_reg_every = args['g_reg_every'] 45 | self.d_reg_every = args['d_reg_every'] 46 | self.lr = args['lr'] 47 | 48 | 49 | """ Network parameters """ 50 | self.channel_multiplier = args['channel_multiplier'] 51 | self.lazy_regularization = args['lazy_regularization'] 52 | self.r1_gamma = 10.0 53 | self.path_batch_shrink = 2 54 | self.path_weight = 2.0 55 | self.path_decay = 0.01 56 | self.mean_path_length = 0 57 | 58 | self.latent_dim = 512 59 | self.mixing_prob = args['mixing_prob'] 60 | 61 | 62 | """ Print parameters """ 63 | self.print_freq = args['print_freq'] 64 | self.save_freq = args['save_freq'] 65 | self.log_template = 'step [{}/{}]: elapsed: {:.2f}s, d_loss: {:.3f}, g_loss: {:.3f}, fid: {:.2f}, best_fid: {:.2f}, best_fid_iter: {}' 66 | self.n_sample = args['n_sample'] 67 | 68 | """ MISC """ 69 | self.nsml_flag = args['nsml'] 70 | 71 | if self.nsml_flag: 72 | import nsml 73 | self.nsml = nsml 74 | self.dataset_name = os.path.basename(self.nsml.DATASET_PATH) 75 | dataset_path = os.path.join(self.nsml.DATASET_PATH, 'train') 76 | self.dataset_path = dataset_path 77 | 78 | else: 79 | dataset_path = './dataset' 80 | self.dataset_path = os.path.join(dataset_path, self.dataset_name) 81 | 82 | """ Directory """ 83 | self.sample_dir = os.path.join(self.sample_dir, self.model_dir) 84 | check_folder(self.sample_dir) 85 | self.checkpoint_dir = os.path.join(self.checkpoint_dir, self.model_dir) 86 | check_folder(self.checkpoint_dir) 87 | self.log_dir = os.path.join(self.log_dir, self.model_dir) 88 | check_folder(self.log_dir) 89 | ################################################################################## 90 | # Model 91 | ################################################################################## 92 | def build_model(self, rank, device): 93 | if self.phase == 'train': 94 | """ Init process """ 95 | build_init_procss(rank, world_size=self.NUM_GPUS, device=device) 96 | 97 | """ Dataset Load """ 98 | dataset = ImageDataset(dataset_path=self.dataset_path, img_size=self.img_size) 99 | self.dataset_num = dataset.__len__() 100 | loader = torch.utils.data.DataLoader(dataset, batch_size=self.batch_size, num_workers=4, 101 | sampler=distributed_sampler(dataset, rank=rank, num_replicas=self.NUM_GPUS, shuffle=True), 102 | drop_last=True, pin_memory=True) 103 | self.dataset_iter = infinite_iterator(loader) 104 | 105 | """ Calculate FID metric """ 106 | self.fid_dataset = ImageDataset(dataset_path=self.dataset_path, img_size=299, fid_transform=True) 107 | self.fid_loader = torch.utils.data.DataLoader(self.fid_dataset, batch_size=self.batch_size, num_workers=4, 108 | sampler=distributed_sampler(dataset, rank=rank, num_replicas=self.NUM_GPUS, shuffle=False), 109 | drop_last=False, pin_memory=True) 110 | self.inception = InceptionV3().to(device) 111 | self.inception = dataparallel_and_sync(self.inception, rank) 112 | 113 | """ Network """ 114 | self.generator = Generator(size=self.img_size, channel_multiplier=self.channel_multiplier).to(device) 115 | self.discriminator = Discriminator(size=self.img_size, channel_multiplier=self.channel_multiplier).to(device) 116 | self.g_ema = deepcopy(self.generator).to(device) 117 | 118 | """ Optimizer """ 119 | g_reg_ratio = self.g_reg_every / (self.g_reg_every + 1) 120 | d_reg_ratio = self.d_reg_every / (self.d_reg_every + 1) 121 | self.g_optim = torch.optim.Adam(self.generator.parameters(), lr=self.lr * self.g_reg_retio, betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio)) 122 | self.d_optim = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr * self.d_reg_ratio, betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio)) 123 | 124 | """ Distributed Learning """ 125 | self.generator = dataparallel_and_sync(self.generator, rank) 126 | self.discriminator = dataparallel_and_sync(self.discriminator, rank) 127 | self.g_ema = dataparallel_and_sync(self.g_ema, rank) 128 | 129 | 130 | """ Checkpoint """ 131 | latest_ckpt_name, start_iter = find_latest_ckpt(self.checkpoint_dir) 132 | if latest_ckpt_name is not None: 133 | print('Latest checkpoint restored!! ', latest_ckpt_name) 134 | print('start iteration : ', start_iter) 135 | self.start_iteration = start_iter 136 | 137 | latest_ckpt = os.path.join(self.checkpoint_dir, latest_ckpt_name) 138 | ckpt = torch.load(latest_ckpt, map_location=device) 139 | 140 | self.generator.load_state_dict(ckpt["generator"]) 141 | self.discriminator.load_state_dict(ckpt["discriminator"]) 142 | self.g_ema.load_state_dict(ckpt["g_ema"]) 143 | 144 | self.g_optim.load_state_dict(ckpt["g_optim"]) 145 | self.d_optim.load_state_dict(ckpt["d_optim"]) 146 | 147 | else: 148 | if rank == 0: 149 | print('Not restoring from saved checkpoint') 150 | self.start_iteration = 0 151 | 152 | else: 153 | """ Init process """ 154 | build_init_procss(rank, world_size=self.NUM_GPUS, device=device) 155 | 156 | """ Network """ 157 | self.g_ema = Generator(size=self.img_size, channel_multiplier=self.channel_multiplier).to(device) 158 | self.g_ema = dataparallel_and_sync(self.g_ema, rank) 159 | 160 | """ Checkpoint """ 161 | latest_ckpt_name, start_iter = find_latest_ckpt(self.checkpoint_dir) 162 | if latest_ckpt_name is not None: 163 | print('Latest checkpoint restored!! ', latest_ckpt_name) 164 | print('start iteration : ', start_iter) 165 | self.start_iteration = start_iter 166 | 167 | latest_ckpt = os.path.join(self.checkpoint_dir, latest_ckpt_name) 168 | ckpt = torch.load(latest_ckpt, map_location=device) 169 | 170 | self.g_ema.load_state_dict(ckpt["g_ema"]) 171 | 172 | else: 173 | print('Not restoring from saved checkpoint') 174 | self.start_iteration = 0 175 | 176 | def d_train_step(self, real_images, d_regularize=False, device=torch.device('cuda')): 177 | # gradient check 178 | requires_grad(self.discriminator, True) 179 | requires_grad(self.generator, False) 180 | 181 | # forward pass 182 | noise = mixing_noise(self.batch_size, self.latent_dim, self.mixing_prob, device) 183 | fake_images = self.generator(noise) 184 | 185 | real_logit = self.discriminator(real_images) 186 | fake_logit = self.discriminator(fake_images) 187 | 188 | # loss 189 | d_loss = d_logistic_loss(real_logit, fake_logit) 190 | 191 | if d_regularize: 192 | real_images.requires_grad = True 193 | real_logit = self.discriminator(real_images) 194 | r1_penalty = d_r1_loss(real_logit, real_images) 195 | r1_penalty = (self.r1_gamma / 2 * r1_penalty * self.d_reg_every + 0 * real_logit[0]).mean() 196 | 197 | d_loss += r1_penalty 198 | 199 | apply_gradients(d_loss, self.d_optim) 200 | 201 | return d_loss 202 | 203 | def g_train_step(self, g_regularize, device=torch.device('cuda')): 204 | # gradient check 205 | requires_grad(self.discriminator, False) 206 | requires_grad(self.generator, True) 207 | 208 | # forward pass 209 | noise = mixing_noise(self.batch_size, self.latent_dim, self.mixing_prob, device) 210 | fake_images = self.generator(noise) 211 | 212 | fake_logit = self.discriminator(fake_images) 213 | 214 | # loss 215 | g_loss = g_nonsaturating_loss(fake_logit) 216 | 217 | if g_regularize: 218 | path_batch_size = max(1, self.batch_size // self.path_batch_shrink) 219 | noise = mixing_noise(path_batch_size, self.latent_dim, self.mixing_prob, device) 220 | fake_img, latents = self.generator(noise, return_latents=True) 221 | 222 | path_loss, mean_path_length, path_lengths = g_path_regularize(fake_img, latents, self.mean_path_length) 223 | self.mean_path_length = mean_path_length 224 | 225 | weighted_path_loss = self.path_weight * self.g_reg_every * path_loss 226 | 227 | g_loss += weighted_path_loss 228 | 229 | apply_gradients(g_loss, self.g_optim) 230 | 231 | return g_loss 232 | 233 | def train_model(self, rank, device): 234 | start_time = time.time() 235 | fid_start_time = time.time() 236 | 237 | # setup tensorboards 238 | train_summary_writer = SummaryWriter(self.log_dir) 239 | 240 | 241 | # start training 242 | if rank == 0: 243 | print() 244 | print(self.dataset_path) 245 | print("Dataset number : ", self.dataset_num) 246 | print("GPUs : ", self.NUM_GPUS) 247 | print("Each batch size : ", self.batch_size) 248 | print("Global batch size : ", self.global_batch_size) 249 | print("Target image size : ", self.img_size) 250 | print("Print frequency : ", self.print_freq) 251 | print("Save frequency : ", self.save_freq) 252 | print("PyTorch Version :", torch.__version__) 253 | print('max_steps: {}'.format(self.iteration)) 254 | print() 255 | losses = {'g/loss': 0.0, 'd/loss': 0.0} 256 | fid_dict = {'metric/fid': 0.0, 'metric/best_fid': 0.0, 'metric/best_fid_iter': 0} 257 | 258 | 259 | fid = 0 260 | best_fid = 1000 261 | best_fid_iter = 0 262 | 263 | for idx in range(self.start_iteration, self.iteration): 264 | iter_start_time = time.time() 265 | 266 | real_img = next(self.dataset_iter) 267 | real_img = real_img.to(device) 268 | 269 | if idx == 0: 270 | if rank == 0: 271 | print("count params") 272 | g_params = count_parameters(self.generator) 273 | d_params = count_parameters(self.discriminator) 274 | print("G network parameters : ", format(g_params, ',')) 275 | print("D network parameters : ", format(d_params, ',')) 276 | print("Total network parameters : ", format(g_params + d_params, ',')) 277 | print() 278 | 279 | # update discriminator 280 | if (idx + 1) % self.d_reg_every == 0: 281 | d_loss = self.d_train_step(real_img, d_regularize=True, device=device) 282 | else: 283 | d_loss = self.d_train_step(real_img, d_regularize=False, device=device) 284 | 285 | losses['d/loss'] = d_loss 286 | 287 | # update generator 288 | if (idx + 1) % self.g_reg_every == 0: 289 | g_loss = self.g_train_step(g_regularize=True, device=device) 290 | else: 291 | g_loss = self.g_train_step(g_regularize=False, device=device) 292 | 293 | losses['g/loss'] = g_loss 294 | 295 | # moving average 296 | moving_average(self.g_ema, self.generator, decay=0.999) 297 | 298 | losses = reduce_loss_dict(losses) 299 | 300 | if np.mod(idx, self.save_freq) == 0 or idx == self.iteration - 1 : 301 | if rank == 0: 302 | print("calculate fid ...") 303 | fid_start_time = time.time() 304 | 305 | fid = calculate_fid(self.fid_loader, self.g_ema, self.inception, self.dataset_name, rank, device, 306 | self.latent_dim, fake_samples=50000, batch_size=self.batch_size) 307 | 308 | if rank == 0: 309 | fid_end_time = time.time() 310 | fid_elapsed = fid_end_time - fid_start_time 311 | print("calculate fid finish: {:.2f}s".format(fid_elapsed)) 312 | if fid < best_fid: 313 | print("BEST FID UPDATED") 314 | best_fid = fid 315 | best_fid_iter = idx 316 | self.torch_save(idx, fid) 317 | 318 | fid_dict['metric/best_fid'] = best_fid 319 | fid_dict['metric/best_fid_iter'] = best_fid_iter 320 | fid_dict['metric/fid'] = fid 321 | 322 | 323 | if rank == 0: 324 | # save to tensorboard 325 | if self.nsml_flag: 326 | if np.mod(idx, self.save_freq) == 0 or idx == self.iteration - 1: 327 | self.nsml.report(**losses, scope=locals(), step=idx) 328 | self.nsml.report(**fid_dict, scope=locals(), step=idx) 329 | else: 330 | self.nsml.report(**losses, scope=locals(), step=idx) 331 | 332 | for k, v in losses.items(): 333 | train_summary_writer.add_scalar(k, v, global_step=idx) 334 | 335 | if np.mod(idx, self.save_freq) == 0 or idx == self.iteration - 1: 336 | train_summary_writer.add_scalar('fid', fid, global_step=idx) 337 | else: 338 | for k, v in losses.items(): 339 | train_summary_writer.add_scalar(k, v, global_step=idx) 340 | 341 | if np.mod(idx, self.save_freq) == 0 or idx == self.iteration - 1: 342 | train_summary_writer.add_scalar('fid', fid, global_step=idx) 343 | 344 | # save every self.save_freq 345 | """ 346 | if np.mod(idx + 1, self.save_freq) == 0: 347 | print("ckpt save") 348 | self.torch_save(idx) 349 | """ 350 | if np.mod(idx + 1, self.print_freq) == 0: 351 | with torch.no_grad(): 352 | partial_size = int(self.n_sample ** 0.5) 353 | sample_z = [torch.randn([self.n_sample, self.latent_dim], device=device)] 354 | self.g_ema.eval() 355 | 356 | sample = self.g_ema(sample_z) 357 | 358 | torchvision.utils.save_image(sample, './{}/fake_{:06d}.png'.format(self.sample_dir, idx + 1), 359 | nrow=partial_size, 360 | normalize=True, range=(-1, 1)) 361 | # normalize = set to the range (0, 1) by range(min, max) 362 | 363 | elapsed = time.time() - iter_start_time 364 | print(self.log_template.format(idx, self.iteration, elapsed, losses['d/loss'], losses['g/loss'], 365 | fid_dict['metric/fid'], fid_dict['metric/best_fid'], fid_dict['metric/best_fid_iter'])) 366 | 367 | dist.barrier() 368 | 369 | if rank == 0: 370 | # save model for final step 371 | self.torch_save(self.iteration, fid) 372 | 373 | print("LAST FID: ", fid) 374 | print("BEST FID: {}, {}".format(best_fid, best_fid_iter)) 375 | print("Total train time: %4.4f" % (time.time() - start_time)) 376 | 377 | dist.barrier() 378 | 379 | def torch_save(self, idx, fid): 380 | torch.save( 381 | { 382 | 'generator': self.generator.state_dict(), 383 | 'discriminator': self.discriminator.state_dict(), 384 | 'g_ema': self.g_ema.state_dict(), 385 | 'g_optim': self.g_optim.state_dict(), 386 | 'd_optim': self.d_optim.state_dict() 387 | }, 388 | os.path.join(self.checkpoint_dir, 'iter_{}_fid_{}.pt'.format(idx, fid)) 389 | ) 390 | 391 | @property 392 | def model_dir(self): 393 | return "{}_{}_{}".format(self.model_name, self.dataset_name, self.img_size) 394 | -------------------------------------------------------------------------------- /cuda/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /cuda/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import warnings 3 | 4 | import torch 5 | from torch import autograd 6 | from torch.nn import functional as F 7 | 8 | enabled = True 9 | weight_gradients_disabled = False 10 | 11 | 12 | @contextlib.contextmanager 13 | def no_weight_gradients(): 14 | global weight_gradients_disabled 15 | 16 | old = weight_gradients_disabled 17 | weight_gradients_disabled = True 18 | yield 19 | weight_gradients_disabled = old 20 | 21 | 22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 23 | if could_use_op(input): 24 | return conv2d_gradfix( 25 | transpose=False, 26 | weight_shape=weight.shape, 27 | stride=stride, 28 | padding=padding, 29 | output_padding=0, 30 | dilation=dilation, 31 | groups=groups, 32 | ).apply(input, weight, bias) 33 | 34 | return F.conv2d( 35 | input=input, 36 | weight=weight, 37 | bias=bias, 38 | stride=stride, 39 | padding=padding, 40 | dilation=dilation, 41 | groups=groups, 42 | ) 43 | 44 | 45 | def conv_transpose2d( 46 | input, 47 | weight, 48 | bias=None, 49 | stride=1, 50 | padding=0, 51 | output_padding=0, 52 | groups=1, 53 | dilation=1, 54 | ): 55 | if could_use_op(input): 56 | return conv2d_gradfix( 57 | transpose=True, 58 | weight_shape=weight.shape, 59 | stride=stride, 60 | padding=padding, 61 | output_padding=output_padding, 62 | groups=groups, 63 | dilation=dilation, 64 | ).apply(input, weight, bias) 65 | 66 | return F.conv_transpose2d( 67 | input=input, 68 | weight=weight, 69 | bias=bias, 70 | stride=stride, 71 | padding=padding, 72 | output_padding=output_padding, 73 | dilation=dilation, 74 | groups=groups, 75 | ) 76 | 77 | 78 | def could_use_op(input): 79 | if (not enabled) or (not torch.backends.cudnn.enabled): 80 | return False 81 | 82 | if input.device.type != "cuda": 83 | return False 84 | 85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]): 86 | return True 87 | 88 | warnings.warn( 89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." 90 | ) 91 | 92 | return False 93 | 94 | 95 | def ensure_tuple(xs, ndim): 96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 97 | 98 | return xs 99 | 100 | 101 | conv2d_gradfix_cache = dict() 102 | 103 | 104 | def conv2d_gradfix( 105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups 106 | ): 107 | ndim = 2 108 | weight_shape = tuple(weight_shape) 109 | stride = ensure_tuple(stride, ndim) 110 | padding = ensure_tuple(padding, ndim) 111 | output_padding = ensure_tuple(output_padding, ndim) 112 | dilation = ensure_tuple(dilation, ndim) 113 | 114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 115 | if key in conv2d_gradfix_cache: 116 | return conv2d_gradfix_cache[key] 117 | 118 | common_kwargs = dict( 119 | stride=stride, padding=padding, dilation=dilation, groups=groups 120 | ) 121 | 122 | def calc_output_padding(input_shape, output_shape): 123 | if transpose: 124 | return [0, 0] 125 | 126 | return [ 127 | input_shape[i + 2] 128 | - (output_shape[i + 2] - 1) * stride[i] 129 | - (1 - 2 * padding[i]) 130 | - dilation[i] * (weight_shape[i + 2] - 1) 131 | for i in range(ndim) 132 | ] 133 | 134 | class Conv2d(autograd.Function): 135 | @staticmethod 136 | def forward(ctx, input, weight, bias): 137 | if not transpose: 138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 139 | 140 | else: 141 | out = F.conv_transpose2d( 142 | input=input, 143 | weight=weight, 144 | bias=bias, 145 | output_padding=output_padding, 146 | **common_kwargs, 147 | ) 148 | 149 | ctx.save_for_backward(input, weight) 150 | 151 | return out 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | input, weight = ctx.saved_tensors 156 | grad_input, grad_weight, grad_bias = None, None, None 157 | 158 | if ctx.needs_input_grad[0]: 159 | p = calc_output_padding( 160 | input_shape=input.shape, output_shape=grad_output.shape 161 | ) 162 | grad_input = conv2d_gradfix( 163 | transpose=(not transpose), 164 | weight_shape=weight_shape, 165 | output_padding=p, 166 | **common_kwargs, 167 | ).apply(grad_output, weight, None) 168 | 169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 170 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 171 | 172 | if ctx.needs_input_grad[2]: 173 | grad_bias = grad_output.sum((0, 2, 3)) 174 | 175 | return grad_input, grad_weight, grad_bias 176 | 177 | class Conv2dGradWeight(autograd.Function): 178 | @staticmethod 179 | def forward(ctx, grad_output, input): 180 | op = torch._C._jit_get_operation( 181 | "aten::cudnn_convolution_backward_weight" 182 | if not transpose 183 | else "aten::cudnn_convolution_transpose_backward_weight" 184 | ) 185 | flags = [ 186 | torch.backends.cudnn.benchmark, 187 | torch.backends.cudnn.deterministic, 188 | torch.backends.cudnn.allow_tf32, 189 | ] 190 | grad_weight = op( 191 | weight_shape, 192 | grad_output, 193 | input, 194 | padding, 195 | stride, 196 | dilation, 197 | groups, 198 | *flags, 199 | ) 200 | ctx.save_for_backward(grad_output, input) 201 | 202 | return grad_weight 203 | 204 | @staticmethod 205 | def backward(ctx, grad_grad_weight): 206 | grad_output, input = ctx.saved_tensors 207 | grad_grad_output, grad_grad_input = None, None 208 | 209 | if ctx.needs_input_grad[0]: 210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) 211 | 212 | if ctx.needs_input_grad[1]: 213 | p = calc_output_padding( 214 | input_shape=input.shape, output_shape=grad_output.shape 215 | ) 216 | grad_grad_input = conv2d_gradfix( 217 | transpose=(not transpose), 218 | weight_shape=weight_shape, 219 | output_padding=p, 220 | **common_kwargs, 221 | ).apply(grad_output, grad_grad_weight, None) 222 | 223 | return grad_grad_output, grad_grad_input 224 | 225 | conv2d_gradfix_cache[key] = Conv2d 226 | 227 | return Conv2d 228 | -------------------------------------------------------------------------------- /cuda/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, bias, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | if bias: 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | else: 42 | grad_bias = empty 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input.contiguous(), 51 | gradgrad_bias, 52 | out, 53 | 3, 54 | 1, 55 | ctx.negative_slope, 56 | ctx.scale, 57 | ) 58 | 59 | return gradgrad_out, None, None, None, None 60 | 61 | 62 | class FusedLeakyReLUFunction(Function): 63 | @staticmethod 64 | def forward(ctx, input, bias, negative_slope, scale): 65 | empty = input.new_empty(0) 66 | 67 | ctx.bias = bias is not None 68 | 69 | if bias is None: 70 | bias = empty 71 | 72 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 73 | ctx.save_for_backward(out) 74 | ctx.negative_slope = negative_slope 75 | ctx.scale = scale 76 | 77 | return out 78 | 79 | @staticmethod 80 | def backward(ctx, grad_output): 81 | out, = ctx.saved_tensors 82 | 83 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 84 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 85 | ) 86 | 87 | if not ctx.bias: 88 | grad_bias = None 89 | 90 | return grad_input, grad_bias, None, None 91 | 92 | 93 | class FusedLeakyReLU(nn.Module): 94 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 95 | super().__init__() 96 | 97 | if bias: 98 | self.bias = nn.Parameter(torch.zeros(channel)) 99 | 100 | else: 101 | self.bias = None 102 | 103 | self.negative_slope = negative_slope 104 | self.scale = scale 105 | 106 | def forward(self, input): 107 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 108 | 109 | 110 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 111 | if input.device.type == "cpu": 112 | if bias is not None: 113 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 114 | return ( 115 | F.leaky_relu( 116 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 117 | ) 118 | * scale 119 | ) 120 | 121 | else: 122 | return F.leaky_relu(input, negative_slope=0.2) * scale 123 | 124 | else: 125 | return FusedLeakyReLUFunction.apply( 126 | input.contiguous(), bias, negative_slope, scale 127 | ) 128 | -------------------------------------------------------------------------------- /cuda/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 6 | const torch::Tensor &bias, 7 | const torch::Tensor &refer, int act, int grad, 8 | float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) \ 11 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) \ 13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) \ 15 | CHECK_CUDA(x); \ 16 | CHECK_CONTIGUOUS(x) 17 | 18 | torch::Tensor fused_bias_act(const torch::Tensor &input, 19 | const torch::Tensor &bias, 20 | const torch::Tensor &refer, int act, int grad, 21 | float alpha, float scale) { 22 | CHECK_INPUT(input); 23 | CHECK_INPUT(bias); 24 | 25 | at::DeviceGuard guard(input.device()); 26 | 27 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 28 | } 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 32 | } -------------------------------------------------------------------------------- /cuda/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | 15 | #include 16 | #include 17 | 18 | template 19 | static __global__ void 20 | fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b, 21 | const scalar_t *p_ref, int act, int grad, scalar_t alpha, 22 | scalar_t scale, int loop_x, int size_x, int step_b, 23 | int size_b, int use_bias, int use_ref) { 24 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 25 | 26 | scalar_t zero = 0.0; 27 | 28 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; 29 | loop_idx++, xi += blockDim.x) { 30 | scalar_t x = p_x[xi]; 31 | 32 | if (use_bias) { 33 | x += p_b[(xi / step_b) % size_b]; 34 | } 35 | 36 | scalar_t ref = use_ref ? p_ref[xi] : zero; 37 | 38 | scalar_t y; 39 | 40 | switch (act * 10 + grad) { 41 | default: 42 | case 10: 43 | y = x; 44 | break; 45 | case 11: 46 | y = x; 47 | break; 48 | case 12: 49 | y = 0.0; 50 | break; 51 | 52 | case 30: 53 | y = (x > 0.0) ? x : x * alpha; 54 | break; 55 | case 31: 56 | y = (ref > 0.0) ? x : x * alpha; 57 | break; 58 | case 32: 59 | y = 0.0; 60 | break; 61 | } 62 | 63 | out[xi] = y * scale; 64 | } 65 | } 66 | 67 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 68 | const torch::Tensor &bias, 69 | const torch::Tensor &refer, int act, int grad, 70 | float alpha, float scale) { 71 | int curDevice = -1; 72 | cudaGetDevice(&curDevice); 73 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 74 | 75 | auto x = input.contiguous(); 76 | auto b = bias.contiguous(); 77 | auto ref = refer.contiguous(); 78 | 79 | int use_bias = b.numel() ? 1 : 0; 80 | int use_ref = ref.numel() ? 1 : 0; 81 | 82 | int size_x = x.numel(); 83 | int size_b = b.numel(); 84 | int step_b = 1; 85 | 86 | for (int i = 1 + 1; i < x.dim(); i++) { 87 | step_b *= x.size(i); 88 | } 89 | 90 | int loop_x = 4; 91 | int block_size = 4 * 32; 92 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 93 | 94 | auto y = torch::empty_like(x); 95 | 96 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 97 | x.scalar_type(), "fused_bias_act_kernel", [&] { 98 | fused_bias_act_kernel<<>>( 99 | y.data_ptr(), x.data_ptr(), 100 | b.data_ptr(), ref.data_ptr(), act, grad, alpha, 101 | scale, loop_x, size_x, step_b, size_b, use_bias, use_ref); 102 | }); 103 | 104 | return y; 105 | } -------------------------------------------------------------------------------- /cuda/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 5 | const torch::Tensor &kernel, int up_x, int up_y, 6 | int down_x, int down_y, int pad_x0, int pad_x1, 7 | int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) \ 10 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) \ 12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) \ 14 | CHECK_CUDA(x); \ 15 | CHECK_CONTIGUOUS(x) 16 | 17 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel, 18 | int up_x, int up_y, int down_x, int down_y, int pad_x0, 19 | int pad_x1, int pad_y0, int pad_y1) { 20 | CHECK_INPUT(input); 21 | CHECK_INPUT(kernel); 22 | 23 | at::DeviceGuard guard(input.device()); 24 | 25 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, 26 | pad_y0, pad_y1); 27 | } 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 31 | } -------------------------------------------------------------------------------- /cuda/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | import os 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | upfirdn2d_op = load( 12 | "upfirdn2d", 13 | sources=[ 14 | os.path.join(module_path, "upfirdn2d.cpp"), 15 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class UpFirDn2dBackward(Function): 21 | @staticmethod 22 | def forward( 23 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 24 | ): 25 | 26 | up_x, up_y = up 27 | down_x, down_y = down 28 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 29 | 30 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 31 | 32 | grad_input = upfirdn2d_op.upfirdn2d( 33 | grad_output, 34 | grad_kernel, 35 | down_x, 36 | down_y, 37 | up_x, 38 | up_y, 39 | g_pad_x0, 40 | g_pad_x1, 41 | g_pad_y0, 42 | g_pad_y1, 43 | ) 44 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 45 | 46 | ctx.save_for_backward(kernel) 47 | 48 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 49 | 50 | ctx.up_x = up_x 51 | ctx.up_y = up_y 52 | ctx.down_x = down_x 53 | ctx.down_y = down_y 54 | ctx.pad_x0 = pad_x0 55 | ctx.pad_x1 = pad_x1 56 | ctx.pad_y0 = pad_y0 57 | ctx.pad_y1 = pad_y1 58 | ctx.in_size = in_size 59 | ctx.out_size = out_size 60 | 61 | return grad_input 62 | 63 | @staticmethod 64 | def backward(ctx, gradgrad_input): 65 | kernel, = ctx.saved_tensors 66 | 67 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 68 | 69 | gradgrad_out = upfirdn2d_op.upfirdn2d( 70 | gradgrad_input, 71 | kernel, 72 | ctx.up_x, 73 | ctx.up_y, 74 | ctx.down_x, 75 | ctx.down_y, 76 | ctx.pad_x0, 77 | ctx.pad_x1, 78 | ctx.pad_y0, 79 | ctx.pad_y1, 80 | ) 81 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 82 | gradgrad_out = gradgrad_out.view( 83 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 84 | ) 85 | 86 | return gradgrad_out, None, None, None, None, None, None, None, None 87 | 88 | 89 | class UpFirDn2d(Function): 90 | @staticmethod 91 | def forward(ctx, input, kernel, up, down, pad): 92 | up_x, up_y = up 93 | down_x, down_y = down 94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 95 | 96 | kernel_h, kernel_w = kernel.shape 97 | batch, channel, in_h, in_w = input.shape 98 | ctx.in_size = input.shape 99 | 100 | input = input.reshape(-1, in_h, in_w, 1) 101 | 102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 103 | 104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 106 | ctx.out_size = (out_h, out_w) 107 | 108 | ctx.up = (up_x, up_y) 109 | ctx.down = (down_x, down_y) 110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 111 | 112 | g_pad_x0 = kernel_w - pad_x0 - 1 113 | g_pad_y0 = kernel_h - pad_y0 - 1 114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 116 | 117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 118 | 119 | out = upfirdn2d_op.upfirdn2d( 120 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 121 | ) 122 | # out = out.view(major, out_h, out_w, minor) 123 | out = out.view(-1, channel, out_h, out_w) 124 | 125 | return out 126 | 127 | @staticmethod 128 | def backward(ctx, grad_output): 129 | kernel, grad_kernel = ctx.saved_tensors 130 | 131 | grad_input = None 132 | 133 | if ctx.needs_input_grad[0]: 134 | grad_input = UpFirDn2dBackward.apply( 135 | grad_output, 136 | kernel, 137 | grad_kernel, 138 | ctx.up, 139 | ctx.down, 140 | ctx.pad, 141 | ctx.g_pad, 142 | ctx.in_size, 143 | ctx.out_size, 144 | ) 145 | 146 | return grad_input, None, None, None, None 147 | 148 | 149 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 150 | if not isinstance(up, abc.Iterable): 151 | up = (up, up) 152 | 153 | if not isinstance(down, abc.Iterable): 154 | down = (down, down) 155 | 156 | if len(pad) == 2: 157 | pad = (pad[0], pad[1], pad[0], pad[1]) 158 | 159 | if input.device.type == "cpu": 160 | out = upfirdn2d_native(input, kernel, *up, *down, *pad) 161 | 162 | else: 163 | out = UpFirDn2d.apply(input, kernel, up, down, pad) 164 | 165 | return out 166 | 167 | 168 | def upfirdn2d_native( 169 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 170 | ): 171 | _, channel, in_h, in_w = input.shape 172 | input = input.reshape(-1, in_h, in_w, 1) 173 | 174 | _, in_h, in_w, minor = input.shape 175 | kernel_h, kernel_w = kernel.shape 176 | 177 | out = input.view(-1, in_h, 1, in_w, 1, minor) 178 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 179 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 180 | 181 | out = F.pad( 182 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 183 | ) 184 | out = out[ 185 | :, 186 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 187 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 188 | :, 189 | ] 190 | 191 | out = out.permute(0, 3, 1, 2) 192 | out = out.reshape( 193 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 194 | ) 195 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 196 | out = F.conv2d(out, w) 197 | out = out.reshape( 198 | -1, 199 | minor, 200 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 201 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 202 | ) 203 | out = out.permute(0, 2, 3, 1) 204 | out = out[:, ::down_y, ::down_x, :] 205 | 206 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 207 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 208 | 209 | return out.view(-1, channel, out_h, out_w) 210 | -------------------------------------------------------------------------------- /cuda/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 18 | int c = a / b; 19 | 20 | if (c * b > a) { 21 | c--; 22 | } 23 | 24 | return c; 25 | } 26 | 27 | struct UpFirDn2DKernelParams { 28 | int up_x; 29 | int up_y; 30 | int down_x; 31 | int down_y; 32 | int pad_x0; 33 | int pad_x1; 34 | int pad_y0; 35 | int pad_y1; 36 | 37 | int major_dim; 38 | int in_h; 39 | int in_w; 40 | int minor_dim; 41 | int kernel_h; 42 | int kernel_w; 43 | int out_h; 44 | int out_w; 45 | int loop_major; 46 | int loop_x; 47 | }; 48 | 49 | template 50 | __global__ void upfirdn2d_kernel_large(scalar_t *out, const scalar_t *input, 51 | const scalar_t *kernel, 52 | const UpFirDn2DKernelParams p) { 53 | int minor_idx = blockIdx.x * blockDim.x + threadIdx.x; 54 | int out_y = minor_idx / p.minor_dim; 55 | minor_idx -= out_y * p.minor_dim; 56 | int out_x_base = blockIdx.y * p.loop_x * blockDim.y + threadIdx.y; 57 | int major_idx_base = blockIdx.z * p.loop_major; 58 | 59 | if (out_x_base >= p.out_w || out_y >= p.out_h || 60 | major_idx_base >= p.major_dim) { 61 | return; 62 | } 63 | 64 | int mid_y = out_y * p.down_y + p.up_y - 1 - p.pad_y0; 65 | int in_y = min(max(floor_div(mid_y, p.up_y), 0), p.in_h); 66 | int h = min(max(floor_div(mid_y + p.kernel_h, p.up_y), 0), p.in_h) - in_y; 67 | int kernel_y = mid_y + p.kernel_h - (in_y + 1) * p.up_y; 68 | 69 | for (int loop_major = 0, major_idx = major_idx_base; 70 | loop_major < p.loop_major && major_idx < p.major_dim; 71 | loop_major++, major_idx++) { 72 | for (int loop_x = 0, out_x = out_x_base; 73 | loop_x < p.loop_x && out_x < p.out_w; loop_x++, out_x += blockDim.y) { 74 | int mid_x = out_x * p.down_x + p.up_x - 1 - p.pad_x0; 75 | int in_x = min(max(floor_div(mid_x, p.up_x), 0), p.in_w); 76 | int w = min(max(floor_div(mid_x + p.kernel_w, p.up_x), 0), p.in_w) - in_x; 77 | int kernel_x = mid_x + p.kernel_w - (in_x + 1) * p.up_x; 78 | 79 | const scalar_t *x_p = 80 | &input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + 81 | minor_idx]; 82 | const scalar_t *k_p = &kernel[kernel_y * p.kernel_w + kernel_x]; 83 | int x_px = p.minor_dim; 84 | int k_px = -p.up_x; 85 | int x_py = p.in_w * p.minor_dim; 86 | int k_py = -p.up_y * p.kernel_w; 87 | 88 | scalar_t v = 0.0f; 89 | 90 | for (int y = 0; y < h; y++) { 91 | for (int x = 0; x < w; x++) { 92 | v += static_cast(*x_p) * static_cast(*k_p); 93 | x_p += x_px; 94 | k_p += k_px; 95 | } 96 | 97 | x_p += x_py - w * x_px; 98 | k_p += k_py - w * k_px; 99 | } 100 | 101 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 102 | minor_idx] = v; 103 | } 104 | } 105 | } 106 | 107 | template 109 | __global__ void upfirdn2d_kernel(scalar_t *out, const scalar_t *input, 110 | const scalar_t *kernel, 111 | const UpFirDn2DKernelParams p) { 112 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 113 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 114 | 115 | __shared__ volatile float sk[kernel_h][kernel_w]; 116 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 117 | 118 | int minor_idx = blockIdx.x; 119 | int tile_out_y = minor_idx / p.minor_dim; 120 | minor_idx -= tile_out_y * p.minor_dim; 121 | tile_out_y *= tile_out_h; 122 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 123 | int major_idx_base = blockIdx.z * p.loop_major; 124 | 125 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | 126 | major_idx_base >= p.major_dim) { 127 | return; 128 | } 129 | 130 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; 131 | tap_idx += blockDim.x) { 132 | int ky = tap_idx / kernel_w; 133 | int kx = tap_idx - ky * kernel_w; 134 | scalar_t v = 0.0; 135 | 136 | if (kx < p.kernel_w & ky < p.kernel_h) { 137 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 138 | } 139 | 140 | sk[ky][kx] = v; 141 | } 142 | 143 | for (int loop_major = 0, major_idx = major_idx_base; 144 | loop_major < p.loop_major & major_idx < p.major_dim; 145 | loop_major++, major_idx++) { 146 | for (int loop_x = 0, tile_out_x = tile_out_x_base; 147 | loop_x < p.loop_x & tile_out_x < p.out_w; 148 | loop_x++, tile_out_x += tile_out_w) { 149 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 150 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 151 | int tile_in_x = floor_div(tile_mid_x, up_x); 152 | int tile_in_y = floor_div(tile_mid_y, up_y); 153 | 154 | __syncthreads(); 155 | 156 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; 157 | in_idx += blockDim.x) { 158 | int rel_in_y = in_idx / tile_in_w; 159 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 160 | int in_x = rel_in_x + tile_in_x; 161 | int in_y = rel_in_y + tile_in_y; 162 | 163 | scalar_t v = 0.0; 164 | 165 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 166 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * 167 | p.minor_dim + 168 | minor_idx]; 169 | } 170 | 171 | sx[rel_in_y][rel_in_x] = v; 172 | } 173 | 174 | __syncthreads(); 175 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; 176 | out_idx += blockDim.x) { 177 | int rel_out_y = out_idx / tile_out_w; 178 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 179 | int out_x = rel_out_x + tile_out_x; 180 | int out_y = rel_out_y + tile_out_y; 181 | 182 | int mid_x = tile_mid_x + rel_out_x * down_x; 183 | int mid_y = tile_mid_y + rel_out_y * down_y; 184 | int in_x = floor_div(mid_x, up_x); 185 | int in_y = floor_div(mid_y, up_y); 186 | int rel_in_x = in_x - tile_in_x; 187 | int rel_in_y = in_y - tile_in_y; 188 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 189 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 190 | 191 | scalar_t v = 0.0; 192 | 193 | #pragma unroll 194 | for (int y = 0; y < kernel_h / up_y; y++) 195 | #pragma unroll 196 | for (int x = 0; x < kernel_w / up_x; x++) 197 | v += sx[rel_in_y + y][rel_in_x + x] * 198 | sk[kernel_y + y * up_y][kernel_x + x * up_x]; 199 | 200 | if (out_x < p.out_w & out_y < p.out_h) { 201 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + 202 | minor_idx] = v; 203 | } 204 | } 205 | } 206 | } 207 | } 208 | 209 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 210 | const torch::Tensor &kernel, int up_x, int up_y, 211 | int down_x, int down_y, int pad_x0, int pad_x1, 212 | int pad_y0, int pad_y1) { 213 | int curDevice = -1; 214 | cudaGetDevice(&curDevice); 215 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 216 | 217 | UpFirDn2DKernelParams p; 218 | 219 | auto x = input.contiguous(); 220 | auto k = kernel.contiguous(); 221 | 222 | p.major_dim = x.size(0); 223 | p.in_h = x.size(1); 224 | p.in_w = x.size(2); 225 | p.minor_dim = x.size(3); 226 | p.kernel_h = k.size(0); 227 | p.kernel_w = k.size(1); 228 | p.up_x = up_x; 229 | p.up_y = up_y; 230 | p.down_x = down_x; 231 | p.down_y = down_y; 232 | p.pad_x0 = pad_x0; 233 | p.pad_x1 = pad_x1; 234 | p.pad_y0 = pad_y0; 235 | p.pad_y1 = pad_y1; 236 | 237 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / 238 | p.down_y; 239 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / 240 | p.down_x; 241 | 242 | auto out = 243 | at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 244 | 245 | int mode = -1; 246 | 247 | int tile_out_h = -1; 248 | int tile_out_w = -1; 249 | 250 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 251 | p.kernel_h <= 4 && p.kernel_w <= 4) { 252 | mode = 1; 253 | tile_out_h = 16; 254 | tile_out_w = 64; 255 | } 256 | 257 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && 258 | p.kernel_h <= 3 && p.kernel_w <= 3) { 259 | mode = 2; 260 | tile_out_h = 16; 261 | tile_out_w = 64; 262 | } 263 | 264 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 265 | p.kernel_h <= 4 && p.kernel_w <= 4) { 266 | mode = 3; 267 | tile_out_h = 16; 268 | tile_out_w = 64; 269 | } 270 | 271 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && 272 | p.kernel_h <= 2 && p.kernel_w <= 2) { 273 | mode = 4; 274 | tile_out_h = 16; 275 | tile_out_w = 64; 276 | } 277 | 278 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 279 | p.kernel_h <= 4 && p.kernel_w <= 4) { 280 | mode = 5; 281 | tile_out_h = 8; 282 | tile_out_w = 32; 283 | } 284 | 285 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && 286 | p.kernel_h <= 2 && p.kernel_w <= 2) { 287 | mode = 6; 288 | tile_out_h = 8; 289 | tile_out_w = 32; 290 | } 291 | 292 | dim3 block_size; 293 | dim3 grid_size; 294 | 295 | if (tile_out_h > 0 && tile_out_w > 0) { 296 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 297 | p.loop_x = 1; 298 | block_size = dim3(32 * 8, 1, 1); 299 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 300 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 301 | (p.major_dim - 1) / p.loop_major + 1); 302 | } else { 303 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 304 | p.loop_x = 4; 305 | block_size = dim3(4, 32, 1); 306 | grid_size = dim3((p.out_h * p.minor_dim - 1) / block_size.x + 1, 307 | (p.out_w - 1) / (p.loop_x * block_size.y) + 1, 308 | (p.major_dim - 1) / p.loop_major + 1); 309 | } 310 | 311 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 312 | switch (mode) { 313 | case 1: 314 | upfirdn2d_kernel 315 | <<>>(out.data_ptr(), 316 | x.data_ptr(), 317 | k.data_ptr(), p); 318 | 319 | break; 320 | 321 | case 2: 322 | upfirdn2d_kernel 323 | <<>>(out.data_ptr(), 324 | x.data_ptr(), 325 | k.data_ptr(), p); 326 | 327 | break; 328 | 329 | case 3: 330 | upfirdn2d_kernel 331 | <<>>(out.data_ptr(), 332 | x.data_ptr(), 333 | k.data_ptr(), p); 334 | 335 | break; 336 | 337 | case 4: 338 | upfirdn2d_kernel 339 | <<>>(out.data_ptr(), 340 | x.data_ptr(), 341 | k.data_ptr(), p); 342 | 343 | break; 344 | 345 | case 5: 346 | upfirdn2d_kernel 347 | <<>>(out.data_ptr(), 348 | x.data_ptr(), 349 | k.data_ptr(), p); 350 | 351 | break; 352 | 353 | case 6: 354 | upfirdn2d_kernel 355 | <<>>(out.data_ptr(), 356 | x.data_ptr(), 357 | k.data_ptr(), p); 358 | 359 | break; 360 | 361 | default: 362 | upfirdn2d_kernel_large<<>>( 363 | out.data_ptr(), x.data_ptr(), 364 | k.data_ptr(), p); 365 | } 366 | }); 367 | 368 | return out; 369 | } -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils import * 3 | from StyleGAN2 import run_fn 4 | 5 | def parse_args(): 6 | desc = "Pytorch implementation of StyleGAN2" 7 | parser = argparse.ArgumentParser(description=desc) 8 | parser.add_argument('--phase', type=str, default='train', help='[train, test, draw]') 9 | parser.add_argument('--dataset', type=str, default='FFHQ', help='dataset_name') 10 | parser.add_argument('--nsml', type=str2bool, default=False, help='NAVER NSML use or not') 11 | 12 | parser.add_argument('--n_total_image', type=int, default=25000, help='The total iterations') 13 | parser.add_argument('--img_size', type=int, default=256, help='The size of image') 14 | parser.add_argument('--batch_size', type=int, default=8, help='batch sizes for each gpus') 15 | 16 | parser.add_argument('--lazy_regularization', type=str2bool, default=True, help='lazy_regularization') 17 | parser.add_argument('--d_reg_every', type=int, default=16, help='interval of the applying r1 regularization') 18 | parser.add_argument('--g_reg_every', type=int, default=4, help='interval of the applying path length regularization') 19 | parser.add_argument("--lr", type=float, default=0.002, help="learning rate") 20 | parser.add_argument('--channel_multiplier', type=int, default=2, help="channel multiplier factor for the model. config-f = 2, else = 1") 21 | parser.add_argument("--mixing_prob", type=float, default=0.9, help="probability of latent code mixing") 22 | 23 | 24 | parser.add_argument('--print_freq', type=int, default=2000, help='The number of image_print_freq') 25 | parser.add_argument('--save_freq', type=int, default=10000, help='The number of ckpt_save_freq') 26 | parser.add_argument('--n_sample', type=int, default=64, help='number of the samples generated during training') 27 | 28 | parser.add_argument('--checkpoint_dir', type=str, default='checkpoint', 29 | help='Directory name to save the checkpoints') 30 | parser.add_argument('--result_dir', type=str, default='results', 31 | help='Directory name to save the generated images') 32 | parser.add_argument('--log_dir', type=str, default='logs', 33 | help='Directory name to save training logs') 34 | parser.add_argument('--sample_dir', type=str, default='samples', 35 | help='Directory name to save the samples on training') 36 | 37 | return check_args(parser.parse_args()) 38 | 39 | 40 | """checking arguments""" 41 | def check_args(args): 42 | # --checkpoint_dir 43 | check_folder(args.checkpoint_dir) 44 | 45 | # --result_dir 46 | check_folder(args.result_dir) 47 | 48 | # --result_dir 49 | check_folder(args.log_dir) 50 | 51 | # --sample_dir 52 | check_folder(args.sample_dir) 53 | 54 | # --batch_size 55 | try: 56 | assert args.batch_size >= 1 57 | except: 58 | print('batch size must be larger than or equal to one', flush=True) 59 | 60 | return args 61 | 62 | """main""" 63 | def main(): 64 | 65 | args = vars(parse_args()) 66 | 67 | # run 68 | multi_gpu_run(ddp_fn=run_fn, args=args) 69 | 70 | 71 | 72 | if __name__ == '__main__': 73 | main() -------------------------------------------------------------------------------- /metric/cal_fid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from torchvision import models 6 | import torch.distributed as dist 7 | import math 8 | from tqdm import tqdm 9 | from torchvision import transforms 10 | from scipy import linalg 11 | import pickle, os 12 | 13 | class GatherLayer(torch.autograd.Function): 14 | @staticmethod 15 | def forward(ctx, input): 16 | ctx.save_for_backward(input) 17 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] 18 | dist.all_gather(output, input) 19 | return tuple(output) 20 | 21 | @staticmethod 22 | def backward(ctx, *grads): 23 | input, = ctx.saved_tensors 24 | grad_out = torch.zeros_like(input) 25 | grad_out[:] = grads[dist.get_rank()] 26 | return grad_out 27 | 28 | 29 | class InceptionV3(nn.Module): 30 | def __init__(self): 31 | super().__init__() 32 | inception = models.inception_v3(pretrained=True) 33 | self.block1 = nn.Sequential( 34 | inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, 35 | inception.Conv2d_2b_3x3, 36 | nn.MaxPool2d(kernel_size=3, stride=2)) 37 | self.block2 = nn.Sequential( 38 | inception.Conv2d_3b_1x1, inception.Conv2d_4a_3x3, 39 | nn.MaxPool2d(kernel_size=3, stride=2)) 40 | self.block3 = nn.Sequential( 41 | inception.Mixed_5b, inception.Mixed_5c, 42 | inception.Mixed_5d, inception.Mixed_6a, 43 | inception.Mixed_6b, inception.Mixed_6c, 44 | inception.Mixed_6d, inception.Mixed_6e) 45 | self.block4 = nn.Sequential( 46 | inception.Mixed_7a, inception.Mixed_7b, 47 | inception.Mixed_7c, 48 | nn.AdaptiveAvgPool2d(output_size=(1, 1))) 49 | 50 | def forward(self, x): 51 | x = self.block1(x) 52 | x = self.block2(x) 53 | x = self.block3(x) 54 | x = self.block4(x) 55 | return x.view(x.size(0), -1) 56 | 57 | def extract_real_feature(data_loader, inception, device): 58 | feats = [] 59 | 60 | for img in tqdm(data_loader): 61 | img = img.to(device) 62 | feat = inception(img) 63 | 64 | feats.append(feat) 65 | 66 | feats = gather_feats(feats) 67 | 68 | return feats 69 | 70 | def normalize_fake_img(imgs): 71 | mean = [0.485, 0.456, 0.406] 72 | std = [0.229, 0.224, 0.225] 73 | 74 | imgs = (imgs + 1) / 2 # -1 ~ 1 to 0~1 75 | imgs = torch.clamp(imgs, 0, 1, out=None) 76 | imgs = F.interpolate(imgs, size=(299, 299), mode="bilinear") 77 | imgs = transforms.Normalize(mean=mean, std=std)(imgs) 78 | 79 | return imgs 80 | 81 | def gather_feats(feats): 82 | feats = torch.cat(feats, dim=0) 83 | feats = torch.cat(GatherLayer.apply(feats), dim=0) 84 | feats = feats.detach().cpu().numpy() 85 | 86 | return feats 87 | 88 | def extract_fake_feature(generator, inception, num_gpus, device, latent_dim, fake_samples=50000, batch_size=16): 89 | num_batches = int(math.ceil(float(fake_samples) / float(batch_size * num_gpus))) 90 | feats = [] 91 | for _ in tqdm(range(num_batches)): 92 | z = [torch.randn([batch_size, latent_dim], device=device)] 93 | fake_img = generator(z) 94 | 95 | fake_img = normalize_fake_img(fake_img) 96 | 97 | feat = inception(fake_img) 98 | 99 | feats.append(feat) 100 | 101 | feats = gather_feats(feats) 102 | 103 | return feats 104 | 105 | def get_statistics(feats): 106 | mu = np.mean(feats, axis=0) 107 | cov = np.cov(feats, rowvar=False) 108 | 109 | return mu, cov 110 | 111 | def frechet_distance(mu, cov, mu2, cov2): 112 | cc, _ = linalg.sqrtm(np.dot(cov, cov2), disp=False) 113 | dist = np.sum((mu -mu2)**2) + np.trace(cov + cov2 - 2*cc) 114 | return np.real(dist) 115 | 116 | @torch.no_grad() 117 | def calculate_fid(data_loader, generator_model, inception_model, dataset_name, rank, device, 118 | latent_dim, fake_samples=50000, batch_size=16): 119 | 120 | num_gpus = torch.cuda.device_count() 121 | 122 | generator_model = generator_model.eval() 123 | inception_model = inception_model.eval() 124 | 125 | pickle_name = '{}_mu_cov.pickle'.format(dataset_name) 126 | cache = os.path.exists(pickle_name) 127 | 128 | if cache: 129 | with open(pickle_name, 'rb') as f: 130 | real_mu, real_cov = pickle.load(f) 131 | else: 132 | real_feats = extract_real_feature(data_loader, inception_model, device=device) 133 | real_mu, real_cov = get_statistics(real_feats) 134 | 135 | if rank == 0: 136 | with open(pickle_name, 'wb') as f: 137 | pickle.dump((real_mu, real_cov), f, protocol=pickle.HIGHEST_PROTOCOL) 138 | 139 | 140 | fake_feats = extract_fake_feature(generator_model, inception_model, num_gpus, device, latent_dim, fake_samples, batch_size) 141 | fake_mu, fake_cov = get_statistics(fake_feats) 142 | 143 | fid = frechet_distance(real_mu, real_cov, fake_mu, fake_cov) 144 | return fid 145 | 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | from ops import * 2 | import random 3 | 4 | class Generator(nn.Module): 5 | def __init__( 6 | self, 7 | size, 8 | style_dim=512, 9 | n_mlp=8, 10 | channel_multiplier=2, 11 | blur_kernel=[1, 3, 3, 1], 12 | lr_mlp=0.01, 13 | ): 14 | super().__init__() 15 | 16 | self.size = size 17 | 18 | self.style_dim = style_dim 19 | 20 | layers = [PixelNorm()] 21 | 22 | for i in range(n_mlp): 23 | layers.append( 24 | EqualLinear( 25 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" 26 | ) 27 | ) 28 | 29 | self.style = nn.Sequential(*layers) 30 | 31 | self.channels = { 32 | 4: 512, 33 | 8: 512, 34 | 16: 512, 35 | 32: 512, 36 | 64: 256 * channel_multiplier, 37 | 128: 128 * channel_multiplier, 38 | 256: 64 * channel_multiplier, 39 | 512: 32 * channel_multiplier, 40 | 1024: 16 * channel_multiplier, 41 | } 42 | 43 | self.input = ConstantInput(self.channels[4]) 44 | self.conv1 = StyledConv( 45 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 46 | ) 47 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) 48 | 49 | self.log_size = int(math.log(size, 2)) 50 | self.num_layers = (self.log_size - 2) * 2 + 1 51 | 52 | self.convs = nn.ModuleList() 53 | self.upsamples = nn.ModuleList() 54 | self.to_rgbs = nn.ModuleList() 55 | self.noises = nn.Module() 56 | 57 | in_channel = self.channels[4] 58 | 59 | for layer_idx in range(self.num_layers): 60 | res = (layer_idx + 5) // 2 61 | shape = [1, 1, 2 ** res, 2 ** res] 62 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) 63 | 64 | for i in range(3, self.log_size + 1): 65 | out_channel = self.channels[2 ** i] 66 | 67 | self.convs.append( 68 | StyledConv( 69 | in_channel, 70 | out_channel, 71 | 3, 72 | style_dim, 73 | upsample=True, 74 | blur_kernel=blur_kernel, 75 | ) 76 | ) 77 | 78 | self.convs.append( 79 | StyledConv( 80 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 81 | ) 82 | ) 83 | 84 | self.to_rgbs.append(ToRGB(out_channel, style_dim)) 85 | 86 | in_channel = out_channel 87 | 88 | self.n_latent = self.log_size * 2 - 2 89 | 90 | def make_noise(self): 91 | device = self.input.input.device 92 | 93 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 94 | 95 | for i in range(3, self.log_size + 1): 96 | for _ in range(2): 97 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 98 | 99 | return noises 100 | 101 | def mean_latent(self, n_latent): 102 | latent_in = torch.randn( 103 | n_latent, self.style_dim, device=self.input.input.device 104 | ) 105 | latent = self.style(latent_in).mean(0, keepdim=True) 106 | 107 | return latent 108 | 109 | def get_latent(self, input): 110 | return self.style(input) 111 | 112 | def forward( 113 | self, 114 | styles, 115 | return_latents=False, 116 | inject_index=None, 117 | truncation=1, 118 | truncation_latent=None, 119 | input_is_latent=False, 120 | noise=None, 121 | randomize_noise=True, 122 | ): 123 | if not input_is_latent: 124 | styles = [self.style(s) for s in styles] 125 | 126 | if noise is None: 127 | if randomize_noise: 128 | noise = [None] * self.num_layers 129 | else: 130 | noise = [ 131 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) 132 | ] 133 | 134 | if truncation < 1: 135 | style_t = [] 136 | 137 | for style in styles: 138 | style_t.append( 139 | truncation_latent + truncation * (style - truncation_latent) 140 | ) 141 | 142 | styles = style_t 143 | 144 | if len(styles) < 2: 145 | inject_index = self.n_latent 146 | 147 | if styles[0].ndim < 3: 148 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 149 | 150 | else: 151 | latent = styles[0] 152 | 153 | else: 154 | if inject_index is None: 155 | inject_index = random.randint(1, self.n_latent - 1) 156 | 157 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 158 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 159 | 160 | latent = torch.cat([latent, latent2], 1) 161 | 162 | out = self.input(latent) 163 | out = self.conv1(out, latent[:, 0], noise=noise[0]) 164 | 165 | skip = self.to_rgb1(out, latent[:, 1]) 166 | 167 | i = 1 168 | for conv1, conv2, noise1, noise2, to_rgb in zip( 169 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 170 | ): 171 | out = conv1(out, latent[:, i], noise=noise1) 172 | out = conv2(out, latent[:, i + 1], noise=noise2) 173 | skip = to_rgb(out, latent[:, i + 2], skip) 174 | 175 | i += 2 176 | 177 | image = skip 178 | 179 | if return_latents: 180 | return image, latent 181 | 182 | else: 183 | return image 184 | 185 | 186 | class ConvLayer(nn.Sequential): 187 | def __init__( 188 | self, 189 | in_channel, 190 | out_channel, 191 | kernel_size, 192 | downsample=False, 193 | blur_kernel=[1, 3, 3, 1], 194 | bias=True, 195 | activate=True, 196 | ): 197 | layers = [] 198 | 199 | if downsample: 200 | factor = 2 201 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 202 | pad0 = (p + 1) // 2 203 | pad1 = p // 2 204 | 205 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 206 | 207 | stride = 2 208 | self.padding = 0 209 | 210 | else: 211 | stride = 1 212 | self.padding = kernel_size // 2 213 | 214 | layers.append( 215 | EqualConv2d( 216 | in_channel, 217 | out_channel, 218 | kernel_size, 219 | padding=self.padding, 220 | stride=stride, 221 | bias=bias and not activate, 222 | ) 223 | ) 224 | 225 | if activate: 226 | layers.append(FusedLeakyReLU(out_channel, bias=bias)) 227 | 228 | super().__init__(*layers) 229 | 230 | 231 | class ResBlock(nn.Module): 232 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): 233 | super().__init__() 234 | 235 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 236 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) 237 | 238 | self.skip = ConvLayer( 239 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False 240 | ) 241 | 242 | def forward(self, input): 243 | out = self.conv1(input) 244 | out = self.conv2(out) 245 | 246 | skip = self.skip(input) 247 | out = (out + skip) / math.sqrt(2) 248 | 249 | return out 250 | 251 | 252 | class Discriminator(nn.Module): 253 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): 254 | super().__init__() 255 | 256 | channels = { 257 | 4: 512, 258 | 8: 512, 259 | 16: 512, 260 | 32: 512, 261 | 64: 256 * channel_multiplier, 262 | 128: 128 * channel_multiplier, 263 | 256: 64 * channel_multiplier, 264 | 512: 32 * channel_multiplier, 265 | 1024: 16 * channel_multiplier, 266 | } 267 | 268 | convs = [ConvLayer(3, channels[size], 1)] 269 | 270 | log_size = int(math.log(size, 2)) 271 | 272 | in_channel = channels[size] 273 | 274 | for i in range(log_size, 2, -1): 275 | out_channel = channels[2 ** (i - 1)] 276 | 277 | convs.append(ResBlock(in_channel, out_channel, blur_kernel)) 278 | 279 | in_channel = out_channel 280 | 281 | self.convs = nn.Sequential(*convs) 282 | 283 | self.stddev_group = 4 284 | self.stddev_feat = 1 285 | 286 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) 287 | self.final_linear = nn.Sequential( 288 | EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), 289 | EqualLinear(channels[4], 1), 290 | ) 291 | 292 | def forward(self, input): 293 | out = self.convs(input) 294 | 295 | batch, channel, height, width = out.shape 296 | group = min(batch, self.stddev_group) 297 | stddev = out.view( 298 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 299 | ) 300 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 301 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 302 | stddev = stddev.repeat(group, 1, height, width) 303 | out = torch.cat([out, stddev], 1) 304 | 305 | out = self.final_conv(out) 306 | 307 | out = out.view(batch, -1) 308 | out = self.final_linear(out) 309 | 310 | return out -------------------------------------------------------------------------------- /ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn # class 3 | from torch.nn import functional as F # function 4 | from torch import autograd 5 | from cuda import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix 6 | import math 7 | import random 8 | class PixelNorm(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def forward(self, input): 13 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 14 | 15 | 16 | def make_kernel(k): 17 | k = torch.tensor(k, dtype=torch.float32) 18 | 19 | if k.ndim == 1: 20 | k = k[None, :] * k[:, None] 21 | 22 | k /= k.sum() 23 | 24 | return k 25 | 26 | 27 | class Upsample(nn.Module): 28 | def __init__(self, kernel, factor=2): 29 | super().__init__() 30 | 31 | self.factor = factor 32 | kernel = make_kernel(kernel) * (factor ** 2) 33 | self.register_buffer("kernel", kernel) 34 | 35 | p = kernel.shape[0] - factor 36 | 37 | pad0 = (p + 1) // 2 + factor - 1 38 | pad1 = p // 2 39 | 40 | self.pad = (pad0, pad1) 41 | 42 | def forward(self, input): 43 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 44 | 45 | return out 46 | 47 | 48 | class Downsample(nn.Module): 49 | def __init__(self, kernel, factor=2): 50 | super().__init__() 51 | 52 | self.factor = factor 53 | kernel = make_kernel(kernel) 54 | self.register_buffer("kernel", kernel) 55 | 56 | p = kernel.shape[0] - factor 57 | 58 | pad0 = (p + 1) // 2 59 | pad1 = p // 2 60 | 61 | self.pad = (pad0, pad1) 62 | 63 | def forward(self, input): 64 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 65 | 66 | return out 67 | 68 | 69 | class Blur(nn.Module): 70 | def __init__(self, kernel, pad, upsample_factor=1): 71 | super().__init__() 72 | 73 | kernel = make_kernel(kernel) 74 | 75 | if upsample_factor > 1: 76 | kernel = kernel * (upsample_factor ** 2) 77 | 78 | self.register_buffer("kernel", kernel) 79 | 80 | self.pad = pad 81 | 82 | def forward(self, input): 83 | out = upfirdn2d(input, self.kernel, pad=self.pad) 84 | 85 | return out 86 | 87 | 88 | class EqualConv2d(nn.Module): 89 | def __init__( 90 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 91 | ): 92 | super().__init__() 93 | 94 | self.weight = nn.Parameter( 95 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 96 | ) 97 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 98 | 99 | self.stride = stride 100 | self.padding = padding 101 | 102 | if bias: 103 | self.bias = nn.Parameter(torch.zeros(out_channel)) 104 | 105 | else: 106 | self.bias = None 107 | 108 | def forward(self, input): 109 | out = conv2d_gradfix.conv2d( 110 | input, 111 | self.weight * self.scale, 112 | bias=self.bias, 113 | stride=self.stride, 114 | padding=self.padding, 115 | ) 116 | 117 | return out 118 | 119 | def __repr__(self): 120 | return ( 121 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," 122 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 123 | ) 124 | 125 | 126 | class EqualLinear(nn.Module): 127 | def __init__( 128 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 129 | ): 130 | super().__init__() 131 | 132 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 133 | 134 | if bias: 135 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 136 | 137 | else: 138 | self.bias = None 139 | 140 | self.activation = activation 141 | 142 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 143 | self.lr_mul = lr_mul 144 | 145 | def forward(self, input): 146 | if self.activation: 147 | out = F.linear(input, self.weight * self.scale) 148 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 149 | 150 | else: 151 | out = F.linear( 152 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 153 | ) 154 | 155 | return out 156 | 157 | def __repr__(self): 158 | return ( 159 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" 160 | ) 161 | 162 | 163 | class ModulatedConv2d(nn.Module): 164 | def __init__( 165 | self, 166 | in_channel, 167 | out_channel, 168 | kernel_size, 169 | style_dim, 170 | demodulate=True, 171 | upsample=False, 172 | downsample=False, 173 | blur_kernel=[1, 3, 3, 1], 174 | fused=True, 175 | ): 176 | super().__init__() 177 | 178 | self.eps = 1e-8 179 | self.kernel_size = kernel_size 180 | self.in_channel = in_channel 181 | self.out_channel = out_channel 182 | self.upsample = upsample 183 | self.downsample = downsample 184 | 185 | if upsample: 186 | factor = 2 187 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 188 | pad0 = (p + 1) // 2 + factor - 1 189 | pad1 = p // 2 + 1 190 | 191 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 192 | 193 | if downsample: 194 | factor = 2 195 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 196 | pad0 = (p + 1) // 2 197 | pad1 = p // 2 198 | 199 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 200 | 201 | fan_in = in_channel * kernel_size ** 2 202 | self.scale = 1 / math.sqrt(fan_in) 203 | self.padding = kernel_size // 2 204 | 205 | self.weight = nn.Parameter( 206 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 207 | ) 208 | 209 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 210 | 211 | self.demodulate = demodulate 212 | self.fused = fused 213 | 214 | def __repr__(self): 215 | return ( 216 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " 217 | f"upsample={self.upsample}, downsample={self.downsample})" 218 | ) 219 | 220 | def forward(self, input, style): 221 | batch, in_channel, height, width = input.shape 222 | 223 | if not self.fused: 224 | weight = self.scale * self.weight.squeeze(0) 225 | style = self.modulation(style) 226 | 227 | if self.demodulate: 228 | w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1) 229 | dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt() 230 | 231 | input = input * style.reshape(batch, in_channel, 1, 1) 232 | 233 | if self.upsample: 234 | weight = weight.transpose(0, 1) 235 | out = conv2d_gradfix.conv_transpose2d( 236 | input, weight, padding=0, stride=2 237 | ) 238 | out = self.blur(out) 239 | 240 | elif self.downsample: 241 | input = self.blur(input) 242 | out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2) 243 | 244 | else: 245 | out = conv2d_gradfix.conv2d(input, weight, padding=self.padding) 246 | 247 | if self.demodulate: 248 | out = out * dcoefs.view(batch, -1, 1, 1) 249 | 250 | return out 251 | 252 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 253 | weight = self.scale * self.weight * style 254 | 255 | if self.demodulate: 256 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 257 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 258 | 259 | weight = weight.view( 260 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 261 | ) 262 | 263 | if self.upsample: 264 | input = input.view(1, batch * in_channel, height, width) 265 | weight = weight.view( 266 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 267 | ) 268 | weight = weight.transpose(1, 2).reshape( 269 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 270 | ) 271 | out = conv2d_gradfix.conv_transpose2d( 272 | input, weight, padding=0, stride=2, groups=batch 273 | ) 274 | _, _, height, width = out.shape 275 | out = out.view(batch, self.out_channel, height, width) 276 | out = self.blur(out) 277 | 278 | elif self.downsample: 279 | input = self.blur(input) 280 | _, _, height, width = input.shape 281 | input = input.view(1, batch * in_channel, height, width) 282 | out = conv2d_gradfix.conv2d( 283 | input, weight, padding=0, stride=2, groups=batch 284 | ) 285 | _, _, height, width = out.shape 286 | out = out.view(batch, self.out_channel, height, width) 287 | 288 | else: 289 | input = input.view(1, batch * in_channel, height, width) 290 | out = conv2d_gradfix.conv2d( 291 | input, weight, padding=self.padding, groups=batch 292 | ) 293 | _, _, height, width = out.shape 294 | out = out.view(batch, self.out_channel, height, width) 295 | 296 | return out 297 | 298 | 299 | class NoiseInjection(nn.Module): 300 | def __init__(self): 301 | super().__init__() 302 | 303 | self.weight = nn.Parameter(torch.zeros(1)) 304 | 305 | def forward(self, image, noise=None): 306 | if noise is None: 307 | batch, _, height, width = image.shape 308 | noise = image.new_empty(batch, 1, height, width).normal_() 309 | 310 | return image + self.weight * noise 311 | 312 | 313 | class ConstantInput(nn.Module): 314 | def __init__(self, channel, size=4): 315 | super().__init__() 316 | 317 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 318 | 319 | def forward(self, input): 320 | batch = input.shape[0] 321 | out = self.input.repeat(batch, 1, 1, 1) 322 | 323 | return out 324 | 325 | 326 | class StyledConv(nn.Module): 327 | def __init__( 328 | self, 329 | in_channel, 330 | out_channel, 331 | kernel_size, 332 | style_dim, 333 | upsample=False, 334 | blur_kernel=[1, 3, 3, 1], 335 | demodulate=True, 336 | ): 337 | super().__init__() 338 | 339 | self.conv = ModulatedConv2d( 340 | in_channel, 341 | out_channel, 342 | kernel_size, 343 | style_dim, 344 | upsample=upsample, 345 | blur_kernel=blur_kernel, 346 | demodulate=demodulate, 347 | ) 348 | 349 | self.noise = NoiseInjection() 350 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 351 | # self.activate = ScaledLeakyReLU(0.2) 352 | self.activate = FusedLeakyReLU(out_channel) 353 | 354 | def forward(self, input, style, noise=None): 355 | out = self.conv(input, style) 356 | out = self.noise(out, noise=noise) 357 | # out = out + self.bias 358 | out = self.activate(out) 359 | 360 | return out 361 | 362 | 363 | class ToRGB(nn.Module): 364 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): 365 | super().__init__() 366 | 367 | if upsample: 368 | self.upsample = Upsample(blur_kernel) 369 | 370 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) 371 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) 372 | 373 | def forward(self, input, style, skip=None): 374 | out = self.conv(input, style) 375 | out = out + self.bias 376 | 377 | if skip is not None: 378 | skip = self.upsample(skip) 379 | 380 | out = out + skip 381 | 382 | return out 383 | 384 | def d_logistic_loss(real_pred, fake_pred): 385 | real_loss = F.softplus(-real_pred) 386 | fake_loss = F.softplus(fake_pred) 387 | 388 | return real_loss.mean() + fake_loss.mean() 389 | 390 | def d_r1_loss(real_pred, real_img): 391 | with conv2d_gradfix.no_weight_gradients(): 392 | grad_real, = autograd.grad( 393 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 394 | ) 395 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() 396 | 397 | return grad_penalty 398 | 399 | def g_nonsaturating_loss(fake_pred): 400 | loss = F.softplus(-fake_pred).mean() 401 | 402 | return loss 403 | 404 | 405 | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): 406 | noise = torch.randn_like(fake_img) / math.sqrt( 407 | fake_img.shape[2] * fake_img.shape[3] 408 | ) 409 | grad, = autograd.grad( 410 | outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True 411 | ) 412 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) 413 | 414 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) 415 | 416 | path_penalty = (path_lengths - path_mean).pow(2).mean() 417 | 418 | return path_penalty, path_mean.detach(), path_lengths 419 | 420 | def apply_gradients(loss, optim): 421 | optim.zero_grad() 422 | loss.backward() 423 | optim.step() 424 | 425 | def moving_average(model1, model2, decay=0.999): 426 | with torch.no_grad(): 427 | par1 = dict(model1.named_parameters()) 428 | par2 = dict(model2.named_parameters()) 429 | 430 | for k in par1.keys(): 431 | par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) 432 | 433 | def make_noise(batch, latent_dim, n_noise, device): 434 | if n_noise == 1: 435 | return torch.randn([batch, latent_dim], device=device) 436 | 437 | noises = torch.randn([n_noise, batch, latent_dim], device=device).unbind(0) 438 | 439 | return noises 440 | 441 | 442 | def mixing_noise(batch, latent_dim, prob, device): 443 | if prob > 0 and random.random() < prob: 444 | return make_noise(batch, latent_dim, 2, device) 445 | 446 | else: 447 | return [make_noise(batch, latent_dim, 1, device)] 448 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | import os, re 6 | from glob import glob 7 | import torch.distributed as dist 8 | from torch.nn.parallel import DistributedDataParallel 9 | import torch.multiprocessing as torch_multiprocessing 10 | 11 | class ImageDataset(Dataset): 12 | def __init__(self, dataset_path, img_size, fid_transform=False): 13 | self.samples = self.listdir(dataset_path) 14 | 15 | # interpolation=transforms.InterpolationMode.BICUBIC, antialias=True 16 | if fid_transform: 17 | mean = [0.485, 0.456, 0.406] 18 | std = [0.229, 0.224, 0.225] 19 | transform_list = [ 20 | transforms.Resize(size=[img_size, img_size]), 21 | transforms.ToTensor(), # [0, 255] -> [0, 1] 22 | transforms.Normalize(mean=mean, std=std, inplace=True), # [0, 1] -> [-1, 1] 23 | ] 24 | else: 25 | transform_list = [ 26 | transforms.Resize(size=[img_size, img_size]), 27 | transforms.RandomHorizontalFlip(p=0.5), 28 | transforms.ToTensor(), # [0, 255] -> [0, 1] 29 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), # [0, 1] -> [-1, 1] 30 | ] 31 | 32 | self.transform = transforms.Compose(transform_list) 33 | 34 | def listdir(self, dir_path): 35 | extensions = ['png', 'jpg', 'jpeg', 'JPG'] 36 | file_path = [] 37 | for ext in extensions: 38 | file_path += glob(os.path.join(dir_path, '*.' + ext)) 39 | 40 | file_path.sort() 41 | return file_path 42 | 43 | def __getitem__(self, index): 44 | sample_path = self.samples[index] 45 | img = Image.open(sample_path).convert('RGB') 46 | img = self.transform(img) 47 | 48 | return img 49 | 50 | def __len__(self): 51 | return len(self.samples) 52 | 53 | 54 | def check_folder(log_dir): 55 | if not os.path.exists(log_dir): 56 | os.makedirs(log_dir) 57 | return log_dir 58 | 59 | 60 | def str2bool(x): 61 | return x.lower() in ('true') 62 | 63 | 64 | def multi_gpu_run(ddp_fn, args): # in main 65 | # ddp_fn = train_fn 66 | world_size = torch.cuda.device_count() # ngpus 67 | torch_multiprocessing.spawn(fn=ddp_fn, args=(args, world_size), nprocs=world_size, join=True) 68 | 69 | 70 | def build_init_procss(rank, world_size, device): # in build 71 | os.environ["MASTER_ADDR"] = "127.0.0.1" # localhost 72 | os.environ["MASTER_PORT"] = "12355" 73 | 74 | # initialize the process group 75 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 76 | synchronize() 77 | torch.cuda.set_device(device) 78 | 79 | 80 | def distributed_sampler(dataset, rank, num_replicas, shuffle): 81 | return torch.utils.data.distributed.DistributedSampler(dataset, rank=rank, num_replicas=num_replicas, shuffle=shuffle) 82 | # return torch.utils.data.RandomSampler(dataset) 83 | 84 | 85 | def infinite_iterator(loader): 86 | while True: 87 | for batch in loader: 88 | yield batch 89 | 90 | def find_latest_ckpt(folder): 91 | files = [] 92 | for fname in os.listdir(folder): 93 | s = re.findall(r'\d+', fname) 94 | if len(s) == 1: 95 | files.append((int(s[0]), fname)) 96 | if files: 97 | file_name = max(files)[1] 98 | index = os.path.splitext(file_name)[0] 99 | return file_name, index 100 | else: 101 | return None, 0 102 | 103 | 104 | def broadcast_params(model): 105 | params = model.parameters() 106 | for param in params: 107 | dist.broadcast(param.data, src=0) 108 | dist.barrier() 109 | torch.cuda.synchronize() 110 | 111 | 112 | def dataparallel_and_sync(model, local_rank, find_unused_parameters=False): 113 | # DistributedDataParallel 114 | model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=find_unused_parameters) 115 | 116 | # broadcast 117 | broadcast_params(model) 118 | 119 | model = model.module 120 | 121 | return model 122 | 123 | 124 | def cleanup(): 125 | dist.destroy_process_group() 126 | 127 | def get_rank(): 128 | if not dist.is_available(): 129 | return 0 130 | 131 | if not dist.is_initialized(): 132 | return 0 133 | 134 | return dist.get_rank() 135 | 136 | def get_world_size(): 137 | if not dist.is_available(): 138 | return 1 139 | 140 | if not dist.is_initialized(): 141 | return 1 142 | 143 | return dist.get_world_size() 144 | 145 | def synchronize(): 146 | if not dist.is_available(): 147 | return 148 | 149 | if not dist.is_initialized(): 150 | return 151 | 152 | world_size = dist.get_world_size() 153 | 154 | if world_size == 1: 155 | return 156 | 157 | dist.barrier() 158 | 159 | def reduce_loss_dict(loss_dict): 160 | world_size = get_world_size() 161 | 162 | if world_size < 2: 163 | return loss_dict 164 | 165 | with torch.no_grad(): 166 | keys = [] 167 | losses = [] 168 | 169 | for k in sorted(loss_dict.keys()): 170 | keys.append(k) 171 | losses.append(loss_dict[k]) 172 | 173 | losses = torch.stack(losses, 0) 174 | dist.reduce(losses, dst=0) 175 | 176 | if dist.get_rank() == 0: 177 | losses /= world_size 178 | 179 | reduced_losses = {k: v.mean().item() for k, v in zip(keys, losses)} 180 | 181 | return reduced_losses 182 | 183 | def get_val(x): 184 | x_val = x.mean().item() 185 | 186 | return x_val 187 | 188 | def requires_grad(model, flag=True): 189 | for p in model.parameters(): 190 | p.requires_grad = flag 191 | 192 | def count_parameters(model): 193 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 194 | 195 | 196 | --------------------------------------------------------------------------------