├── .gitignore ├── LICENSE ├── README.md ├── checkpoint └── .gitignore ├── cog.yaml ├── dataset.py ├── doc ├── sample.png ├── sample_ffhq.png ├── sample_ffhq_new.png ├── sample_mixing.png ├── sample_mixing_ffhq.png ├── sample_mixing_ffhq_new.png └── sample_prev.png ├── generate.py ├── lpips ├── __init__.py ├── base_model.py ├── dist_model.py ├── networks_basic.py ├── pretrained_networks.py └── weights │ ├── v0.0 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth │ └── v0.1 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth ├── model.py ├── predict.py ├── prepare_data.py ├── projector.py ├── sample └── .gitignore └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | *.lmdb 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | # License for LPIPS 24 | 25 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 26 | All rights reserved. 27 | 28 | Redistribution and use in source and binary forms, with or without 29 | modification, are permitted provided that the following conditions are met: 30 | 31 | * Redistributions of source code must retain the above copyright notice, this 32 | list of conditions and the following disclaimer. 33 | 34 | * Redistributions in binary form must reproduce the above copyright notice, 35 | this list of conditions and the following disclaimer in the documentation 36 | and/or other materials provided with the distribution. 37 | 38 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 39 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 40 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 41 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 42 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 43 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 44 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 45 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 46 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 47 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 48 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Style-Based GAN in PyTorch 2 | 3 | ##### Update (2019/09/01) 4 | 5 | I found bugs in the implementation thanks to @adambielski and @TropComplique! (https://github.com/rosinality/style-based-gan-pytorch/issues/33, https://github.com/rosinality/style-based-gan-pytorch/issues/34) I have fixed this and updated checkpoints 6 | 7 | ##### Update (2019/07/04) 8 | 9 | * Now trainer uses pre-resized lmdb dataset for more stable data loading and training. 10 | * Model architecture is now more closely matches with official implementation. 11 | 12 | Implementation of A Style-Based Generator Architecture for Generative Adversarial Networks (https://arxiv.org/abs/1812.04948) in PyTorch 13 | 14 | * [Demo and Docker image on Replicate](https://replicate.ai/rosinality/style-based-gan-pytorch) 15 | 16 | Usage: 17 | 18 | You should prepare lmdb dataset 19 | 20 | > python prepare_data.py --out LMDB_PATH --n_worker N_WORKER DATASET_PATH 21 | 22 | This will convert images to jpeg and pre-resizes it. (For example, 8/16/32/64/128/256/512/1024) Then you can train StyleGAN. 23 | 24 | for celebA 25 | 26 | > python train.py --mixing LMDB_PATH 27 | 28 | for FFHQ 29 | 30 | > python train.py --mixing --loss r1 --sched LMDB_PATH 31 | 32 | Resolution | Model & Optimizer 33 | -----------|------------------- 34 | 256px | [Link](https://drive.google.com/open?id=1QlXFPIOFzsJyjZ1AtfpnVhqW4Z0r8GLZ) 35 | 512px | [Link](https://drive.google.com/open?id=13f0tXPX0EfHdac0zcudfC8osD4OdsxZQ) 36 | 1024px | [Link](https://drive.google.com/open?id=1NJMqp2AN1de8cPXTBzYC7mX2wXF9ox-i) 37 | 38 | Model & Optimizer checkpoints saved at the end of phases of each resolution. (that is, 512px checkpoint saved at the end of 512px training.) 39 | 40 | ## Sample 41 | 42 | ![Sample of the model trained on FFHQ](doc/sample_ffhq_new.png) 43 | ![Style mixing sample of the model trained on FFHQ](doc/sample_mixing_ffhq_new.png) 44 | 45 | 512px sample from the generator trained on FFHQ. 46 | 47 | ## Old Checkpoints 48 | 49 | Resolution | Model & Optimizer | Running average of generator 50 | -----------|-------------------|------------------------------ 51 | 128px | [Link](https://drive.google.com/open?id=1Fc0d8tTjS7Fcmr8gyHk8M0P-VMiRNeMl) | 100k iter [Link](https://drive.google.com/open?id=1b4MKSVTbWoY15NkzsM58T0QCvTE9d_Ch) 52 | 256px | [Link](https://drive.google.com/open?id=1K2G1p-m1BQNoTEKJDBGAtFI1fC4eBjcd) | 140k iter [Link](https://drive.google.com/open?id=1n01mlc1mPpQyeUnnWNGeZiY7vp6JgakM) 53 | 512px | [Link](https://drive.google.com/open?id=1Ls8NA56UnJWGJkRXXyJoDdz4a7uizBtw) | 180k iter [Link](https://drive.google.com/open?id=15lnKHnldIidQnXAlQ8PHo2W4XUTaIfq-) 54 | 55 | Old version of checkpoints. As gradient penalty and discriminator activations are different, it is better to use new checkpoints to do some training. But you can use these checkpoints to make samples as generator architecture is not changed. 56 | 57 | Running average of generator is saved at the specified iterations. So these two are saved at different iterations. (Yes, this is my mistake.) -------------------------------------------------------------------------------- /checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | *.model 2 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | predict: predict.py:Predictor 2 | build: 3 | python_version: 3.8 4 | python_packages: 5 | - torch==1.7.0 6 | - torchvision==0.8.1 7 | - tqdm==4.59.0 8 | - pillow==8.1.2 9 | - lmdb==1.1.1 10 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class MultiResolutionDataset(Dataset): 9 | def __init__(self, path, transform, resolution=8): 10 | self.env = lmdb.open( 11 | path, 12 | max_readers=32, 13 | readonly=True, 14 | lock=False, 15 | readahead=False, 16 | meminit=False, 17 | ) 18 | 19 | if not self.env: 20 | raise IOError('Cannot open lmdb dataset', path) 21 | 22 | with self.env.begin(write=False) as txn: 23 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 24 | 25 | self.resolution = resolution 26 | self.transform = transform 27 | 28 | def __len__(self): 29 | return self.length 30 | 31 | def __getitem__(self, index): 32 | with self.env.begin(write=False) as txn: 33 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 34 | img_bytes = txn.get(key) 35 | 36 | buffer = BytesIO(img_bytes) 37 | img = Image.open(buffer) 38 | img = self.transform(img) 39 | 40 | return img 41 | -------------------------------------------------------------------------------- /doc/sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/style-based-gan-pytorch/07fa60be77b093dd13a46597138df409ffc3b9bc/doc/sample.png -------------------------------------------------------------------------------- /doc/sample_ffhq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/style-based-gan-pytorch/07fa60be77b093dd13a46597138df409ffc3b9bc/doc/sample_ffhq.png -------------------------------------------------------------------------------- /doc/sample_ffhq_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/style-based-gan-pytorch/07fa60be77b093dd13a46597138df409ffc3b9bc/doc/sample_ffhq_new.png -------------------------------------------------------------------------------- /doc/sample_mixing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/style-based-gan-pytorch/07fa60be77b093dd13a46597138df409ffc3b9bc/doc/sample_mixing.png -------------------------------------------------------------------------------- /doc/sample_mixing_ffhq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/style-based-gan-pytorch/07fa60be77b093dd13a46597138df409ffc3b9bc/doc/sample_mixing_ffhq.png -------------------------------------------------------------------------------- /doc/sample_mixing_ffhq_new.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/style-based-gan-pytorch/07fa60be77b093dd13a46597138df409ffc3b9bc/doc/sample_mixing_ffhq_new.png -------------------------------------------------------------------------------- /doc/sample_prev.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/style-based-gan-pytorch/07fa60be77b093dd13a46597138df409ffc3b9bc/doc/sample_prev.png -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | 4 | import torch 5 | from torchvision import utils 6 | 7 | from model import StyledGenerator 8 | 9 | 10 | @torch.no_grad() 11 | def get_mean_style(generator, device): 12 | mean_style = None 13 | 14 | for i in range(10): 15 | style = generator.mean_style(torch.randn(1024, 512).to(device)) 16 | 17 | if mean_style is None: 18 | mean_style = style 19 | 20 | else: 21 | mean_style += style 22 | 23 | mean_style /= 10 24 | return mean_style 25 | 26 | @torch.no_grad() 27 | def sample(generator, step, mean_style, n_sample, device): 28 | image = generator( 29 | torch.randn(n_sample, 512).to(device), 30 | step=step, 31 | alpha=1, 32 | mean_style=mean_style, 33 | style_weight=0.7, 34 | ) 35 | 36 | return image 37 | 38 | @torch.no_grad() 39 | def style_mixing(generator, step, mean_style, n_source, n_target, device): 40 | source_code = torch.randn(n_source, 512).to(device) 41 | target_code = torch.randn(n_target, 512).to(device) 42 | 43 | shape = 4 * 2 ** step 44 | alpha = 1 45 | 46 | images = [torch.ones(1, 3, shape, shape).to(device) * -1] 47 | 48 | source_image = generator( 49 | source_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7 50 | ) 51 | target_image = generator( 52 | target_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7 53 | ) 54 | 55 | images.append(source_image) 56 | 57 | for i in range(n_target): 58 | image = generator( 59 | [target_code[i].unsqueeze(0).repeat(n_source, 1), source_code], 60 | step=step, 61 | alpha=alpha, 62 | mean_style=mean_style, 63 | style_weight=0.7, 64 | mixing_range=(0, 1), 65 | ) 66 | images.append(target_image[i].unsqueeze(0)) 67 | images.append(image) 68 | 69 | images = torch.cat(images, 0) 70 | 71 | return images 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument('--size', type=int, default=1024, help='size of the image') 77 | parser.add_argument('--n_row', type=int, default=3, help='number of rows of sample matrix') 78 | parser.add_argument('--n_col', type=int, default=5, help='number of columns of sample matrix') 79 | parser.add_argument('path', type=str, help='path to checkpoint file') 80 | 81 | args = parser.parse_args() 82 | 83 | device = 'cuda' 84 | 85 | generator = StyledGenerator(512).to(device) 86 | generator.load_state_dict(torch.load(args.path)['g_running']) 87 | generator.eval() 88 | 89 | mean_style = get_mean_style(generator, device) 90 | 91 | step = int(math.log(args.size, 2)) - 2 92 | 93 | img = sample(generator, step, mean_style, args.n_row * args.n_col, device) 94 | utils.save_image(img, 'sample.png', nrow=args.n_col, normalize=True, range=(-1, 1)) 95 | 96 | for j in range(20): 97 | img = style_mixing(generator, step, mean_style, args.n_col, args.n_row, device) 98 | utils.save_image( 99 | img, f'sample_mixing_{j}.png', nrow=args.n_col + 1, normalize=True, range=(-1, 1) 100 | ) 101 | -------------------------------------------------------------------------------- /lpips/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from skimage.measure import compare_ssim 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | from lpips import dist_model 12 | 13 | class PerceptualLoss(torch.nn.Module): 14 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) 15 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 16 | super(PerceptualLoss, self).__init__() 17 | print('Setting up Perceptual loss...') 18 | self.use_gpu = use_gpu 19 | self.spatial = spatial 20 | self.gpu_ids = gpu_ids 21 | self.model = dist_model.DistModel() 22 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) 23 | print('...[%s] initialized'%self.model.name()) 24 | print('...Done') 25 | 26 | def forward(self, pred, target, normalize=False): 27 | """ 28 | Pred and target are Variables. 29 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 30 | If normalize is False, assumes the images are already between [-1,+1] 31 | 32 | Inputs pred and target are Nx3xHxW 33 | Output pytorch Variable N long 34 | """ 35 | 36 | if normalize: 37 | target = 2 * target - 1 38 | pred = 2 * pred - 1 39 | 40 | return self.model.forward(target, pred) 41 | 42 | def normalize_tensor(in_feat,eps=1e-10): 43 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 44 | return in_feat/(norm_factor+eps) 45 | 46 | def l2(p0, p1, range=255.): 47 | return .5*np.mean((p0 / range - p1 / range)**2) 48 | 49 | def psnr(p0, p1, peak=255.): 50 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 51 | 52 | def dssim(p0, p1, range=255.): 53 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 54 | 55 | def rgb2lab(in_img,mean_cent=False): 56 | from skimage import color 57 | img_lab = color.rgb2lab(in_img) 58 | if(mean_cent): 59 | img_lab[:,:,0] = img_lab[:,:,0]-50 60 | return img_lab 61 | 62 | def tensor2np(tensor_obj): 63 | # change dimension of a tensor object into a numpy array 64 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 65 | 66 | def np2tensor(np_obj): 67 | # change dimenion of np array into tensor array 68 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 69 | 70 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 71 | # image tensor to lab tensor 72 | from skimage import color 73 | 74 | img = tensor2im(image_tensor) 75 | img_lab = color.rgb2lab(img) 76 | if(mc_only): 77 | img_lab[:,:,0] = img_lab[:,:,0]-50 78 | if(to_norm and not mc_only): 79 | img_lab[:,:,0] = img_lab[:,:,0]-50 80 | img_lab = img_lab/100. 81 | 82 | return np2tensor(img_lab) 83 | 84 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 85 | from skimage import color 86 | import warnings 87 | warnings.filterwarnings("ignore") 88 | 89 | lab = tensor2np(lab_tensor)*100. 90 | lab[:,:,0] = lab[:,:,0]+50 91 | 92 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 93 | if(return_inbnd): 94 | # convert back to lab, see if we match 95 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 96 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 97 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 98 | return (im2tensor(rgb_back),mask) 99 | else: 100 | return im2tensor(rgb_back) 101 | 102 | def rgb2lab(input): 103 | from skimage import color 104 | return color.rgb2lab(input / 255.) 105 | 106 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 107 | image_numpy = image_tensor[0].cpu().float().numpy() 108 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 109 | return image_numpy.astype(imtype) 110 | 111 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 112 | return torch.Tensor((image / factor - cent) 113 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 114 | 115 | def tensor2vec(vector_tensor): 116 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 117 | 118 | def voc_ap(rec, prec, use_07_metric=False): 119 | """ ap = voc_ap(rec, prec, [use_07_metric]) 120 | Compute VOC AP given precision and recall. 121 | If use_07_metric is true, uses the 122 | VOC 07 11 point method (default:False). 123 | """ 124 | if use_07_metric: 125 | # 11 point metric 126 | ap = 0. 127 | for t in np.arange(0., 1.1, 0.1): 128 | if np.sum(rec >= t) == 0: 129 | p = 0 130 | else: 131 | p = np.max(prec[rec >= t]) 132 | ap = ap + p / 11. 133 | else: 134 | # correct AP calculation 135 | # first append sentinel values at the end 136 | mrec = np.concatenate(([0.], rec, [1.])) 137 | mpre = np.concatenate(([0.], prec, [0.])) 138 | 139 | # compute the precision envelope 140 | for i in range(mpre.size - 1, 0, -1): 141 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 142 | 143 | # to calculate area under PR curve, look for points 144 | # where X axis (recall) changes value 145 | i = np.where(mrec[1:] != mrec[:-1])[0] 146 | 147 | # and sum (\Delta recall) * prec 148 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 149 | return ap 150 | 151 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 152 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 153 | image_numpy = image_tensor[0].cpu().float().numpy() 154 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 155 | return image_numpy.astype(imtype) 156 | 157 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 158 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 159 | return torch.Tensor((image / factor - cent) 160 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 161 | -------------------------------------------------------------------------------- /lpips/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | from pdb import set_trace as st 6 | from IPython import embed 7 | 8 | class BaseModel(): 9 | def __init__(self): 10 | pass; 11 | 12 | def name(self): 13 | return 'BaseModel' 14 | 15 | def initialize(self, use_gpu=True, gpu_ids=[0]): 16 | self.use_gpu = use_gpu 17 | self.gpu_ids = gpu_ids 18 | 19 | def forward(self): 20 | pass 21 | 22 | def get_image_paths(self): 23 | pass 24 | 25 | def optimize_parameters(self): 26 | pass 27 | 28 | def get_current_visuals(self): 29 | return self.input 30 | 31 | def get_current_errors(self): 32 | return {} 33 | 34 | def save(self, label): 35 | pass 36 | 37 | # helper saving function that can be used by subclasses 38 | def save_network(self, network, path, network_label, epoch_label): 39 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 40 | save_path = os.path.join(path, save_filename) 41 | torch.save(network.state_dict(), save_path) 42 | 43 | # helper loading function that can be used by subclasses 44 | def load_network(self, network, network_label, epoch_label): 45 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 46 | save_path = os.path.join(self.save_dir, save_filename) 47 | print('Loading network from %s'%save_path) 48 | network.load_state_dict(torch.load(save_path)) 49 | 50 | def update_learning_rate(): 51 | pass 52 | 53 | def get_image_paths(self): 54 | return self.image_paths 55 | 56 | def save_done(self, flag=False): 57 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 58 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 59 | -------------------------------------------------------------------------------- /lpips/dist_model.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | import os 9 | from collections import OrderedDict 10 | from torch.autograd import Variable 11 | import itertools 12 | from .base_model import BaseModel 13 | from scipy.ndimage import zoom 14 | import fractions 15 | import functools 16 | import skimage.transform 17 | from tqdm import tqdm 18 | 19 | from IPython import embed 20 | 21 | from . import networks_basic as networks 22 | import lpips as util 23 | 24 | class DistModel(BaseModel): 25 | def name(self): 26 | return self.model_name 27 | 28 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, 29 | use_gpu=True, printNet=False, spatial=False, 30 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): 31 | ''' 32 | INPUTS 33 | model - ['net-lin'] for linearly calibrated network 34 | ['net'] for off-the-shelf network 35 | ['L2'] for L2 distance in Lab colorspace 36 | ['SSIM'] for ssim in RGB colorspace 37 | net - ['squeeze','alex','vgg'] 38 | model_path - if None, will look in weights/[NET_NAME].pth 39 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 40 | use_gpu - bool - whether or not to use a GPU 41 | printNet - bool - whether or not to print network architecture out 42 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 43 | spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). 44 | spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. 45 | spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). 46 | is_train - bool - [True] for training mode 47 | lr - float - initial learning rate 48 | beta1 - float - initial momentum term for adam 49 | version - 0.1 for latest, 0.0 was original (with a bug) 50 | gpu_ids - int array - [0] by default, gpus to use 51 | ''' 52 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) 53 | 54 | self.model = model 55 | self.net = net 56 | self.is_train = is_train 57 | self.spatial = spatial 58 | self.gpu_ids = gpu_ids 59 | self.model_name = '%s [%s]'%(model,net) 60 | 61 | if(self.model == 'net-lin'): # pretrained net + linear layer 62 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, 63 | use_dropout=True, spatial=spatial, version=version, lpips=True) 64 | kw = {} 65 | if not use_gpu: 66 | kw['map_location'] = 'cpu' 67 | if(model_path is None): 68 | import inspect 69 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) 70 | 71 | if(not is_train): 72 | print('Loading model from: %s'%model_path) 73 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 74 | 75 | elif(self.model=='net'): # pretrained network 76 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) 77 | elif(self.model in ['L2','l2']): 78 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing 79 | self.model_name = 'L2' 80 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): 81 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) 82 | self.model_name = 'SSIM' 83 | else: 84 | raise ValueError("Model [%s] not recognized." % self.model) 85 | 86 | self.parameters = list(self.net.parameters()) 87 | 88 | if self.is_train: # training mode 89 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 90 | self.rankLoss = networks.BCERankingLoss() 91 | self.parameters += list(self.rankLoss.net.parameters()) 92 | self.lr = lr 93 | self.old_lr = lr 94 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 95 | else: # test mode 96 | self.net.eval() 97 | 98 | if(use_gpu): 99 | self.net.to(gpu_ids[0]) 100 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 101 | if(self.is_train): 102 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 103 | 104 | if(printNet): 105 | print('---------- Networks initialized -------------') 106 | networks.print_network(self.net) 107 | print('-----------------------------------------------') 108 | 109 | def forward(self, in0, in1, retPerLayer=False): 110 | ''' Function computes the distance between image patches in0 and in1 111 | INPUTS 112 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 113 | OUTPUT 114 | computed distances between in0 and in1 115 | ''' 116 | 117 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 118 | 119 | # ***** TRAINING FUNCTIONS ***** 120 | def optimize_parameters(self): 121 | self.forward_train() 122 | self.optimizer_net.zero_grad() 123 | self.backward_train() 124 | self.optimizer_net.step() 125 | self.clamp_weights() 126 | 127 | def clamp_weights(self): 128 | for module in self.net.modules(): 129 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 130 | module.weight.data = torch.clamp(module.weight.data,min=0) 131 | 132 | def set_input(self, data): 133 | self.input_ref = data['ref'] 134 | self.input_p0 = data['p0'] 135 | self.input_p1 = data['p1'] 136 | self.input_judge = data['judge'] 137 | 138 | if(self.use_gpu): 139 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 140 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 141 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 142 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 143 | 144 | self.var_ref = Variable(self.input_ref,requires_grad=True) 145 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 146 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 147 | 148 | def forward_train(self): # run forward pass 149 | # print(self.net.module.scaling_layer.shift) 150 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) 151 | 152 | self.d0 = self.forward(self.var_ref, self.var_p0) 153 | self.d1 = self.forward(self.var_ref, self.var_p1) 154 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 155 | 156 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 157 | 158 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 159 | 160 | return self.loss_total 161 | 162 | def backward_train(self): 163 | torch.mean(self.loss_total).backward() 164 | 165 | def compute_accuracy(self,d0,d1,judge): 166 | ''' d0, d1 are Variables, judge is a Tensor ''' 167 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 210 | self.old_lr = lr 211 | 212 | def score_2afc_dataset(data_loader, func, name=''): 213 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 214 | distance function 'func' in dataset 'data_loader' 215 | INPUTS 216 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 217 | func - callable distance function - calling d=func(in0,in1) should take 2 218 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 219 | OUTPUTS 220 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 221 | [1] - dictionary with following elements 222 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 223 | gts - N array in [0,1], preferred patch selected by human evaluators 224 | (closer to "0" for left patch p0, "1" for right patch p1, 225 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 226 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 227 | CONSTS 228 | N - number of test triplets in data_loader 229 | ''' 230 | 231 | d0s = [] 232 | d1s = [] 233 | gts = [] 234 | 235 | for data in tqdm(data_loader.load_data(), desc=name): 236 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() 237 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() 238 | gts+=data['judge'].cpu().numpy().flatten().tolist() 239 | 240 | d0s = np.array(d0s) 241 | d1s = np.array(d1s) 242 | gts = np.array(gts) 243 | scores = (d0s inject_index[crossover]: 422 | crossover = min(crossover + 1, len(style)) 423 | 424 | style_step = style[crossover] 425 | 426 | else: 427 | if mixing_range[0] <= i <= mixing_range[1]: 428 | style_step = style[1] 429 | 430 | else: 431 | style_step = style[0] 432 | 433 | if i > 0 and step > 0: 434 | out_prev = out 435 | 436 | out = conv(out, style_step, noise[i]) 437 | 438 | if i == step: 439 | out = to_rgb(out) 440 | 441 | if i > 0 and 0 <= alpha < 1: 442 | skip_rgb = self.to_rgb[i - 1](out_prev) 443 | skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='nearest') 444 | out = (1 - alpha) * skip_rgb + alpha * out 445 | 446 | break 447 | 448 | return out 449 | 450 | 451 | class StyledGenerator(nn.Module): 452 | def __init__(self, code_dim=512, n_mlp=8): 453 | super().__init__() 454 | 455 | self.generator = Generator(code_dim) 456 | 457 | layers = [PixelNorm()] 458 | for i in range(n_mlp): 459 | layers.append(EqualLinear(code_dim, code_dim)) 460 | layers.append(nn.LeakyReLU(0.2)) 461 | 462 | self.style = nn.Sequential(*layers) 463 | 464 | def forward( 465 | self, 466 | input, 467 | noise=None, 468 | step=0, 469 | alpha=-1, 470 | mean_style=None, 471 | style_weight=0, 472 | mixing_range=(-1, -1), 473 | ): 474 | styles = [] 475 | if type(input) not in (list, tuple): 476 | input = [input] 477 | 478 | for i in input: 479 | styles.append(self.style(i)) 480 | 481 | batch = input[0].shape[0] 482 | 483 | if noise is None: 484 | noise = [] 485 | 486 | for i in range(step + 1): 487 | size = 4 * 2 ** i 488 | noise.append(torch.randn(batch, 1, size, size, device=input[0].device)) 489 | 490 | if mean_style is not None: 491 | styles_norm = [] 492 | 493 | for style in styles: 494 | styles_norm.append(mean_style + style_weight * (style - mean_style)) 495 | 496 | styles = styles_norm 497 | 498 | return self.generator(styles, noise, step, alpha, mixing_range=mixing_range) 499 | 500 | def mean_style(self, input): 501 | style = self.style(input).mean(0, keepdim=True) 502 | 503 | return style 504 | 505 | 506 | class Discriminator(nn.Module): 507 | def __init__(self, fused=True, from_rgb_activate=False): 508 | super().__init__() 509 | 510 | self.progression = nn.ModuleList( 511 | [ 512 | ConvBlock(16, 32, 3, 1, downsample=True, fused=fused), # 512 513 | ConvBlock(32, 64, 3, 1, downsample=True, fused=fused), # 256 514 | ConvBlock(64, 128, 3, 1, downsample=True, fused=fused), # 128 515 | ConvBlock(128, 256, 3, 1, downsample=True, fused=fused), # 64 516 | ConvBlock(256, 512, 3, 1, downsample=True), # 32 517 | ConvBlock(512, 512, 3, 1, downsample=True), # 16 518 | ConvBlock(512, 512, 3, 1, downsample=True), # 8 519 | ConvBlock(512, 512, 3, 1, downsample=True), # 4 520 | ConvBlock(513, 512, 3, 1, 4, 0), 521 | ] 522 | ) 523 | 524 | def make_from_rgb(out_channel): 525 | if from_rgb_activate: 526 | return nn.Sequential(EqualConv2d(3, out_channel, 1), nn.LeakyReLU(0.2)) 527 | 528 | else: 529 | return EqualConv2d(3, out_channel, 1) 530 | 531 | self.from_rgb = nn.ModuleList( 532 | [ 533 | make_from_rgb(16), 534 | make_from_rgb(32), 535 | make_from_rgb(64), 536 | make_from_rgb(128), 537 | make_from_rgb(256), 538 | make_from_rgb(512), 539 | make_from_rgb(512), 540 | make_from_rgb(512), 541 | make_from_rgb(512), 542 | ] 543 | ) 544 | 545 | # self.blur = Blur() 546 | 547 | self.n_layer = len(self.progression) 548 | 549 | self.linear = EqualLinear(512, 1) 550 | 551 | def forward(self, input, step=0, alpha=-1): 552 | for i in range(step, -1, -1): 553 | index = self.n_layer - i - 1 554 | 555 | if i == step: 556 | out = self.from_rgb[index](input) 557 | 558 | if i == 0: 559 | out_std = torch.sqrt(out.var(0, unbiased=False) + 1e-8) 560 | mean_std = out_std.mean() 561 | mean_std = mean_std.expand(out.size(0), 1, 4, 4) 562 | out = torch.cat([out, mean_std], 1) 563 | 564 | out = self.progression[index](out) 565 | 566 | if i > 0: 567 | if i == step and 0 <= alpha < 1: 568 | skip_rgb = F.avg_pool2d(input, 2) 569 | skip_rgb = self.from_rgb[index + 1](skip_rgb) 570 | 571 | out = (1 - alpha) * skip_rgb + alpha * out 572 | 573 | out = out.squeeze(2).squeeze(2) 574 | # print(input.size(), out.size(), step) 575 | out = self.linear(out) 576 | 577 | return out 578 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import tempfile 4 | from pathlib import Path 5 | import torch 6 | from torchvision import utils 7 | import cog 8 | 9 | from generate import sample, get_mean_style 10 | from model import StyledGenerator 11 | 12 | SIZE = 1024 13 | 14 | 15 | class Predictor(cog.Predictor): 16 | def setup(self): 17 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | self.generator = StyledGenerator(512).to(self.device) 19 | print("Loading checkpoint") 20 | self.generator.load_state_dict( 21 | torch.load( 22 | "stylegan-1024px-new.model", 23 | map_location=self.device, 24 | )["g_running"], 25 | ) 26 | self.generator.eval() 27 | 28 | @cog.input("seed", type=int, default=-1, help="Random seed, -1 for random") 29 | def predict(self, seed): 30 | if seed < 0: 31 | seed = int.from_bytes(os.urandom(2), "big") 32 | torch.manual_seed(seed) 33 | print(f"seed: {seed}") 34 | 35 | mean_style = get_mean_style(self.generator, self.device) 36 | step = int(math.log(SIZE, 2)) - 2 37 | img = sample(self.generator, step, mean_style, 1, self.device) 38 | output_path = Path(tempfile.mkdtemp()) / "output.png" 39 | utils.save_image(img, output_path, normalize=True) 40 | return output_path 41 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from io import BytesIO 3 | import multiprocessing 4 | from functools import partial 5 | 6 | from PIL import Image 7 | import lmdb 8 | from tqdm import tqdm 9 | from torchvision import datasets 10 | from torchvision.transforms import functional as trans_fn 11 | 12 | 13 | def resize_and_convert(img, size, quality=100): 14 | img = trans_fn.resize(img, size, Image.LANCZOS) 15 | img = trans_fn.center_crop(img, size) 16 | buffer = BytesIO() 17 | img.save(buffer, format='jpeg', quality=quality) 18 | val = buffer.getvalue() 19 | 20 | return val 21 | 22 | 23 | def resize_multiple(img, sizes=(8, 16, 32, 64, 128, 256, 512, 1024), quality=100): 24 | imgs = [] 25 | 26 | for size in sizes: 27 | imgs.append(resize_and_convert(img, size, quality)) 28 | 29 | return imgs 30 | 31 | 32 | def resize_worker(img_file, sizes): 33 | i, file = img_file 34 | img = Image.open(file) 35 | img = img.convert('RGB') 36 | out = resize_multiple(img, sizes=sizes) 37 | 38 | return i, out 39 | 40 | 41 | def prepare(transaction, dataset, n_worker, sizes=(8, 16, 32, 64, 128, 256, 512, 1024)): 42 | resize_fn = partial(resize_worker, sizes=sizes) 43 | 44 | files = sorted(dataset.imgs, key=lambda x: x[0]) 45 | files = [(i, file) for i, (file, label) in enumerate(files)] 46 | total = 0 47 | 48 | with multiprocessing.Pool(n_worker) as pool: 49 | for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)): 50 | for size, img in zip(sizes, imgs): 51 | key = f'{size}-{str(i).zfill(5)}'.encode('utf-8') 52 | transaction.put(key, img) 53 | 54 | total += 1 55 | 56 | transaction.put('length'.encode('utf-8'), str(total).encode('utf-8')) 57 | 58 | 59 | if __name__ == '__main__': 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument('--out', type=str) 62 | parser.add_argument('--n_worker', type=int, default=8) 63 | parser.add_argument('path', type=str) 64 | 65 | args = parser.parse_args() 66 | 67 | imgset = datasets.ImageFolder(args.path) 68 | 69 | with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env: 70 | with env.begin(write=True) as txn: 71 | prepare(txn, imgset, args.n_worker) 72 | -------------------------------------------------------------------------------- /projector.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import torch 5 | from torch import optim 6 | from torch.nn import functional as F 7 | from torchvision import transforms 8 | from PIL import Image 9 | from tqdm import tqdm 10 | 11 | import lpips 12 | from model import StyledGenerator 13 | 14 | def noise_regularize(noises): 15 | loss = 0 16 | 17 | for noise in noises: 18 | size = noise.shape[2] 19 | 20 | while True: 21 | loss = ( 22 | loss 23 | + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) 24 | + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) 25 | ) 26 | 27 | if size <= 8: 28 | break 29 | 30 | noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2]) 31 | noise = noise.mean([3, 5]) 32 | size //= 2 33 | 34 | return loss 35 | 36 | 37 | def noise_normalize_(noises): 38 | for noise in noises: 39 | mean = noise.mean() 40 | std = noise.std() 41 | 42 | noise.data.add_(-mean).div_(std) 43 | 44 | 45 | def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): 46 | lr_ramp = min(1, (1 - t) / rampdown) 47 | lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) 48 | lr_ramp = lr_ramp * min(1, t / rampup) 49 | 50 | return initial_lr * lr_ramp 51 | 52 | 53 | def latent_noise(latent, strength): 54 | noise = torch.randn_like(latent) * strength 55 | 56 | return latent + noise 57 | 58 | 59 | def make_image(tensor): 60 | return ( 61 | tensor.detach() 62 | .clamp_(min=-1, max=1) 63 | .add(1) 64 | .div_(2) 65 | .mul(255) 66 | .type(torch.uint8) 67 | .permute(0, 2, 3, 1) 68 | .to("cpu") 69 | .numpy() 70 | ) 71 | 72 | def make_noise(device,size): 73 | noises = [] 74 | step = int(math.log(size, 2)) - 2 75 | for i in range(step + 1): 76 | size = 4 * 2 ** i 77 | noises.append(torch.randn(1, 1, size, size, device=device)) 78 | return noises 79 | 80 | if __name__ == "__main__": 81 | device = "cuda" 82 | 83 | parser = argparse.ArgumentParser( 84 | description="Image projector to the generator latent spaces" 85 | ) 86 | parser.add_argument( 87 | "--ckpt", type=str, required=True, help="path to the model checkpoint" 88 | ) 89 | parser.add_argument( 90 | "--size", type=int, default=256, help="output image sizes of the generator" 91 | ) 92 | parser.add_argument( 93 | "--lr_rampup", 94 | type=float, 95 | default=0.05, 96 | help="duration of the learning rate warmup", 97 | ) 98 | parser.add_argument( 99 | "--lr_rampdown", 100 | type=float, 101 | default=0.25, 102 | help="duration of the learning rate decay", 103 | ) 104 | parser.add_argument("--lr", type=float, default=0.1, help="learning rate") 105 | parser.add_argument( 106 | "--noise", type=float, default=0.05, help="strength of the noise level" 107 | ) 108 | parser.add_argument( 109 | "--noise_ramp", 110 | type=float, 111 | default=0.75, 112 | help="duration of the noise level decay", 113 | ) 114 | parser.add_argument("--step", type=int, default=1000, help="optimize iterations") 115 | parser.add_argument( 116 | "--noise_regularize", 117 | type=float, 118 | default=1e5, 119 | help="weight of the noise regularization", 120 | ) 121 | parser.add_argument("--mse", type=float, default=0, help="weight of the mse loss") 122 | parser.add_argument( 123 | "--w_plus", 124 | action="store_true", 125 | help="allow to use distinct latent codes to each layers", 126 | ) 127 | parser.add_argument( 128 | "--files", metavar="FILES", nargs="+", help="path to image files to be projected" 129 | ) 130 | 131 | args = parser.parse_args() 132 | 133 | n_mean_latent = 10000 134 | 135 | resize = min(args.size, 256) 136 | 137 | transform = transforms.Compose( 138 | [ 139 | transforms.Resize(resize), 140 | transforms.CenterCrop(resize), 141 | transforms.ToTensor(), 142 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 143 | ] 144 | ) 145 | 146 | imgs = [] 147 | 148 | for imgfile in args.files: 149 | img = transform(Image.open(imgfile).convert("RGB")) 150 | imgs.append(img) 151 | 152 | imgs = torch.stack(imgs, 0).to(device) 153 | g_ema = StyledGenerator(512) 154 | g_ema.load_state_dict(torch.load(args.ckpt)["g_running"], strict=False) 155 | g_ema.eval() 156 | g_ema = g_ema.to(device) 157 | step = int(math.log(args.size, 2)) - 2 158 | with torch.no_grad(): 159 | noise_sample = torch.randn(n_mean_latent, 512, device=device) 160 | latent_out = g_ema.style(noise_sample) 161 | 162 | latent_mean = latent_out.mean(0) 163 | latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5 164 | 165 | percept = lpips.PerceptualLoss( 166 | model="net-lin", net="vgg", use_gpu=device.startswith("cuda") 167 | ) 168 | 169 | noises_single = make_noise(device,args.size) 170 | noises = [] 171 | for noise in noises_single: 172 | noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_()) 173 | 174 | latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1) 175 | 176 | if args.w_plus: 177 | latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1) 178 | 179 | latent_in.requires_grad = True 180 | 181 | for noise in noises: 182 | noise.requires_grad = True 183 | 184 | optimizer = optim.Adam([latent_in] + noises, lr=args.lr) 185 | 186 | pbar = tqdm(range(args.step)) 187 | latent_path = [] 188 | 189 | for i in pbar: 190 | 191 | t = i / args.step 192 | lr = get_lr(t, args.lr) 193 | optimizer.param_groups[0]["lr"] = lr 194 | noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2 195 | latent_n = latent_noise(latent_in, noise_strength.item()) 196 | latent_n.to(device) 197 | img_gen = g_ema([latent_n], noise=noises, step=step) 198 | batch, channel, height, width = img_gen.shape 199 | 200 | if height > 256: 201 | factor = height // 256 202 | 203 | img_gen = img_gen.reshape( 204 | batch, channel, height // factor, factor, width // factor, factor 205 | ) 206 | img_gen = img_gen.mean([3, 5]) 207 | 208 | p_loss = percept(img_gen, imgs).sum() 209 | n_loss = noise_regularize(noises) 210 | mse_loss = F.mse_loss(img_gen, imgs) 211 | 212 | loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss 213 | 214 | optimizer.zero_grad() 215 | loss.backward() 216 | optimizer.step() 217 | 218 | noise_normalize_(noises) 219 | 220 | if (i + 1) % 100 == 0: 221 | latent_path.append(latent_in.detach().clone()) 222 | 223 | pbar.set_description( 224 | ( 225 | f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};" 226 | f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}" 227 | ) 228 | ) 229 | img_gen = g_ema([latent_path[-1]], noise=noises,step=step) 230 | 231 | filename = os.path.splitext(os.path.basename(args.files[0]))[0] + ".pt" 232 | 233 | img_ar = make_image(img_gen) 234 | 235 | result_file = {} 236 | for i, input_name in enumerate(args.files): 237 | noise_single = [] 238 | for noise in noises: 239 | noise_single.append(noise[i : i + 1]) 240 | 241 | result_file[input_name] = { 242 | "img": img_gen[i], 243 | "latent": latent_in[i], 244 | "noise": noise_single, 245 | } 246 | 247 | img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png" 248 | pil_img = Image.fromarray(img_ar[i]) 249 | pil_img.save(img_name) 250 | 251 | torch.save(result_file, filename) 252 | -------------------------------------------------------------------------------- /sample/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import math 4 | 5 | from tqdm import tqdm 6 | import numpy as np 7 | from PIL import Image 8 | 9 | import torch 10 | from torch import nn, optim 11 | from torch.nn import functional as F 12 | from torch.autograd import Variable, grad 13 | from torch.utils.data import DataLoader 14 | from torchvision import datasets, transforms, utils 15 | 16 | from dataset import MultiResolutionDataset 17 | from model import StyledGenerator, Discriminator 18 | 19 | 20 | def requires_grad(model, flag=True): 21 | for p in model.parameters(): 22 | p.requires_grad = flag 23 | 24 | 25 | def accumulate(model1, model2, decay=0.999): 26 | par1 = dict(model1.named_parameters()) 27 | par2 = dict(model2.named_parameters()) 28 | 29 | for k in par1.keys(): 30 | par1[k].data.mul_(decay).add_(1 - decay, par2[k].data) 31 | 32 | 33 | def sample_data(dataset, batch_size, image_size=4): 34 | dataset.resolution = image_size 35 | loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=1, drop_last=True) 36 | 37 | return loader 38 | 39 | 40 | def adjust_lr(optimizer, lr): 41 | for group in optimizer.param_groups: 42 | mult = group.get('mult', 1) 43 | group['lr'] = lr * mult 44 | 45 | 46 | def train(args, dataset, generator, discriminator): 47 | step = int(math.log2(args.init_size)) - 2 48 | resolution = 4 * 2 ** step 49 | loader = sample_data( 50 | dataset, args.batch.get(resolution, args.batch_default), resolution 51 | ) 52 | data_loader = iter(loader) 53 | 54 | adjust_lr(g_optimizer, args.lr.get(resolution, 0.001)) 55 | adjust_lr(d_optimizer, args.lr.get(resolution, 0.001)) 56 | 57 | pbar = tqdm(range(3_000_000)) 58 | 59 | requires_grad(generator, False) 60 | requires_grad(discriminator, True) 61 | 62 | disc_loss_val = 0 63 | gen_loss_val = 0 64 | grad_loss_val = 0 65 | 66 | alpha = 0 67 | used_sample = 0 68 | 69 | max_step = int(math.log2(args.max_size)) - 2 70 | final_progress = False 71 | 72 | for i in pbar: 73 | discriminator.zero_grad() 74 | 75 | alpha = min(1, 1 / args.phase * (used_sample + 1)) 76 | 77 | if (resolution == args.init_size and args.ckpt is None) or final_progress: 78 | alpha = 1 79 | 80 | if used_sample > args.phase * 2: 81 | used_sample = 0 82 | step += 1 83 | 84 | if step > max_step: 85 | step = max_step 86 | final_progress = True 87 | ckpt_step = step + 1 88 | 89 | else: 90 | alpha = 0 91 | ckpt_step = step 92 | 93 | resolution = 4 * 2 ** step 94 | 95 | loader = sample_data( 96 | dataset, args.batch.get(resolution, args.batch_default), resolution 97 | ) 98 | data_loader = iter(loader) 99 | 100 | torch.save( 101 | { 102 | 'generator': generator.module.state_dict(), 103 | 'discriminator': discriminator.module.state_dict(), 104 | 'g_optimizer': g_optimizer.state_dict(), 105 | 'd_optimizer': d_optimizer.state_dict(), 106 | 'g_running': g_running.state_dict(), 107 | }, 108 | f'checkpoint/train_step-{ckpt_step}.model', 109 | ) 110 | 111 | adjust_lr(g_optimizer, args.lr.get(resolution, 0.001)) 112 | adjust_lr(d_optimizer, args.lr.get(resolution, 0.001)) 113 | 114 | try: 115 | real_image = next(data_loader) 116 | 117 | except (OSError, StopIteration): 118 | data_loader = iter(loader) 119 | real_image = next(data_loader) 120 | 121 | used_sample += real_image.shape[0] 122 | 123 | b_size = real_image.size(0) 124 | real_image = real_image.cuda() 125 | 126 | if args.loss == 'wgan-gp': 127 | real_predict = discriminator(real_image, step=step, alpha=alpha) 128 | real_predict = real_predict.mean() - 0.001 * (real_predict ** 2).mean() 129 | (-real_predict).backward() 130 | 131 | elif args.loss == 'r1': 132 | real_image.requires_grad = True 133 | real_scores = discriminator(real_image, step=step, alpha=alpha) 134 | real_predict = F.softplus(-real_scores).mean() 135 | real_predict.backward(retain_graph=True) 136 | 137 | grad_real = grad( 138 | outputs=real_scores.sum(), inputs=real_image, create_graph=True 139 | )[0] 140 | grad_penalty = ( 141 | grad_real.view(grad_real.size(0), -1).norm(2, dim=1) ** 2 142 | ).mean() 143 | grad_penalty = 10 / 2 * grad_penalty 144 | grad_penalty.backward() 145 | if i%10 == 0: 146 | grad_loss_val = grad_penalty.item() 147 | 148 | if args.mixing and random.random() < 0.9: 149 | gen_in11, gen_in12, gen_in21, gen_in22 = torch.randn( 150 | 4, b_size, code_size, device='cuda' 151 | ).chunk(4, 0) 152 | gen_in1 = [gen_in11.squeeze(0), gen_in12.squeeze(0)] 153 | gen_in2 = [gen_in21.squeeze(0), gen_in22.squeeze(0)] 154 | 155 | else: 156 | gen_in1, gen_in2 = torch.randn(2, b_size, code_size, device='cuda').chunk( 157 | 2, 0 158 | ) 159 | gen_in1 = gen_in1.squeeze(0) 160 | gen_in2 = gen_in2.squeeze(0) 161 | 162 | fake_image = generator(gen_in1, step=step, alpha=alpha) 163 | fake_predict = discriminator(fake_image, step=step, alpha=alpha) 164 | 165 | if args.loss == 'wgan-gp': 166 | fake_predict = fake_predict.mean() 167 | fake_predict.backward() 168 | 169 | eps = torch.rand(b_size, 1, 1, 1).cuda() 170 | x_hat = eps * real_image.data + (1 - eps) * fake_image.data 171 | x_hat.requires_grad = True 172 | hat_predict = discriminator(x_hat, step=step, alpha=alpha) 173 | grad_x_hat = grad( 174 | outputs=hat_predict.sum(), inputs=x_hat, create_graph=True 175 | )[0] 176 | grad_penalty = ( 177 | (grad_x_hat.view(grad_x_hat.size(0), -1).norm(2, dim=1) - 1) ** 2 178 | ).mean() 179 | grad_penalty = 10 * grad_penalty 180 | grad_penalty.backward() 181 | if i%10 == 0: 182 | grad_loss_val = grad_penalty.item() 183 | disc_loss_val = (-real_predict + fake_predict).item() 184 | 185 | elif args.loss == 'r1': 186 | fake_predict = F.softplus(fake_predict).mean() 187 | fake_predict.backward() 188 | if i%10 == 0: 189 | disc_loss_val = (real_predict + fake_predict).item() 190 | 191 | d_optimizer.step() 192 | 193 | if (i + 1) % n_critic == 0: 194 | generator.zero_grad() 195 | 196 | requires_grad(generator, True) 197 | requires_grad(discriminator, False) 198 | 199 | fake_image = generator(gen_in2, step=step, alpha=alpha) 200 | 201 | predict = discriminator(fake_image, step=step, alpha=alpha) 202 | 203 | if args.loss == 'wgan-gp': 204 | loss = -predict.mean() 205 | 206 | elif args.loss == 'r1': 207 | loss = F.softplus(-predict).mean() 208 | 209 | if i%10 == 0: 210 | gen_loss_val = loss.item() 211 | 212 | loss.backward() 213 | g_optimizer.step() 214 | accumulate(g_running, generator.module) 215 | 216 | requires_grad(generator, False) 217 | requires_grad(discriminator, True) 218 | 219 | if (i + 1) % 100 == 0: 220 | images = [] 221 | 222 | gen_i, gen_j = args.gen_sample.get(resolution, (10, 5)) 223 | 224 | with torch.no_grad(): 225 | for _ in range(gen_i): 226 | images.append( 227 | g_running( 228 | torch.randn(gen_j, code_size).cuda(), step=step, alpha=alpha 229 | ).data.cpu() 230 | ) 231 | 232 | utils.save_image( 233 | torch.cat(images, 0), 234 | f'sample/{str(i + 1).zfill(6)}.png', 235 | nrow=gen_i, 236 | normalize=True, 237 | range=(-1, 1), 238 | ) 239 | 240 | if (i + 1) % 10000 == 0: 241 | torch.save( 242 | g_running.state_dict(), f'checkpoint/{str(i + 1).zfill(6)}.model' 243 | ) 244 | 245 | state_msg = ( 246 | f'Size: {4 * 2 ** step}; G: {gen_loss_val:.3f}; D: {disc_loss_val:.3f};' 247 | f' Grad: {grad_loss_val:.3f}; Alpha: {alpha:.5f}' 248 | ) 249 | 250 | pbar.set_description(state_msg) 251 | 252 | 253 | if __name__ == '__main__': 254 | code_size = 512 255 | batch_size = 16 256 | n_critic = 1 257 | 258 | parser = argparse.ArgumentParser(description='Progressive Growing of GANs') 259 | 260 | parser.add_argument('path', type=str, help='path of specified dataset') 261 | parser.add_argument( 262 | '--phase', 263 | type=int, 264 | default=600_000, 265 | help='number of samples used for each training phases', 266 | ) 267 | parser.add_argument('--lr', default=0.001, type=float, help='learning rate') 268 | parser.add_argument('--sched', action='store_true', help='use lr scheduling') 269 | parser.add_argument('--init_size', default=8, type=int, help='initial image size') 270 | parser.add_argument('--max_size', default=1024, type=int, help='max image size') 271 | parser.add_argument( 272 | '--ckpt', default=None, type=str, help='load from previous checkpoints' 273 | ) 274 | parser.add_argument( 275 | '--no_from_rgb_activate', 276 | action='store_true', 277 | help='use activate in from_rgb (original implementation)', 278 | ) 279 | parser.add_argument( 280 | '--mixing', action='store_true', help='use mixing regularization' 281 | ) 282 | parser.add_argument( 283 | '--loss', 284 | type=str, 285 | default='wgan-gp', 286 | choices=['wgan-gp', 'r1'], 287 | help='class of gan loss', 288 | ) 289 | 290 | args = parser.parse_args() 291 | 292 | generator = nn.DataParallel(StyledGenerator(code_size)).cuda() 293 | discriminator = nn.DataParallel( 294 | Discriminator(from_rgb_activate=not args.no_from_rgb_activate) 295 | ).cuda() 296 | g_running = StyledGenerator(code_size).cuda() 297 | g_running.train(False) 298 | 299 | g_optimizer = optim.Adam( 300 | generator.module.generator.parameters(), lr=args.lr, betas=(0.0, 0.99) 301 | ) 302 | g_optimizer.add_param_group( 303 | { 304 | 'params': generator.module.style.parameters(), 305 | 'lr': args.lr * 0.01, 306 | 'mult': 0.01, 307 | } 308 | ) 309 | d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr, betas=(0.0, 0.99)) 310 | 311 | accumulate(g_running, generator.module, 0) 312 | 313 | if args.ckpt is not None: 314 | ckpt = torch.load(args.ckpt) 315 | 316 | generator.module.load_state_dict(ckpt['generator']) 317 | discriminator.module.load_state_dict(ckpt['discriminator']) 318 | g_running.load_state_dict(ckpt['g_running']) 319 | g_optimizer.load_state_dict(ckpt['g_optimizer']) 320 | d_optimizer.load_state_dict(ckpt['d_optimizer']) 321 | 322 | transform = transforms.Compose( 323 | [ 324 | transforms.RandomHorizontalFlip(), 325 | transforms.ToTensor(), 326 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 327 | ] 328 | ) 329 | 330 | dataset = MultiResolutionDataset(args.path, transform) 331 | 332 | if args.sched: 333 | args.lr = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} 334 | args.batch = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 32, 256: 32} 335 | 336 | else: 337 | args.lr = {} 338 | args.batch = {} 339 | 340 | args.gen_sample = {512: (8, 4), 1024: (4, 2)} 341 | 342 | args.batch_default = 32 343 | 344 | train(args, dataset, generator, discriminator) 345 | --------------------------------------------------------------------------------