├── .gitignore ├── LICENSE ├── README.md ├── UI.png ├── draggan.png ├── draggan.py ├── generate.py ├── gui.py ├── op ├── __init__.py ├── conv2d_gradfix.py ├── fused_act.py └── upfirdn2d.py ├── sample ├── .gitignore ├── end.png ├── multi-point.png └── start.png └── stylegan2.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Mr. Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Architecture of DragGAN 2 | 3 | # DragGAN 4 | Implementation of [DragGAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://arxiv.org/abs/2305.10973). 5 | 6 | ```shell 7 | # gui 8 | pip install dearpygui 9 | # run demo 10 | python gui.py 11 | ``` 12 | 13 | Demo UI 14 | 15 | Demo StartDemo EndMulti Point Control 16 | 17 | # TODO 18 | - [x] GUI 19 | - [x] drag it 20 | - [ ] load real image 21 | - [ ] mask 22 | 23 | # StyleGAN2 Pre-Trained Model 24 | Rosinality's pre-trained model(256px) on FFHQ 550k iterations \[[Link](https://drive.google.com/open?id=1PQutd-JboOCOZqmd95XWxWrO8gGEvRcO)\]. 25 | 26 | # References 27 | - https://github.com/rosinality/stylegan2-pytorch 28 | -------------------------------------------------------------------------------- /UI.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiauZhang/DragGAN/65540aed91113772b662b0710553e6c5a9906a3d/UI.png -------------------------------------------------------------------------------- /draggan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiauZhang/DragGAN/65540aed91113772b662b0710553e6c5a9906a3d/draggan.png -------------------------------------------------------------------------------- /draggan.py: -------------------------------------------------------------------------------- 1 | import torch, math 2 | import numpy as np 3 | from stylegan2 import Generator 4 | import torch.nn.functional as functional 5 | 6 | def linear(feature, p0, p1, d, axis=0): 7 | f0 = feature[..., p0[0], p0[1]] 8 | f1 = feature[..., p1[0], p1[1]] 9 | weight = abs(d[axis]) 10 | f = (1 - weight) * f0 + weight * f1 11 | return f 12 | 13 | def bilinear(feature, qi, d): 14 | y0, x0 = qi 15 | dy, dx = d 16 | d = (dx, dy) 17 | dx = 1 if dx >= 0 else -1 18 | dy = 1 if dy >= 0 else -1 19 | x1 = x0 + dx 20 | y1 = y0 + dy 21 | fx1 = linear(feature, (x0, y0), (x1, y0), d, axis=0) 22 | fx2 = linear(feature, (x0, y1), (x1, y1), d, axis=0) 23 | weight = abs(d[1]) 24 | fx = (1 - weight) * fx1 + weight * fx2 25 | return fx 26 | 27 | def motion_supervision(F0, F, pi, ti, r1=3, M=None): 28 | F = functional.interpolate(F, [256, 256], mode="bilinear") 29 | F0 = functional.interpolate(F0, [256, 256], mode="bilinear") 30 | 31 | dw, dh = ti[0] - pi[0], ti[1] - pi[1] 32 | norm = math.sqrt(dw**2 + dh**2) 33 | w = (max(0, pi[0] - r1), min(256, pi[0] + r1)) 34 | h = (max(0, pi[1] - r1), min(256, pi[1] + r1)) 35 | d = torch.tensor( 36 | (dw / norm, dh / norm), 37 | dtype=F.dtype, device=F.device, 38 | ).reshape(1, 1, 1, 2) 39 | grid_h, grid_w = torch.meshgrid( 40 | torch.tensor(range(h[0], h[1])), 41 | torch.tensor(range(w[0], w[1])), 42 | indexing='xy', 43 | ) 44 | grid = torch.stack([grid_w, grid_h], dim=-1).unsqueeze(0) 45 | grid = (grid / 255 - 0.5) * 2 46 | grid_d = grid + 2 * d / 255 47 | 48 | sample = functional.grid_sample( 49 | F, grid, mode='bilinear', padding_mode='border', 50 | align_corners=True, 51 | ) 52 | sample_d = functional.grid_sample( 53 | F, grid_d, mode='bilinear', padding_mode='border', 54 | align_corners=True, 55 | ) 56 | 57 | loss = (sample_d - sample.detach()).abs().mean(1).sum() 58 | 59 | return loss 60 | 61 | @torch.no_grad() 62 | def point_tracking(F0, F, pi, p0, r2=12): 63 | F = functional.interpolate(F, [256, 256], mode="bilinear") 64 | F0 = functional.interpolate(F0, [256, 256], mode="bilinear") 65 | x = (max(0, pi[0] - r2), min(256, pi[0] + r2)) 66 | y = (max(0, pi[1] - r2), min(256, pi[1] + r2)) 67 | base = F0[..., p0[1], p0[0]].reshape(1, -1, 1, 1) 68 | diff = (F[..., y[0]:y[1], x[0]:x[1]] - base).abs().mean(1) 69 | idx = diff.argmin() 70 | dy = int(idx / (x[1] - x[0])) 71 | dx = int(idx % (x[1] - x[0])) 72 | npi = (x[0] + dx, y[0] + dy) 73 | return npi 74 | 75 | def requires_grad(model, flag=True): 76 | for p in model.parameters(): 77 | p.requires_grad = flag 78 | 79 | class DragGAN(): 80 | def __init__(self, device, layer_index=6): 81 | self.generator = Generator(256, 512, 8).to(device) 82 | requires_grad(self.generator, False) 83 | self._device = device 84 | self.layer_index = layer_index 85 | self.latent = None 86 | self.F0 = None 87 | self.optimizer = None 88 | self.p0 = None 89 | 90 | def load_ckpt(self, path): 91 | print(f'loading checkpoint from {path}') 92 | ckpt = torch.load(path, map_location=self._device) 93 | self.generator.load_state_dict(ckpt["g_ema"], strict=False) 94 | print('loading checkpoint successed!') 95 | 96 | def to(self, device): 97 | if self._device != device: 98 | self.generator = self.generator.to(device) 99 | self._device = device 100 | 101 | @torch.no_grad() 102 | def generate_image(self, seed): 103 | z = torch.from_numpy( 104 | np.random.RandomState(seed).randn(1, 512).astype(np.float32) 105 | ).to(self._device) 106 | image, self.latent, self.F0 = self.generator( 107 | [z], return_latents=True, return_features=True, randomize_noise=False, 108 | ) 109 | image, self.F0 = image[0], self.F0[self.layer_index*2+1].detach() 110 | image = image.detach().cpu().permute(1, 2, 0).numpy() 111 | image = (image / 2 + 0.5).clip(0, 1).reshape(-1) 112 | return image 113 | 114 | @property 115 | def device(self): 116 | return self._device 117 | 118 | def __call__(self, *args, **kwargs): 119 | return self.generator(*args, **kwargs) 120 | 121 | def step(self, points): 122 | if self.optimizer is None: 123 | len_pts = (len(points) // 2) * 2 124 | if len_pts == 0: 125 | print('Select at least one pair of points') 126 | return False, None 127 | self.trainable = self.latent[:, :self.layer_index*2, :].detach( 128 | ).requires_grad_(True) 129 | self.fixed = self.latent[:, self.layer_index*2:, :].detach( 130 | ).requires_grad_(False) 131 | self.optimizer = torch.optim.Adam([self.trainable], lr=2e-3) 132 | points = points[:len_pts] 133 | self.p0 = points[::2] 134 | self.optimizer.zero_grad() 135 | trainable_fixed = torch.cat([self.trainable, self.fixed], dim=1) 136 | image, _, features = self.generator( 137 | [trainable_fixed], input_is_latent=True, 138 | return_features=True, randomize_noise=False, 139 | ) 140 | features = features[self.layer_index*2+1] 141 | loss = 0 142 | for i in range(len(self.p0)): 143 | loss += motion_supervision(self.F0, features, points[2*i], points[2*i+1]) 144 | print(loss) 145 | loss.backward() 146 | self.optimizer.step() 147 | image, _, features = self.generator( 148 | [trainable_fixed], input_is_latent=True, 149 | return_features=True, randomize_noise=False, 150 | ) 151 | features = features[self.layer_index*2+1] 152 | image = image[0].detach().cpu().permute(1, 2, 0).numpy() 153 | image = (image / 2 + 0.5).clip(0, 1).reshape(-1) 154 | for i in range(len(self.p0)): 155 | points[2*i] = point_tracking(self.F0, features, points[2*i], self.p0[i]) 156 | return True, (points, image) 157 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse, torch 2 | from torchvision import utils 3 | from stylegan2 import Generator 4 | from tqdm import tqdm 5 | 6 | def generate(args, g_ema, device, mean_latent): 7 | with torch.no_grad(): 8 | g_ema.eval() 9 | for i in tqdm(range(args.pics)): 10 | sample_z = torch.randn(args.sample, args.latent, device=device) 11 | 12 | sample, _, _ = g_ema( 13 | [sample_z], truncation=args.truncation, truncation_latent=mean_latent 14 | ) 15 | 16 | utils.save_image( 17 | sample, 18 | f"sample/{str(i).zfill(6)}.png", 19 | nrow=1, 20 | normalize=True, 21 | range=(-1, 1), 22 | ) 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser(description="Generate samples from the generator") 26 | 27 | parser.add_argument( 28 | "--size", type=int, default=256, help="output image size of the generator" 29 | ) 30 | parser.add_argument( 31 | "--device", type=str, default='cuda', help="output image size of the generator" 32 | ) 33 | parser.add_argument( 34 | "--sample", 35 | type=int, 36 | default=1, 37 | help="number of samples to be generated for each image", 38 | ) 39 | parser.add_argument( 40 | "--pics", type=int, default=20, help="number of images to be generated" 41 | ) 42 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio") 43 | parser.add_argument( 44 | "--truncation_mean", 45 | type=int, 46 | default=4096, 47 | help="number of vectors to calculate mean for the truncation", 48 | ) 49 | parser.add_argument( 50 | "--ckpt", 51 | type=str, 52 | default="stylegan2-ffhq-config-f.pt", 53 | help="path to the model checkpoint", 54 | ) 55 | parser.add_argument( 56 | "--channel_multiplier", 57 | type=int, 58 | default=2, 59 | help="channel multiplier of the generator. config-f = 2, else = 1", 60 | ) 61 | 62 | args = parser.parse_args() 63 | 64 | args.latent = 512 65 | args.n_mlp = 8 66 | device = torch.device(args.device) 67 | 68 | g_ema = Generator( 69 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier 70 | ).to(device) 71 | checkpoint = torch.load(args.ckpt, map_location=device) 72 | 73 | g_ema.load_state_dict(checkpoint["g_ema"], strict=False) 74 | 75 | if args.truncation < 1: 76 | with torch.no_grad(): 77 | mean_latent = g_ema.mean_latent(args.truncation_mean) 78 | else: 79 | mean_latent = None 80 | 81 | generate(args, g_ema, device, mean_latent) 82 | -------------------------------------------------------------------------------- /gui.py: -------------------------------------------------------------------------------- 1 | import dearpygui.dearpygui as dpg 2 | import numpy as np 3 | from draggan import DragGAN 4 | from array import array 5 | import threading 6 | 7 | add_point = 0 8 | point_color = [(1, 0, 0), (0, 0, 1)] 9 | points, steps = [], 0 10 | dragging = False 11 | # mvFormat_Float_rgb not currently supported on macOS 12 | # More details: https://dearpygui.readthedocs.io/en/latest/documentation/textures.html#formats 13 | texture_format = dpg.mvFormat_Float_rgba 14 | image_width, image_height, rgb_channel, rgba_channel = 256, 256, 3, 4 15 | image_pixels = image_height * image_width 16 | model = DragGAN('cpu') 17 | 18 | dpg.create_context() 19 | dpg.create_viewport(title='DragGAN', width=800, height=650) 20 | 21 | raw_data_size = image_width * image_height * rgba_channel 22 | raw_data = array('f', [1] * raw_data_size) 23 | with dpg.texture_registry(show=False): 24 | dpg.add_raw_texture( 25 | width=image_width, height=image_height, default_value=raw_data, 26 | format=texture_format, tag="image" 27 | ) 28 | 29 | def update_image(new_image): 30 | # Convert image data (rgb) to raw_data (rgba) 31 | for i in range(0, image_pixels): 32 | rd_base, im_base = i * rgba_channel, i * rgb_channel 33 | raw_data[rd_base:rd_base + rgb_channel] = array( 34 | 'f', new_image[im_base:im_base + rgb_channel] 35 | ) 36 | 37 | def generate_image(sender, app_data, user_data): 38 | seed = dpg.get_value('seed') 39 | image = model.generate_image(seed) 40 | update_image(image) 41 | 42 | def change_device(sender, app_data): 43 | model.to(app_data) 44 | 45 | def dragging_thread(): 46 | global points, steps, dragging 47 | while (dragging): 48 | status, ret = model.step(points) 49 | if status: 50 | points, image = ret 51 | else: 52 | dragging = False 53 | return 54 | update_image(image) 55 | for i in range(len(points)): 56 | draw_point(*points[i], point_color[i%2]) 57 | steps += 1 58 | dpg.set_value('steps', f'steps: {steps}') 59 | 60 | width, height = 260, 200 61 | posx, posy = 0, 0 62 | with dpg.window( 63 | label='Network & Latent', width=width, height=height, pos=(posx, posy), 64 | no_move=True, no_close=True, no_collapse=True, no_resize=True, 65 | ): 66 | dpg.add_text('device', pos=(5, 20)) 67 | dpg.add_combo( 68 | ('cpu', 'cuda'), default_value='cpu', width=60, pos=(70, 20), 69 | callback=change_device, 70 | ) 71 | 72 | dpg.add_text('weight', pos=(5, 40)) 73 | 74 | def select_cb(sender, app_data): 75 | selections = app_data['selections'] 76 | if selections: 77 | for fn in selections: 78 | model.load_ckpt(selections[fn]) 79 | break 80 | 81 | def cancel_cb(sender, app_data): 82 | ... 83 | 84 | with dpg.file_dialog( 85 | directory_selector=False, show=False, callback=select_cb, id='weight selector', 86 | cancel_callback=cancel_cb, width=700 ,height=400 87 | ): 88 | dpg.add_file_extension('.*') 89 | dpg.add_button( 90 | label="select weight", callback=lambda: dpg.show_item("weight selector"), 91 | pos=(70, 40), 92 | ) 93 | 94 | dpg.add_text('latent', pos=(5, 60)) 95 | dpg.add_input_int( 96 | label='seed', width=100, pos=(70, 60), tag='seed', default_value=512, 97 | ) 98 | dpg.add_input_float( 99 | label='step size', width=54, pos=(70, 80), step=-1, default_value=0.002, 100 | ) 101 | dpg.add_button(label="reset", width=54, pos=(70, 100), callback=None) 102 | dpg.add_radio_button( 103 | items=('w', 'w+'), pos=(130, 100), horizontal=True, default_value='w+', 104 | ) 105 | dpg.add_button(label="generate", pos=(70, 120), callback=generate_image) 106 | 107 | posy += height + 2 108 | with dpg.window( 109 | label='Drag', width=width, height=height, pos=(posx, posy), 110 | no_move=True, no_close=True, no_collapse=True, no_resize=True, 111 | ): 112 | def add_point_cb(): 113 | global add_point 114 | add_point += 2 115 | 116 | def reset_point_cb(): 117 | global points 118 | points = [] 119 | 120 | def start_cb(): 121 | global dragging 122 | if dragging: return 123 | dragging = True 124 | threading.Thread(target=dragging_thread).start() 125 | 126 | def stop_cb(): 127 | global dragging 128 | dragging = False 129 | print('stop dragging...') 130 | 131 | dpg.add_text('drag', pos=(5, 20)) 132 | dpg.add_button(label="add point", width=80, pos=(70, 20), callback=add_point_cb) 133 | dpg.add_button(label="reset point", width=80, pos=(155, 20), callback=reset_point_cb) 134 | dpg.add_button(label="start", width=80, pos=(70, 40), callback=start_cb) 135 | dpg.add_button(label="stop", width=80, pos=(155, 40), callback=stop_cb) 136 | dpg.add_text('steps: 0', tag='steps', pos=(70, 60)) 137 | 138 | dpg.add_text('mask', pos=(5, 80)) 139 | dpg.add_button(label="fixed area", width=80, pos=(70, 80), callback=None) 140 | dpg.add_button(label="reset mask", width=80, pos=(70, 100), callback=None) 141 | dpg.add_checkbox(label='show mask', pos=(155, 100), default_value=False) 142 | dpg.add_input_int(label='radius', width=100, pos=(70, 120), default_value=50) 143 | dpg.add_input_float(label='lambda', width=100, pos=(70, 140), default_value=20) 144 | 145 | posy += height + 2 146 | with dpg.window( 147 | label='Capture', width=width, height=height, pos=(posx, posy), 148 | no_move=True, no_close=True, no_collapse=True, no_resize=True, 149 | ): 150 | dpg.add_text('capture', pos=(5, 20)) 151 | dpg.add_input_text(pos=(70, 20), default_value='capture') 152 | dpg.add_button(label="save image", width=80, pos=(70, 40), callback=None) 153 | 154 | def draw_point(x, y, color): 155 | x_start, x_end = max(0, x - 2), min(image_width, x + 2) 156 | y_start, y_end = max(0, y - 2), min(image_height, y + 2) 157 | for x in range(x_start, x_end): 158 | for y in range(y_start, y_end): 159 | offset = (y * image_width + x) * rgba_channel 160 | raw_data[offset:offset + rgb_channel] = array('f', color[:rgb_channel]) 161 | 162 | def select_point(sender, app_data): 163 | global add_point, points 164 | if add_point <= 0: return 165 | ms_pos = dpg.get_mouse_pos(local=False) 166 | id_pos = dpg.get_item_pos('image_data') 167 | iw_pos = dpg.get_item_pos('Image Win') 168 | ix = int(ms_pos[0]-id_pos[0]-iw_pos[0]) 169 | iy = int(ms_pos[1]-id_pos[1]-iw_pos[1]) 170 | draw_point(ix, iy, point_color[add_point % 2]) 171 | points.append(np.array([ix, iy])) 172 | print(points) 173 | add_point -= 1 174 | 175 | posx, posy = 2 + width, 0 176 | with dpg.window( 177 | label='Image', pos=(posx, posy), tag='Image Win', 178 | no_move=True, no_close=True, no_collapse=True, no_resize=True, 179 | ): 180 | dpg.add_image("image", show=True, tag='image_data', pos=(10, 30)) 181 | 182 | with dpg.item_handler_registry(tag='double_clicked_handler'): 183 | dpg.add_item_double_clicked_handler(callback=select_point) 184 | dpg.bind_item_handler_registry("image_data", "double_clicked_handler") 185 | 186 | dpg.setup_dearpygui() 187 | dpg.show_viewport() 188 | dpg.start_dearpygui() 189 | dpg.destroy_context() 190 | -------------------------------------------------------------------------------- /op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /op/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import warnings 3 | 4 | import torch 5 | from torch import autograd 6 | from torch.nn import functional as F 7 | 8 | enabled = True 9 | weight_gradients_disabled = False 10 | 11 | 12 | @contextlib.contextmanager 13 | def no_weight_gradients(): 14 | global weight_gradients_disabled 15 | 16 | old = weight_gradients_disabled 17 | weight_gradients_disabled = True 18 | yield 19 | weight_gradients_disabled = old 20 | 21 | 22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 23 | if could_use_op(input): 24 | return conv2d_gradfix( 25 | transpose=False, 26 | weight_shape=weight.shape, 27 | stride=stride, 28 | padding=padding, 29 | output_padding=0, 30 | dilation=dilation, 31 | groups=groups, 32 | ).apply(input, weight, bias) 33 | 34 | return F.conv2d( 35 | input=input, 36 | weight=weight, 37 | bias=bias, 38 | stride=stride, 39 | padding=padding, 40 | dilation=dilation, 41 | groups=groups, 42 | ) 43 | 44 | 45 | def conv_transpose2d( 46 | input, 47 | weight, 48 | bias=None, 49 | stride=1, 50 | padding=0, 51 | output_padding=0, 52 | groups=1, 53 | dilation=1, 54 | ): 55 | if could_use_op(input): 56 | return conv2d_gradfix( 57 | transpose=True, 58 | weight_shape=weight.shape, 59 | stride=stride, 60 | padding=padding, 61 | output_padding=output_padding, 62 | groups=groups, 63 | dilation=dilation, 64 | ).apply(input, weight, bias) 65 | 66 | return F.conv_transpose2d( 67 | input=input, 68 | weight=weight, 69 | bias=bias, 70 | stride=stride, 71 | padding=padding, 72 | output_padding=output_padding, 73 | dilation=dilation, 74 | groups=groups, 75 | ) 76 | 77 | 78 | def could_use_op(input): 79 | if (not enabled) or (not torch.backends.cudnn.enabled): 80 | return False 81 | 82 | if input.device.type != "cuda": 83 | return False 84 | 85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]): 86 | return True 87 | 88 | warnings.warn( 89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()." 90 | ) 91 | 92 | return False 93 | 94 | 95 | def ensure_tuple(xs, ndim): 96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 97 | 98 | return xs 99 | 100 | 101 | conv2d_gradfix_cache = dict() 102 | 103 | 104 | def conv2d_gradfix( 105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups 106 | ): 107 | ndim = 2 108 | weight_shape = tuple(weight_shape) 109 | stride = ensure_tuple(stride, ndim) 110 | padding = ensure_tuple(padding, ndim) 111 | output_padding = ensure_tuple(output_padding, ndim) 112 | dilation = ensure_tuple(dilation, ndim) 113 | 114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 115 | if key in conv2d_gradfix_cache: 116 | return conv2d_gradfix_cache[key] 117 | 118 | common_kwargs = dict( 119 | stride=stride, padding=padding, dilation=dilation, groups=groups 120 | ) 121 | 122 | def calc_output_padding(input_shape, output_shape): 123 | if transpose: 124 | return [0, 0] 125 | 126 | return [ 127 | input_shape[i + 2] 128 | - (output_shape[i + 2] - 1) * stride[i] 129 | - (1 - 2 * padding[i]) 130 | - dilation[i] * (weight_shape[i + 2] - 1) 131 | for i in range(ndim) 132 | ] 133 | 134 | class Conv2d(autograd.Function): 135 | @staticmethod 136 | def forward(ctx, input, weight, bias): 137 | if not transpose: 138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 139 | 140 | else: 141 | out = F.conv_transpose2d( 142 | input=input, 143 | weight=weight, 144 | bias=bias, 145 | output_padding=output_padding, 146 | **common_kwargs, 147 | ) 148 | 149 | ctx.save_for_backward(input, weight) 150 | 151 | return out 152 | 153 | @staticmethod 154 | def backward(ctx, grad_output): 155 | input, weight = ctx.saved_tensors 156 | grad_input, grad_weight, grad_bias = None, None, None 157 | 158 | if ctx.needs_input_grad[0]: 159 | p = calc_output_padding( 160 | input_shape=input.shape, output_shape=grad_output.shape 161 | ) 162 | grad_input = conv2d_gradfix( 163 | transpose=(not transpose), 164 | weight_shape=weight_shape, 165 | output_padding=p, 166 | **common_kwargs, 167 | ).apply(grad_output, weight, None) 168 | 169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 170 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 171 | 172 | if ctx.needs_input_grad[2]: 173 | grad_bias = grad_output.sum((0, 2, 3)) 174 | 175 | return grad_input, grad_weight, grad_bias 176 | 177 | class Conv2dGradWeight(autograd.Function): 178 | @staticmethod 179 | def forward(ctx, grad_output, input): 180 | op = torch._C._jit_get_operation( 181 | "aten::cudnn_convolution_backward_weight" 182 | if not transpose 183 | else "aten::cudnn_convolution_transpose_backward_weight" 184 | ) 185 | flags = [ 186 | torch.backends.cudnn.benchmark, 187 | torch.backends.cudnn.deterministic, 188 | torch.backends.cudnn.allow_tf32, 189 | ] 190 | grad_weight = op( 191 | weight_shape, 192 | grad_output, 193 | input, 194 | padding, 195 | stride, 196 | dilation, 197 | groups, 198 | *flags, 199 | ) 200 | ctx.save_for_backward(grad_output, input) 201 | 202 | return grad_weight 203 | 204 | @staticmethod 205 | def backward(ctx, grad_grad_weight): 206 | grad_output, input = ctx.saved_tensors 207 | grad_grad_output, grad_grad_input = None, None 208 | 209 | if ctx.needs_input_grad[0]: 210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None) 211 | 212 | if ctx.needs_input_grad[1]: 213 | p = calc_output_padding( 214 | input_shape=input.shape, output_shape=grad_output.shape 215 | ) 216 | grad_grad_input = conv2d_gradfix( 217 | transpose=(not transpose), 218 | weight_shape=weight_shape, 219 | output_padding=p, 220 | **common_kwargs, 221 | ).apply(grad_output, grad_grad_weight, None) 222 | 223 | return grad_grad_output, grad_grad_input 224 | 225 | conv2d_gradfix_cache[key] = Conv2d 226 | 227 | return Conv2d 228 | -------------------------------------------------------------------------------- /op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class FusedLeakyReLU(nn.Module): 9 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5): 10 | super().__init__() 11 | 12 | if bias: 13 | self.bias = nn.Parameter(torch.zeros(channel)) 14 | 15 | else: 16 | self.bias = None 17 | 18 | self.negative_slope = negative_slope 19 | self.scale = scale 20 | 21 | def forward(self, input): 22 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 23 | 24 | 25 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5): 26 | if bias is not None: 27 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 28 | return ( 29 | F.leaky_relu( 30 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 31 | ) 32 | * scale 33 | ) 34 | 35 | else: 36 | return F.leaky_relu(input, negative_slope=0.2) * scale 37 | -------------------------------------------------------------------------------- /op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 7 | if not isinstance(up, abc.Iterable): 8 | up = (up, up) 9 | 10 | if not isinstance(down, abc.Iterable): 11 | down = (down, down) 12 | 13 | if len(pad) == 2: 14 | pad = (pad[0], pad[1], pad[0], pad[1]) 15 | 16 | out = upfirdn2d_native(input, kernel, *up, *down, *pad) 17 | 18 | return out 19 | 20 | 21 | def upfirdn2d_native( 22 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 23 | ): 24 | _, channel, in_h, in_w = input.shape 25 | input = input.reshape(-1, in_h, in_w, 1) 26 | 27 | _, in_h, in_w, minor = input.shape 28 | kernel_h, kernel_w = kernel.shape 29 | 30 | out = input.view(-1, in_h, 1, in_w, 1, minor) 31 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 32 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 33 | 34 | out = F.pad( 35 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 36 | ) 37 | out = out[ 38 | :, 39 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 40 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 41 | :, 42 | ] 43 | 44 | out = out.permute(0, 3, 1, 2) 45 | out = out.reshape( 46 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 47 | ) 48 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 49 | out = F.conv2d(out, w) 50 | out = out.reshape( 51 | -1, 52 | minor, 53 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 54 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 55 | ) 56 | out = out.permute(0, 2, 3, 1) 57 | out = out[:, ::down_y, ::down_x, :] 58 | 59 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y 60 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x 61 | 62 | return out.view(-1, channel, out_h, out_w) 63 | -------------------------------------------------------------------------------- /sample/.gitignore: -------------------------------------------------------------------------------- 1 | *.png -------------------------------------------------------------------------------- /sample/end.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiauZhang/DragGAN/65540aed91113772b662b0710553e6c5a9906a3d/sample/end.png -------------------------------------------------------------------------------- /sample/multi-point.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiauZhang/DragGAN/65540aed91113772b662b0710553e6c5a9906a3d/sample/multi-point.png -------------------------------------------------------------------------------- /sample/start.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiauZhang/DragGAN/65540aed91113772b662b0710553e6c5a9906a3d/sample/start.png -------------------------------------------------------------------------------- /stylegan2.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import functools 4 | import operator 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | from torch.autograd import Function 10 | 11 | from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix 12 | 13 | 14 | class PixelNorm(nn.Module): 15 | def __init__(self): 16 | super().__init__() 17 | 18 | def forward(self, input): 19 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 20 | 21 | 22 | def make_kernel(k): 23 | k = torch.tensor(k, dtype=torch.float32) 24 | 25 | if k.ndim == 1: 26 | k = k[None, :] * k[:, None] 27 | 28 | k /= k.sum() 29 | 30 | return k 31 | 32 | 33 | class Upsample(nn.Module): 34 | def __init__(self, kernel, factor=2): 35 | super().__init__() 36 | 37 | self.factor = factor 38 | kernel = make_kernel(kernel) * (factor ** 2) 39 | self.register_buffer("kernel", kernel) 40 | 41 | p = kernel.shape[0] - factor 42 | 43 | pad0 = (p + 1) // 2 + factor - 1 44 | pad1 = p // 2 45 | 46 | self.pad = (pad0, pad1) 47 | 48 | def forward(self, input): 49 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 50 | 51 | return out 52 | 53 | 54 | class Downsample(nn.Module): 55 | def __init__(self, kernel, factor=2): 56 | super().__init__() 57 | 58 | self.factor = factor 59 | kernel = make_kernel(kernel) 60 | self.register_buffer("kernel", kernel) 61 | 62 | p = kernel.shape[0] - factor 63 | 64 | pad0 = (p + 1) // 2 65 | pad1 = p // 2 66 | 67 | self.pad = (pad0, pad1) 68 | 69 | def forward(self, input): 70 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 71 | 72 | return out 73 | 74 | 75 | class Blur(nn.Module): 76 | def __init__(self, kernel, pad, upsample_factor=1): 77 | super().__init__() 78 | 79 | kernel = make_kernel(kernel) 80 | 81 | if upsample_factor > 1: 82 | kernel = kernel * (upsample_factor ** 2) 83 | 84 | self.register_buffer("kernel", kernel) 85 | 86 | self.pad = pad 87 | 88 | def forward(self, input): 89 | out = upfirdn2d(input, self.kernel, pad=self.pad) 90 | 91 | return out 92 | 93 | 94 | class EqualConv2d(nn.Module): 95 | def __init__( 96 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True 97 | ): 98 | super().__init__() 99 | 100 | self.weight = nn.Parameter( 101 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 102 | ) 103 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 104 | 105 | self.stride = stride 106 | self.padding = padding 107 | 108 | if bias: 109 | self.bias = nn.Parameter(torch.zeros(out_channel)) 110 | 111 | else: 112 | self.bias = None 113 | 114 | def forward(self, input): 115 | out = conv2d_gradfix.conv2d( 116 | input, 117 | self.weight * self.scale, 118 | bias=self.bias, 119 | stride=self.stride, 120 | padding=self.padding, 121 | ) 122 | 123 | return out 124 | 125 | def __repr__(self): 126 | return ( 127 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]}," 128 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})" 129 | ) 130 | 131 | 132 | class EqualLinear(nn.Module): 133 | def __init__( 134 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 135 | ): 136 | super().__init__() 137 | 138 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 139 | 140 | if bias: 141 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 142 | 143 | else: 144 | self.bias = None 145 | 146 | self.activation = activation 147 | 148 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 149 | self.lr_mul = lr_mul 150 | 151 | def forward(self, input): 152 | if self.activation: 153 | out = F.linear(input, self.weight * self.scale) 154 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 155 | 156 | else: 157 | out = F.linear( 158 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 159 | ) 160 | 161 | return out 162 | 163 | def __repr__(self): 164 | return ( 165 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})" 166 | ) 167 | 168 | 169 | class ModulatedConv2d(nn.Module): 170 | def __init__( 171 | self, 172 | in_channel, 173 | out_channel, 174 | kernel_size, 175 | style_dim, 176 | demodulate=True, 177 | upsample=False, 178 | downsample=False, 179 | blur_kernel=[1, 3, 3, 1], 180 | fused=True, 181 | ): 182 | super().__init__() 183 | 184 | self.eps = 1e-8 185 | self.kernel_size = kernel_size 186 | self.in_channel = in_channel 187 | self.out_channel = out_channel 188 | self.upsample = upsample 189 | self.downsample = downsample 190 | 191 | if upsample: 192 | factor = 2 193 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 194 | pad0 = (p + 1) // 2 + factor - 1 195 | pad1 = p // 2 + 1 196 | 197 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 198 | 199 | if downsample: 200 | factor = 2 201 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 202 | pad0 = (p + 1) // 2 203 | pad1 = p // 2 204 | 205 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 206 | 207 | fan_in = in_channel * kernel_size ** 2 208 | self.scale = 1 / math.sqrt(fan_in) 209 | self.padding = kernel_size // 2 210 | 211 | self.weight = nn.Parameter( 212 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 213 | ) 214 | 215 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 216 | 217 | self.demodulate = demodulate 218 | self.fused = fused 219 | 220 | def __repr__(self): 221 | return ( 222 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, " 223 | f"upsample={self.upsample}, downsample={self.downsample})" 224 | ) 225 | 226 | def forward(self, input, style): 227 | batch, in_channel, height, width = input.shape 228 | 229 | if not self.fused: 230 | weight = self.scale * self.weight.squeeze(0) 231 | style = self.modulation(style) 232 | 233 | if self.demodulate: 234 | w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1) 235 | dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt() 236 | 237 | input = input * style.reshape(batch, in_channel, 1, 1) 238 | 239 | if self.upsample: 240 | weight = weight.transpose(0, 1) 241 | out = conv2d_gradfix.conv_transpose2d( 242 | input, weight, padding=0, stride=2 243 | ) 244 | out = self.blur(out) 245 | 246 | elif self.downsample: 247 | input = self.blur(input) 248 | out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2) 249 | 250 | else: 251 | out = conv2d_gradfix.conv2d(input, weight, padding=self.padding) 252 | 253 | if self.demodulate: 254 | out = out * dcoefs.view(batch, -1, 1, 1) 255 | 256 | return out 257 | 258 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1) 259 | weight = self.scale * self.weight * style 260 | 261 | if self.demodulate: 262 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 263 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 264 | 265 | weight = weight.view( 266 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 267 | ) 268 | 269 | if self.upsample: 270 | input = input.view(1, batch * in_channel, height, width) 271 | weight = weight.view( 272 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 273 | ) 274 | weight = weight.transpose(1, 2).reshape( 275 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 276 | ) 277 | out = conv2d_gradfix.conv_transpose2d( 278 | input, weight, padding=0, stride=2, groups=batch 279 | ) 280 | _, _, height, width = out.shape 281 | out = out.view(batch, self.out_channel, height, width) 282 | out = self.blur(out) 283 | 284 | elif self.downsample: 285 | input = self.blur(input) 286 | _, _, height, width = input.shape 287 | input = input.view(1, batch * in_channel, height, width) 288 | out = conv2d_gradfix.conv2d( 289 | input, weight, padding=0, stride=2, groups=batch 290 | ) 291 | _, _, height, width = out.shape 292 | out = out.view(batch, self.out_channel, height, width) 293 | 294 | else: 295 | input = input.view(1, batch * in_channel, height, width) 296 | out = conv2d_gradfix.conv2d( 297 | input, weight, padding=self.padding, groups=batch 298 | ) 299 | _, _, height, width = out.shape 300 | out = out.view(batch, self.out_channel, height, width) 301 | 302 | return out 303 | 304 | 305 | class NoiseInjection(nn.Module): 306 | def __init__(self): 307 | super().__init__() 308 | 309 | self.weight = nn.Parameter(torch.zeros(1)) 310 | 311 | def forward(self, image, noise=None): 312 | if noise is None: 313 | batch, _, height, width = image.shape 314 | noise = image.new_empty(batch, 1, height, width).normal_() 315 | 316 | return image + self.weight * noise 317 | 318 | 319 | class ConstantInput(nn.Module): 320 | def __init__(self, channel, size=4): 321 | super().__init__() 322 | 323 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 324 | 325 | def forward(self, input): 326 | batch = input.shape[0] 327 | out = self.input.repeat(batch, 1, 1, 1) 328 | 329 | return out 330 | 331 | 332 | class StyledConv(nn.Module): 333 | def __init__( 334 | self, 335 | in_channel, 336 | out_channel, 337 | kernel_size, 338 | style_dim, 339 | upsample=False, 340 | blur_kernel=[1, 3, 3, 1], 341 | demodulate=True, 342 | ): 343 | super().__init__() 344 | 345 | self.conv = ModulatedConv2d( 346 | in_channel, 347 | out_channel, 348 | kernel_size, 349 | style_dim, 350 | upsample=upsample, 351 | blur_kernel=blur_kernel, 352 | demodulate=demodulate, 353 | ) 354 | 355 | self.noise = NoiseInjection() 356 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 357 | # self.activate = ScaledLeakyReLU(0.2) 358 | self.activate = FusedLeakyReLU(out_channel) 359 | 360 | def forward(self, input, style, noise=None): 361 | out = self.conv(input, style) 362 | out = self.noise(out, noise=noise) 363 | # out = out + self.bias 364 | out = self.activate(out) 365 | 366 | return out 367 | 368 | 369 | class ToRGB(nn.Module): 370 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]): 371 | super().__init__() 372 | 373 | if upsample: 374 | self.upsample = Upsample(blur_kernel) 375 | 376 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False) 377 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) 378 | 379 | def forward(self, input, style, skip=None): 380 | out = self.conv(input, style) 381 | out = out + self.bias 382 | 383 | if skip is not None: 384 | skip = self.upsample(skip) 385 | 386 | out = out + skip 387 | 388 | return out 389 | 390 | def append_if(condition, var, elem): 391 | if (condition): 392 | var.append(elem) 393 | 394 | class Generator(nn.Module): 395 | def __init__( 396 | self, 397 | size, 398 | style_dim, 399 | n_mlp, 400 | channel_multiplier=2, 401 | blur_kernel=[1, 3, 3, 1], 402 | lr_mlp=0.01, 403 | ): 404 | super().__init__() 405 | 406 | self.size = size 407 | 408 | self.style_dim = style_dim 409 | 410 | layers = [PixelNorm()] 411 | 412 | for i in range(n_mlp): 413 | layers.append( 414 | EqualLinear( 415 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu" 416 | ) 417 | ) 418 | 419 | self.style = nn.Sequential(*layers) 420 | 421 | self.channels = { 422 | 4: 512, 423 | 8: 512, 424 | 16: 512, 425 | 32: 512, 426 | 64: 256 * channel_multiplier, 427 | 128: 128 * channel_multiplier, 428 | 256: 64 * channel_multiplier, 429 | 512: 32 * channel_multiplier, 430 | 1024: 16 * channel_multiplier, 431 | } 432 | 433 | self.input = ConstantInput(self.channels[4]) 434 | self.conv1 = StyledConv( 435 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 436 | ) 437 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False) 438 | 439 | self.log_size = int(math.log(size, 2)) 440 | self.num_layers = (self.log_size - 2) * 2 + 1 441 | 442 | self.convs = nn.ModuleList() 443 | self.upsamples = nn.ModuleList() 444 | self.to_rgbs = nn.ModuleList() 445 | self.noises = nn.Module() 446 | 447 | in_channel = self.channels[4] 448 | 449 | for layer_idx in range(self.num_layers): 450 | res = (layer_idx + 5) // 2 451 | shape = [1, 1, 2 ** res, 2 ** res] 452 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape)) 453 | 454 | for i in range(3, self.log_size + 1): 455 | out_channel = self.channels[2 ** i] 456 | 457 | self.convs.append( 458 | StyledConv( 459 | in_channel, 460 | out_channel, 461 | 3, 462 | style_dim, 463 | upsample=True, 464 | blur_kernel=blur_kernel, 465 | ) 466 | ) 467 | 468 | self.convs.append( 469 | StyledConv( 470 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 471 | ) 472 | ) 473 | 474 | self.to_rgbs.append(ToRGB(out_channel, style_dim)) 475 | 476 | in_channel = out_channel 477 | 478 | self.n_latent = self.log_size * 2 - 2 479 | 480 | def make_noise(self): 481 | device = self.input.input.device 482 | 483 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 484 | 485 | for i in range(3, self.log_size + 1): 486 | for _ in range(2): 487 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 488 | 489 | return noises 490 | 491 | def mean_latent(self, n_latent): 492 | latent_in = torch.randn( 493 | n_latent, self.style_dim, device=self.input.input.device 494 | ) 495 | latent = self.style(latent_in).mean(0, keepdim=True) 496 | 497 | return latent 498 | 499 | def get_latent(self, input): 500 | return self.style(input) 501 | 502 | def forward( 503 | self, 504 | styles, 505 | return_latents=False, 506 | inject_index=None, 507 | truncation=1, 508 | truncation_latent=None, 509 | input_is_latent=False, 510 | noise=None, 511 | randomize_noise=True, 512 | return_features=False, 513 | ): 514 | if not input_is_latent: 515 | styles = [self.style(s) for s in styles] 516 | 517 | if noise is None: 518 | if randomize_noise: 519 | noise = [None] * self.num_layers 520 | else: 521 | noise = [ 522 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) 523 | ] 524 | 525 | if truncation < 1: 526 | style_t = [] 527 | 528 | for style in styles: 529 | style_t.append( 530 | truncation_latent + truncation * (style - truncation_latent) 531 | ) 532 | 533 | styles = style_t 534 | 535 | if len(styles) < 2: 536 | inject_index = self.n_latent 537 | 538 | if styles[0].ndim < 3: 539 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 540 | 541 | else: 542 | latent = styles[0] 543 | 544 | else: 545 | if inject_index is None: 546 | inject_index = random.randint(1, self.n_latent - 1) 547 | 548 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 549 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 550 | 551 | latent = torch.cat([latent, latent2], 1) 552 | 553 | features = [] 554 | out = self.input(latent) 555 | append_if(return_features, features, out) 556 | out = self.conv1(out, latent[:, 0], noise=noise[0]) 557 | append_if(return_features, features, out) 558 | 559 | skip = self.to_rgb1(out, latent[:, 1]) 560 | 561 | i = 1 562 | for conv1, conv2, noise1, noise2, to_rgb in zip( 563 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 564 | ): 565 | out = conv1(out, latent[:, i], noise=noise1) 566 | append_if(return_features, features, out) 567 | out = conv2(out, latent[:, i + 1], noise=noise2) 568 | append_if(return_features, features, out) 569 | skip = to_rgb(out, latent[:, i + 2], skip) 570 | 571 | i += 2 572 | 573 | image = skip 574 | 575 | if return_latents: 576 | return image, latent, features 577 | 578 | else: 579 | return image, None, features 580 | 581 | 582 | class ConvLayer(nn.Sequential): 583 | def __init__( 584 | self, 585 | in_channel, 586 | out_channel, 587 | kernel_size, 588 | downsample=False, 589 | blur_kernel=[1, 3, 3, 1], 590 | bias=True, 591 | activate=True, 592 | ): 593 | layers = [] 594 | 595 | if downsample: 596 | factor = 2 597 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 598 | pad0 = (p + 1) // 2 599 | pad1 = p // 2 600 | 601 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 602 | 603 | stride = 2 604 | self.padding = 0 605 | 606 | else: 607 | stride = 1 608 | self.padding = kernel_size // 2 609 | 610 | layers.append( 611 | EqualConv2d( 612 | in_channel, 613 | out_channel, 614 | kernel_size, 615 | padding=self.padding, 616 | stride=stride, 617 | bias=bias and not activate, 618 | ) 619 | ) 620 | 621 | if activate: 622 | layers.append(FusedLeakyReLU(out_channel, bias=bias)) 623 | 624 | super().__init__(*layers) 625 | 626 | 627 | class ResBlock(nn.Module): 628 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]): 629 | super().__init__() 630 | 631 | self.conv1 = ConvLayer(in_channel, in_channel, 3) 632 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True) 633 | 634 | self.skip = ConvLayer( 635 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False 636 | ) 637 | 638 | def forward(self, input): 639 | out = self.conv1(input) 640 | out = self.conv2(out) 641 | 642 | skip = self.skip(input) 643 | out = (out + skip) / math.sqrt(2) 644 | 645 | return out 646 | 647 | 648 | class Discriminator(nn.Module): 649 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]): 650 | super().__init__() 651 | 652 | channels = { 653 | 4: 512, 654 | 8: 512, 655 | 16: 512, 656 | 32: 512, 657 | 64: 256 * channel_multiplier, 658 | 128: 128 * channel_multiplier, 659 | 256: 64 * channel_multiplier, 660 | 512: 32 * channel_multiplier, 661 | 1024: 16 * channel_multiplier, 662 | } 663 | 664 | convs = [ConvLayer(3, channels[size], 1)] 665 | 666 | log_size = int(math.log(size, 2)) 667 | 668 | in_channel = channels[size] 669 | 670 | for i in range(log_size, 2, -1): 671 | out_channel = channels[2 ** (i - 1)] 672 | 673 | convs.append(ResBlock(in_channel, out_channel, blur_kernel)) 674 | 675 | in_channel = out_channel 676 | 677 | self.convs = nn.Sequential(*convs) 678 | 679 | self.stddev_group = 4 680 | self.stddev_feat = 1 681 | 682 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3) 683 | self.final_linear = nn.Sequential( 684 | EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"), 685 | EqualLinear(channels[4], 1), 686 | ) 687 | 688 | def forward(self, input): 689 | out = self.convs(input) 690 | 691 | batch, channel, height, width = out.shape 692 | group = min(batch, self.stddev_group) 693 | stddev = out.view( 694 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width 695 | ) 696 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8) 697 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2) 698 | stddev = stddev.repeat(group, 1, height, width) 699 | out = torch.cat([out, stddev], 1) 700 | 701 | out = self.final_conv(out) 702 | 703 | out = out.view(batch, -1) 704 | out = self.final_linear(out) 705 | 706 | return out 707 | 708 | --------------------------------------------------------------------------------