├── .gitignore
├── LICENSE
├── README.md
├── UI.png
├── draggan.png
├── draggan.py
├── generate.py
├── gui.py
├── op
├── __init__.py
├── conv2d_gradfix.py
├── fused_act.py
└── upfirdn2d.py
├── sample
├── .gitignore
├── end.png
├── multi-point.png
└── start.png
└── stylegan2.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Mr. Zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # DragGAN
4 | Implementation of [DragGAN: Interactive Point-based Manipulation on the Generative Image Manifold](https://arxiv.org/abs/2305.10973).
5 |
6 | ```shell
7 | # gui
8 | pip install dearpygui
9 | # run demo
10 | python gui.py
11 | ```
12 |
13 |
14 |
15 | 

16 |
17 | # TODO
18 | - [x] GUI
19 | - [x] drag it
20 | - [ ] load real image
21 | - [ ] mask
22 |
23 | # StyleGAN2 Pre-Trained Model
24 | Rosinality's pre-trained model(256px) on FFHQ 550k iterations \[[Link](https://drive.google.com/open?id=1PQutd-JboOCOZqmd95XWxWrO8gGEvRcO)\].
25 |
26 | # References
27 | - https://github.com/rosinality/stylegan2-pytorch
28 |
--------------------------------------------------------------------------------
/UI.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiauZhang/DragGAN/65540aed91113772b662b0710553e6c5a9906a3d/UI.png
--------------------------------------------------------------------------------
/draggan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiauZhang/DragGAN/65540aed91113772b662b0710553e6c5a9906a3d/draggan.png
--------------------------------------------------------------------------------
/draggan.py:
--------------------------------------------------------------------------------
1 | import torch, math
2 | import numpy as np
3 | from stylegan2 import Generator
4 | import torch.nn.functional as functional
5 |
6 | def linear(feature, p0, p1, d, axis=0):
7 | f0 = feature[..., p0[0], p0[1]]
8 | f1 = feature[..., p1[0], p1[1]]
9 | weight = abs(d[axis])
10 | f = (1 - weight) * f0 + weight * f1
11 | return f
12 |
13 | def bilinear(feature, qi, d):
14 | y0, x0 = qi
15 | dy, dx = d
16 | d = (dx, dy)
17 | dx = 1 if dx >= 0 else -1
18 | dy = 1 if dy >= 0 else -1
19 | x1 = x0 + dx
20 | y1 = y0 + dy
21 | fx1 = linear(feature, (x0, y0), (x1, y0), d, axis=0)
22 | fx2 = linear(feature, (x0, y1), (x1, y1), d, axis=0)
23 | weight = abs(d[1])
24 | fx = (1 - weight) * fx1 + weight * fx2
25 | return fx
26 |
27 | def motion_supervision(F0, F, pi, ti, r1=3, M=None):
28 | F = functional.interpolate(F, [256, 256], mode="bilinear")
29 | F0 = functional.interpolate(F0, [256, 256], mode="bilinear")
30 |
31 | dw, dh = ti[0] - pi[0], ti[1] - pi[1]
32 | norm = math.sqrt(dw**2 + dh**2)
33 | w = (max(0, pi[0] - r1), min(256, pi[0] + r1))
34 | h = (max(0, pi[1] - r1), min(256, pi[1] + r1))
35 | d = torch.tensor(
36 | (dw / norm, dh / norm),
37 | dtype=F.dtype, device=F.device,
38 | ).reshape(1, 1, 1, 2)
39 | grid_h, grid_w = torch.meshgrid(
40 | torch.tensor(range(h[0], h[1])),
41 | torch.tensor(range(w[0], w[1])),
42 | indexing='xy',
43 | )
44 | grid = torch.stack([grid_w, grid_h], dim=-1).unsqueeze(0)
45 | grid = (grid / 255 - 0.5) * 2
46 | grid_d = grid + 2 * d / 255
47 |
48 | sample = functional.grid_sample(
49 | F, grid, mode='bilinear', padding_mode='border',
50 | align_corners=True,
51 | )
52 | sample_d = functional.grid_sample(
53 | F, grid_d, mode='bilinear', padding_mode='border',
54 | align_corners=True,
55 | )
56 |
57 | loss = (sample_d - sample.detach()).abs().mean(1).sum()
58 |
59 | return loss
60 |
61 | @torch.no_grad()
62 | def point_tracking(F0, F, pi, p0, r2=12):
63 | F = functional.interpolate(F, [256, 256], mode="bilinear")
64 | F0 = functional.interpolate(F0, [256, 256], mode="bilinear")
65 | x = (max(0, pi[0] - r2), min(256, pi[0] + r2))
66 | y = (max(0, pi[1] - r2), min(256, pi[1] + r2))
67 | base = F0[..., p0[1], p0[0]].reshape(1, -1, 1, 1)
68 | diff = (F[..., y[0]:y[1], x[0]:x[1]] - base).abs().mean(1)
69 | idx = diff.argmin()
70 | dy = int(idx / (x[1] - x[0]))
71 | dx = int(idx % (x[1] - x[0]))
72 | npi = (x[0] + dx, y[0] + dy)
73 | return npi
74 |
75 | def requires_grad(model, flag=True):
76 | for p in model.parameters():
77 | p.requires_grad = flag
78 |
79 | class DragGAN():
80 | def __init__(self, device, layer_index=6):
81 | self.generator = Generator(256, 512, 8).to(device)
82 | requires_grad(self.generator, False)
83 | self._device = device
84 | self.layer_index = layer_index
85 | self.latent = None
86 | self.F0 = None
87 | self.optimizer = None
88 | self.p0 = None
89 |
90 | def load_ckpt(self, path):
91 | print(f'loading checkpoint from {path}')
92 | ckpt = torch.load(path, map_location=self._device)
93 | self.generator.load_state_dict(ckpt["g_ema"], strict=False)
94 | print('loading checkpoint successed!')
95 |
96 | def to(self, device):
97 | if self._device != device:
98 | self.generator = self.generator.to(device)
99 | self._device = device
100 |
101 | @torch.no_grad()
102 | def generate_image(self, seed):
103 | z = torch.from_numpy(
104 | np.random.RandomState(seed).randn(1, 512).astype(np.float32)
105 | ).to(self._device)
106 | image, self.latent, self.F0 = self.generator(
107 | [z], return_latents=True, return_features=True, randomize_noise=False,
108 | )
109 | image, self.F0 = image[0], self.F0[self.layer_index*2+1].detach()
110 | image = image.detach().cpu().permute(1, 2, 0).numpy()
111 | image = (image / 2 + 0.5).clip(0, 1).reshape(-1)
112 | return image
113 |
114 | @property
115 | def device(self):
116 | return self._device
117 |
118 | def __call__(self, *args, **kwargs):
119 | return self.generator(*args, **kwargs)
120 |
121 | def step(self, points):
122 | if self.optimizer is None:
123 | len_pts = (len(points) // 2) * 2
124 | if len_pts == 0:
125 | print('Select at least one pair of points')
126 | return False, None
127 | self.trainable = self.latent[:, :self.layer_index*2, :].detach(
128 | ).requires_grad_(True)
129 | self.fixed = self.latent[:, self.layer_index*2:, :].detach(
130 | ).requires_grad_(False)
131 | self.optimizer = torch.optim.Adam([self.trainable], lr=2e-3)
132 | points = points[:len_pts]
133 | self.p0 = points[::2]
134 | self.optimizer.zero_grad()
135 | trainable_fixed = torch.cat([self.trainable, self.fixed], dim=1)
136 | image, _, features = self.generator(
137 | [trainable_fixed], input_is_latent=True,
138 | return_features=True, randomize_noise=False,
139 | )
140 | features = features[self.layer_index*2+1]
141 | loss = 0
142 | for i in range(len(self.p0)):
143 | loss += motion_supervision(self.F0, features, points[2*i], points[2*i+1])
144 | print(loss)
145 | loss.backward()
146 | self.optimizer.step()
147 | image, _, features = self.generator(
148 | [trainable_fixed], input_is_latent=True,
149 | return_features=True, randomize_noise=False,
150 | )
151 | features = features[self.layer_index*2+1]
152 | image = image[0].detach().cpu().permute(1, 2, 0).numpy()
153 | image = (image / 2 + 0.5).clip(0, 1).reshape(-1)
154 | for i in range(len(self.p0)):
155 | points[2*i] = point_tracking(self.F0, features, points[2*i], self.p0[i])
156 | return True, (points, image)
157 |
--------------------------------------------------------------------------------
/generate.py:
--------------------------------------------------------------------------------
1 | import argparse, torch
2 | from torchvision import utils
3 | from stylegan2 import Generator
4 | from tqdm import tqdm
5 |
6 | def generate(args, g_ema, device, mean_latent):
7 | with torch.no_grad():
8 | g_ema.eval()
9 | for i in tqdm(range(args.pics)):
10 | sample_z = torch.randn(args.sample, args.latent, device=device)
11 |
12 | sample, _, _ = g_ema(
13 | [sample_z], truncation=args.truncation, truncation_latent=mean_latent
14 | )
15 |
16 | utils.save_image(
17 | sample,
18 | f"sample/{str(i).zfill(6)}.png",
19 | nrow=1,
20 | normalize=True,
21 | range=(-1, 1),
22 | )
23 |
24 | if __name__ == "__main__":
25 | parser = argparse.ArgumentParser(description="Generate samples from the generator")
26 |
27 | parser.add_argument(
28 | "--size", type=int, default=256, help="output image size of the generator"
29 | )
30 | parser.add_argument(
31 | "--device", type=str, default='cuda', help="output image size of the generator"
32 | )
33 | parser.add_argument(
34 | "--sample",
35 | type=int,
36 | default=1,
37 | help="number of samples to be generated for each image",
38 | )
39 | parser.add_argument(
40 | "--pics", type=int, default=20, help="number of images to be generated"
41 | )
42 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
43 | parser.add_argument(
44 | "--truncation_mean",
45 | type=int,
46 | default=4096,
47 | help="number of vectors to calculate mean for the truncation",
48 | )
49 | parser.add_argument(
50 | "--ckpt",
51 | type=str,
52 | default="stylegan2-ffhq-config-f.pt",
53 | help="path to the model checkpoint",
54 | )
55 | parser.add_argument(
56 | "--channel_multiplier",
57 | type=int,
58 | default=2,
59 | help="channel multiplier of the generator. config-f = 2, else = 1",
60 | )
61 |
62 | args = parser.parse_args()
63 |
64 | args.latent = 512
65 | args.n_mlp = 8
66 | device = torch.device(args.device)
67 |
68 | g_ema = Generator(
69 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
70 | ).to(device)
71 | checkpoint = torch.load(args.ckpt, map_location=device)
72 |
73 | g_ema.load_state_dict(checkpoint["g_ema"], strict=False)
74 |
75 | if args.truncation < 1:
76 | with torch.no_grad():
77 | mean_latent = g_ema.mean_latent(args.truncation_mean)
78 | else:
79 | mean_latent = None
80 |
81 | generate(args, g_ema, device, mean_latent)
82 |
--------------------------------------------------------------------------------
/gui.py:
--------------------------------------------------------------------------------
1 | import dearpygui.dearpygui as dpg
2 | import numpy as np
3 | from draggan import DragGAN
4 | from array import array
5 | import threading
6 |
7 | add_point = 0
8 | point_color = [(1, 0, 0), (0, 0, 1)]
9 | points, steps = [], 0
10 | dragging = False
11 | # mvFormat_Float_rgb not currently supported on macOS
12 | # More details: https://dearpygui.readthedocs.io/en/latest/documentation/textures.html#formats
13 | texture_format = dpg.mvFormat_Float_rgba
14 | image_width, image_height, rgb_channel, rgba_channel = 256, 256, 3, 4
15 | image_pixels = image_height * image_width
16 | model = DragGAN('cpu')
17 |
18 | dpg.create_context()
19 | dpg.create_viewport(title='DragGAN', width=800, height=650)
20 |
21 | raw_data_size = image_width * image_height * rgba_channel
22 | raw_data = array('f', [1] * raw_data_size)
23 | with dpg.texture_registry(show=False):
24 | dpg.add_raw_texture(
25 | width=image_width, height=image_height, default_value=raw_data,
26 | format=texture_format, tag="image"
27 | )
28 |
29 | def update_image(new_image):
30 | # Convert image data (rgb) to raw_data (rgba)
31 | for i in range(0, image_pixels):
32 | rd_base, im_base = i * rgba_channel, i * rgb_channel
33 | raw_data[rd_base:rd_base + rgb_channel] = array(
34 | 'f', new_image[im_base:im_base + rgb_channel]
35 | )
36 |
37 | def generate_image(sender, app_data, user_data):
38 | seed = dpg.get_value('seed')
39 | image = model.generate_image(seed)
40 | update_image(image)
41 |
42 | def change_device(sender, app_data):
43 | model.to(app_data)
44 |
45 | def dragging_thread():
46 | global points, steps, dragging
47 | while (dragging):
48 | status, ret = model.step(points)
49 | if status:
50 | points, image = ret
51 | else:
52 | dragging = False
53 | return
54 | update_image(image)
55 | for i in range(len(points)):
56 | draw_point(*points[i], point_color[i%2])
57 | steps += 1
58 | dpg.set_value('steps', f'steps: {steps}')
59 |
60 | width, height = 260, 200
61 | posx, posy = 0, 0
62 | with dpg.window(
63 | label='Network & Latent', width=width, height=height, pos=(posx, posy),
64 | no_move=True, no_close=True, no_collapse=True, no_resize=True,
65 | ):
66 | dpg.add_text('device', pos=(5, 20))
67 | dpg.add_combo(
68 | ('cpu', 'cuda'), default_value='cpu', width=60, pos=(70, 20),
69 | callback=change_device,
70 | )
71 |
72 | dpg.add_text('weight', pos=(5, 40))
73 |
74 | def select_cb(sender, app_data):
75 | selections = app_data['selections']
76 | if selections:
77 | for fn in selections:
78 | model.load_ckpt(selections[fn])
79 | break
80 |
81 | def cancel_cb(sender, app_data):
82 | ...
83 |
84 | with dpg.file_dialog(
85 | directory_selector=False, show=False, callback=select_cb, id='weight selector',
86 | cancel_callback=cancel_cb, width=700 ,height=400
87 | ):
88 | dpg.add_file_extension('.*')
89 | dpg.add_button(
90 | label="select weight", callback=lambda: dpg.show_item("weight selector"),
91 | pos=(70, 40),
92 | )
93 |
94 | dpg.add_text('latent', pos=(5, 60))
95 | dpg.add_input_int(
96 | label='seed', width=100, pos=(70, 60), tag='seed', default_value=512,
97 | )
98 | dpg.add_input_float(
99 | label='step size', width=54, pos=(70, 80), step=-1, default_value=0.002,
100 | )
101 | dpg.add_button(label="reset", width=54, pos=(70, 100), callback=None)
102 | dpg.add_radio_button(
103 | items=('w', 'w+'), pos=(130, 100), horizontal=True, default_value='w+',
104 | )
105 | dpg.add_button(label="generate", pos=(70, 120), callback=generate_image)
106 |
107 | posy += height + 2
108 | with dpg.window(
109 | label='Drag', width=width, height=height, pos=(posx, posy),
110 | no_move=True, no_close=True, no_collapse=True, no_resize=True,
111 | ):
112 | def add_point_cb():
113 | global add_point
114 | add_point += 2
115 |
116 | def reset_point_cb():
117 | global points
118 | points = []
119 |
120 | def start_cb():
121 | global dragging
122 | if dragging: return
123 | dragging = True
124 | threading.Thread(target=dragging_thread).start()
125 |
126 | def stop_cb():
127 | global dragging
128 | dragging = False
129 | print('stop dragging...')
130 |
131 | dpg.add_text('drag', pos=(5, 20))
132 | dpg.add_button(label="add point", width=80, pos=(70, 20), callback=add_point_cb)
133 | dpg.add_button(label="reset point", width=80, pos=(155, 20), callback=reset_point_cb)
134 | dpg.add_button(label="start", width=80, pos=(70, 40), callback=start_cb)
135 | dpg.add_button(label="stop", width=80, pos=(155, 40), callback=stop_cb)
136 | dpg.add_text('steps: 0', tag='steps', pos=(70, 60))
137 |
138 | dpg.add_text('mask', pos=(5, 80))
139 | dpg.add_button(label="fixed area", width=80, pos=(70, 80), callback=None)
140 | dpg.add_button(label="reset mask", width=80, pos=(70, 100), callback=None)
141 | dpg.add_checkbox(label='show mask', pos=(155, 100), default_value=False)
142 | dpg.add_input_int(label='radius', width=100, pos=(70, 120), default_value=50)
143 | dpg.add_input_float(label='lambda', width=100, pos=(70, 140), default_value=20)
144 |
145 | posy += height + 2
146 | with dpg.window(
147 | label='Capture', width=width, height=height, pos=(posx, posy),
148 | no_move=True, no_close=True, no_collapse=True, no_resize=True,
149 | ):
150 | dpg.add_text('capture', pos=(5, 20))
151 | dpg.add_input_text(pos=(70, 20), default_value='capture')
152 | dpg.add_button(label="save image", width=80, pos=(70, 40), callback=None)
153 |
154 | def draw_point(x, y, color):
155 | x_start, x_end = max(0, x - 2), min(image_width, x + 2)
156 | y_start, y_end = max(0, y - 2), min(image_height, y + 2)
157 | for x in range(x_start, x_end):
158 | for y in range(y_start, y_end):
159 | offset = (y * image_width + x) * rgba_channel
160 | raw_data[offset:offset + rgb_channel] = array('f', color[:rgb_channel])
161 |
162 | def select_point(sender, app_data):
163 | global add_point, points
164 | if add_point <= 0: return
165 | ms_pos = dpg.get_mouse_pos(local=False)
166 | id_pos = dpg.get_item_pos('image_data')
167 | iw_pos = dpg.get_item_pos('Image Win')
168 | ix = int(ms_pos[0]-id_pos[0]-iw_pos[0])
169 | iy = int(ms_pos[1]-id_pos[1]-iw_pos[1])
170 | draw_point(ix, iy, point_color[add_point % 2])
171 | points.append(np.array([ix, iy]))
172 | print(points)
173 | add_point -= 1
174 |
175 | posx, posy = 2 + width, 0
176 | with dpg.window(
177 | label='Image', pos=(posx, posy), tag='Image Win',
178 | no_move=True, no_close=True, no_collapse=True, no_resize=True,
179 | ):
180 | dpg.add_image("image", show=True, tag='image_data', pos=(10, 30))
181 |
182 | with dpg.item_handler_registry(tag='double_clicked_handler'):
183 | dpg.add_item_double_clicked_handler(callback=select_point)
184 | dpg.bind_item_handler_registry("image_data", "double_clicked_handler")
185 |
186 | dpg.setup_dearpygui()
187 | dpg.show_viewport()
188 | dpg.start_dearpygui()
189 | dpg.destroy_context()
190 |
--------------------------------------------------------------------------------
/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/op/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import warnings
3 |
4 | import torch
5 | from torch import autograd
6 | from torch.nn import functional as F
7 |
8 | enabled = True
9 | weight_gradients_disabled = False
10 |
11 |
12 | @contextlib.contextmanager
13 | def no_weight_gradients():
14 | global weight_gradients_disabled
15 |
16 | old = weight_gradients_disabled
17 | weight_gradients_disabled = True
18 | yield
19 | weight_gradients_disabled = old
20 |
21 |
22 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
23 | if could_use_op(input):
24 | return conv2d_gradfix(
25 | transpose=False,
26 | weight_shape=weight.shape,
27 | stride=stride,
28 | padding=padding,
29 | output_padding=0,
30 | dilation=dilation,
31 | groups=groups,
32 | ).apply(input, weight, bias)
33 |
34 | return F.conv2d(
35 | input=input,
36 | weight=weight,
37 | bias=bias,
38 | stride=stride,
39 | padding=padding,
40 | dilation=dilation,
41 | groups=groups,
42 | )
43 |
44 |
45 | def conv_transpose2d(
46 | input,
47 | weight,
48 | bias=None,
49 | stride=1,
50 | padding=0,
51 | output_padding=0,
52 | groups=1,
53 | dilation=1,
54 | ):
55 | if could_use_op(input):
56 | return conv2d_gradfix(
57 | transpose=True,
58 | weight_shape=weight.shape,
59 | stride=stride,
60 | padding=padding,
61 | output_padding=output_padding,
62 | groups=groups,
63 | dilation=dilation,
64 | ).apply(input, weight, bias)
65 |
66 | return F.conv_transpose2d(
67 | input=input,
68 | weight=weight,
69 | bias=bias,
70 | stride=stride,
71 | padding=padding,
72 | output_padding=output_padding,
73 | dilation=dilation,
74 | groups=groups,
75 | )
76 |
77 |
78 | def could_use_op(input):
79 | if (not enabled) or (not torch.backends.cudnn.enabled):
80 | return False
81 |
82 | if input.device.type != "cuda":
83 | return False
84 |
85 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
86 | return True
87 |
88 | warnings.warn(
89 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
90 | )
91 |
92 | return False
93 |
94 |
95 | def ensure_tuple(xs, ndim):
96 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
97 |
98 | return xs
99 |
100 |
101 | conv2d_gradfix_cache = dict()
102 |
103 |
104 | def conv2d_gradfix(
105 | transpose, weight_shape, stride, padding, output_padding, dilation, groups
106 | ):
107 | ndim = 2
108 | weight_shape = tuple(weight_shape)
109 | stride = ensure_tuple(stride, ndim)
110 | padding = ensure_tuple(padding, ndim)
111 | output_padding = ensure_tuple(output_padding, ndim)
112 | dilation = ensure_tuple(dilation, ndim)
113 |
114 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
115 | if key in conv2d_gradfix_cache:
116 | return conv2d_gradfix_cache[key]
117 |
118 | common_kwargs = dict(
119 | stride=stride, padding=padding, dilation=dilation, groups=groups
120 | )
121 |
122 | def calc_output_padding(input_shape, output_shape):
123 | if transpose:
124 | return [0, 0]
125 |
126 | return [
127 | input_shape[i + 2]
128 | - (output_shape[i + 2] - 1) * stride[i]
129 | - (1 - 2 * padding[i])
130 | - dilation[i] * (weight_shape[i + 2] - 1)
131 | for i in range(ndim)
132 | ]
133 |
134 | class Conv2d(autograd.Function):
135 | @staticmethod
136 | def forward(ctx, input, weight, bias):
137 | if not transpose:
138 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
139 |
140 | else:
141 | out = F.conv_transpose2d(
142 | input=input,
143 | weight=weight,
144 | bias=bias,
145 | output_padding=output_padding,
146 | **common_kwargs,
147 | )
148 |
149 | ctx.save_for_backward(input, weight)
150 |
151 | return out
152 |
153 | @staticmethod
154 | def backward(ctx, grad_output):
155 | input, weight = ctx.saved_tensors
156 | grad_input, grad_weight, grad_bias = None, None, None
157 |
158 | if ctx.needs_input_grad[0]:
159 | p = calc_output_padding(
160 | input_shape=input.shape, output_shape=grad_output.shape
161 | )
162 | grad_input = conv2d_gradfix(
163 | transpose=(not transpose),
164 | weight_shape=weight_shape,
165 | output_padding=p,
166 | **common_kwargs,
167 | ).apply(grad_output, weight, None)
168 |
169 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
170 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
171 |
172 | if ctx.needs_input_grad[2]:
173 | grad_bias = grad_output.sum((0, 2, 3))
174 |
175 | return grad_input, grad_weight, grad_bias
176 |
177 | class Conv2dGradWeight(autograd.Function):
178 | @staticmethod
179 | def forward(ctx, grad_output, input):
180 | op = torch._C._jit_get_operation(
181 | "aten::cudnn_convolution_backward_weight"
182 | if not transpose
183 | else "aten::cudnn_convolution_transpose_backward_weight"
184 | )
185 | flags = [
186 | torch.backends.cudnn.benchmark,
187 | torch.backends.cudnn.deterministic,
188 | torch.backends.cudnn.allow_tf32,
189 | ]
190 | grad_weight = op(
191 | weight_shape,
192 | grad_output,
193 | input,
194 | padding,
195 | stride,
196 | dilation,
197 | groups,
198 | *flags,
199 | )
200 | ctx.save_for_backward(grad_output, input)
201 |
202 | return grad_weight
203 |
204 | @staticmethod
205 | def backward(ctx, grad_grad_weight):
206 | grad_output, input = ctx.saved_tensors
207 | grad_grad_output, grad_grad_input = None, None
208 |
209 | if ctx.needs_input_grad[0]:
210 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
211 |
212 | if ctx.needs_input_grad[1]:
213 | p = calc_output_padding(
214 | input_shape=input.shape, output_shape=grad_output.shape
215 | )
216 | grad_grad_input = conv2d_gradfix(
217 | transpose=(not transpose),
218 | weight_shape=weight_shape,
219 | output_padding=p,
220 | **common_kwargs,
221 | ).apply(grad_output, grad_grad_weight, None)
222 |
223 | return grad_grad_output, grad_grad_input
224 |
225 | conv2d_gradfix_cache[key] = Conv2d
226 |
227 | return Conv2d
228 |
--------------------------------------------------------------------------------
/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 |
8 | class FusedLeakyReLU(nn.Module):
9 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
10 | super().__init__()
11 |
12 | if bias:
13 | self.bias = nn.Parameter(torch.zeros(channel))
14 |
15 | else:
16 | self.bias = None
17 |
18 | self.negative_slope = negative_slope
19 | self.scale = scale
20 |
21 | def forward(self, input):
22 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
23 |
24 |
25 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
26 | if bias is not None:
27 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
28 | return (
29 | F.leaky_relu(
30 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
31 | )
32 | * scale
33 | )
34 |
35 | else:
36 | return F.leaky_relu(input, negative_slope=0.2) * scale
37 |
--------------------------------------------------------------------------------
/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | from collections import abc
2 | import torch
3 | from torch.nn import functional as F
4 |
5 |
6 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
7 | if not isinstance(up, abc.Iterable):
8 | up = (up, up)
9 |
10 | if not isinstance(down, abc.Iterable):
11 | down = (down, down)
12 |
13 | if len(pad) == 2:
14 | pad = (pad[0], pad[1], pad[0], pad[1])
15 |
16 | out = upfirdn2d_native(input, kernel, *up, *down, *pad)
17 |
18 | return out
19 |
20 |
21 | def upfirdn2d_native(
22 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
23 | ):
24 | _, channel, in_h, in_w = input.shape
25 | input = input.reshape(-1, in_h, in_w, 1)
26 |
27 | _, in_h, in_w, minor = input.shape
28 | kernel_h, kernel_w = kernel.shape
29 |
30 | out = input.view(-1, in_h, 1, in_w, 1, minor)
31 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
32 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
33 |
34 | out = F.pad(
35 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
36 | )
37 | out = out[
38 | :,
39 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
40 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
41 | :,
42 | ]
43 |
44 | out = out.permute(0, 3, 1, 2)
45 | out = out.reshape(
46 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
47 | )
48 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
49 | out = F.conv2d(out, w)
50 | out = out.reshape(
51 | -1,
52 | minor,
53 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
54 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
55 | )
56 | out = out.permute(0, 2, 3, 1)
57 | out = out[:, ::down_y, ::down_x, :]
58 |
59 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
60 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
61 |
62 | return out.view(-1, channel, out_h, out_w)
63 |
--------------------------------------------------------------------------------
/sample/.gitignore:
--------------------------------------------------------------------------------
1 | *.png
--------------------------------------------------------------------------------
/sample/end.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiauZhang/DragGAN/65540aed91113772b662b0710553e6c5a9906a3d/sample/end.png
--------------------------------------------------------------------------------
/sample/multi-point.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiauZhang/DragGAN/65540aed91113772b662b0710553e6c5a9906a3d/sample/multi-point.png
--------------------------------------------------------------------------------
/sample/start.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiauZhang/DragGAN/65540aed91113772b662b0710553e6c5a9906a3d/sample/start.png
--------------------------------------------------------------------------------
/stylegan2.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import functools
4 | import operator
5 |
6 | import torch
7 | from torch import nn
8 | from torch.nn import functional as F
9 | from torch.autograd import Function
10 |
11 | from op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d, conv2d_gradfix
12 |
13 |
14 | class PixelNorm(nn.Module):
15 | def __init__(self):
16 | super().__init__()
17 |
18 | def forward(self, input):
19 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
20 |
21 |
22 | def make_kernel(k):
23 | k = torch.tensor(k, dtype=torch.float32)
24 |
25 | if k.ndim == 1:
26 | k = k[None, :] * k[:, None]
27 |
28 | k /= k.sum()
29 |
30 | return k
31 |
32 |
33 | class Upsample(nn.Module):
34 | def __init__(self, kernel, factor=2):
35 | super().__init__()
36 |
37 | self.factor = factor
38 | kernel = make_kernel(kernel) * (factor ** 2)
39 | self.register_buffer("kernel", kernel)
40 |
41 | p = kernel.shape[0] - factor
42 |
43 | pad0 = (p + 1) // 2 + factor - 1
44 | pad1 = p // 2
45 |
46 | self.pad = (pad0, pad1)
47 |
48 | def forward(self, input):
49 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
50 |
51 | return out
52 |
53 |
54 | class Downsample(nn.Module):
55 | def __init__(self, kernel, factor=2):
56 | super().__init__()
57 |
58 | self.factor = factor
59 | kernel = make_kernel(kernel)
60 | self.register_buffer("kernel", kernel)
61 |
62 | p = kernel.shape[0] - factor
63 |
64 | pad0 = (p + 1) // 2
65 | pad1 = p // 2
66 |
67 | self.pad = (pad0, pad1)
68 |
69 | def forward(self, input):
70 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
71 |
72 | return out
73 |
74 |
75 | class Blur(nn.Module):
76 | def __init__(self, kernel, pad, upsample_factor=1):
77 | super().__init__()
78 |
79 | kernel = make_kernel(kernel)
80 |
81 | if upsample_factor > 1:
82 | kernel = kernel * (upsample_factor ** 2)
83 |
84 | self.register_buffer("kernel", kernel)
85 |
86 | self.pad = pad
87 |
88 | def forward(self, input):
89 | out = upfirdn2d(input, self.kernel, pad=self.pad)
90 |
91 | return out
92 |
93 |
94 | class EqualConv2d(nn.Module):
95 | def __init__(
96 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True
97 | ):
98 | super().__init__()
99 |
100 | self.weight = nn.Parameter(
101 | torch.randn(out_channel, in_channel, kernel_size, kernel_size)
102 | )
103 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
104 |
105 | self.stride = stride
106 | self.padding = padding
107 |
108 | if bias:
109 | self.bias = nn.Parameter(torch.zeros(out_channel))
110 |
111 | else:
112 | self.bias = None
113 |
114 | def forward(self, input):
115 | out = conv2d_gradfix.conv2d(
116 | input,
117 | self.weight * self.scale,
118 | bias=self.bias,
119 | stride=self.stride,
120 | padding=self.padding,
121 | )
122 |
123 | return out
124 |
125 | def __repr__(self):
126 | return (
127 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},"
128 | f" {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})"
129 | )
130 |
131 |
132 | class EqualLinear(nn.Module):
133 | def __init__(
134 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
135 | ):
136 | super().__init__()
137 |
138 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
139 |
140 | if bias:
141 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
142 |
143 | else:
144 | self.bias = None
145 |
146 | self.activation = activation
147 |
148 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul
149 | self.lr_mul = lr_mul
150 |
151 | def forward(self, input):
152 | if self.activation:
153 | out = F.linear(input, self.weight * self.scale)
154 | out = fused_leaky_relu(out, self.bias * self.lr_mul)
155 |
156 | else:
157 | out = F.linear(
158 | input, self.weight * self.scale, bias=self.bias * self.lr_mul
159 | )
160 |
161 | return out
162 |
163 | def __repr__(self):
164 | return (
165 | f"{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})"
166 | )
167 |
168 |
169 | class ModulatedConv2d(nn.Module):
170 | def __init__(
171 | self,
172 | in_channel,
173 | out_channel,
174 | kernel_size,
175 | style_dim,
176 | demodulate=True,
177 | upsample=False,
178 | downsample=False,
179 | blur_kernel=[1, 3, 3, 1],
180 | fused=True,
181 | ):
182 | super().__init__()
183 |
184 | self.eps = 1e-8
185 | self.kernel_size = kernel_size
186 | self.in_channel = in_channel
187 | self.out_channel = out_channel
188 | self.upsample = upsample
189 | self.downsample = downsample
190 |
191 | if upsample:
192 | factor = 2
193 | p = (len(blur_kernel) - factor) - (kernel_size - 1)
194 | pad0 = (p + 1) // 2 + factor - 1
195 | pad1 = p // 2 + 1
196 |
197 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
198 |
199 | if downsample:
200 | factor = 2
201 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
202 | pad0 = (p + 1) // 2
203 | pad1 = p // 2
204 |
205 | self.blur = Blur(blur_kernel, pad=(pad0, pad1))
206 |
207 | fan_in = in_channel * kernel_size ** 2
208 | self.scale = 1 / math.sqrt(fan_in)
209 | self.padding = kernel_size // 2
210 |
211 | self.weight = nn.Parameter(
212 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
213 | )
214 |
215 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
216 |
217 | self.demodulate = demodulate
218 | self.fused = fused
219 |
220 | def __repr__(self):
221 | return (
222 | f"{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, "
223 | f"upsample={self.upsample}, downsample={self.downsample})"
224 | )
225 |
226 | def forward(self, input, style):
227 | batch, in_channel, height, width = input.shape
228 |
229 | if not self.fused:
230 | weight = self.scale * self.weight.squeeze(0)
231 | style = self.modulation(style)
232 |
233 | if self.demodulate:
234 | w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1, 1)
235 | dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()
236 |
237 | input = input * style.reshape(batch, in_channel, 1, 1)
238 |
239 | if self.upsample:
240 | weight = weight.transpose(0, 1)
241 | out = conv2d_gradfix.conv_transpose2d(
242 | input, weight, padding=0, stride=2
243 | )
244 | out = self.blur(out)
245 |
246 | elif self.downsample:
247 | input = self.blur(input)
248 | out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)
249 |
250 | else:
251 | out = conv2d_gradfix.conv2d(input, weight, padding=self.padding)
252 |
253 | if self.demodulate:
254 | out = out * dcoefs.view(batch, -1, 1, 1)
255 |
256 | return out
257 |
258 | style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
259 | weight = self.scale * self.weight * style
260 |
261 | if self.demodulate:
262 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
263 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
264 |
265 | weight = weight.view(
266 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
267 | )
268 |
269 | if self.upsample:
270 | input = input.view(1, batch * in_channel, height, width)
271 | weight = weight.view(
272 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
273 | )
274 | weight = weight.transpose(1, 2).reshape(
275 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
276 | )
277 | out = conv2d_gradfix.conv_transpose2d(
278 | input, weight, padding=0, stride=2, groups=batch
279 | )
280 | _, _, height, width = out.shape
281 | out = out.view(batch, self.out_channel, height, width)
282 | out = self.blur(out)
283 |
284 | elif self.downsample:
285 | input = self.blur(input)
286 | _, _, height, width = input.shape
287 | input = input.view(1, batch * in_channel, height, width)
288 | out = conv2d_gradfix.conv2d(
289 | input, weight, padding=0, stride=2, groups=batch
290 | )
291 | _, _, height, width = out.shape
292 | out = out.view(batch, self.out_channel, height, width)
293 |
294 | else:
295 | input = input.view(1, batch * in_channel, height, width)
296 | out = conv2d_gradfix.conv2d(
297 | input, weight, padding=self.padding, groups=batch
298 | )
299 | _, _, height, width = out.shape
300 | out = out.view(batch, self.out_channel, height, width)
301 |
302 | return out
303 |
304 |
305 | class NoiseInjection(nn.Module):
306 | def __init__(self):
307 | super().__init__()
308 |
309 | self.weight = nn.Parameter(torch.zeros(1))
310 |
311 | def forward(self, image, noise=None):
312 | if noise is None:
313 | batch, _, height, width = image.shape
314 | noise = image.new_empty(batch, 1, height, width).normal_()
315 |
316 | return image + self.weight * noise
317 |
318 |
319 | class ConstantInput(nn.Module):
320 | def __init__(self, channel, size=4):
321 | super().__init__()
322 |
323 | self.input = nn.Parameter(torch.randn(1, channel, size, size))
324 |
325 | def forward(self, input):
326 | batch = input.shape[0]
327 | out = self.input.repeat(batch, 1, 1, 1)
328 |
329 | return out
330 |
331 |
332 | class StyledConv(nn.Module):
333 | def __init__(
334 | self,
335 | in_channel,
336 | out_channel,
337 | kernel_size,
338 | style_dim,
339 | upsample=False,
340 | blur_kernel=[1, 3, 3, 1],
341 | demodulate=True,
342 | ):
343 | super().__init__()
344 |
345 | self.conv = ModulatedConv2d(
346 | in_channel,
347 | out_channel,
348 | kernel_size,
349 | style_dim,
350 | upsample=upsample,
351 | blur_kernel=blur_kernel,
352 | demodulate=demodulate,
353 | )
354 |
355 | self.noise = NoiseInjection()
356 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
357 | # self.activate = ScaledLeakyReLU(0.2)
358 | self.activate = FusedLeakyReLU(out_channel)
359 |
360 | def forward(self, input, style, noise=None):
361 | out = self.conv(input, style)
362 | out = self.noise(out, noise=noise)
363 | # out = out + self.bias
364 | out = self.activate(out)
365 |
366 | return out
367 |
368 |
369 | class ToRGB(nn.Module):
370 | def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
371 | super().__init__()
372 |
373 | if upsample:
374 | self.upsample = Upsample(blur_kernel)
375 |
376 | self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
377 | self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
378 |
379 | def forward(self, input, style, skip=None):
380 | out = self.conv(input, style)
381 | out = out + self.bias
382 |
383 | if skip is not None:
384 | skip = self.upsample(skip)
385 |
386 | out = out + skip
387 |
388 | return out
389 |
390 | def append_if(condition, var, elem):
391 | if (condition):
392 | var.append(elem)
393 |
394 | class Generator(nn.Module):
395 | def __init__(
396 | self,
397 | size,
398 | style_dim,
399 | n_mlp,
400 | channel_multiplier=2,
401 | blur_kernel=[1, 3, 3, 1],
402 | lr_mlp=0.01,
403 | ):
404 | super().__init__()
405 |
406 | self.size = size
407 |
408 | self.style_dim = style_dim
409 |
410 | layers = [PixelNorm()]
411 |
412 | for i in range(n_mlp):
413 | layers.append(
414 | EqualLinear(
415 | style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
416 | )
417 | )
418 |
419 | self.style = nn.Sequential(*layers)
420 |
421 | self.channels = {
422 | 4: 512,
423 | 8: 512,
424 | 16: 512,
425 | 32: 512,
426 | 64: 256 * channel_multiplier,
427 | 128: 128 * channel_multiplier,
428 | 256: 64 * channel_multiplier,
429 | 512: 32 * channel_multiplier,
430 | 1024: 16 * channel_multiplier,
431 | }
432 |
433 | self.input = ConstantInput(self.channels[4])
434 | self.conv1 = StyledConv(
435 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
436 | )
437 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
438 |
439 | self.log_size = int(math.log(size, 2))
440 | self.num_layers = (self.log_size - 2) * 2 + 1
441 |
442 | self.convs = nn.ModuleList()
443 | self.upsamples = nn.ModuleList()
444 | self.to_rgbs = nn.ModuleList()
445 | self.noises = nn.Module()
446 |
447 | in_channel = self.channels[4]
448 |
449 | for layer_idx in range(self.num_layers):
450 | res = (layer_idx + 5) // 2
451 | shape = [1, 1, 2 ** res, 2 ** res]
452 | self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
453 |
454 | for i in range(3, self.log_size + 1):
455 | out_channel = self.channels[2 ** i]
456 |
457 | self.convs.append(
458 | StyledConv(
459 | in_channel,
460 | out_channel,
461 | 3,
462 | style_dim,
463 | upsample=True,
464 | blur_kernel=blur_kernel,
465 | )
466 | )
467 |
468 | self.convs.append(
469 | StyledConv(
470 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
471 | )
472 | )
473 |
474 | self.to_rgbs.append(ToRGB(out_channel, style_dim))
475 |
476 | in_channel = out_channel
477 |
478 | self.n_latent = self.log_size * 2 - 2
479 |
480 | def make_noise(self):
481 | device = self.input.input.device
482 |
483 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
484 |
485 | for i in range(3, self.log_size + 1):
486 | for _ in range(2):
487 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
488 |
489 | return noises
490 |
491 | def mean_latent(self, n_latent):
492 | latent_in = torch.randn(
493 | n_latent, self.style_dim, device=self.input.input.device
494 | )
495 | latent = self.style(latent_in).mean(0, keepdim=True)
496 |
497 | return latent
498 |
499 | def get_latent(self, input):
500 | return self.style(input)
501 |
502 | def forward(
503 | self,
504 | styles,
505 | return_latents=False,
506 | inject_index=None,
507 | truncation=1,
508 | truncation_latent=None,
509 | input_is_latent=False,
510 | noise=None,
511 | randomize_noise=True,
512 | return_features=False,
513 | ):
514 | if not input_is_latent:
515 | styles = [self.style(s) for s in styles]
516 |
517 | if noise is None:
518 | if randomize_noise:
519 | noise = [None] * self.num_layers
520 | else:
521 | noise = [
522 | getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
523 | ]
524 |
525 | if truncation < 1:
526 | style_t = []
527 |
528 | for style in styles:
529 | style_t.append(
530 | truncation_latent + truncation * (style - truncation_latent)
531 | )
532 |
533 | styles = style_t
534 |
535 | if len(styles) < 2:
536 | inject_index = self.n_latent
537 |
538 | if styles[0].ndim < 3:
539 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
540 |
541 | else:
542 | latent = styles[0]
543 |
544 | else:
545 | if inject_index is None:
546 | inject_index = random.randint(1, self.n_latent - 1)
547 |
548 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
549 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
550 |
551 | latent = torch.cat([latent, latent2], 1)
552 |
553 | features = []
554 | out = self.input(latent)
555 | append_if(return_features, features, out)
556 | out = self.conv1(out, latent[:, 0], noise=noise[0])
557 | append_if(return_features, features, out)
558 |
559 | skip = self.to_rgb1(out, latent[:, 1])
560 |
561 | i = 1
562 | for conv1, conv2, noise1, noise2, to_rgb in zip(
563 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
564 | ):
565 | out = conv1(out, latent[:, i], noise=noise1)
566 | append_if(return_features, features, out)
567 | out = conv2(out, latent[:, i + 1], noise=noise2)
568 | append_if(return_features, features, out)
569 | skip = to_rgb(out, latent[:, i + 2], skip)
570 |
571 | i += 2
572 |
573 | image = skip
574 |
575 | if return_latents:
576 | return image, latent, features
577 |
578 | else:
579 | return image, None, features
580 |
581 |
582 | class ConvLayer(nn.Sequential):
583 | def __init__(
584 | self,
585 | in_channel,
586 | out_channel,
587 | kernel_size,
588 | downsample=False,
589 | blur_kernel=[1, 3, 3, 1],
590 | bias=True,
591 | activate=True,
592 | ):
593 | layers = []
594 |
595 | if downsample:
596 | factor = 2
597 | p = (len(blur_kernel) - factor) + (kernel_size - 1)
598 | pad0 = (p + 1) // 2
599 | pad1 = p // 2
600 |
601 | layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
602 |
603 | stride = 2
604 | self.padding = 0
605 |
606 | else:
607 | stride = 1
608 | self.padding = kernel_size // 2
609 |
610 | layers.append(
611 | EqualConv2d(
612 | in_channel,
613 | out_channel,
614 | kernel_size,
615 | padding=self.padding,
616 | stride=stride,
617 | bias=bias and not activate,
618 | )
619 | )
620 |
621 | if activate:
622 | layers.append(FusedLeakyReLU(out_channel, bias=bias))
623 |
624 | super().__init__(*layers)
625 |
626 |
627 | class ResBlock(nn.Module):
628 | def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
629 | super().__init__()
630 |
631 | self.conv1 = ConvLayer(in_channel, in_channel, 3)
632 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
633 |
634 | self.skip = ConvLayer(
635 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False
636 | )
637 |
638 | def forward(self, input):
639 | out = self.conv1(input)
640 | out = self.conv2(out)
641 |
642 | skip = self.skip(input)
643 | out = (out + skip) / math.sqrt(2)
644 |
645 | return out
646 |
647 |
648 | class Discriminator(nn.Module):
649 | def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
650 | super().__init__()
651 |
652 | channels = {
653 | 4: 512,
654 | 8: 512,
655 | 16: 512,
656 | 32: 512,
657 | 64: 256 * channel_multiplier,
658 | 128: 128 * channel_multiplier,
659 | 256: 64 * channel_multiplier,
660 | 512: 32 * channel_multiplier,
661 | 1024: 16 * channel_multiplier,
662 | }
663 |
664 | convs = [ConvLayer(3, channels[size], 1)]
665 |
666 | log_size = int(math.log(size, 2))
667 |
668 | in_channel = channels[size]
669 |
670 | for i in range(log_size, 2, -1):
671 | out_channel = channels[2 ** (i - 1)]
672 |
673 | convs.append(ResBlock(in_channel, out_channel, blur_kernel))
674 |
675 | in_channel = out_channel
676 |
677 | self.convs = nn.Sequential(*convs)
678 |
679 | self.stddev_group = 4
680 | self.stddev_feat = 1
681 |
682 | self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
683 | self.final_linear = nn.Sequential(
684 | EqualLinear(channels[4] * 4 * 4, channels[4], activation="fused_lrelu"),
685 | EqualLinear(channels[4], 1),
686 | )
687 |
688 | def forward(self, input):
689 | out = self.convs(input)
690 |
691 | batch, channel, height, width = out.shape
692 | group = min(batch, self.stddev_group)
693 | stddev = out.view(
694 | group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
695 | )
696 | stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
697 | stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
698 | stddev = stddev.repeat(group, 1, height, width)
699 | out = torch.cat([out, stddev], 1)
700 |
701 | out = self.final_conv(out)
702 |
703 | out = out.view(batch, -1)
704 | out = self.final_linear(out)
705 |
706 | return out
707 |
708 |
--------------------------------------------------------------------------------