├── .gitignore
├── README.md
├── align.py
├── attack_utils.py
├── autoencoding.py
├── choices.py
├── config.py
├── config_base.py
├── dataset.py
├── dataset_util.py
├── demo.py
├── diffusion
├── __init__.py
├── base.py
├── diffusion.py
└── resample.py
├── dist_utils.py
├── experiment.py
├── external
└── face_makeup
│ ├── .gitignore
│ ├── LICENSE
│ ├── README.md
│ ├── cp
│ └── 79999_iter.pth
│ ├── imgs
│ ├── 116.jpg
│ └── 6.jpg
│ ├── makeup.py
│ ├── makeup
│ ├── 116_0.png
│ ├── 116_1.png
│ ├── 116_2.png
│ ├── 116_3.png
│ ├── 116_4.png
│ ├── 116_5.png
│ ├── 116_6.png
│ ├── 116_lip_ori.png
│ └── 116_ori.png
│ ├── model.py
│ ├── resnet.py
│ └── test.py
├── iterative_projected_gradient_fast.py
├── lmdb_writer.py
├── metrics.py
├── model
├── __init__.py
├── blocks.py
├── latentnet.py
├── nn.py
├── unet.py
└── unet_autoenc.py
├── pipeline.png
├── predict.py
├── renderer.py
├── requirements.txt
├── ssim.py
└── templates.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | .idea
3 | *.DS_Store
4 | checkpoints
5 | assets
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DiffProtect: Generate Adversarial Examples with Diffusion Models for Facial Privacy Protection
2 | Welcome to the official repository for the method presented in [DiffProtect: Generate Adversarial Examples with Diffusion Models for Facial Privacy Protection](https://arxiv.org/abs/2305.13625)
3 | by [Jiang Liu*](https://joellliu.github.io/), [Chun Pong Lau*](https://samuel930930.github.io/), and [Rama Chellappa](https://engineering.jhu.edu/ece/faculty/rama-chellappa/).
4 | 
5 | ## Updates
6 | **Aug 1th, 2023** Preview code release.
7 |
8 | ## Setting Up
9 | ### Install dependency
10 | ```shell
11 | pip install -r requirements.txt
12 | ```
13 | ### Prepare model checkpoints
14 | 1. Download [DiffAE](https://github.com/phizaz/diffae/tree/master) checkpoint [ffhq256_autoenc](https://vistec-my.sharepoint.com/:f:/g/personal/nattanatc_pro_vistec_ac_th/Ev2D_RNV2llIvm2yXyKgUxAB6w8ffg0C9NWSOtFqPMYQuw?e=f2kWUa) to `./checkpoints`
15 | 2. Download the weights for victim models from [here](https://drive.google.com/file/d/19_Y0jR789BGciogjjoGtWNEv-5QBiCB7/view?usp=sharing) and extract to `./assets`
16 |
17 | ## Run the code
18 | ```shell
19 | python demo.py
20 | ```
21 |
22 | ## Citation
23 | Please cite our paper if you find this codebase helpful :)
24 |
25 | ```
26 | @article{liu2023diffprotect,
27 | title={DiffProtect: Generate Adversarial Examples with Diffusion Models for Facial Privacy Protection},
28 | author={Liu, Jiang and Lau, Chun Pong and Chellappa, Rama},
29 | journal={arXiv preprint arXiv:2305.13625},
30 | year={2023}
31 | }
32 | ```
33 |
--------------------------------------------------------------------------------
/align.py:
--------------------------------------------------------------------------------
1 | import bz2
2 | import os
3 | import os.path as osp
4 | import sys
5 | from multiprocessing import Pool
6 |
7 | import dlib
8 | import numpy as np
9 | import PIL.Image
10 | import requests
11 | import scipy.ndimage
12 | from tqdm import tqdm
13 | from argparse import ArgumentParser
14 |
15 | LANDMARKS_MODEL_URL = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2'
16 |
17 |
18 | def image_align(src_file,
19 | dst_file,
20 | face_landmarks,
21 | output_size=1024,
22 | transform_size=4096,
23 | enable_padding=True):
24 | # Align function from FFHQ dataset pre-processing step
25 | # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
26 |
27 | lm = np.array(face_landmarks)
28 | lm_chin = lm[0:17] # left-right
29 | lm_eyebrow_left = lm[17:22] # left-right
30 | lm_eyebrow_right = lm[22:27] # left-right
31 | lm_nose = lm[27:31] # top-down
32 | lm_nostrils = lm[31:36] # top-down
33 | lm_eye_left = lm[36:42] # left-clockwise
34 | lm_eye_right = lm[42:48] # left-clockwise
35 | lm_mouth_outer = lm[48:60] # left-clockwise
36 | lm_mouth_inner = lm[60:68] # left-clockwise
37 |
38 | # Calculate auxiliary vectors.
39 | eye_left = np.mean(lm_eye_left, axis=0)
40 | eye_right = np.mean(lm_eye_right, axis=0)
41 | eye_avg = (eye_left + eye_right) * 0.5
42 | eye_to_eye = eye_right - eye_left
43 | mouth_left = lm_mouth_outer[0]
44 | mouth_right = lm_mouth_outer[6]
45 | mouth_avg = (mouth_left + mouth_right) * 0.5
46 | eye_to_mouth = mouth_avg - eye_avg
47 |
48 | # Choose oriented crop rectangle.
49 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
50 | x /= np.hypot(*x)
51 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
52 | y = np.flipud(x) * [-1, 1]
53 | c = eye_avg + eye_to_mouth * 0.1
54 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
55 | qsize = np.hypot(*x) * 2
56 |
57 | # Load in-the-wild image.
58 | if not os.path.isfile(src_file):
59 | print(
60 | '\nCannot find source image. Please run "--wilds" before "--align".'
61 | )
62 | return
63 | img = PIL.Image.open(src_file)
64 | img = img.convert('RGB')
65 |
66 | # Shrink.
67 | shrink = int(np.floor(qsize / output_size * 0.5))
68 | if shrink > 1:
69 | rsize = (int(np.rint(float(img.size[0]) / shrink)),
70 | int(np.rint(float(img.size[1]) / shrink)))
71 | img = img.resize(rsize, PIL.Image.ANTIALIAS)
72 | quad /= shrink
73 | qsize /= shrink
74 |
75 | # Crop.
76 | border = max(int(np.rint(qsize * 0.1)), 3)
77 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
78 | int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
79 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
80 | min(crop[2] + border,
81 | img.size[0]), min(crop[3] + border, img.size[1]))
82 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
83 | img = img.crop(crop)
84 | quad -= crop[0:2]
85 |
86 | # Pad.
87 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
88 | int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
89 | pad = (max(-pad[0] + border,
90 | 0), max(-pad[1] + border,
91 | 0), max(pad[2] - img.size[0] + border,
92 | 0), max(pad[3] - img.size[1] + border, 0))
93 | if enable_padding and max(pad) > border - 4:
94 | pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
95 | img = np.pad(np.float32(img),
96 | ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
97 | h, w, _ = img.shape
98 | y, x, _ = np.ogrid[:h, :w, :1]
99 | mask = np.maximum(
100 | 1.0 -
101 | np.minimum(np.float32(x) / pad[0],
102 | np.float32(w - 1 - x) / pad[2]), 1.0 -
103 | np.minimum(np.float32(y) / pad[1],
104 | np.float32(h - 1 - y) / pad[3]))
105 | blur = qsize * 0.02
106 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) -
107 | img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
108 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
109 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)),
110 | 'RGB')
111 | quad += pad[:2]
112 |
113 | # Transform.
114 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD,
115 | (quad + 0.5).flatten(), PIL.Image.BILINEAR)
116 | if output_size < transform_size:
117 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
118 |
119 | # Save aligned image.
120 | img.save(dst_file, 'PNG')
121 |
122 |
123 | class LandmarksDetector:
124 | def __init__(self, predictor_model_path):
125 | """
126 | :param predictor_model_path: path to shape_predictor_68_face_landmarks.dat file
127 | """
128 | self.detector = dlib.get_frontal_face_detector(
129 | ) # cnn_face_detection_model_v1 also can be used
130 | self.shape_predictor = dlib.shape_predictor(predictor_model_path)
131 |
132 | def get_landmarks(self, image):
133 | img = dlib.load_rgb_image(image)
134 | dets = self.detector(img, 1)
135 |
136 | for detection in dets:
137 | face_landmarks = [
138 | (item.x, item.y)
139 | for item in self.shape_predictor(img, detection).parts()
140 | ]
141 | yield face_landmarks
142 |
143 |
144 | def unpack_bz2(src_path):
145 | dst_path = src_path[:-4]
146 | if os.path.exists(dst_path):
147 | print('cached')
148 | return dst_path
149 | data = bz2.BZ2File(src_path).read()
150 | with open(dst_path, 'wb') as fp:
151 | fp.write(data)
152 | return dst_path
153 |
154 |
155 | def work_landmark(raw_img_path, img_name, face_landmarks):
156 | face_img_name = '%s.png' % (os.path.splitext(img_name)[0], )
157 | aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name)
158 | if os.path.exists(aligned_face_path):
159 | return
160 | image_align(raw_img_path,
161 | aligned_face_path,
162 | face_landmarks,
163 | output_size=256)
164 |
165 |
166 | def get_file(src, tgt):
167 | if os.path.exists(tgt):
168 | print('cached')
169 | return tgt
170 | tgt_dir = os.path.dirname(tgt)
171 | if not os.path.exists(tgt_dir):
172 | os.makedirs(tgt_dir)
173 | file = requests.get(src)
174 | open(tgt, 'wb').write(file.content)
175 | return tgt
176 |
177 |
178 | if __name__ == "__main__":
179 | """
180 | Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step
181 | python align_images.py /raw_images /aligned_images
182 | """
183 | parser = ArgumentParser()
184 | parser.add_argument("-i",
185 | "--input_imgs_path",
186 | type=str,
187 | default="assets/datasets/CelebA-HQ_hard",
188 | help="input images directory path")
189 | parser.add_argument("-o",
190 | "--output_imgs_path",
191 | type=str,
192 | default="assets/datasets/CelebA-HQ_align_hard",
193 | help="output images directory path")
194 |
195 | args = parser.parse_args()
196 |
197 | # takes very long time ...
198 | landmarks_model_path = unpack_bz2(
199 | get_file(
200 | 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2',
201 | 'temp/shape_predictor_68_face_landmarks.dat.bz2'))
202 |
203 | # RAW_IMAGES_DIR = sys.argv[1]
204 | # ALIGNED_IMAGES_DIR = sys.argv[2]
205 | RAW_IMAGES_DIR = args.input_imgs_path
206 | ALIGNED_IMAGES_DIR = args.output_imgs_path
207 |
208 | if not osp.exists(ALIGNED_IMAGES_DIR): os.makedirs(ALIGNED_IMAGES_DIR)
209 |
210 | files = os.listdir(RAW_IMAGES_DIR)
211 | print(f'total img files {len(files)}')
212 | with tqdm(total=len(files)) as progress:
213 |
214 | def cb(*args):
215 | # print('update')
216 | progress.update()
217 |
218 | def err_cb(e):
219 | print('error:', e)
220 |
221 | with Pool(8) as pool:
222 | res = []
223 | landmarks_detector = LandmarksDetector(landmarks_model_path)
224 | for img_name in files:
225 | raw_img_path = os.path.join(RAW_IMAGES_DIR, img_name)
226 | # print('img_name:', img_name)
227 | for i, face_landmarks in enumerate(
228 | landmarks_detector.get_landmarks(raw_img_path),
229 | start=1):
230 | # assert i == 1, f'{i}'
231 | # print(i, face_landmarks)
232 | # face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i)
233 | # aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name)
234 | # image_align(raw_img_path, aligned_face_path, face_landmarks, output_size=256)
235 |
236 | work_landmark(raw_img_path, img_name, face_landmarks)
237 | progress.update()
238 |
239 | # job = pool.apply_async(
240 | # work_landmark,
241 | # (raw_img_path, img_name, face_landmarks),
242 | # callback=cb,
243 | # error_callback=err_cb,
244 | # )
245 | # res.append(job)
246 |
247 | # pool.close()
248 | # pool.join()
249 | print(f"output aligned images at: {ALIGNED_IMAGES_DIR}")
250 |
--------------------------------------------------------------------------------
/attack_utils.py:
--------------------------------------------------------------------------------
1 | from assets.models import irse, ir152, facenet
2 | import torch
3 | import cv2
4 | from advertorch.utils import NormalizeByChannelMeanStd
5 | import torch.nn.functional as F
6 | import torch.nn as nn
7 |
8 | def preprocess(im, mean, std, device):
9 | if len(im.size()) == 3:
10 | im = im.transpose(0, 2).transpose(1, 2).unsqueeze(0)
11 | elif len(im.size()) == 4:
12 | im = im.transpose(1, 3).transpose(2, 3)
13 |
14 | mean = torch.tensor(mean).to(device)
15 | mean = mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
16 | std = torch.tensor(std).to(device)
17 | std = std.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
18 | im = (im - mean) / std
19 | return im
20 |
21 |
22 | def read_img(data_dir, mean, std, device):
23 | img = cv2.imread(data_dir)
24 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255
25 | img = torch.from_numpy(img).to(torch.float32).to(device)
26 | img = preprocess(img, mean, std, device)
27 | return img
28 |
29 | class Net(torch.nn.Module):
30 | def __init__(self, test_models, decoder=None):
31 | super(Net, self).__init__()
32 | self.test_models = test_models
33 | self.decoder = decoder
34 | self.norm = NormalizeByChannelMeanStd([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]).cuda()
35 |
36 | def forward(self, z, xT=None, T=1):
37 |
38 | if xT is None: # input is images
39 | x = z
40 | else: # input are latent codes
41 | #x = self.decoder.render(xT, z, T)
42 | x = self.decoder.render(z, xT, T)
43 |
44 | x = self.norm(x)
45 | features = []
46 | for model_name in self.test_models.keys():
47 | input_size = self.test_models[model_name][0]
48 | fr_model = self.test_models[model_name][1]
49 | source_resize = F.interpolate(x, size=input_size, mode='bilinear')
50 | emb_source = fr_model(source_resize)
51 | features.append(emb_source)
52 | # avg_feature = torch.mean(torch.stack(features), dim=0)
53 | avg_feature = features
54 |
55 | return avg_feature
56 |
57 |
58 | def cos_simi(emb_1, emb_2):
59 | return torch.mean(torch.sum(torch.mul(emb_2, emb_1), dim=1) / emb_2.norm(dim=1) / emb_1.norm(dim=1))
60 |
61 | def Cos_Loss(source_feature, target_feature):
62 | cos_loss_list = []
63 | for i in range(len(source_feature)):
64 | cos_loss_list.append(1 - cos_simi(source_feature[i], target_feature[i].detach()))
65 | # print(1 - cos_simi(source_feature[i], target_feature[i]))
66 | cos_loss = torch.mean(torch.stack(cos_loss_list))
67 | return cos_loss
68 |
69 |
70 | class OhemCELoss(nn.Module):
71 | def __init__(self, thresh, n_min, ignore_lb=255, *args, **kwargs):
72 | super(OhemCELoss, self).__init__()
73 | self.thresh = -torch.log(torch.tensor(thresh, dtype=torch.float)).cuda()
74 | self.n_min = n_min
75 | self.ignore_lb = ignore_lb
76 | self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none')
77 |
78 | def forward(self, logits, labels):
79 | N, C, H, W = logits.size()
80 | loss = self.criteria(logits, labels).view(-1)
81 | loss, _ = torch.sort(loss, descending=True)
82 | if loss[self.n_min] > self.thresh:
83 | loss = loss[loss>self.thresh]
84 | else:
85 | loss = loss[:self.n_min]
86 | return torch.mean(loss)
87 |
88 |
89 | def load_test_models(model_names):
90 | test_models = {}
91 | device='cuda'
92 | for model_name in model_names:
93 | if model_name == 'ir152':
94 | test_models[model_name] = []
95 | test_models[model_name].append((112, 112))
96 | fr_model = ir152.IR_152((112, 112))
97 | fr_model.load_state_dict(torch.load('./assets/models/ir152.pth'))
98 | fr_model.to(device)
99 | fr_model.eval()
100 | test_models[model_name].append(fr_model)
101 | if model_name == 'irse50':
102 | test_models[model_name] = []
103 | test_models[model_name].append((112, 112))
104 | fr_model = irse.Backbone(50, 0.6, 'ir_se')
105 | fr_model.load_state_dict(torch.load('./assets/models/irse50.pth'))
106 | fr_model.to(device)
107 | fr_model.eval()
108 | test_models[model_name].append(fr_model)
109 | if model_name == 'facenet':
110 | test_models[model_name] = []
111 | test_models[model_name].append((160, 160))
112 | fr_model = facenet.InceptionResnetV1(num_classes=8631, device=device)
113 | fr_model.load_state_dict(torch.load('./assets/models/facenet.pth'))
114 | fr_model.to(device)
115 | fr_model.eval()
116 | test_models[model_name].append(fr_model)
117 | if model_name == 'mobile_face':
118 | test_models[model_name] = []
119 | test_models[model_name].append((112, 112))
120 | fr_model = irse.MobileFaceNet(512)
121 | fr_model.load_state_dict(torch.load('./assets/models/mobile_face.pth'))
122 | fr_model.to(device)
123 | fr_model.eval()
124 | test_models[model_name].append(fr_model)
125 | return test_models
126 |
127 |
128 | class Net(torch.nn.Module):
129 | def __init__(self, test_models, decoder=None):
130 | super(Net, self).__init__()
131 | self.test_models = test_models
132 | self.decoder = decoder
133 | self.norm = NormalizeByChannelMeanStd([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]).cuda()
134 |
135 | def forward(self, z, xT=None, T=1):
136 |
137 | if xT is None: # input is images
138 | x = z
139 | else: # input are latent codes
140 | # x = self.decoder.render(xT, z, T)
141 | x = self.decoder.render(z, xT, T)
142 | # print(x.size())
143 |
144 | x = self.norm(x)
145 | features = []
146 | for model_name in self.test_models.keys():
147 | input_size = self.test_models[model_name][0]
148 | fr_model = self.test_models[model_name][1]
149 | source_resize = F.interpolate(x, size=input_size, mode='bilinear')
150 | emb_source = fr_model(source_resize)
151 | features.append(emb_source)
152 | # avg_feature = torch.mean(torch.stack(features), dim=0)
153 | avg_feature = features
154 | # print(len(avg_feature), avg_feature[0].size())
155 |
156 | return avg_feature
157 |
158 |
159 | class Net_fast(torch.nn.Module):
160 | def __init__(self, test_models, decoder=None):
161 | super(Net_fast, self).__init__()
162 | self.test_models = test_models
163 | self.decoder = decoder
164 | self.norm = NormalizeByChannelMeanStd([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]).cuda()
165 | # self.model = self.decoder.ema_model
166 |
167 | def forward(self, z, xT=None, T=1, t=None, predict=True, is_last=False):
168 |
169 | if xT is None: # input is images
170 | x = z
171 | x_pred = z
172 | else: # input are latent codes
173 | # x = self.decoder.render(xT, z, T)
174 | # x = self.decoder.render(z, xT, T)
175 | x, x_pred = self.decoder.render(z, xT, T, True, t, is_last)
176 | # print(x.size())
177 | # x = out["pred_xstart"]
178 | # x_pred = out["sample"]
179 |
180 | if predict:
181 | # if is_last:
182 | # x_norm = self.norm(x_pred)
183 | # else:
184 | # x_norm = self.norm(x)
185 | x_norm = self.norm(x)
186 | # x_norm = x_pred
187 | features = []
188 | for model_name in self.test_models.keys():
189 | input_size = self.test_models[model_name][0]
190 | fr_model = self.test_models[model_name][1]
191 | source_resize = F.interpolate(x_norm, size=input_size, mode='bilinear')
192 | emb_source = fr_model(source_resize)
193 | features.append(emb_source)
194 | # avg_feature = torch.mean(torch.stack(features), dim=0)
195 | avg_feature = features
196 | # print(len(avg_feature), avg_feature[0].size())
197 |
198 | return avg_feature, x_pred, x
199 |
200 | else:
201 | return x_pred, x
202 |
203 | def encode_all(self, z, x, T=1):
204 |
205 | xT = self.decoder.encode_stochastic_all(x, z, T)
206 |
207 | return xT
208 |
209 | # def encode_t(self, z, x, T=1):
210 | #
211 | # xT = self.decoder.encode_stochastic_all(x, z, T)
212 | #
213 | # return xT
214 |
--------------------------------------------------------------------------------
/autoencoding.py:
--------------------------------------------------------------------------------
1 | from templates import *
2 | import torch
3 | from assets.models import irse, ir152, facenet
4 | from advertorch.utils import NormalizeByChannelMeanStd
5 | from iterative_projected_gradient import PGDAttack
6 | from advertorch.context import ctx_noparamgrad_and_eval
7 | from torchvision.transforms import ToPILImage, ToTensor
8 |
9 | class Net(torch.nn.Module):
10 | def __init__(self, test_models, decoder=None):
11 | super(Net, self).__init__()
12 | self.test_models = test_models
13 | self.decoder = decoder
14 | self.norm = NormalizeByChannelMeanStd([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]).cuda()
15 |
16 | def forward(self, z, xT=None, T=1):
17 |
18 | if xT is None: # input is images
19 | x = z
20 | else: # input are latent codes
21 | #x = self.decoder.render(xT, z, T)
22 | x = self.decoder.render(z, xT, T)
23 |
24 | x = self.norm(x)
25 | features = []
26 | for model_name in self.test_models.keys():
27 | input_size = self.test_models[model_name][0]
28 | fr_model = self.test_models[model_name][1]
29 | source_resize = F.interpolate(x, size=input_size, mode='bilinear')
30 | emb_source = fr_model(source_resize)
31 | features.append(emb_source)
32 | # avg_feature = torch.mean(torch.stack(features), dim=0)
33 | avg_feature = features
34 |
35 | return avg_feature
36 |
37 | def cos_simi(emb_1, emb_2):
38 | return torch.mean(torch.sum(torch.mul(emb_2, emb_1), dim=1) / emb_2.norm(dim=1) / emb_1.norm(dim=1))
39 |
40 | def Cos_Loss(source_feature, target_feature):
41 | cos_loss_list = []
42 | for i in range(len(source_feature)):
43 | cos_loss_list.append(1 - cos_simi(source_feature[i], target_feature[i].detach()))
44 | # print(1 - cos_simi(source_feature[i], target_feature[i]))
45 | cos_loss = torch.mean(torch.stack(cos_loss_list))
46 | return cos_loss
47 |
48 | device = 'cuda:0'
49 | conf = ffhq256_autoenc()
50 | model = LitModel(conf)
51 | state = torch.load(f'checkpoints/{conf.name}/last.ckpt', map_location='cpu')
52 | model.load_state_dict(state['state_dict'], strict=False)
53 | model.ema_model.eval()
54 | model.ema_model.to(device)
55 |
56 | data = ImageDataset('imgs_align', image_size=conf.img_size, exts=['jpg', 'JPG', 'png'], do_augment=False)
57 | batch = data[1]['img'][None]
58 |
59 | cond = model.encode(batch.to(device))
60 | xT = model.encode_stochastic(batch.to(device), cond, T=250)
61 |
62 |
63 | model_names = ['ir152', 'irse50', 'facenet']
64 | th_dict = {'ir152': (0.094632, 0.166788, 0.227922), 'irse50': (0.144840, 0.241045, 0.312703),
65 | 'facenet': (0.256587, 0.409131, 0.591191), 'mobile_face': (0.183635, 0.301611, 0.380878)}
66 |
67 | test_models = {}
68 | for model_name in model_names:
69 | if model_name == 'ir152':
70 | test_models[model_name] = []
71 | test_models[model_name].append((112, 112))
72 | fr_model = ir152.IR_152((112, 112))
73 | fr_model.load_state_dict(torch.load('./assets/models/ir152.pth'))
74 | fr_model.to(device)
75 | fr_model.eval()
76 | test_models[model_name].append(fr_model)
77 | if model_name == 'irse50':
78 | test_models[model_name] = []
79 | test_models[model_name].append((112, 112))
80 | fr_model = irse.Backbone(50, 0.6, 'ir_se')
81 | fr_model.load_state_dict(torch.load('./assets/models/irse50.pth'))
82 | fr_model.to(device)
83 | fr_model.eval()
84 | test_models[model_name].append(fr_model)
85 | if model_name == 'facenet':
86 | test_models[model_name] = []
87 | test_models[model_name].append((160, 160))
88 | fr_model = facenet.InceptionResnetV1(num_classes=8631, device=device)
89 | fr_model.load_state_dict(torch.load('./assets/models/facenet.pth'))
90 | fr_model.to(device)
91 | fr_model.eval()
92 | test_models[model_name].append(fr_model)
93 | if model_name == 'facenet':
94 | test_models[model_name] = []
95 | test_models[model_name].append((160, 160))
96 | fr_model = facenet.InceptionResnetV1(num_classes=8631, device=device)
97 | fr_model.load_state_dict(torch.load('./assets/models/facenet.pth'))
98 | fr_model.to(device)
99 | fr_model.eval()
100 | test_models[model_name].append(fr_model)
101 | if model_name == 'mobile_face':
102 | test_models[model_name] = []
103 | test_models[model_name].append((112, 112))
104 | fr_model = irse.MobileFaceNet(512)
105 | fr_model.load_state_dict(torch.load('./assets/models/mobile_face.pth'))
106 | fr_model.to(device)
107 | fr_model.eval()
108 | test_models[model_name].append(fr_model)
109 |
110 |
111 | net = Net(test_models, model).to(device)
112 | test_attacker = PGDAttack(predict=net,
113 | loss_fn=Cos_Loss,
114 | eps=0.1,
115 | eps_iter=0.02,
116 | nb_iter=10,
117 | clip_min=-1e6,
118 | clip_max=1e6,
119 | targeted=True,
120 | rand_init=False)
121 |
122 |
123 | target = data[0]['img'][None]
124 |
125 | target_embeding = net(target.to(device))
126 |
127 | z_adv = test_attacker.perturb(cond, target_embeding, xT, T=5)
128 |
129 |
--------------------------------------------------------------------------------
/choices.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from torch import nn
3 |
4 |
5 | class TrainMode(Enum):
6 | # manipulate mode = training the classifier
7 | manipulate = 'manipulate'
8 | # default trainin mode!
9 | diffusion = 'diffusion'
10 | # default latent training mode!
11 | # fitting the a DDPM to a given latent
12 | latent_diffusion = 'latentdiffusion'
13 |
14 | def is_manipulate(self):
15 | return self in [
16 | TrainMode.manipulate,
17 | ]
18 |
19 | def is_diffusion(self):
20 | return self in [
21 | TrainMode.diffusion,
22 | TrainMode.latent_diffusion,
23 | ]
24 |
25 | def is_autoenc(self):
26 | # the network possibly does autoencoding
27 | return self in [
28 | TrainMode.diffusion,
29 | ]
30 |
31 | def is_latent_diffusion(self):
32 | return self in [
33 | TrainMode.latent_diffusion,
34 | ]
35 |
36 | def use_latent_net(self):
37 | return self.is_latent_diffusion()
38 |
39 | def require_dataset_infer(self):
40 | """
41 | whether training in this mode requires the latent variables to be available?
42 | """
43 | # this will precalculate all the latents before hand
44 | # and the dataset will be all the predicted latents
45 | return self in [
46 | TrainMode.latent_diffusion,
47 | TrainMode.manipulate,
48 | ]
49 |
50 |
51 | class ManipulateMode(Enum):
52 | """
53 | how to train the classifier to manipulate
54 | """
55 | # train on whole celeba attr dataset
56 | celebahq_all = 'celebahq_all'
57 | # celeba with D2C's crop
58 | d2c_fewshot = 'd2cfewshot'
59 | d2c_fewshot_allneg = 'd2cfewshotallneg'
60 |
61 | def is_celeba_attr(self):
62 | return self in [
63 | ManipulateMode.d2c_fewshot,
64 | ManipulateMode.d2c_fewshot_allneg,
65 | ManipulateMode.celebahq_all,
66 | ]
67 |
68 | def is_single_class(self):
69 | return self in [
70 | ManipulateMode.d2c_fewshot,
71 | ManipulateMode.d2c_fewshot_allneg,
72 | ]
73 |
74 | def is_fewshot(self):
75 | return self in [
76 | ManipulateMode.d2c_fewshot,
77 | ManipulateMode.d2c_fewshot_allneg,
78 | ]
79 |
80 | def is_fewshot_allneg(self):
81 | return self in [
82 | ManipulateMode.d2c_fewshot_allneg,
83 | ]
84 |
85 |
86 | class ModelType(Enum):
87 | """
88 | Kinds of the backbone models
89 | """
90 |
91 | # unconditional ddpm
92 | ddpm = 'ddpm'
93 | # autoencoding ddpm cannot do unconditional generation
94 | autoencoder = 'autoencoder'
95 |
96 | def has_autoenc(self):
97 | return self in [
98 | ModelType.autoencoder,
99 | ]
100 |
101 | def can_sample(self):
102 | return self in [ModelType.ddpm]
103 |
104 |
105 | class ModelName(Enum):
106 | """
107 | List of all supported model classes
108 | """
109 |
110 | beatgans_ddpm = 'beatgans_ddpm'
111 | beatgans_autoenc = 'beatgans_autoenc'
112 |
113 |
114 | class ModelMeanType(Enum):
115 | """
116 | Which type of output the model predicts.
117 | """
118 |
119 | eps = 'eps' # the model predicts epsilon
120 |
121 |
122 | class ModelVarType(Enum):
123 | """
124 | What is used as the model's output variance.
125 |
126 | The LEARNED_RANGE option has been added to allow the model to predict
127 | values between FIXED_SMALL and FIXED_LARGE, making its job easier.
128 | """
129 |
130 | # posterior beta_t
131 | fixed_small = 'fixed_small'
132 | # beta_t
133 | fixed_large = 'fixed_large'
134 |
135 |
136 | class LossType(Enum):
137 | mse = 'mse' # use raw MSE loss (and KL when learning variances)
138 | l1 = 'l1'
139 |
140 |
141 | class GenerativeType(Enum):
142 | """
143 | How's a sample generated
144 | """
145 |
146 | ddpm = 'ddpm'
147 | ddim = 'ddim'
148 |
149 |
150 | class OptimizerType(Enum):
151 | adam = 'adam'
152 | adamw = 'adamw'
153 |
154 |
155 | class Activation(Enum):
156 | none = 'none'
157 | relu = 'relu'
158 | lrelu = 'lrelu'
159 | silu = 'silu'
160 | tanh = 'tanh'
161 |
162 | def get_act(self):
163 | if self == Activation.none:
164 | return nn.Identity()
165 | elif self == Activation.relu:
166 | return nn.ReLU()
167 | elif self == Activation.lrelu:
168 | return nn.LeakyReLU(negative_slope=0.2)
169 | elif self == Activation.silu:
170 | return nn.SiLU()
171 | elif self == Activation.tanh:
172 | return nn.Tanh()
173 | else:
174 | raise NotImplementedError()
175 |
176 |
177 | class ManipulateLossType(Enum):
178 | bce = 'bce'
179 | mse = 'mse'
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from model.unet import ScaleAt
2 | from model.latentnet import *
3 | from diffusion.resample import UniformSampler
4 | from diffusion.diffusion import space_timesteps
5 | from typing import Tuple
6 |
7 | from torch.utils.data import DataLoader
8 |
9 | from config_base import BaseConfig
10 | from dataset import *
11 | from diffusion import *
12 | from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule
13 | from model import *
14 | from choices import *
15 | from multiprocessing import get_context
16 | import os
17 | from dataset_util import *
18 | from torch.utils.data.distributed import DistributedSampler
19 |
20 | data_paths = {
21 | 'ffhqlmdb256':
22 | os.path.expanduser('datasets/ffhq256.lmdb'),
23 | # used for training a classifier
24 | 'celeba':
25 | os.path.expanduser('datasets/celeba'),
26 | # used for training DPM models
27 | 'celebalmdb':
28 | os.path.expanduser('datasets/celeba.lmdb'),
29 | 'celebahq':
30 | os.path.expanduser('datasets/celebahq256.lmdb'),
31 | 'horse256':
32 | os.path.expanduser('datasets/horse256.lmdb'),
33 | 'bedroom256':
34 | os.path.expanduser('datasets/bedroom256.lmdb'),
35 | 'celeba_anno':
36 | os.path.expanduser('datasets/celeba_anno/list_attr_celeba.txt'),
37 | 'celebahq_anno':
38 | os.path.expanduser(
39 | 'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'),
40 | 'celeba_relight':
41 | os.path.expanduser('datasets/celeba_hq_light/celeba_light.txt'),
42 | }
43 |
44 |
45 | @dataclass
46 | class PretrainConfig(BaseConfig):
47 | name: str
48 | path: str
49 |
50 |
51 | @dataclass
52 | class TrainConfig(BaseConfig):
53 | # random seed
54 | seed: int = 0
55 | train_mode: TrainMode = TrainMode.diffusion
56 | train_cond0_prob: float = 0
57 | train_pred_xstart_detach: bool = True
58 | train_interpolate_prob: float = 0
59 | train_interpolate_img: bool = False
60 | manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all
61 | manipulate_cls: str = None
62 | manipulate_shots: int = None
63 | manipulate_loss: ManipulateLossType = ManipulateLossType.bce
64 | manipulate_znormalize: bool = False
65 | manipulate_seed: int = 0
66 | accum_batches: int = 1
67 | autoenc_mid_attn: bool = True
68 | batch_size: int = 16
69 | batch_size_eval: int = None
70 | beatgans_gen_type: GenerativeType = GenerativeType.ddim
71 | beatgans_loss_type: LossType = LossType.mse
72 | beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
73 | beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
74 | beatgans_rescale_timesteps: bool = False
75 | latent_infer_path: str = None
76 | latent_znormalize: bool = False
77 | latent_gen_type: GenerativeType = GenerativeType.ddim
78 | latent_loss_type: LossType = LossType.mse
79 | latent_model_mean_type: ModelMeanType = ModelMeanType.eps
80 | latent_model_var_type: ModelVarType = ModelVarType.fixed_large
81 | latent_rescale_timesteps: bool = False
82 | latent_T_eval: int = 1_000
83 | latent_clip_sample: bool = False
84 | latent_beta_scheduler: str = 'linear'
85 | beta_scheduler: str = 'linear'
86 | data_name: str = ''
87 | data_val_name: str = None
88 | diffusion_type: str = None
89 | dropout: float = 0.1
90 | ema_decay: float = 0.9999
91 | eval_num_images: int = 5_000
92 | eval_every_samples: int = 200_000
93 | eval_ema_every_samples: int = 200_000
94 | fid_use_torch: bool = True
95 | fp16: bool = False
96 | grad_clip: float = 1
97 | img_size: int = 64
98 | lr: float = 0.0001
99 | optimizer: OptimizerType = OptimizerType.adam
100 | weight_decay: float = 0
101 | model_conf: ModelConfig = None
102 | model_name: ModelName = None
103 | model_type: ModelType = None
104 | net_attn: Tuple[int] = None
105 | net_beatgans_attn_head: int = 1
106 | # not necessarily the same as the the number of style channels
107 | net_beatgans_embed_channels: int = 512
108 | net_resblock_updown: bool = True
109 | net_enc_use_time: bool = False
110 | net_enc_pool: str = 'adaptivenonzero'
111 | net_beatgans_gradient_checkpoint: bool = False
112 | net_beatgans_resnet_two_cond: bool = False
113 | net_beatgans_resnet_use_zero_module: bool = True
114 | net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm
115 | net_beatgans_resnet_cond_channels: int = None
116 | net_ch_mult: Tuple[int] = None
117 | net_ch: int = 64
118 | net_enc_attn: Tuple[int] = None
119 | net_enc_k: int = None
120 | # number of resblocks for the encoder (half-unet)
121 | net_enc_num_res_blocks: int = 2
122 | net_enc_channel_mult: Tuple[int] = None
123 | net_enc_grad_checkpoint: bool = False
124 | net_autoenc_stochastic: bool = False
125 | net_latent_activation: Activation = Activation.silu
126 | net_latent_channel_mult: Tuple[int] = (1, 2, 4)
127 | net_latent_condition_bias: float = 0
128 | net_latent_dropout: float = 0
129 | net_latent_layers: int = None
130 | net_latent_net_last_act: Activation = Activation.none
131 | net_latent_net_type: LatentNetType = LatentNetType.none
132 | net_latent_num_hid_channels: int = 1024
133 | net_latent_num_time_layers: int = 2
134 | net_latent_skip_layers: Tuple[int] = None
135 | net_latent_time_emb_channels: int = 64
136 | net_latent_use_norm: bool = False
137 | net_latent_time_last_act: bool = False
138 | net_num_res_blocks: int = 2
139 | # number of resblocks for the UNET
140 | net_num_input_res_blocks: int = None
141 | net_enc_num_cls: int = None
142 | num_workers: int = 4
143 | parallel: bool = False
144 | postfix: str = ''
145 | sample_size: int = 64
146 | sample_every_samples: int = 20_000
147 | save_every_samples: int = 100_000
148 | style_ch: int = 512
149 | T_eval: int = 1_000
150 | T_sampler: str = 'uniform'
151 | T: int = 1_000
152 | total_samples: int = 10_000_000
153 | warmup: int = 0
154 | pretrain: PretrainConfig = None
155 | continue_from: PretrainConfig = None
156 | eval_programs: Tuple[str] = None
157 | # if present load the checkpoint from this path instead
158 | eval_path: str = None
159 | base_dir: str = 'checkpoints'
160 | use_cache_dataset: bool = False
161 | data_cache_dir: str = os.path.expanduser('~/cache')
162 | work_cache_dir: str = os.path.expanduser('~/mycache')
163 | # to be overridden
164 | name: str = ''
165 |
166 | def __post_init__(self):
167 | self.batch_size_eval = self.batch_size_eval or self.batch_size
168 | self.data_val_name = self.data_val_name or self.data_name
169 |
170 | def scale_up_gpus(self, num_gpus, num_nodes=1):
171 | self.eval_ema_every_samples *= num_gpus * num_nodes
172 | self.eval_every_samples *= num_gpus * num_nodes
173 | self.sample_every_samples *= num_gpus * num_nodes
174 | self.batch_size *= num_gpus * num_nodes
175 | self.batch_size_eval *= num_gpus * num_nodes
176 | return self
177 |
178 | @property
179 | def batch_size_effective(self):
180 | return self.batch_size * self.accum_batches
181 |
182 | @property
183 | def fid_cache(self):
184 | # we try to use the local dirs to reduce the load over network drives
185 | # hopefully, this would reduce the disconnection problems with sshfs
186 | return f'{self.work_cache_dir}/eval_images/{self.data_name}_size{self.img_size}_{self.eval_num_images}'
187 |
188 | @property
189 | def data_path(self):
190 | # may use the cache dir
191 | path = data_paths[self.data_name]
192 | if self.use_cache_dataset and path is not None:
193 | path = use_cached_dataset_path(
194 | path, f'{self.data_cache_dir}/{self.data_name}')
195 | return path
196 |
197 | @property
198 | def logdir(self):
199 | return f'{self.base_dir}/{self.name}'
200 |
201 | @property
202 | def generate_dir(self):
203 | # we try to use the local dirs to reduce the load over network drives
204 | # hopefully, this would reduce the disconnection problems with sshfs
205 | return f'{self.work_cache_dir}/gen_images/{self.name}'
206 |
207 | def _make_diffusion_conf(self, T=None):
208 | if self.diffusion_type == 'beatgans':
209 | # can use T < self.T for evaluation
210 | # follows the guided-diffusion repo conventions
211 | # t's are evenly spaced
212 | if self.beatgans_gen_type == GenerativeType.ddpm:
213 | section_counts = [T]
214 | elif self.beatgans_gen_type == GenerativeType.ddim:
215 | section_counts = f'ddim{T}'
216 | else:
217 | raise NotImplementedError()
218 |
219 | return SpacedDiffusionBeatGansConfig(
220 | gen_type=self.beatgans_gen_type,
221 | model_type=self.model_type,
222 | betas=get_named_beta_schedule(self.beta_scheduler, self.T),
223 | model_mean_type=self.beatgans_model_mean_type,
224 | model_var_type=self.beatgans_model_var_type,
225 | loss_type=self.beatgans_loss_type,
226 | rescale_timesteps=self.beatgans_rescale_timesteps,
227 | use_timesteps=space_timesteps(num_timesteps=self.T,
228 | section_counts=section_counts),
229 | fp16=self.fp16,
230 | )
231 | else:
232 | raise NotImplementedError()
233 |
234 | def _make_latent_diffusion_conf(self, T=None):
235 | # can use T < self.T for evaluation
236 | # follows the guided-diffusion repo conventions
237 | # t's are evenly spaced
238 | if self.latent_gen_type == GenerativeType.ddpm:
239 | section_counts = [T]
240 | elif self.latent_gen_type == GenerativeType.ddim:
241 | section_counts = f'ddim{T}'
242 | else:
243 | raise NotImplementedError()
244 |
245 | return SpacedDiffusionBeatGansConfig(
246 | train_pred_xstart_detach=self.train_pred_xstart_detach,
247 | gen_type=self.latent_gen_type,
248 | # latent's model is always ddpm
249 | model_type=ModelType.ddpm,
250 | # latent shares the beta scheduler and full T
251 | betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T),
252 | model_mean_type=self.latent_model_mean_type,
253 | model_var_type=self.latent_model_var_type,
254 | loss_type=self.latent_loss_type,
255 | rescale_timesteps=self.latent_rescale_timesteps,
256 | use_timesteps=space_timesteps(num_timesteps=self.T,
257 | section_counts=section_counts),
258 | fp16=self.fp16,
259 | )
260 |
261 | @property
262 | def model_out_channels(self):
263 | return 3
264 |
265 | def make_T_sampler(self):
266 | if self.T_sampler == 'uniform':
267 | return UniformSampler(self.T)
268 | else:
269 | raise NotImplementedError()
270 |
271 | def make_diffusion_conf(self):
272 | return self._make_diffusion_conf(self.T)
273 |
274 | def make_eval_diffusion_conf(self):
275 | return self._make_diffusion_conf(T=self.T_eval)
276 |
277 | def make_latent_diffusion_conf(self):
278 | return self._make_latent_diffusion_conf(T=self.T)
279 |
280 | def make_latent_eval_diffusion_conf(self):
281 | # latent can have different eval T
282 | return self._make_latent_diffusion_conf(T=self.latent_T_eval)
283 |
284 | def make_dataset(self, path=None, **kwargs):
285 | if self.data_name == 'ffhqlmdb256':
286 | return FFHQlmdb(path=path or self.data_path,
287 | image_size=self.img_size,
288 | **kwargs)
289 | elif self.data_name == 'horse256':
290 | return Horse_lmdb(path=path or self.data_path,
291 | image_size=self.img_size,
292 | **kwargs)
293 | elif self.data_name == 'bedroom256':
294 | return Horse_lmdb(path=path or self.data_path,
295 | image_size=self.img_size,
296 | **kwargs)
297 | elif self.data_name == 'celebalmdb':
298 | # always use d2c crop
299 | return CelebAlmdb(path=path or self.data_path,
300 | image_size=self.img_size,
301 | original_resolution=None,
302 | crop_d2c=True,
303 | **kwargs)
304 | else:
305 | raise NotImplementedError()
306 |
307 | def make_loader(self,
308 | dataset,
309 | shuffle: bool,
310 | num_worker: bool = None,
311 | drop_last: bool = True,
312 | batch_size: int = None,
313 | parallel: bool = False):
314 | if parallel and distributed.is_initialized():
315 | # drop last to make sure that there is no added special indexes
316 | sampler = DistributedSampler(dataset,
317 | shuffle=shuffle,
318 | drop_last=True)
319 | else:
320 | sampler = None
321 | return DataLoader(
322 | dataset,
323 | batch_size=batch_size or self.batch_size,
324 | sampler=sampler,
325 | # with sampler, use the sample instead of this option
326 | shuffle=False if sampler else shuffle,
327 | num_workers=num_worker or self.num_workers,
328 | pin_memory=True,
329 | drop_last=drop_last,
330 | multiprocessing_context=get_context('fork'),
331 | )
332 |
333 | def make_model_conf(self):
334 | if self.model_name == ModelName.beatgans_ddpm:
335 | self.model_type = ModelType.ddpm
336 | self.model_conf = BeatGANsUNetConfig(
337 | attention_resolutions=self.net_attn,
338 | channel_mult=self.net_ch_mult,
339 | conv_resample=True,
340 | dims=2,
341 | dropout=self.dropout,
342 | embed_channels=self.net_beatgans_embed_channels,
343 | image_size=self.img_size,
344 | in_channels=3,
345 | model_channels=self.net_ch,
346 | num_classes=None,
347 | num_head_channels=-1,
348 | num_heads_upsample=-1,
349 | num_heads=self.net_beatgans_attn_head,
350 | num_res_blocks=self.net_num_res_blocks,
351 | num_input_res_blocks=self.net_num_input_res_blocks,
352 | out_channels=self.model_out_channels,
353 | resblock_updown=self.net_resblock_updown,
354 | use_checkpoint=self.net_beatgans_gradient_checkpoint,
355 | use_new_attention_order=False,
356 | resnet_two_cond=self.net_beatgans_resnet_two_cond,
357 | resnet_use_zero_module=self.
358 | net_beatgans_resnet_use_zero_module,
359 | )
360 | elif self.model_name in [
361 | ModelName.beatgans_autoenc,
362 | ]:
363 | cls = BeatGANsAutoencConfig
364 | # supports both autoenc and vaeddpm
365 | if self.model_name == ModelName.beatgans_autoenc:
366 | self.model_type = ModelType.autoencoder
367 | else:
368 | raise NotImplementedError()
369 |
370 | if self.net_latent_net_type == LatentNetType.none:
371 | latent_net_conf = None
372 | elif self.net_latent_net_type == LatentNetType.skip:
373 | latent_net_conf = MLPSkipNetConfig(
374 | num_channels=self.style_ch,
375 | skip_layers=self.net_latent_skip_layers,
376 | num_hid_channels=self.net_latent_num_hid_channels,
377 | num_layers=self.net_latent_layers,
378 | num_time_emb_channels=self.net_latent_time_emb_channels,
379 | activation=self.net_latent_activation,
380 | use_norm=self.net_latent_use_norm,
381 | condition_bias=self.net_latent_condition_bias,
382 | dropout=self.net_latent_dropout,
383 | last_act=self.net_latent_net_last_act,
384 | num_time_layers=self.net_latent_num_time_layers,
385 | time_last_act=self.net_latent_time_last_act,
386 | )
387 | else:
388 | raise NotImplementedError()
389 |
390 | self.model_conf = cls(
391 | attention_resolutions=self.net_attn,
392 | channel_mult=self.net_ch_mult,
393 | conv_resample=True,
394 | dims=2,
395 | dropout=self.dropout,
396 | embed_channels=self.net_beatgans_embed_channels,
397 | enc_out_channels=self.style_ch,
398 | enc_pool=self.net_enc_pool,
399 | enc_num_res_block=self.net_enc_num_res_blocks,
400 | enc_channel_mult=self.net_enc_channel_mult,
401 | enc_grad_checkpoint=self.net_enc_grad_checkpoint,
402 | enc_attn_resolutions=self.net_enc_attn,
403 | image_size=self.img_size,
404 | in_channels=3,
405 | model_channels=self.net_ch,
406 | num_classes=None,
407 | num_head_channels=-1,
408 | num_heads_upsample=-1,
409 | num_heads=self.net_beatgans_attn_head,
410 | num_res_blocks=self.net_num_res_blocks,
411 | num_input_res_blocks=self.net_num_input_res_blocks,
412 | out_channels=self.model_out_channels,
413 | resblock_updown=self.net_resblock_updown,
414 | use_checkpoint=self.net_beatgans_gradient_checkpoint,
415 | use_new_attention_order=False,
416 | resnet_two_cond=self.net_beatgans_resnet_two_cond,
417 | resnet_use_zero_module=self.
418 | net_beatgans_resnet_use_zero_module,
419 | latent_net_conf=latent_net_conf,
420 | resnet_cond_channels=self.net_beatgans_resnet_cond_channels,
421 | )
422 | else:
423 | raise NotImplementedError(self.model_name)
424 |
425 | return self.model_conf
426 |
--------------------------------------------------------------------------------
/config_base.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from copy import deepcopy
4 | from dataclasses import dataclass
5 |
6 |
7 | @dataclass
8 | class BaseConfig:
9 | def clone(self):
10 | return deepcopy(self)
11 |
12 | def inherit(self, another):
13 | """inherit common keys from a given config"""
14 | common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys())
15 | for k in common_keys:
16 | setattr(self, k, getattr(another, k))
17 |
18 | def propagate(self):
19 | """push down the configuration to all members"""
20 | for k, v in self.__dict__.items():
21 | if isinstance(v, BaseConfig):
22 | v.inherit(self)
23 | v.propagate()
24 |
25 | def save(self, save_path):
26 | """save config to json file"""
27 | dirname = os.path.dirname(save_path)
28 | if not os.path.exists(dirname):
29 | os.makedirs(dirname)
30 | conf = self.as_dict_jsonable()
31 | with open(save_path, 'w') as f:
32 | json.dump(conf, f)
33 |
34 | def load(self, load_path):
35 | """load json config"""
36 | with open(load_path) as f:
37 | conf = json.load(f)
38 | self.from_dict(conf)
39 |
40 | def from_dict(self, dict, strict=False):
41 | for k, v in dict.items():
42 | if not hasattr(self, k):
43 | if strict:
44 | raise ValueError(f"loading extra '{k}'")
45 | else:
46 | print(f"loading extra '{k}'")
47 | continue
48 | if isinstance(self.__dict__[k], BaseConfig):
49 | self.__dict__[k].from_dict(v)
50 | else:
51 | self.__dict__[k] = v
52 |
53 | def as_dict_jsonable(self):
54 | conf = {}
55 | for k, v in self.__dict__.items():
56 | if isinstance(v, BaseConfig):
57 | conf[k] = v.as_dict_jsonable()
58 | else:
59 | if jsonable(v):
60 | conf[k] = v
61 | else:
62 | # ignore not jsonable
63 | pass
64 | return conf
65 |
66 |
67 | def jsonable(x):
68 | try:
69 | json.dumps(x)
70 | return True
71 | except TypeError:
72 | return False
73 |
--------------------------------------------------------------------------------
/dataset_util.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import os
3 | from dist_utils import *
4 |
5 |
6 | def use_cached_dataset_path(source_path, cache_path):
7 | if get_rank() == 0:
8 | if not os.path.exists(cache_path):
9 | # shutil.rmtree(cache_path)
10 | print(f'copying the data: {source_path} to {cache_path}')
11 | shutil.copytree(source_path, cache_path)
12 | barrier()
13 | return cache_path
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from PIL import Image
4 | from tqdm import tqdm
5 | import argparse
6 | import torch.nn.functional as F
7 | from templates import *
8 | from iterative_projected_gradient_fast import PGDAttack
9 | from attack_utils import load_test_models, Cos_Loss, read_img, Net_fast, OhemCELoss
10 | from torchvision.utils import save_image
11 | import glob
12 | import cv2
13 | import numpy as np
14 |
15 | from external.face_makeup.model import BiSeNet
16 | from external.face_makeup.test import vis_parsing_maps
17 |
18 |
19 | class Timer(object):
20 | """A simple timer."""
21 | def __init__(self):
22 | self.total_time = 0.
23 | self.calls = 0
24 | self.start_time = 0.
25 | self.diff = 0.
26 | self.average_time = 0.
27 |
28 | def tic(self):
29 | # using time.time instead of time.clock because time time.clock
30 | # does not normalize for multithreading
31 | self.start_time = time.time()
32 |
33 | def toc(self, average=True):
34 | self.diff = time.time() - self.start_time
35 | self.total_time += self.diff
36 | self.calls += 1
37 | self.average_time = self.total_time / self.calls
38 | if average:
39 | return self.average_time
40 | else:
41 | return self.diff
42 |
43 | def clear(self):
44 | self.total_time = 0.
45 | self.calls = 0
46 | self.start_time = 0.
47 | self.diff = 0.
48 | self.average_time = 0.
49 |
50 | def vis_parsing_maps(im, parsing_anno, stride=1, save_im=True, save_path='parsing_map_on_im.png'):
51 | # Colors for all 20 parts
52 | part_colors = [[0, 0, 0], [255, 85, 0], [255, 170, 0],
53 | [255, 0, 85], [255, 0, 170],
54 | [0, 255, 0], [85, 255, 0], [170, 255, 0],
55 | [0, 255, 85], [0, 255, 170],
56 | [0, 0, 255], [85, 0, 255], [170, 0, 255],
57 | [0, 85, 255], [0, 170, 255],
58 | [255, 255, 0], [255, 255, 85], [255, 255, 170],
59 | [255, 0, 255], [255, 85, 255], [255, 170, 255],
60 | [0, 255, 255], [85, 255, 255], [170, 255, 255]]
61 |
62 | im = np.array(im)
63 | vis_im = im.copy().astype(np.uint8)
64 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
65 | vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
66 | vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
67 |
68 | num_of_class = np.max(vis_parsing_anno)
69 |
70 | for pi in range(0, num_of_class + 1):
71 | index = np.where(vis_parsing_anno == pi)
72 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
73 |
74 | # pi = 0
75 | # index = np.where(vis_parsing_anno == pi)
76 | # vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
77 |
78 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
79 | # print(vis_parsing_anno_color.shape, vis_im.shape)
80 | # vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
81 | vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0., vis_parsing_anno_color, 1., 0)
82 |
83 | # Save result or not
84 | if save_im:
85 | cv2.imwrite(save_path, vis_parsing_anno)
86 | cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
87 | return vis_parsing_anno
88 |
89 | def generate(args):
90 | eps = args.eps
91 | iter = args.iter
92 | T_enc = args.T_enc
93 | T_atk = args.T_atk
94 | T_inf = args.T_inf
95 | start_t = args.start_t
96 | attack_iter = args.attack_iter
97 | attack_inf_iter = args.attack_inf_iter
98 | repeat_times = args.repeat_times
99 | cnt_skip = args.cnt_skip
100 | lam = args.lam
101 | vis_full = args.vis_full
102 |
103 | # set up FR models
104 | model_names = args.model_names # ['ir152', 'irse50', 'facenet']
105 | test_models = load_test_models(model_names)
106 |
107 | # set up diffae model
108 | conf = ffhq256_autoenc()
109 | model = LitModel(conf)
110 | state = torch.load(f'checkpoints/{conf.name}/last.ckpt', map_location='cpu')
111 | model.load_state_dict(state['state_dict'], strict=False)
112 | model.ema_model.eval()
113 | model.ema_model.cuda()
114 |
115 | # set up face parsing model
116 | if lam > 0:
117 | net_parse = BiSeNet(n_classes=19)
118 | net_parse.load_state_dict(torch.load('./external/face_makeup/cp/79999_iter.pth'))
119 | net_parse.cuda()
120 | net_parse.eval()
121 | loss_parse_fn = OhemCELoss(thresh=0.7, n_min=256 * 256 // 16, ignore_lb=0)
122 | else:
123 | loss_parse_fn = None
124 | net_parse = None
125 |
126 | # set up attacker
127 | net = Net_fast(test_models, model).cuda()
128 | test_attacker = PGDAttack(predict=net,
129 | loss_fn=Cos_Loss,
130 | eps=eps,
131 | eps_iter=2 * eps / (iter * attack_iter * repeat_times),
132 | nb_iter=iter,
133 | T_atk=T_atk,
134 | T_enc=T_enc,
135 | start_t=start_t,
136 | attack_iter=attack_iter,
137 | attack_inf_iter=attack_inf_iter,
138 | repeat_times=repeat_times,
139 | cnt_skip=cnt_skip,
140 | lam=lam,
141 | clip_min=-1e6,
142 | clip_max=1e6,
143 | targeted=True,
144 | rand_init=False,
145 | loss_parse_fn=loss_parse_fn,
146 | parse=net_parse,
147 | vis_full = args.vis_full)
148 |
149 | # Target embedding
150 | target = read_img(args.target_path, 0., 1., 'cuda')
151 | with torch.no_grad():
152 | target_embeding, _, _ = net(target)
153 |
154 | save_path = args.save_path
155 | os.makedirs(save_path, exist_ok=True)
156 | source_paths = os.listdir(args.source_dir)
157 | timer = Timer()
158 | time_list = []
159 | for source_path in tqdm(source_paths):
160 | source_name = source_path.replace('.jpg', '.png')
161 | source_path = os.path.join(args.source_dir, source_path)
162 | src_img = read_img(source_path, 0.5, 0.5, 'cuda')
163 |
164 | timer.tic()
165 |
166 | # Semantic Regularization
167 | if lam > 0:
168 | parse_map = net_parse((src_img + 1) / 2)[0]
169 | parse_map = parse_map.squeeze(0).argmax(0).unsqueeze(0).detach()
170 |
171 | parse_map[parse_map == 17] = 0 # denote hair as background, which we ignore in the loss
172 |
173 | # encode images
174 | cond = model.encode(src_img)
175 |
176 | xT_all = model.encode_stochastic_all(src_img, cond, T=T_enc) # xT_all is [x_1, ..., x_T=Noise]
177 |
178 | xT = xT_all[start_t * (T_enc // T_atk) - 1] # Make the stocastic noise consitent from the encoding step to the attack step
179 |
180 | if lam > 0:
181 | if vis_full:
182 | z_adv, adv_img, img_list = test_attacker.perturb(cond, target_embeding, xT, parse_map=parse_map)
183 | else:
184 | z_adv, adv_img = test_attacker.perturb(cond, target_embeding, xT, parse_map=parse_map)
185 | else:
186 | if vis_full:
187 | z_adv, adv_img, img_list = test_attacker.perturb(cond, target_embeding, xT)
188 | else:
189 | z_adv, adv_img = test_attacker.perturb(cond, target_embeding, xT)
190 |
191 | # reconstruct adv images
192 | xT = xT_all[-1]
193 | adv_render_img = model.render(xT, z_adv, T=T_inf)
194 | avg_time = timer.toc()
195 | time_list.append(avg_time)
196 |
197 | save_name = os.path.join(save_path, source_name)
198 | save_image(adv_render_img, save_name)
199 |
200 | # Full visualization
201 | if vis_full:
202 | vis_path = os.path.join(args.save_path, 'vis_full')
203 | os.makedirs(vis_path, exist_ok=True)
204 | source_name = source_name.split('.')[0]
205 | vis_input_name = os.path.join(vis_path, source_name + '_input.png')
206 | vis_encode_name = os.path.join(vis_path, source_name + '_encode.png')
207 | vis_adv_name = os.path.join(vis_path, source_name + '_adv.png')
208 | vis_input_parse_name = os.path.join(vis_path, source_name + '_input_parse.png')
209 | vis_adv_parse_name = os.path.join(vis_path, source_name + '_adv_parse.png')
210 |
211 | parse_map_input = net_parse((src_img + 1) / 2)[0]
212 | parse_map_input = parse_map_input.squeeze(0).cpu().detach().numpy().argmax(0)
213 |
214 | parse_map_adv = net_parse(adv_render_img)[0]
215 | parse_map_adv = parse_map_adv.squeeze(0).cpu().detach().numpy().argmax(0)
216 |
217 | save_image((src_img + 1) / 2, vis_input_name)
218 | save_image(xT, vis_encode_name)
219 | save_image(adv_render_img, vis_adv_name)
220 | vis_parsing_maps(Image.open(vis_input_name), parse_map_input, save_path=vis_input_parse_name)
221 | vis_parsing_maps(Image.open(vis_adv_name), parse_map_adv, save_path=vis_adv_parse_name)
222 |
223 | for i, x_tmp in enumerate(img_list):
224 | vis_tmp_name = os.path.join(vis_path, source_name + f'_{str(i)}.png')
225 | save_image((x_tmp + 1) / 2, vis_tmp_name)
226 |
227 | return
228 |
229 | print('Finished! Image saved in:', os.path.abspath(args.save_path))
230 | result_fn = os.path.join(args.save_path, "time.txt")
231 | f = open(result_fn, 'a')
232 | print('Time: ', round(np.average(time_list),2))
233 | f.write(f"Time: {round(np.average(time_list),2)}\n")
234 |
235 | def attack_local_models(args, attack=True):
236 | test_models = load_test_models(args.test_model_names)
237 | th_dict = {'ir152': (0.094632, 0.166788, 0.227922), 'irse50': (0.144840, 0.241045, 0.312703),
238 | 'facenet': (0.256587, 0.409131, 0.591191), 'mobile_face': (0.183635, 0.301611, 0.380878), 'cosface': (0.144840, 0.241045, 0.312703), 'arcface': (0.144840, 0.241045, 0.312703)}
239 |
240 | result_fn = os.path.join(args.save_path, "result.txt")
241 | f = open(result_fn, 'a')
242 | print('Is Adversarial Attack:', attack)
243 | f.write(f"Is Adversarial Attack: {attack}\n")
244 |
245 | combined_dir = os.path.join(args.save_path, "combined")
246 | os.makedirs(combined_dir, exist_ok=True)
247 |
248 | for test_model in test_models.keys():
249 | size = test_models[test_model][0]
250 | model = test_models[test_model][1]
251 |
252 | target = read_img(args.test_path, 0.5, 0.5, 'cuda')
253 |
254 | target_embbeding = model.forward((F.interpolate(target, size=size, mode='bilinear')))
255 |
256 | FAR01 = 0
257 | FAR001 = 0
258 | FAR0001 = 0
259 | total = 0
260 | if attack:
261 | for img_path in glob.glob(os.path.join(args.save_path, "*.png")):
262 |
263 | adv_example = read_img(img_path, 0.5, 0.5, 'cuda')
264 | ae_embbeding = model.forward((F.interpolate(adv_example, size=size, mode='bilinear')))
265 | fn = img_path.split("/")[-1]
266 | clean_img = cv2.imread(os.path.join(args.clean_path, fn))
267 |
268 |
269 | cos_simi = torch.cosine_similarity(ae_embbeding, target_embbeding)
270 |
271 | if cos_simi.item() > th_dict[test_model][0]:
272 | FAR01 += 1
273 | if cos_simi.item() > th_dict[test_model][1]:
274 | FAR001 += 1
275 | if cos_simi.item() > th_dict[test_model][2]:
276 | FAR0001 += 1
277 | total += 1
278 |
279 | # combine the clean and adv image for visualization
280 | adv_img = cv2.imread(img_path)
281 | # fn = img_path.split("/")[-1]
282 | # clean_img = cv2.imread(os.path.join(args.clean_path, fn))
283 | if 'AMT' in args.save_path:
284 | continue
285 | combined_img = np.concatenate([clean_img, adv_img], 1)
286 | combined_fn = f"{fn.split('.')[0]}_{cos_simi.item():.4f}.png"
287 | cv2.imwrite(os.path.join(combined_dir, combined_fn), combined_img)
288 |
289 | else:
290 | for img in tqdm(os.listdir(args.clean_path), desc=test_model + ' clean'):
291 | adv_example = read_img(os.path.join(args.clean_path, img), 0.5, 0.5, 'cuda')
292 | ae_embbeding = model.forward((F.interpolate(adv_example, size=size, mode='bilinear')))
293 |
294 | cos_simi = torch.cosine_similarity(ae_embbeding, target_embbeding)
295 | if cos_simi.item() > th_dict[test_model][0]:
296 | FAR01 += 1
297 | if cos_simi.item() > th_dict[test_model][1]:
298 | FAR001 += 1
299 | if cos_simi.item() > th_dict[test_model][2]:
300 | FAR0001 += 1
301 | total += 1
302 |
303 |
304 | result_str = f"{test_model} ASR in FAR@0.1: {FAR01/total:.4f}, ASR in FAR@0.01: {FAR001/total:.4f}, ASR in FAR@0.001: {FAR0001/total:.4f}\n"
305 | print(result_str)
306 | f.write(result_str)
307 |
308 |
309 | if __name__ == '__main__':
310 | parser = argparse.ArgumentParser()
311 | parser.add_argument("--source_dir", default="assets/datasets/CelebA-HQ_align", help="path to source images")
312 | parser.add_argument("--save_path", default="assets/datasets/val", help="path to generated images")
313 | parser.add_argument("--clean_path", default="assets/datasets/CelebA-HQ_align", help="path to clean images")
314 | parser.add_argument("--target_path", default="assets/datasets/target/085807.jpg", help="path to target images")
315 | parser.add_argument("--test_path", default="assets/datasets/test/047073.jpg", help="path to test images")
316 | parser.add_argument('--device', type=str, default='0', help='cuda device')
317 | parser.add_argument("--model_path", default="checkpoints/G.pth", help="model for loading")
318 | parser.add_argument("--test_model_names", nargs='+', default=['mobile_face'], help="model for testing")
319 | parser.add_argument("--model_names", nargs='+', default=['ir152', 'irse50', 'facenet'], help="model for attacking")
320 | parser.add_argument("--eps", type=float, default=0.02, help="latent attack budget")
321 | parser.add_argument("--iter", type=int, default=50, help="Attack iterations (Global loop)")
322 | parser.add_argument("--T_enc", type=int, default=100, help="DDIM steps during image encoding")
323 | parser.add_argument("--T_atk", type=int, default=5, help="DDIM steps during attack")
324 | parser.add_argument("--T_inf", type=int, default=100, help="DDIM steps during inference")
325 | parser.add_argument("--start_t", type=int, default=5, help="starting point of attack")
326 | parser.add_argument("--attack_iter", type=int, default=1, help="number of attack iteration")
327 | parser.add_argument("--attack_inf_iter", type=int, default=4, help="number of attack inference iteration")
328 | parser.add_argument("--repeat_times", type=int, default=1, help="number of repeatance of attack inference and attack steps ")
329 | parser.add_argument("--cnt_skip", type=int, default=0, help="number of skip of attack inference and attack steps ")
330 | parser.add_argument("--lam", type=float, default=0., help="hyperparameter of face parsing ")
331 | parser.add_argument("--vis_full", action='store_true', default=False,
332 | help="compare with other method")
333 |
334 | # T_atk = (attack_iter + attack_inf_iter) * repeat_times
335 |
336 | args = parser.parse_args()
337 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device
338 |
339 | args.save_path = os.path.join(args.save_path,
340 | f"save_eps{args.eps}_Tenc{args.T_enc}_iter{args.iter}_Tatk{args.T_atk}_Tstart{args.start_t}_Tinf{args.T_inf}"
341 | f"_atk{args.attack_iter}_atkinf{args.attack_inf_iter}_repeat{args.repeat_times}_skip{args.cnt_skip}_lam{args.lam}")
342 |
343 | generate(args)
344 | attack_local_models(args, attack=False)
345 | attack_local_models(args, attack=True)
346 |
--------------------------------------------------------------------------------
/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig
4 |
5 | Sampler = Union[SpacedDiffusionBeatGans]
6 | SamplerConfig = Union[SpacedDiffusionBeatGansConfig]
7 |
--------------------------------------------------------------------------------
/diffusion/diffusion.py:
--------------------------------------------------------------------------------
1 | from .base import *
2 | from dataclasses import dataclass
3 |
4 |
5 | def space_timesteps(num_timesteps, section_counts):
6 | """
7 | Create a list of timesteps to use from an original diffusion process,
8 | given the number of timesteps we want to take from equally-sized portions
9 | of the original process.
10 |
11 | For example, if there's 300 timesteps and the section counts are [10,15,20]
12 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
13 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
14 |
15 | If the stride is a string starting with "ddim", then the fixed striding
16 | from the DDIM paper is used, and only one section is allowed.
17 |
18 | :param num_timesteps: the number of diffusion steps in the original
19 | process to divide up.
20 | :param section_counts: either a list of numbers, or a string containing
21 | comma-separated numbers, indicating the step count
22 | per section. As a special case, use "ddimN" where N
23 | is a number of steps to use the striding from the
24 | DDIM paper.
25 | :return: a set of diffusion steps from the original process to use.
26 | """
27 | if isinstance(section_counts, str):
28 | if section_counts.startswith("ddim"):
29 | desired_count = int(section_counts[len("ddim"):])
30 | for i in range(1, num_timesteps):
31 | if len(range(0, num_timesteps, i)) == desired_count:
32 | return set(range(0, num_timesteps, i))
33 | raise ValueError(
34 | f"cannot create exactly {num_timesteps} steps with an integer stride"
35 | )
36 | section_counts = [int(x) for x in section_counts.split(",")]
37 | size_per = num_timesteps // len(section_counts)
38 | extra = num_timesteps % len(section_counts)
39 | start_idx = 0
40 | all_steps = []
41 | for i, section_count in enumerate(section_counts):
42 | size = size_per + (1 if i < extra else 0)
43 | if size < section_count:
44 | raise ValueError(
45 | f"cannot divide section of {size} steps into {section_count}")
46 | if section_count <= 1:
47 | frac_stride = 1
48 | else:
49 | frac_stride = (size - 1) / (section_count - 1)
50 | cur_idx = 0.0
51 | taken_steps = []
52 | for _ in range(section_count):
53 | taken_steps.append(start_idx + round(cur_idx))
54 | cur_idx += frac_stride
55 | all_steps += taken_steps
56 | start_idx += size
57 | return set(all_steps)
58 |
59 |
60 | @dataclass
61 | class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig):
62 | use_timesteps: Tuple[int] = None
63 |
64 | def make_sampler(self):
65 | return SpacedDiffusionBeatGans(self)
66 |
67 |
68 | class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans):
69 | """
70 | A diffusion process which can skip steps in a base diffusion process.
71 |
72 | :param use_timesteps: a collection (sequence or set) of timesteps from the
73 | original diffusion process to retain.
74 | :param kwargs: the kwargs to create the base diffusion process.
75 | """
76 | def __init__(self, conf: SpacedDiffusionBeatGansConfig):
77 | self.conf = conf
78 | self.use_timesteps = set(conf.use_timesteps)
79 | # how the new t's mapped to the old t's
80 | self.timestep_map = []
81 | self.original_num_steps = len(conf.betas)
82 |
83 | base_diffusion = GaussianDiffusionBeatGans(conf) # pylint: disable=missing-kwoa
84 | last_alpha_cumprod = 1.0
85 | new_betas = []
86 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
87 | if i in self.use_timesteps:
88 | # getting the new betas of the new timesteps
89 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
90 | last_alpha_cumprod = alpha_cumprod
91 | self.timestep_map.append(i)
92 | conf.betas = np.array(new_betas)
93 | super().__init__(conf)
94 |
95 | def p_mean_variance(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
96 | return super().p_mean_variance(self._wrap_model(model), *args,
97 | **kwargs)
98 |
99 | def training_losses(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
100 | return super().training_losses(self._wrap_model(model), *args,
101 | **kwargs)
102 |
103 | def condition_mean(self, cond_fn, *args, **kwargs):
104 | return super().condition_mean(self._wrap_model(cond_fn), *args,
105 | **kwargs)
106 |
107 | def condition_score(self, cond_fn, *args, **kwargs):
108 | return super().condition_score(self._wrap_model(cond_fn), *args,
109 | **kwargs)
110 |
111 | def _wrap_model(self, model: Model):
112 | if isinstance(model, _WrappedModel):
113 | return model
114 | return _WrappedModel(model, self.timestep_map, self.rescale_timesteps,
115 | self.original_num_steps)
116 |
117 | def _scale_timesteps(self, t):
118 | # Scaling is done by the wrapped model.
119 | return t
120 |
121 |
122 | class _WrappedModel:
123 | """
124 | converting the supplied t's to the old t's scales.
125 | """
126 | def __init__(self, model, timestep_map, rescale_timesteps,
127 | original_num_steps):
128 | self.model = model
129 | self.timestep_map = timestep_map
130 | self.rescale_timesteps = rescale_timesteps
131 | self.original_num_steps = original_num_steps
132 |
133 | def forward(self, x, t, t_cond=None, **kwargs):
134 | """
135 | Args:
136 | t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
137 | t_cond: the same as t but can be of different values
138 | """
139 | map_tensor = th.tensor(self.timestep_map,
140 | device=t.device,
141 | dtype=t.dtype)
142 |
143 | def do(t):
144 | new_ts = map_tensor[t]
145 | if self.rescale_timesteps:
146 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
147 | return new_ts
148 |
149 | if t_cond is not None:
150 | # support t_cond
151 | t_cond = do(t_cond)
152 |
153 | return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs)
154 |
155 | def __getattr__(self, name):
156 | # allow for calling the model's methods
157 | if hasattr(self.model, name):
158 | func = getattr(self.model, name)
159 | return func
160 | raise AttributeError(name)
161 |
--------------------------------------------------------------------------------
/diffusion/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | else:
18 | raise NotImplementedError(f"unknown schedule sampler: {name}")
19 |
20 |
21 | class ScheduleSampler(ABC):
22 | """
23 | A distribution over timesteps in the diffusion process, intended to reduce
24 | variance of the objective.
25 |
26 | By default, samplers perform unbiased importance sampling, in which the
27 | objective's mean is unchanged.
28 | However, subclasses may override sample() to change how the resampled
29 | terms are reweighted, allowing for actual changes in the objective.
30 | """
31 | @abstractmethod
32 | def weights(self):
33 | """
34 | Get a numpy array of weights, one per diffusion step.
35 |
36 | The weights needn't be normalized, but must be positive.
37 | """
38 |
39 | def sample(self, batch_size, device):
40 | """
41 | Importance-sample timesteps for a batch.
42 |
43 | :param batch_size: the number of timesteps.
44 | :param device: the torch device to save to.
45 | :return: a tuple (timesteps, weights):
46 | - timesteps: a tensor of timestep indices.
47 | - weights: a tensor of weights to scale the resulting losses.
48 | """
49 | w = self.weights()
50 | p = w / np.sum(w)
51 | indices_np = np.random.choice(len(p), size=(batch_size, ), p=p)
52 | indices = th.from_numpy(indices_np).long().to(device)
53 | weights_np = 1 / (len(p) * p[indices_np])
54 | weights = th.from_numpy(weights_np).float().to(device)
55 | return indices, weights
56 |
57 |
58 | class UniformSampler(ScheduleSampler):
59 | def __init__(self, num_timesteps):
60 | self._weights = np.ones([num_timesteps])
61 |
62 | def weights(self):
63 | return self._weights
64 |
--------------------------------------------------------------------------------
/dist_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from torch import distributed
3 |
4 |
5 | def barrier():
6 | if distributed.is_initialized():
7 | distributed.barrier()
8 | else:
9 | pass
10 |
11 |
12 | def broadcast(data, src):
13 | if distributed.is_initialized():
14 | distributed.broadcast(data, src)
15 | else:
16 | pass
17 |
18 |
19 | def all_gather(data: List, src):
20 | if distributed.is_initialized():
21 | distributed.all_gather(data, src)
22 | else:
23 | data[0] = src
24 |
25 |
26 | def get_rank():
27 | if distributed.is_initialized():
28 | return distributed.get_rank()
29 | else:
30 | return 0
31 |
32 |
33 | def get_world_size():
34 | if distributed.is_initialized():
35 | return distributed.get_world_size()
36 | else:
37 | return 1
38 |
39 |
40 | def chunk_size(size, rank, world_size):
41 | extra = rank < size % world_size
42 | return size // world_size + extra
--------------------------------------------------------------------------------
/external/face_makeup/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/external/face_makeup/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 zll
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/external/face_makeup/README.md:
--------------------------------------------------------------------------------
1 | # face-makeup.PyTorch
2 | Lip and hair color editor using face parsing maps.
3 |
4 |
5 |
6 |
7 | |
8 | Hair |
9 | Lip |
10 |
11 |
12 |
13 |
14 | Original Input |
15 |  |
16 |  |
17 |
18 |
19 |
20 |
21 | Color |
22 |  |
23 |  |
24 |
25 |
26 |
27 |
28 | Color |
29 |  |
30 |  |
31 |
32 |
33 |
34 |
35 | Color |
36 |  |
37 |  |
38 |
39 |
40 |
41 |
42 | ### Using PyTorch 1.0 and python 3.x
43 |
44 | ## Demo
45 | Change hair and lip color:
46 | ```Shell
47 | python makeup.py --img-path imgs/116.jpg
48 | ```
49 | ### Try to use other colors:
50 | Change the color list in **makeup.py**(line 83)
51 | ```
52 | colors = [[230, 50, 20], [20, 70, 180], [20, 70, 180]]
53 | ```
54 | ### Train face parsing model (optional)
55 | Follow this repo [zllrunning/face-parsing.PyTorch](https://github.com/zllrunning/face-parsing.PyTorch)
--------------------------------------------------------------------------------
/external/face_makeup/cp/79999_iter.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joellliu/DiffProtect/79b80f4b68caa4b6549c808eb0964a4398914f41/external/face_makeup/cp/79999_iter.pth
--------------------------------------------------------------------------------
/external/face_makeup/imgs/116.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joellliu/DiffProtect/79b80f4b68caa4b6549c808eb0964a4398914f41/external/face_makeup/imgs/116.jpg
--------------------------------------------------------------------------------
/external/face_makeup/imgs/6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joellliu/DiffProtect/79b80f4b68caa4b6549c808eb0964a4398914f41/external/face_makeup/imgs/6.jpg
--------------------------------------------------------------------------------
/external/face_makeup/makeup.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import numpy as np
4 | from skimage.filters import gaussian
5 | from test import evaluate
6 | import argparse
7 |
8 |
9 | def parse_args():
10 | parse = argparse.ArgumentParser()
11 | parse.add_argument('--img-path', default='imgs/116.jpg')
12 | return parse.parse_args()
13 |
14 |
15 | def sharpen(img):
16 | img = img * 1.0
17 | gauss_out = gaussian(img, sigma=5, multichannel=True)
18 |
19 | alpha = 1.5
20 | img_out = (img - gauss_out) * alpha + img
21 |
22 | img_out = img_out / 255.0
23 |
24 | mask_1 = img_out < 0
25 | mask_2 = img_out > 1
26 |
27 | img_out = img_out * (1 - mask_1)
28 | img_out = img_out * (1 - mask_2) + mask_2
29 | img_out = np.clip(img_out, 0, 1)
30 | img_out = img_out * 255
31 | return np.array(img_out, dtype=np.uint8)
32 |
33 |
34 | def hair(image, parsing, part=17, color=[230, 50, 20]):
35 | b, g, r = color #[10, 50, 250] # [10, 250, 10]
36 | tar_color = np.zeros_like(image)
37 | tar_color[:, :, 0] = b
38 | tar_color[:, :, 1] = g
39 | tar_color[:, :, 2] = r
40 |
41 | image_hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
42 | tar_hsv = cv2.cvtColor(tar_color, cv2.COLOR_BGR2HSV)
43 |
44 | if part == 12 or part == 13:
45 | image_hsv[:, :, 0:2] = tar_hsv[:, :, 0:2]
46 | else:
47 | image_hsv[:, :, 0:1] = tar_hsv[:, :, 0:1]
48 |
49 | changed = cv2.cvtColor(image_hsv, cv2.COLOR_HSV2BGR)
50 |
51 | if part == 17:
52 | changed = sharpen(changed)
53 |
54 | changed[parsing != part] = image[parsing != part]
55 | return changed
56 |
57 |
58 | if __name__ == '__main__':
59 | # 1 face
60 | # 11 teeth
61 | # 12 upper lip
62 | # 13 lower lip
63 | # 17 hair
64 |
65 | args = parse_args()
66 |
67 | table = {
68 | 'hair': 17,
69 | 'upper_lip': 12,
70 | 'lower_lip': 13
71 | }
72 |
73 | image_path = args.img_path
74 | cp = 'cp/79999_iter.pth'
75 |
76 | image = cv2.imread(image_path)
77 | ori = image.copy()
78 | parsing = evaluate(image_path, cp)
79 | parsing = cv2.resize(parsing, image.shape[0:2], interpolation=cv2.INTER_NEAREST)
80 |
81 | parts = [table['hair'], table['upper_lip'], table['lower_lip']]
82 |
83 | colors = [[230, 50, 20], [20, 70, 180], [20, 70, 180]]
84 |
85 | for part, color in zip(parts, colors):
86 | image = hair(image, parsing, part, color)
87 |
88 | cv2.imshow('image', cv2.resize(ori, (512, 512)))
89 | cv2.imshow('color', cv2.resize(image, (512, 512)))
90 |
91 | cv2.waitKey(0)
92 | cv2.destroyAllWindows()
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
--------------------------------------------------------------------------------
/external/face_makeup/makeup/116_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joellliu/DiffProtect/79b80f4b68caa4b6549c808eb0964a4398914f41/external/face_makeup/makeup/116_0.png
--------------------------------------------------------------------------------
/external/face_makeup/makeup/116_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joellliu/DiffProtect/79b80f4b68caa4b6549c808eb0964a4398914f41/external/face_makeup/makeup/116_1.png
--------------------------------------------------------------------------------
/external/face_makeup/makeup/116_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joellliu/DiffProtect/79b80f4b68caa4b6549c808eb0964a4398914f41/external/face_makeup/makeup/116_2.png
--------------------------------------------------------------------------------
/external/face_makeup/makeup/116_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joellliu/DiffProtect/79b80f4b68caa4b6549c808eb0964a4398914f41/external/face_makeup/makeup/116_3.png
--------------------------------------------------------------------------------
/external/face_makeup/makeup/116_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joellliu/DiffProtect/79b80f4b68caa4b6549c808eb0964a4398914f41/external/face_makeup/makeup/116_4.png
--------------------------------------------------------------------------------
/external/face_makeup/makeup/116_5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joellliu/DiffProtect/79b80f4b68caa4b6549c808eb0964a4398914f41/external/face_makeup/makeup/116_5.png
--------------------------------------------------------------------------------
/external/face_makeup/makeup/116_6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joellliu/DiffProtect/79b80f4b68caa4b6549c808eb0964a4398914f41/external/face_makeup/makeup/116_6.png
--------------------------------------------------------------------------------
/external/face_makeup/makeup/116_lip_ori.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joellliu/DiffProtect/79b80f4b68caa4b6549c808eb0964a4398914f41/external/face_makeup/makeup/116_lip_ori.png
--------------------------------------------------------------------------------
/external/face_makeup/makeup/116_ori.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joellliu/DiffProtect/79b80f4b68caa4b6549c808eb0964a4398914f41/external/face_makeup/makeup/116_ori.png
--------------------------------------------------------------------------------
/external/face_makeup/model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torchvision
9 |
10 | from .resnet import Resnet18
11 | # from modules.bn import InPlaceABNSync as BatchNorm2d
12 |
13 |
14 | class ConvBNReLU(nn.Module):
15 | def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
16 | super(ConvBNReLU, self).__init__()
17 | self.conv = nn.Conv2d(in_chan,
18 | out_chan,
19 | kernel_size = ks,
20 | stride = stride,
21 | padding = padding,
22 | bias = False)
23 | self.bn = nn.BatchNorm2d(out_chan)
24 | self.init_weight()
25 |
26 | def forward(self, x):
27 | x = self.conv(x)
28 | x = F.relu(self.bn(x))
29 | return x
30 |
31 | def init_weight(self):
32 | for ly in self.children():
33 | if isinstance(ly, nn.Conv2d):
34 | nn.init.kaiming_normal_(ly.weight, a=1)
35 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
36 |
37 | class BiSeNetOutput(nn.Module):
38 | def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
39 | super(BiSeNetOutput, self).__init__()
40 | self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
41 | self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
42 | self.init_weight()
43 |
44 | def forward(self, x):
45 | x = self.conv(x)
46 | x = self.conv_out(x)
47 | return x
48 |
49 | def init_weight(self):
50 | for ly in self.children():
51 | if isinstance(ly, nn.Conv2d):
52 | nn.init.kaiming_normal_(ly.weight, a=1)
53 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
54 |
55 | def get_params(self):
56 | wd_params, nowd_params = [], []
57 | for name, module in self.named_modules():
58 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
59 | wd_params.append(module.weight)
60 | if not module.bias is None:
61 | nowd_params.append(module.bias)
62 | elif isinstance(module, nn.BatchNorm2d):
63 | nowd_params += list(module.parameters())
64 | return wd_params, nowd_params
65 |
66 |
67 | class AttentionRefinementModule(nn.Module):
68 | def __init__(self, in_chan, out_chan, *args, **kwargs):
69 | super(AttentionRefinementModule, self).__init__()
70 | self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
71 | self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
72 | self.bn_atten = nn.BatchNorm2d(out_chan)
73 | self.sigmoid_atten = nn.Sigmoid()
74 | self.init_weight()
75 |
76 | def forward(self, x):
77 | feat = self.conv(x)
78 | atten = F.avg_pool2d(feat, feat.size()[2:])
79 | atten = self.conv_atten(atten)
80 | atten = self.bn_atten(atten)
81 | atten = self.sigmoid_atten(atten)
82 | out = torch.mul(feat, atten)
83 | return out
84 |
85 | def init_weight(self):
86 | for ly in self.children():
87 | if isinstance(ly, nn.Conv2d):
88 | nn.init.kaiming_normal_(ly.weight, a=1)
89 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
90 |
91 |
92 | class ContextPath(nn.Module):
93 | def __init__(self, *args, **kwargs):
94 | super(ContextPath, self).__init__()
95 | self.resnet = Resnet18()
96 | self.arm16 = AttentionRefinementModule(256, 128)
97 | self.arm32 = AttentionRefinementModule(512, 128)
98 | self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
99 | self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
100 | self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
101 |
102 | self.init_weight()
103 |
104 | def forward(self, x):
105 | H0, W0 = x.size()[2:]
106 | feat8, feat16, feat32 = self.resnet(x)
107 | H8, W8 = feat8.size()[2:]
108 | H16, W16 = feat16.size()[2:]
109 | H32, W32 = feat32.size()[2:]
110 |
111 | avg = F.avg_pool2d(feat32, feat32.size()[2:])
112 | avg = self.conv_avg(avg)
113 | avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
114 |
115 | feat32_arm = self.arm32(feat32)
116 | feat32_sum = feat32_arm + avg_up
117 | feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
118 | feat32_up = self.conv_head32(feat32_up)
119 |
120 | feat16_arm = self.arm16(feat16)
121 | feat16_sum = feat16_arm + feat32_up
122 | feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
123 | feat16_up = self.conv_head16(feat16_up)
124 |
125 | return feat8, feat16_up, feat32_up # x8, x8, x16
126 |
127 | def init_weight(self):
128 | for ly in self.children():
129 | if isinstance(ly, nn.Conv2d):
130 | nn.init.kaiming_normal_(ly.weight, a=1)
131 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
132 |
133 | def get_params(self):
134 | wd_params, nowd_params = [], []
135 | for name, module in self.named_modules():
136 | if isinstance(module, (nn.Linear, nn.Conv2d)):
137 | wd_params.append(module.weight)
138 | if not module.bias is None:
139 | nowd_params.append(module.bias)
140 | elif isinstance(module, nn.BatchNorm2d):
141 | nowd_params += list(module.parameters())
142 | return wd_params, nowd_params
143 |
144 |
145 | ### This is not used, since I replace this with the resnet feature with the same size
146 | class SpatialPath(nn.Module):
147 | def __init__(self, *args, **kwargs):
148 | super(SpatialPath, self).__init__()
149 | self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
150 | self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
151 | self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
152 | self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
153 | self.init_weight()
154 |
155 | def forward(self, x):
156 | feat = self.conv1(x)
157 | feat = self.conv2(feat)
158 | feat = self.conv3(feat)
159 | feat = self.conv_out(feat)
160 | return feat
161 |
162 | def init_weight(self):
163 | for ly in self.children():
164 | if isinstance(ly, nn.Conv2d):
165 | nn.init.kaiming_normal_(ly.weight, a=1)
166 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
167 |
168 | def get_params(self):
169 | wd_params, nowd_params = [], []
170 | for name, module in self.named_modules():
171 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
172 | wd_params.append(module.weight)
173 | if not module.bias is None:
174 | nowd_params.append(module.bias)
175 | elif isinstance(module, nn.BatchNorm2d):
176 | nowd_params += list(module.parameters())
177 | return wd_params, nowd_params
178 |
179 |
180 | class FeatureFusionModule(nn.Module):
181 | def __init__(self, in_chan, out_chan, *args, **kwargs):
182 | super(FeatureFusionModule, self).__init__()
183 | self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
184 | self.conv1 = nn.Conv2d(out_chan,
185 | out_chan//4,
186 | kernel_size = 1,
187 | stride = 1,
188 | padding = 0,
189 | bias = False)
190 | self.conv2 = nn.Conv2d(out_chan//4,
191 | out_chan,
192 | kernel_size = 1,
193 | stride = 1,
194 | padding = 0,
195 | bias = False)
196 | self.relu = nn.ReLU(inplace=True)
197 | self.sigmoid = nn.Sigmoid()
198 | self.init_weight()
199 |
200 | def forward(self, fsp, fcp):
201 | fcat = torch.cat([fsp, fcp], dim=1)
202 | feat = self.convblk(fcat)
203 | atten = F.avg_pool2d(feat, feat.size()[2:])
204 | atten = self.conv1(atten)
205 | atten = self.relu(atten)
206 | atten = self.conv2(atten)
207 | atten = self.sigmoid(atten)
208 | feat_atten = torch.mul(feat, atten)
209 | feat_out = feat_atten + feat
210 | return feat_out
211 |
212 | def init_weight(self):
213 | for ly in self.children():
214 | if isinstance(ly, nn.Conv2d):
215 | nn.init.kaiming_normal_(ly.weight, a=1)
216 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
217 |
218 | def get_params(self):
219 | wd_params, nowd_params = [], []
220 | for name, module in self.named_modules():
221 | if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
222 | wd_params.append(module.weight)
223 | if not module.bias is None:
224 | nowd_params.append(module.bias)
225 | elif isinstance(module, nn.BatchNorm2d):
226 | nowd_params += list(module.parameters())
227 | return wd_params, nowd_params
228 |
229 |
230 | class BiSeNet(nn.Module):
231 | def __init__(self, n_classes, *args, **kwargs):
232 | super(BiSeNet, self).__init__()
233 | self.cp = ContextPath()
234 | ## here self.sp is deleted
235 | self.ffm = FeatureFusionModule(256, 256)
236 | self.conv_out = BiSeNetOutput(256, 256, n_classes)
237 | self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
238 | self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
239 | self.init_weight()
240 |
241 | def forward(self, x):
242 | H, W = x.size()[2:]
243 | feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
244 | feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
245 | feat_fuse = self.ffm(feat_sp, feat_cp8)
246 |
247 | feat_out = self.conv_out(feat_fuse)
248 | feat_out16 = self.conv_out16(feat_cp8)
249 | feat_out32 = self.conv_out32(feat_cp16)
250 |
251 | feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
252 | feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
253 | feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
254 | return feat_out, feat_out16, feat_out32
255 |
256 | def init_weight(self):
257 | for ly in self.children():
258 | if isinstance(ly, nn.Conv2d):
259 | nn.init.kaiming_normal_(ly.weight, a=1)
260 | if not ly.bias is None: nn.init.constant_(ly.bias, 0)
261 |
262 | def get_params(self):
263 | wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
264 | for name, child in self.named_children():
265 | child_wd_params, child_nowd_params = child.get_params()
266 | if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
267 | lr_mul_wd_params += child_wd_params
268 | lr_mul_nowd_params += child_nowd_params
269 | else:
270 | wd_params += child_wd_params
271 | nowd_params += child_nowd_params
272 | return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
273 |
274 |
275 | if __name__ == "__main__":
276 | net = BiSeNet(19)
277 | net.cuda()
278 | net.eval()
279 | in_ten = torch.randn(16, 3, 640, 480).cuda()
280 | out, out16, out32 = net(in_ten)
281 | print(out.shape)
282 |
283 | net.get_params()
284 |
--------------------------------------------------------------------------------
/external/face_makeup/resnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.utils.model_zoo as modelzoo
8 |
9 | # from modules.bn import InPlaceABNSync as BatchNorm2d
10 |
11 | resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12 |
13 |
14 | def conv3x3(in_planes, out_planes, stride=1):
15 | """3x3 convolution with padding"""
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=1, bias=False)
18 |
19 |
20 | class BasicBlock(nn.Module):
21 | def __init__(self, in_chan, out_chan, stride=1):
22 | super(BasicBlock, self).__init__()
23 | self.conv1 = conv3x3(in_chan, out_chan, stride)
24 | self.bn1 = nn.BatchNorm2d(out_chan)
25 | self.conv2 = conv3x3(out_chan, out_chan)
26 | self.bn2 = nn.BatchNorm2d(out_chan)
27 | self.relu = nn.ReLU(inplace=True)
28 | self.downsample = None
29 | if in_chan != out_chan or stride != 1:
30 | self.downsample = nn.Sequential(
31 | nn.Conv2d(in_chan, out_chan,
32 | kernel_size=1, stride=stride, bias=False),
33 | nn.BatchNorm2d(out_chan),
34 | )
35 |
36 | def forward(self, x):
37 | residual = self.conv1(x)
38 | residual = F.relu(self.bn1(residual))
39 | residual = self.conv2(residual)
40 | residual = self.bn2(residual)
41 |
42 | shortcut = x
43 | if self.downsample is not None:
44 | shortcut = self.downsample(x)
45 |
46 | out = shortcut + residual
47 | out = self.relu(out)
48 | return out
49 |
50 |
51 | def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52 | layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53 | for i in range(bnum-1):
54 | layers.append(BasicBlock(out_chan, out_chan, stride=1))
55 | return nn.Sequential(*layers)
56 |
57 |
58 | class Resnet18(nn.Module):
59 | def __init__(self):
60 | super(Resnet18, self).__init__()
61 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62 | bias=False)
63 | self.bn1 = nn.BatchNorm2d(64)
64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65 | self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66 | self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67 | self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68 | self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69 | self.init_weight()
70 |
71 | def forward(self, x):
72 | x = self.conv1(x)
73 | x = F.relu(self.bn1(x))
74 | x = self.maxpool(x)
75 |
76 | x = self.layer1(x)
77 | feat8 = self.layer2(x) # 1/8
78 | feat16 = self.layer3(feat8) # 1/16
79 | feat32 = self.layer4(feat16) # 1/32
80 | return feat8, feat16, feat32
81 |
82 | def init_weight(self):
83 | state_dict = modelzoo.load_url(resnet18_url)
84 | self_state_dict = self.state_dict()
85 | for k, v in state_dict.items():
86 | if 'fc' in k: continue
87 | self_state_dict.update({k: v})
88 | self.load_state_dict(self_state_dict)
89 |
90 | def get_params(self):
91 | wd_params, nowd_params = [], []
92 | for name, module in self.named_modules():
93 | if isinstance(module, (nn.Linear, nn.Conv2d)):
94 | wd_params.append(module.weight)
95 | if not module.bias is None:
96 | nowd_params.append(module.bias)
97 | elif isinstance(module, nn.BatchNorm2d):
98 | nowd_params += list(module.parameters())
99 | return wd_params, nowd_params
100 |
101 |
102 | if __name__ == "__main__":
103 | net = Resnet18()
104 | x = torch.randn(16, 3, 224, 224)
105 | out = net(x)
106 | print(out[0].size())
107 | print(out[1].size())
108 | print(out[2].size())
109 | net.get_params()
110 |
--------------------------------------------------------------------------------
/external/face_makeup/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- encoding: utf-8 -*-
3 |
4 | import torch
5 | import os
6 | from .model import BiSeNet
7 | import os.path as osp
8 | import numpy as np
9 | from PIL import Image
10 | import torchvision.transforms as transforms
11 | import cv2
12 |
13 |
14 | def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='parsing_map_on_im.png'):
15 | # Colors for all 20 parts
16 | part_colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0],
17 | [255, 0, 85], [255, 0, 170],
18 | [0, 255, 0], [85, 255, 0], [170, 255, 0],
19 | [0, 255, 85], [0, 255, 170],
20 | [0, 0, 255], [85, 0, 255], [170, 0, 255],
21 | [0, 85, 255], [0, 170, 255],
22 | [255, 255, 0], [255, 255, 85], [255, 255, 170],
23 | [255, 0, 255], [255, 85, 255], [255, 170, 255],
24 | [0, 255, 255], [85, 255, 255], [170, 255, 255]]
25 |
26 | im = np.array(im)
27 | vis_im = im.copy().astype(np.uint8)
28 | vis_parsing_anno = parsing_anno.copy().astype(np.uint8)
29 | vis_parsing_anno = cv2.resize(vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST)
30 | vis_parsing_anno_color = np.zeros((vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + 255
31 |
32 | num_of_class = np.max(vis_parsing_anno)
33 |
34 | for pi in range(1, num_of_class + 1):
35 | index = np.where(vis_parsing_anno == pi)
36 | vis_parsing_anno_color[index[0], index[1], :] = part_colors[pi]
37 |
38 | vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8)
39 | # print(vis_parsing_anno_color.shape, vis_im.shape)
40 | vis_im = cv2.addWeighted(cv2.cvtColor(vis_im, cv2.COLOR_RGB2BGR), 0.4, vis_parsing_anno_color, 0.6, 0)
41 |
42 | # Save result or not
43 | if save_im:
44 | cv2.imwrite(save_path[:-4] +'.png', vis_parsing_anno)
45 | cv2.imwrite(save_path, vis_im, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
46 | return vis_parsing_anno
47 | # return vis_im
48 |
49 |
50 | def evaluate(image_path='./imgs/116.jpg', cp='cp/79999_iter.pth'):
51 |
52 | # if not os.path.exists(respth):
53 | # os.makedirs(respth)
54 |
55 | n_classes = 19
56 | net = BiSeNet(n_classes=n_classes)
57 | net.cuda()
58 | net.load_state_dict(torch.load(cp))
59 | net.eval()
60 |
61 | to_tensor = transforms.Compose([
62 | transforms.ToTensor(),
63 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
64 | ])
65 |
66 | with torch.no_grad():
67 | img = Image.open(image_path)
68 | image = img.resize((512, 512), Image.BILINEAR)
69 | img = to_tensor(image)
70 | img = torch.unsqueeze(img, 0)
71 | img = img.cuda()
72 | out = net(img)[0]
73 | parsing = out.squeeze(0).cpu().numpy().argmax(0)
74 | # print(parsing)
75 | # print(np.unique(parsing))
76 |
77 | # vis_parsing_maps(image, parsing, stride=1, save_im=False, save_path=osp.join(respth, dspth))
78 | return parsing
79 |
80 | if __name__ == "__main__":
81 | evaluate(image_path='../../assets/datasets/save_1000_PGD_attribute/as_05_adv_01_mask_16_progressive/000865.png', cp='79999_iter.pth')
82 |
83 |
84 |
--------------------------------------------------------------------------------
/lmdb_writer.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | import lmdb
4 | from PIL import Image
5 |
6 | import torch
7 |
8 | from contextlib import contextmanager
9 | from torch.utils.data import Dataset
10 | from multiprocessing import Process, Queue
11 | import os
12 | import shutil
13 |
14 |
15 | def convert(x, format, quality=100):
16 | # to prevent locking!
17 | torch.set_num_threads(1)
18 |
19 | buffer = BytesIO()
20 | x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0)
21 | x = x.to(torch.uint8)
22 | x = x.numpy()
23 | img = Image.fromarray(x)
24 | img.save(buffer, format=format, quality=quality)
25 | val = buffer.getvalue()
26 | return val
27 |
28 |
29 | @contextmanager
30 | def nullcontext():
31 | yield
32 |
33 |
34 | class _WriterWroker(Process):
35 | def __init__(self, path, format, quality, zfill, q):
36 | super().__init__()
37 | if os.path.exists(path):
38 | shutil.rmtree(path)
39 |
40 | self.path = path
41 | self.format = format
42 | self.quality = quality
43 | self.zfill = zfill
44 | self.q = q
45 | self.i = 0
46 |
47 | def run(self):
48 | if not os.path.exists(self.path):
49 | os.makedirs(self.path)
50 |
51 | with lmdb.open(self.path, map_size=1024**4, readahead=False) as env:
52 | while True:
53 | job = self.q.get()
54 | if job is None:
55 | break
56 | with env.begin(write=True) as txn:
57 | for x in job:
58 | key = f"{str(self.i).zfill(self.zfill)}".encode(
59 | "utf-8")
60 | x = convert(x, self.format, self.quality)
61 | txn.put(key, x)
62 | self.i += 1
63 |
64 | with env.begin(write=True) as txn:
65 | txn.put("length".encode("utf-8"), str(self.i).encode("utf-8"))
66 |
67 |
68 | class LMDBImageWriter:
69 | def __init__(self, path, format='webp', quality=100, zfill=7) -> None:
70 | self.path = path
71 | self.format = format
72 | self.quality = quality
73 | self.zfill = zfill
74 | self.queue = None
75 | self.worker = None
76 |
77 | def __enter__(self):
78 | self.queue = Queue(maxsize=3)
79 | self.worker = _WriterWroker(self.path, self.format, self.quality,
80 | self.zfill, self.queue)
81 | self.worker.start()
82 |
83 | def put_images(self, tensor):
84 | """
85 | Args:
86 | tensor: (n, c, h, w) [0-1] tensor
87 | """
88 | self.queue.put(tensor.cpu())
89 | # with self.env.begin(write=True) as txn:
90 | # for x in tensor:
91 | # key = f"{str(self.i).zfill(self.zfill)}".encode("utf-8")
92 | # x = convert(x, self.format, self.quality)
93 | # txn.put(key, x)
94 | # self.i += 1
95 |
96 | def __exit__(self, *args, **kwargs):
97 | self.queue.put(None)
98 | self.queue.close()
99 | self.worker.join()
100 |
101 |
102 | class LMDBImageReader(Dataset):
103 | def __init__(self, path, zfill: int = 7):
104 | self.zfill = zfill
105 | self.env = lmdb.open(
106 | path,
107 | max_readers=32,
108 | readonly=True,
109 | lock=False,
110 | readahead=False,
111 | meminit=False,
112 | )
113 |
114 | if not self.env:
115 | raise IOError('Cannot open lmdb dataset', path)
116 |
117 | with self.env.begin(write=False) as txn:
118 | self.length = int(
119 | txn.get('length'.encode('utf-8')).decode('utf-8'))
120 |
121 | def __len__(self):
122 | return self.length
123 |
124 | def __getitem__(self, index):
125 | with self.env.begin(write=False) as txn:
126 | key = f'{str(index).zfill(self.zfill)}'.encode('utf-8')
127 | img_bytes = txn.get(key)
128 |
129 | buffer = BytesIO(img_bytes)
130 | img = Image.open(buffer)
131 | return img
132 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import torch
5 | import torchvision
6 | from pytorch_fid import fid_score
7 | from torch import distributed
8 | from torch.utils.data import DataLoader
9 | from torch.utils.data.distributed import DistributedSampler
10 | from tqdm.autonotebook import tqdm, trange
11 |
12 | from renderer import *
13 | from config import *
14 | from diffusion import Sampler
15 | from dist_utils import *
16 | import lpips
17 | from ssim import ssim
18 |
19 |
20 | def make_subset_loader(conf: TrainConfig,
21 | dataset: Dataset,
22 | batch_size: int,
23 | shuffle: bool,
24 | parallel: bool,
25 | drop_last=True):
26 | dataset = SubsetDataset(dataset, size=conf.eval_num_images)
27 | if parallel and distributed.is_initialized():
28 | sampler = DistributedSampler(dataset, shuffle=shuffle)
29 | else:
30 | sampler = None
31 | return DataLoader(
32 | dataset,
33 | batch_size=batch_size,
34 | sampler=sampler,
35 | # with sampler, use the sample instead of this option
36 | shuffle=False if sampler else shuffle,
37 | num_workers=conf.num_workers,
38 | pin_memory=True,
39 | drop_last=drop_last,
40 | multiprocessing_context=get_context('fork'),
41 | )
42 |
43 |
44 | def evaluate_lpips(
45 | sampler: Sampler,
46 | model: Model,
47 | conf: TrainConfig,
48 | device,
49 | val_data: Dataset,
50 | latent_sampler: Sampler = None,
51 | use_inverted_noise: bool = False,
52 | ):
53 | """
54 | compare the generated images from autoencoder on validation dataset
55 |
56 | Args:
57 | use_inversed_noise: the noise is also inverted from DDIM
58 | """
59 | lpips_fn = lpips.LPIPS(net='alex').to(device)
60 | val_loader = make_subset_loader(conf,
61 | dataset=val_data,
62 | batch_size=conf.batch_size_eval,
63 | shuffle=False,
64 | parallel=True)
65 |
66 | model.eval()
67 | with torch.no_grad():
68 | scores = {
69 | 'lpips': [],
70 | 'mse': [],
71 | 'ssim': [],
72 | 'psnr': [],
73 | }
74 | for batch in tqdm(val_loader, desc='lpips'):
75 | imgs = batch['img'].to(device)
76 |
77 | if use_inverted_noise:
78 | # inverse the noise
79 | # with condition from the encoder
80 | model_kwargs = {}
81 | if conf.model_type.has_autoenc():
82 | with torch.no_grad():
83 | model_kwargs = model.encode(imgs)
84 | x_T = sampler.ddim_reverse_sample_loop(
85 | model=model,
86 | x=imgs,
87 | clip_denoised=True,
88 | model_kwargs=model_kwargs)
89 | x_T = x_T['sample']
90 | else:
91 | x_T = torch.randn((len(imgs), 3, conf.img_size, conf.img_size),
92 | device=device)
93 |
94 | if conf.model_type == ModelType.ddpm:
95 | # the case where you want to calculate the inversion capability of the DDIM model
96 | assert use_inverted_noise
97 | pred_imgs = render_uncondition(
98 | conf=conf,
99 | model=model,
100 | x_T=x_T,
101 | sampler=sampler,
102 | latent_sampler=latent_sampler,
103 | )
104 | else:
105 | pred_imgs = render_condition(conf=conf,
106 | model=model,
107 | x_T=x_T,
108 | x_start=imgs,
109 | cond=None,
110 | sampler=sampler)
111 | # # returns {'cond', 'cond2'}
112 | # conds = model.encode(imgs)
113 | # pred_imgs = sampler.sample(model=model,
114 | # noise=x_T,
115 | # model_kwargs=conds)
116 |
117 | # (n, 1, 1, 1) => (n, )
118 | scores['lpips'].append(lpips_fn.forward(imgs, pred_imgs).view(-1))
119 |
120 | # need to normalize into [0, 1]
121 | norm_imgs = (imgs + 1) / 2
122 | norm_pred_imgs = (pred_imgs + 1) / 2
123 | # (n, )
124 | scores['ssim'].append(
125 | ssim(norm_imgs, norm_pred_imgs, size_average=False))
126 | # (n, )
127 | scores['mse'].append(
128 | (norm_imgs - norm_pred_imgs).pow(2).mean(dim=[1, 2, 3]))
129 | # (n, )
130 | scores['psnr'].append(psnr(norm_imgs, norm_pred_imgs))
131 | # (N, )
132 | for key in scores.keys():
133 | scores[key] = torch.cat(scores[key]).float()
134 | model.train()
135 |
136 | barrier()
137 |
138 | # support multi-gpu
139 | outs = {
140 | key: [
141 | torch.zeros(len(scores[key]), device=device)
142 | for i in range(get_world_size())
143 | ]
144 | for key in scores.keys()
145 | }
146 | for key in scores.keys():
147 | all_gather(outs[key], scores[key])
148 |
149 | # final scores
150 | for key in scores.keys():
151 | scores[key] = torch.cat(outs[key]).mean().item()
152 |
153 | # {'lpips', 'mse', 'ssim'}
154 | return scores
155 |
156 |
157 | def psnr(img1, img2):
158 | """
159 | Args:
160 | img1: (n, c, h, w)
161 | """
162 | v_max = 1.
163 | # (n,)
164 | mse = torch.mean((img1 - img2)**2, dim=[1, 2, 3])
165 | return 20 * torch.log10(v_max / torch.sqrt(mse))
166 |
167 |
168 | def evaluate_fid(
169 | sampler: Sampler,
170 | model: Model,
171 | conf: TrainConfig,
172 | device,
173 | train_data: Dataset,
174 | val_data: Dataset,
175 | latent_sampler: Sampler = None,
176 | conds_mean=None,
177 | conds_std=None,
178 | remove_cache: bool = True,
179 | clip_latent_noise: bool = False,
180 | ):
181 | assert conf.fid_cache is not None
182 | if get_rank() == 0:
183 | # no parallel
184 | # validation data for a comparing FID
185 | val_loader = make_subset_loader(conf,
186 | dataset=val_data,
187 | batch_size=conf.batch_size_eval,
188 | shuffle=False,
189 | parallel=False)
190 |
191 | # put the val images to a directory
192 | cache_dir = f'{conf.fid_cache}_{conf.eval_num_images}'
193 | if (os.path.exists(cache_dir)
194 | and len(os.listdir(cache_dir)) < conf.eval_num_images):
195 | shutil.rmtree(cache_dir)
196 |
197 | if not os.path.exists(cache_dir):
198 | # write files to the cache
199 | # the images are normalized, hence need to denormalize first
200 | loader_to_path(val_loader, cache_dir, denormalize=True)
201 |
202 | # create the generate dir
203 | if os.path.exists(conf.generate_dir):
204 | shutil.rmtree(conf.generate_dir)
205 | os.makedirs(conf.generate_dir)
206 |
207 | barrier()
208 |
209 | world_size = get_world_size()
210 | rank = get_rank()
211 | batch_size = chunk_size(conf.batch_size_eval, rank, world_size)
212 |
213 | def filename(idx):
214 | return world_size * idx + rank
215 |
216 | model.eval()
217 | with torch.no_grad():
218 | if conf.model_type.can_sample():
219 | eval_num_images = chunk_size(conf.eval_num_images, rank,
220 | world_size)
221 | desc = "generating images"
222 | for i in trange(0, eval_num_images, batch_size, desc=desc):
223 | batch_size = min(batch_size, eval_num_images - i)
224 | x_T = torch.randn(
225 | (batch_size, 3, conf.img_size, conf.img_size),
226 | device=device)
227 | batch_images = render_uncondition(
228 | conf=conf,
229 | model=model,
230 | x_T=x_T,
231 | sampler=sampler,
232 | latent_sampler=latent_sampler,
233 | conds_mean=conds_mean,
234 | conds_std=conds_std).cpu()
235 |
236 | batch_images = (batch_images + 1) / 2
237 | # keep the generated images
238 | for j in range(len(batch_images)):
239 | img_name = filename(i + j)
240 | torchvision.utils.save_image(
241 | batch_images[j],
242 | os.path.join(conf.generate_dir, f'{img_name}.png'))
243 | elif conf.model_type == ModelType.autoencoder:
244 | if conf.train_mode.is_latent_diffusion():
245 | # evaluate autoencoder + latent diffusion (doesn't give the images)
246 | model: BeatGANsAutoencModel
247 | eval_num_images = chunk_size(conf.eval_num_images, rank,
248 | world_size)
249 | desc = "generating images"
250 | for i in trange(0, eval_num_images, batch_size, desc=desc):
251 | batch_size = min(batch_size, eval_num_images - i)
252 | x_T = torch.randn(
253 | (batch_size, 3, conf.img_size, conf.img_size),
254 | device=device)
255 | batch_images = render_uncondition(
256 | conf=conf,
257 | model=model,
258 | x_T=x_T,
259 | sampler=sampler,
260 | latent_sampler=latent_sampler,
261 | conds_mean=conds_mean,
262 | conds_std=conds_std,
263 | clip_latent_noise=clip_latent_noise,
264 | ).cpu()
265 | batch_images = (batch_images + 1) / 2
266 | # keep the generated images
267 | for j in range(len(batch_images)):
268 | img_name = filename(i + j)
269 | torchvision.utils.save_image(
270 | batch_images[j],
271 | os.path.join(conf.generate_dir, f'{img_name}.png'))
272 | else:
273 | # evaulate autoencoder (given the images)
274 | # to make the FID fair, autoencoder must not see the validation dataset
275 | # also shuffle to make it closer to unconditional generation
276 | train_loader = make_subset_loader(conf,
277 | dataset=train_data,
278 | batch_size=batch_size,
279 | shuffle=True,
280 | parallel=True)
281 |
282 | i = 0
283 | for batch in tqdm(train_loader, desc='generating images'):
284 | imgs = batch['img'].to(device)
285 | x_T = torch.randn(
286 | (len(imgs), 3, conf.img_size, conf.img_size),
287 | device=device)
288 | batch_images = render_condition(
289 | conf=conf,
290 | model=model,
291 | x_T=x_T,
292 | x_start=imgs,
293 | cond=None,
294 | sampler=sampler,
295 | latent_sampler=latent_sampler).cpu()
296 | # model: BeatGANsAutoencModel
297 | # # returns {'cond', 'cond2'}
298 | # conds = model.encode(imgs)
299 | # batch_images = sampler.sample(model=model,
300 | # noise=x_T,
301 | # model_kwargs=conds).cpu()
302 | # denormalize the images
303 | batch_images = (batch_images + 1) / 2
304 | # keep the generated images
305 | for j in range(len(batch_images)):
306 | img_name = filename(i + j)
307 | torchvision.utils.save_image(
308 | batch_images[j],
309 | os.path.join(conf.generate_dir, f'{img_name}.png'))
310 | i += len(imgs)
311 | else:
312 | raise NotImplementedError()
313 | model.train()
314 |
315 | barrier()
316 |
317 | if get_rank() == 0:
318 | fid = fid_score.calculate_fid_given_paths(
319 | [cache_dir, conf.generate_dir],
320 | batch_size,
321 | device=device,
322 | dims=2048)
323 |
324 | # remove the cache
325 | if remove_cache and os.path.exists(conf.generate_dir):
326 | shutil.rmtree(conf.generate_dir)
327 |
328 | barrier()
329 |
330 | if get_rank() == 0:
331 | # need to float it! unless the broadcasted value is wrong
332 | fid = torch.tensor(float(fid), device=device)
333 | broadcast(fid, 0)
334 | else:
335 | fid = torch.tensor(0., device=device)
336 | broadcast(fid, 0)
337 | fid = fid.item()
338 | print(f'fid ({get_rank()}):', fid)
339 |
340 | return fid
341 |
342 |
343 | def loader_to_path(loader: DataLoader, path: str, denormalize: bool):
344 | # not process safe!
345 |
346 | if not os.path.exists(path):
347 | os.makedirs(path)
348 |
349 | # write the loader to files
350 | i = 0
351 | for batch in tqdm(loader, desc='copy images'):
352 | imgs = batch['img']
353 | if denormalize:
354 | imgs = (imgs + 1) / 2
355 | for j in range(len(imgs)):
356 | torchvision.utils.save_image(imgs[j],
357 | os.path.join(path, f'{i+j}.png'))
358 | i += len(imgs)
359 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from .unet import BeatGANsUNetModel, BeatGANsUNetConfig
3 | from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel
4 |
5 | Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel]
6 | ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig]
7 |
--------------------------------------------------------------------------------
/model/blocks.py:
--------------------------------------------------------------------------------
1 | import math
2 | from abc import abstractmethod
3 | from dataclasses import dataclass
4 | from numbers import Number
5 |
6 | import torch as th
7 | import torch.nn.functional as F
8 | from choices import *
9 | from config_base import BaseConfig
10 | from torch import nn
11 |
12 | from .nn import (avg_pool_nd, conv_nd, linear, normalization,
13 | timestep_embedding, torch_checkpoint, zero_module)
14 |
15 |
16 | class ScaleAt(Enum):
17 | after_norm = 'afternorm'
18 |
19 |
20 | class TimestepBlock(nn.Module):
21 | """
22 | Any module where forward() takes timestep embeddings as a second argument.
23 | """
24 | @abstractmethod
25 | def forward(self, x, emb=None, cond=None, lateral=None):
26 | """
27 | Apply the module to `x` given `emb` timestep embeddings.
28 | """
29 |
30 |
31 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
32 | """
33 | A sequential module that passes timestep embeddings to the children that
34 | support it as an extra input.
35 | """
36 | def forward(self, x, emb=None, cond=None, lateral=None):
37 | for layer in self:
38 | if isinstance(layer, TimestepBlock):
39 | x = layer(x, emb=emb, cond=cond, lateral=lateral)
40 | else:
41 | x = layer(x)
42 | return x
43 |
44 |
45 | @dataclass
46 | class ResBlockConfig(BaseConfig):
47 | channels: int
48 | emb_channels: int
49 | dropout: float
50 | out_channels: int = None
51 | # condition the resblock with time (and encoder's output)
52 | use_condition: bool = True
53 | # whether to use 3x3 conv for skip path when the channels aren't matched
54 | use_conv: bool = False
55 | # dimension of conv (always 2 = 2d)
56 | dims: int = 2
57 | # gradient checkpoint
58 | use_checkpoint: bool = False
59 | up: bool = False
60 | down: bool = False
61 | # whether to condition with both time & encoder's output
62 | two_cond: bool = False
63 | # number of encoders' output channels
64 | cond_emb_channels: int = None
65 | # suggest: False
66 | has_lateral: bool = False
67 | lateral_channels: int = None
68 | # whether to init the convolution with zero weights
69 | # this is default from BeatGANs and seems to help learning
70 | use_zero_module: bool = True
71 |
72 | def __post_init__(self):
73 | self.out_channels = self.out_channels or self.channels
74 | self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
75 |
76 | def make_model(self):
77 | return ResBlock(self)
78 |
79 |
80 | class ResBlock(TimestepBlock):
81 | """
82 | A residual block that can optionally change the number of channels.
83 |
84 | total layers:
85 | in_layers
86 | - norm
87 | - act
88 | - conv
89 | out_layers
90 | - norm
91 | - (modulation)
92 | - act
93 | - conv
94 | """
95 | def __init__(self, conf: ResBlockConfig):
96 | super().__init__()
97 | self.conf = conf
98 |
99 | #############################
100 | # IN LAYERS
101 | #############################
102 | assert conf.lateral_channels is None
103 | layers = [
104 | normalization(conf.channels),
105 | nn.SiLU(),
106 | conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1)
107 | ]
108 | self.in_layers = nn.Sequential(*layers)
109 |
110 | self.updown = conf.up or conf.down
111 |
112 | if conf.up:
113 | self.h_upd = Upsample(conf.channels, False, conf.dims)
114 | self.x_upd = Upsample(conf.channels, False, conf.dims)
115 | elif conf.down:
116 | self.h_upd = Downsample(conf.channels, False, conf.dims)
117 | self.x_upd = Downsample(conf.channels, False, conf.dims)
118 | else:
119 | self.h_upd = self.x_upd = nn.Identity()
120 |
121 | #############################
122 | # OUT LAYERS CONDITIONS
123 | #############################
124 | if conf.use_condition:
125 | # condition layers for the out_layers
126 | self.emb_layers = nn.Sequential(
127 | nn.SiLU(),
128 | linear(conf.emb_channels, 2 * conf.out_channels),
129 | )
130 |
131 | if conf.two_cond:
132 | self.cond_emb_layers = nn.Sequential(
133 | nn.SiLU(),
134 | linear(conf.cond_emb_channels, conf.out_channels),
135 | )
136 | #############################
137 | # OUT LAYERS (ignored when there is no condition)
138 | #############################
139 | # original version
140 | conv = conv_nd(conf.dims,
141 | conf.out_channels,
142 | conf.out_channels,
143 | 3,
144 | padding=1)
145 | if conf.use_zero_module:
146 | # zere out the weights
147 | # it seems to help training
148 | conv = zero_module(conv)
149 |
150 | # construct the layers
151 | # - norm
152 | # - (modulation)
153 | # - act
154 | # - dropout
155 | # - conv
156 | layers = []
157 | layers += [
158 | normalization(conf.out_channels),
159 | nn.SiLU(),
160 | nn.Dropout(p=conf.dropout),
161 | conv,
162 | ]
163 | self.out_layers = nn.Sequential(*layers)
164 |
165 | #############################
166 | # SKIP LAYERS
167 | #############################
168 | if conf.out_channels == conf.channels:
169 | # cannot be used with gatedconv, also gatedconv is alsways used as the first block
170 | self.skip_connection = nn.Identity()
171 | else:
172 | if conf.use_conv:
173 | kernel_size = 3
174 | padding = 1
175 | else:
176 | kernel_size = 1
177 | padding = 0
178 |
179 | self.skip_connection = conv_nd(conf.dims,
180 | conf.channels,
181 | conf.out_channels,
182 | kernel_size,
183 | padding=padding)
184 |
185 | def forward(self, x, emb=None, cond=None, lateral=None):
186 | """
187 | Apply the block to a Tensor, conditioned on a timestep embedding.
188 |
189 | Args:
190 | x: input
191 | lateral: lateral connection from the encoder
192 | """
193 | return torch_checkpoint(self._forward, (x, emb, cond, lateral),
194 | self.conf.use_checkpoint)
195 |
196 | def _forward(
197 | self,
198 | x,
199 | emb=None,
200 | cond=None,
201 | lateral=None,
202 | ):
203 | """
204 | Args:
205 | lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally
206 | """
207 | if self.conf.has_lateral:
208 | # lateral may be supplied even if it doesn't require
209 | # the model will take the lateral only if "has_lateral"
210 | assert lateral is not None
211 | x = th.cat([x, lateral], dim=1)
212 |
213 | if self.updown:
214 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
215 | h = in_rest(x)
216 | h = self.h_upd(h)
217 | x = self.x_upd(x)
218 | h = in_conv(h)
219 | else:
220 | h = self.in_layers(x)
221 |
222 | if self.conf.use_condition:
223 | # it's possible that the network may not receieve the time emb
224 | # this happens with autoenc and setting the time_at
225 | if emb is not None:
226 | emb_out = self.emb_layers(emb).type(h.dtype)
227 | else:
228 | emb_out = None
229 |
230 | if self.conf.two_cond:
231 | # it's possible that the network is two_cond
232 | # but it doesn't get the second condition
233 | # in which case, we ignore the second condition
234 | # and treat as if the network has one condition
235 | if cond is None:
236 | cond_out = None
237 | else:
238 | cond_out = self.cond_emb_layers(cond).type(h.dtype)
239 |
240 | if cond_out is not None:
241 | while len(cond_out.shape) < len(h.shape):
242 | cond_out = cond_out[..., None]
243 | else:
244 | cond_out = None
245 |
246 | # this is the new refactored code
247 | h = apply_conditions(
248 | h=h,
249 | emb=emb_out,
250 | cond=cond_out,
251 | layers=self.out_layers,
252 | scale_bias=1,
253 | in_channels=self.conf.out_channels,
254 | up_down_layer=None,
255 | )
256 |
257 | return self.skip_connection(x) + h
258 |
259 |
260 | def apply_conditions(
261 | h,
262 | emb=None,
263 | cond=None,
264 | layers: nn.Sequential = None,
265 | scale_bias: float = 1,
266 | in_channels: int = 512,
267 | up_down_layer: nn.Module = None,
268 | ):
269 | """
270 | apply conditions on the feature maps
271 |
272 | Args:
273 | emb: time conditional (ready to scale + shift)
274 | cond: encoder's conditional (read to scale + shift)
275 | """
276 | two_cond = emb is not None and cond is not None
277 |
278 | if emb is not None:
279 | # adjusting shapes
280 | while len(emb.shape) < len(h.shape):
281 | emb = emb[..., None]
282 |
283 | if two_cond:
284 | # adjusting shapes
285 | while len(cond.shape) < len(h.shape):
286 | cond = cond[..., None]
287 | # time first
288 | scale_shifts = [emb, cond]
289 | else:
290 | # "cond" is not used with single cond mode
291 | scale_shifts = [emb]
292 |
293 | # support scale, shift or shift only
294 | for i, each in enumerate(scale_shifts):
295 | if each is None:
296 | # special case: the condition is not provided
297 | a = None
298 | b = None
299 | else:
300 | if each.shape[1] == in_channels * 2:
301 | a, b = th.chunk(each, 2, dim=1)
302 | else:
303 | a = each
304 | b = None
305 | scale_shifts[i] = (a, b)
306 |
307 | # condition scale bias could be a list
308 | if isinstance(scale_bias, Number):
309 | biases = [scale_bias] * len(scale_shifts)
310 | else:
311 | # a list
312 | biases = scale_bias
313 |
314 | # default, the scale & shift are applied after the group norm but BEFORE SiLU
315 | pre_layers, post_layers = layers[0], layers[1:]
316 |
317 | # spilt the post layer to be able to scale up or down before conv
318 | # post layers will contain only the conv
319 | mid_layers, post_layers = post_layers[:-2], post_layers[-2:]
320 |
321 | h = pre_layers(h)
322 | # scale and shift for each condition
323 | for i, (scale, shift) in enumerate(scale_shifts):
324 | # if scale is None, it indicates that the condition is not provided
325 | if scale is not None:
326 | h = h * (biases[i] + scale)
327 | if shift is not None:
328 | h = h + shift
329 | h = mid_layers(h)
330 |
331 | # upscale or downscale if any just before the last conv
332 | if up_down_layer is not None:
333 | h = up_down_layer(h)
334 | h = post_layers(h)
335 | return h
336 |
337 |
338 | class Upsample(nn.Module):
339 | """
340 | An upsampling layer with an optional convolution.
341 |
342 | :param channels: channels in the inputs and outputs.
343 | :param use_conv: a bool determining if a convolution is applied.
344 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
345 | upsampling occurs in the inner-two dimensions.
346 | """
347 | def __init__(self, channels, use_conv, dims=2, out_channels=None):
348 | super().__init__()
349 | self.channels = channels
350 | self.out_channels = out_channels or channels
351 | self.use_conv = use_conv
352 | self.dims = dims
353 | if use_conv:
354 | self.conv = conv_nd(dims,
355 | self.channels,
356 | self.out_channels,
357 | 3,
358 | padding=1)
359 |
360 | def forward(self, x):
361 | assert x.shape[1] == self.channels
362 | if self.dims == 3:
363 | x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
364 | mode="nearest")
365 | else:
366 | x = F.interpolate(x, scale_factor=2, mode="nearest")
367 | if self.use_conv:
368 | x = self.conv(x)
369 | return x
370 |
371 |
372 | class Downsample(nn.Module):
373 | """
374 | A downsampling layer with an optional convolution.
375 |
376 | :param channels: channels in the inputs and outputs.
377 | :param use_conv: a bool determining if a convolution is applied.
378 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
379 | downsampling occurs in the inner-two dimensions.
380 | """
381 | def __init__(self, channels, use_conv, dims=2, out_channels=None):
382 | super().__init__()
383 | self.channels = channels
384 | self.out_channels = out_channels or channels
385 | self.use_conv = use_conv
386 | self.dims = dims
387 | stride = 2 if dims != 3 else (1, 2, 2)
388 | if use_conv:
389 | self.op = conv_nd(dims,
390 | self.channels,
391 | self.out_channels,
392 | 3,
393 | stride=stride,
394 | padding=1)
395 | else:
396 | assert self.channels == self.out_channels
397 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
398 |
399 | def forward(self, x):
400 | assert x.shape[1] == self.channels
401 | return self.op(x)
402 |
403 |
404 | class AttentionBlock(nn.Module):
405 | """
406 | An attention block that allows spatial positions to attend to each other.
407 |
408 | Originally ported from here, but adapted to the N-d case.
409 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
410 | """
411 | def __init__(
412 | self,
413 | channels,
414 | num_heads=1,
415 | num_head_channels=-1,
416 | use_checkpoint=False,
417 | use_new_attention_order=False,
418 | ):
419 | super().__init__()
420 | self.channels = channels
421 | if num_head_channels == -1:
422 | self.num_heads = num_heads
423 | else:
424 | assert (
425 | channels % num_head_channels == 0
426 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
427 | self.num_heads = channels // num_head_channels
428 | self.use_checkpoint = use_checkpoint
429 | self.norm = normalization(channels)
430 | self.qkv = conv_nd(1, channels, channels * 3, 1)
431 | if use_new_attention_order:
432 | # split qkv before split heads
433 | self.attention = QKVAttention(self.num_heads)
434 | else:
435 | # split heads before split qkv
436 | self.attention = QKVAttentionLegacy(self.num_heads)
437 |
438 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
439 |
440 | def forward(self, x):
441 | return torch_checkpoint(self._forward, (x, ), self.use_checkpoint)
442 |
443 | def _forward(self, x):
444 | b, c, *spatial = x.shape
445 | x = x.reshape(b, c, -1)
446 | qkv = self.qkv(self.norm(x))
447 | h = self.attention(qkv)
448 | h = self.proj_out(h)
449 | return (x + h).reshape(b, c, *spatial)
450 |
451 |
452 | def count_flops_attn(model, _x, y):
453 | """
454 | A counter for the `thop` package to count the operations in an
455 | attention operation.
456 | Meant to be used like:
457 | macs, params = thop.profile(
458 | model,
459 | inputs=(inputs, timestamps),
460 | custom_ops={QKVAttention: QKVAttention.count_flops},
461 | )
462 | """
463 | b, c, *spatial = y[0].shape
464 | num_spatial = int(np.prod(spatial))
465 | # We perform two matmuls with the same number of ops.
466 | # The first computes the weight matrix, the second computes
467 | # the combination of the value vectors.
468 | matmul_ops = 2 * b * (num_spatial**2) * c
469 | model.total_ops += th.DoubleTensor([matmul_ops])
470 |
471 |
472 | class QKVAttentionLegacy(nn.Module):
473 | """
474 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
475 | """
476 | def __init__(self, n_heads):
477 | super().__init__()
478 | self.n_heads = n_heads
479 |
480 | def forward(self, qkv):
481 | """
482 | Apply QKV attention.
483 |
484 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
485 | :return: an [N x (H * C) x T] tensor after attention.
486 | """
487 | bs, width, length = qkv.shape
488 | assert width % (3 * self.n_heads) == 0
489 | ch = width // (3 * self.n_heads)
490 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch,
491 | dim=1)
492 | scale = 1 / math.sqrt(math.sqrt(ch))
493 | weight = th.einsum(
494 | "bct,bcs->bts", q * scale,
495 | k * scale) # More stable with f16 than dividing afterwards
496 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
497 | a = th.einsum("bts,bcs->bct", weight, v)
498 | return a.reshape(bs, -1, length)
499 |
500 | @staticmethod
501 | def count_flops(model, _x, y):
502 | return count_flops_attn(model, _x, y)
503 |
504 |
505 | class QKVAttention(nn.Module):
506 | """
507 | A module which performs QKV attention and splits in a different order.
508 | """
509 | def __init__(self, n_heads):
510 | super().__init__()
511 | self.n_heads = n_heads
512 |
513 | def forward(self, qkv):
514 | """
515 | Apply QKV attention.
516 |
517 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
518 | :return: an [N x (H * C) x T] tensor after attention.
519 | """
520 | bs, width, length = qkv.shape
521 | assert width % (3 * self.n_heads) == 0
522 | ch = width // (3 * self.n_heads)
523 | q, k, v = qkv.chunk(3, dim=1)
524 | scale = 1 / math.sqrt(math.sqrt(ch))
525 | weight = th.einsum(
526 | "bct,bcs->bts",
527 | (q * scale).view(bs * self.n_heads, ch, length),
528 | (k * scale).view(bs * self.n_heads, ch, length),
529 | ) # More stable with f16 than dividing afterwards
530 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
531 | a = th.einsum("bts,bcs->bct", weight,
532 | v.reshape(bs * self.n_heads, ch, length))
533 | return a.reshape(bs, -1, length)
534 |
535 | @staticmethod
536 | def count_flops(model, _x, y):
537 | return count_flops_attn(model, _x, y)
538 |
539 |
540 | class AttentionPool2d(nn.Module):
541 | """
542 | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
543 | """
544 | def __init__(
545 | self,
546 | spacial_dim: int,
547 | embed_dim: int,
548 | num_heads_channels: int,
549 | output_dim: int = None,
550 | ):
551 | super().__init__()
552 | self.positional_embedding = nn.Parameter(
553 | th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
554 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
555 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
556 | self.num_heads = embed_dim // num_heads_channels
557 | self.attention = QKVAttention(self.num_heads)
558 |
559 | def forward(self, x):
560 | b, c, *_spatial = x.shape
561 | x = x.reshape(b, c, -1) # NC(HW)
562 | x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
563 | x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
564 | x = self.qkv_proj(x)
565 | x = self.attention(x)
566 | x = self.c_proj(x)
567 | return x[:, :, 0]
568 |
--------------------------------------------------------------------------------
/model/latentnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | from dataclasses import dataclass
3 | from enum import Enum
4 | from typing import NamedTuple, Tuple
5 |
6 | import torch
7 | from choices import *
8 | from config_base import BaseConfig
9 | from torch import nn
10 | from torch.nn import init
11 |
12 | from .blocks import *
13 | from .nn import timestep_embedding
14 | from .unet import *
15 |
16 |
17 | class LatentNetType(Enum):
18 | none = 'none'
19 | # injecting inputs into the hidden layers
20 | skip = 'skip'
21 |
22 |
23 | class LatentNetReturn(NamedTuple):
24 | pred: torch.Tensor = None
25 |
26 |
27 | @dataclass
28 | class MLPSkipNetConfig(BaseConfig):
29 | """
30 | default MLP for the latent DPM in the paper!
31 | """
32 | num_channels: int
33 | skip_layers: Tuple[int]
34 | num_hid_channels: int
35 | num_layers: int
36 | num_time_emb_channels: int = 64
37 | activation: Activation = Activation.silu
38 | use_norm: bool = True
39 | condition_bias: float = 1
40 | dropout: float = 0
41 | last_act: Activation = Activation.none
42 | num_time_layers: int = 2
43 | time_last_act: bool = False
44 |
45 | def make_model(self):
46 | return MLPSkipNet(self)
47 |
48 |
49 | class MLPSkipNet(nn.Module):
50 | """
51 | concat x to hidden layers
52 |
53 | default MLP for the latent DPM in the paper!
54 | """
55 | def __init__(self, conf: MLPSkipNetConfig):
56 | super().__init__()
57 | self.conf = conf
58 |
59 | layers = []
60 | for i in range(conf.num_time_layers):
61 | if i == 0:
62 | a = conf.num_time_emb_channels
63 | b = conf.num_channels
64 | else:
65 | a = conf.num_channels
66 | b = conf.num_channels
67 | layers.append(nn.Linear(a, b))
68 | if i < conf.num_time_layers - 1 or conf.time_last_act:
69 | layers.append(conf.activation.get_act())
70 | self.time_embed = nn.Sequential(*layers)
71 |
72 | self.layers = nn.ModuleList([])
73 | for i in range(conf.num_layers):
74 | if i == 0:
75 | act = conf.activation
76 | norm = conf.use_norm
77 | cond = True
78 | a, b = conf.num_channels, conf.num_hid_channels
79 | dropout = conf.dropout
80 | elif i == conf.num_layers - 1:
81 | act = Activation.none
82 | norm = False
83 | cond = False
84 | a, b = conf.num_hid_channels, conf.num_channels
85 | dropout = 0
86 | else:
87 | act = conf.activation
88 | norm = conf.use_norm
89 | cond = True
90 | a, b = conf.num_hid_channels, conf.num_hid_channels
91 | dropout = conf.dropout
92 |
93 | if i in conf.skip_layers:
94 | a += conf.num_channels
95 |
96 | self.layers.append(
97 | MLPLNAct(
98 | a,
99 | b,
100 | norm=norm,
101 | activation=act,
102 | cond_channels=conf.num_channels,
103 | use_cond=cond,
104 | condition_bias=conf.condition_bias,
105 | dropout=dropout,
106 | ))
107 | self.last_act = conf.last_act.get_act()
108 |
109 | def forward(self, x, t, **kwargs):
110 | t = timestep_embedding(t, self.conf.num_time_emb_channels)
111 | cond = self.time_embed(t)
112 | h = x
113 | for i in range(len(self.layers)):
114 | if i in self.conf.skip_layers:
115 | # injecting input into the hidden layers
116 | h = torch.cat([h, x], dim=1)
117 | h = self.layers[i].forward(x=h, cond=cond)
118 | h = self.last_act(h)
119 | return LatentNetReturn(h)
120 |
121 |
122 | class MLPLNAct(nn.Module):
123 | def __init__(
124 | self,
125 | in_channels: int,
126 | out_channels: int,
127 | norm: bool,
128 | use_cond: bool,
129 | activation: Activation,
130 | cond_channels: int,
131 | condition_bias: float = 0,
132 | dropout: float = 0,
133 | ):
134 | super().__init__()
135 | self.activation = activation
136 | self.condition_bias = condition_bias
137 | self.use_cond = use_cond
138 |
139 | self.linear = nn.Linear(in_channels, out_channels)
140 | self.act = activation.get_act()
141 | if self.use_cond:
142 | self.linear_emb = nn.Linear(cond_channels, out_channels)
143 | self.cond_layers = nn.Sequential(self.act, self.linear_emb)
144 | if norm:
145 | self.norm = nn.LayerNorm(out_channels)
146 | else:
147 | self.norm = nn.Identity()
148 |
149 | if dropout > 0:
150 | self.dropout = nn.Dropout(p=dropout)
151 | else:
152 | self.dropout = nn.Identity()
153 |
154 | self.init_weights()
155 |
156 | def init_weights(self):
157 | for module in self.modules():
158 | if isinstance(module, nn.Linear):
159 | if self.activation == Activation.relu:
160 | init.kaiming_normal_(module.weight,
161 | a=0,
162 | nonlinearity='relu')
163 | elif self.activation == Activation.lrelu:
164 | init.kaiming_normal_(module.weight,
165 | a=0.2,
166 | nonlinearity='leaky_relu')
167 | elif self.activation == Activation.silu:
168 | init.kaiming_normal_(module.weight,
169 | a=0,
170 | nonlinearity='relu')
171 | else:
172 | # leave it as default
173 | pass
174 |
175 | def forward(self, x, cond=None):
176 | x = self.linear(x)
177 | if self.use_cond:
178 | # (n, c) or (n, c * 2)
179 | cond = self.cond_layers(cond)
180 | cond = (cond, None)
181 |
182 | # scale shift first
183 | x = x * (self.condition_bias + cond[0])
184 | if cond[1] is not None:
185 | x = x + cond[1]
186 | # then norm
187 | x = self.norm(x)
188 | else:
189 | # no condition
190 | x = self.norm(x)
191 | x = self.act(x)
192 | x = self.dropout(x)
193 | return x
--------------------------------------------------------------------------------
/model/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | from enum import Enum
6 | import math
7 | from typing import Optional
8 |
9 | import torch as th
10 | import torch.nn as nn
11 | import torch.utils.checkpoint
12 |
13 | import torch.nn.functional as F
14 |
15 |
16 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
17 | class SiLU(nn.Module):
18 | # @th.jit.script
19 | def forward(self, x):
20 | return x * th.sigmoid(x)
21 |
22 |
23 | class GroupNorm32(nn.GroupNorm):
24 | def forward(self, x):
25 | return super().forward(x.float()).type(x.dtype)
26 |
27 |
28 | def conv_nd(dims, *args, **kwargs):
29 | """
30 | Create a 1D, 2D, or 3D convolution module.
31 | """
32 | if dims == 1:
33 | return nn.Conv1d(*args, **kwargs)
34 | elif dims == 2:
35 | return nn.Conv2d(*args, **kwargs)
36 | elif dims == 3:
37 | return nn.Conv3d(*args, **kwargs)
38 | raise ValueError(f"unsupported dimensions: {dims}")
39 |
40 |
41 | def linear(*args, **kwargs):
42 | """
43 | Create a linear module.
44 | """
45 | return nn.Linear(*args, **kwargs)
46 |
47 |
48 | def avg_pool_nd(dims, *args, **kwargs):
49 | """
50 | Create a 1D, 2D, or 3D average pooling module.
51 | """
52 | if dims == 1:
53 | return nn.AvgPool1d(*args, **kwargs)
54 | elif dims == 2:
55 | return nn.AvgPool2d(*args, **kwargs)
56 | elif dims == 3:
57 | return nn.AvgPool3d(*args, **kwargs)
58 | raise ValueError(f"unsupported dimensions: {dims}")
59 |
60 |
61 | def update_ema(target_params, source_params, rate=0.99):
62 | """
63 | Update target parameters to be closer to those of source parameters using
64 | an exponential moving average.
65 |
66 | :param target_params: the target parameter sequence.
67 | :param source_params: the source parameter sequence.
68 | :param rate: the EMA rate (closer to 1 means slower).
69 | """
70 | for targ, src in zip(target_params, source_params):
71 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
72 |
73 |
74 | def zero_module(module):
75 | """
76 | Zero out the parameters of a module and return it.
77 | """
78 | for p in module.parameters():
79 | p.detach().zero_()
80 | return module
81 |
82 |
83 | def scale_module(module, scale):
84 | """
85 | Scale the parameters of a module and return it.
86 | """
87 | for p in module.parameters():
88 | p.detach().mul_(scale)
89 | return module
90 |
91 |
92 | def mean_flat(tensor):
93 | """
94 | Take the mean over all non-batch dimensions.
95 | """
96 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
97 |
98 |
99 | def normalization(channels):
100 | """
101 | Make a standard normalization layer.
102 |
103 | :param channels: number of input channels.
104 | :return: an nn.Module for normalization.
105 | """
106 | return GroupNorm32(min(32, channels), channels)
107 |
108 |
109 | def timestep_embedding(timesteps, dim, max_period=10000):
110 | """
111 | Create sinusoidal timestep embeddings.
112 |
113 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
114 | These may be fractional.
115 | :param dim: the dimension of the output.
116 | :param max_period: controls the minimum frequency of the embeddings.
117 | :return: an [N x dim] Tensor of positional embeddings.
118 | """
119 | half = dim // 2
120 | freqs = th.exp(-math.log(max_period) *
121 | th.arange(start=0, end=half, dtype=th.float32) /
122 | half).to(device=timesteps.device)
123 | args = timesteps[:, None].float() * freqs[None]
124 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
125 | if dim % 2:
126 | embedding = th.cat(
127 | [embedding, th.zeros_like(embedding[:, :1])], dim=-1)
128 | return embedding
129 |
130 |
131 | def torch_checkpoint(func, args, flag, preserve_rng_state=False):
132 | # torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
133 | if flag:
134 | return torch.utils.checkpoint.checkpoint(
135 | func, *args, preserve_rng_state=preserve_rng_state)
136 | else:
137 | return func(*args)
138 |
--------------------------------------------------------------------------------
/model/unet.py:
--------------------------------------------------------------------------------
1 | import math
2 | from dataclasses import dataclass
3 | from numbers import Number
4 | from typing import NamedTuple, Tuple, Union
5 |
6 | import numpy as np
7 | import torch as th
8 | from torch import nn
9 | import torch.nn.functional as F
10 | from choices import *
11 | from config_base import BaseConfig
12 | from .blocks import *
13 |
14 | from .nn import (conv_nd, linear, normalization, timestep_embedding,
15 | torch_checkpoint, zero_module)
16 |
17 |
18 | @dataclass
19 | class BeatGANsUNetConfig(BaseConfig):
20 | image_size: int = 64
21 | in_channels: int = 3
22 | # base channels, will be multiplied
23 | model_channels: int = 64
24 | # output of the unet
25 | # suggest: 3
26 | # you only need 6 if you also model the variance of the noise prediction (usually we use an analytical variance hence 3)
27 | out_channels: int = 3
28 | # how many repeating resblocks per resolution
29 | # the decoding side would have "one more" resblock
30 | # default: 2
31 | num_res_blocks: int = 2
32 | # you can also set the number of resblocks specifically for the input blocks
33 | # default: None = above
34 | num_input_res_blocks: int = None
35 | # number of time embed channels and style channels
36 | embed_channels: int = 512
37 | # at what resolutions you want to do self-attention of the feature maps
38 | # attentions generally improve performance
39 | # default: [16]
40 | # beatgans: [32, 16, 8]
41 | attention_resolutions: Tuple[int] = (16, )
42 | # number of time embed channels
43 | time_embed_channels: int = None
44 | # dropout applies to the resblocks (on feature maps)
45 | dropout: float = 0.1
46 | channel_mult: Tuple[int] = (1, 2, 4, 8)
47 | input_channel_mult: Tuple[int] = None
48 | conv_resample: bool = True
49 | # always 2 = 2d conv
50 | dims: int = 2
51 | # don't use this, legacy from BeatGANs
52 | num_classes: int = None
53 | use_checkpoint: bool = False
54 | # number of attention heads
55 | num_heads: int = 1
56 | # or specify the number of channels per attention head
57 | num_head_channels: int = -1
58 | # what's this?
59 | num_heads_upsample: int = -1
60 | # use resblock for upscale/downscale blocks (expensive)
61 | # default: True (BeatGANs)
62 | resblock_updown: bool = True
63 | # never tried
64 | use_new_attention_order: bool = False
65 | resnet_two_cond: bool = False
66 | resnet_cond_channels: int = None
67 | # init the decoding conv layers with zero weights, this speeds up training
68 | # default: True (BeattGANs)
69 | resnet_use_zero_module: bool = True
70 | # gradient checkpoint the attention operation
71 | attn_checkpoint: bool = False
72 |
73 | def make_model(self):
74 | return BeatGANsUNetModel(self)
75 |
76 |
77 | class BeatGANsUNetModel(nn.Module):
78 | def __init__(self, conf: BeatGANsUNetConfig):
79 | super().__init__()
80 | self.conf = conf
81 |
82 | if conf.num_heads_upsample == -1:
83 | self.num_heads_upsample = conf.num_heads
84 |
85 | self.dtype = th.float32
86 |
87 | self.time_emb_channels = conf.time_embed_channels or conf.model_channels
88 | self.time_embed = nn.Sequential(
89 | linear(self.time_emb_channels, conf.embed_channels),
90 | nn.SiLU(),
91 | linear(conf.embed_channels, conf.embed_channels),
92 | )
93 |
94 | if conf.num_classes is not None:
95 | self.label_emb = nn.Embedding(conf.num_classes,
96 | conf.embed_channels)
97 |
98 | ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
99 | self.input_blocks = nn.ModuleList([
100 | TimestepEmbedSequential(
101 | conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
102 | ])
103 |
104 | kwargs = dict(
105 | use_condition=True,
106 | two_cond=conf.resnet_two_cond,
107 | use_zero_module=conf.resnet_use_zero_module,
108 | # style channels for the resnet block
109 | cond_emb_channels=conf.resnet_cond_channels,
110 | )
111 |
112 | self._feature_size = ch
113 |
114 | # input_block_chans = [ch]
115 | input_block_chans = [[] for _ in range(len(conf.channel_mult))]
116 | input_block_chans[0].append(ch)
117 |
118 | # number of blocks at each resolution
119 | self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
120 | self.input_num_blocks[0] = 1
121 | self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
122 |
123 | ds = 1
124 | resolution = conf.image_size
125 | for level, mult in enumerate(conf.input_channel_mult
126 | or conf.channel_mult):
127 | for _ in range(conf.num_input_res_blocks or conf.num_res_blocks):
128 | layers = [
129 | ResBlockConfig(
130 | ch,
131 | conf.embed_channels,
132 | conf.dropout,
133 | out_channels=int(mult * conf.model_channels),
134 | dims=conf.dims,
135 | use_checkpoint=conf.use_checkpoint,
136 | **kwargs,
137 | ).make_model()
138 | ]
139 | ch = int(mult * conf.model_channels)
140 | if resolution in conf.attention_resolutions:
141 | layers.append(
142 | AttentionBlock(
143 | ch,
144 | use_checkpoint=conf.use_checkpoint
145 | or conf.attn_checkpoint,
146 | num_heads=conf.num_heads,
147 | num_head_channels=conf.num_head_channels,
148 | use_new_attention_order=conf.
149 | use_new_attention_order,
150 | ))
151 | self.input_blocks.append(TimestepEmbedSequential(*layers))
152 | self._feature_size += ch
153 | # input_block_chans.append(ch)
154 | input_block_chans[level].append(ch)
155 | self.input_num_blocks[level] += 1
156 | # print(input_block_chans)
157 | if level != len(conf.channel_mult) - 1:
158 | resolution //= 2
159 | out_ch = ch
160 | self.input_blocks.append(
161 | TimestepEmbedSequential(
162 | ResBlockConfig(
163 | ch,
164 | conf.embed_channels,
165 | conf.dropout,
166 | out_channels=out_ch,
167 | dims=conf.dims,
168 | use_checkpoint=conf.use_checkpoint,
169 | down=True,
170 | **kwargs,
171 | ).make_model() if conf.
172 | resblock_updown else Downsample(ch,
173 | conf.conv_resample,
174 | dims=conf.dims,
175 | out_channels=out_ch)))
176 | ch = out_ch
177 | # input_block_chans.append(ch)
178 | input_block_chans[level + 1].append(ch)
179 | self.input_num_blocks[level + 1] += 1
180 | ds *= 2
181 | self._feature_size += ch
182 |
183 | self.middle_block = TimestepEmbedSequential(
184 | ResBlockConfig(
185 | ch,
186 | conf.embed_channels,
187 | conf.dropout,
188 | dims=conf.dims,
189 | use_checkpoint=conf.use_checkpoint,
190 | **kwargs,
191 | ).make_model(),
192 | AttentionBlock(
193 | ch,
194 | use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
195 | num_heads=conf.num_heads,
196 | num_head_channels=conf.num_head_channels,
197 | use_new_attention_order=conf.use_new_attention_order,
198 | ),
199 | ResBlockConfig(
200 | ch,
201 | conf.embed_channels,
202 | conf.dropout,
203 | dims=conf.dims,
204 | use_checkpoint=conf.use_checkpoint,
205 | **kwargs,
206 | ).make_model(),
207 | )
208 | self._feature_size += ch
209 |
210 | self.output_blocks = nn.ModuleList([])
211 | for level, mult in list(enumerate(conf.channel_mult))[::-1]:
212 | for i in range(conf.num_res_blocks + 1):
213 | # print(input_block_chans)
214 | # ich = input_block_chans.pop()
215 | try:
216 | ich = input_block_chans[level].pop()
217 | except IndexError:
218 | # this happens only when num_res_block > num_enc_res_block
219 | # we will not have enough lateral (skip) connecions for all decoder blocks
220 | ich = 0
221 | # print('pop:', ich)
222 | layers = [
223 | ResBlockConfig(
224 | # only direct channels when gated
225 | channels=ch + ich,
226 | emb_channels=conf.embed_channels,
227 | dropout=conf.dropout,
228 | out_channels=int(conf.model_channels * mult),
229 | dims=conf.dims,
230 | use_checkpoint=conf.use_checkpoint,
231 | # lateral channels are described here when gated
232 | has_lateral=True if ich > 0 else False,
233 | lateral_channels=None,
234 | **kwargs,
235 | ).make_model()
236 | ]
237 | ch = int(conf.model_channels * mult)
238 | if resolution in conf.attention_resolutions:
239 | layers.append(
240 | AttentionBlock(
241 | ch,
242 | use_checkpoint=conf.use_checkpoint
243 | or conf.attn_checkpoint,
244 | num_heads=self.num_heads_upsample,
245 | num_head_channels=conf.num_head_channels,
246 | use_new_attention_order=conf.
247 | use_new_attention_order,
248 | ))
249 | if level and i == conf.num_res_blocks:
250 | resolution *= 2
251 | out_ch = ch
252 | layers.append(
253 | ResBlockConfig(
254 | ch,
255 | conf.embed_channels,
256 | conf.dropout,
257 | out_channels=out_ch,
258 | dims=conf.dims,
259 | use_checkpoint=conf.use_checkpoint,
260 | up=True,
261 | **kwargs,
262 | ).make_model() if (
263 | conf.resblock_updown
264 | ) else Upsample(ch,
265 | conf.conv_resample,
266 | dims=conf.dims,
267 | out_channels=out_ch))
268 | ds //= 2
269 | self.output_blocks.append(TimestepEmbedSequential(*layers))
270 | self.output_num_blocks[level] += 1
271 | self._feature_size += ch
272 |
273 | # print(input_block_chans)
274 | # print('inputs:', self.input_num_blocks)
275 | # print('outputs:', self.output_num_blocks)
276 |
277 | if conf.resnet_use_zero_module:
278 | self.out = nn.Sequential(
279 | normalization(ch),
280 | nn.SiLU(),
281 | zero_module(
282 | conv_nd(conf.dims,
283 | input_ch,
284 | conf.out_channels,
285 | 3,
286 | padding=1)),
287 | )
288 | else:
289 | self.out = nn.Sequential(
290 | normalization(ch),
291 | nn.SiLU(),
292 | conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
293 | )
294 |
295 | def forward(self, x, t, y=None, **kwargs):
296 | """
297 | Apply the model to an input batch.
298 |
299 | :param x: an [N x C x ...] Tensor of inputs.
300 | :param timesteps: a 1-D batch of timesteps.
301 | :param y: an [N] Tensor of labels, if class-conditional.
302 | :return: an [N x C x ...] Tensor of outputs.
303 | """
304 | assert (y is not None) == (
305 | self.conf.num_classes is not None
306 | ), "must specify y if and only if the model is class-conditional"
307 |
308 | # hs = []
309 | hs = [[] for _ in range(len(self.conf.channel_mult))]
310 | emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
311 |
312 | if self.conf.num_classes is not None:
313 | raise NotImplementedError()
314 | # assert y.shape == (x.shape[0], )
315 | # emb = emb + self.label_emb(y)
316 |
317 | # new code supports input_num_blocks != output_num_blocks
318 | h = x.type(self.dtype)
319 | k = 0
320 | for i in range(len(self.input_num_blocks)):
321 | for j in range(self.input_num_blocks[i]):
322 | h = self.input_blocks[k](h, emb=emb)
323 | # print(i, j, h.shape)
324 | hs[i].append(h)
325 | k += 1
326 | assert k == len(self.input_blocks)
327 |
328 | h = self.middle_block(h, emb=emb)
329 | k = 0
330 | for i in range(len(self.output_num_blocks)):
331 | for j in range(self.output_num_blocks[i]):
332 | # take the lateral connection from the same layer (in reserve)
333 | # until there is no more, use None
334 | try:
335 | lateral = hs[-i - 1].pop()
336 | # print(i, j, lateral.shape)
337 | except IndexError:
338 | lateral = None
339 | # print(i, j, lateral)
340 | h = self.output_blocks[k](h, emb=emb, lateral=lateral)
341 | k += 1
342 |
343 | h = h.type(x.dtype)
344 | pred = self.out(h)
345 | return Return(pred=pred)
346 |
347 |
348 | class Return(NamedTuple):
349 | pred: th.Tensor
350 |
351 |
352 | @dataclass
353 | class BeatGANsEncoderConfig(BaseConfig):
354 | image_size: int
355 | in_channels: int
356 | model_channels: int
357 | out_hid_channels: int
358 | out_channels: int
359 | num_res_blocks: int
360 | attention_resolutions: Tuple[int]
361 | dropout: float = 0
362 | channel_mult: Tuple[int] = (1, 2, 4, 8)
363 | use_time_condition: bool = True
364 | conv_resample: bool = True
365 | dims: int = 2
366 | use_checkpoint: bool = False
367 | num_heads: int = 1
368 | num_head_channels: int = -1
369 | resblock_updown: bool = False
370 | use_new_attention_order: bool = False
371 | pool: str = 'adaptivenonzero'
372 |
373 | def make_model(self):
374 | return BeatGANsEncoderModel(self)
375 |
376 |
377 | class BeatGANsEncoderModel(nn.Module):
378 | """
379 | The half UNet model with attention and timestep embedding.
380 |
381 | For usage, see UNet.
382 | """
383 | def __init__(self, conf: BeatGANsEncoderConfig):
384 | super().__init__()
385 | self.conf = conf
386 | self.dtype = th.float32
387 |
388 | if conf.use_time_condition:
389 | time_embed_dim = conf.model_channels * 4
390 | self.time_embed = nn.Sequential(
391 | linear(conf.model_channels, time_embed_dim),
392 | nn.SiLU(),
393 | linear(time_embed_dim, time_embed_dim),
394 | )
395 | else:
396 | time_embed_dim = None
397 |
398 | ch = int(conf.channel_mult[0] * conf.model_channels)
399 | self.input_blocks = nn.ModuleList([
400 | TimestepEmbedSequential(
401 | conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
402 | ])
403 | self._feature_size = ch
404 | input_block_chans = [ch]
405 | ds = 1
406 | resolution = conf.image_size
407 | for level, mult in enumerate(conf.channel_mult):
408 | for _ in range(conf.num_res_blocks):
409 | layers = [
410 | ResBlockConfig(
411 | ch,
412 | time_embed_dim,
413 | conf.dropout,
414 | out_channels=int(mult * conf.model_channels),
415 | dims=conf.dims,
416 | use_condition=conf.use_time_condition,
417 | use_checkpoint=conf.use_checkpoint,
418 | ).make_model()
419 | ]
420 | ch = int(mult * conf.model_channels)
421 | if resolution in conf.attention_resolutions:
422 | layers.append(
423 | AttentionBlock(
424 | ch,
425 | use_checkpoint=conf.use_checkpoint,
426 | num_heads=conf.num_heads,
427 | num_head_channels=conf.num_head_channels,
428 | use_new_attention_order=conf.
429 | use_new_attention_order,
430 | ))
431 | self.input_blocks.append(TimestepEmbedSequential(*layers))
432 | self._feature_size += ch
433 | input_block_chans.append(ch)
434 | if level != len(conf.channel_mult) - 1:
435 | resolution //= 2
436 | out_ch = ch
437 | self.input_blocks.append(
438 | TimestepEmbedSequential(
439 | ResBlockConfig(
440 | ch,
441 | time_embed_dim,
442 | conf.dropout,
443 | out_channels=out_ch,
444 | dims=conf.dims,
445 | use_condition=conf.use_time_condition,
446 | use_checkpoint=conf.use_checkpoint,
447 | down=True,
448 | ).make_model() if (
449 | conf.resblock_updown
450 | ) else Downsample(ch,
451 | conf.conv_resample,
452 | dims=conf.dims,
453 | out_channels=out_ch)))
454 | ch = out_ch
455 | input_block_chans.append(ch)
456 | ds *= 2
457 | self._feature_size += ch
458 |
459 | self.middle_block = TimestepEmbedSequential(
460 | ResBlockConfig(
461 | ch,
462 | time_embed_dim,
463 | conf.dropout,
464 | dims=conf.dims,
465 | use_condition=conf.use_time_condition,
466 | use_checkpoint=conf.use_checkpoint,
467 | ).make_model(),
468 | AttentionBlock(
469 | ch,
470 | use_checkpoint=conf.use_checkpoint,
471 | num_heads=conf.num_heads,
472 | num_head_channels=conf.num_head_channels,
473 | use_new_attention_order=conf.use_new_attention_order,
474 | ),
475 | ResBlockConfig(
476 | ch,
477 | time_embed_dim,
478 | conf.dropout,
479 | dims=conf.dims,
480 | use_condition=conf.use_time_condition,
481 | use_checkpoint=conf.use_checkpoint,
482 | ).make_model(),
483 | )
484 | self._feature_size += ch
485 | if conf.pool == "adaptivenonzero":
486 | self.out = nn.Sequential(
487 | normalization(ch),
488 | nn.SiLU(),
489 | nn.AdaptiveAvgPool2d((1, 1)),
490 | conv_nd(conf.dims, ch, conf.out_channels, 1),
491 | nn.Flatten(),
492 | )
493 | else:
494 | raise NotImplementedError(f"Unexpected {conf.pool} pooling")
495 |
496 | def forward(self, x, t=None, return_2d_feature=False):
497 | """
498 | Apply the model to an input batch.
499 |
500 | :param x: an [N x C x ...] Tensor of inputs.
501 | :param timesteps: a 1-D batch of timesteps.
502 | :return: an [N x K] Tensor of outputs.
503 | """
504 | if self.conf.use_time_condition:
505 | emb = self.time_embed(timestep_embedding(t, self.model_channels))
506 | else:
507 | emb = None
508 |
509 | results = []
510 | h = x.type(self.dtype)
511 | for module in self.input_blocks:
512 | h = module(h, emb=emb)
513 | if self.conf.pool.startswith("spatial"):
514 | results.append(h.type(x.dtype).mean(dim=(2, 3)))
515 | h = self.middle_block(h, emb=emb)
516 | if self.conf.pool.startswith("spatial"):
517 | results.append(h.type(x.dtype).mean(dim=(2, 3)))
518 | h = th.cat(results, axis=-1)
519 | else:
520 | h = h.type(x.dtype)
521 |
522 | h_2d = h
523 | h = self.out(h)
524 |
525 | if return_2d_feature:
526 | return h, h_2d
527 | else:
528 | return h
529 |
530 | def forward_flatten(self, x):
531 | """
532 | transform the last 2d feature into a flatten vector
533 | """
534 | h = self.out(x)
535 | return h
536 |
537 |
538 | class SuperResModel(BeatGANsUNetModel):
539 | """
540 | A UNetModel that performs super-resolution.
541 |
542 | Expects an extra kwarg `low_res` to condition on a low-resolution image.
543 | """
544 | def __init__(self, image_size, in_channels, *args, **kwargs):
545 | super().__init__(image_size, in_channels * 2, *args, **kwargs)
546 |
547 | def forward(self, x, timesteps, low_res=None, **kwargs):
548 | _, _, new_height, new_width = x.shape
549 | upsampled = F.interpolate(low_res, (new_height, new_width),
550 | mode="bilinear")
551 | x = th.cat([x, upsampled], dim=1)
552 | return super().forward(x, timesteps, **kwargs)
553 |
--------------------------------------------------------------------------------
/model/unet_autoenc.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.nn.functional import silu
6 |
7 | from .latentnet import *
8 | from .unet import *
9 | from choices import *
10 |
11 |
12 | @dataclass
13 | class BeatGANsAutoencConfig(BeatGANsUNetConfig):
14 | # number of style channels
15 | enc_out_channels: int = 512
16 | enc_attn_resolutions: Tuple[int] = None
17 | enc_pool: str = 'depthconv'
18 | enc_num_res_block: int = 2
19 | enc_channel_mult: Tuple[int] = None
20 | enc_grad_checkpoint: bool = False
21 | latent_net_conf: MLPSkipNetConfig = None
22 |
23 | def make_model(self):
24 | return BeatGANsAutoencModel(self)
25 |
26 |
27 | class BeatGANsAutoencModel(BeatGANsUNetModel):
28 | def __init__(self, conf: BeatGANsAutoencConfig):
29 | super().__init__(conf)
30 | self.conf = conf
31 |
32 | # having only time, cond
33 | self.time_embed = TimeStyleSeperateEmbed(
34 | time_channels=conf.model_channels,
35 | time_out_channels=conf.embed_channels,
36 | )
37 |
38 | self.encoder = BeatGANsEncoderConfig(
39 | image_size=conf.image_size,
40 | in_channels=conf.in_channels,
41 | model_channels=conf.model_channels,
42 | out_hid_channels=conf.enc_out_channels,
43 | out_channels=conf.enc_out_channels,
44 | num_res_blocks=conf.enc_num_res_block,
45 | attention_resolutions=(conf.enc_attn_resolutions
46 | or conf.attention_resolutions),
47 | dropout=conf.dropout,
48 | channel_mult=conf.enc_channel_mult or conf.channel_mult,
49 | use_time_condition=False,
50 | conv_resample=conf.conv_resample,
51 | dims=conf.dims,
52 | use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint,
53 | num_heads=conf.num_heads,
54 | num_head_channels=conf.num_head_channels,
55 | resblock_updown=conf.resblock_updown,
56 | use_new_attention_order=conf.use_new_attention_order,
57 | pool=conf.enc_pool,
58 | ).make_model()
59 |
60 | if conf.latent_net_conf is not None:
61 | self.latent_net = conf.latent_net_conf.make_model()
62 |
63 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
64 | """
65 | Reparameterization trick to sample from N(mu, var) from
66 | N(0,1).
67 | :param mu: (Tensor) Mean of the latent Gaussian [B x D]
68 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
69 | :return: (Tensor) [B x D]
70 | """
71 | assert self.conf.is_stochastic
72 | std = torch.exp(0.5 * logvar)
73 | eps = torch.randn_like(std)
74 | return eps * std + mu
75 |
76 | def sample_z(self, n: int, device):
77 | assert self.conf.is_stochastic
78 | return torch.randn(n, self.conf.enc_out_channels, device=device)
79 |
80 | def noise_to_cond(self, noise: Tensor):
81 | raise NotImplementedError()
82 | assert self.conf.noise_net_conf is not None
83 | return self.noise_net.forward(noise)
84 |
85 | def encode(self, x):
86 | cond = self.encoder.forward(x)
87 | return {'cond': cond}
88 |
89 | @property
90 | def stylespace_sizes(self):
91 | modules = list(self.input_blocks.modules()) + list(
92 | self.middle_block.modules()) + list(self.output_blocks.modules())
93 | sizes = []
94 | for module in modules:
95 | if isinstance(module, ResBlock):
96 | linear = module.cond_emb_layers[-1]
97 | sizes.append(linear.weight.shape[0])
98 | return sizes
99 |
100 | def encode_stylespace(self, x, return_vector: bool = True):
101 | """
102 | encode to style space
103 | """
104 | modules = list(self.input_blocks.modules()) + list(
105 | self.middle_block.modules()) + list(self.output_blocks.modules())
106 | # (n, c)
107 | cond = self.encoder.forward(x)
108 | S = []
109 | for module in modules:
110 | if isinstance(module, ResBlock):
111 | # (n, c')
112 | s = module.cond_emb_layers.forward(cond)
113 | S.append(s)
114 |
115 | if return_vector:
116 | # (n, sum_c)
117 | return torch.cat(S, dim=1)
118 | else:
119 | return S
120 |
121 | def forward(self,
122 | x,
123 | t,
124 | y=None,
125 | x_start=None,
126 | cond=None,
127 | style=None,
128 | noise=None,
129 | t_cond=None,
130 | **kwargs):
131 | """
132 | Apply the model to an input batch.
133 |
134 | Args:
135 | x_start: the original image to encode
136 | cond: output of the encoder
137 | noise: random noise (to predict the cond)
138 | """
139 |
140 | if t_cond is None:
141 | t_cond = t
142 |
143 | if noise is not None:
144 | # if the noise is given, we predict the cond from noise
145 | cond = self.noise_to_cond(noise)
146 |
147 | if cond is None:
148 | if x is not None:
149 | assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
150 |
151 | tmp = self.encode(x_start)
152 | cond = tmp['cond']
153 |
154 | if t is not None:
155 | _t_emb = timestep_embedding(t, self.conf.model_channels)
156 | _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
157 | else:
158 | # this happens when training only autoenc
159 | _t_emb = None
160 | _t_cond_emb = None
161 |
162 | if self.conf.resnet_two_cond:
163 | res = self.time_embed.forward(
164 | time_emb=_t_emb,
165 | cond=cond,
166 | time_cond_emb=_t_cond_emb,
167 | )
168 | else:
169 | raise NotImplementedError()
170 |
171 | if self.conf.resnet_two_cond:
172 | # two cond: first = time emb, second = cond_emb
173 | emb = res.time_emb
174 | cond_emb = res.emb
175 | else:
176 | # one cond = combined of both time and cond
177 | emb = res.emb
178 | cond_emb = None
179 |
180 | # override the style if given
181 | style = style or res.style
182 |
183 | assert (y is not None) == (
184 | self.conf.num_classes is not None
185 | ), "must specify y if and only if the model is class-conditional"
186 |
187 | if self.conf.num_classes is not None:
188 | raise NotImplementedError()
189 | # assert y.shape == (x.shape[0], )
190 | # emb = emb + self.label_emb(y)
191 |
192 | # where in the model to supply time conditions
193 | enc_time_emb = emb
194 | mid_time_emb = emb
195 | dec_time_emb = emb
196 | # where in the model to supply style conditions
197 | enc_cond_emb = cond_emb
198 | mid_cond_emb = cond_emb
199 | dec_cond_emb = cond_emb
200 |
201 | # hs = []
202 | hs = [[] for _ in range(len(self.conf.channel_mult))]
203 |
204 | if x is not None:
205 | h = x.type(self.dtype)
206 |
207 | # input blocks
208 | k = 0
209 | for i in range(len(self.input_num_blocks)):
210 | for j in range(self.input_num_blocks[i]):
211 | h = self.input_blocks[k](h,
212 | emb=enc_time_emb,
213 | cond=enc_cond_emb)
214 |
215 | # print(i, j, h.shape)
216 | hs[i].append(h)
217 | k += 1
218 | assert k == len(self.input_blocks)
219 |
220 | # middle blocks
221 | h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
222 | else:
223 | # no lateral connections
224 | # happens when training only the autonecoder
225 | h = None
226 | hs = [[] for _ in range(len(self.conf.channel_mult))]
227 |
228 | # output blocks
229 | k = 0
230 | for i in range(len(self.output_num_blocks)):
231 | for j in range(self.output_num_blocks[i]):
232 | # take the lateral connection from the same layer (in reserve)
233 | # until there is no more, use None
234 | try:
235 | lateral = hs[-i - 1].pop()
236 | # print(i, j, lateral.shape)
237 | except IndexError:
238 | lateral = None
239 | # print(i, j, lateral)
240 |
241 | h = self.output_blocks[k](h,
242 | emb=dec_time_emb,
243 | cond=dec_cond_emb,
244 | lateral=lateral)
245 | k += 1
246 |
247 | pred = self.out(h)
248 | return AutoencReturn(pred=pred, cond=cond)
249 |
250 |
251 | class AutoencReturn(NamedTuple):
252 | pred: Tensor
253 | cond: Tensor = None
254 |
255 |
256 | class EmbedReturn(NamedTuple):
257 | # style and time
258 | emb: Tensor = None
259 | # time only
260 | time_emb: Tensor = None
261 | # style only (but could depend on time)
262 | style: Tensor = None
263 |
264 |
265 | class TimeStyleSeperateEmbed(nn.Module):
266 | # embed only style
267 | def __init__(self, time_channels, time_out_channels):
268 | super().__init__()
269 | self.time_embed = nn.Sequential(
270 | linear(time_channels, time_out_channels),
271 | nn.SiLU(),
272 | linear(time_out_channels, time_out_channels),
273 | )
274 | self.style = nn.Identity()
275 |
276 | def forward(self, time_emb=None, cond=None, **kwargs):
277 | if time_emb is None:
278 | # happens with autoenc training mode
279 | time_emb = None
280 | else:
281 | time_emb = self.time_embed(time_emb)
282 | style = self.style(cond)
283 | return EmbedReturn(emb=style, time_emb=time_emb, style=style)
284 |
--------------------------------------------------------------------------------
/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joellliu/DiffProtect/79b80f4b68caa4b6549c808eb0964a4398914f41/pipeline.png
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | # pre-download the weights for 256 resolution model to checkpoints/ffhq256_autoenc and checkpoints/ffhq256_autoenc_cls
2 | # wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
3 | # bunzip2 shape_predictor_68_face_landmarks.dat.bz2
4 |
5 | import os
6 | import torch
7 | from torchvision.utils import save_image
8 | import tempfile
9 | from templates import *
10 | from templates_cls import *
11 | from experiment_classifier import ClsModel
12 | from align import LandmarksDetector, image_align
13 | from cog import BasePredictor, Path, Input, BaseModel
14 |
15 |
16 | class ModelOutput(BaseModel):
17 | image: Path
18 |
19 |
20 | class Predictor(BasePredictor):
21 | def setup(self):
22 | self.aligned_dir = "aligned"
23 | os.makedirs(self.aligned_dir, exist_ok=True)
24 | self.device = "cuda:0"
25 |
26 | # Model Initialization
27 | model_config = ffhq256_autoenc()
28 | self.model = LitModel(model_config)
29 | state = torch.load("checkpoints/ffhq256_autoenc/last.ckpt", map_location="cpu")
30 | self.model.load_state_dict(state["state_dict"], strict=False)
31 | self.model.ema_model.eval()
32 | self.model.ema_model.to(self.device)
33 |
34 | # Classifier Initialization
35 | classifier_config = ffhq256_autoenc_cls()
36 | classifier_config.pretrain = None # a bit faster
37 | self.classifier = ClsModel(classifier_config)
38 | state_class = torch.load(
39 | "checkpoints/ffhq256_autoenc_cls/last.ckpt", map_location="cpu"
40 | )
41 | print("latent step:", state_class["global_step"])
42 | self.classifier.load_state_dict(state_class["state_dict"], strict=False)
43 | self.classifier.to(self.device)
44 |
45 | self.landmarks_detector = LandmarksDetector(
46 | "shape_predictor_68_face_landmarks.dat"
47 | )
48 |
49 | def predict(
50 | self,
51 | image: Path = Input(
52 | description="Input image for face manipulation. Image will be aligned and cropped, "
53 | "output aligned and manipulated images.",
54 | ),
55 | target_class: str = Input(
56 | default="Bangs",
57 | choices=[
58 | "5_o_Clock_Shadow",
59 | "Arched_Eyebrows",
60 | "Attractive",
61 | "Bags_Under_Eyes",
62 | "Bald",
63 | "Bangs",
64 | "Big_Lips",
65 | "Big_Nose",
66 | "Black_Hair",
67 | "Blond_Hair",
68 | "Blurry",
69 | "Brown_Hair",
70 | "Bushy_Eyebrows",
71 | "Chubby",
72 | "Double_Chin",
73 | "Eyeglasses",
74 | "Goatee",
75 | "Gray_Hair",
76 | "Heavy_Makeup",
77 | "High_Cheekbones",
78 | "Male",
79 | "Mouth_Slightly_Open",
80 | "Mustache",
81 | "Narrow_Eyes",
82 | "Beard",
83 | "Oval_Face",
84 | "Pale_Skin",
85 | "Pointy_Nose",
86 | "Receding_Hairline",
87 | "Rosy_Cheeks",
88 | "Sideburns",
89 | "Smiling",
90 | "Straight_Hair",
91 | "Wavy_Hair",
92 | "Wearing_Earrings",
93 | "Wearing_Hat",
94 | "Wearing_Lipstick",
95 | "Wearing_Necklace",
96 | "Wearing_Necktie",
97 | "Young",
98 | ],
99 | description="Choose manipulation direction.",
100 | ),
101 | manipulation_amplitude: float = Input(
102 | default=0.3,
103 | ge=-0.5,
104 | le=0.5,
105 | description="When set too strong it would result in artifact as it could dominate the original image information.",
106 | ),
107 | T_step: int = Input(
108 | default=100,
109 | choices=[50, 100, 125, 200, 250, 500],
110 | description="Number of step for generation.",
111 | ),
112 | T_inv: int = Input(default=200, choices=[50, 100, 125, 200, 250, 500]),
113 | ) -> List[ModelOutput]:
114 |
115 | img_size = 256
116 | print("Aligning image...")
117 | for i, face_landmarks in enumerate(
118 | self.landmarks_detector.get_landmarks(str(image)), start=1
119 | ):
120 | image_align(str(image), f"{self.aligned_dir}/aligned.png", face_landmarks)
121 |
122 | data = ImageDataset(
123 | self.aligned_dir,
124 | image_size=img_size,
125 | exts=["jpg", "jpeg", "JPG", "png"],
126 | do_augment=False,
127 | )
128 |
129 | print("Encoding and Manipulating the aligned image...")
130 | cls_manipulation_amplitude = manipulation_amplitude
131 | interpreted_target_class = target_class
132 | if (
133 | target_class not in CelebAttrDataset.id_to_cls
134 | and f"No_{target_class}" in CelebAttrDataset.id_to_cls
135 | ):
136 | cls_manipulation_amplitude = -manipulation_amplitude
137 | interpreted_target_class = f"No_{target_class}"
138 |
139 | batch = data[0]["img"][None]
140 |
141 | semantic_latent = self.model.encode(batch.to(self.device))
142 | stochastic_latent = self.model.encode_stochastic(
143 | batch.to(self.device), semantic_latent, T=T_inv
144 | )
145 |
146 | cls_id = CelebAttrDataset.cls_to_id[interpreted_target_class]
147 | class_direction = self.classifier.classifier.weight[cls_id]
148 | normalized_class_direction = F.normalize(class_direction[None, :], dim=1)
149 |
150 | normalized_semantic_latent = self.classifier.normalize(semantic_latent)
151 | normalized_manipulation_amp = cls_manipulation_amplitude * math.sqrt(512)
152 | normalized_manipulated_semantic_latent = (
153 | normalized_semantic_latent
154 | + normalized_manipulation_amp * normalized_class_direction
155 | )
156 |
157 | manipulated_semantic_latent = self.classifier.denormalize(
158 | normalized_manipulated_semantic_latent
159 | )
160 |
161 | # Render Manipulated image
162 | manipulated_img = self.model.render(
163 | stochastic_latent, manipulated_semantic_latent, T=T_step
164 | )[0]
165 | original_img = data[0]["img"]
166 |
167 | model_output = []
168 | out_path = Path(tempfile.mkdtemp()) / "original_aligned.png"
169 | save_image(convert2rgb(original_img), str(out_path))
170 | model_output.append(ModelOutput(image=out_path))
171 |
172 | out_path = Path(tempfile.mkdtemp()) / "manipulated_img.png"
173 | save_image(convert2rgb(manipulated_img, adjust_scale=False), str(out_path))
174 | model_output.append(ModelOutput(image=out_path))
175 | return model_output
176 |
177 |
178 | def convert2rgb(img, adjust_scale=True):
179 | convert_img = torch.tensor(img)
180 | if adjust_scale:
181 | convert_img = (convert_img + 1) / 2
182 | return convert_img.cpu()
183 |
--------------------------------------------------------------------------------
/renderer.py:
--------------------------------------------------------------------------------
1 | from config import *
2 |
3 | from torch.cuda import amp
4 |
5 |
6 | def render_uncondition(conf: TrainConfig,
7 | model: BeatGANsAutoencModel,
8 | x_T,
9 | sampler: Sampler,
10 | latent_sampler: Sampler,
11 | conds_mean=None,
12 | conds_std=None,
13 | clip_latent_noise: bool = False):
14 | device = x_T.device
15 | if conf.train_mode == TrainMode.diffusion:
16 | assert conf.model_type.can_sample()
17 | return sampler.sample(model=model, noise=x_T)
18 | elif conf.train_mode.is_latent_diffusion():
19 | model: BeatGANsAutoencModel
20 | if conf.train_mode == TrainMode.latent_diffusion:
21 | latent_noise = torch.randn(len(x_T), conf.style_ch, device=device)
22 | else:
23 | raise NotImplementedError()
24 |
25 | if clip_latent_noise:
26 | latent_noise = latent_noise.clip(-1, 1)
27 |
28 | cond = latent_sampler.sample(
29 | model=model.latent_net,
30 | noise=latent_noise,
31 | clip_denoised=conf.latent_clip_sample,
32 | )
33 |
34 | if conf.latent_znormalize:
35 | cond = cond * conds_std.to(device) + conds_mean.to(device)
36 |
37 | # the diffusion on the model
38 | return sampler.sample(model=model, noise=x_T, cond=cond)
39 | else:
40 | raise NotImplementedError()
41 |
42 |
43 | def render_condition(
44 | conf: TrainConfig,
45 | model: BeatGANsAutoencModel,
46 | x_T,
47 | sampler: Sampler,
48 | x_start=None,
49 | cond=None,
50 | ):
51 | if conf.train_mode == TrainMode.diffusion:
52 | assert conf.model_type.has_autoenc()
53 | # returns {'cond', 'cond2'}
54 | if cond is None:
55 | cond = model.encode(x_start)
56 | return sampler.sample(model=model,
57 | noise=x_T,
58 | model_kwargs={'cond': cond})
59 | else:
60 | raise NotImplementedError()
61 |
62 | def render_condition_inter(
63 | conf: TrainConfig,
64 | model: BeatGANsAutoencModel,
65 | x_T,
66 | sampler: Sampler,
67 | x_start=None,
68 | cond=None,
69 | t=None
70 | ):
71 | if conf.train_mode == TrainMode.diffusion:
72 | assert conf.model_type.has_autoenc()
73 | # returns {'cond', 'cond2'}
74 | if cond is None:
75 | cond = model.encode(x_start)
76 |
77 | out = sampler.ddim_sample(model=model, x=x_T, t=t, model_kwargs={'cond': cond})
78 |
79 | return out
80 | else:
81 | raise NotImplementedError()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-lightning==1.4.5
2 | torchmetrics==0.5.0
3 | torch==1.8.1
4 | torchvision
5 | scipy==1.5.4
6 | numpy==1.19.5
7 | tqdm
8 | pytorch-fid==0.2.0
9 | pandas==1.1.5
10 | lpips==0.1.4
11 | lmdb==1.2.1
12 | ftfy
13 | regex
14 | advertorch
15 |
--------------------------------------------------------------------------------
/ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 |
8 | def gaussian(window_size, sigma):
9 | gauss = torch.Tensor([
10 | exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
11 | for x in range(window_size)
12 | ])
13 | return gauss / gauss.sum()
14 |
15 |
16 | def create_window(window_size, channel):
17 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
18 | _2D_window = _1D_window.mm(
19 | _1D_window.t()).float().unsqueeze(0).unsqueeze(0)
20 | window = Variable(
21 | _2D_window.expand(channel, 1, window_size, window_size).contiguous())
22 | return window
23 |
24 |
25 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
26 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
27 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
28 |
29 | mu1_sq = mu1.pow(2)
30 | mu2_sq = mu2.pow(2)
31 | mu1_mu2 = mu1 * mu2
32 |
33 | sigma1_sq = F.conv2d(
34 | img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
35 | sigma2_sq = F.conv2d(
36 | img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
37 | sigma12 = F.conv2d(
38 | img1 * img2, window, padding=window_size // 2,
39 | groups=channel) - mu1_mu2
40 |
41 | C1 = 0.01**2
42 | C2 = 0.03**2
43 |
44 | ssim_map = ((2 * mu1_mu2 + C1) *
45 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
46 | (sigma1_sq + sigma2_sq + C2))
47 |
48 | if size_average:
49 | return ssim_map.mean()
50 | else:
51 | return ssim_map.mean(1).mean(1).mean(1)
52 |
53 |
54 | class SSIM(torch.nn.Module):
55 | def __init__(self, window_size=11, size_average=True):
56 | super(SSIM, self).__init__()
57 | self.window_size = window_size
58 | self.size_average = size_average
59 | self.channel = 1
60 | self.window = create_window(window_size, self.channel)
61 |
62 | def forward(self, img1, img2):
63 | (_, channel, _, _) = img1.size()
64 |
65 | if channel == self.channel and self.window.data.type(
66 | ) == img1.data.type():
67 | window = self.window
68 | else:
69 | window = create_window(self.window_size, channel)
70 |
71 | if img1.is_cuda:
72 | window = window.cuda(img1.get_device())
73 | window = window.type_as(img1)
74 |
75 | self.window = window
76 | self.channel = channel
77 |
78 | return _ssim(img1, img2, window, self.window_size, channel,
79 | self.size_average)
80 |
81 |
82 | def ssim(img1, img2, window_size=11, size_average=True):
83 | (_, channel, _, _) = img1.size()
84 | window = create_window(window_size, channel)
85 |
86 | if img1.is_cuda:
87 | window = window.cuda(img1.get_device())
88 | window = window.type_as(img1)
89 |
90 | return _ssim(img1, img2, window, window_size, channel, size_average)
--------------------------------------------------------------------------------
/templates.py:
--------------------------------------------------------------------------------
1 | from experiment import *
2 |
3 |
4 | def ddpm():
5 | """
6 | base configuration for all DDIM-based models.
7 | """
8 | conf = TrainConfig()
9 | conf.batch_size = 32
10 | conf.beatgans_gen_type = GenerativeType.ddim
11 | conf.beta_scheduler = 'linear'
12 | conf.data_name = 'ffhq'
13 | conf.diffusion_type = 'beatgans'
14 | conf.eval_ema_every_samples = 200_000
15 | conf.eval_every_samples = 200_000
16 | conf.fp16 = True
17 | conf.lr = 1e-4
18 | conf.model_name = ModelName.beatgans_ddpm
19 | conf.net_attn = (16, )
20 | conf.net_beatgans_attn_head = 1
21 | conf.net_beatgans_embed_channels = 512
22 | conf.net_ch_mult = (1, 2, 4, 8)
23 | conf.net_ch = 64
24 | conf.sample_size = 32
25 | conf.T_eval = 20
26 | conf.T = 1000
27 | conf.make_model_conf()
28 | return conf
29 |
30 |
31 | def autoenc_base():
32 | """
33 | base configuration for all Diff-AE models.
34 | """
35 | conf = TrainConfig()
36 | conf.batch_size = 32
37 | conf.beatgans_gen_type = GenerativeType.ddim
38 | conf.beta_scheduler = 'linear'
39 | conf.data_name = 'ffhq'
40 | conf.diffusion_type = 'beatgans'
41 | conf.eval_ema_every_samples = 200_000
42 | conf.eval_every_samples = 200_000
43 | conf.fp16 = True
44 | conf.lr = 1e-4
45 | conf.model_name = ModelName.beatgans_autoenc
46 | conf.net_attn = (16, )
47 | conf.net_beatgans_attn_head = 1
48 | conf.net_beatgans_embed_channels = 512
49 | conf.net_beatgans_resnet_two_cond = True
50 | conf.net_ch_mult = (1, 2, 4, 8)
51 | conf.net_ch = 64
52 | conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
53 | conf.net_enc_pool = 'adaptivenonzero'
54 | conf.sample_size = 32
55 | conf.T_eval = 20
56 | conf.T = 1000
57 | conf.make_model_conf()
58 | return conf
59 |
60 |
61 | def ffhq64_ddpm():
62 | conf = ddpm()
63 | conf.data_name = 'ffhqlmdb256'
64 | conf.warmup = 0
65 | conf.total_samples = 72_000_000
66 | conf.scale_up_gpus(4)
67 | return conf
68 |
69 |
70 | def ffhq64_autoenc():
71 | conf = autoenc_base()
72 | conf.data_name = 'ffhqlmdb256'
73 | conf.warmup = 0
74 | conf.total_samples = 72_000_000
75 | conf.net_ch_mult = (1, 2, 4, 8)
76 | conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
77 | conf.eval_every_samples = 1_000_000
78 | conf.eval_ema_every_samples = 1_000_000
79 | conf.scale_up_gpus(4)
80 | conf.make_model_conf()
81 | return conf
82 |
83 |
84 | def celeba64d2c_ddpm():
85 | conf = ffhq128_ddpm()
86 | conf.data_name = 'celebalmdb'
87 | conf.eval_every_samples = 10_000_000
88 | conf.eval_ema_every_samples = 10_000_000
89 | conf.total_samples = 72_000_000
90 | conf.name = 'celeba64d2c_ddpm'
91 | return conf
92 |
93 |
94 | def celeba64d2c_autoenc():
95 | conf = ffhq64_autoenc()
96 | conf.data_name = 'celebalmdb'
97 | conf.eval_every_samples = 10_000_000
98 | conf.eval_ema_every_samples = 10_000_000
99 | conf.total_samples = 72_000_000
100 | conf.name = 'celeba64d2c_autoenc'
101 | return conf
102 |
103 |
104 | def ffhq128_ddpm():
105 | conf = ddpm()
106 | conf.data_name = 'ffhqlmdb256'
107 | conf.warmup = 0
108 | conf.total_samples = 48_000_000
109 | conf.img_size = 128
110 | conf.net_ch = 128
111 | # channels:
112 | # 3 => 128 * 1 => 128 * 1 => 128 * 2 => 128 * 3 => 128 * 4
113 | # sizes:
114 | # 128 => 128 => 64 => 32 => 16 => 8
115 | conf.net_ch_mult = (1, 1, 2, 3, 4)
116 | conf.eval_every_samples = 1_000_000
117 | conf.eval_ema_every_samples = 1_000_000
118 | conf.scale_up_gpus(4)
119 | conf.eval_ema_every_samples = 10_000_000
120 | conf.eval_every_samples = 10_000_000
121 | conf.make_model_conf()
122 | return conf
123 |
124 |
125 | def ffhq128_autoenc_base():
126 | conf = autoenc_base()
127 | conf.data_name = 'ffhqlmdb256'
128 | conf.scale_up_gpus(4)
129 | conf.img_size = 128
130 | conf.net_ch = 128
131 | # final resolution = 8x8
132 | conf.net_ch_mult = (1, 1, 2, 3, 4)
133 | # final resolution = 4x4
134 | conf.net_enc_channel_mult = (1, 1, 2, 3, 4, 4)
135 | conf.eval_ema_every_samples = 10_000_000
136 | conf.eval_every_samples = 10_000_000
137 | conf.make_model_conf()
138 | return conf
139 |
140 |
141 | def ffhq256_autoenc():
142 | conf = ffhq128_autoenc_base()
143 | conf.img_size = 256
144 | conf.net_ch = 128
145 | conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
146 | conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
147 | conf.eval_every_samples = 10_000_000
148 | conf.eval_ema_every_samples = 10_000_000
149 | conf.total_samples = 200_000_000
150 | conf.batch_size = 64
151 | conf.make_model_conf()
152 | conf.name = 'ffhq256_autoenc'
153 | return conf
154 |
155 |
156 | def ffhq256_autoenc_eco():
157 | conf = ffhq128_autoenc_base()
158 | conf.img_size = 256
159 | conf.net_ch = 128
160 | conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
161 | conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
162 | conf.eval_every_samples = 10_000_000
163 | conf.eval_ema_every_samples = 10_000_000
164 | conf.total_samples = 200_000_000
165 | conf.batch_size = 64
166 | conf.make_model_conf()
167 | conf.name = 'ffhq256_autoenc_eco'
168 | return conf
169 |
170 |
171 | def ffhq128_ddpm_72M():
172 | conf = ffhq128_ddpm()
173 | conf.total_samples = 72_000_000
174 | conf.name = 'ffhq128_ddpm_72M'
175 | return conf
176 |
177 |
178 | def ffhq128_autoenc_72M():
179 | conf = ffhq128_autoenc_base()
180 | conf.total_samples = 72_000_000
181 | conf.name = 'ffhq128_autoenc_72M'
182 | return conf
183 |
184 |
185 | def ffhq128_ddpm_130M():
186 | conf = ffhq128_ddpm()
187 | conf.total_samples = 130_000_000
188 | conf.eval_ema_every_samples = 10_000_000
189 | conf.eval_every_samples = 10_000_000
190 | conf.name = 'ffhq128_ddpm_130M'
191 | return conf
192 |
193 |
194 | def ffhq128_autoenc_130M():
195 | conf = ffhq128_autoenc_base()
196 | conf.total_samples = 130_000_000
197 | conf.eval_ema_every_samples = 10_000_000
198 | conf.eval_every_samples = 10_000_000
199 | conf.name = 'ffhq128_autoenc_130M'
200 | return conf
201 |
202 |
203 | def horse128_ddpm():
204 | conf = ffhq128_ddpm()
205 | conf.data_name = 'horse256'
206 | conf.total_samples = 130_000_000
207 | conf.eval_ema_every_samples = 10_000_000
208 | conf.eval_every_samples = 10_000_000
209 | conf.name = 'horse128_ddpm'
210 | return conf
211 |
212 |
213 | def horse128_autoenc():
214 | conf = ffhq128_autoenc_base()
215 | conf.data_name = 'horse256'
216 | conf.total_samples = 130_000_000
217 | conf.eval_ema_every_samples = 10_000_000
218 | conf.eval_every_samples = 10_000_000
219 | conf.name = 'horse128_autoenc'
220 | return conf
221 |
222 |
223 | def bedroom128_ddpm():
224 | conf = ffhq128_ddpm()
225 | conf.data_name = 'bedroom256'
226 | conf.eval_ema_every_samples = 10_000_000
227 | conf.eval_every_samples = 10_000_000
228 | conf.total_samples = 120_000_000
229 | conf.name = 'bedroom128_ddpm'
230 | return conf
231 |
232 |
233 | def bedroom128_autoenc():
234 | conf = ffhq128_autoenc_base()
235 | conf.data_name = 'bedroom256'
236 | conf.eval_ema_every_samples = 10_000_000
237 | conf.eval_every_samples = 10_000_000
238 | conf.total_samples = 120_000_000
239 | conf.name = 'bedroom128_autoenc'
240 | return conf
241 |
242 |
243 | def pretrain_celeba64d2c_72M():
244 | conf = celeba64d2c_autoenc()
245 | conf.pretrain = PretrainConfig(
246 | name='72M',
247 | path=f'checkpoints/{celeba64d2c_autoenc().name}/last.ckpt',
248 | )
249 | conf.latent_infer_path = f'checkpoints/{celeba64d2c_autoenc().name}/latent.pkl'
250 | return conf
251 |
252 |
253 | def pretrain_ffhq128_autoenc72M():
254 | conf = ffhq128_autoenc_base()
255 | conf.postfix = ''
256 | conf.pretrain = PretrainConfig(
257 | name='72M',
258 | path=f'checkpoints/{ffhq128_autoenc_72M().name}/last.ckpt',
259 | )
260 | conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_72M().name}/latent.pkl'
261 | return conf
262 |
263 |
264 | def pretrain_ffhq128_autoenc130M():
265 | conf = ffhq128_autoenc_base()
266 | conf.pretrain = PretrainConfig(
267 | name='130M',
268 | path=f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt',
269 | )
270 | conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl'
271 | return conf
272 |
273 |
274 | def pretrain_ffhq256_autoenc():
275 | conf = ffhq256_autoenc()
276 | conf.pretrain = PretrainConfig(
277 | name='90M',
278 | path=f'checkpoints/{ffhq256_autoenc().name}/last.ckpt',
279 | )
280 | conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl'
281 | return conf
282 |
283 |
284 | def pretrain_horse128():
285 | conf = horse128_autoenc()
286 | conf.pretrain = PretrainConfig(
287 | name='82M',
288 | path=f'checkpoints/{horse128_autoenc().name}/last.ckpt',
289 | )
290 | conf.latent_infer_path = f'checkpoints/{horse128_autoenc().name}/latent.pkl'
291 | return conf
292 |
293 |
294 | def pretrain_bedroom128():
295 | conf = bedroom128_autoenc()
296 | conf.pretrain = PretrainConfig(
297 | name='120M',
298 | path=f'checkpoints/{bedroom128_autoenc().name}/last.ckpt',
299 | )
300 | conf.latent_infer_path = f'checkpoints/{bedroom128_autoenc().name}/latent.pkl'
301 | return conf
302 |
--------------------------------------------------------------------------------