├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── checkpoint └── .gitignore ├── config.py ├── config ├── config-r.jsonnet └── config-t.jsonnet ├── doc ├── sample1.gif ├── sample2.gif ├── sample3.gif ├── sample4.gif ├── sample5.gif └── sample6.gif ├── generate.py ├── model.py ├── op ├── __init__.py ├── filtered_lrelu.cpp ├── filtered_lrelu.cu ├── filtered_lrelu.h ├── filtered_lrelu.py ├── filtered_lrelu_ns.cu ├── filtered_lrelu_rd.cu └── filtered_lrelu_wr.cu ├── prepare_data.py ├── sample.webm ├── sample └── .gitignore ├── stylegan2 ├── .gitignore ├── LICENSE ├── LICENSE-FID ├── LICENSE-LPIPS ├── LICENSE-NVIDIA ├── README.md ├── apply_factor.py ├── calc_inception.py ├── checkpoint │ └── .gitignore ├── closed_form_factorization.py ├── convert_weight.py ├── dataset.py ├── distributed.py ├── doc │ ├── sample-metfaces.png │ ├── sample.png │ ├── stylegan2-church-config-f.png │ └── stylegan2-ffhq-config-f.png ├── factor_index-13_degree-5.0.png ├── fid.py ├── generate.py ├── inception.py ├── inception_ffhq.pkl ├── lpips │ ├── __init__.py │ ├── base_model.py │ ├── dist_model.py │ ├── networks_basic.py │ ├── pretrained_networks.py │ └── weights │ │ ├── v0.0 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth │ │ └── v0.1 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth ├── model.py ├── non_leaking.py ├── op │ ├── __init__.py │ ├── conv2d_gradfix.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── ppl.py ├── prepare_data.py ├── projector.py ├── sample │ └── .gitignore ├── swagan.py └── train.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | wandb/ 132 | train_orig.py 133 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04 2 | 3 | ARG APT_INSTALL="apt-get install -y --no-install-recommends" 4 | ARG PIP_INSTALL="python -m pip --no-cache-dir install --upgrade" 5 | ARG GIT_CLONE="git clone --depth 10" 6 | 7 | ENV HOME /root 8 | 9 | WORKDIR $HOME 10 | 11 | RUN rm -rf /var/lib/apt/lists/* \ 12 | /etc/apt/sources.list.d/cuda.list \ 13 | /etc/apt/sources.list.d/nvidia-ml.list 14 | 15 | RUN apt-get update 16 | 17 | ARG DEBIAN_FRONTEND=noninteractive 18 | 19 | RUN $APT_INSTALL build-essential software-properties-common ca-certificates \ 20 | wget git zlib1g-dev nasm cmake 21 | 22 | RUN add-apt-repository ppa:deadsnakes/ppa 23 | 24 | RUN apt-get update 25 | 26 | RUN $APT_INSTALL python3.7 python3.7-dev 27 | 28 | RUN wget -O $HOME/get-pip.py https://bootstrap.pypa.io/get-pip.py 29 | 30 | RUN python3.7 $HOME/get-pip.py 31 | 32 | RUN ln -s /usr/bin/python3.7 /usr/local/bin/python3 33 | RUN ln -s /usr/bin/python3.7 /usr/local/bin/python 34 | 35 | RUN $PIP_INSTALL setuptools 36 | RUN $PIP_INSTALL numpy scipy nltk lmdb cython pydantic pyhocon 37 | 38 | RUN $PIP_INSTALL torch==1.7.1+cu92 torchvision==0.8.2+cu92 -f https://download.pytorch.org/whl/torch_stable.html 39 | 40 | RUN $PIP_INSTALL tensorfn rich 41 | 42 | ENV FORCE_CUDA="1" 43 | ENV TORCH_CUDA_ARCH_LIST="Pascal;Volta;Turing" 44 | 45 | RUN $APT_INSTALL libsm6 libxext6 libxrender1 46 | RUN $PIP_INSTALL opencv-python-headless 47 | 48 | RUN python -m pip uninstall -y pillow pil jpeg libtiff libjpeg-turbo 49 | 50 | RUN $GIT_CLONE https://github.com/libjpeg-turbo/libjpeg-turbo.git 51 | WORKDIR libjpeg-turbo 52 | RUN mkdir build 53 | WORKDIR build 54 | RUN cmake -G"Unix Makefiles" -DCMAKE_INSTALL_PREFIX=libjpeg-turbo -DWITH_JPEG8=1 .. 55 | RUN make 56 | RUN make install 57 | WORKDIR libjpeg-turbo 58 | RUN mv include/jerror.h include/jmorecfg.h include/jpeglib.h include/turbojpeg.h /usr/include 59 | RUN mv include/jconfig.h /usr/include/x86_64-linux-gnu 60 | RUN mv lib/*.* /usr/lib/x86_64-linux-gnu 61 | RUN mv lib/pkgconfig/* /usr/lib/x86_64-linux-gnu/pkgconfig 62 | RUN ldconfig 63 | 64 | RUN CFLAGS="${CFLAGS} -mavx2" $PIP_INSTALL --force-reinstall --no-binary :all: --compile pillow-simd 65 | 66 | WORKDIR $HOME 67 | 68 | RUN ldconfig 69 | RUN apt-get clean 70 | RUN apt-get autoremove 71 | RUN rm -rf /var/lib/apt/lists/* /tmp/* ~/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | ### LICENSE for filtered_lrelu custom kernel 24 | 25 | Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved. 26 | 27 | 28 | NVIDIA Source Code License for StyleGAN3 29 | 30 | 31 | ======================================================================= 32 | 33 | 1. Definitions 34 | 35 | "Licensor" means any person or entity that distributes its Work. 36 | 37 | "Software" means the original work of authorship made available under 38 | this License. 39 | 40 | "Work" means the Software and any additions to or derivative works of 41 | the Software that are made available under this License. 42 | 43 | The terms "reproduce," "reproduction," "derivative works," and 44 | "distribution" have the meaning as provided under U.S. copyright law; 45 | provided, however, that for the purposes of this License, derivative 46 | works shall not include works that remain separable from, or merely 47 | link (or bind by name) to the interfaces of, the Work. 48 | 49 | Works, including the Software, are "made available" under this License 50 | by including in or with the Work either (a) a copyright notice 51 | referencing the applicability of this License to the Work, or (b) a 52 | copy of this License. 53 | 54 | 2. License Grants 55 | 56 | 2.1 Copyright Grant. Subject to the terms and conditions of this 57 | License, each Licensor grants to you a perpetual, worldwide, 58 | non-exclusive, royalty-free, copyright license to reproduce, 59 | prepare derivative works of, publicly display, publicly perform, 60 | sublicense and distribute its Work and any resulting derivative 61 | works in any form. 62 | 63 | 3. Limitations 64 | 65 | 3.1 Redistribution. You may reproduce or distribute the Work only 66 | if (a) you do so under this License, (b) you include a complete 67 | copy of this License with your distribution, and (c) you retain 68 | without modification any copyright, patent, trademark, or 69 | attribution notices that are present in the Work. 70 | 71 | 3.2 Derivative Works. You may specify that additional or different 72 | terms apply to the use, reproduction, and distribution of your 73 | derivative works of the Work ("Your Terms") only if (a) Your Terms 74 | provide that the use limitation in Section 3.3 applies to your 75 | derivative works, and (b) you identify the specific derivative 76 | works that are subject to Your Terms. Notwithstanding Your Terms, 77 | this License (including the redistribution requirements in Section 78 | 3.1) will continue to apply to the Work itself. 79 | 80 | 3.3 Use Limitation. The Work and any derivative works thereof only 81 | may be used or intended for use non-commercially. Notwithstanding 82 | the foregoing, NVIDIA and its affiliates may use the Work and any 83 | derivative works commercially. As used herein, "non-commercially" 84 | means for research or evaluation purposes only. 85 | 86 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 87 | against any Licensor (including any claim, cross-claim or 88 | counterclaim in a lawsuit) to enforce any patents that you allege 89 | are infringed by any Work, then your rights under this License from 90 | such Licensor (including the grant in Section 2.1) will terminate 91 | immediately. 92 | 93 | 3.5 Trademarks. This License does not grant any rights to use any 94 | Licensor’s or its affiliates’ names, logos, or trademarks, except 95 | as necessary to reproduce the notices described in this License. 96 | 97 | 3.6 Termination. If you violate any term of this License, then your 98 | rights under this License (including the grant in Section 2.1) will 99 | terminate immediately. 100 | 101 | 4. Disclaimer of Warranty. 102 | 103 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 104 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 105 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 106 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 107 | THIS LICENSE. 108 | 109 | 5. Limitation of Liability. 110 | 111 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 112 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 113 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 114 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 115 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 116 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 117 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 118 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 119 | THE POSSIBILITY OF SUCH DAMAGES. 120 | 121 | ======================================================================= -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # alias-free-gan-pytorch 2 | 3 | Unofficial implementation of Alias-Free Generative Adversarial Networks. (https://arxiv.org/abs/2106.12423) This implementation contains a lot of my guesses, so I think there are many differences to the official implementations 4 | 5 | ## Usage 6 | 7 | First create lmdb datasets: 8 | 9 | > python prepare_data.py --out LMDB_PATH --n_worker N_WORKER --size SIZE1,SIZE2,SIZE3,... DATASET_PATH 10 | 11 | This will convert images to jpeg and pre-resizes it. This implementation does not use progressive growing, but you can create multiple resolution datasets using size arguments with comma separated lists, for the cases that you want to try another resolutions later. 12 | 13 | Then you can train model in distributed settings 14 | 15 | > python train.py --n_gpu N_GPU --conf config/config-t.jsonnet training.batch=BATCH_SIZE path=LMDB_PATH 16 | 17 | train.py supports Weights & Biases logging. If you want to use it, add `wandb=true` arguments to the script. 18 | 19 | ## Sample 20 | 21 | ![Latent translation sample 1](doc/sample1.gif) 22 | ![Latent translation sample 2](doc/sample2.gif) 23 | ![Latent translation sample 3](doc/sample3.gif) 24 | ![Latent translation sample 4](doc/sample4.gif) 25 | ![Latent translation sample 5](doc/sample5.gif) 26 | ![Latent translation sample 6](doc/sample6.gif) -------------------------------------------------------------------------------- /checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional, Tuple, Sequence, List 2 | 3 | from tensorfn.config import ( 4 | get_models, 5 | get_model, 6 | MainConfig, 7 | Config, 8 | Optimizer, 9 | Scheduler, 10 | DataLoader, 11 | checker, 12 | Checker, 13 | TypedConfig, 14 | Instance, 15 | ) 16 | from pydantic import StrictStr, StrictInt, StrictBool 17 | 18 | 19 | class Training(Config): 20 | size: StrictInt 21 | iter: StrictInt = 800000 22 | batch: StrictInt = 16 23 | n_sample: StrictInt = 32 24 | r1: float = 10 25 | d_reg_every: StrictInt = 16 26 | lr_g: float = 2e-3 27 | lr_d: float = 2e-3 28 | augment: StrictBool = False 29 | augment_p: float = 0 30 | ada_target: float = 0.6 31 | ada_length: StrictInt = 500 * 1000 32 | ada_every: StrictInt = 256 33 | start_iter: StrictInt = 0 34 | 35 | 36 | class GANConfig(MainConfig): 37 | generator: Instance 38 | discriminator: Instance 39 | training: Training 40 | path: StrictStr = None 41 | wandb: StrictBool = False 42 | logger: StrictStr = "rich" 43 | -------------------------------------------------------------------------------- /config/config-r.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | generator: { 3 | __target: 'model.Generator', 4 | style_dim: 512, 5 | n_mlp: 2, 6 | kernel_size: 1, 7 | n_taps: 6, 8 | filter_parameters: { 9 | __target: 'model.filter_parameters', 10 | n_layer: 14, 11 | n_critical: 2, 12 | sr_max: $.training.size, 13 | cutoff_0: 2, 14 | cutoff_n: self.sr_max / 2, 15 | stopband_0: std.pow(2, 2.1), 16 | stopband_n: self.cutoff_n * std.pow(2, 0.3), 17 | channel_max: 1024, 18 | channel_base: std.pow(2, 15) 19 | }, 20 | margin: 10, 21 | lr_mlp: 0.01, 22 | use_jinc: true 23 | }, 24 | 25 | discriminator: { 26 | __target: 'stylegan2.model.Discriminator', 27 | size: $.training.size, 28 | channel_multiplier: 2 29 | }, 30 | 31 | training: { 32 | size: 256, 33 | iter: 800000, 34 | batch: 16, 35 | n_sample: 32, 36 | r1: 2, 37 | d_reg_every: 16, 38 | lr_g: 3e-3, 39 | lr_d: 2.5e-3, 40 | augment: false, 41 | augment_p: 0, 42 | ada_target: 0.6, 43 | ada_length: 500 * 1000, 44 | ada_every: 256, 45 | start_iter: 0 46 | } 47 | } -------------------------------------------------------------------------------- /config/config-t.jsonnet: -------------------------------------------------------------------------------- 1 | { 2 | generator: { 3 | __target: 'model.Generator', 4 | style_dim: 512, 5 | n_mlp: 2, 6 | kernel_size: 3, 7 | n_taps: 6, 8 | filter_parameters: { 9 | __target: 'model.filter_parameters', 10 | n_layer: 14, 11 | n_critical: 2, 12 | sr_max: $.training.size, 13 | cutoff_0: 2, 14 | cutoff_n: self.sr_max / 2, 15 | stopband_0: std.pow(2, 2.1), 16 | stopband_n: self.cutoff_n * std.pow(2, 0.3), 17 | channel_max: 512, 18 | channel_base: std.pow(2, 14) 19 | }, 20 | margin: 10, 21 | lr_mlp: 0.01, 22 | ema: std.pow(0.5, 32 / (20 * 1e3)) 23 | }, 24 | 25 | discriminator: { 26 | __target: 'stylegan2.model.Discriminator', 27 | size: $.training.size, 28 | channel_multiplier: 2 29 | }, 30 | 31 | training: { 32 | size: 256, 33 | iter: 800000, 34 | batch: 16, 35 | n_sample: 32, 36 | r1: 10, 37 | d_reg_every: 16, 38 | lr_g: 2.5e-3, 39 | lr_d: 2.5e-3, 40 | augment: false, 41 | augment_p: 0, 42 | ada_target: 0.6, 43 | ada_length: 500 * 1000, 44 | ada_every: 256, 45 | start_iter: 0 46 | } 47 | } -------------------------------------------------------------------------------- /doc/sample1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/alias-free-gan-pytorch/f14d54ce2d973880b0c352614b2d63088c9026ae/doc/sample1.gif -------------------------------------------------------------------------------- /doc/sample2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/alias-free-gan-pytorch/f14d54ce2d973880b0c352614b2d63088c9026ae/doc/sample2.gif -------------------------------------------------------------------------------- /doc/sample3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/alias-free-gan-pytorch/f14d54ce2d973880b0c352614b2d63088c9026ae/doc/sample3.gif -------------------------------------------------------------------------------- /doc/sample4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/alias-free-gan-pytorch/f14d54ce2d973880b0c352614b2d63088c9026ae/doc/sample4.gif -------------------------------------------------------------------------------- /doc/sample5.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/alias-free-gan-pytorch/f14d54ce2d973880b0c352614b2d63088c9026ae/doc/sample5.gif -------------------------------------------------------------------------------- /doc/sample6.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/alias-free-gan-pytorch/f14d54ce2d973880b0c352614b2d63088c9026ae/doc/sample6.gif -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torchvision import utils 5 | import cv2 6 | from tqdm import tqdm 7 | import numpy as np 8 | 9 | from model import Generator 10 | from config import GANConfig 11 | 12 | if __name__ == "__main__": 13 | device = "cuda" 14 | 15 | parser = argparse.ArgumentParser(description="Generate samples from the generator") 16 | 17 | parser.add_argument( 18 | "--n_img", type=int, default=16, help="number of images to be generated" 19 | ) 20 | parser.add_argument( 21 | "--n_row", type=int, default=4, help="number of samples per row" 22 | ) 23 | parser.add_argument( 24 | "--truncation", type=float, default=0.5, help="truncation ratio" 25 | ) 26 | parser.add_argument( 27 | "--truncation_mean", 28 | type=int, 29 | default=4096, 30 | help="number of vectors to calculate mean for the truncation", 31 | ) 32 | parser.add_argument("--n_frame", type=int, default=120) 33 | parser.add_argument("--radius", type=float, default=30) 34 | parser.add_argument( 35 | "ckpt", metavar="CKPT", type=str, help="path to the model checkpoint" 36 | ) 37 | 38 | args = parser.parse_args() 39 | 40 | ckpt = torch.load(args.ckpt, map_location=lambda storage, loc: storage) 41 | conf = GANConfig(**ckpt["conf"]) 42 | generator = conf.generator.make().to(device) 43 | generator.load_state_dict(ckpt["g_ema"], strict=False) 44 | generator.eval() 45 | 46 | mean_latent = generator.mean_latent(args.truncation_mean) 47 | x = torch.randn(args.n_img, conf.generator["style_dim"], device=device) 48 | 49 | theta = np.radians(np.linspace(0, 360, args.n_frame)) 50 | x_2 = np.cos(theta) * args.radius 51 | y_2 = np.sin(theta) * args.radius 52 | 53 | trans_x = x_2.tolist() 54 | trans_y = y_2.tolist() 55 | 56 | images = [] 57 | 58 | transform_p = generator.get_transform( 59 | x, truncation=args.truncation, truncation_latent=mean_latent 60 | ) 61 | 62 | with torch.no_grad(): 63 | for i, (t_x, t_y) in enumerate(tqdm(zip(trans_x, trans_y), total=args.n_frame)): 64 | transform_p[:, 2] = t_y 65 | transform_p[:, 3] = t_x 66 | 67 | img = generator( 68 | x, 69 | truncation=args.truncation, 70 | truncation_latent=mean_latent, 71 | transform=transform_p, 72 | ) 73 | images.append( 74 | utils.make_grid( 75 | img.cpu(), normalize=True, nrow=args.n_row, value_range=(-1, 1) 76 | ) 77 | .mul(255) 78 | .permute(1, 2, 0) 79 | .numpy() 80 | .astype(np.uint8) 81 | ) 82 | 83 | videodims = (images[0].shape[1], images[0].shape[0]) 84 | fourcc = cv2.VideoWriter_fourcc(*"VP90") 85 | video = cv2.VideoWriter("sample.webm", fourcc, 24, videodims) 86 | 87 | for i in tqdm(images): 88 | video.write(cv2.cvtColor(i, cv2.COLOR_RGB2BGR)) 89 | 90 | video.release() 91 | -------------------------------------------------------------------------------- /op/__init__.py: -------------------------------------------------------------------------------- 1 | from .filtered_lrelu import filtered_lrelu 2 | -------------------------------------------------------------------------------- /op/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /op/filtered_lrelu.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | from collections import abc 11 | 12 | import torch 13 | from torch.nn import functional as F 14 | from torch.autograd import Function 15 | from torch.utils.cpp_extension import load 16 | 17 | from stylegan2.op import upfirdn2d, fused_leaky_relu 18 | 19 | module_path = os.path.dirname(__file__) 20 | filtered_lrelu_op = load( 21 | "filtered_lrelu", 22 | sources=[ 23 | os.path.join(module_path, "filtered_lrelu.cpp"), 24 | os.path.join(module_path, "filtered_lrelu_wr.cu"), 25 | os.path.join(module_path, "filtered_lrelu_rd.cu"), 26 | os.path.join(module_path, "filtered_lrelu_ns.cu"), 27 | # os.path.join(module_path, "filtered_lrelu.h"), 28 | os.path.join(module_path, "filtered_lrelu.cu"), 29 | ], 30 | extra_cuda_cflags=["--use_fast_math"], 31 | ) 32 | 33 | 34 | def format_padding(padding): 35 | if not isinstance(padding, abc.Iterable): 36 | padding = (padding, padding) 37 | 38 | if len(padding) == 2: 39 | padding = (padding[0], padding[0], padding[1], padding[1]) 40 | 41 | return padding 42 | 43 | 44 | def filtered_lrelu( 45 | x, bias, up_filter, down_filter, up, down, padding, negative_slope=0.2, clamp=None 46 | ): 47 | padding = format_padding(padding) 48 | 49 | if x.device.type == "cuda": 50 | return FilteredLReLU.apply( 51 | x, 52 | bias, 53 | up_filter, 54 | down_filter, 55 | up, 56 | down, 57 | padding, 58 | negative_slope, 59 | 2 ** 0.5, 60 | clamp, 61 | False, 62 | None, 63 | 0, 64 | 0, 65 | ) 66 | 67 | return filtered_lrelu_upfirdn2d( 68 | x, bias, up_filter, down_filter, up, down, padding, negative_slope, clamp 69 | ) 70 | 71 | 72 | def filtered_lrelu_upfirdn2d( 73 | x, bias, up_filter, down_filter, up, down, padding, negative_slope=0.2, clamp=None 74 | ): 75 | if bias is not None: 76 | x = x + bias.view(1, -1, 1, 1) 77 | 78 | x = upfirdn2d(x, up_filter, up=up, pad=padding, gain=up ** 2) 79 | x = fused_leaky_relu(x, negative_slope=negative_slope) 80 | 81 | if clamp is not None: 82 | x = x.clamp(-clamp, clamp) 83 | 84 | x = upfirdn2d(x, down_filter, down=down) 85 | 86 | return x 87 | 88 | 89 | class FilteredLReLU(Function): 90 | @staticmethod 91 | def forward( 92 | ctx, 93 | x, 94 | bias, 95 | up_filter, 96 | down_filter, 97 | up, 98 | down, 99 | padding, 100 | negative_slope, 101 | gain, 102 | clamp, 103 | flip_filter, 104 | sign, 105 | sign_offset_x, 106 | sign_offset_y, 107 | ): 108 | if up_filter is None: 109 | up_filter = torch.ones([1, 1], dtype=torch.float32, device=x.device) 110 | 111 | if down_filter is None: 112 | down_filter = torch.ones([1, 1], dtype=torch.float32, device=x.device) 113 | 114 | if up == 1 and up_filter.ndim == 1 and up_filter.shape[0] == 1: 115 | up_filter = up_filter.square()[None] 116 | 117 | if down == 1 and down_filter.ndim == 1 and down_filter.shape[0] == 1: 118 | down_filter = down_filter.square()[None] 119 | 120 | clamp = float(clamp if clamp is not None else "inf") 121 | 122 | if sign is None: 123 | sign = torch.empty([0]) 124 | 125 | if bias is None: 126 | bias = torch.zeros([x.shape[1]], dtype=x.dtype, device=x.device) 127 | 128 | write_signs = (sign.numel() == 0) and (x.requires_grad or bias.requires_grad) 129 | 130 | # strides = [x.stride(i) for i in range(x.ndim) if x.shape[i] > 1] 131 | # if any(a < b for a, b in zip(strides[:-1], strides[1:])): 132 | 133 | pad_x0, pad_x1, pad_y0, pad_y1 = padding 134 | 135 | if x.dtype in (torch.float16, torch.float32): 136 | # if torch.cuda.current_stream(x.device) != torch.cuda.default_stream(x.device): 137 | 138 | y, sign_out, return_code = filtered_lrelu_op.filtered_lrelu( 139 | x, 140 | up_filter, 141 | down_filter, 142 | bias, 143 | sign, 144 | up, 145 | down, 146 | pad_x0, 147 | pad_x1, 148 | pad_y0, 149 | pad_y1, 150 | sign_offset_x, 151 | sign_offset_y, 152 | gain, 153 | negative_slope, 154 | clamp, 155 | flip_filter, 156 | write_signs, 157 | ) 158 | 159 | else: 160 | return_code = -1 161 | 162 | if return_code < 0: 163 | y = x + bias.view(-1, 1, 1) 164 | y = upfirdn2d( 165 | y, up_filter, up=up, pad=padding, flip_filter=flip_filter, gain=up ** 2 166 | ) 167 | sign_out = filtered_lrelu_op.filtered_lrelu_act_( 168 | y, 169 | sign, 170 | sign_offset_x, 171 | sign_offset_y, 172 | gain, 173 | negative_slope, 174 | clamp, 175 | write_signs, 176 | ) 177 | y = upfirdn2d(y, down_filter, down=down, flip_filter=flip_filter) 178 | 179 | ctx.save_for_backward( 180 | up_filter, down_filter, (sign if sign.numel() else sign_out) 181 | ) 182 | ctx.x_shape = x.shape 183 | ctx.y_shape = y.shape 184 | ctx.sign_offsets = sign_offset_x, sign_offset_y 185 | ctx.padding = padding 186 | ctx.args = up, down, negative_slope, gain, flip_filter 187 | 188 | return y 189 | 190 | @staticmethod 191 | def backward(ctx, dy): 192 | up_filter, down_filter, sign = ctx.saved_tensors 193 | _, _, x_h, x_w = ctx.x_shape 194 | _, _, y_h, y_w = ctx.y_shape 195 | sign_offset_x, sign_offset_y = ctx.sign_offsets 196 | pad_x0, pad_x1, pad_y0, pad_y1 = ctx.padding 197 | up, down, negative_slope, gain, flip_filter = ctx.args 198 | 199 | dx = None 200 | dup_filter = None 201 | ddown_filter = None 202 | dbias = None 203 | dsign = None 204 | dsign_offset_x = None 205 | dsign_offset_y = None 206 | 207 | if ctx.needs_input_grad[0] or ctx.need_input_grad[1]: 208 | padding = [ 209 | (up_filter.shape[-1] - 1) + (down_filter.shape[-1] - 1) - pad_x0, 210 | x_w * up - y_w * down + pad_x0 - (up - 1), 211 | (up_filter.shape[0] - 1) + (down_filter.shape[0] - 1) - pad_y0, 212 | x_h * up - y_h * down + pad_y0 - (up - 1), 213 | ] 214 | gain2 = gain * (up ** 2) / (down ** 2) 215 | sign_offset_x = sign_offset_x - (up_filter.shape[-1] - 1) + pad_x0 216 | sign_offset_y = sign_offset_y - (up_filter.shape[0] - 1) + pad_y0 217 | 218 | dx = FilteredLReLU.apply( 219 | dy, 220 | None, 221 | down_filter, 222 | up_filter, 223 | down, 224 | up, 225 | padding, 226 | negative_slope, 227 | gain2, 228 | None, 229 | not flip_filter, 230 | sign, 231 | sign_offset_x, 232 | sign_offset_y, 233 | ) 234 | 235 | if ctx.needs_input_grad[1]: 236 | dbias = dx.sum((0, 2, 3)) 237 | 238 | return ( 239 | dx, 240 | dbias, 241 | dup_filter, 242 | ddown_filter, 243 | None, 244 | None, 245 | None, 246 | None, 247 | None, 248 | None, 249 | None, 250 | dsign, 251 | dsign_offset_x, 252 | dsign_offset_y, 253 | ) 254 | -------------------------------------------------------------------------------- /op/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /op/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /op/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from io import BytesIO 3 | import multiprocessing 4 | from functools import partial 5 | 6 | from PIL import Image 7 | import lmdb 8 | from tqdm import tqdm 9 | from torchvision import datasets 10 | from torchvision.transforms import functional as trans_fn 11 | 12 | 13 | def resize_and_convert(img, size, resample, quality=100): 14 | img = trans_fn.resize(img, size, resample) 15 | img = trans_fn.center_crop(img, size) 16 | buffer = BytesIO() 17 | img.save(buffer, format="jpeg", quality=quality) 18 | val = buffer.getvalue() 19 | 20 | return val 21 | 22 | 23 | def resize_multiple( 24 | img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100 25 | ): 26 | imgs = [] 27 | 28 | for size in sizes: 29 | imgs.append(resize_and_convert(img, size, resample, quality)) 30 | 31 | return imgs 32 | 33 | 34 | def resize_worker(img_file, sizes, resample): 35 | i, file = img_file 36 | img = Image.open(file) 37 | img = img.convert("RGB") 38 | out = resize_multiple(img, sizes=sizes, resample=resample) 39 | 40 | return i, out 41 | 42 | 43 | def prepare( 44 | env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS 45 | ): 46 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample) 47 | 48 | files = sorted(dataset.imgs, key=lambda x: x[0]) 49 | files = [(i, file) for i, (file, label) in enumerate(files)] 50 | total = 0 51 | 52 | with multiprocessing.Pool(n_worker) as pool: 53 | for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)): 54 | for size, img in zip(sizes, imgs): 55 | key = f"{size}-{str(i).zfill(5)}".encode("utf-8") 56 | 57 | with env.begin(write=True) as txn: 58 | txn.put(key, img) 59 | 60 | total += 1 61 | 62 | with env.begin(write=True) as txn: 63 | txn.put("length".encode("utf-8"), str(total).encode("utf-8")) 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = argparse.ArgumentParser(description="Preprocess images for model training") 68 | parser.add_argument("--out", type=str, help="filename of the result lmdb dataset") 69 | parser.add_argument( 70 | "--size", 71 | type=str, 72 | default="128,256,512,1024", 73 | help="resolutions of images for the dataset", 74 | ) 75 | parser.add_argument( 76 | "--n_worker", 77 | type=int, 78 | default=8, 79 | help="number of workers for preparing dataset", 80 | ) 81 | parser.add_argument( 82 | "--resample", 83 | type=str, 84 | default="lanczos", 85 | help="resampling methods for resizing images", 86 | ) 87 | parser.add_argument("path", type=str, help="path to the image dataset") 88 | 89 | args = parser.parse_args() 90 | 91 | resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR} 92 | resample = resample_map[args.resample] 93 | 94 | sizes = [int(s.strip()) for s in args.size.split(",")] 95 | 96 | print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes)) 97 | 98 | imgset = datasets.ImageFolder(args.path) 99 | 100 | with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env: 101 | prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample) 102 | -------------------------------------------------------------------------------- /sample.webm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/alias-free-gan-pytorch/f14d54ce2d973880b0c352614b2d63088c9026ae/sample.webm -------------------------------------------------------------------------------- /sample/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | -------------------------------------------------------------------------------- /stylegan2/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | wandb/ 132 | *.lmdb/ 133 | *.pkl 134 | -------------------------------------------------------------------------------- /stylegan2/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /stylegan2/LICENSE-FID: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /stylegan2/LICENSE-LPIPS: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | -------------------------------------------------------------------------------- /stylegan2/LICENSE-NVIDIA: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | Nvidia Source Code License-NC 5 | 6 | ======================================================================= 7 | 8 | 1. Definitions 9 | 10 | "Licensor" means any person or entity that distributes its Work. 11 | 12 | "Software" means the original work of authorship made available under 13 | this License. 14 | 15 | "Work" means the Software and any additions to or derivative works of 16 | the Software that are made available under this License. 17 | 18 | "Nvidia Processors" means any central processing unit (CPU), graphics 19 | processing unit (GPU), field-programmable gate array (FPGA), 20 | application-specific integrated circuit (ASIC) or any combination 21 | thereof designed, made, sold, or provided by Nvidia or its affiliates. 22 | 23 | The terms "reproduce," "reproduction," "derivative works," and 24 | "distribution" have the meaning as provided under U.S. copyright law; 25 | provided, however, that for the purposes of this License, derivative 26 | works shall not include works that remain separable from, or merely 27 | link (or bind by name) to the interfaces of, the Work. 28 | 29 | Works, including the Software, are "made available" under this License 30 | by including in or with the Work either (a) a copyright notice 31 | referencing the applicability of this License to the Work, or (b) a 32 | copy of this License. 33 | 34 | 2. License Grants 35 | 36 | 2.1 Copyright Grant. Subject to the terms and conditions of this 37 | License, each Licensor grants to you a perpetual, worldwide, 38 | non-exclusive, royalty-free, copyright license to reproduce, 39 | prepare derivative works of, publicly display, publicly perform, 40 | sublicense and distribute its Work and any resulting derivative 41 | works in any form. 42 | 43 | 3. Limitations 44 | 45 | 3.1 Redistribution. You may reproduce or distribute the Work only 46 | if (a) you do so under this License, (b) you include a complete 47 | copy of this License with your distribution, and (c) you retain 48 | without modification any copyright, patent, trademark, or 49 | attribution notices that are present in the Work. 50 | 51 | 3.2 Derivative Works. You may specify that additional or different 52 | terms apply to the use, reproduction, and distribution of your 53 | derivative works of the Work ("Your Terms") only if (a) Your Terms 54 | provide that the use limitation in Section 3.3 applies to your 55 | derivative works, and (b) you identify the specific derivative 56 | works that are subject to Your Terms. Notwithstanding Your Terms, 57 | this License (including the redistribution requirements in Section 58 | 3.1) will continue to apply to the Work itself. 59 | 60 | 3.3 Use Limitation. The Work and any derivative works thereof only 61 | may be used or intended for use non-commercially. The Work or 62 | derivative works thereof may be used or intended for use by Nvidia 63 | or its affiliates commercially or non-commercially. As used herein, 64 | "non-commercially" means for research or evaluation purposes only. 65 | 66 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 67 | against any Licensor (including any claim, cross-claim or 68 | counterclaim in a lawsuit) to enforce any patents that you allege 69 | are infringed by any Work, then your rights under this License from 70 | such Licensor (including the grants in Sections 2.1 and 2.2) will 71 | terminate immediately. 72 | 73 | 3.5 Trademarks. This License does not grant any rights to use any 74 | Licensor's or its affiliates' names, logos, or trademarks, except 75 | as necessary to reproduce the notices described in this License. 76 | 77 | 3.6 Termination. If you violate any term of this License, then your 78 | rights under this License (including the grants in Sections 2.1 and 79 | 2.2) will terminate immediately. 80 | 81 | 4. Disclaimer of Warranty. 82 | 83 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 84 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 85 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 86 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 87 | THIS LICENSE. 88 | 89 | 5. Limitation of Liability. 90 | 91 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 92 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 93 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 94 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 95 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 96 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 97 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 98 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 99 | THE POSSIBILITY OF SUCH DAMAGES. 100 | 101 | ======================================================================= 102 | -------------------------------------------------------------------------------- /stylegan2/README.md: -------------------------------------------------------------------------------- 1 | # StyleGAN 2 in PyTorch 2 | 3 | Implementation of Analyzing and Improving the Image Quality of StyleGAN (https://arxiv.org/abs/1912.04958) in PyTorch 4 | 5 | ## Notice 6 | 7 | I have tried to match official implementation as close as possible, but maybe there are some details I missed. So please use this implementation with care. 8 | 9 | ## Requirements 10 | 11 | I have tested on: 12 | 13 | - PyTorch 1.3.1 14 | - CUDA 10.1/10.2 15 | 16 | ## Usage 17 | 18 | First create lmdb datasets: 19 | 20 | > python prepare_data.py --out LMDB_PATH --n_worker N_WORKER --size SIZE1,SIZE2,SIZE3,... DATASET_PATH 21 | 22 | This will convert images to jpeg and pre-resizes it. This implementation does not use progressive growing, but you can create multiple resolution datasets using size arguments with comma separated lists, for the cases that you want to try another resolutions later. 23 | 24 | Then you can train model in distributed settings 25 | 26 | > python -m torch.distributed.launch --nproc_per_node=N_GPU --master_port=PORT train.py --batch BATCH_SIZE LMDB_PATH 27 | 28 | train.py supports Weights & Biases logging. If you want to use it, add --wandb arguments to the script. 29 | 30 | #### SWAGAN 31 | 32 | This implementation experimentally supports SWAGAN: A Style-based Wavelet-driven Generative Model (https://arxiv.org/abs/2102.06108). You can train SWAGAN by using 33 | 34 | > python -m torch.distributed.launch --nproc_per_node=N_GPU --master_port=PORT train.py --arch swagan --batch BATCH_SIZE LMDB_PATH 35 | 36 | As noted in the paper, SWAGAN trains much faster. (About ~2x at 256px.) 37 | 38 | ### Convert weight from official checkpoints 39 | 40 | You need to clone official repositories, (https://github.com/NVlabs/stylegan2) as it is requires for load official checkpoints. 41 | 42 | For example, if you cloned repositories in ~/stylegan2 and downloaded stylegan2-ffhq-config-f.pkl, You can convert it like this: 43 | 44 | > python convert_weight.py --repo ~/stylegan2 stylegan2-ffhq-config-f.pkl 45 | 46 | This will create converted stylegan2-ffhq-config-f.pt file. 47 | 48 | ### Generate samples 49 | 50 | > python generate.py --sample N_FACES --pics N_PICS --ckpt PATH_CHECKPOINT 51 | 52 | You should change your size (--size 256 for example) if you train with another dimension. 53 | 54 | ### Project images to latent spaces 55 | 56 | > python projector.py --ckpt [CHECKPOINT] --size [GENERATOR_OUTPUT_SIZE] FILE1 FILE2 ... 57 | 58 | ### Closed-Form Factorization (https://arxiv.org/abs/2007.06600) 59 | 60 | You can use `closed_form_factorization.py` and `apply_factor.py` to discover meaningful latent semantic factor or directions in unsupervised manner. 61 | 62 | First, you need to extract eigenvectors of weight matrices using `closed_form_factorization.py` 63 | 64 | > python closed_form_factorization.py [CHECKPOINT] 65 | 66 | This will create factor file that contains eigenvectors. (Default: factor.pt) And you can use `apply_factor.py` to test the meaning of extracted directions 67 | 68 | > python apply_factor.py -i [INDEX_OF_EIGENVECTOR] -d [DEGREE_OF_MOVE] -n [NUMBER_OF_SAMPLES] --ckpt [CHECKPOINT] [FACTOR_FILE] 69 | 70 | For example, 71 | 72 | > python apply_factor.py -i 19 -d 5 -n 10 --ckpt [CHECKPOINT] factor.pt 73 | 74 | Will generate 10 random samples, and samples generated from latents that moved along 19th eigenvector with size/degree +-5. 75 | 76 | ![Sample of closed form factorization](factor_index-13_degree-5.0.png) 77 | 78 | ## Pretrained Checkpoints 79 | 80 | [Link](https://drive.google.com/open?id=1PQutd-JboOCOZqmd95XWxWrO8gGEvRcO) 81 | 82 | I have trained the 256px model on FFHQ 550k iterations. I got FID about 4.5. Maybe data preprocessing, resolution, training loop could made this difference, but currently I don't know the exact reason of FID differences. 83 | 84 | ## Samples 85 | 86 | ![Sample with truncation](doc/sample.png) 87 | 88 | Sample from FFHQ. At 110,000 iterations. (trained on 3.52M images) 89 | 90 | ![MetFaces sample with non-leaking augmentations](doc/sample-metfaces.png) 91 | 92 | Sample from MetFaces with Non-leaking augmentations. At 150,000 iterations. (trained on 4.8M images) 93 | 94 | ### Samples from converted weights 95 | 96 | ![Sample from FFHQ](doc/stylegan2-ffhq-config-f.png) 97 | 98 | Sample from FFHQ (1024px) 99 | 100 | ![Sample from LSUN Church](doc/stylegan2-church-config-f.png) 101 | 102 | Sample from LSUN Church (256px) 103 | 104 | ## License 105 | 106 | Model details and custom CUDA kernel codes are from official repostiories: https://github.com/NVlabs/stylegan2 107 | 108 | Codes for Learned Perceptual Image Patch Similarity, LPIPS came from https://github.com/richzhang/PerceptualSimilarity 109 | 110 | To match FID scores more closely to tensorflow official implementations, I have used FID Inception V3 implementations in https://github.com/mseitzer/pytorch-fid 111 | -------------------------------------------------------------------------------- /stylegan2/apply_factor.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torchvision import utils 5 | 6 | from model import Generator 7 | 8 | 9 | if __name__ == "__main__": 10 | torch.set_grad_enabled(False) 11 | 12 | parser = argparse.ArgumentParser(description="Apply closed form factorization") 13 | 14 | parser.add_argument( 15 | "-i", "--index", type=int, default=0, help="index of eigenvector" 16 | ) 17 | parser.add_argument( 18 | "-d", 19 | "--degree", 20 | type=float, 21 | default=5, 22 | help="scalar factors for moving latent vectors along eigenvector", 23 | ) 24 | parser.add_argument( 25 | "--channel_multiplier", 26 | type=int, 27 | default=2, 28 | help='channel multiplier factor. config-f = 2, else = 1', 29 | ) 30 | parser.add_argument("--ckpt", type=str, required=True, help="stylegan2 checkpoints") 31 | parser.add_argument( 32 | "--size", type=int, default=256, help="output image size of the generator" 33 | ) 34 | parser.add_argument( 35 | "-n", "--n_sample", type=int, default=7, help="number of samples created" 36 | ) 37 | parser.add_argument( 38 | "--truncation", type=float, default=0.7, help="truncation factor" 39 | ) 40 | parser.add_argument( 41 | "--device", type=str, default="cuda", help="device to run the model" 42 | ) 43 | parser.add_argument( 44 | "--out_prefix", 45 | type=str, 46 | default="factor", 47 | help="filename prefix to result samples", 48 | ) 49 | parser.add_argument( 50 | "factor", 51 | type=str, 52 | help="name of the closed form factorization result factor file", 53 | ) 54 | 55 | args = parser.parse_args() 56 | 57 | eigvec = torch.load(args.factor)["eigvec"].to(args.device) 58 | ckpt = torch.load(args.ckpt) 59 | g = Generator(args.size, 512, 8, channel_multiplier=args.channel_multiplier).to(args.device) 60 | g.load_state_dict(ckpt["g_ema"], strict=False) 61 | 62 | trunc = g.mean_latent(4096) 63 | 64 | latent = torch.randn(args.n_sample, 512, device=args.device) 65 | latent = g.get_latent(latent) 66 | 67 | direction = args.degree * eigvec[:, args.index].unsqueeze(0) 68 | 69 | img, _ = g( 70 | [latent], 71 | truncation=args.truncation, 72 | truncation_latent=trunc, 73 | input_is_latent=True, 74 | ) 75 | img1, _ = g( 76 | [latent + direction], 77 | truncation=args.truncation, 78 | truncation_latent=trunc, 79 | input_is_latent=True, 80 | ) 81 | img2, _ = g( 82 | [latent - direction], 83 | truncation=args.truncation, 84 | truncation_latent=trunc, 85 | input_is_latent=True, 86 | ) 87 | 88 | grid = utils.save_image( 89 | torch.cat([img1, img, img2], 0), 90 | f"{args.out_prefix}_index-{args.index}_degree-{args.degree}.png", 91 | normalize=True, 92 | range=(-1, 1), 93 | nrow=args.n_sample, 94 | ) 95 | -------------------------------------------------------------------------------- /stylegan2/calc_inception.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import os 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | from torchvision.models import inception_v3, Inception3 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | from inception import InceptionV3 15 | from dataset import MultiResolutionDataset 16 | 17 | 18 | class Inception3Feature(Inception3): 19 | def forward(self, x): 20 | if x.shape[2] != 299 or x.shape[3] != 299: 21 | x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=True) 22 | 23 | x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3 24 | x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32 25 | x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32 26 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64 27 | 28 | x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64 29 | x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80 30 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192 31 | 32 | x = self.Mixed_5b(x) # 35 x 35 x 192 33 | x = self.Mixed_5c(x) # 35 x 35 x 256 34 | x = self.Mixed_5d(x) # 35 x 35 x 288 35 | 36 | x = self.Mixed_6a(x) # 35 x 35 x 288 37 | x = self.Mixed_6b(x) # 17 x 17 x 768 38 | x = self.Mixed_6c(x) # 17 x 17 x 768 39 | x = self.Mixed_6d(x) # 17 x 17 x 768 40 | x = self.Mixed_6e(x) # 17 x 17 x 768 41 | 42 | x = self.Mixed_7a(x) # 17 x 17 x 768 43 | x = self.Mixed_7b(x) # 8 x 8 x 1280 44 | x = self.Mixed_7c(x) # 8 x 8 x 2048 45 | 46 | x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048 47 | 48 | return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048 49 | 50 | 51 | def load_patched_inception_v3(): 52 | # inception = inception_v3(pretrained=True) 53 | # inception_feat = Inception3Feature() 54 | # inception_feat.load_state_dict(inception.state_dict()) 55 | inception_feat = InceptionV3([3], normalize_input=False) 56 | 57 | return inception_feat 58 | 59 | 60 | @torch.no_grad() 61 | def extract_features(loader, inception, device): 62 | pbar = tqdm(loader) 63 | 64 | feature_list = [] 65 | 66 | for img in pbar: 67 | img = img.to(device) 68 | feature = inception(img)[0].view(img.shape[0], -1) 69 | feature_list.append(feature.to("cpu")) 70 | 71 | features = torch.cat(feature_list, 0) 72 | 73 | return features 74 | 75 | 76 | if __name__ == "__main__": 77 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 78 | 79 | parser = argparse.ArgumentParser( 80 | description="Calculate Inception v3 features for datasets" 81 | ) 82 | parser.add_argument( 83 | "--size", 84 | type=int, 85 | default=256, 86 | help="image sizes used for embedding calculation", 87 | ) 88 | parser.add_argument( 89 | "--batch", default=64, type=int, help="batch size for inception networks" 90 | ) 91 | parser.add_argument( 92 | "--n_sample", 93 | type=int, 94 | default=50000, 95 | help="number of samples used for embedding calculation", 96 | ) 97 | parser.add_argument( 98 | "--flip", action="store_true", help="apply random flipping to real images" 99 | ) 100 | parser.add_argument("path", metavar="PATH", help="path to datset lmdb file") 101 | 102 | args = parser.parse_args() 103 | 104 | inception = load_patched_inception_v3() 105 | inception = nn.DataParallel(inception).eval().to(device) 106 | 107 | transform = transforms.Compose( 108 | [ 109 | transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0), 110 | transforms.ToTensor(), 111 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 112 | ] 113 | ) 114 | 115 | dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size) 116 | loader = DataLoader(dset, batch_size=args.batch, num_workers=4) 117 | 118 | features = extract_features(loader, inception, device).numpy() 119 | 120 | features = features[: args.n_sample] 121 | 122 | print(f"extracted {features.shape[0]} features") 123 | 124 | mean = np.mean(features, 0) 125 | cov = np.cov(features, rowvar=False) 126 | 127 | name = os.path.splitext(os.path.basename(args.path))[0] 128 | 129 | with open(f"inception_{name}.pkl", "wb") as f: 130 | pickle.dump({"mean": mean, "cov": cov, "size": args.size, "path": args.path}, f) 131 | -------------------------------------------------------------------------------- /stylegan2/checkpoint/.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | -------------------------------------------------------------------------------- /stylegan2/closed_form_factorization.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser( 8 | description="Extract factor/eigenvectors of latent spaces using closed form factorization" 9 | ) 10 | 11 | parser.add_argument( 12 | "--out", type=str, default="factor.pt", help="name of the result factor file" 13 | ) 14 | parser.add_argument("ckpt", type=str, help="name of the model checkpoint") 15 | 16 | args = parser.parse_args() 17 | 18 | ckpt = torch.load(args.ckpt) 19 | modulate = { 20 | k: v 21 | for k, v in ckpt["g_ema"].items() 22 | if "modulation" in k and "to_rgbs" not in k and "weight" in k 23 | } 24 | 25 | weight_mat = [] 26 | for k, v in modulate.items(): 27 | weight_mat.append(v) 28 | 29 | W = torch.cat(weight_mat, 0) 30 | eigvec = torch.svd(W).V.to("cpu") 31 | 32 | torch.save({"ckpt": args.ckpt, "eigvec": eigvec}, args.out) 33 | 34 | -------------------------------------------------------------------------------- /stylegan2/convert_weight.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import pickle 5 | import math 6 | 7 | import torch 8 | import numpy as np 9 | from torchvision import utils 10 | 11 | from model import Generator, Discriminator 12 | 13 | 14 | def convert_modconv(vars, source_name, target_name, flip=False): 15 | weight = vars[source_name + "/weight"].value().eval() 16 | mod_weight = vars[source_name + "/mod_weight"].value().eval() 17 | mod_bias = vars[source_name + "/mod_bias"].value().eval() 18 | noise = vars[source_name + "/noise_strength"].value().eval() 19 | bias = vars[source_name + "/bias"].value().eval() 20 | 21 | dic = { 22 | "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), 23 | "conv.modulation.weight": mod_weight.transpose((1, 0)), 24 | "conv.modulation.bias": mod_bias + 1, 25 | "noise.weight": np.array([noise]), 26 | "activate.bias": bias, 27 | } 28 | 29 | dic_torch = {} 30 | 31 | for k, v in dic.items(): 32 | dic_torch[target_name + "." + k] = torch.from_numpy(v) 33 | 34 | if flip: 35 | dic_torch[target_name + ".conv.weight"] = torch.flip( 36 | dic_torch[target_name + ".conv.weight"], [3, 4] 37 | ) 38 | 39 | return dic_torch 40 | 41 | 42 | def convert_conv(vars, source_name, target_name, bias=True, start=0): 43 | weight = vars[source_name + "/weight"].value().eval() 44 | 45 | dic = {"weight": weight.transpose((3, 2, 0, 1))} 46 | 47 | if bias: 48 | dic["bias"] = vars[source_name + "/bias"].value().eval() 49 | 50 | dic_torch = {} 51 | 52 | dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"]) 53 | 54 | if bias: 55 | dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"]) 56 | 57 | return dic_torch 58 | 59 | 60 | def convert_torgb(vars, source_name, target_name): 61 | weight = vars[source_name + "/weight"].value().eval() 62 | mod_weight = vars[source_name + "/mod_weight"].value().eval() 63 | mod_bias = vars[source_name + "/mod_bias"].value().eval() 64 | bias = vars[source_name + "/bias"].value().eval() 65 | 66 | dic = { 67 | "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), 68 | "conv.modulation.weight": mod_weight.transpose((1, 0)), 69 | "conv.modulation.bias": mod_bias + 1, 70 | "bias": bias.reshape((1, 3, 1, 1)), 71 | } 72 | 73 | dic_torch = {} 74 | 75 | for k, v in dic.items(): 76 | dic_torch[target_name + "." + k] = torch.from_numpy(v) 77 | 78 | return dic_torch 79 | 80 | 81 | def convert_dense(vars, source_name, target_name): 82 | weight = vars[source_name + "/weight"].value().eval() 83 | bias = vars[source_name + "/bias"].value().eval() 84 | 85 | dic = {"weight": weight.transpose((1, 0)), "bias": bias} 86 | 87 | dic_torch = {} 88 | 89 | for k, v in dic.items(): 90 | dic_torch[target_name + "." + k] = torch.from_numpy(v) 91 | 92 | return dic_torch 93 | 94 | 95 | def update(state_dict, new): 96 | for k, v in new.items(): 97 | if k not in state_dict: 98 | raise KeyError(k + " is not found") 99 | 100 | if v.shape != state_dict[k].shape: 101 | raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}") 102 | 103 | state_dict[k] = v 104 | 105 | 106 | def discriminator_fill_statedict(statedict, vars, size): 107 | log_size = int(math.log(size, 2)) 108 | 109 | update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0")) 110 | 111 | conv_i = 1 112 | 113 | for i in range(log_size - 2, 0, -1): 114 | reso = 4 * 2 ** i 115 | update( 116 | statedict, 117 | convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"), 118 | ) 119 | update( 120 | statedict, 121 | convert_conv( 122 | vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1 123 | ), 124 | ) 125 | update( 126 | statedict, 127 | convert_conv( 128 | vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False 129 | ), 130 | ) 131 | conv_i += 1 132 | 133 | update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv")) 134 | update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0")) 135 | update(statedict, convert_dense(vars, f"Output", "final_linear.1")) 136 | 137 | return statedict 138 | 139 | 140 | def fill_statedict(state_dict, vars, size, n_mlp): 141 | log_size = int(math.log(size, 2)) 142 | 143 | for i in range(n_mlp): 144 | update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}")) 145 | 146 | update( 147 | state_dict, 148 | { 149 | "input.input": torch.from_numpy( 150 | vars["G_synthesis/4x4/Const/const"].value().eval() 151 | ) 152 | }, 153 | ) 154 | 155 | update(state_dict, convert_torgb(vars, "G_synthesis/4x4/ToRGB", "to_rgb1")) 156 | 157 | for i in range(log_size - 2): 158 | reso = 4 * 2 ** (i + 1) 159 | update( 160 | state_dict, 161 | convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"), 162 | ) 163 | 164 | update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "conv1")) 165 | 166 | conv_i = 0 167 | 168 | for i in range(log_size - 2): 169 | reso = 4 * 2 ** (i + 1) 170 | update( 171 | state_dict, 172 | convert_modconv( 173 | vars, 174 | f"G_synthesis/{reso}x{reso}/Conv0_up", 175 | f"convs.{conv_i}", 176 | flip=True, 177 | ), 178 | ) 179 | update( 180 | state_dict, 181 | convert_modconv( 182 | vars, f"G_synthesis/{reso}x{reso}/Conv1", f"convs.{conv_i + 1}" 183 | ), 184 | ) 185 | conv_i += 2 186 | 187 | for i in range(0, (log_size - 2) * 2 + 1): 188 | update( 189 | state_dict, 190 | { 191 | f"noises.noise_{i}": torch.from_numpy( 192 | vars[f"G_synthesis/noise{i}"].value().eval() 193 | ) 194 | }, 195 | ) 196 | 197 | return state_dict 198 | 199 | 200 | if __name__ == "__main__": 201 | device = "cuda" 202 | 203 | parser = argparse.ArgumentParser( 204 | description="Tensorflow to pytorch model checkpoint converter" 205 | ) 206 | parser.add_argument( 207 | "--repo", 208 | type=str, 209 | required=True, 210 | help="path to the offical StyleGAN2 repository with dnnlib/ folder", 211 | ) 212 | parser.add_argument( 213 | "--gen", action="store_true", help="convert the generator weights" 214 | ) 215 | parser.add_argument( 216 | "--disc", action="store_true", help="convert the discriminator weights" 217 | ) 218 | parser.add_argument( 219 | "--channel_multiplier", 220 | type=int, 221 | default=2, 222 | help="channel multiplier factor. config-f = 2, else = 1", 223 | ) 224 | parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights") 225 | 226 | args = parser.parse_args() 227 | 228 | sys.path.append(args.repo) 229 | 230 | import dnnlib 231 | from dnnlib import tflib 232 | 233 | tflib.init_tf() 234 | 235 | with open(args.path, "rb") as f: 236 | generator, discriminator, g_ema = pickle.load(f) 237 | 238 | size = g_ema.output_shape[2] 239 | 240 | n_mlp = 0 241 | mapping_layers_names = g_ema.__getstate__()['components']['mapping'].list_layers() 242 | for layer in mapping_layers_names: 243 | if layer[0].startswith('Dense'): 244 | n_mlp += 1 245 | 246 | g = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier) 247 | state_dict = g.state_dict() 248 | state_dict = fill_statedict(state_dict, g_ema.vars, size, n_mlp) 249 | 250 | g.load_state_dict(state_dict) 251 | 252 | latent_avg = torch.from_numpy(g_ema.vars["dlatent_avg"].value().eval()) 253 | 254 | ckpt = {"g_ema": state_dict, "latent_avg": latent_avg} 255 | 256 | if args.gen: 257 | g_train = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier) 258 | g_train_state = g_train.state_dict() 259 | g_train_state = fill_statedict(g_train_state, generator.vars, size, n_mlp) 260 | ckpt["g"] = g_train_state 261 | 262 | if args.disc: 263 | disc = Discriminator(size, channel_multiplier=args.channel_multiplier) 264 | d_state = disc.state_dict() 265 | d_state = discriminator_fill_statedict(d_state, discriminator.vars, size) 266 | ckpt["d"] = d_state 267 | 268 | name = os.path.splitext(os.path.basename(args.path))[0] 269 | torch.save(ckpt, name + ".pt") 270 | 271 | batch_size = {256: 16, 512: 9, 1024: 4} 272 | n_sample = batch_size.get(size, 25) 273 | 274 | g = g.to(device) 275 | 276 | z = np.random.RandomState(0).randn(n_sample, 512).astype("float32") 277 | 278 | with torch.no_grad(): 279 | img_pt, _ = g( 280 | [torch.from_numpy(z).to(device)], 281 | truncation=0.5, 282 | truncation_latent=latent_avg.to(device), 283 | randomize_noise=False, 284 | ) 285 | 286 | Gs_kwargs = dnnlib.EasyDict() 287 | Gs_kwargs.randomize_noise = False 288 | img_tf = g_ema.run(z, None, **Gs_kwargs) 289 | img_tf = torch.from_numpy(img_tf).to(device) 290 | 291 | img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp( 292 | 0.0, 1.0 293 | ) 294 | 295 | img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0) 296 | 297 | print(img_diff.abs().max()) 298 | 299 | utils.save_image( 300 | img_concat, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1) 301 | ) 302 | -------------------------------------------------------------------------------- /stylegan2/dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class MultiResolutionDataset(Dataset): 9 | def __init__(self, path, transform, resolution=256): 10 | self.env = lmdb.open( 11 | path, 12 | max_readers=32, 13 | readonly=True, 14 | lock=False, 15 | readahead=False, 16 | meminit=False, 17 | ) 18 | 19 | if not self.env: 20 | raise IOError('Cannot open lmdb dataset', path) 21 | 22 | with self.env.begin(write=False) as txn: 23 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) 24 | 25 | self.resolution = resolution 26 | self.transform = transform 27 | 28 | def __len__(self): 29 | return self.length 30 | 31 | def __getitem__(self, index): 32 | with self.env.begin(write=False) as txn: 33 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') 34 | img_bytes = txn.get(key) 35 | 36 | buffer = BytesIO(img_bytes) 37 | img = Image.open(buffer) 38 | img = self.transform(img) 39 | 40 | return img 41 | -------------------------------------------------------------------------------- /stylegan2/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pickle 3 | 4 | import torch 5 | from torch import distributed as dist 6 | from torch.utils.data.sampler import Sampler 7 | 8 | 9 | def get_rank(): 10 | if not dist.is_available(): 11 | return 0 12 | 13 | if not dist.is_initialized(): 14 | return 0 15 | 16 | return dist.get_rank() 17 | 18 | 19 | def synchronize(): 20 | if not dist.is_available(): 21 | return 22 | 23 | if not dist.is_initialized(): 24 | return 25 | 26 | world_size = dist.get_world_size() 27 | 28 | if world_size == 1: 29 | return 30 | 31 | dist.barrier() 32 | 33 | 34 | def get_world_size(): 35 | if not dist.is_available(): 36 | return 1 37 | 38 | if not dist.is_initialized(): 39 | return 1 40 | 41 | return dist.get_world_size() 42 | 43 | 44 | def reduce_sum(tensor): 45 | if not dist.is_available(): 46 | return tensor 47 | 48 | if not dist.is_initialized(): 49 | return tensor 50 | 51 | tensor = tensor.clone() 52 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 53 | 54 | return tensor 55 | 56 | 57 | def gather_grad(params): 58 | world_size = get_world_size() 59 | 60 | if world_size == 1: 61 | return 62 | 63 | for param in params: 64 | if param.grad is not None: 65 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 66 | param.grad.data.div_(world_size) 67 | 68 | 69 | def all_gather(data): 70 | world_size = get_world_size() 71 | 72 | if world_size == 1: 73 | return [data] 74 | 75 | buffer = pickle.dumps(data) 76 | storage = torch.ByteStorage.from_buffer(buffer) 77 | tensor = torch.ByteTensor(storage).to('cuda') 78 | 79 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 80 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 81 | dist.all_gather(size_list, local_size) 82 | size_list = [int(size.item()) for size in size_list] 83 | max_size = max(size_list) 84 | 85 | tensor_list = [] 86 | for _ in size_list: 87 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 88 | 89 | if local_size != max_size: 90 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 91 | tensor = torch.cat((tensor, padding), 0) 92 | 93 | dist.all_gather(tensor_list, tensor) 94 | 95 | data_list = [] 96 | 97 | for size, tensor in zip(size_list, tensor_list): 98 | buffer = tensor.cpu().numpy().tobytes()[:size] 99 | data_list.append(pickle.loads(buffer)) 100 | 101 | return data_list 102 | 103 | 104 | def reduce_loss_dict(loss_dict): 105 | world_size = get_world_size() 106 | 107 | if world_size < 2: 108 | return loss_dict 109 | 110 | with torch.no_grad(): 111 | keys = [] 112 | losses = [] 113 | 114 | for k in sorted(loss_dict.keys()): 115 | keys.append(k) 116 | losses.append(loss_dict[k]) 117 | 118 | losses = torch.stack(losses, 0) 119 | dist.reduce(losses, dst=0) 120 | 121 | if dist.get_rank() == 0: 122 | losses /= world_size 123 | 124 | reduced_losses = {k: v for k, v in zip(keys, losses)} 125 | 126 | return reduced_losses 127 | -------------------------------------------------------------------------------- /stylegan2/doc/sample-metfaces.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/alias-free-gan-pytorch/f14d54ce2d973880b0c352614b2d63088c9026ae/stylegan2/doc/sample-metfaces.png -------------------------------------------------------------------------------- /stylegan2/doc/sample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/alias-free-gan-pytorch/f14d54ce2d973880b0c352614b2d63088c9026ae/stylegan2/doc/sample.png -------------------------------------------------------------------------------- /stylegan2/doc/stylegan2-church-config-f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/alias-free-gan-pytorch/f14d54ce2d973880b0c352614b2d63088c9026ae/stylegan2/doc/stylegan2-church-config-f.png -------------------------------------------------------------------------------- /stylegan2/doc/stylegan2-ffhq-config-f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/alias-free-gan-pytorch/f14d54ce2d973880b0c352614b2d63088c9026ae/stylegan2/doc/stylegan2-ffhq-config-f.png -------------------------------------------------------------------------------- /stylegan2/factor_index-13_degree-5.0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/alias-free-gan-pytorch/f14d54ce2d973880b0c352614b2d63088c9026ae/stylegan2/factor_index-13_degree-5.0.png -------------------------------------------------------------------------------- /stylegan2/fid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import torch 5 | from torch import nn 6 | import numpy as np 7 | from scipy import linalg 8 | from tqdm import tqdm 9 | 10 | from model import Generator 11 | from calc_inception import load_patched_inception_v3 12 | 13 | 14 | @torch.no_grad() 15 | def extract_feature_from_samples( 16 | generator, inception, truncation, truncation_latent, batch_size, n_sample, device 17 | ): 18 | n_batch = n_sample // batch_size 19 | resid = n_sample - (n_batch * batch_size) 20 | batch_sizes = [batch_size] * n_batch + [resid] 21 | features = [] 22 | 23 | for batch in tqdm(batch_sizes): 24 | latent = torch.randn(batch, 512, device=device) 25 | img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent) 26 | feat = inception(img)[0].view(img.shape[0], -1) 27 | features.append(feat.to("cpu")) 28 | 29 | features = torch.cat(features, 0) 30 | 31 | return features 32 | 33 | 34 | def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6): 35 | cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) 36 | 37 | if not np.isfinite(cov_sqrt).all(): 38 | print("product of cov matrices is singular") 39 | offset = np.eye(sample_cov.shape[0]) * eps 40 | cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) 41 | 42 | if np.iscomplexobj(cov_sqrt): 43 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 44 | m = np.max(np.abs(cov_sqrt.imag)) 45 | 46 | raise ValueError(f"Imaginary component {m}") 47 | 48 | cov_sqrt = cov_sqrt.real 49 | 50 | mean_diff = sample_mean - real_mean 51 | mean_norm = mean_diff @ mean_diff 52 | 53 | trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) 54 | 55 | fid = mean_norm + trace 56 | 57 | return fid 58 | 59 | 60 | if __name__ == "__main__": 61 | device = "cuda" 62 | 63 | parser = argparse.ArgumentParser(description="Calculate FID scores") 64 | 65 | parser.add_argument("--truncation", type=float, default=1, help="truncation factor") 66 | parser.add_argument( 67 | "--truncation_mean", 68 | type=int, 69 | default=4096, 70 | help="number of samples to calculate mean for truncation", 71 | ) 72 | parser.add_argument( 73 | "--batch", type=int, default=64, help="batch size for the generator" 74 | ) 75 | parser.add_argument( 76 | "--n_sample", 77 | type=int, 78 | default=50000, 79 | help="number of the samples for calculating FID", 80 | ) 81 | parser.add_argument( 82 | "--size", type=int, default=256, help="image sizes for generator" 83 | ) 84 | parser.add_argument( 85 | "--inception", 86 | type=str, 87 | default=None, 88 | required=True, 89 | help="path to precomputed inception embedding", 90 | ) 91 | parser.add_argument( 92 | "ckpt", metavar="CHECKPOINT", help="path to generator checkpoint" 93 | ) 94 | 95 | args = parser.parse_args() 96 | 97 | ckpt = torch.load(args.ckpt) 98 | 99 | g = Generator(args.size, 512, 8).to(device) 100 | g.load_state_dict(ckpt["g_ema"]) 101 | g = nn.DataParallel(g) 102 | g.eval() 103 | 104 | if args.truncation < 1: 105 | with torch.no_grad(): 106 | mean_latent = g.mean_latent(args.truncation_mean) 107 | 108 | else: 109 | mean_latent = None 110 | 111 | inception = nn.DataParallel(load_patched_inception_v3()).to(device) 112 | inception.eval() 113 | 114 | features = extract_feature_from_samples( 115 | g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device 116 | ).numpy() 117 | print(f"extracted {features.shape[0]} features") 118 | 119 | sample_mean = np.mean(features, 0) 120 | sample_cov = np.cov(features, rowvar=False) 121 | 122 | with open(args.inception, "rb") as f: 123 | embeds = pickle.load(f) 124 | real_mean = embeds["mean"] 125 | real_cov = embeds["cov"] 126 | 127 | fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) 128 | 129 | print("fid:", fid) 130 | -------------------------------------------------------------------------------- /stylegan2/generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torchvision import utils 5 | from model import Generator 6 | from tqdm import tqdm 7 | 8 | 9 | def generate(args, g_ema, device, mean_latent): 10 | 11 | with torch.no_grad(): 12 | g_ema.eval() 13 | for i in tqdm(range(args.pics)): 14 | sample_z = torch.randn(args.sample, args.latent, device=device) 15 | 16 | sample, _ = g_ema( 17 | [sample_z], truncation=args.truncation, truncation_latent=mean_latent 18 | ) 19 | 20 | utils.save_image( 21 | sample, 22 | f"sample/{str(i).zfill(6)}.png", 23 | nrow=1, 24 | normalize=True, 25 | range=(-1, 1), 26 | ) 27 | 28 | 29 | if __name__ == "__main__": 30 | device = "cuda" 31 | 32 | parser = argparse.ArgumentParser(description="Generate samples from the generator") 33 | 34 | parser.add_argument( 35 | "--size", type=int, default=1024, help="output image size of the generator" 36 | ) 37 | parser.add_argument( 38 | "--sample", 39 | type=int, 40 | default=1, 41 | help="number of samples to be generated for each image", 42 | ) 43 | parser.add_argument( 44 | "--pics", type=int, default=20, help="number of images to be generated" 45 | ) 46 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio") 47 | parser.add_argument( 48 | "--truncation_mean", 49 | type=int, 50 | default=4096, 51 | help="number of vectors to calculate mean for the truncation", 52 | ) 53 | parser.add_argument( 54 | "--ckpt", 55 | type=str, 56 | default="stylegan2-ffhq-config-f.pt", 57 | help="path to the model checkpoint", 58 | ) 59 | parser.add_argument( 60 | "--channel_multiplier", 61 | type=int, 62 | default=2, 63 | help="channel multiplier of the generator. config-f = 2, else = 1", 64 | ) 65 | 66 | args = parser.parse_args() 67 | 68 | args.latent = 512 69 | args.n_mlp = 8 70 | 71 | g_ema = Generator( 72 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier 73 | ).to(device) 74 | checkpoint = torch.load(args.ckpt) 75 | 76 | g_ema.load_state_dict(checkpoint["g_ema"]) 77 | 78 | if args.truncation < 1: 79 | with torch.no_grad(): 80 | mean_latent = g_ema.mean_latent(args.truncation_mean) 81 | else: 82 | mean_latent = None 83 | 84 | generate(args, g_ema, device, mean_latent) 85 | -------------------------------------------------------------------------------- /stylegan2/inception.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import models 5 | 6 | try: 7 | from torchvision.models.utils import load_state_dict_from_url 8 | except ImportError: 9 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 10 | 11 | # Inception weights ported to Pytorch from 12 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 13 | FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' 14 | 15 | 16 | class InceptionV3(nn.Module): 17 | """Pretrained InceptionV3 network returning feature maps""" 18 | 19 | # Index of default block of inception to return, 20 | # corresponds to output of final average pooling 21 | DEFAULT_BLOCK_INDEX = 3 22 | 23 | # Maps feature dimensionality to their output blocks indices 24 | BLOCK_INDEX_BY_DIM = { 25 | 64: 0, # First max pooling features 26 | 192: 1, # Second max pooling featurs 27 | 768: 2, # Pre-aux classifier features 28 | 2048: 3 # Final average pooling features 29 | } 30 | 31 | def __init__(self, 32 | output_blocks=[DEFAULT_BLOCK_INDEX], 33 | resize_input=True, 34 | normalize_input=True, 35 | requires_grad=False, 36 | use_fid_inception=True): 37 | """Build pretrained InceptionV3 38 | 39 | Parameters 40 | ---------- 41 | output_blocks : list of int 42 | Indices of blocks to return features of. Possible values are: 43 | - 0: corresponds to output of first max pooling 44 | - 1: corresponds to output of second max pooling 45 | - 2: corresponds to output which is fed to aux classifier 46 | - 3: corresponds to output of final average pooling 47 | resize_input : bool 48 | If true, bilinearly resizes input to width and height 299 before 49 | feeding input to model. As the network without fully connected 50 | layers is fully convolutional, it should be able to handle inputs 51 | of arbitrary size, so resizing might not be strictly needed 52 | normalize_input : bool 53 | If true, scales the input from range (0, 1) to the range the 54 | pretrained Inception network expects, namely (-1, 1) 55 | requires_grad : bool 56 | If true, parameters of the model require gradients. Possibly useful 57 | for finetuning the network 58 | use_fid_inception : bool 59 | If true, uses the pretrained Inception model used in Tensorflow's 60 | FID implementation. If false, uses the pretrained Inception model 61 | available in torchvision. The FID Inception model has different 62 | weights and a slightly different structure from torchvision's 63 | Inception model. If you want to compute FID scores, you are 64 | strongly advised to set this parameter to true to get comparable 65 | results. 66 | """ 67 | super(InceptionV3, self).__init__() 68 | 69 | self.resize_input = resize_input 70 | self.normalize_input = normalize_input 71 | self.output_blocks = sorted(output_blocks) 72 | self.last_needed_block = max(output_blocks) 73 | 74 | assert self.last_needed_block <= 3, \ 75 | 'Last possible output block index is 3' 76 | 77 | self.blocks = nn.ModuleList() 78 | 79 | if use_fid_inception: 80 | inception = fid_inception_v3() 81 | else: 82 | inception = models.inception_v3(pretrained=True) 83 | 84 | # Block 0: input to maxpool1 85 | block0 = [ 86 | inception.Conv2d_1a_3x3, 87 | inception.Conv2d_2a_3x3, 88 | inception.Conv2d_2b_3x3, 89 | nn.MaxPool2d(kernel_size=3, stride=2) 90 | ] 91 | self.blocks.append(nn.Sequential(*block0)) 92 | 93 | # Block 1: maxpool1 to maxpool2 94 | if self.last_needed_block >= 1: 95 | block1 = [ 96 | inception.Conv2d_3b_1x1, 97 | inception.Conv2d_4a_3x3, 98 | nn.MaxPool2d(kernel_size=3, stride=2) 99 | ] 100 | self.blocks.append(nn.Sequential(*block1)) 101 | 102 | # Block 2: maxpool2 to aux classifier 103 | if self.last_needed_block >= 2: 104 | block2 = [ 105 | inception.Mixed_5b, 106 | inception.Mixed_5c, 107 | inception.Mixed_5d, 108 | inception.Mixed_6a, 109 | inception.Mixed_6b, 110 | inception.Mixed_6c, 111 | inception.Mixed_6d, 112 | inception.Mixed_6e, 113 | ] 114 | self.blocks.append(nn.Sequential(*block2)) 115 | 116 | # Block 3: aux classifier to final avgpool 117 | if self.last_needed_block >= 3: 118 | block3 = [ 119 | inception.Mixed_7a, 120 | inception.Mixed_7b, 121 | inception.Mixed_7c, 122 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 123 | ] 124 | self.blocks.append(nn.Sequential(*block3)) 125 | 126 | for param in self.parameters(): 127 | param.requires_grad = requires_grad 128 | 129 | def forward(self, inp): 130 | """Get Inception feature maps 131 | 132 | Parameters 133 | ---------- 134 | inp : torch.autograd.Variable 135 | Input tensor of shape Bx3xHxW. Values are expected to be in 136 | range (0, 1) 137 | 138 | Returns 139 | ------- 140 | List of torch.autograd.Variable, corresponding to the selected output 141 | block, sorted ascending by index 142 | """ 143 | outp = [] 144 | x = inp 145 | 146 | if self.resize_input: 147 | x = F.interpolate(x, 148 | size=(299, 299), 149 | mode='bilinear', 150 | align_corners=False) 151 | 152 | if self.normalize_input: 153 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 154 | 155 | for idx, block in enumerate(self.blocks): 156 | x = block(x) 157 | if idx in self.output_blocks: 158 | outp.append(x) 159 | 160 | if idx == self.last_needed_block: 161 | break 162 | 163 | return outp 164 | 165 | 166 | def fid_inception_v3(): 167 | """Build pretrained Inception model for FID computation 168 | 169 | The Inception model for FID computation uses a different set of weights 170 | and has a slightly different structure than torchvision's Inception. 171 | 172 | This method first constructs torchvision's Inception and then patches the 173 | necessary parts that are different in the FID Inception model. 174 | """ 175 | inception = models.inception_v3(num_classes=1008, 176 | aux_logits=False, 177 | pretrained=False) 178 | inception.Mixed_5b = FIDInceptionA(192, pool_features=32) 179 | inception.Mixed_5c = FIDInceptionA(256, pool_features=64) 180 | inception.Mixed_5d = FIDInceptionA(288, pool_features=64) 181 | inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128) 182 | inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160) 183 | inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160) 184 | inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192) 185 | inception.Mixed_7b = FIDInceptionE_1(1280) 186 | inception.Mixed_7c = FIDInceptionE_2(2048) 187 | 188 | state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True) 189 | inception.load_state_dict(state_dict) 190 | return inception 191 | 192 | 193 | class FIDInceptionA(models.inception.InceptionA): 194 | """InceptionA block patched for FID computation""" 195 | def __init__(self, in_channels, pool_features): 196 | super(FIDInceptionA, self).__init__(in_channels, pool_features) 197 | 198 | def forward(self, x): 199 | branch1x1 = self.branch1x1(x) 200 | 201 | branch5x5 = self.branch5x5_1(x) 202 | branch5x5 = self.branch5x5_2(branch5x5) 203 | 204 | branch3x3dbl = self.branch3x3dbl_1(x) 205 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 206 | branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl) 207 | 208 | # Patch: Tensorflow's average pool does not use the padded zero's in 209 | # its average calculation 210 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 211 | count_include_pad=False) 212 | branch_pool = self.branch_pool(branch_pool) 213 | 214 | outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool] 215 | return torch.cat(outputs, 1) 216 | 217 | 218 | class FIDInceptionC(models.inception.InceptionC): 219 | """InceptionC block patched for FID computation""" 220 | def __init__(self, in_channels, channels_7x7): 221 | super(FIDInceptionC, self).__init__(in_channels, channels_7x7) 222 | 223 | def forward(self, x): 224 | branch1x1 = self.branch1x1(x) 225 | 226 | branch7x7 = self.branch7x7_1(x) 227 | branch7x7 = self.branch7x7_2(branch7x7) 228 | branch7x7 = self.branch7x7_3(branch7x7) 229 | 230 | branch7x7dbl = self.branch7x7dbl_1(x) 231 | branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl) 232 | branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl) 233 | branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl) 234 | branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl) 235 | 236 | # Patch: Tensorflow's average pool does not use the padded zero's in 237 | # its average calculation 238 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 239 | count_include_pad=False) 240 | branch_pool = self.branch_pool(branch_pool) 241 | 242 | outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool] 243 | return torch.cat(outputs, 1) 244 | 245 | 246 | class FIDInceptionE_1(models.inception.InceptionE): 247 | """First InceptionE block patched for FID computation""" 248 | def __init__(self, in_channels): 249 | super(FIDInceptionE_1, self).__init__(in_channels) 250 | 251 | def forward(self, x): 252 | branch1x1 = self.branch1x1(x) 253 | 254 | branch3x3 = self.branch3x3_1(x) 255 | branch3x3 = [ 256 | self.branch3x3_2a(branch3x3), 257 | self.branch3x3_2b(branch3x3), 258 | ] 259 | branch3x3 = torch.cat(branch3x3, 1) 260 | 261 | branch3x3dbl = self.branch3x3dbl_1(x) 262 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 263 | branch3x3dbl = [ 264 | self.branch3x3dbl_3a(branch3x3dbl), 265 | self.branch3x3dbl_3b(branch3x3dbl), 266 | ] 267 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 268 | 269 | # Patch: Tensorflow's average pool does not use the padded zero's in 270 | # its average calculation 271 | branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1, 272 | count_include_pad=False) 273 | branch_pool = self.branch_pool(branch_pool) 274 | 275 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 276 | return torch.cat(outputs, 1) 277 | 278 | 279 | class FIDInceptionE_2(models.inception.InceptionE): 280 | """Second InceptionE block patched for FID computation""" 281 | def __init__(self, in_channels): 282 | super(FIDInceptionE_2, self).__init__(in_channels) 283 | 284 | def forward(self, x): 285 | branch1x1 = self.branch1x1(x) 286 | 287 | branch3x3 = self.branch3x3_1(x) 288 | branch3x3 = [ 289 | self.branch3x3_2a(branch3x3), 290 | self.branch3x3_2b(branch3x3), 291 | ] 292 | branch3x3 = torch.cat(branch3x3, 1) 293 | 294 | branch3x3dbl = self.branch3x3dbl_1(x) 295 | branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) 296 | branch3x3dbl = [ 297 | self.branch3x3dbl_3a(branch3x3dbl), 298 | self.branch3x3dbl_3b(branch3x3dbl), 299 | ] 300 | branch3x3dbl = torch.cat(branch3x3dbl, 1) 301 | 302 | # Patch: The FID Inception model uses max pooling instead of average 303 | # pooling. This is likely an error in this specific Inception 304 | # implementation, as other Inception models use average pooling here 305 | # (which matches the description in the paper). 306 | branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1) 307 | branch_pool = self.branch_pool(branch_pool) 308 | 309 | outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool] 310 | return torch.cat(outputs, 1) 311 | -------------------------------------------------------------------------------- /stylegan2/inception_ffhq.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rosinality/alias-free-gan-pytorch/f14d54ce2d973880b0c352614b2d63088c9026ae/stylegan2/inception_ffhq.pkl -------------------------------------------------------------------------------- /stylegan2/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from skimage.measure import compare_ssim 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | from lpips import dist_model 12 | 13 | class PerceptualLoss(torch.nn.Module): 14 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) 15 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 16 | super(PerceptualLoss, self).__init__() 17 | print('Setting up Perceptual loss...') 18 | self.use_gpu = use_gpu 19 | self.spatial = spatial 20 | self.gpu_ids = gpu_ids 21 | self.model = dist_model.DistModel() 22 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) 23 | print('...[%s] initialized'%self.model.name()) 24 | print('...Done') 25 | 26 | def forward(self, pred, target, normalize=False): 27 | """ 28 | Pred and target are Variables. 29 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 30 | If normalize is False, assumes the images are already between [-1,+1] 31 | 32 | Inputs pred and target are Nx3xHxW 33 | Output pytorch Variable N long 34 | """ 35 | 36 | if normalize: 37 | target = 2 * target - 1 38 | pred = 2 * pred - 1 39 | 40 | return self.model.forward(target, pred) 41 | 42 | def normalize_tensor(in_feat,eps=1e-10): 43 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 44 | return in_feat/(norm_factor+eps) 45 | 46 | def l2(p0, p1, range=255.): 47 | return .5*np.mean((p0 / range - p1 / range)**2) 48 | 49 | def psnr(p0, p1, peak=255.): 50 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 51 | 52 | def dssim(p0, p1, range=255.): 53 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 54 | 55 | def rgb2lab(in_img,mean_cent=False): 56 | from skimage import color 57 | img_lab = color.rgb2lab(in_img) 58 | if(mean_cent): 59 | img_lab[:,:,0] = img_lab[:,:,0]-50 60 | return img_lab 61 | 62 | def tensor2np(tensor_obj): 63 | # change dimension of a tensor object into a numpy array 64 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 65 | 66 | def np2tensor(np_obj): 67 | # change dimenion of np array into tensor array 68 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 69 | 70 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 71 | # image tensor to lab tensor 72 | from skimage import color 73 | 74 | img = tensor2im(image_tensor) 75 | img_lab = color.rgb2lab(img) 76 | if(mc_only): 77 | img_lab[:,:,0] = img_lab[:,:,0]-50 78 | if(to_norm and not mc_only): 79 | img_lab[:,:,0] = img_lab[:,:,0]-50 80 | img_lab = img_lab/100. 81 | 82 | return np2tensor(img_lab) 83 | 84 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 85 | from skimage import color 86 | import warnings 87 | warnings.filterwarnings("ignore") 88 | 89 | lab = tensor2np(lab_tensor)*100. 90 | lab[:,:,0] = lab[:,:,0]+50 91 | 92 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 93 | if(return_inbnd): 94 | # convert back to lab, see if we match 95 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 96 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 97 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 98 | return (im2tensor(rgb_back),mask) 99 | else: 100 | return im2tensor(rgb_back) 101 | 102 | def rgb2lab(input): 103 | from skimage import color 104 | return color.rgb2lab(input / 255.) 105 | 106 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 107 | image_numpy = image_tensor[0].cpu().float().numpy() 108 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 109 | return image_numpy.astype(imtype) 110 | 111 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 112 | return torch.Tensor((image / factor - cent) 113 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 114 | 115 | def tensor2vec(vector_tensor): 116 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 117 | 118 | def voc_ap(rec, prec, use_07_metric=False): 119 | """ ap = voc_ap(rec, prec, [use_07_metric]) 120 | Compute VOC AP given precision and recall. 121 | If use_07_metric is true, uses the 122 | VOC 07 11 point method (default:False). 123 | """ 124 | if use_07_metric: 125 | # 11 point metric 126 | ap = 0. 127 | for t in np.arange(0., 1.1, 0.1): 128 | if np.sum(rec >= t) == 0: 129 | p = 0 130 | else: 131 | p = np.max(prec[rec >= t]) 132 | ap = ap + p / 11. 133 | else: 134 | # correct AP calculation 135 | # first append sentinel values at the end 136 | mrec = np.concatenate(([0.], rec, [1.])) 137 | mpre = np.concatenate(([0.], prec, [0.])) 138 | 139 | # compute the precision envelope 140 | for i in range(mpre.size - 1, 0, -1): 141 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 142 | 143 | # to calculate area under PR curve, look for points 144 | # where X axis (recall) changes value 145 | i = np.where(mrec[1:] != mrec[:-1])[0] 146 | 147 | # and sum (\Delta recall) * prec 148 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 149 | return ap 150 | 151 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 152 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 153 | image_numpy = image_tensor[0].cpu().float().numpy() 154 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 155 | return image_numpy.astype(imtype) 156 | 157 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 158 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 159 | return torch.Tensor((image / factor - cent) 160 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 161 | -------------------------------------------------------------------------------- /stylegan2/lpips/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | from pdb import set_trace as st 6 | from IPython import embed 7 | 8 | class BaseModel(): 9 | def __init__(self): 10 | pass; 11 | 12 | def name(self): 13 | return 'BaseModel' 14 | 15 | def initialize(self, use_gpu=True, gpu_ids=[0]): 16 | self.use_gpu = use_gpu 17 | self.gpu_ids = gpu_ids 18 | 19 | def forward(self): 20 | pass 21 | 22 | def get_image_paths(self): 23 | pass 24 | 25 | def optimize_parameters(self): 26 | pass 27 | 28 | def get_current_visuals(self): 29 | return self.input 30 | 31 | def get_current_errors(self): 32 | return {} 33 | 34 | def save(self, label): 35 | pass 36 | 37 | # helper saving function that can be used by subclasses 38 | def save_network(self, network, path, network_label, epoch_label): 39 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 40 | save_path = os.path.join(path, save_filename) 41 | torch.save(network.state_dict(), save_path) 42 | 43 | # helper loading function that can be used by subclasses 44 | def load_network(self, network, network_label, epoch_label): 45 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 46 | save_path = os.path.join(self.save_dir, save_filename) 47 | print('Loading network from %s'%save_path) 48 | network.load_state_dict(torch.load(save_path)) 49 | 50 | def update_learning_rate(): 51 | pass 52 | 53 | def get_image_paths(self): 54 | return self.image_paths 55 | 56 | def save_done(self, flag=False): 57 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 58 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 59 | -------------------------------------------------------------------------------- /stylegan2/lpips/dist_model.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import numpy as np 6 | import torch 7 | from torch import nn 8 | import os 9 | from collections import OrderedDict 10 | from torch.autograd import Variable 11 | import itertools 12 | from .base_model import BaseModel 13 | from scipy.ndimage import zoom 14 | import fractions 15 | import functools 16 | import skimage.transform 17 | from tqdm import tqdm 18 | 19 | from IPython import embed 20 | 21 | from . import networks_basic as networks 22 | import lpips as util 23 | 24 | class DistModel(BaseModel): 25 | def name(self): 26 | return self.model_name 27 | 28 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, 29 | use_gpu=True, printNet=False, spatial=False, 30 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): 31 | ''' 32 | INPUTS 33 | model - ['net-lin'] for linearly calibrated network 34 | ['net'] for off-the-shelf network 35 | ['L2'] for L2 distance in Lab colorspace 36 | ['SSIM'] for ssim in RGB colorspace 37 | net - ['squeeze','alex','vgg'] 38 | model_path - if None, will look in weights/[NET_NAME].pth 39 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 40 | use_gpu - bool - whether or not to use a GPU 41 | printNet - bool - whether or not to print network architecture out 42 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 43 | spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). 44 | spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. 45 | spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). 46 | is_train - bool - [True] for training mode 47 | lr - float - initial learning rate 48 | beta1 - float - initial momentum term for adam 49 | version - 0.1 for latest, 0.0 was original (with a bug) 50 | gpu_ids - int array - [0] by default, gpus to use 51 | ''' 52 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) 53 | 54 | self.model = model 55 | self.net = net 56 | self.is_train = is_train 57 | self.spatial = spatial 58 | self.gpu_ids = gpu_ids 59 | self.model_name = '%s [%s]'%(model,net) 60 | 61 | if(self.model == 'net-lin'): # pretrained net + linear layer 62 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, 63 | use_dropout=True, spatial=spatial, version=version, lpips=True) 64 | kw = {} 65 | if not use_gpu: 66 | kw['map_location'] = 'cpu' 67 | if(model_path is None): 68 | import inspect 69 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) 70 | 71 | if(not is_train): 72 | print('Loading model from: %s'%model_path) 73 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 74 | 75 | elif(self.model=='net'): # pretrained network 76 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) 77 | elif(self.model in ['L2','l2']): 78 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing 79 | self.model_name = 'L2' 80 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): 81 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) 82 | self.model_name = 'SSIM' 83 | else: 84 | raise ValueError("Model [%s] not recognized." % self.model) 85 | 86 | self.parameters = list(self.net.parameters()) 87 | 88 | if self.is_train: # training mode 89 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 90 | self.rankLoss = networks.BCERankingLoss() 91 | self.parameters += list(self.rankLoss.net.parameters()) 92 | self.lr = lr 93 | self.old_lr = lr 94 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 95 | else: # test mode 96 | self.net.eval() 97 | 98 | if(use_gpu): 99 | self.net.to(gpu_ids[0]) 100 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 101 | if(self.is_train): 102 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 103 | 104 | if(printNet): 105 | print('---------- Networks initialized -------------') 106 | networks.print_network(self.net) 107 | print('-----------------------------------------------') 108 | 109 | def forward(self, in0, in1, retPerLayer=False): 110 | ''' Function computes the distance between image patches in0 and in1 111 | INPUTS 112 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 113 | OUTPUT 114 | computed distances between in0 and in1 115 | ''' 116 | 117 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 118 | 119 | # ***** TRAINING FUNCTIONS ***** 120 | def optimize_parameters(self): 121 | self.forward_train() 122 | self.optimizer_net.zero_grad() 123 | self.backward_train() 124 | self.optimizer_net.step() 125 | self.clamp_weights() 126 | 127 | def clamp_weights(self): 128 | for module in self.net.modules(): 129 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 130 | module.weight.data = torch.clamp(module.weight.data,min=0) 131 | 132 | def set_input(self, data): 133 | self.input_ref = data['ref'] 134 | self.input_p0 = data['p0'] 135 | self.input_p1 = data['p1'] 136 | self.input_judge = data['judge'] 137 | 138 | if(self.use_gpu): 139 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 140 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 141 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 142 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 143 | 144 | self.var_ref = Variable(self.input_ref,requires_grad=True) 145 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 146 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 147 | 148 | def forward_train(self): # run forward pass 149 | # print(self.net.module.scaling_layer.shift) 150 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) 151 | 152 | self.d0 = self.forward(self.var_ref, self.var_p0) 153 | self.d1 = self.forward(self.var_ref, self.var_p1) 154 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 155 | 156 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 157 | 158 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 159 | 160 | return self.loss_total 161 | 162 | def backward_train(self): 163 | torch.mean(self.loss_total).backward() 164 | 165 | def compute_accuracy(self,d0,d1,judge): 166 | ''' d0, d1 are Variables, judge is a Tensor ''' 167 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 210 | self.old_lr = lr 211 | 212 | def score_2afc_dataset(data_loader, func, name=''): 213 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 214 | distance function 'func' in dataset 'data_loader' 215 | INPUTS 216 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 217 | func - callable distance function - calling d=func(in0,in1) should take 2 218 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 219 | OUTPUTS 220 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 221 | [1] - dictionary with following elements 222 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 223 | gts - N array in [0,1], preferred patch selected by human evaluators 224 | (closer to "0" for left patch p0, "1" for right patch p1, 225 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 226 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 227 | CONSTS 228 | N - number of test triplets in data_loader 229 | ''' 230 | 231 | d0s = [] 232 | d1s = [] 233 | gts = [] 234 | 235 | for data in tqdm(data_loader.load_data(), desc=name): 236 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() 237 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() 238 | gts+=data['judge'].cpu().numpy().flatten().tolist() 239 | 240 | d0s = np.array(d0s) 241 | d1s = np.array(d1s) 242 | gts = np.array(gts) 243 | scores = (d0s 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | if bias: 39 | grad_bias = grad_input.sum(dim).detach() 40 | 41 | else: 42 | grad_bias = empty 43 | 44 | return grad_input, grad_bias 45 | 46 | @staticmethod 47 | def backward(ctx, gradgrad_input, gradgrad_bias): 48 | out, = ctx.saved_tensors 49 | gradgrad_out = fused.fused_bias_act( 50 | gradgrad_input.contiguous(), 51 | gradgrad_bias, 52 | out, 53 | 3, 54 | 1, 55 | ctx.negative_slope, 56 | ctx.scale, 57 | ) 58 | 59 | return gradgrad_out, None, None, None, None 60 | 61 | 62 | class FusedLeakyReLUFunction(Function): 63 | @staticmethod 64 | def forward(ctx, input, bias, negative_slope, scale): 65 | empty = input.new_empty(0) 66 | 67 | ctx.bias = bias is not None 68 | 69 | if bias is None: 70 | bias = empty 71 | 72 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 73 | ctx.save_for_backward(out) 74 | ctx.negative_slope = negative_slope 75 | ctx.scale = scale 76 | 77 | return out 78 | 79 | @staticmethod 80 | def backward(ctx, grad_output): 81 | out, = ctx.saved_tensors 82 | 83 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 84 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale 85 | ) 86 | 87 | if not ctx.bias: 88 | grad_bias = None 89 | 90 | return grad_input, grad_bias, None, None 91 | 92 | 93 | class FusedLeakyReLU(nn.Module): 94 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 95 | super().__init__() 96 | 97 | if bias: 98 | self.bias = nn.Parameter(torch.zeros(channel)) 99 | 100 | else: 101 | self.bias = None 102 | 103 | self.negative_slope = negative_slope 104 | self.scale = scale 105 | 106 | def forward(self, input): 107 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 108 | 109 | 110 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 111 | if input.device.type == "cpu": 112 | if bias is not None: 113 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 114 | return ( 115 | F.leaky_relu( 116 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 117 | ) 118 | * scale 119 | ) 120 | 121 | else: 122 | return F.leaky_relu(input, negative_slope=0.2) * scale 123 | 124 | else: 125 | return FusedLeakyReLUFunction.apply( 126 | input.contiguous(), bias, negative_slope, scale 127 | ) 128 | -------------------------------------------------------------------------------- /stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 6 | const torch::Tensor &bias, 7 | const torch::Tensor &refer, int act, int grad, 8 | float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) \ 11 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 12 | #define CHECK_CONTIGUOUS(x) \ 13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 14 | #define CHECK_INPUT(x) \ 15 | CHECK_CUDA(x); \ 16 | CHECK_CONTIGUOUS(x) 17 | 18 | torch::Tensor fused_bias_act(const torch::Tensor &input, 19 | const torch::Tensor &bias, 20 | const torch::Tensor &refer, int act, int grad, 21 | float alpha, float scale) { 22 | CHECK_INPUT(input); 23 | CHECK_INPUT(bias); 24 | 25 | at::DeviceGuard guard(input.device()); 26 | 27 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 28 | } 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 32 | } -------------------------------------------------------------------------------- /stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | 15 | #include 16 | #include 17 | 18 | template 19 | static __global__ void 20 | fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b, 21 | const scalar_t *p_ref, int act, int grad, scalar_t alpha, 22 | scalar_t scale, int loop_x, int size_x, int step_b, 23 | int size_b, int use_bias, int use_ref) { 24 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 25 | 26 | scalar_t zero = 0.0; 27 | 28 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; 29 | loop_idx++, xi += blockDim.x) { 30 | scalar_t x = p_x[xi]; 31 | 32 | if (use_bias) { 33 | x += p_b[(xi / step_b) % size_b]; 34 | } 35 | 36 | scalar_t ref = use_ref ? p_ref[xi] : zero; 37 | 38 | scalar_t y; 39 | 40 | switch (act * 10 + grad) { 41 | default: 42 | case 10: 43 | y = x; 44 | break; 45 | case 11: 46 | y = x; 47 | break; 48 | case 12: 49 | y = 0.0; 50 | break; 51 | 52 | case 30: 53 | y = (x > 0.0) ? x : x * alpha; 54 | break; 55 | case 31: 56 | y = (ref > 0.0) ? x : x * alpha; 57 | break; 58 | case 32: 59 | y = 0.0; 60 | break; 61 | } 62 | 63 | out[xi] = y * scale; 64 | } 65 | } 66 | 67 | torch::Tensor fused_bias_act_op(const torch::Tensor &input, 68 | const torch::Tensor &bias, 69 | const torch::Tensor &refer, int act, int grad, 70 | float alpha, float scale) { 71 | int curDevice = -1; 72 | cudaGetDevice(&curDevice); 73 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 74 | 75 | auto x = input.contiguous(); 76 | auto b = bias.contiguous(); 77 | auto ref = refer.contiguous(); 78 | 79 | int use_bias = b.numel() ? 1 : 0; 80 | int use_ref = ref.numel() ? 1 : 0; 81 | 82 | int size_x = x.numel(); 83 | int size_b = b.numel(); 84 | int step_b = 1; 85 | 86 | for (int i = 1 + 1; i < x.dim(); i++) { 87 | step_b *= x.size(i); 88 | } 89 | 90 | int loop_x = 4; 91 | int block_size = 4 * 32; 92 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 93 | 94 | auto y = torch::empty_like(x); 95 | 96 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 97 | x.scalar_type(), "fused_bias_act_kernel", [&] { 98 | fused_bias_act_kernel<<>>( 99 | y.data_ptr(), x.data_ptr(), 100 | b.data_ptr(), ref.data_ptr(), act, grad, alpha, 101 | scale, loop_x, size_x, step_b, size_b, use_bias, use_ref); 102 | }); 103 | 104 | return y; 105 | } -------------------------------------------------------------------------------- /stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor &input, 5 | const torch::Tensor &kernel, int up_x, int up_y, 6 | int down_x, int down_y, int pad_x0, int pad_x1, 7 | int pad_y0, int pad_y1, bool flip, float gain); 8 | 9 | #define CHECK_CUDA(x) \ 10 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) \ 12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 13 | #define CHECK_INPUT(x) \ 14 | CHECK_CUDA(x); \ 15 | CHECK_CONTIGUOUS(x) 16 | 17 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel, 18 | int up_x, int up_y, int down_x, int down_y, int pad_x0, 19 | int pad_x1, int pad_y0, int pad_y1, bool flip, 20 | float gain) { 21 | CHECK_INPUT(input); 22 | CHECK_INPUT(kernel); 23 | 24 | at::DeviceGuard guard(input.device()); 25 | 26 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, 27 | pad_y0, pad_y1, flip, gain); 28 | } 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 32 | } -------------------------------------------------------------------------------- /stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | import os 3 | 4 | import torch 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | upfirdn2d_op = load( 12 | "upfirdn2d", 13 | sources=[ 14 | os.path.join(module_path, "upfirdn2d.cpp"), 15 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class UpFirDn2dBackward(Function): 21 | @staticmethod 22 | def forward( 23 | ctx, 24 | grad_output, 25 | kernel, 26 | up, 27 | down, 28 | pad, 29 | g_pad, 30 | in_size, 31 | out_size, 32 | flip_filter, 33 | gain, 34 | ): 35 | 36 | up_x, up_y = up 37 | down_x, down_y = down 38 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 39 | 40 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 41 | 42 | grad_input = upfirdn2d_op.upfirdn2d( 43 | grad_output, 44 | kernel, 45 | down_x, 46 | down_y, 47 | up_x, 48 | up_y, 49 | g_pad_x0, 50 | g_pad_x1, 51 | g_pad_y0, 52 | g_pad_y1, 53 | flip_filter, 54 | gain, 55 | ) 56 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 57 | 58 | ctx.save_for_backward(kernel) 59 | 60 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 61 | 62 | ctx.up_x = up_x 63 | ctx.up_y = up_y 64 | ctx.down_x = down_x 65 | ctx.down_y = down_y 66 | ctx.pad_x0 = pad_x0 67 | ctx.pad_x1 = pad_x1 68 | ctx.pad_y0 = pad_y0 69 | ctx.pad_y1 = pad_y1 70 | ctx.in_size = in_size 71 | ctx.out_size = out_size 72 | ctx.flip_filter = flip_filter 73 | ctx.gain = gain 74 | 75 | return grad_input 76 | 77 | @staticmethod 78 | def backward(ctx, gradgrad_input): 79 | (kernel,) = ctx.saved_tensors 80 | 81 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 82 | 83 | gradgrad_out = upfirdn2d_op.upfirdn2d( 84 | gradgrad_input, 85 | kernel, 86 | ctx.up_x, 87 | ctx.up_y, 88 | ctx.down_x, 89 | ctx.down_y, 90 | ctx.pad_x0, 91 | ctx.pad_x1, 92 | ctx.pad_y0, 93 | ctx.pad_y1, 94 | not ctx.flip_filter, 95 | ctx.gain, 96 | ) 97 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 98 | gradgrad_out = gradgrad_out.view( 99 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 100 | ) 101 | 102 | return gradgrad_out, None, None, None, None, None, None, None, None, None 103 | 104 | 105 | class UpFirDn2d(Function): 106 | @staticmethod 107 | def forward(ctx, input, kernel, up, down, pad, flip_filter, gain): 108 | up_x, up_y = up 109 | down_x, down_y = down 110 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 111 | 112 | kernel_h, kernel_w = kernel.shape 113 | batch, channel, in_h, in_w = input.shape 114 | ctx.in_size = input.shape 115 | 116 | input = input.reshape(-1, in_h, in_w, 1) 117 | 118 | ctx.save_for_backward(kernel) 119 | 120 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 121 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 122 | ctx.out_size = (out_h, out_w) 123 | 124 | ctx.up = (up_x, up_y) 125 | ctx.down = (down_x, down_y) 126 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 127 | 128 | g_pad_x0 = kernel_w - pad_x0 - 1 129 | g_pad_y0 = kernel_h - pad_y0 - 1 130 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 131 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 132 | 133 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 134 | ctx.flip_filter = flip_filter 135 | ctx.gain = gain 136 | 137 | out = upfirdn2d_op.upfirdn2d( 138 | input, 139 | kernel, 140 | up_x, 141 | up_y, 142 | down_x, 143 | down_y, 144 | pad_x0, 145 | pad_x1, 146 | pad_y0, 147 | pad_y1, 148 | flip_filter, 149 | gain, 150 | ) 151 | # out = out.view(major, out_h, out_w, minor) 152 | out = out.view(-1, channel, out_h, out_w) 153 | 154 | return out 155 | 156 | @staticmethod 157 | def backward(ctx, grad_output): 158 | (kernel,) = ctx.saved_tensors 159 | 160 | grad_input = None 161 | 162 | if ctx.needs_input_grad[0]: 163 | grad_input = UpFirDn2dBackward.apply( 164 | grad_output, 165 | kernel, 166 | ctx.up, 167 | ctx.down, 168 | ctx.pad, 169 | ctx.g_pad, 170 | ctx.in_size, 171 | ctx.out_size, 172 | not ctx.flip_filter, 173 | ctx.gain, 174 | ) 175 | 176 | return grad_input, None, None, None, None, None, None 177 | 178 | 179 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0), flip_filter=False, gain=1): 180 | if not isinstance(up, abc.Iterable): 181 | up = (up, up) 182 | 183 | if not isinstance(down, abc.Iterable): 184 | down = (down, down) 185 | 186 | if len(pad) == 2: 187 | pad = (pad[0], pad[1], pad[0], pad[1]) 188 | 189 | if input.device.type == "cpu": 190 | out = upfirdn2d_native( 191 | input, kernel, *up, *down, *pad, flip_filter=flip_filter, gain=gain 192 | ) 193 | 194 | else: 195 | gain = gain ** (kernel.ndim / 2) 196 | 197 | if kernel.ndim == 1: 198 | out = UpFirDn2d.apply( 199 | input, 200 | kernel.unsqueeze(0), 201 | (up[0], 1), 202 | (down[0], 1), 203 | (*pad[:2], 0, 0), 204 | flip_filter, 205 | gain, 206 | ) 207 | out = UpFirDn2d.apply( 208 | out, 209 | kernel.unsqueeze(1), 210 | (1, up[1]), 211 | (1, down[1]), 212 | (0, 0, *pad[2:]), 213 | flip_filter, 214 | gain, 215 | ) 216 | 217 | else: 218 | out = UpFirDn2d.apply(input, kernel, up, down, pad, flip_filter, gain) 219 | 220 | return out 221 | 222 | 223 | def upfirdn2d_native( 224 | input, 225 | kernel, 226 | up_x, 227 | up_y, 228 | down_x, 229 | down_y, 230 | pad_x0, 231 | pad_x1, 232 | pad_y0, 233 | pad_y1, 234 | flip_filter=False, 235 | gain=1, 236 | ): 237 | _, channel, in_h, in_w = input.shape 238 | input = input.reshape(-1, in_h, in_w, 1) 239 | 240 | _, in_h, in_w, minor = input.shape 241 | 242 | out = input.view(-1, in_h, 1, in_w, 1, minor) 243 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 244 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 245 | 246 | out = F.pad( 247 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 248 | ) 249 | out = out[ 250 | :, 251 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 252 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 253 | :, 254 | ] 255 | 256 | out = out.permute(0, 3, 1, 2) 257 | out = out.reshape( 258 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 259 | ) 260 | 261 | w = kernel * (gain ** (kernel.ndim / 2)) 262 | 263 | if not flip_filter: 264 | w = torch.flip(w, list(range(kernel.ndim))) 265 | 266 | if kernel.ndim == 2: 267 | kernel_h, kernel_w = w.shape 268 | 269 | out = F.conv2d(out, w.view(1, 1, kernel_h, kernel_w)) 270 | 271 | else: 272 | kernel_h = kernel_w = w.shape[0] 273 | 274 | out = F.conv2d(out, w.view(1, 1, kernel_h, 1)) 275 | out = F.conv2d(out, w.view(1, 1, 1, kernel_h)) 276 | 277 | out = out.reshape( 278 | -1, 279 | minor, 280 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 281 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 282 | ) 283 | out = out.permute(0, 2, 3, 1) 284 | out = out[:, ::down_y, ::down_x, :] 285 | 286 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 287 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 288 | 289 | return out.view(-1, channel, out_h, out_w) 290 | -------------------------------------------------------------------------------- /stylegan2/ppl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | import lpips 9 | from model import Generator 10 | 11 | 12 | def normalize(x): 13 | return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True)) 14 | 15 | 16 | def slerp(a, b, t): 17 | a = normalize(a) 18 | b = normalize(b) 19 | d = (a * b).sum(-1, keepdim=True) 20 | p = t * torch.acos(d) 21 | c = normalize(b - d * a) 22 | d = a * torch.cos(p) + c * torch.sin(p) 23 | 24 | return normalize(d) 25 | 26 | 27 | def lerp(a, b, t): 28 | return a + (b - a) * t 29 | 30 | 31 | if __name__ == "__main__": 32 | device = "cuda" 33 | 34 | parser = argparse.ArgumentParser(description="Perceptual Path Length calculator") 35 | 36 | parser.add_argument( 37 | "--space", choices=["z", "w"], help="space that PPL calculated with" 38 | ) 39 | parser.add_argument( 40 | "--batch", type=int, default=64, help="batch size for the models" 41 | ) 42 | parser.add_argument( 43 | "--n_sample", 44 | type=int, 45 | default=5000, 46 | help="number of the samples for calculating PPL", 47 | ) 48 | parser.add_argument( 49 | "--size", type=int, default=256, help="output image sizes of the generator" 50 | ) 51 | parser.add_argument( 52 | "--eps", type=float, default=1e-4, help="epsilon for numerical stability" 53 | ) 54 | parser.add_argument( 55 | "--crop", action="store_true", help="apply center crop to the images" 56 | ) 57 | parser.add_argument( 58 | "--sampling", 59 | default="end", 60 | choices=["end", "full"], 61 | help="set endpoint sampling method", 62 | ) 63 | parser.add_argument( 64 | "ckpt", metavar="CHECKPOINT", help="path to the model checkpoints" 65 | ) 66 | 67 | args = parser.parse_args() 68 | 69 | latent_dim = 512 70 | 71 | ckpt = torch.load(args.ckpt) 72 | 73 | g = Generator(args.size, latent_dim, 8).to(device) 74 | g.load_state_dict(ckpt["g_ema"]) 75 | g.eval() 76 | 77 | percept = lpips.PerceptualLoss( 78 | model="net-lin", net="vgg", use_gpu=device.startswith("cuda") 79 | ) 80 | 81 | distances = [] 82 | 83 | n_batch = args.n_sample // args.batch 84 | resid = args.n_sample - (n_batch * args.batch) 85 | batch_sizes = [args.batch] * n_batch + [resid] 86 | 87 | with torch.no_grad(): 88 | for batch in tqdm(batch_sizes): 89 | noise = g.make_noise() 90 | 91 | inputs = torch.randn([batch * 2, latent_dim], device=device) 92 | if args.sampling == "full": 93 | lerp_t = torch.rand(batch, device=device) 94 | else: 95 | lerp_t = torch.zeros(batch, device=device) 96 | 97 | if args.space == "w": 98 | latent = g.get_latent(inputs) 99 | latent_t0, latent_t1 = latent[::2], latent[1::2] 100 | latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None]) 101 | latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps) 102 | latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape) 103 | 104 | image, _ = g([latent_e], input_is_latent=True, noise=noise) 105 | 106 | if args.crop: 107 | c = image.shape[2] // 8 108 | image = image[:, :, c * 3 : c * 7, c * 2 : c * 6] 109 | 110 | factor = image.shape[2] // 256 111 | 112 | if factor > 1: 113 | image = F.interpolate( 114 | image, size=(256, 256), mode="bilinear", align_corners=False 115 | ) 116 | 117 | dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / ( 118 | args.eps ** 2 119 | ) 120 | distances.append(dist.to("cpu").numpy()) 121 | 122 | distances = np.concatenate(distances, 0) 123 | 124 | lo = np.percentile(distances, 1, interpolation="lower") 125 | hi = np.percentile(distances, 99, interpolation="higher") 126 | filtered_dist = np.extract( 127 | np.logical_and(lo <= distances, distances <= hi), distances 128 | ) 129 | 130 | print("ppl:", filtered_dist.mean()) 131 | -------------------------------------------------------------------------------- /stylegan2/prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from io import BytesIO 3 | import multiprocessing 4 | from functools import partial 5 | 6 | from PIL import Image 7 | import lmdb 8 | from tqdm import tqdm 9 | from torchvision import datasets 10 | from torchvision.transforms import functional as trans_fn 11 | 12 | 13 | def resize_and_convert(img, size, resample, quality=100): 14 | img = trans_fn.resize(img, size, resample) 15 | img = trans_fn.center_crop(img, size) 16 | buffer = BytesIO() 17 | img.save(buffer, format="jpeg", quality=quality) 18 | val = buffer.getvalue() 19 | 20 | return val 21 | 22 | 23 | def resize_multiple( 24 | img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100 25 | ): 26 | imgs = [] 27 | 28 | for size in sizes: 29 | imgs.append(resize_and_convert(img, size, resample, quality)) 30 | 31 | return imgs 32 | 33 | 34 | def resize_worker(img_file, sizes, resample): 35 | i, file = img_file 36 | img = Image.open(file) 37 | img = img.convert("RGB") 38 | out = resize_multiple(img, sizes=sizes, resample=resample) 39 | 40 | return i, out 41 | 42 | 43 | def prepare( 44 | env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS 45 | ): 46 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample) 47 | 48 | files = sorted(dataset.imgs, key=lambda x: x[0]) 49 | files = [(i, file) for i, (file, label) in enumerate(files)] 50 | total = 0 51 | 52 | with multiprocessing.Pool(n_worker) as pool: 53 | for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)): 54 | for size, img in zip(sizes, imgs): 55 | key = f"{size}-{str(i).zfill(5)}".encode("utf-8") 56 | 57 | with env.begin(write=True) as txn: 58 | txn.put(key, img) 59 | 60 | total += 1 61 | 62 | with env.begin(write=True) as txn: 63 | txn.put("length".encode("utf-8"), str(total).encode("utf-8")) 64 | 65 | 66 | if __name__ == "__main__": 67 | parser = argparse.ArgumentParser(description="Preprocess images for model training") 68 | parser.add_argument("--out", type=str, help="filename of the result lmdb dataset") 69 | parser.add_argument( 70 | "--size", 71 | type=str, 72 | default="128,256,512,1024", 73 | help="resolutions of images for the dataset", 74 | ) 75 | parser.add_argument( 76 | "--n_worker", 77 | type=int, 78 | default=8, 79 | help="number of workers for preparing dataset", 80 | ) 81 | parser.add_argument( 82 | "--resample", 83 | type=str, 84 | default="lanczos", 85 | help="resampling methods for resizing images", 86 | ) 87 | parser.add_argument("path", type=str, help="path to the image dataset") 88 | 89 | args = parser.parse_args() 90 | 91 | resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR} 92 | resample = resample_map[args.resample] 93 | 94 | sizes = [int(s.strip()) for s in args.size.split(",")] 95 | 96 | print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes)) 97 | 98 | imgset = datasets.ImageFolder(args.path) 99 | 100 | with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env: 101 | prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample) 102 | -------------------------------------------------------------------------------- /stylegan2/projector.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | 5 | import torch 6 | from torch import optim 7 | from torch.nn import functional as F 8 | from torchvision import transforms 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | import lpips 13 | from model import Generator 14 | 15 | 16 | def noise_regularize(noises): 17 | loss = 0 18 | 19 | for noise in noises: 20 | size = noise.shape[2] 21 | 22 | while True: 23 | loss = ( 24 | loss 25 | + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) 26 | + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) 27 | ) 28 | 29 | if size <= 8: 30 | break 31 | 32 | noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2]) 33 | noise = noise.mean([3, 5]) 34 | size //= 2 35 | 36 | return loss 37 | 38 | 39 | def noise_normalize_(noises): 40 | for noise in noises: 41 | mean = noise.mean() 42 | std = noise.std() 43 | 44 | noise.data.add_(-mean).div_(std) 45 | 46 | 47 | def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): 48 | lr_ramp = min(1, (1 - t) / rampdown) 49 | lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) 50 | lr_ramp = lr_ramp * min(1, t / rampup) 51 | 52 | return initial_lr * lr_ramp 53 | 54 | 55 | def latent_noise(latent, strength): 56 | noise = torch.randn_like(latent) * strength 57 | 58 | return latent + noise 59 | 60 | 61 | def make_image(tensor): 62 | return ( 63 | tensor.detach() 64 | .clamp_(min=-1, max=1) 65 | .add(1) 66 | .div_(2) 67 | .mul(255) 68 | .type(torch.uint8) 69 | .permute(0, 2, 3, 1) 70 | .to("cpu") 71 | .numpy() 72 | ) 73 | 74 | 75 | if __name__ == "__main__": 76 | device = "cuda" 77 | 78 | parser = argparse.ArgumentParser( 79 | description="Image projector to the generator latent spaces" 80 | ) 81 | parser.add_argument( 82 | "--ckpt", type=str, required=True, help="path to the model checkpoint" 83 | ) 84 | parser.add_argument( 85 | "--size", type=int, default=256, help="output image sizes of the generator" 86 | ) 87 | parser.add_argument( 88 | "--lr_rampup", 89 | type=float, 90 | default=0.05, 91 | help="duration of the learning rate warmup", 92 | ) 93 | parser.add_argument( 94 | "--lr_rampdown", 95 | type=float, 96 | default=0.25, 97 | help="duration of the learning rate decay", 98 | ) 99 | parser.add_argument("--lr", type=float, default=0.1, help="learning rate") 100 | parser.add_argument( 101 | "--noise", type=float, default=0.05, help="strength of the noise level" 102 | ) 103 | parser.add_argument( 104 | "--noise_ramp", 105 | type=float, 106 | default=0.75, 107 | help="duration of the noise level decay", 108 | ) 109 | parser.add_argument("--step", type=int, default=1000, help="optimize iterations") 110 | parser.add_argument( 111 | "--noise_regularize", 112 | type=float, 113 | default=1e5, 114 | help="weight of the noise regularization", 115 | ) 116 | parser.add_argument("--mse", type=float, default=0, help="weight of the mse loss") 117 | parser.add_argument( 118 | "--w_plus", 119 | action="store_true", 120 | help="allow to use distinct latent codes to each layers", 121 | ) 122 | parser.add_argument( 123 | "files", metavar="FILES", nargs="+", help="path to image files to be projected" 124 | ) 125 | 126 | args = parser.parse_args() 127 | 128 | n_mean_latent = 10000 129 | 130 | resize = min(args.size, 256) 131 | 132 | transform = transforms.Compose( 133 | [ 134 | transforms.Resize(resize), 135 | transforms.CenterCrop(resize), 136 | transforms.ToTensor(), 137 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 138 | ] 139 | ) 140 | 141 | imgs = [] 142 | 143 | for imgfile in args.files: 144 | img = transform(Image.open(imgfile).convert("RGB")) 145 | imgs.append(img) 146 | 147 | imgs = torch.stack(imgs, 0).to(device) 148 | 149 | g_ema = Generator(args.size, 512, 8) 150 | g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False) 151 | g_ema.eval() 152 | g_ema = g_ema.to(device) 153 | 154 | with torch.no_grad(): 155 | noise_sample = torch.randn(n_mean_latent, 512, device=device) 156 | latent_out = g_ema.style(noise_sample) 157 | 158 | latent_mean = latent_out.mean(0) 159 | latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5 160 | 161 | percept = lpips.PerceptualLoss( 162 | model="net-lin", net="vgg", use_gpu=device.startswith("cuda") 163 | ) 164 | 165 | noises_single = g_ema.make_noise() 166 | noises = [] 167 | for noise in noises_single: 168 | noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_()) 169 | 170 | latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1) 171 | 172 | if args.w_plus: 173 | latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1) 174 | 175 | latent_in.requires_grad = True 176 | 177 | for noise in noises: 178 | noise.requires_grad = True 179 | 180 | optimizer = optim.Adam([latent_in] + noises, lr=args.lr) 181 | 182 | pbar = tqdm(range(args.step)) 183 | latent_path = [] 184 | 185 | for i in pbar: 186 | t = i / args.step 187 | lr = get_lr(t, args.lr) 188 | optimizer.param_groups[0]["lr"] = lr 189 | noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2 190 | latent_n = latent_noise(latent_in, noise_strength.item()) 191 | 192 | img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises) 193 | 194 | batch, channel, height, width = img_gen.shape 195 | 196 | if height > 256: 197 | factor = height // 256 198 | 199 | img_gen = img_gen.reshape( 200 | batch, channel, height // factor, factor, width // factor, factor 201 | ) 202 | img_gen = img_gen.mean([3, 5]) 203 | 204 | p_loss = percept(img_gen, imgs).sum() 205 | n_loss = noise_regularize(noises) 206 | mse_loss = F.mse_loss(img_gen, imgs) 207 | 208 | loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss 209 | 210 | optimizer.zero_grad() 211 | loss.backward() 212 | optimizer.step() 213 | 214 | noise_normalize_(noises) 215 | 216 | if (i + 1) % 100 == 0: 217 | latent_path.append(latent_in.detach().clone()) 218 | 219 | pbar.set_description( 220 | ( 221 | f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};" 222 | f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}" 223 | ) 224 | ) 225 | 226 | img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises) 227 | 228 | filename = os.path.splitext(os.path.basename(args.files[0]))[0] + ".pt" 229 | 230 | img_ar = make_image(img_gen) 231 | 232 | result_file = {} 233 | for i, input_name in enumerate(args.files): 234 | noise_single = [] 235 | for noise in noises: 236 | noise_single.append(noise[i : i + 1]) 237 | 238 | result_file[input_name] = { 239 | "img": img_gen[i], 240 | "latent": latent_in[i], 241 | "noise": noise_single, 242 | } 243 | 244 | img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png" 245 | pil_img = Image.fromarray(img_ar[i]) 246 | pil_img.save(img_name) 247 | 248 | torch.save(result_file, filename) 249 | -------------------------------------------------------------------------------- /stylegan2/sample/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | -------------------------------------------------------------------------------- /stylegan2/swagan.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import functools 4 | import operator 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torch.autograd import Function 10 | 11 | from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix 12 | from model import ModulatedConv2d, StyledConv, ConstantInput, PixelNorm, Upsample, Downsample, Blur, EqualLinear, ConvLayer 13 | 14 | def get_haar_wavelet(in_channels): 15 | haar_wav_l = 1 / (2 ** 0.5) * torch.ones(1, 2) 16 | haar_wav_h = 1 / (2 ** 0.5) * torch.ones(1, 2) 17 | haar_wav_h[0, 0] = -1 * haar_wav_h[0, 0] 18 | 19 | haar_wav_ll = haar_wav_l.T * haar_wav_l 20 | haar_wav_lh = haar_wav_h.T * haar_wav_l 21 | haar_wav_hl = haar_wav_l.T * haar_wav_h 22 | haar_wav_hh = haar_wav_h.T * haar_wav_h 23 | 24 | return haar_wav_ll, haar_wav_lh, haar_wav_hl, haar_wav_hh 25 | 26 | 27 | class HaarTransform(nn.Module): 28 | def __init__(self, in_channels): 29 | super().__init__() 30 | 31 | ll, lh, hl, hh = get_haar_wavelet(in_channels) 32 | 33 | self.register_buffer('ll', ll) 34 | self.register_buffer('lh', lh) 35 | self.register_buffer('hl', hl) 36 | self.register_buffer('hh', hh) 37 | 38 | def forward(self, input): 39 | ll = upfirdn2d(input, self.ll, down=2) 40 | lh = upfirdn2d(input, self.lh, down=2) 41 | hl = upfirdn2d(input, self.hl, down=2) 42 | hh = upfirdn2d(input, self.hh, down=2) 43 | 44 | return torch.cat((ll, lh, hl, hh), 1) 45 | 46 | class InverseHaarTransform(nn.Module): 47 | def __init__(self, in_channels): 48 | super().__init__() 49 | 50 | ll, lh, hl, hh = get_haar_wavelet(in_channels) 51 | 52 | self.register_buffer('ll', ll) 53 | self.register_buffer('lh', -lh) 54 | self.register_buffer('hl', -hl) 55 | self.register_buffer('hh', hh) 56 | 57 | def forward(self, input): 58 | ll, lh, hl, hh = input.chunk(4, 1) 59 | ll = upfirdn2d(ll, self.ll, up=2, pad=(1, 0, 1, 0)) 60 | lh = upfirdn2d(lh, self.lh, up=2, pad=(1, 0, 1, 0)) 61 | hl = upfirdn2d(hl, self.hl, up=2, pad=(1, 0, 1, 0)) 62 | hh = upfirdn2d(hh, self.hh, up=2, pad=(1, 0, 1, 0)) 63 | 64 | return ll + lh + hl + hh 65 | 66 | 67 | class ToRGB(nn.Module): 68 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): 69 | super().__init__() 70 | 71 | if upsample: 72 | self.iwt = InverseHaarTransform(3) 73 | self.upsample = Upsample(blur_kernel) 74 | self.dwt = HaarTransform(3) 75 | 76 | self.conv = ModulatedConv2d(in_channel, 3 * 4, 1, style_dim, demodulate=False) 77 | self.bias = nn.Parameter(torch.zeros(1, 3 * 4, 1, 1)) 78 | 79 | def forward(self, input, style, skip=None): 80 | out = self.conv(input, style) 81 | out = out + self.bias 82 | 83 | if skip is not None: 84 | skip = self.iwt(skip) 85 | skip = self.upsample(skip) 86 | skip = self.dwt(skip) 87 | 88 | out = out + skip 89 | 90 | return out 91 | 92 | 93 | class Generator(nn.Module): 94 | def __init__( 95 | self, 96 | size, 97 | style_dim, 98 | n_mlp, 99 | channel_multiplier=2, 100 | blur_kernel=[1, 3, 3, 1], 101 | lr_mlp=0.01, 102 | ): 103 | super().__init__() 104 | 105 | self.size = size 106 | 107 | self.style_dim = style_dim 108 | 109 | layers = [PixelNorm()] 110 | 111 | for i in range(n_mlp): 112 | layers.append( 113 | EqualLinear( 114 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" 115 | ) 116 | ) 117 | 118 | self.style = nn.Sequential(*layers) 119 | 120 | self.channels = { 121 | 4: 512, 122 | 8: 512, 123 | 16: 512, 124 | 32: 512, 125 | 64: 256 * channel_multiplier, 126 | 128: 128 * channel_multiplier, 127 | 256: 64 * channel_multiplier, 128 | 512: 32 * channel_multiplier, 129 | 1024: 16 * channel_multiplier, 130 | } 131 | 132 | self.input = ConstantInput(self.channels[4]) 133 | self.conv1 = StyledConv( 134 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 135 | ) 136 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) 137 | 138 | self.log_size = int(math.log(size, 2)) - 1 139 | self.num_layers = (self.log_size - 2) * 2 + 1 140 | 141 | self.convs = nn.ModuleList() 142 | self.upsamples = nn.ModuleList() 143 | self.to_rgbs = nn.ModuleList() 144 | self.noises = nn.Module() 145 | 146 | in_channel = self.channels[4] 147 | 148 | for layer_idx in range(self.num_layers): 149 | res = (layer_idx + 5) // 2 150 | shape = [1, 1, 2 ** res, 2 ** res] 151 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) 152 | 153 | for i in range(3, self.log_size + 1): 154 | out_channel = self.channels[2 ** i] 155 | 156 | self.convs.append( 157 | StyledConv( 158 | in_channel, 159 | out_channel, 160 | 3, 161 | style_dim, 162 | upsample=True, 163 | blur_kernel=blur_kernel, 164 | ) 165 | ) 166 | 167 | self.convs.append( 168 | StyledConv( 169 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 170 | ) 171 | ) 172 | 173 | self.to_rgbs.append(ToRGB(out_channel, style_dim)) 174 | 175 | in_channel = out_channel 176 | 177 | self.iwt = InverseHaarTransform(3) 178 | 179 | self.n_latent = self.log_size * 2 - 2 180 | 181 | def make_noise(self): 182 | device = self.input.input.device 183 | 184 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 185 | 186 | for i in range(3, self.log_size + 1): 187 | for _ in range(2): 188 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 189 | 190 | return noises 191 | 192 | def mean_latent(self, n_latent): 193 | latent_in = torch.randn( 194 | n_latent, self.style_dim, device=self.input.input.device 195 | ) 196 | latent = self.style(latent_in).mean(0, keepdim=True) 197 | 198 | return latent 199 | 200 | def get_latent(self, input): 201 | return self.style(input) 202 | 203 | def forward( 204 | self, 205 | styles, 206 | return_latents=False, 207 | inject_index=None, 208 | truncation=1, 209 | truncation_latent=None, 210 | input_is_latent=False, 211 | noise=None, 212 | randomize_noise=True, 213 | ): 214 | if not input_is_latent: 215 | styles = [self.style(s) for s in styles] 216 | 217 | if noise is None: 218 | if randomize_noise: 219 | noise = [None] * self.num_layers 220 | else: 221 | noise = [ 222 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) 223 | ] 224 | 225 | if truncation < 1: 226 | style_t = [] 227 | 228 | for style in styles: 229 | style_t.append( 230 | truncation_latent + truncation * (style - truncation_latent) 231 | ) 232 | 233 | styles = style_t 234 | 235 | if len(styles) < 2: 236 | inject_index = self.n_latent 237 | 238 | if styles[0].ndim < 3: 239 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 240 | 241 | else: 242 | latent = styles[0] 243 | 244 | else: 245 | if inject_index is None: 246 | inject_index = random.randint(1, self.n_latent - 1) 247 | 248 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 249 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 250 | 251 | latent = torch.cat([latent, latent2], 1) 252 | 253 | out = self.input(latent) 254 | out = self.conv1(out, latent[:, 0], noise=noise[0]) 255 | 256 | skip = self.to_rgb1(out, latent[:, 1]) 257 | 258 | i = 1 259 | for conv1, conv2, noise1, noise2, to_rgb in zip( 260 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 261 | ): 262 | out = conv1(out, latent[:, i], noise=noise1) 263 | out = conv2(out, latent[:, i + 1], noise=noise2) 264 | skip = to_rgb(out, latent[:, i + 2], skip) 265 | 266 | i += 2 267 | 268 | image = self.iwt(skip) 269 | 270 | if return_latents: 271 | return image, latent 272 | 273 | else: 274 | return image, None 275 | 276 | 277 | class ConvBlock(nn.Module): 278 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): 279 | super().__init__() 280 | 281 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 282 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) 283 | 284 | def forward(self, input): 285 | out = self.conv1(input) 286 | out = self.conv2(out) 287 | 288 | return out 289 | 290 | 291 | class FromRGB(nn.Module): 292 | def __init__(self, out_channel, downsample=True, blur_kernel=[1, 3, 3, 1]): 293 | super().__init__() 294 | 295 | self.downsample = downsample 296 | 297 | if downsample: 298 | self.iwt = InverseHaarTransform(3) 299 | self.downsample = Downsample(blur_kernel) 300 | self.dwt = HaarTransform(3) 301 | 302 | self.conv = ConvLayer(3 * 4, out_channel, 1) 303 | 304 | def forward(self, input, skip=None): 305 | if self.downsample: 306 | input = self.iwt(input) 307 | input = self.downsample(input) 308 | input = self.dwt(input) 309 | 310 | out = self.conv(input) 311 | 312 | if skip is not None: 313 | out = out + skip 314 | 315 | return input, out 316 | 317 | 318 | class Discriminator(nn.Module): 319 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): 320 | super().__init__() 321 | 322 | channels = { 323 | 4: 512, 324 | 8: 512, 325 | 16: 512, 326 | 32: 512, 327 | 64: 256 * channel_multiplier, 328 | 128: 128 * channel_multiplier, 329 | 256: 64 * channel_multiplier, 330 | 512: 32 * channel_multiplier, 331 | 1024: 16 * channel_multiplier, 332 | } 333 | 334 | self.dwt = HaarTransform(3) 335 | 336 | self.from_rgbs = nn.ModuleList() 337 | self.convs = nn.ModuleList() 338 | 339 | log_size = int(math.log(size, 2)) - 1 340 | 341 | in_channel = channels[size] 342 | 343 | for i in range(log_size, 2, -1): 344 | out_channel = channels[2 ** (i - 1)] 345 | 346 | self.from_rgbs.append(FromRGB(in_channel, downsample=i != log_size)) 347 | self.convs.append(ConvBlock(in_channel, out_channel, blur_kernel)) 348 | 349 | in_channel = out_channel 350 | 351 | self.from_rgbs.append(FromRGB(channels[4])) 352 | 353 | self.stddev_group = 4 354 | self.stddev_feat = 1 355 | 356 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) 357 | self.final_linear = nn.Sequential( 358 | EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), 359 | EqualLinear(channels[4], 1), 360 | ) 361 | 362 | def forward(self, input): 363 | input = self.dwt(input) 364 | out = None 365 | 366 | for from_rgb, conv in zip(self.from_rgbs, self.convs): 367 | input, out = from_rgb(input, out) 368 | out = conv(out) 369 | 370 | _, out = self.from_rgbs[-1](input, out) 371 | 372 | batch, channel, height, width = out.shape 373 | group = min(batch, self.stddev_group) 374 | stddev = out.view( 375 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 376 | ) 377 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 378 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 379 | stddev = stddev.repeat(group, 1, height, width) 380 | out = torch.cat([out, stddev], 1) 381 | 382 | out = self.final_conv(out) 383 | 384 | out = out.view(batch, -1) 385 | out = self.final_linear(out) 386 | 387 | return out 388 | 389 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import random 4 | import os 5 | 6 | import numpy as np 7 | import torch 8 | from torch import nn, autograd, optim 9 | from torch.nn import functional as F 10 | from torch.utils import data 11 | import torch.distributed as dist 12 | from torchvision import transforms, utils 13 | from tqdm import tqdm 14 | from tensorfn import load_arg_config, distributed as dist, get_logger 15 | 16 | try: 17 | import wandb 18 | 19 | except ImportError: 20 | wandb = None 21 | 22 | 23 | from stylegan2.dataset import MultiResolutionDataset 24 | from stylegan2.distributed import ( 25 | get_rank, 26 | synchronize, 27 | reduce_loss_dict, 28 | reduce_sum, 29 | get_world_size, 30 | ) 31 | from stylegan2.op import conv2d_gradfix 32 | from stylegan2.non_leaking import augment, AdaptiveAugment 33 | from stylegan2.model import Discriminator 34 | from config import GANConfig 35 | from model import Generator, filter_parameters 36 | 37 | 38 | def data_sampler(dataset, shuffle, distributed): 39 | if distributed: 40 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 41 | 42 | if shuffle: 43 | return data.RandomSampler(dataset) 44 | 45 | else: 46 | return data.SequentialSampler(dataset) 47 | 48 | 49 | def requires_grad(model, flag=True): 50 | for p in model.parameters(): 51 | p.requires_grad = flag 52 | 53 | 54 | @torch.no_grad() 55 | def accumulate(model1, model2, decay=0.999): 56 | par1 = dict(model1.named_parameters()) 57 | par2 = dict(model2.named_parameters()) 58 | 59 | for k in par1.keys(): 60 | par1[k].data.mul_(decay).add_(par2[k].data, alpha=1 - decay) 61 | 62 | buf1 = dict(model1.named_buffers()) 63 | buf2 = dict(model2.named_buffers()) 64 | 65 | for k in buf1.keys(): 66 | buf1[k].detach().copy_(buf2[k].detach()) 67 | 68 | 69 | def sample_data(loader): 70 | while True: 71 | for batch in loader: 72 | yield batch 73 | 74 | 75 | def d_logistic_loss(real_pred, fake_pred): 76 | real_loss = F.softplus(-real_pred) 77 | fake_loss = F.softplus(fake_pred) 78 | 79 | return real_loss.mean() + fake_loss.mean() 80 | 81 | 82 | def d_r1_loss(real_pred, real_img): 83 | with conv2d_gradfix.no_weight_gradients(): 84 | (grad_real,) = autograd.grad( 85 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 86 | ) 87 | grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean() 88 | 89 | return grad_penalty 90 | 91 | 92 | def g_nonsaturating_loss(fake_pred): 93 | loss = F.softplus(-fake_pred).mean() 94 | 95 | return loss 96 | 97 | 98 | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): 99 | noise = torch.randn_like(fake_img) / math.sqrt( 100 | fake_img.shape[2] * fake_img.shape[3] 101 | ) 102 | (grad,) = autograd.grad( 103 | outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True 104 | ) 105 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) 106 | 107 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) 108 | 109 | path_penalty = (path_lengths - path_mean).pow(2).mean() 110 | 111 | return path_penalty, path_mean.detach(), path_lengths 112 | 113 | 114 | def make_noise(batch, latent_dim, n_noise, device): 115 | if n_noise == 1: 116 | return torch.randn(batch, latent_dim, device=device) 117 | 118 | noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0) 119 | 120 | return noises 121 | 122 | 123 | def mixing_noise(batch, latent_dim, prob, device): 124 | if prob > 0 and random.random() < prob: 125 | return make_noise(batch, latent_dim, 2, device) 126 | 127 | else: 128 | return [make_noise(batch, latent_dim, 1, device)] 129 | 130 | 131 | def set_grad_none(model, targets): 132 | for n, p in model.named_parameters(): 133 | if n in targets: 134 | p.grad = None 135 | 136 | 137 | def train(conf, loader, generator, discriminator, g_optim, d_optim, g_ema, device): 138 | loader = sample_data(loader) 139 | 140 | pbar = range(conf.training.iter) 141 | 142 | if get_rank() == 0: 143 | pbar = tqdm( 144 | pbar, initial=conf.training.start_iter, dynamic_ncols=True, smoothing=0.01 145 | ) 146 | 147 | mean_path_length = 0 148 | 149 | d_loss_val = 0 150 | r1_loss = torch.tensor(0.0, device=device) 151 | g_loss_val = 0 152 | path_loss = torch.tensor(0.0, device=device) 153 | path_lengths = torch.tensor(0.0, device=device) 154 | mean_path_length_avg = 0 155 | loss_dict = {} 156 | 157 | if conf.distributed: 158 | g_module = generator.module 159 | d_module = discriminator.module 160 | 161 | else: 162 | g_module = generator 163 | d_module = discriminator 164 | 165 | accum = 0.5 ** (32 / (10 * 1000)) 166 | ada_aug_p = conf.training.augment_p if conf.training.augment_p > 0 else 0.0 167 | r_t_stat = 0 168 | 169 | if conf.training.augment and conf.training.augment_p == 0: 170 | ada_augment = AdaptiveAugment( 171 | conf.training.ada_target, conf.training.ada_length, 8, device 172 | ) 173 | 174 | sample_z = torch.randn( 175 | conf.training.n_sample, conf.generator["style_dim"], device=device 176 | ) 177 | 178 | for idx in pbar: 179 | i = idx + conf.training.start_iter 180 | 181 | if i > conf.training.iter: 182 | print("Done!") 183 | 184 | break 185 | 186 | real_img = next(loader) 187 | real_img = real_img.to(device) 188 | 189 | requires_grad(generator, False) 190 | requires_grad(discriminator, True) 191 | 192 | noise = make_noise(conf.training.batch, conf.generator["style_dim"], 1, device) 193 | generator.eval() 194 | fake_img = generator(noise) 195 | generator.train() 196 | 197 | if conf.training.augment: 198 | real_img_aug, _ = augment(real_img, ada_aug_p) 199 | fake_img, _ = augment(fake_img, ada_aug_p) 200 | 201 | else: 202 | real_img_aug = real_img 203 | 204 | fake_pred = discriminator(fake_img) 205 | real_pred = discriminator(real_img_aug) 206 | d_loss = d_logistic_loss(real_pred, fake_pred) 207 | 208 | loss_dict["d"] = d_loss 209 | loss_dict["real_score"] = real_pred.mean() 210 | loss_dict["fake_score"] = fake_pred.mean() 211 | 212 | discriminator.zero_grad() 213 | d_loss.backward() 214 | d_optim.step() 215 | 216 | if conf.training.augment and conf.training.augment_p == 0: 217 | ada_aug_p = ada_augment.tune(real_pred) 218 | r_t_stat = ada_augment.r_t_stat 219 | 220 | d_regularize = i % conf.training.d_reg_every == 0 221 | 222 | if d_regularize: 223 | real_img.requires_grad = True 224 | 225 | if conf.training.augment: 226 | real_img_aug, _ = augment(real_img, ada_aug_p) 227 | 228 | else: 229 | real_img_aug = real_img 230 | 231 | real_pred = discriminator(real_img_aug) 232 | r1_loss = d_r1_loss(real_pred, real_img) 233 | 234 | discriminator.zero_grad() 235 | ( 236 | conf.training.r1 / 2 * r1_loss * conf.training.d_reg_every 237 | + 0 * real_pred[0] 238 | ).backward() 239 | 240 | d_optim.step() 241 | 242 | loss_dict["r1"] = r1_loss 243 | 244 | requires_grad(generator, True) 245 | requires_grad(discriminator, False) 246 | 247 | noise = make_noise(conf.training.batch, conf.generator["style_dim"], 1, device) 248 | fake_img = generator(noise) 249 | 250 | if conf.training.augment: 251 | fake_img, _ = augment(fake_img, ada_aug_p) 252 | 253 | fake_pred = discriminator(fake_img) 254 | g_loss = g_nonsaturating_loss(fake_pred) 255 | 256 | loss_dict["g"] = g_loss 257 | 258 | generator.zero_grad() 259 | g_loss.backward() 260 | g_optim.step() 261 | 262 | accumulate(g_ema, g_module, accum) 263 | 264 | loss_reduced = reduce_loss_dict(loss_dict) 265 | 266 | d_loss_val = loss_reduced["d"].mean().item() 267 | g_loss_val = loss_reduced["g"].mean().item() 268 | r1_val = loss_reduced["r1"].mean().item() 269 | real_score_val = loss_reduced["real_score"].mean().item() 270 | fake_score_val = loss_reduced["fake_score"].mean().item() 271 | 272 | if get_rank() == 0: 273 | pbar.set_description( 274 | ( 275 | f"d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; " 276 | f"augment: {ada_aug_p:.4f}" 277 | ) 278 | ) 279 | 280 | if wandb and conf.wandb: 281 | wandb.log( 282 | { 283 | "Generator": g_loss_val, 284 | "Discriminator": d_loss_val, 285 | "Augment": ada_aug_p, 286 | "Rt": r_t_stat, 287 | "R1": r1_val, 288 | "Real Score": real_score_val, 289 | "Fake Score": fake_score_val, 290 | } 291 | ) 292 | 293 | if i % 100 == 0: 294 | generator.zero_grad() 295 | discriminator.zero_grad() 296 | with torch.no_grad(): 297 | g_ema.eval() 298 | sample = g_ema(sample_z).cpu() 299 | utils.save_image( 300 | sample, 301 | f"sample/{str(i).zfill(6)}.png", 302 | nrow=int(conf.training.n_sample ** 0.5), 303 | normalize=True, 304 | value_range=(-1, 1), 305 | ) 306 | sample = None # cleanup memory 307 | 308 | if i % 10000 == 0: 309 | torch.save( 310 | { 311 | "g": g_module.state_dict(), 312 | "d": d_module.state_dict(), 313 | "g_ema": g_ema.state_dict(), 314 | "g_optim": g_optim.state_dict(), 315 | "d_optim": d_optim.state_dict(), 316 | "conf": conf.dict(), 317 | "ada_aug_p": ada_aug_p, 318 | }, 319 | f"checkpoint/{str(i).zfill(6)}.pt", 320 | ) 321 | 322 | 323 | def main(conf): 324 | device = "cuda" 325 | conf.distributed = conf.n_gpu > 1 326 | 327 | logger = get_logger(mode=conf.logger) 328 | logger.info(conf.dict()) 329 | 330 | generator = conf.generator.make().to(device) 331 | g_ema = conf.generator.make().to(device) 332 | discriminator = conf.discriminator.make().to(device) 333 | accumulate(g_ema, generator, 0) 334 | 335 | logger.info(generator) 336 | logger.info(discriminator) 337 | 338 | d_reg_ratio = conf.training.d_reg_every / (conf.training.d_reg_every + 1) 339 | 340 | g_optim = optim.Adam(generator.parameters(), lr=conf.training.lr_g, betas=(0, 0.99)) 341 | d_optim = optim.Adam( 342 | discriminator.parameters(), 343 | lr=conf.training.lr_d * d_reg_ratio, 344 | betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio), 345 | ) 346 | 347 | if conf.ckpt is not None: 348 | logger.info(f"load model: {conf.ckpt}") 349 | 350 | ckpt = torch.load(conf.ckpt, map_location=lambda storage, loc: storage) 351 | 352 | try: 353 | ckpt_name = os.path.basename(conf.ckpt) 354 | conf.training.start_iter = int(os.path.splitext(ckpt_name)[0]) 355 | 356 | except ValueError: 357 | pass 358 | 359 | generator.load_state_dict(ckpt["g"]) 360 | discriminator.load_state_dict(ckpt["d"]) 361 | g_ema.load_state_dict(ckpt["g_ema"]) 362 | 363 | g_optim.load_state_dict(ckpt["g_optim"]) 364 | d_optim.load_state_dict(ckpt["d_optim"]) 365 | 366 | if conf.distributed: 367 | generator = nn.parallel.DistributedDataParallel( 368 | generator, 369 | device_ids=[dist.get_local_rank()], 370 | output_device=dist.get_local_rank(), 371 | broadcast_buffers=True, 372 | ) 373 | 374 | discriminator = nn.parallel.DistributedDataParallel( 375 | discriminator, 376 | device_ids=[dist.get_local_rank()], 377 | output_device=dist.get_local_rank(), 378 | broadcast_buffers=False, 379 | ) 380 | 381 | transform = transforms.Compose( 382 | [ 383 | transforms.RandomHorizontalFlip(), 384 | transforms.ToTensor(), 385 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 386 | ] 387 | ) 388 | 389 | dataset = MultiResolutionDataset(conf.path, transform, conf.training.size) 390 | loader = data.DataLoader( 391 | dataset, 392 | batch_size=conf.training.batch, 393 | sampler=data_sampler(dataset, shuffle=True, distributed=conf.distributed), 394 | drop_last=True, 395 | ) 396 | 397 | if get_rank() == 0 and wandb is not None and conf.wandb: 398 | wandb.init(project="alias free gan") 399 | 400 | train(conf, loader, generator, discriminator, g_optim, d_optim, g_ema, device) 401 | 402 | 403 | if __name__ == "__main__": 404 | conf = load_arg_config(GANConfig) 405 | 406 | dist.launch( 407 | main, conf.n_gpu, conf.n_machine, conf.machine_rank, conf.dist_url, args=(conf,) 408 | ) 409 | --------------------------------------------------------------------------------