├── README.md ├── configs └── diffswap │ └── default-project.yaml ├── data └── portrait_jpg │ ├── source │ └── 0.jpg │ └── target │ └── 0.jpg ├── data_preprocessing ├── align │ ├── __init__.py │ ├── align_trans.py │ ├── box_utils.py │ ├── detector.py │ ├── face_align.py │ ├── face_align_portrait.py │ ├── face_resize.py │ ├── first_stage.py │ ├── get_nets.py │ ├── matlab_cp2tform.py │ ├── onet.npy │ ├── pnet.npy │ ├── rnet.npy │ └── visualization_utils.py └── detection │ ├── detcect_faces_portrait.py │ ├── merge_mtcnn_portrait.py │ └── run_detect_faces_portrait.sh ├── ldm ├── data │ └── portrait.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── ddim.py │ │ └── ddpm.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── face_embedder.py │ │ └── modules.py │ └── x_transformer.py └── util.py ├── pipeline.py ├── requirements.txt ├── src └── arcface_torch │ ├── README.md │ ├── backbones │ ├── __init__.py │ ├── iresnet.py │ ├── iresnet2060.py │ ├── mobilefacenet.py │ └── vit.py │ ├── configs │ ├── 3millions.py │ ├── __init__.py │ ├── base.py │ ├── glint360k_mbf.py │ ├── glint360k_r100.py │ ├── glint360k_r50.py │ ├── ms1mv2_mbf.py │ ├── ms1mv2_r100.py │ ├── ms1mv2_r50.py │ ├── ms1mv3_mbf.py │ ├── ms1mv3_r100.py │ ├── ms1mv3_r50.py │ ├── wf12m_conflict_r50.py │ ├── wf12m_conflict_r50_pfc03_filter04.py │ ├── wf12m_flip_pfc01_filter04_r50.py │ ├── wf12m_flip_r50.py │ ├── wf12m_mbf.py │ ├── wf12m_pfc02_r100.py │ ├── wf12m_r100.py │ ├── wf12m_r50.py │ ├── wf42m_pfc0008_32gpu_r100.py │ ├── wf42m_pfc02_16gpus_mbf_bs8k.py │ ├── wf42m_pfc02_16gpus_r100.py │ ├── wf42m_pfc02_16gpus_r50_bs8k.py │ ├── wf42m_pfc02_32gpus_r50_bs4k.py │ ├── wf42m_pfc02_8gpus_r50_bs4k.py │ ├── wf42m_pfc02_r100.py │ ├── wf42m_pfc02_r100_16gpus.py │ ├── wf42m_pfc02_r100_32gpus.py │ ├── wf42m_pfc03_32gpu_r100.py │ ├── wf42m_pfc03_32gpu_r18.py │ ├── wf42m_pfc03_32gpu_r200.py │ ├── wf42m_pfc03_32gpu_r50.py │ ├── wf42m_pfc03_40epoch_64gpu_vit_b.py │ ├── wf42m_pfc03_40epoch_64gpu_vit_l.py │ ├── wf42m_pfc03_40epoch_64gpu_vit_s.py │ ├── wf42m_pfc03_40epoch_64gpu_vit_t.py │ ├── wf42m_pfc03_40epoch_8gpu_vit_b.py │ ├── wf42m_pfc03_40epoch_8gpu_vit_t.py │ ├── wf4m_mbf.py │ ├── wf4m_r100.py │ └── wf4m_r50.py │ ├── dataset.py │ ├── dist.sh │ ├── docs │ ├── eval.md │ ├── install.md │ ├── install_dali.md │ ├── modelzoo.md │ ├── prepare_webface42m.md │ └── speed_benchmark.md │ ├── eval │ ├── __init__.py │ └── verification.py │ ├── eval_ijbc.py │ ├── flops.py │ ├── inference.py │ ├── losses.py │ ├── lr_scheduler.py │ ├── onnx_helper.py │ ├── onnx_ijbc.py │ ├── partial_fc.py │ ├── partial_fc_v2.py │ ├── requirement.txt │ ├── run.sh │ ├── torch2onnx.py │ ├── train.py │ ├── train_v2.py │ └── utils │ ├── __init__.py │ ├── plot.py │ ├── utils_callbacks.py │ ├── utils_config.py │ ├── utils_distributed_sampler.py │ └── utils_logging.py ├── tests ├── face_swap.sh └── faceswap_portrait.py └── utils ├── blending ├── __init__.py ├── blending.py └── blending_mask.py └── portrait.py /README.md: -------------------------------------------------------------------------------- 1 | # DiffSwap 2 | 3 | 4 | Created by [Wenliang Zhao](https://wl-zhao.github.io/), [Yongming Rao](https://raoyongming.github.io/), Weikang Shi, [Zuyan Liu](https://scholar.google.com/citations?user=7npgHqAAAAAJ&hl=en), [Jie Zhou](https://scholar.google.com/citations?user=6a79aPwAAAAJ&hl=en&authuser=1), [Jiwen Lu](https://scholar.google.com/citations?user=TN8uDQoAAAAJ&hl=en&authuser=1)† 5 | 6 | This repository contains PyTorch implementation for paper "DiffSwap: High-Fidelity and Controllable Face Swapping via 3D-Aware Masked Diffusion" 7 | 8 | [[paper]](https://openaccess.thecvf.com/content/CVPR2023/papers/Zhao_DiffSwap_High-Fidelity_and_Controllable_Face_Swapping_via_3D-Aware_Masked_Diffusion_CVPR_2023_paper.pdf) 9 | 10 | ## Installation 11 | Please first install the environment following [stable-diffusion](https://github.com/CompVis/stable-diffusion), and then run `pip install -r requirements.txt`. 12 | 13 | Please download the checkpoints from [[here]](https://cloud.tsinghua.edu.cn/d/9575c106b9324df7bfe3/), and put them under the `checkpoints/` folder. 14 | The resulting file structure should be: 15 | 16 | ``` 17 | ├── checkpoints 18 | │ ├── diffswap.pth 19 | │ ├── glint360k_r100.pth 20 | │ └── shape_predictor_68_face_landmarks.dat 21 | ``` 22 | 23 | ## Inference 24 | We provide a sample code to perform face swapping given the portrait source and target images. Please put the source images and target images in `data/portrait_jpg` and run 25 | ``` 26 | python pipeline.py 27 | ``` 28 | the swapped results are saved in `data/portrait/swap_res_ori`. 29 | 30 | ## Citation 31 | If you find our work useful in your research, please consider citing: 32 | ``` 33 | @article{zhao2023diffswap, 34 | title={DiffSwap: High-Fidelity and Controllable Face Swapping via 3D-Aware Masked Diffusion}, 35 | author={Zhao, Wenliang and Rao, Yongming and Shi, Weikang and Liu, Zuyan and Zhou, Jie and Lu, Jiwen}, 36 | journal={CVPR}, 37 | year={2023} 38 | } 39 | ``` 40 | -------------------------------------------------------------------------------- /configs/diffswap/default-project.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-07 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | ckpt_path: null 6 | linear_start: 0.0015 7 | linear_end: 0.0195 8 | num_timesteps_cond: 1 9 | log_every_t: 200 10 | timesteps: 1000 11 | first_stage_key: image 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | cond_stage_key: faceattr 16 | conditioning_key: crossattn 17 | monitor: val/loss_simple_ema 18 | unet_config: 19 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 20 | params: 21 | image_size: 64 22 | in_channels: 3 23 | out_channels: 3 24 | model_channels: 224 25 | attention_resolutions: 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | use_spatial_transformer: true 37 | transformer_depth: 1 38 | context_dim: 256 39 | first_stage_config: 40 | target: ldm.models.autoencoder.VQModelInterface 41 | params: 42 | embed_dim: 3 43 | n_embed: 8192 44 | ddconfig: 45 | double_z: false 46 | z_channels: 3 47 | resolution: 256 48 | in_channels: 3 49 | out_ch: 3 50 | ch: 128 51 | ch_mult: 52 | - 1 53 | - 2 54 | - 4 55 | num_res_blocks: 2 56 | attn_resolutions: [] 57 | dropout: 0.0 58 | lossconfig: 59 | target: torch.nn.Identity 60 | cond_stage_config: 61 | target: ldm.modules.encoders.modules.FaceEmbedder 62 | params: 63 | lmk_dim: 256 64 | comb_mode: stack 65 | keys: 66 | - image 67 | - landmark 68 | attention: true 69 | merge_eyes: true 70 | face_model: r100 71 | affine_crop: true 72 | -------------------------------------------------------------------------------- /data/portrait_jpg/source/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wl-zhao/DiffSwap/8596b4d635e3d97621df688245b365bc4d3ae02a/data/portrait_jpg/source/0.jpg -------------------------------------------------------------------------------- /data/portrait_jpg/target/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wl-zhao/DiffSwap/8596b4d635e3d97621df688245b365bc4d3ae02a/data/portrait_jpg/target/0.jpg -------------------------------------------------------------------------------- /data_preprocessing/align/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data_preprocessing/align/box_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | 5 | def nms(boxes, overlap_threshold = 0.5, mode = 'union'): 6 | """Non-maximum suppression. 7 | 8 | Arguments: 9 | boxes: a float numpy array of shape [n, 5], 10 | where each row is (xmin, ymin, xmax, ymax, score). 11 | overlap_threshold: a float number. 12 | mode: 'union' or 'min'. 13 | 14 | Returns: 15 | list with indices of the selected boxes 16 | """ 17 | 18 | # if there are no boxes, return the empty list 19 | if len(boxes) == 0: 20 | return [] 21 | 22 | # list of picked indices 23 | pick = [] 24 | 25 | # grab the coordinates of the bounding boxes 26 | x1, y1, x2, y2, score = [boxes[:, i] for i in range(5)] 27 | 28 | area = (x2 - x1 + 1.0)*(y2 - y1 + 1.0) 29 | ids = np.argsort(score) # in increasing order 30 | 31 | while len(ids) > 0: 32 | 33 | # grab index of the largest value 34 | last = len(ids) - 1 35 | i = ids[last] 36 | pick.append(i) 37 | 38 | # compute intersections 39 | # of the box with the largest score 40 | # with the rest of boxes 41 | 42 | # left top corner of intersection boxes 43 | ix1 = np.maximum(x1[i], x1[ids[:last]]) 44 | iy1 = np.maximum(y1[i], y1[ids[:last]]) 45 | 46 | # right bottom corner of intersection boxes 47 | ix2 = np.minimum(x2[i], x2[ids[:last]]) 48 | iy2 = np.minimum(y2[i], y2[ids[:last]]) 49 | 50 | # width and height of intersection boxes 51 | w = np.maximum(0.0, ix2 - ix1 + 1.0) 52 | h = np.maximum(0.0, iy2 - iy1 + 1.0) 53 | 54 | # intersections' areas 55 | inter = w * h 56 | if mode == 'min': 57 | overlap = inter/np.minimum(area[i], area[ids[:last]]) 58 | elif mode == 'union': 59 | # intersection over union (IoU) 60 | overlap = inter/(area[i] + area[ids[:last]] - inter) 61 | 62 | # delete all boxes where overlap is too big 63 | ids = np.delete( 64 | ids, 65 | np.concatenate([[last], np.where(overlap > overlap_threshold)[0]]) 66 | ) 67 | 68 | return pick 69 | 70 | 71 | def convert_to_square(bboxes): 72 | """Convert bounding boxes to a square form. 73 | 74 | Arguments: 75 | bboxes: a float numpy array of shape [n, 5]. 76 | 77 | Returns: 78 | a float numpy array of shape [n, 5], 79 | squared bounding boxes. 80 | """ 81 | 82 | square_bboxes = np.zeros_like(bboxes) 83 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 84 | h = y2 - y1 + 1.0 85 | w = x2 - x1 + 1.0 86 | max_side = np.maximum(h, w) 87 | square_bboxes[:, 0] = x1 + w*0.5 - max_side*0.5 88 | square_bboxes[:, 1] = y1 + h*0.5 - max_side*0.5 89 | square_bboxes[:, 2] = square_bboxes[:, 0] + max_side - 1.0 90 | square_bboxes[:, 3] = square_bboxes[:, 1] + max_side - 1.0 91 | return square_bboxes 92 | 93 | 94 | def calibrate_box(bboxes, offsets): 95 | """Transform bounding boxes to be more like true bounding boxes. 96 | 'offsets' is one of the outputs of the nets. 97 | 98 | Arguments: 99 | bboxes: a float numpy array of shape [n, 5]. 100 | offsets: a float numpy array of shape [n, 4]. 101 | 102 | Returns: 103 | a float numpy array of shape [n, 5]. 104 | """ 105 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 106 | w = x2 - x1 + 1.0 107 | h = y2 - y1 + 1.0 108 | w = np.expand_dims(w, 1) 109 | h = np.expand_dims(h, 1) 110 | 111 | # this is what happening here: 112 | # tx1, ty1, tx2, ty2 = [offsets[:, i] for i in range(4)] 113 | # x1_true = x1 + tx1*w 114 | # y1_true = y1 + ty1*h 115 | # x2_true = x2 + tx2*w 116 | # y2_true = y2 + ty2*h 117 | # below is just more compact form of this 118 | 119 | # are offsets always such that 120 | # x1 < x2 and y1 < y2 ? 121 | 122 | translation = np.hstack([w, h, w, h])*offsets 123 | bboxes[:, 0:4] = bboxes[:, 0:4] + translation 124 | return bboxes 125 | 126 | 127 | def get_image_boxes(bounding_boxes, img, size = 24): 128 | """Cut out boxes from the image. 129 | 130 | Arguments: 131 | bounding_boxes: a float numpy array of shape [n, 5]. 132 | img: an instance of PIL.Image. 133 | size: an integer, size of cutouts. 134 | 135 | Returns: 136 | a float numpy array of shape [n, 3, size, size]. 137 | """ 138 | 139 | num_boxes = len(bounding_boxes) 140 | width, height = img.size 141 | 142 | [dy, edy, dx, edx, y, ey, x, ex, w, h] = correct_bboxes(bounding_boxes, width, height) 143 | img_boxes = np.zeros((num_boxes, 3, size, size), 'float32') 144 | 145 | for i in range(num_boxes): 146 | img_box = np.zeros((h[i], w[i], 3), 'uint8') 147 | 148 | img_array = np.asarray(img, 'uint8') 149 | img_box[dy[i]:(edy[i] + 1), dx[i]:(edx[i] + 1), :] =\ 150 | img_array[y[i]:(ey[i] + 1), x[i]:(ex[i] + 1), :] 151 | 152 | # resize 153 | img_box = Image.fromarray(img_box) 154 | img_box = img_box.resize((size, size), Image.BILINEAR) 155 | img_box = np.asarray(img_box, 'float32') 156 | 157 | img_boxes[i, :, :, :] = _preprocess(img_box) 158 | 159 | return img_boxes 160 | 161 | 162 | def correct_bboxes(bboxes, width, height): 163 | """Crop boxes that are too big and get coordinates 164 | with respect to cutouts. 165 | 166 | Arguments: 167 | bboxes: a float numpy array of shape [n, 5], 168 | where each row is (xmin, ymin, xmax, ymax, score). 169 | width: a float number. 170 | height: a float number. 171 | 172 | Returns: 173 | dy, dx, edy, edx: a int numpy arrays of shape [n], 174 | coordinates of the boxes with respect to the cutouts. 175 | y, x, ey, ex: a int numpy arrays of shape [n], 176 | corrected ymin, xmin, ymax, xmax. 177 | h, w: a int numpy arrays of shape [n], 178 | just heights and widths of boxes. 179 | 180 | in the following order: 181 | [dy, edy, dx, edx, y, ey, x, ex, w, h]. 182 | """ 183 | 184 | x1, y1, x2, y2 = [bboxes[:, i] for i in range(4)] 185 | w, h = x2 - x1 + 1.0, y2 - y1 + 1.0 186 | num_boxes = bboxes.shape[0] 187 | 188 | # 'e' stands for end 189 | # (x, y) -> (ex, ey) 190 | x, y, ex, ey = x1, y1, x2, y2 191 | 192 | # we need to cut out a box from the image. 193 | # (x, y, ex, ey) are corrected coordinates of the box 194 | # in the image. 195 | # (dx, dy, edx, edy) are coordinates of the box in the cutout 196 | # from the image. 197 | dx, dy = np.zeros((num_boxes,)), np.zeros((num_boxes,)) 198 | edx, edy = w.copy() - 1.0, h.copy() - 1.0 199 | 200 | # if box's bottom right corner is too far right 201 | ind = np.where(ex > width - 1.0)[0] 202 | edx[ind] = w[ind] + width - 2.0 - ex[ind] 203 | ex[ind] = width - 1.0 204 | 205 | # if box's bottom right corner is too low 206 | ind = np.where(ey > height - 1.0)[0] 207 | edy[ind] = h[ind] + height - 2.0 - ey[ind] 208 | ey[ind] = height - 1.0 209 | 210 | # if box's top left corner is too far left 211 | ind = np.where(x < 0.0)[0] 212 | dx[ind] = 0.0 - x[ind] 213 | x[ind] = 0.0 214 | 215 | # if box's top left corner is too high 216 | ind = np.where(y < 0.0)[0] 217 | dy[ind] = 0.0 - y[ind] 218 | y[ind] = 0.0 219 | 220 | return_list = [dy, edy, dx, edx, y, ey, x, ex, w, h] 221 | return_list = [i.astype('int32') for i in return_list] 222 | 223 | return return_list 224 | 225 | 226 | def _preprocess(img): 227 | """Preprocessing step before feeding the network. 228 | 229 | Arguments: 230 | img: a float numpy array of shape [h, w, c]. 231 | 232 | Returns: 233 | a float numpy array of shape [1, c, h, w]. 234 | """ 235 | img = img.transpose((2, 0, 1)) 236 | img = np.expand_dims(img, 0) 237 | img = (img - 127.5) * 0.0078125 238 | return img 239 | -------------------------------------------------------------------------------- /data_preprocessing/align/detector.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import Variable 4 | from get_nets import PNet, RNet, ONet 5 | from box_utils import nms, calibrate_box, get_image_boxes, convert_to_square 6 | from first_stage import run_first_stage 7 | 8 | 9 | def detect_faces(image, min_face_size = 20.0, 10 | thresholds=[0.6, 0.7, 0.8], 11 | nms_thresholds=[0.7, 0.7, 0.7]): 12 | """ 13 | Arguments: 14 | image: an instance of PIL.Image. 15 | min_face_size: a float number. 16 | thresholds: a list of length 3. 17 | nms_thresholds: a list of length 3. 18 | 19 | Returns: 20 | two float numpy arrays of shapes [n_boxes, 4] and [n_boxes, 10], 21 | bounding boxes and facial landmarks. 22 | """ 23 | 24 | # LOAD MODELS 25 | pnet = PNet() 26 | rnet = RNet() 27 | onet = ONet() 28 | onet.eval() 29 | 30 | # BUILD AN IMAGE PYRAMID 31 | width, height = image.size 32 | min_length = min(height, width) 33 | 34 | min_detection_size = 12 35 | factor = 0.707 # sqrt(0.5) 36 | 37 | # scales for scaling the image 38 | scales = [] 39 | 40 | # scales the image so that 41 | # minimum size that we can detect equals to 42 | # minimum face size that we want to detect 43 | m = min_detection_size/min_face_size 44 | min_length *= m 45 | 46 | factor_count = 0 47 | while min_length > min_detection_size: 48 | scales.append(m*factor**factor_count) 49 | min_length *= factor 50 | factor_count += 1 51 | 52 | # STAGE 1 53 | 54 | # it will be returned 55 | bounding_boxes = [] 56 | 57 | # run P-Net on different scales 58 | for s in scales: 59 | boxes = run_first_stage(image, pnet, scale = s, threshold = thresholds[0]) 60 | bounding_boxes.append(boxes) 61 | 62 | # collect boxes (and offsets, and scores) from different scales 63 | bounding_boxes = [i for i in bounding_boxes if i is not None] 64 | bounding_boxes = np.vstack(bounding_boxes) 65 | 66 | keep = nms(bounding_boxes[:, 0:5], nms_thresholds[0]) 67 | bounding_boxes = bounding_boxes[keep] 68 | 69 | # use offsets predicted by pnet to transform bounding boxes 70 | bounding_boxes = calibrate_box(bounding_boxes[:, 0:5], bounding_boxes[:, 5:]) 71 | # shape [n_boxes, 5] 72 | 73 | bounding_boxes = convert_to_square(bounding_boxes) 74 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 75 | 76 | # STAGE 2 77 | 78 | img_boxes = get_image_boxes(bounding_boxes, image, size = 24) 79 | img_boxes = Variable(torch.FloatTensor(img_boxes), volatile = True) 80 | output = rnet(img_boxes) 81 | offsets = output[0].data.numpy() # shape [n_boxes, 4] 82 | probs = output[1].data.numpy() # shape [n_boxes, 2] 83 | 84 | keep = np.where(probs[:, 1] > thresholds[1])[0] 85 | bounding_boxes = bounding_boxes[keep] 86 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1, )) 87 | offsets = offsets[keep] 88 | 89 | keep = nms(bounding_boxes, nms_thresholds[1]) 90 | bounding_boxes = bounding_boxes[keep] 91 | bounding_boxes = calibrate_box(bounding_boxes, offsets[keep]) 92 | bounding_boxes = convert_to_square(bounding_boxes) 93 | bounding_boxes[:, 0:4] = np.round(bounding_boxes[:, 0:4]) 94 | 95 | # STAGE 3 96 | 97 | img_boxes = get_image_boxes(bounding_boxes, image, size = 48) 98 | if len(img_boxes) == 0: 99 | return [], [] 100 | img_boxes = Variable(torch.FloatTensor(img_boxes), volatile = True) 101 | output = onet(img_boxes) 102 | landmarks = output[0].data.numpy() # shape [n_boxes, 10] 103 | offsets = output[1].data.numpy() # shape [n_boxes, 4] 104 | probs = output[2].data.numpy() # shape [n_boxes, 2] 105 | 106 | keep = np.where(probs[:, 1] > thresholds[2])[0] 107 | bounding_boxes = bounding_boxes[keep] 108 | bounding_boxes[:, 4] = probs[keep, 1].reshape((-1, )) 109 | offsets = offsets[keep] 110 | landmarks = landmarks[keep] 111 | 112 | # compute landmark points 113 | width = bounding_boxes[:, 2] - bounding_boxes[:, 0] + 1.0 114 | height = bounding_boxes[:, 3] - bounding_boxes[:, 1] + 1.0 115 | xmin, ymin = bounding_boxes[:, 0], bounding_boxes[:, 1] 116 | landmarks[:, 0:5] = np.expand_dims(xmin, 1) + np.expand_dims(width, 1)*landmarks[:, 0:5] 117 | landmarks[:, 5:10] = np.expand_dims(ymin, 1) + np.expand_dims(height, 1)*landmarks[:, 5:10] 118 | 119 | bounding_boxes = calibrate_box(bounding_boxes, offsets) 120 | keep = nms(bounding_boxes, nms_thresholds[2], mode = 'min') 121 | bounding_boxes = bounding_boxes[keep] 122 | landmarks = landmarks[keep] 123 | 124 | return bounding_boxes, landmarks 125 | -------------------------------------------------------------------------------- /data_preprocessing/align/face_align.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from detector import detect_faces 3 | from align_trans import get_reference_facial_points, warp_and_crop_face 4 | import numpy as np 5 | import os 6 | from tqdm import tqdm 7 | import argparse 8 | 9 | 10 | if __name__ == '__main__': 11 | parser = argparse.ArgumentParser(description = "face alignment") 12 | parser.add_argument("-source_root", "--source_root", help = "specify your source dir", default = "./data/test", type = str) 13 | parser.add_argument("-dest_root", "--dest_root", help = "specify your destination dir", default = "./data/test_Aligned", type = str) 14 | parser.add_argument("-crop_size", "--crop_size", help = "specify size of aligned faces, align and crop with padding", default = 112, type = int) 15 | args = parser.parse_args() 16 | 17 | source_root = args.source_root # specify your source dir 18 | dest_root = args.dest_root # specify your destination dir 19 | crop_size = args.crop_size # specify size of aligned faces, align and crop with padding 20 | scale = crop_size / 112. 21 | reference = get_reference_facial_points(default_square = True) * scale 22 | 23 | cwd = os.getcwd() # delete '.DS_Store' existed in the source_root 24 | os.chdir(source_root) 25 | os.system("find . -name '*.DS_Store' -type f -delete") 26 | os.chdir(cwd) 27 | 28 | if not os.path.isdir(dest_root): 29 | os.mkdir(dest_root) 30 | 31 | for subfolder in tqdm(os.listdir(source_root)): 32 | if not os.path.isdir(os.path.join(dest_root, subfolder)): 33 | os.mkdir(os.path.join(dest_root, subfolder)) 34 | for image_name in os.listdir(os.path.join(source_root, subfolder)): 35 | print("Processing\t{}".format(os.path.join(source_root, subfolder, image_name))) 36 | img = Image.open(os.path.join(source_root, subfolder, image_name)) 37 | try: # Handle exception 38 | _, landmarks = detect_faces(img) 39 | except Exception: 40 | print("{} is discarded due to exception!".format(os.path.join(source_root, subfolder, image_name))) 41 | continue 42 | if len(landmarks) == 0: # If the landmarks cannot be detected, the img will be discarded 43 | print("{} is discarded due to non-detected landmarks!".format(os.path.join(source_root, subfolder, image_name))) 44 | continue 45 | facial5points = [[landmarks[0][j], landmarks[0][j + 5]] for j in range(5)] 46 | warped_face = warp_and_crop_face(np.array(img), facial5points, reference, crop_size=(crop_size, crop_size)) 47 | img_warped = Image.fromarray(warped_face) 48 | if image_name.split('.')[-1].lower() not in ['jpg', 'jpeg']: #not from jpg 49 | image_name = '.'.join(image_name.split('.')[:-1]) + '.jpg' 50 | img_warped.save(os.path.join(dest_root, subfolder, image_name)) 51 | -------------------------------------------------------------------------------- /data_preprocessing/align/face_align_portrait.py: -------------------------------------------------------------------------------- 1 | # final count: source: 7 target: 1054 2 | from PIL import Image 3 | from .align_trans import get_reference_facial_points, warp_and_crop_face 4 | from torchvision.transforms import ToTensor, ToPILImage 5 | import numpy as np 6 | import os 7 | from tqdm import tqdm 8 | import argparse 9 | import json 10 | from pathlib import Path 11 | from tqdm import tqdm 12 | import cv2 13 | import torch 14 | import warnings 15 | warnings.filterwarnings("ignore") 16 | 17 | def compute_area(item): 18 | return -np.prod(item['box'][-2:]) 19 | 20 | keys = ['left_eye', 'right_eye', 'nose', 'mouth_left', 'mouth_right'] 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser(description = "face alignment") 23 | parser.add_argument("-dest_root", "--dest_root", help = "specify your destination dir", default = "data/portrait/align112x112", type = str) 24 | parser.add_argument("-crop_size", "--crop_size", help = "specify size of aligned faces, align and crop with padding", default = 112, type = int) 25 | args = parser.parse_args() 26 | 27 | crop_size = args.crop_size 28 | scale = crop_size / 112. 29 | reference = get_reference_facial_points(default_square = True) * scale 30 | 31 | # unsequential 32 | results_all = json.load(open('data/portrait/mtcnn/mtcnn_256.json')) 33 | default_tfm = np.array([[ 2.90431126e-01, 1.89934467e-03, -1.88962605e+01], 34 | [-1.90354592e-03, 2.90477119e-01, -1.70081139e+01]]) 35 | 36 | H1 = W1 = 256 37 | H2 = W2 = 112 38 | A = np.array([[2 / (W1 - 1), 0, -1], [0, 2 / (H1 - 1), -1], [0, 0, 1]]) 39 | B = np.linalg.inv(np.array([[2 / (W2 - 1), 0, -1], [0, 2 / (H2 - 1), -1], [0, 0, 1]])) 40 | C = np.array([[0, 0, 1]]) 41 | 42 | def tfm2theta(tfm): 43 | ttt = np.concatenate([tfm, C], axis=0) 44 | ttt = np.linalg.inv(ttt) 45 | theta = A @ ttt @ B 46 | return theta[:2] 47 | 48 | all_tfms = 0 49 | 50 | use_torch = False 51 | to_tensor = ToTensor() 52 | to_pil = ToPILImage() 53 | save_img = True 54 | error_path = 'data/portrait/error_img.json' 55 | if os.path.exists(error_path): 56 | error_img = json.load(open(error_path, 'r')) 57 | else: 58 | error_img = {'source': [], 'target': []} 59 | affine_theta_all = {} 60 | for type in ['source', 'target']: 61 | cnt = 0 62 | affine_theta_all[type] = {} 63 | img_list = os.listdir(os.path.join('data/portrait/align', type)) 64 | for img, value in tqdm(results_all[type].items()): 65 | value = sorted(value, key=compute_area) 66 | if len(value) == 0: 67 | error_img[type].append(img) 68 | continue 69 | value = value[0] 70 | facial5points = [value['keypoints'][key] for key in keys] 71 | tfm = warp_and_crop_face(None, facial5points, reference, crop_size=(crop_size, crop_size), return_tfm=True) 72 | 73 | all_tfms += tfm 74 | cnt += 1 75 | 76 | theta = tfm2theta(tfm).tolist() 77 | affine_theta_all[type][img] = theta 78 | 79 | if save_img: 80 | image = Image.open(os.path.join('data/portrait/align', type, img)) 81 | if not use_torch: 82 | face_img = cv2.warpAffine(np.array(image), tfm, (crop_size, crop_size)) 83 | img_warped = Image.fromarray(face_img) 84 | else: 85 | image = to_tensor(image)[None].float() 86 | image = torch.nn.functional.interpolate(image, size=(256, 256)) 87 | theta = torch.tensor(tfm2theta(tfm)[None]).float() 88 | grid = torch.nn.functional.affine_grid(theta, size=(1, 3, crop_size, crop_size)) 89 | image = torch.nn.functional.grid_sample(image, grid) 90 | img_warped = to_pil(image[0]) 91 | 92 | os.makedirs(os.path.join(args.dest_root, type), exist_ok=True) 93 | img_warped.save(os.path.join(args.dest_root, type, img)) 94 | print('type: {}, cnt: {}'.format(type, cnt)) 95 | 96 | json.dump(affine_theta_all, open('data/portrait/affine_theta.json', 'w'), indent=4) 97 | json.dump(error_img, open(error_path, 'w'), indent=4) -------------------------------------------------------------------------------- /data_preprocessing/align/face_resize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from tqdm import tqdm 4 | 5 | 6 | def mkdir(path): 7 | if not os.path.exists(path): 8 | os.mkdir(path) 9 | 10 | 11 | def process_image(img): 12 | 13 | size = img.shape 14 | h, w = size[0], size[1] 15 | scale = max(w, h) / float(min_side) 16 | new_w, new_h = int(w / scale), int(h / scale) 17 | resize_img = cv2.resize(img, (new_w, new_h)) 18 | if new_w % 2 != 0 and new_h % 2 == 0: 19 | top, bottom, left, right = (min_side - new_h) / 2, (min_side - new_h) / 2, (min_side - new_w) / 2 + 1, ( 20 | min_side - new_w) / 2 21 | elif new_h % 2 != 0 and new_w % 2 == 0: 22 | top, bottom, left, right = (min_side - new_h) / 2 + 1, (min_side - new_h) / 2, (min_side - new_w) / 2, ( 23 | min_side - new_w) / 2 24 | elif new_h % 2 == 0 and new_w % 2 == 0: 25 | top, bottom, left, right = (min_side - new_h) / 2, (min_side - new_h) / 2, (min_side - new_w) / 2, ( 26 | min_side - new_w) / 2 27 | else: 28 | top, bottom, left, right = (min_side - new_h) / 2 + 1, (min_side - new_h) / 2, (min_side - new_w) / 2 + 1, ( 29 | min_side - new_w) / 2 30 | pad_img = cv2.copyMakeBorder(resize_img, top, bottom, left, right, cv2.BORDER_CONSTANT, 31 | value=[0, 0, 0]) 32 | 33 | return pad_img 34 | 35 | 36 | def main(source_root): 37 | 38 | dest_root = "/media/pc/6T/jasonjzhao/data/MS-Celeb-1M_Resized" 39 | mkdir(dest_root) 40 | cwd = os.getcwd() # delete '.DS_Store' existed in the source_root 41 | os.chdir(source_root) 42 | os.system("find . -name '*.DS_Store' -type f -delete") 43 | os.chdir(cwd) 44 | 45 | if not os.path.isdir(dest_root): 46 | os.mkdir(dest_root) 47 | 48 | for subfolder in tqdm(os.listdir(source_root)): 49 | if not os.path.isdir(os.path.join(dest_root, subfolder)): 50 | os.mkdir(os.path.join(dest_root, subfolder)) 51 | for image_name in os.listdir(os.path.join(source_root, subfolder)): 52 | print("Processing\t{}".format(os.path.join(source_root, subfolder, image_name))) 53 | img = cv2.imread(os.path.join(source_root, subfolder, image_name)) 54 | if type(img) == type(None): 55 | print("damaged image %s, del it" % (img)) 56 | os.remove(img) 57 | continue 58 | size = img.shape 59 | h, w = size[0], size[1] 60 | if max(w, h) > 512: 61 | img_pad = process_image(img) 62 | else: 63 | img_pad = img 64 | cv2.imwrite(os.path.join(dest_root, subfolder, image_name.split('.')[0] + '.jpg'), img_pad) 65 | 66 | 67 | if __name__ == "__main__": 68 | min_side = 512 69 | main(source_root = "/media/pc/6T/jasonjzhao/data/MS-Celeb-1M/database/base") -------------------------------------------------------------------------------- /data_preprocessing/align/first_stage.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import math 4 | from PIL import Image 5 | import numpy as np 6 | from box_utils import nms, _preprocess 7 | 8 | 9 | def run_first_stage(image, net, scale, threshold): 10 | """Run P-Net, generate bounding boxes, and do NMS. 11 | 12 | Arguments: 13 | image: an instance of PIL.Image. 14 | net: an instance of pytorch's nn.Module, P-Net. 15 | scale: a float number, 16 | scale width and height of the image by this number. 17 | threshold: a float number, 18 | threshold on the probability of a face when generating 19 | bounding boxes from predictions of the net. 20 | 21 | Returns: 22 | a float numpy array of shape [n_boxes, 9], 23 | bounding boxes with scores and offsets (4 + 1 + 4). 24 | """ 25 | 26 | # scale the image and convert it to a float array 27 | width, height = image.size 28 | sw, sh = math.ceil(width*scale), math.ceil(height*scale) 29 | img = image.resize((sw, sh), Image.BILINEAR) 30 | img = np.asarray(img, 'float32') 31 | 32 | img = Variable(torch.FloatTensor(_preprocess(img)), volatile = True) 33 | output = net(img) 34 | probs = output[1].data.numpy()[0, 1, :, :] 35 | offsets = output[0].data.numpy() 36 | # probs: probability of a face at each sliding window 37 | # offsets: transformations to true bounding boxes 38 | 39 | boxes = _generate_bboxes(probs, offsets, scale, threshold) 40 | if len(boxes) == 0: 41 | return None 42 | 43 | keep = nms(boxes[:, 0:5], overlap_threshold = 0.5) 44 | return boxes[keep] 45 | 46 | 47 | def _generate_bboxes(probs, offsets, scale, threshold): 48 | """Generate bounding boxes at places 49 | where there is probably a face. 50 | 51 | Arguments: 52 | probs: a float numpy array of shape [n, m]. 53 | offsets: a float numpy array of shape [1, 4, n, m]. 54 | scale: a float number, 55 | width and height of the image were scaled by this number. 56 | threshold: a float number. 57 | 58 | Returns: 59 | a float numpy array of shape [n_boxes, 9] 60 | """ 61 | 62 | # applying P-Net is equivalent, in some sense, to 63 | # moving 12x12 window with stride 2 64 | stride = 2 65 | cell_size = 12 66 | 67 | # indices of boxes where there is probably a face 68 | inds = np.where(probs > threshold) 69 | 70 | if inds[0].size == 0: 71 | return np.array([]) 72 | 73 | # transformations of bounding boxes 74 | tx1, ty1, tx2, ty2 = [offsets[0, i, inds[0], inds[1]] for i in range(4)] 75 | # they are defined as: 76 | # w = x2 - x1 + 1 77 | # h = y2 - y1 + 1 78 | # x1_true = x1 + tx1*w 79 | # x2_true = x2 + tx2*w 80 | # y1_true = y1 + ty1*h 81 | # y2_true = y2 + ty2*h 82 | 83 | offsets = np.array([tx1, ty1, tx2, ty2]) 84 | score = probs[inds[0], inds[1]] 85 | 86 | # P-Net is applied to scaled images 87 | # so we need to rescale bounding boxes back 88 | bounding_boxes = np.vstack([ 89 | np.round((stride*inds[1] + 1.0)/scale), 90 | np.round((stride*inds[0] + 1.0)/scale), 91 | np.round((stride*inds[1] + 1.0 + cell_size)/scale), 92 | np.round((stride*inds[0] + 1.0 + cell_size)/scale), 93 | score, offsets 94 | ]) 95 | # why one is added? 96 | 97 | return bounding_boxes.T -------------------------------------------------------------------------------- /data_preprocessing/align/get_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | import numpy as np 6 | 7 | 8 | class Flatten(nn.Module): 9 | 10 | def __init__(self): 11 | super(Flatten, self).__init__() 12 | 13 | def forward(self, x): 14 | """ 15 | Arguments: 16 | x: a float tensor with shape [batch_size, c, h, w]. 17 | Returns: 18 | a float tensor with shape [batch_size, c*h*w]. 19 | """ 20 | 21 | # without this pretrained model isn't working 22 | x = x.transpose(3, 2).contiguous() 23 | 24 | return x.view(x.size(0), -1) 25 | 26 | 27 | class PNet(nn.Module): 28 | 29 | def __init__(self): 30 | 31 | super(PNet, self).__init__() 32 | 33 | # suppose we have input with size HxW, then 34 | # after first layer: H - 2, 35 | # after pool: ceil((H - 2)/2), 36 | # after second conv: ceil((H - 2)/2) - 2, 37 | # after last conv: ceil((H - 2)/2) - 4, 38 | # and the same for W 39 | 40 | self.features = nn.Sequential(OrderedDict([ 41 | ('conv1', nn.Conv2d(3, 10, 3, 1)), 42 | ('prelu1', nn.PReLU(10)), 43 | ('pool1', nn.MaxPool2d(2, 2, ceil_mode = True)), 44 | 45 | ('conv2', nn.Conv2d(10, 16, 3, 1)), 46 | ('prelu2', nn.PReLU(16)), 47 | 48 | ('conv3', nn.Conv2d(16, 32, 3, 1)), 49 | ('prelu3', nn.PReLU(32)) 50 | ])) 51 | 52 | self.conv4_1 = nn.Conv2d(32, 2, 1, 1) 53 | self.conv4_2 = nn.Conv2d(32, 4, 1, 1) 54 | 55 | weights = np.load("./pnet.npy", allow_pickle=True)[()] 56 | for n, p in self.named_parameters(): 57 | p.data = torch.FloatTensor(weights[n]) 58 | 59 | def forward(self, x): 60 | """ 61 | Arguments: 62 | x: a float tensor with shape [batch_size, 3, h, w]. 63 | Returns: 64 | b: a float tensor with shape [batch_size, 4, h', w']. 65 | a: a float tensor with shape [batch_size, 2, h', w']. 66 | """ 67 | x = self.features(x) 68 | a = self.conv4_1(x) 69 | b = self.conv4_2(x) 70 | a = F.softmax(a) 71 | return b, a 72 | 73 | 74 | class RNet(nn.Module): 75 | 76 | def __init__(self): 77 | 78 | super(RNet, self).__init__() 79 | 80 | self.features = nn.Sequential(OrderedDict([ 81 | ('conv1', nn.Conv2d(3, 28, 3, 1)), 82 | ('prelu1', nn.PReLU(28)), 83 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode = True)), 84 | 85 | ('conv2', nn.Conv2d(28, 48, 3, 1)), 86 | ('prelu2', nn.PReLU(48)), 87 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode = True)), 88 | 89 | ('conv3', nn.Conv2d(48, 64, 2, 1)), 90 | ('prelu3', nn.PReLU(64)), 91 | 92 | ('flatten', Flatten()), 93 | ('conv4', nn.Linear(576, 128)), 94 | ('prelu4', nn.PReLU(128)) 95 | ])) 96 | 97 | self.conv5_1 = nn.Linear(128, 2) 98 | self.conv5_2 = nn.Linear(128, 4) 99 | 100 | weights = np.load("./rnet.npy", allow_pickle=True)[()] 101 | for n, p in self.named_parameters(): 102 | p.data = torch.FloatTensor(weights[n]) 103 | 104 | def forward(self, x): 105 | """ 106 | Arguments: 107 | x: a float tensor with shape [batch_size, 3, h, w]. 108 | Returns: 109 | b: a float tensor with shape [batch_size, 4]. 110 | a: a float tensor with shape [batch_size, 2]. 111 | """ 112 | x = self.features(x) 113 | a = self.conv5_1(x) 114 | b = self.conv5_2(x) 115 | a = F.softmax(a) 116 | return b, a 117 | 118 | 119 | class ONet(nn.Module): 120 | 121 | def __init__(self): 122 | 123 | super(ONet, self).__init__() 124 | 125 | self.features = nn.Sequential(OrderedDict([ 126 | ('conv1', nn.Conv2d(3, 32, 3, 1)), 127 | ('prelu1', nn.PReLU(32)), 128 | ('pool1', nn.MaxPool2d(3, 2, ceil_mode = True)), 129 | 130 | ('conv2', nn.Conv2d(32, 64, 3, 1)), 131 | ('prelu2', nn.PReLU(64)), 132 | ('pool2', nn.MaxPool2d(3, 2, ceil_mode = True)), 133 | 134 | ('conv3', nn.Conv2d(64, 64, 3, 1)), 135 | ('prelu3', nn.PReLU(64)), 136 | ('pool3', nn.MaxPool2d(2, 2, ceil_mode = True)), 137 | 138 | ('conv4', nn.Conv2d(64, 128, 2, 1)), 139 | ('prelu4', nn.PReLU(128)), 140 | 141 | ('flatten', Flatten()), 142 | ('conv5', nn.Linear(1152, 256)), 143 | ('drop5', nn.Dropout(0.25)), 144 | ('prelu5', nn.PReLU(256)), 145 | ])) 146 | 147 | self.conv6_1 = nn.Linear(256, 2) 148 | self.conv6_2 = nn.Linear(256, 4) 149 | self.conv6_3 = nn.Linear(256, 10) 150 | 151 | weights = np.load("./onet.npy", allow_pickle=True)[()] 152 | for n, p in self.named_parameters(): 153 | p.data = torch.FloatTensor(weights[n]) 154 | 155 | def forward(self, x): 156 | """ 157 | Arguments: 158 | x: a float tensor with shape [batch_size, 3, h, w]. 159 | Returns: 160 | c: a float tensor with shape [batch_size, 10]. 161 | b: a float tensor with shape [batch_size, 4]. 162 | a: a float tensor with shape [batch_size, 2]. 163 | """ 164 | x = self.features(x) 165 | a = self.conv6_1(x) 166 | b = self.conv6_2(x) 167 | c = self.conv6_3(x) 168 | a = F.softmax(a) 169 | return c, b, a -------------------------------------------------------------------------------- /data_preprocessing/align/onet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wl-zhao/DiffSwap/8596b4d635e3d97621df688245b365bc4d3ae02a/data_preprocessing/align/onet.npy -------------------------------------------------------------------------------- /data_preprocessing/align/pnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wl-zhao/DiffSwap/8596b4d635e3d97621df688245b365bc4d3ae02a/data_preprocessing/align/pnet.npy -------------------------------------------------------------------------------- /data_preprocessing/align/rnet.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wl-zhao/DiffSwap/8596b4d635e3d97621df688245b365bc4d3ae02a/data_preprocessing/align/rnet.npy -------------------------------------------------------------------------------- /data_preprocessing/align/visualization_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageDraw 2 | 3 | 4 | def show_results(img, bounding_boxes, facial_landmarks = []): 5 | """Draw bounding boxes and facial landmarks. 6 | Arguments: 7 | img: an instance of PIL.Image. 8 | bounding_boxes: a float numpy array of shape [n, 5]. 9 | facial_landmarks: a float numpy array of shape [n, 10]. 10 | Returns: 11 | an instance of PIL.Image. 12 | """ 13 | img_copy = img.copy() 14 | draw = ImageDraw.Draw(img_copy) 15 | 16 | for b in bounding_boxes: 17 | draw.rectangle([ 18 | (b[0], b[1]), (b[2], b[3]) 19 | ], outline = 'white') 20 | 21 | inx = 0 22 | for p in facial_landmarks: 23 | for i in range(5): 24 | draw.ellipse([ 25 | (p[i] - 1.0, p[i + 5] - 1.0), 26 | (p[i] + 1.0, p[i + 5] + 1.0) 27 | ], outline = 'blue') 28 | 29 | return img_copy -------------------------------------------------------------------------------- /data_preprocessing/detection/detcect_faces_portrait.py: -------------------------------------------------------------------------------- 1 | from mtcnn import MTCNN 2 | import cv2 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import sys 7 | 8 | 9 | if __name__ == '__main__': 10 | gpu_num = 1 11 | gpu = sys.argv[1] 12 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu) 13 | mtcnn_path = 'data/portrait/mtcnn' 14 | data_path = 'data/portrait/align' 15 | detector = MTCNN() 16 | results_all = {} 17 | for type in ['source', 'target']: 18 | results_all[type] = {} 19 | img_list = os.listdir(os.path.join(data_path, type)) 20 | img_list.sort() 21 | count = 0 22 | for img_idx, img in enumerate(tqdm(img_list)): 23 | if img_idx % gpu_num != int(gpu): 24 | continue 25 | 26 | count += 1 27 | image = cv2.cvtColor(cv2.imread(os.path.join(data_path, type, img)), cv2.COLOR_BGR2RGB) 28 | result = detector.detect_faces(image) 29 | results_all[type][img] = result 30 | 31 | os.makedirs(mtcnn_path, exist_ok=True) 32 | print(f'gpu {gpu} process {count} images') 33 | json.dump(results_all, open(os.path.join(mtcnn_path, f'mtcnn_{gpu}.json'), 'w'), indent=4) -------------------------------------------------------------------------------- /data_preprocessing/detection/merge_mtcnn_portrait.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | if __name__ == '__main__': 5 | gpu_num = 1 6 | mtcnns = {} 7 | for type in ['source', 'target']: 8 | mtcnns[type] = {} 9 | for i in range(gpu_num): 10 | tmp = json.load(open('data/portrait/mtcnn/mtcnn_{}.json'.format(i), 'r')) 11 | mtcnns[type].update(tmp[type]) 12 | 13 | for i, j in mtcnns[type].items(): 14 | print(type, i) 15 | json.dump(mtcnns, open('data/portrait/mtcnn/mtcnn_256.json', 'w'), indent=4) -------------------------------------------------------------------------------- /data_preprocessing/detection/run_detect_faces_portrait.sh: -------------------------------------------------------------------------------- 1 | python data_preprocessing/detection/detcect_faces_portrait.py 0 -------------------------------------------------------------------------------- /ldm/data/portrait.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import pickle 4 | from PIL import Image 5 | import io 6 | from torchvision import transforms 7 | import cv2 8 | import pdb 9 | import PIL 10 | import numpy as np 11 | from einops import rearrange 12 | from scipy.spatial import ConvexHull 13 | import random 14 | import json 15 | from math import floor 16 | import os 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | class Portrait(Dataset): 21 | def __init__(self, root, size = 256, base_res = 256, flip=False, 22 | interpolation="bicubic", dilate=False, convex_hull=True): 23 | super().__init__() 24 | self.size = size 25 | self.root = root 26 | self.base_res = base_res 27 | self.error_img = json.load(open(f'{root}/error_img.json')) 28 | 29 | self.lmk_path = f'{root}/landmark/landmark_256.pkl' 30 | self.landmarks = pickle.load(open(self.lmk_path,'rb')) 31 | 32 | self.src_list = os.listdir(f'{root}/source') 33 | self.src_list = [x for x in self.src_list if x not in self.error_img['source']] 34 | self.src_list.sort() 35 | self.tgt_list = os.listdir(f'{root}/target') 36 | self.tgt_list = [x for x in self.tgt_list if x not in self.error_img['target']] 37 | self.tgt_list.sort() 38 | 39 | print(f'len(self.src_list): {len(self.src_list)}') 40 | self.affine_thetas = json.load(open(f'{root}/affine_theta.json')) 41 | 42 | self.interpolation = {"linear": PIL.Image.LINEAR, 43 | "bilinear": PIL.Image.BILINEAR, 44 | "bicubic": PIL.Image.BICUBIC, 45 | "lanczos": PIL.Image.LANCZOS, 46 | }[interpolation] 47 | self.convex_hull = convex_hull 48 | 49 | 50 | all_indices = np.arange(0, 68) 51 | self.landmark_indices = { 52 | # 'face': all_indices[:17].tolist() + all_indices[17:27].tolist(), 53 | 'l_eye': all_indices[36:42].tolist(), 54 | 'r_eye': all_indices[42:48].tolist(), 55 | 'nose': all_indices[27:36].tolist(), 56 | 'mouth': all_indices[48:68].tolist(), 57 | } 58 | self.dilate = dilate 59 | if dilate: 60 | self.dilate_kernel = np.ones((11, 11), np.uint8) 61 | 62 | def __len__(self): 63 | return len(self.src_list) * len(self.tgt_list) # 9 * 1039 = 9351 64 | 65 | def __getitem__(self, index): # index: 0 - 9350 66 | src_index = floor(index / len(self.tgt_list)) # 0 - 8 67 | tgt_index = index - len(self.tgt_list) * src_index # 0 - 1038 68 | batch = {} 69 | 70 | for type in ['source', 'target']: 71 | if type == 'source': 72 | image = Image.open(os.path.join(f'{self.root}/align/{type}', self.src_list[src_index])).convert('RGB') 73 | affine_theta = np.array(self.affine_thetas[type][self.src_list[src_index]], dtype=np.float32) 74 | landmark = torch.tensor(self.landmarks[type][self.src_list[src_index]]) / self.base_res 75 | elif type == 'target': 76 | image = Image.open(os.path.join(f'{self.root}/align/{type}', self.tgt_list[tgt_index])).convert('RGB') 77 | affine_theta = np.array(self.affine_thetas[type][self.tgt_list[tgt_index]], dtype=np.float32) 78 | landmark = torch.tensor(self.landmarks[type][self.tgt_list[tgt_index]]) / self.base_res 79 | 80 | image = image.resize((self.size, self.size), resample=self.interpolation) 81 | image = np.array(image).astype(np.uint8) 82 | image = (image / 127.5 - 1.0).astype(np.float32) 83 | 84 | if type == 'source': 85 | batch['mask_organ_src'] = self.mask_organ_src(landmark) 86 | batch['image_src'] = image 87 | # batch['image_src']'s identity 88 | batch['affine_theta_src'] = affine_theta 89 | batch['src'] = self.src_list[src_index] 90 | else: 91 | batch['image'] = image 92 | # batch['image']'s identity 93 | batch['landmark'] = landmark 94 | batch['affine_theta'] = affine_theta 95 | batch['target'] = self.tgt_list[tgt_index] 96 | 97 | if self.convex_hull: 98 | mask_dict = self.extract_convex_hulls(batch['landmark']) 99 | batch.update(mask_dict) 100 | return batch 101 | 102 | def mask_organ_src(self, landmark): 103 | mask_organ = [] 104 | for key, indices in self.landmark_indices.items(): 105 | mask_key = self.extract_convex_hull(landmark[indices]) 106 | if self.dilate: 107 | # mask_key = mask_key[:, :, None] 108 | # mask_key = repeat(mask_key, 'h w -> h w k', k=3) 109 | # print(mask_key.shape, type(mask_key)) 110 | mask_key = mask_key.astype(np.uint8) 111 | mask_key = cv2.dilate(mask_key, self.dilate_kernel, iterations=1) 112 | mask_organ.append(mask_key) 113 | return np.stack(mask_organ) 114 | 115 | 116 | def extract_convex_hulls(self, landmark): 117 | mask_dict = {} 118 | mask_organ = [] 119 | for key, indices in self.landmark_indices.items(): 120 | mask_key = self.extract_convex_hull(landmark[indices]) 121 | if self.dilate: 122 | # mask_key = mask_key[:, :, None] 123 | # mask_key = repeat(mask_key, 'h w -> h w k', k=3) 124 | # print(mask_key.shape, type(mask_key)) 125 | mask_key = mask_key.astype(np.uint8) 126 | mask_key = cv2.dilate(mask_key, self.dilate_kernel, iterations=1) 127 | mask_organ.append(mask_key) 128 | mask_organ = np.stack(mask_organ) # (4, 256, 256) 129 | mask_dict['mask_organ'] = mask_organ 130 | mask_dict['mask'] = self.extract_convex_hull(landmark) 131 | return mask_dict 132 | 133 | def extract_convex_hull(self, landmark): 134 | landmark = landmark * self.size 135 | hull = ConvexHull(landmark) 136 | image = np.zeros((self.size, self.size)) 137 | points = [landmark[hull.vertices, :1], landmark[hull.vertices, 1:]] 138 | points = np.concatenate(points, axis=-1).astype('int32') 139 | mask = cv2.fillPoly(image, pts=[points], color=(255,255,255)) 140 | mask = mask > 0 141 | return mask 142 | 143 | def visualize(batch): 144 | n = len(batch['image']) 145 | os.makedirs('ldm/data/debug', exist_ok=True) 146 | for i in range(n): 147 | print(i) 148 | print(batch['src'][i], batch['target'][i]) 149 | image = (batch['image'][i] + 1) / 2 * 255 150 | image = rearrange(image, 'h w c -> c h w') 151 | image = image.numpy().transpose(1, 2, 0).astype('uint8').copy() 152 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 153 | cv2.imwrite(f'ldm/data/debug/{i}_tgt.png', image) 154 | 155 | image_src = (batch['image_src'][i] + 1) / 2 * 255 156 | image_src = rearrange(image_src, 'h w c -> c h w') 157 | image_src = image_src.numpy().transpose(1, 2, 0).astype('uint8').copy() 158 | image_src = cv2.cvtColor(image_src, cv2.COLOR_RGB2BGR) 159 | cv2.imwrite(f'ldm/data/debug/{i}_src.png', image_src) 160 | 161 | lmk = (batch['landmark'][i] * image.shape[0]).numpy().astype('int32') 162 | for k in range(68): 163 | image = cv2.circle(image, (lmk[k, 0], lmk[k, 1]), 3, (255, 0, 255), thickness=-1) 164 | cv2.imwrite(f'ldm/data/debug/{i}_lmk.png', image) 165 | 166 | mask = (batch['mask'][i].numpy() * 255).astype('uint8') #[:, :, None] 167 | mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) 168 | cv2.imwrite(f'ldm/data/debug/{i}_mask.png', mask) 169 | 170 | mask_organ = (batch['mask_organ'][i][0].numpy() * 255).astype('uint8') #[:, :, None] 171 | mask_organ = cv2.cvtColor(mask_organ, cv2.COLOR_GRAY2BGR) 172 | cv2.imwrite(f'ldm/data/debug/{i}_mask_organ.png', mask_organ) 173 | 174 | mask_organ_src = (batch['mask_organ_src'][i][1].numpy() * 255).astype('uint8') #[:, :, None] 175 | mask_organ_src = cv2.cvtColor(mask_organ_src, cv2.COLOR_GRAY2BGR) 176 | cv2.imwrite(f'ldm/data/debug/{i}_mask_organ_src.png', mask_organ_src) 177 | 178 | if __name__ == '__main__': 179 | dataset = Portrait('data/portrait') 180 | dataloader = DataLoader(dataset, batch_size=8, shuffle=True) 181 | 182 | for batch in dataloader: 183 | visualize(batch) 184 | break -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wl-zhao/DiffSwap/8596b4d635e3d97621df688245b365bc4d3ae02a/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/encoders/face_embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from src.arcface_torch.backbones import get_model 6 | from src.arcface_torch.utils.utils_config import get_config 7 | 8 | from einops import rearrange 9 | from ldm.modules.attention import CrossAttention 10 | 11 | from ldm.models.diffusion.ddpm import disabled_train 12 | 13 | 14 | # including l_eye r_eye nose mouse 15 | class FaceEmbedder(nn.Module): 16 | def __init__(self, lmk_dim=128, keys=None, pair=False, comb_mode='concat', merge_eyes=False, \ 17 | attention=False, face_model='r50', face_dataset='glint360k', affine_crop=False, use_blur=False): 18 | super().__init__() 19 | self.pair = pair 20 | cfg = get_config(f'src/arcface_torch/configs/{face_dataset}_{face_model}.py') 21 | self.face_model = get_model(cfg.network, dropout=0.0, 22 | fp16=cfg.fp16, num_features=cfg.embedding_size) 23 | ckpt_path = f'checkpoints/{face_dataset}_{face_model}.pth' 24 | a, b = self.face_model.load_state_dict(torch.load(ckpt_path, map_location='cpu'), strict=False) 25 | print('loading face model:', a, b) 26 | print('build model:', face_dataset, face_model, ckpt_path) 27 | 28 | self.landmark_encoder = nn.Sequential( 29 | nn.Linear(68 * 2, lmk_dim * 4), 30 | nn.GELU(), 31 | nn.Linear(lmk_dim * 4, lmk_dim) 32 | ) 33 | assert 'image' in keys and 'landmark' in keys 34 | self.swap = False 35 | self.comb_mode = comb_mode 36 | if lmk_dim != 512 and self.comb_mode == 'stack': 37 | self.id_fc = nn.Sequential( 38 | nn.Linear(512, 2 * lmk_dim), 39 | nn.GELU(), 40 | nn.Linear(2 * lmk_dim, lmk_dim) 41 | ) 42 | else: 43 | self.id_fc = None 44 | 45 | self.organ_fc = nn.Sequential( 46 | nn.Linear(3, lmk_dim), 47 | nn.GELU(), 48 | nn.Linear(lmk_dim, lmk_dim) 49 | ) 50 | 51 | if attention: 52 | self.organ_norm = nn.LayerNorm(lmk_dim) 53 | self.organ_attention = CrossAttention( 54 | query_dim=lmk_dim, 55 | context_dim=lmk_dim, 56 | dim_head=32, 57 | heads=4, 58 | ) 59 | print('[FaceEmbedder]: building organ attention') 60 | else: 61 | self.organ_attention = None 62 | self.organ_keys = ['l_eye', 'r_eye', 'mouth', 'nose'] 63 | self.merge_eyes = merge_eyes 64 | self.affine_crop = affine_crop 65 | self.face_model.eval() 66 | self.face_model.train = disabled_train 67 | self.use_blur = use_blur 68 | 69 | def encode_face(self, image): 70 | # image: (b c h w) 71 | id_feat = self.face_model(image) 72 | return id_feat 73 | 74 | def extract_organ_feats(self, z, mask_organ): 75 | if self.merge_eyes: 76 | mask_organ_ = mask_organ[:, 1:] 77 | mask_organ_[:, 0] = torch.logical_or(mask_organ[:, 0], mask_organ[:, 1]) 78 | mask_organ = mask_organ_ 79 | else: 80 | pass 81 | 82 | h, w = z.shape[-2:] 83 | mask_organ = F.interpolate(mask_organ.float(), (h, w), mode='nearest') 84 | sum1 = torch.einsum('bchw,bkhw->bkc', z, mask_organ) 85 | sum2 = mask_organ.sum(dim=(-1, -2))[..., None] 86 | return sum1 / (sum2 + 1e-6) 87 | 88 | def affine_crop_face(self, img, affine_theta): 89 | grid = F.affine_grid(affine_theta, size=(img.size(0), 3, 112, 112)) 90 | img = F.grid_sample(img, grid) 91 | return img 92 | 93 | def forward(self, cond, swap=False): 94 | swap = swap or self.swap 95 | 96 | image, landmark = cond['image'], cond['landmark'] 97 | z = cond['z'] 98 | 99 | if swap: 100 | assert 'image_src' in cond 101 | image = cond['image_src'] 102 | z = cond['z_src'] 103 | mask_organ = cond['mask_organ_src'] 104 | affine_theta = cond['affine_theta_src'] 105 | else: 106 | mask_organ = cond['mask_organ'] 107 | affine_theta = cond['affine_theta'] 108 | 109 | 110 | # extract organ features using z 111 | organ_feats = self.extract_organ_feats(z, mask_organ) 112 | cond['organ_feat'] = organ_feats 113 | organ_feats = self.organ_fc(organ_feats) 114 | if self.organ_attention is not None: 115 | organ_feats = organ_feats + self.organ_attention(self.organ_norm(organ_feats)) 116 | 117 | B = image.size(0) 118 | 119 | with torch.no_grad(): 120 | image = rearrange(image, 'b h w c -> b c h w') 121 | if self.affine_crop: 122 | image = self.affine_crop_face(image, affine_theta) 123 | else: 124 | image = F.interpolate(image, (112, 112), mode='bicubic') 125 | id_feat = self.encode_face(image) 126 | 127 | cond['id_feat'] = id_feat 128 | 129 | if self.id_fc: 130 | id_feat = self.id_fc(id_feat) 131 | 132 | lmk_feat = self.landmark_encoder(landmark.reshape(B, -1)) 133 | # print(self.landmark_encoder[0].weight.sum()) 134 | if self.comb_mode == 'concat': 135 | raise NotImplementedError() 136 | out = torch.cat([id_feat, lmk_feat], dim=1)[:, None] 137 | elif self.comb_mode == 'stack': 138 | out = torch.stack([lmk_feat, id_feat], dim=1) 139 | out = torch.cat([out, organ_feats], dim=1) 140 | else: 141 | raise NotImplementedError() 142 | # if self.training: 143 | 144 | if self.use_blur: 145 | return { 146 | 'c_concat': cond['z_blur'], 147 | 'c_crossattn': out 148 | } 149 | else: 150 | return out 151 | 152 | def parameters(self): 153 | params = list(self.landmark_encoder.parameters()) 154 | if self.id_fc: 155 | params += list(self.id_fc.parameters()) 156 | params += list(self.organ_fc.parameters()) 157 | if self.organ_attention: 158 | params += list(self.organ_attention.parameters()) 159 | params += list(self.organ_norm.parameters()) 160 | return params 161 | 162 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | import cv2 16 | from einops import rearrange 17 | 18 | 19 | def log_txt_as_img(wh, xc, size=10): 20 | # wh a tuple of (width, height) 21 | # xc a list of captions to plot 22 | b = len(xc) 23 | txts = list() 24 | for bi in range(b): 25 | txt = Image.new("RGB", wh, color="white") 26 | draw = ImageDraw.Draw(txt) 27 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 28 | nc = int(40 * (wh[0] / 256)) 29 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 30 | 31 | try: 32 | draw.text((0, 0), lines, fill="black", font=font) 33 | except UnicodeEncodeError: 34 | print("Cant encode string for logging. Skipping.") 35 | 36 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 37 | txts.append(txt) 38 | txts = np.stack(txts) 39 | txts = torch.tensor(txts) 40 | return txts 41 | 42 | def log_lmk_and_img(image, landmark): 43 | n = len(image) 44 | images = [] 45 | images_with_lmk = [] 46 | for i in range(n): 47 | img = (image[i] + 1) / 2 * 255 48 | img = rearrange(img, 'h w c -> c h w') 49 | img = img.cpu().numpy().transpose(1, 2, 0).astype('uint8').copy() 50 | images.append(img.transpose(2, 0, 1) / 127.5 - 1.0) 51 | 52 | lmk = (landmark[i] * img.shape[0]).cpu().numpy().astype('int32') 53 | 54 | # pdb.set_trace() 55 | for k in range(68): 56 | img = cv2.circle(img, (lmk[k, 0], lmk[k, 1]), 3, (255, 0, 255), thickness=-1) 57 | # pdb.set_trace() 58 | images_with_lmk.append(img.transpose(2, 0, 1) / 127.5 - 1.0) 59 | return torch.tensor(images), torch.tensor(images_with_lmk) 60 | 61 | 62 | def ismap(x): 63 | if not isinstance(x, torch.Tensor): 64 | return False 65 | return (len(x.shape) == 4) and (x.shape[1] > 3) 66 | 67 | 68 | def isimage(x): 69 | if not isinstance(x, torch.Tensor): 70 | return False 71 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 72 | 73 | 74 | def exists(x): 75 | return x is not None 76 | 77 | 78 | def default(val, d): 79 | if exists(val): 80 | return val 81 | return d() if isfunction(d) else d 82 | 83 | 84 | def mean_flat(tensor): 85 | """ 86 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 87 | Take the mean over all non-batch dimensions. 88 | """ 89 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 90 | 91 | 92 | def count_params(model, verbose=False): 93 | total_params = sum(p.numel() for p in model.parameters()) 94 | if verbose: 95 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 96 | return total_params 97 | 98 | 99 | def instantiate_from_config(config): 100 | if not "target" in config: 101 | if config == '__is_first_stage__': 102 | return None 103 | elif config == "__is_unconditional__": 104 | return None 105 | raise KeyError("Expected key `target` to instantiate.") 106 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 107 | 108 | 109 | def get_obj_from_str(string, reload=False): 110 | module, cls = string.rsplit(".", 1) 111 | if reload: 112 | module_imp = importlib.import_module(module) 113 | importlib.reload(module_imp) 114 | return getattr(importlib.import_module(module, package=None), cls) 115 | 116 | 117 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 118 | # create dummy dataset instance 119 | 120 | # run prefetching 121 | if idx_to_fn: 122 | res = func(data, worker_id=idx) 123 | else: 124 | res = func(data) 125 | Q.put([idx, res]) 126 | Q.put("Done") 127 | 128 | 129 | def parallel_data_prefetch( 130 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 131 | ): 132 | # if target_data_type not in ["ndarray", "list"]: 133 | # raise ValueError( 134 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 135 | # ) 136 | if isinstance(data, np.ndarray) and target_data_type == "list": 137 | raise ValueError("list expected but function got ndarray.") 138 | elif isinstance(data, abc.Iterable): 139 | if isinstance(data, dict): 140 | print( 141 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 142 | ) 143 | data = list(data.values()) 144 | if target_data_type == "ndarray": 145 | data = np.asarray(data) 146 | else: 147 | data = list(data) 148 | else: 149 | raise TypeError( 150 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 151 | ) 152 | 153 | if cpu_intensive: 154 | Q = mp.Queue(1000) 155 | proc = mp.Process 156 | else: 157 | Q = Queue(1000) 158 | proc = Thread 159 | # spawn processes 160 | if target_data_type == "ndarray": 161 | arguments = [ 162 | [func, Q, part, i, use_worker_id] 163 | for i, part in enumerate(np.array_split(data, n_proc)) 164 | ] 165 | else: 166 | step = ( 167 | int(len(data) / n_proc + 1) 168 | if len(data) % n_proc != 0 169 | else int(len(data) / n_proc) 170 | ) 171 | arguments = [ 172 | [func, Q, part, i, use_worker_id] 173 | for i, part in enumerate( 174 | [data[i: i + step] for i in range(0, len(data), step)] 175 | ) 176 | ] 177 | processes = [] 178 | for i in range(n_proc): 179 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 180 | processes += [p] 181 | 182 | # start processes 183 | print(f"Start prefetching...") 184 | import time 185 | 186 | start = time.time() 187 | gather_res = [[] for _ in range(n_proc)] 188 | try: 189 | for p in processes: 190 | p.start() 191 | 192 | k = 0 193 | while k < n_proc: 194 | # get result 195 | res = Q.get() 196 | if res == "Done": 197 | k += 1 198 | else: 199 | gather_res[res[0]] = res[1] 200 | 201 | except Exception as e: 202 | print("Exception: ", e) 203 | for p in processes: 204 | p.terminate() 205 | 206 | raise e 207 | finally: 208 | for p in processes: 209 | p.join() 210 | print(f"Prefetching complete. [{time.time() - start} sec.]") 211 | 212 | if target_data_type == 'ndarray': 213 | if not isinstance(gather_res[0], np.ndarray): 214 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 215 | 216 | # order outputs 217 | return np.concatenate(gather_res, axis=0) 218 | elif target_data_type == 'list': 219 | out = [] 220 | for r in gather_res: 221 | out.extend(r) 222 | return out 223 | else: 224 | return gather_res 225 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mtcnn==0.1.1 2 | tensorflow==2.13.0 3 | openai-clip==1.0.1 4 | easydict==1.10 5 | kornia==0.7.0 6 | dlib==19.24.2 7 | pytorch_lightning==1.4.2 8 | torchmetrics==0.6.0 9 | omegaconf==2.3.0 10 | einops==0.6.1 -------------------------------------------------------------------------------- /src/arcface_torch/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200 2 | from .mobilefacenet import get_mbf 3 | 4 | 5 | def get_model(name, **kwargs): 6 | # resnet 7 | if name == "r18": 8 | return iresnet18(False, **kwargs) 9 | elif name == "r34": 10 | return iresnet34(False, **kwargs) 11 | elif name == "r50": 12 | return iresnet50(False, **kwargs) 13 | elif name == "r100": 14 | return iresnet100(False, **kwargs) 15 | elif name == "r200": 16 | return iresnet200(False, **kwargs) 17 | elif name == "r2060": 18 | from .iresnet2060 import iresnet2060 19 | return iresnet2060(False, **kwargs) 20 | 21 | elif name == "mbf": 22 | fp16 = kwargs.get("fp16", False) 23 | num_features = kwargs.get("num_features", 512) 24 | return get_mbf(fp16=fp16, num_features=num_features) 25 | 26 | elif name == "mbf_large": 27 | from .mobilefacenet import get_mbf_large 28 | fp16 = kwargs.get("fp16", False) 29 | num_features = kwargs.get("num_features", 512) 30 | return get_mbf_large(fp16=fp16, num_features=num_features) 31 | 32 | elif name == "vit_t": 33 | num_features = kwargs.get("num_features", 512) 34 | from .vit import VisionTransformer 35 | return VisionTransformer( 36 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12, 37 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1) 38 | 39 | elif name == "vit_t_dp005_mask0": # For WebFace42M 40 | num_features = kwargs.get("num_features", 512) 41 | from .vit import VisionTransformer 42 | return VisionTransformer( 43 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12, 44 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0) 45 | 46 | elif name == "vit_s": 47 | num_features = kwargs.get("num_features", 512) 48 | from .vit import VisionTransformer 49 | return VisionTransformer( 50 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12, 51 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1) 52 | 53 | elif name == "vit_s_dp005_mask_0": # For WebFace42M 54 | num_features = kwargs.get("num_features", 512) 55 | from .vit import VisionTransformer 56 | return VisionTransformer( 57 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12, 58 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0) 59 | 60 | elif name == "vit_b": 61 | # this is a feature 62 | num_features = kwargs.get("num_features", 512) 63 | from .vit import VisionTransformer 64 | return VisionTransformer( 65 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24, 66 | num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True) 67 | 68 | elif name == "vit_b_dp005_mask_005": # For WebFace42M 69 | # this is a feature 70 | num_features = kwargs.get("num_features", 512) 71 | from .vit import VisionTransformer 72 | return VisionTransformer( 73 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24, 74 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True) 75 | 76 | elif name == "vit_l_dp005_mask_005": # For WebFace42M 77 | # this is a feature 78 | num_features = kwargs.get("num_features", 512) 79 | from .vit import VisionTransformer 80 | return VisionTransformer( 81 | img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24, 82 | num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True) 83 | 84 | else: 85 | raise ValueError() 86 | -------------------------------------------------------------------------------- /src/arcface_torch/backbones/iresnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.utils.checkpoint import checkpoint 4 | 5 | __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] 6 | using_ckpt = False 7 | 8 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 9 | """3x3 convolution with padding""" 10 | return nn.Conv2d(in_planes, 11 | out_planes, 12 | kernel_size=3, 13 | stride=stride, 14 | padding=dilation, 15 | groups=groups, 16 | bias=False, 17 | dilation=dilation) 18 | 19 | 20 | def conv1x1(in_planes, out_planes, stride=1): 21 | """1x1 convolution""" 22 | return nn.Conv2d(in_planes, 23 | out_planes, 24 | kernel_size=1, 25 | stride=stride, 26 | bias=False) 27 | 28 | 29 | class IBasicBlock(nn.Module): 30 | expansion = 1 31 | def __init__(self, inplanes, planes, stride=1, downsample=None, 32 | groups=1, base_width=64, dilation=1): 33 | super(IBasicBlock, self).__init__() 34 | if groups != 1 or base_width != 64: 35 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 36 | if dilation > 1: 37 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 38 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) 39 | self.conv1 = conv3x3(inplanes, planes) 40 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) 41 | self.prelu = nn.PReLU(planes) 42 | self.conv2 = conv3x3(planes, planes, stride) 43 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) 44 | self.downsample = downsample 45 | self.stride = stride 46 | 47 | def forward_impl(self, x): 48 | identity = x 49 | out = self.bn1(x) 50 | out = self.conv1(out) 51 | out = self.bn2(out) 52 | out = self.prelu(out) 53 | out = self.conv2(out) 54 | out = self.bn3(out) 55 | if self.downsample is not None: 56 | identity = self.downsample(x) 57 | out += identity 58 | return out 59 | 60 | def forward(self, x): 61 | if self.training and using_ckpt: 62 | return checkpoint(self.forward_impl, x) 63 | else: 64 | return self.forward_impl(x) 65 | 66 | 67 | class IResNet(nn.Module): 68 | fc_scale = 7 * 7 69 | def __init__(self, 70 | block, layers, dropout=0, num_features=512, zero_init_residual=False, 71 | groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): 72 | super(IResNet, self).__init__() 73 | self.extra_gflops = 0.0 74 | self.fp16 = fp16 75 | self.inplanes = 64 76 | self.dilation = 1 77 | if replace_stride_with_dilation is None: 78 | replace_stride_with_dilation = [False, False, False] 79 | if len(replace_stride_with_dilation) != 3: 80 | raise ValueError("replace_stride_with_dilation should be None " 81 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 82 | self.groups = groups 83 | self.base_width = width_per_group 84 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 85 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 86 | self.prelu = nn.PReLU(self.inplanes) 87 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 88 | self.layer2 = self._make_layer(block, 89 | 128, 90 | layers[1], 91 | stride=2, 92 | dilate=replace_stride_with_dilation[0]) 93 | self.layer3 = self._make_layer(block, 94 | 256, 95 | layers[2], 96 | stride=2, 97 | dilate=replace_stride_with_dilation[1]) 98 | self.layer4 = self._make_layer(block, 99 | 512, 100 | layers[3], 101 | stride=2, 102 | dilate=replace_stride_with_dilation[2]) 103 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) 104 | self.dropout = nn.Dropout(p=dropout, inplace=True) 105 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) 106 | self.features = nn.BatchNorm1d(num_features, eps=1e-05) 107 | nn.init.constant_(self.features.weight, 1.0) 108 | self.features.weight.requires_grad = False 109 | 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | nn.init.normal_(m.weight, 0, 0.1) 113 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 114 | nn.init.constant_(m.weight, 1) 115 | nn.init.constant_(m.bias, 0) 116 | 117 | if zero_init_residual: 118 | for m in self.modules(): 119 | if isinstance(m, IBasicBlock): 120 | nn.init.constant_(m.bn2.weight, 0) 121 | 122 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 123 | downsample = None 124 | previous_dilation = self.dilation 125 | if dilate: 126 | self.dilation *= stride 127 | stride = 1 128 | if stride != 1 or self.inplanes != planes * block.expansion: 129 | downsample = nn.Sequential( 130 | conv1x1(self.inplanes, planes * block.expansion, stride), 131 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), 132 | ) 133 | layers = [] 134 | layers.append( 135 | block(self.inplanes, planes, stride, downsample, self.groups, 136 | self.base_width, previous_dilation)) 137 | self.inplanes = planes * block.expansion 138 | for _ in range(1, blocks): 139 | layers.append( 140 | block(self.inplanes, 141 | planes, 142 | groups=self.groups, 143 | base_width=self.base_width, 144 | dilation=self.dilation)) 145 | 146 | return nn.Sequential(*layers) 147 | 148 | def forward(self, x): 149 | with torch.cuda.amp.autocast(self.fp16): 150 | x = self.conv1(x) 151 | x = self.bn1(x) 152 | x = self.prelu(x) 153 | x = self.layer1(x) 154 | x = self.layer2(x) 155 | x = self.layer3(x) 156 | x = self.layer4(x) 157 | x = self.bn2(x) 158 | x = torch.flatten(x, 1) 159 | x = self.dropout(x) 160 | x = self.fc(x.float() if self.fp16 else x) 161 | x = self.features(x) 162 | return x 163 | 164 | 165 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 166 | model = IResNet(block, layers, **kwargs) 167 | if pretrained: 168 | raise ValueError() 169 | return model 170 | 171 | 172 | def iresnet18(pretrained=False, progress=True, **kwargs): 173 | return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, 174 | progress, **kwargs) 175 | 176 | 177 | def iresnet34(pretrained=False, progress=True, **kwargs): 178 | return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, 179 | progress, **kwargs) 180 | 181 | 182 | def iresnet50(pretrained=False, progress=True, **kwargs): 183 | return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, 184 | progress, **kwargs) 185 | 186 | 187 | def iresnet100(pretrained=False, progress=True, **kwargs): 188 | return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, 189 | progress, **kwargs) 190 | 191 | 192 | def iresnet200(pretrained=False, progress=True, **kwargs): 193 | return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, 194 | progress, **kwargs) 195 | -------------------------------------------------------------------------------- /src/arcface_torch/backbones/iresnet2060.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | assert torch.__version__ >= "1.8.1" 5 | from torch.utils.checkpoint import checkpoint_sequential 6 | 7 | __all__ = ['iresnet2060'] 8 | 9 | 10 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 11 | """3x3 convolution with padding""" 12 | return nn.Conv2d(in_planes, 13 | out_planes, 14 | kernel_size=3, 15 | stride=stride, 16 | padding=dilation, 17 | groups=groups, 18 | bias=False, 19 | dilation=dilation) 20 | 21 | 22 | def conv1x1(in_planes, out_planes, stride=1): 23 | """1x1 convolution""" 24 | return nn.Conv2d(in_planes, 25 | out_planes, 26 | kernel_size=1, 27 | stride=stride, 28 | bias=False) 29 | 30 | 31 | class IBasicBlock(nn.Module): 32 | expansion = 1 33 | 34 | def __init__(self, inplanes, planes, stride=1, downsample=None, 35 | groups=1, base_width=64, dilation=1): 36 | super(IBasicBlock, self).__init__() 37 | if groups != 1 or base_width != 64: 38 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 39 | if dilation > 1: 40 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 41 | self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, ) 42 | self.conv1 = conv3x3(inplanes, planes) 43 | self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, ) 44 | self.prelu = nn.PReLU(planes) 45 | self.conv2 = conv3x3(planes, planes, stride) 46 | self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, ) 47 | self.downsample = downsample 48 | self.stride = stride 49 | 50 | def forward(self, x): 51 | identity = x 52 | out = self.bn1(x) 53 | out = self.conv1(out) 54 | out = self.bn2(out) 55 | out = self.prelu(out) 56 | out = self.conv2(out) 57 | out = self.bn3(out) 58 | if self.downsample is not None: 59 | identity = self.downsample(x) 60 | out += identity 61 | return out 62 | 63 | 64 | class IResNet(nn.Module): 65 | fc_scale = 7 * 7 66 | 67 | def __init__(self, 68 | block, layers, dropout=0, num_features=512, zero_init_residual=False, 69 | groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): 70 | super(IResNet, self).__init__() 71 | self.fp16 = fp16 72 | self.inplanes = 64 73 | self.dilation = 1 74 | if replace_stride_with_dilation is None: 75 | replace_stride_with_dilation = [False, False, False] 76 | if len(replace_stride_with_dilation) != 3: 77 | raise ValueError("replace_stride_with_dilation should be None " 78 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 79 | self.groups = groups 80 | self.base_width = width_per_group 81 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) 82 | self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) 83 | self.prelu = nn.PReLU(self.inplanes) 84 | self.layer1 = self._make_layer(block, 64, layers[0], stride=2) 85 | self.layer2 = self._make_layer(block, 86 | 128, 87 | layers[1], 88 | stride=2, 89 | dilate=replace_stride_with_dilation[0]) 90 | self.layer3 = self._make_layer(block, 91 | 256, 92 | layers[2], 93 | stride=2, 94 | dilate=replace_stride_with_dilation[1]) 95 | self.layer4 = self._make_layer(block, 96 | 512, 97 | layers[3], 98 | stride=2, 99 | dilate=replace_stride_with_dilation[2]) 100 | self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, ) 101 | self.dropout = nn.Dropout(p=dropout, inplace=True) 102 | self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) 103 | self.features = nn.BatchNorm1d(num_features, eps=1e-05) 104 | nn.init.constant_(self.features.weight, 1.0) 105 | self.features.weight.requires_grad = False 106 | 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.normal_(m.weight, 0, 0.1) 110 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 111 | nn.init.constant_(m.weight, 1) 112 | nn.init.constant_(m.bias, 0) 113 | 114 | if zero_init_residual: 115 | for m in self.modules(): 116 | if isinstance(m, IBasicBlock): 117 | nn.init.constant_(m.bn2.weight, 0) 118 | 119 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 120 | downsample = None 121 | previous_dilation = self.dilation 122 | if dilate: 123 | self.dilation *= stride 124 | stride = 1 125 | if stride != 1 or self.inplanes != planes * block.expansion: 126 | downsample = nn.Sequential( 127 | conv1x1(self.inplanes, planes * block.expansion, stride), 128 | nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), 129 | ) 130 | layers = [] 131 | layers.append( 132 | block(self.inplanes, planes, stride, downsample, self.groups, 133 | self.base_width, previous_dilation)) 134 | self.inplanes = planes * block.expansion 135 | for _ in range(1, blocks): 136 | layers.append( 137 | block(self.inplanes, 138 | planes, 139 | groups=self.groups, 140 | base_width=self.base_width, 141 | dilation=self.dilation)) 142 | 143 | return nn.Sequential(*layers) 144 | 145 | def checkpoint(self, func, num_seg, x): 146 | if self.training: 147 | return checkpoint_sequential(func, num_seg, x) 148 | else: 149 | return func(x) 150 | 151 | def forward(self, x): 152 | with torch.cuda.amp.autocast(self.fp16): 153 | x = self.conv1(x) 154 | x = self.bn1(x) 155 | x = self.prelu(x) 156 | x = self.layer1(x) 157 | x = self.checkpoint(self.layer2, 20, x) 158 | x = self.checkpoint(self.layer3, 100, x) 159 | x = self.layer4(x) 160 | x = self.bn2(x) 161 | x = torch.flatten(x, 1) 162 | x = self.dropout(x) 163 | x = self.fc(x.float() if self.fp16 else x) 164 | x = self.features(x) 165 | return x 166 | 167 | 168 | def _iresnet(arch, block, layers, pretrained, progress, **kwargs): 169 | model = IResNet(block, layers, **kwargs) 170 | if pretrained: 171 | raise ValueError() 172 | return model 173 | 174 | 175 | def iresnet2060(pretrained=False, progress=True, **kwargs): 176 | return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs) 177 | -------------------------------------------------------------------------------- /src/arcface_torch/backbones/mobilefacenet.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py 3 | Original author cavalleria 4 | ''' 5 | 6 | import torch.nn as nn 7 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module 8 | import torch 9 | 10 | 11 | class Flatten(Module): 12 | def forward(self, x): 13 | return x.view(x.size(0), -1) 14 | 15 | 16 | class ConvBlock(Module): 17 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 18 | super(ConvBlock, self).__init__() 19 | self.layers = nn.Sequential( 20 | Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False), 21 | BatchNorm2d(num_features=out_c), 22 | PReLU(num_parameters=out_c) 23 | ) 24 | 25 | def forward(self, x): 26 | return self.layers(x) 27 | 28 | 29 | class LinearBlock(Module): 30 | def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): 31 | super(LinearBlock, self).__init__() 32 | self.layers = nn.Sequential( 33 | Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False), 34 | BatchNorm2d(num_features=out_c) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.layers(x) 39 | 40 | 41 | class DepthWise(Module): 42 | def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1): 43 | super(DepthWise, self).__init__() 44 | self.residual = residual 45 | self.layers = nn.Sequential( 46 | ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)), 47 | ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride), 48 | LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1)) 49 | ) 50 | 51 | def forward(self, x): 52 | short_cut = None 53 | if self.residual: 54 | short_cut = x 55 | x = self.layers(x) 56 | if self.residual: 57 | output = short_cut + x 58 | else: 59 | output = x 60 | return output 61 | 62 | 63 | class Residual(Module): 64 | def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)): 65 | super(Residual, self).__init__() 66 | modules = [] 67 | for _ in range(num_block): 68 | modules.append(DepthWise(c, c, True, kernel, stride, padding, groups)) 69 | self.layers = Sequential(*modules) 70 | 71 | def forward(self, x): 72 | return self.layers(x) 73 | 74 | 75 | class GDC(Module): 76 | def __init__(self, embedding_size): 77 | super(GDC, self).__init__() 78 | self.layers = nn.Sequential( 79 | LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)), 80 | Flatten(), 81 | Linear(512, embedding_size, bias=False), 82 | BatchNorm1d(embedding_size)) 83 | 84 | def forward(self, x): 85 | return self.layers(x) 86 | 87 | 88 | class MobileFaceNet(Module): 89 | def __init__(self, fp16=False, num_features=512, blocks=(1, 4, 6, 2), scale=2): 90 | super(MobileFaceNet, self).__init__() 91 | self.scale = scale 92 | self.fp16 = fp16 93 | self.layers = nn.ModuleList() 94 | self.layers.append( 95 | ConvBlock(3, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)) 96 | ) 97 | if blocks[0] == 1: 98 | self.layers.append( 99 | ConvBlock(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64) 100 | ) 101 | else: 102 | self.layers.append( 103 | Residual(64 * self.scale, num_block=blocks[0], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 104 | ) 105 | 106 | self.layers.extend( 107 | [ 108 | DepthWise(64 * self.scale, 64 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128), 109 | Residual(64 * self.scale, num_block=blocks[1], groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 110 | DepthWise(64 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256), 111 | Residual(128 * self.scale, num_block=blocks[2], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 112 | DepthWise(128 * self.scale, 128 * self.scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512), 113 | Residual(128 * self.scale, num_block=blocks[3], groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)), 114 | ]) 115 | 116 | self.conv_sep = ConvBlock(128 * self.scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0)) 117 | self.features = GDC(num_features) 118 | self._initialize_weights() 119 | 120 | def _initialize_weights(self): 121 | for m in self.modules(): 122 | if isinstance(m, nn.Conv2d): 123 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 124 | if m.bias is not None: 125 | m.bias.data.zero_() 126 | elif isinstance(m, nn.BatchNorm2d): 127 | m.weight.data.fill_(1) 128 | m.bias.data.zero_() 129 | elif isinstance(m, nn.Linear): 130 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 131 | if m.bias is not None: 132 | m.bias.data.zero_() 133 | 134 | def forward(self, x): 135 | with torch.cuda.amp.autocast(self.fp16): 136 | for func in self.layers: 137 | x = func(x) 138 | x = self.conv_sep(x.float() if self.fp16 else x) 139 | x = self.features(x) 140 | return x 141 | 142 | 143 | def get_mbf(fp16, num_features, blocks=(1, 4, 6, 2), scale=2): 144 | return MobileFaceNet(fp16, num_features, blocks, scale=scale) 145 | 146 | def get_mbf_large(fp16, num_features, blocks=(2, 8, 12, 4), scale=4): 147 | return MobileFaceNet(fp16, num_features, blocks, scale=scale) 148 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/3millions.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # configs for test speed 4 | 5 | config = edict() 6 | config.margin_list = (1.0, 0.0, 0.4) 7 | config.network = "mbf" 8 | config.resume = False 9 | config.output = None 10 | config.embedding_size = 512 11 | config.sample_rate = 0.1 12 | config.fp16 = True 13 | config.momentum = 0.9 14 | config.weight_decay = 5e-4 15 | config.batch_size = 512 # total_batch_size = batch_size * num_gpus 16 | config.lr = 0.1 # batch size is 512 17 | 18 | config.rec = "synthetic" 19 | config.num_classes = 30 * 10000 20 | config.num_image = 100000 21 | config.num_epoch = 30 22 | config.warmup_epoch = -1 23 | config.val_targets = [] 24 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wl-zhao/DiffSwap/8596b4d635e3d97621df688245b365bc4d3ae02a/src/arcface_torch/configs/__init__.py -------------------------------------------------------------------------------- /src/arcface_torch/configs/base.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | 9 | # Margin Base Softmax 10 | config.margin_list = (1.0, 0.5, 0.0) 11 | config.network = "r50" 12 | config.resume = False 13 | config.save_all_states = False 14 | config.output = "ms1mv3_arcface_r50" 15 | 16 | config.embedding_size = 512 17 | 18 | # Partial FC 19 | config.sample_rate = 1 20 | config.interclass_filtering_threshold = 0 21 | 22 | config.fp16 = False 23 | config.batch_size = 128 24 | 25 | # For SGD 26 | config.optimizer = "sgd" 27 | config.lr = 0.1 28 | config.momentum = 0.9 29 | config.weight_decay = 5e-4 30 | 31 | # For AdamW 32 | # config.optimizer = "adamw" 33 | # config.lr = 0.001 34 | # config.weight_decay = 0.1 35 | 36 | config.verbose = 2000 37 | config.frequent = 10 38 | 39 | # For Large Sacle Dataset, such as WebFace42M 40 | config.dali = False 41 | 42 | # Gradient ACC 43 | config.gradient_acc = 1 44 | 45 | # setup seed 46 | config.seed = 2048 47 | 48 | # dataload numworkers 49 | config.num_workers = 2 50 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/glint360k_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/glint360k" 23 | config.num_classes = 360232 24 | config.num_image = 17091657 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/glint360k_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/glint360k" 23 | config.num_classes = 360232 24 | config.num_image = 17091657 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/glint360k_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/glint360k" 23 | config.num_classes = 360232 24 | config.num_image = 17091657 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/ms1mv2_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/faces_emore" 23 | config.num_classes = 85742 24 | config.num_image = 5822653 25 | config.num_epoch = 40 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/ms1mv2_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/faces_emore" 23 | config.num_classes = 85742 24 | config.num_image = 5822653 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/ms1mv2_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/faces_emore" 23 | config.num_classes = 85742 24 | config.num_image = 5822653 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/ms1mv3_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/ms1m-retinaface-t1" 23 | config.num_classes = 93431 24 | config.num_image = 5179510 25 | config.num_epoch = 40 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/ms1mv3_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/ms1m-retinaface-t1" 23 | config.num_classes = 93431 24 | config.num_image = 5179510 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/ms1mv3_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.5, 0.0) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/ms1m-retinaface-t1" 23 | config.num_classes = 93431 24 | config.num_image = 5179510 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf12m_conflict_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M_Conflict" 24 | config.num_classes = 1017970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf12m_conflict_r50_pfc03_filter04.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.interclass_filtering_threshold = 0.4 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M_Conflict" 24 | config.num_classes = 1017970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf12m_flip_pfc01_filter04_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.1 14 | config.interclass_filtering_threshold = 0.4 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M_FLIP40" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf12m_flip_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M_FLIP40" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf12m_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = 0 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf12m_pfc02_r100.py: -------------------------------------------------------------------------------- 1 | 2 | from easydict import EasyDict as edict 3 | 4 | # make training faster 5 | # our RAM is 256G 6 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 7 | 8 | config = edict() 9 | config.margin_list = (1.0, 0.0, 0.4) 10 | config.network = "r100" 11 | config.resume = False 12 | config.output = None 13 | config.embedding_size = 512 14 | config.sample_rate = 0.2 15 | config.interclass_filtering_threshold = 0 16 | config.fp16 = True 17 | config.weight_decay = 5e-4 18 | config.batch_size = 128 19 | config.optimizer = "sgd" 20 | config.lr = 0.1 21 | config.verbose = 2000 22 | config.dali = False 23 | 24 | config.rec = "/train_tmp/WebFace12M" 25 | config.num_classes = 617970 26 | config.num_image = 12720066 27 | config.num_epoch = 20 28 | config.warmup_epoch = 0 29 | config.val_targets = [] 30 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf12m_r100.py: -------------------------------------------------------------------------------- 1 | 2 | from easydict import EasyDict as edict 3 | 4 | # make training faster 5 | # our RAM is 256G 6 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 7 | 8 | config = edict() 9 | config.margin_list = (1.0, 0.0, 0.4) 10 | config.network = "r100" 11 | config.resume = False 12 | config.output = None 13 | config.embedding_size = 512 14 | config.sample_rate = 1.0 15 | config.interclass_filtering_threshold = 0 16 | config.fp16 = True 17 | config.weight_decay = 5e-4 18 | config.batch_size = 128 19 | config.optimizer = "sgd" 20 | config.lr = 0.1 21 | config.verbose = 2000 22 | config.dali = False 23 | 24 | config.rec = "/train_tmp/WebFace12M" 25 | config.num_classes = 617970 26 | config.num_image = 12720066 27 | config.num_epoch = 20 28 | config.warmup_epoch = 0 29 | config.val_targets = [] 30 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf12m_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.interclass_filtering_threshold = 0 15 | config.fp16 = True 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.optimizer = "sgd" 19 | config.lr = 0.1 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace12M" 24 | config.num_classes = 617970 25 | config.num_image = 12720066 26 | config.num_epoch = 20 27 | config.warmup_epoch = 0 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc0008_32gpu_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 512 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc02_16gpus_mbf_bs8k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 512 18 | config.lr = 0.4 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 2 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc02_16gpus_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 256 18 | config.lr = 0.3 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 1 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc02_16gpus_r50_bs8k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 512 18 | config.lr = 0.6 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 4 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc02_32gpus_r50_bs4k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 2 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc02_8gpus_r50_bs4k.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 512 18 | config.lr = 0.4 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 2 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc02_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc02_r100_16gpus.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.2 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc02_r100_32gpus.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.2 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 10000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc03_32gpu_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc03_32gpu_r18.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r18" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc03_32gpu_r200.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r200" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc03_32gpu_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.4 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 20 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = ["lfw", "cfp_fp", "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_b.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_b_dp005_mask_005" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 384 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_l.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_l_dp005_mask_005" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 384 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_s.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_s_dp005_mask_0" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 384 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc03_40epoch_64gpu_vit_t.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_t_dp005_mask0" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 384 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_b.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_b_dp005_mask_005" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 256 17 | config.gradient_acc = 12 # total batchsize is 256 * 12 18 | config.optimizer = "adamw" 19 | config.lr = 0.001 20 | config.verbose = 2000 21 | config.dali = False 22 | 23 | config.rec = "/train_tmp/WebFace42M" 24 | config.num_classes = 2059906 25 | config.num_image = 42474557 26 | config.num_epoch = 40 27 | config.warmup_epoch = config.num_epoch // 10 28 | config.val_targets = [] 29 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf42m_pfc03_40epoch_8gpu_vit_t.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "vit_t_dp005_mask0" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 0.3 14 | config.fp16 = True 15 | config.weight_decay = 0.1 16 | config.batch_size = 512 17 | config.optimizer = "adamw" 18 | config.lr = 0.001 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace42M" 23 | config.num_classes = 2059906 24 | config.num_image = 42474557 25 | config.num_epoch = 40 26 | config.warmup_epoch = config.num_epoch // 10 27 | config.val_targets = [] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf4m_mbf.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "mbf" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 1e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace4M" 23 | config.num_classes = 205990 24 | config.num_image = 4235242 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf4m_r100.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r100" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace4M" 23 | config.num_classes = 205990 24 | config.num_image = 4235242 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/configs/wf4m_r50.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | # make training faster 4 | # our RAM is 256G 5 | # mount -t tmpfs -o size=140G tmpfs /train_tmp 6 | 7 | config = edict() 8 | config.margin_list = (1.0, 0.0, 0.4) 9 | config.network = "r50" 10 | config.resume = False 11 | config.output = None 12 | config.embedding_size = 512 13 | config.sample_rate = 1.0 14 | config.fp16 = True 15 | config.momentum = 0.9 16 | config.weight_decay = 5e-4 17 | config.batch_size = 128 18 | config.lr = 0.1 19 | config.verbose = 2000 20 | config.dali = False 21 | 22 | config.rec = "/train_tmp/WebFace4M" 23 | config.num_classes = 205990 24 | config.num_image = 4235242 25 | config.num_epoch = 20 26 | config.warmup_epoch = 0 27 | config.val_targets = ['lfw', 'cfp_fp', "agedb_30"] 28 | -------------------------------------------------------------------------------- /src/arcface_torch/dataset.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import os 3 | import queue as Queue 4 | import threading 5 | from typing import Iterable 6 | 7 | import mxnet as mx 8 | import numpy as np 9 | import torch 10 | from functools import partial 11 | from torch import distributed 12 | from torch.utils.data import DataLoader, Dataset 13 | from torchvision import transforms 14 | from torchvision.datasets import ImageFolder 15 | from utils.utils_distributed_sampler import DistributedSampler 16 | from utils.utils_distributed_sampler import get_dist_info, worker_init_fn 17 | 18 | 19 | def get_dataloader( 20 | root_dir, 21 | local_rank, 22 | batch_size, 23 | dali = False, 24 | seed = 2048, 25 | num_workers = 2, 26 | ) -> Iterable: 27 | 28 | rec = os.path.join(root_dir, 'train.rec') 29 | idx = os.path.join(root_dir, 'train.idx') 30 | train_set = None 31 | 32 | # Synthetic 33 | if root_dir == "synthetic": 34 | train_set = SyntheticDataset() 35 | dali = False 36 | 37 | # Mxnet RecordIO 38 | elif os.path.exists(rec) and os.path.exists(idx): 39 | train_set = MXFaceDataset(root_dir=root_dir, local_rank=local_rank) 40 | 41 | # Image Folder 42 | else: 43 | transform = transforms.Compose([ 44 | transforms.RandomHorizontalFlip(), 45 | transforms.ToTensor(), 46 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 47 | ]) 48 | train_set = ImageFolder(root_dir, transform) 49 | 50 | # DALI 51 | if dali: 52 | return dali_data_iter( 53 | batch_size=batch_size, rec_file=rec, idx_file=idx, 54 | num_threads=2, local_rank=local_rank) 55 | 56 | rank, world_size = get_dist_info() 57 | train_sampler = DistributedSampler( 58 | train_set, num_replicas=world_size, rank=rank, shuffle=True, seed=seed) 59 | 60 | if seed is None: 61 | init_fn = None 62 | else: 63 | init_fn = partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) 64 | 65 | train_loader = DataLoaderX( 66 | local_rank=local_rank, 67 | dataset=train_set, 68 | batch_size=batch_size, 69 | sampler=train_sampler, 70 | num_workers=num_workers, 71 | pin_memory=True, 72 | drop_last=True, 73 | worker_init_fn=init_fn, 74 | ) 75 | 76 | return train_loader 77 | 78 | class BackgroundGenerator(threading.Thread): 79 | def __init__(self, generator, local_rank, max_prefetch=6): 80 | super(BackgroundGenerator, self).__init__() 81 | self.queue = Queue.Queue(max_prefetch) 82 | self.generator = generator 83 | self.local_rank = local_rank 84 | self.daemon = True 85 | self.start() 86 | 87 | def run(self): 88 | torch.cuda.set_device(self.local_rank) 89 | for item in self.generator: 90 | self.queue.put(item) 91 | self.queue.put(None) 92 | 93 | def next(self): 94 | next_item = self.queue.get() 95 | if next_item is None: 96 | raise StopIteration 97 | return next_item 98 | 99 | def __next__(self): 100 | return self.next() 101 | 102 | def __iter__(self): 103 | return self 104 | 105 | 106 | class DataLoaderX(DataLoader): 107 | 108 | def __init__(self, local_rank, **kwargs): 109 | super(DataLoaderX, self).__init__(**kwargs) 110 | self.stream = torch.cuda.Stream(local_rank) 111 | self.local_rank = local_rank 112 | 113 | def __iter__(self): 114 | self.iter = super(DataLoaderX, self).__iter__() 115 | self.iter = BackgroundGenerator(self.iter, self.local_rank) 116 | self.preload() 117 | return self 118 | 119 | def preload(self): 120 | self.batch = next(self.iter, None) 121 | if self.batch is None: 122 | return None 123 | with torch.cuda.stream(self.stream): 124 | for k in range(len(self.batch)): 125 | self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) 126 | 127 | def __next__(self): 128 | torch.cuda.current_stream().wait_stream(self.stream) 129 | batch = self.batch 130 | if batch is None: 131 | raise StopIteration 132 | self.preload() 133 | return batch 134 | 135 | 136 | class MXFaceDataset(Dataset): 137 | def __init__(self, root_dir, local_rank): 138 | super(MXFaceDataset, self).__init__() 139 | self.transform = transforms.Compose( 140 | [transforms.ToPILImage(), 141 | transforms.RandomHorizontalFlip(), 142 | transforms.ToTensor(), 143 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 144 | ]) 145 | self.root_dir = root_dir 146 | self.local_rank = local_rank 147 | path_imgrec = os.path.join(root_dir, 'train.rec') 148 | path_imgidx = os.path.join(root_dir, 'train.idx') 149 | self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') 150 | s = self.imgrec.read_idx(0) 151 | header, _ = mx.recordio.unpack(s) 152 | if header.flag > 0: 153 | self.header0 = (int(header.label[0]), int(header.label[1])) 154 | self.imgidx = np.array(range(1, int(header.label[0]))) 155 | else: 156 | self.imgidx = np.array(list(self.imgrec.keys)) 157 | 158 | def __getitem__(self, index): 159 | idx = self.imgidx[index] 160 | s = self.imgrec.read_idx(idx) 161 | header, img = mx.recordio.unpack(s) 162 | label = header.label 163 | if not isinstance(label, numbers.Number): 164 | label = label[0] 165 | label = torch.tensor(label, dtype=torch.long) 166 | sample = mx.image.imdecode(img).asnumpy() 167 | if self.transform is not None: 168 | sample = self.transform(sample) 169 | return sample, label 170 | 171 | def __len__(self): 172 | return len(self.imgidx) 173 | 174 | 175 | class SyntheticDataset(Dataset): 176 | def __init__(self): 177 | super(SyntheticDataset, self).__init__() 178 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) 179 | img = np.transpose(img, (2, 0, 1)) 180 | img = torch.from_numpy(img).squeeze(0).float() 181 | img = ((img / 255) - 0.5) / 0.5 182 | self.img = img 183 | self.label = 1 184 | 185 | def __getitem__(self, index): 186 | return self.img, self.label 187 | 188 | def __len__(self): 189 | return 1000000 190 | 191 | 192 | def dali_data_iter( 193 | batch_size: int, rec_file: str, idx_file: str, num_threads: int, 194 | initial_fill=32768, random_shuffle=True, 195 | prefetch_queue_depth=1, local_rank=0, name="reader", 196 | mean=(127.5, 127.5, 127.5), 197 | std=(127.5, 127.5, 127.5)): 198 | """ 199 | Parameters: 200 | ---------- 201 | initial_fill: int 202 | Size of the buffer that is used for shuffling. If random_shuffle is False, this parameter is ignored. 203 | 204 | """ 205 | rank: int = distributed.get_rank() 206 | world_size: int = distributed.get_world_size() 207 | import nvidia.dali.fn as fn 208 | import nvidia.dali.types as types 209 | from nvidia.dali.pipeline import Pipeline 210 | from nvidia.dali.plugin.pytorch import DALIClassificationIterator 211 | 212 | pipe = Pipeline( 213 | batch_size=batch_size, num_threads=num_threads, 214 | device_id=local_rank, prefetch_queue_depth=prefetch_queue_depth, ) 215 | condition_flip = fn.random.coin_flip(probability=0.5) 216 | with pipe: 217 | jpegs, labels = fn.readers.mxnet( 218 | path=rec_file, index_path=idx_file, initial_fill=initial_fill, 219 | num_shards=world_size, shard_id=rank, 220 | random_shuffle=random_shuffle, pad_last_batch=False, name=name) 221 | images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB) 222 | images = fn.crop_mirror_normalize( 223 | images, dtype=types.FLOAT, mean=mean, std=std, mirror=condition_flip) 224 | pipe.set_outputs(images, labels) 225 | pipe.build() 226 | return DALIWarper(DALIClassificationIterator(pipelines=[pipe], reader_name=name, )) 227 | 228 | 229 | @torch.no_grad() 230 | class DALIWarper(object): 231 | def __init__(self, dali_iter): 232 | self.iter = dali_iter 233 | 234 | def __next__(self): 235 | data_dict = self.iter.__next__()[0] 236 | tensor_data = data_dict['data'].cuda() 237 | tensor_label: torch.Tensor = data_dict['label'].cuda().long() 238 | tensor_label.squeeze_() 239 | return tensor_data, tensor_label 240 | 241 | def __iter__(self): 242 | return self 243 | 244 | def reset(self): 245 | self.iter.reset() 246 | -------------------------------------------------------------------------------- /src/arcface_torch/dist.sh: -------------------------------------------------------------------------------- 1 | ip_list=("ip1" "ip2" "ip3" "ip4") 2 | 3 | config=wf42m_pfc03_32gpu_r100 4 | 5 | for((node_rank=0;node_rank<${#ip_list[*]};node_rank++)); 6 | do 7 | ssh face@${ip_list[node_rank]} "cd `pwd`;PATH=$PATH \ 8 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ 9 | python -m torch.distributed.launch \ 10 | --nproc_per_node=8 \ 11 | --nnodes=${#ip_list[*]} \ 12 | --node_rank=$node_rank \ 13 | --master_addr=${ip_list[0]} \ 14 | --master_port=22345 train.py configs/$config" & 15 | done 16 | -------------------------------------------------------------------------------- /src/arcface_torch/docs/eval.md: -------------------------------------------------------------------------------- 1 | ## Eval on ICCV2021-MFR 2 | 3 | coming soon. 4 | 5 | 6 | ## Eval IJBC 7 | You can eval ijbc with pytorch or onnx. 8 | 9 | 10 | 1. Eval IJBC With Onnx 11 | ```shell 12 | CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50 13 | ``` 14 | 15 | 2. Eval IJBC With Pytorch 16 | ```shell 17 | CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \ 18 | --model-prefix ms1mv3_arcface_r50/backbone.pth \ 19 | --image-path IJB_release/IJBC \ 20 | --result-dir ms1mv3_arcface_r50 \ 21 | --batch-size 128 \ 22 | --job ms1mv3_arcface_r50 \ 23 | --target IJBC \ 24 | --network iresnet50 25 | ``` 26 | 27 | 28 | ## Inference 29 | 30 | ```shell 31 | python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50 32 | ``` 33 | 34 | 35 | ## Result 36 | 37 | | Datasets | Backbone | **MFR-ALL** | IJB-C(1E-4) | IJB-C(1E-5) | 38 | |:---------------|:--------------------|:------------|:------------|:------------| 39 | | WF12M-PFC-0.05 | r100 | 94.05 | 97.51 | 95.75 | 40 | | WF12M-PFC-0.1 | r100 | 94.49 | 97.56 | 95.92 | 41 | | WF12M-PFC-0.2 | r100 | 94.75 | 97.60 | 95.90 | 42 | | WF12M-PFC-0.3 | r100 | 94.71 | 97.64 | 96.01 | 43 | | WF12M | r100 | 94.69 | 97.59 | 95.97 | -------------------------------------------------------------------------------- /src/arcface_torch/docs/install.md: -------------------------------------------------------------------------------- 1 | ## [v1.11.0](https://pytorch.org/) 2 | 3 | ## [v1.9.0](https://pytorch.org/get-started/previous-versions/#linux-and-windows-7) 4 | ### Linux and Windows 5 | ```shell 6 | # CUDA 11.1 7 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 8 | 9 | # CUDA 10.2 10 | pip install torch==1.9.0+cu102 torchvision==0.10.0+cu102 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 11 | ``` 12 | -------------------------------------------------------------------------------- /src/arcface_torch/docs/install_dali.md: -------------------------------------------------------------------------------- 1 | TODO 2 | -------------------------------------------------------------------------------- /src/arcface_torch/docs/modelzoo.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wl-zhao/DiffSwap/8596b4d635e3d97621df688245b365bc4d3ae02a/src/arcface_torch/docs/modelzoo.md -------------------------------------------------------------------------------- /src/arcface_torch/docs/prepare_webface42m.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ## 1. Download Datasets and Unzip 5 | 6 | Download WebFace42M from [https://www.face-benchmark.org/download.html](https://www.face-benchmark.org/download.html). 7 | The raw data of `WebFace42M` will have 10 directories after being unarchived: 8 | `WebFace4M` contains 1 directory: `0`. 9 | `WebFace12M` contains 3 directories: `0,1,2`. 10 | `WebFace42M` contains 10 directories: `0,1,2,3,4,5,6,7,8,9`. 11 | 12 | ## 2. Create Shuffled Rec File for DALI 13 | 14 | Note: Shuffled rec is very important to DALI, and rec without shuffled can cause performance degradation, origin insightface style rec file 15 | do not support Nvidia DALI, you must follow this command [mxnet.tools.im2rec](https://github.com/apache/incubator-mxnet/blob/master/tools/im2rec.py) to generate a shuffled rec file. 16 | 17 | ```shell 18 | # directories and files for yours datsaets 19 | /WebFace42M_Root 20 | ├── 0_0_0000000 21 | │   ├── 0_0.jpg 22 | │   ├── 0_1.jpg 23 | │   ├── 0_2.jpg 24 | │   ├── 0_3.jpg 25 | │   └── 0_4.jpg 26 | ├── 0_0_0000001 27 | │   ├── 0_5.jpg 28 | │   ├── 0_6.jpg 29 | │   ├── 0_7.jpg 30 | │   ├── 0_8.jpg 31 | │   └── 0_9.jpg 32 | ├── 0_0_0000002 33 | │   ├── 0_10.jpg 34 | │   ├── 0_11.jpg 35 | │   ├── 0_12.jpg 36 | │   ├── 0_13.jpg 37 | │   ├── 0_14.jpg 38 | │   ├── 0_15.jpg 39 | │   ├── 0_16.jpg 40 | │   └── 0_17.jpg 41 | ├── 0_0_0000003 42 | │   ├── 0_18.jpg 43 | │   ├── 0_19.jpg 44 | │   └── 0_20.jpg 45 | ├── 0_0_0000004 46 | 47 | 48 | 49 | # 1) create train.lst using follow command 50 | python -m mxnet.tools.im2rec --list --recursive train WebFace42M_Root 51 | 52 | # 2) create train.rec and train.idx using train.lst using following command 53 | python -m mxnet.tools.im2rec --num-thread 16 --quality 100 train WebFace42M_Root 54 | ``` 55 | 56 | Finally, you will get three files: `train.lst`, `train.rec`, `train.idx`. which `train.idx`, `train.rec` are using for training. 57 | -------------------------------------------------------------------------------- /src/arcface_torch/docs/speed_benchmark.md: -------------------------------------------------------------------------------- 1 | ## Test Training Speed 2 | 3 | - Test Commands 4 | 5 | You need to use the following two commands to test the Partial FC training performance. 6 | The number of identites is **3 millions** (synthetic data), turn mixed precision training on, backbone is resnet50, 7 | batch size is 1024. 8 | ```shell 9 | # Model Parallel 10 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions 11 | # Partial FC 0.1 12 | python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc 13 | ``` 14 | 15 | - GPU Memory 16 | 17 | ``` 18 | # (Model Parallel) gpustat -i 19 | [0] Tesla V100-SXM2-32GB | 64'C, 94 % | 30338 / 32510 MB 20 | [1] Tesla V100-SXM2-32GB | 60'C, 99 % | 28876 / 32510 MB 21 | [2] Tesla V100-SXM2-32GB | 60'C, 99 % | 28872 / 32510 MB 22 | [3] Tesla V100-SXM2-32GB | 69'C, 99 % | 28872 / 32510 MB 23 | [4] Tesla V100-SXM2-32GB | 66'C, 99 % | 28888 / 32510 MB 24 | [5] Tesla V100-SXM2-32GB | 60'C, 99 % | 28932 / 32510 MB 25 | [6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB 26 | [7] Tesla V100-SXM2-32GB | 65'C, 99 % | 28860 / 32510 MB 27 | 28 | # (Partial FC 0.1) gpustat -i 29 | [0] Tesla V100-SXM2-32GB | 60'C, 95 % | 10488 / 32510 MB │······················· 30 | [1] Tesla V100-SXM2-32GB | 60'C, 97 % | 10344 / 32510 MB │······················· 31 | [2] Tesla V100-SXM2-32GB | 61'C, 95 % | 10340 / 32510 MB │······················· 32 | [3] Tesla V100-SXM2-32GB | 66'C, 95 % | 10340 / 32510 MB │······················· 33 | [4] Tesla V100-SXM2-32GB | 65'C, 94 % | 10356 / 32510 MB │······················· 34 | [5] Tesla V100-SXM2-32GB | 61'C, 95 % | 10400 / 32510 MB │······················· 35 | [6] Tesla V100-SXM2-32GB | 68'C, 96 % | 10384 / 32510 MB │······················· 36 | [7] Tesla V100-SXM2-32GB | 64'C, 95 % | 10328 / 32510 MB │······················· 37 | ``` 38 | 39 | - Training Speed 40 | 41 | ```python 42 | # (Model Parallel) trainging.log 43 | Training: Speed 2271.33 samples/sec Loss 1.1624 LearningRate 0.2000 Epoch: 0 Global Step: 100 44 | Training: Speed 2269.94 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 45 | Training: Speed 2272.67 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 46 | Training: Speed 2266.55 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 47 | Training: Speed 2272.54 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 48 | 49 | # (Partial FC 0.1) trainging.log 50 | Training: Speed 5299.56 samples/sec Loss 1.0965 LearningRate 0.2000 Epoch: 0 Global Step: 100 51 | Training: Speed 5296.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 150 52 | Training: Speed 5304.37 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 200 53 | Training: Speed 5274.43 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 250 54 | Training: Speed 5300.10 samples/sec Loss 0.0000 LearningRate 0.2000 Epoch: 0 Global Step: 300 55 | ``` 56 | 57 | In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel, 58 | and the training speed is 2.5 times faster than the model parallel. 59 | 60 | 61 | ## Speed Benchmark 62 | 63 | 1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better) 64 | 65 | | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | 66 | | :--- | :--- | :--- | :--- | 67 | |125000 | 4681 | 4824 | 5004 | 68 | |250000 | 4047 | 4521 | 4976 | 69 | |500000 | 3087 | 4013 | 4900 | 70 | |1000000 | 2090 | 3449 | 4803 | 71 | |1400000 | 1672 | 3043 | 4738 | 72 | |2000000 | - | 2593 | 4626 | 73 | |4000000 | - | 1748 | 4208 | 74 | |5500000 | - | 1389 | 3975 | 75 | |8000000 | - | - | 3565 | 76 | |16000000 | - | - | 2679 | 77 | |29000000 | - | - | 1855 | 78 | 79 | 2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better) 80 | 81 | | Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 | 82 | | :--- | :--- | :--- | :--- | 83 | |125000 | 7358 | 5306 | 4868 | 84 | |250000 | 9940 | 5826 | 5004 | 85 | |500000 | 14220 | 7114 | 5202 | 86 | |1000000 | 23708 | 9966 | 5620 | 87 | |1400000 | 32252 | 11178 | 6056 | 88 | |2000000 | - | 13978 | 6472 | 89 | |4000000 | - | 23238 | 8284 | 90 | |5500000 | - | 32188 | 9854 | 91 | |8000000 | - | - | 12310 | 92 | |16000000 | - | - | 19950 | 93 | |29000000 | - | - | 32324 | 94 | -------------------------------------------------------------------------------- /src/arcface_torch/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wl-zhao/DiffSwap/8596b4d635e3d97621df688245b365bc4d3ae02a/src/arcface_torch/eval/__init__.py -------------------------------------------------------------------------------- /src/arcface_torch/flops.py: -------------------------------------------------------------------------------- 1 | from ptflops import get_model_complexity_info 2 | from backbones import get_model 3 | import argparse 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser(description='') 7 | parser.add_argument('n', type=str, default="r100") 8 | args = parser.parse_args() 9 | net = get_model(args.n) 10 | macs, params = get_model_complexity_info( 11 | net, (3, 112, 112), as_strings=False, 12 | print_per_layer_stat=True, verbose=True) 13 | gmacs = macs / (1000**3) 14 | print("%.3f GFLOPs"%gmacs) 15 | print("%.3f Mparams"%(params/(1000**2))) 16 | 17 | if hasattr(net, "extra_gflops"): 18 | print("%.3f Extra-GFLOPs"%net.extra_gflops) 19 | print("%.3f Total-GFLOPs"%(gmacs+net.extra_gflops)) 20 | 21 | -------------------------------------------------------------------------------- /src/arcface_torch/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from backbones import get_model 8 | 9 | 10 | @torch.no_grad() 11 | def inference(weight, name, img): 12 | if img is None: 13 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) 14 | else: 15 | img = cv2.imread(img) 16 | img = cv2.resize(img, (112, 112)) 17 | 18 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 19 | img = np.transpose(img, (2, 0, 1)) 20 | img = torch.from_numpy(img).unsqueeze(0).float() 21 | img.div_(255).sub_(0.5).div_(0.5) 22 | net = get_model(name, fp16=False) 23 | net.load_state_dict(torch.load(weight)) 24 | net.eval() 25 | feat = net(img).numpy() 26 | print(feat) 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser(description='PyTorch ArcFace Training') 31 | parser.add_argument('--network', type=str, default='r50', help='backbone network') 32 | parser.add_argument('--weight', type=str, default='') 33 | parser.add_argument('--img', type=str, default=None) 34 | args = parser.parse_args() 35 | inference(args.weight, args.network, args.img) 36 | -------------------------------------------------------------------------------- /src/arcface_torch/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | 5 | class CombinedMarginLoss(torch.nn.Module): 6 | def __init__(self, 7 | s, 8 | m1, 9 | m2, 10 | m3, 11 | interclass_filtering_threshold=0): 12 | super().__init__() 13 | self.s = s 14 | self.m1 = m1 15 | self.m2 = m2 16 | self.m3 = m3 17 | self.interclass_filtering_threshold = interclass_filtering_threshold 18 | 19 | # For ArcFace 20 | self.cos_m = math.cos(self.m2) 21 | self.sin_m = math.sin(self.m2) 22 | self.theta = math.cos(math.pi - self.m2) 23 | self.sinmm = math.sin(math.pi - self.m2) * self.m2 24 | self.easy_margin = False 25 | 26 | 27 | def forward(self, logits, labels): 28 | index_positive = torch.where(labels != -1)[0] 29 | 30 | if self.interclass_filtering_threshold > 0: 31 | with torch.no_grad(): 32 | dirty = logits > self.interclass_filtering_threshold 33 | dirty = dirty.float() 34 | mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device) 35 | mask.scatter_(1, labels[index_positive], 0) 36 | dirty[index_positive] *= mask 37 | tensor_mul = 1 - dirty 38 | logits = tensor_mul * logits 39 | 40 | target_logit = logits[index_positive, labels[index_positive].view(-1)] 41 | 42 | if self.m1 == 1.0 and self.m3 == 0.0: 43 | sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2)) 44 | cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin) 45 | if self.easy_margin: 46 | final_target_logit = torch.where( 47 | target_logit > 0, cos_theta_m, target_logit) 48 | else: 49 | final_target_logit = torch.where( 50 | target_logit > self.theta, cos_theta_m, target_logit - self.sinmm) 51 | logits[index_positive, labels[index_positive].view(-1)] = final_target_logit 52 | logits = logits * self.s 53 | 54 | elif self.m3 > 0: 55 | final_target_logit = target_logit - self.m3 56 | logits[index_positive, labels[index_positive].view(-1)] = final_target_logit 57 | logits = logits * self.s 58 | else: 59 | raise 60 | 61 | return logits 62 | 63 | class ArcFace(torch.nn.Module): 64 | """ ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): 65 | """ 66 | def __init__(self, s=64.0, margin=0.5): 67 | super(ArcFace, self).__init__() 68 | self.scale = s 69 | self.cos_m = math.cos(margin) 70 | self.sin_m = math.sin(margin) 71 | self.theta = math.cos(math.pi - margin) 72 | self.sinmm = math.sin(math.pi - margin) * margin 73 | self.easy_margin = False 74 | 75 | 76 | def forward(self, logits: torch.Tensor, labels: torch.Tensor): 77 | index = torch.where(labels != -1)[0] 78 | target_logit = logits[index, labels[index].view(-1)] 79 | 80 | sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2)) 81 | cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m # cos(target+margin) 82 | if self.easy_margin: 83 | final_target_logit = torch.where( 84 | target_logit > 0, cos_theta_m, target_logit) 85 | else: 86 | final_target_logit = torch.where( 87 | target_logit > self.theta, cos_theta_m, target_logit - self.sinmm) 88 | 89 | logits[index, labels[index].view(-1)] = final_target_logit 90 | logits = logits * self.scale 91 | return logits 92 | 93 | 94 | class CosFace(torch.nn.Module): 95 | def __init__(self, s=64.0, m=0.40): 96 | super(CosFace, self).__init__() 97 | self.s = s 98 | self.m = m 99 | 100 | def forward(self, logits: torch.Tensor, labels: torch.Tensor): 101 | index = torch.where(labels != -1)[0] 102 | target_logit = logits[index, labels[index].view(-1)] 103 | final_target_logit = target_logit - self.m 104 | logits[index, labels[index].view(-1)] = final_target_logit 105 | logits = logits * self.s 106 | return logits 107 | -------------------------------------------------------------------------------- /src/arcface_torch/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | 3 | 4 | class PolyScheduler(_LRScheduler): 5 | def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1): 6 | self.base_lr = base_lr 7 | self.warmup_lr_init = 0.0001 8 | self.max_steps: int = max_steps 9 | self.warmup_steps: int = warmup_steps 10 | self.power = 2 11 | super(PolyScheduler, self).__init__(optimizer, -1, False) 12 | self.last_epoch = last_epoch 13 | 14 | def get_warmup_lr(self): 15 | alpha = float(self.last_epoch) / float(self.warmup_steps) 16 | return [self.base_lr * alpha for _ in self.optimizer.param_groups] 17 | 18 | def get_lr(self): 19 | if self.last_epoch == -1: 20 | return [self.warmup_lr_init for _ in self.optimizer.param_groups] 21 | if self.last_epoch < self.warmup_steps: 22 | return self.get_warmup_lr() 23 | else: 24 | alpha = pow( 25 | 1 26 | - float(self.last_epoch - self.warmup_steps) 27 | / float(self.max_steps - self.warmup_steps), 28 | self.power, 29 | ) 30 | return [self.base_lr * alpha for _ in self.optimizer.param_groups] 31 | -------------------------------------------------------------------------------- /src/arcface_torch/requirement.txt: -------------------------------------------------------------------------------- 1 | tensorboard 2 | easydict 3 | mxnet 4 | onnx 5 | sklearn 6 | -------------------------------------------------------------------------------- /src/arcface_torch/run.sh: -------------------------------------------------------------------------------- 1 | 2 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch \ 3 | --nproc_per_node=8 \ 4 | --nnodes=1 \ 5 | --node_rank=0 \ 6 | --master_addr="127.0.0.1" \ 7 | --master_port=12345 train.py $@ 8 | 9 | ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh 10 | -------------------------------------------------------------------------------- /src/arcface_torch/torch2onnx.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import onnx 3 | import torch 4 | 5 | 6 | def convert_onnx(net, path_module, output, opset=11, simplify=False): 7 | assert isinstance(net, torch.nn.Module) 8 | img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) 9 | img = img.astype(np.float) 10 | img = (img / 255. - 0.5) / 0.5 # torch style norm 11 | img = img.transpose((2, 0, 1)) 12 | img = torch.from_numpy(img).unsqueeze(0).float() 13 | 14 | weight = torch.load(path_module) 15 | net.load_state_dict(weight, strict=True) 16 | net.eval() 17 | torch.onnx.export(net, img, output, input_names=["data"], keep_initializers_as_inputs=False, verbose=False, opset_version=opset) 18 | model = onnx.load(output) 19 | graph = model.graph 20 | graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None' 21 | if simplify: 22 | from onnxsim import simplify 23 | model, check = simplify(model) 24 | assert check, "Simplified ONNX model could not be validated" 25 | onnx.save(model, output) 26 | 27 | 28 | if __name__ == '__main__': 29 | import os 30 | import argparse 31 | from backbones import get_model 32 | 33 | parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx') 34 | parser.add_argument('input', type=str, help='input backbone.pth file or path') 35 | parser.add_argument('--output', type=str, default=None, help='output onnx path') 36 | parser.add_argument('--network', type=str, default=None, help='backbone network') 37 | parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify') 38 | args = parser.parse_args() 39 | input_file = args.input 40 | if os.path.isdir(input_file): 41 | input_file = os.path.join(input_file, "model.pt") 42 | assert os.path.exists(input_file) 43 | # model_name = os.path.basename(os.path.dirname(input_file)).lower() 44 | # params = model_name.split("_") 45 | # if len(params) >= 3 and params[1] in ('arcface', 'cosface'): 46 | # if args.network is None: 47 | # args.network = params[2] 48 | assert args.network is not None 49 | print(args) 50 | backbone_onnx = get_model(args.network, dropout=0.0, fp16=False, num_features=512) 51 | if args.output is None: 52 | args.output = os.path.join(os.path.dirname(args.input), "model.onnx") 53 | convert_onnx(backbone_onnx, input_file, args.output, simplify=args.simplify) 54 | -------------------------------------------------------------------------------- /src/arcface_torch/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | from torch import distributed 8 | from torch.utils.data import DataLoader 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from backbones import get_model 12 | from dataset import get_dataloader 13 | from losses import CombinedMarginLoss 14 | from lr_scheduler import PolyScheduler 15 | from partial_fc import PartialFC, PartialFCAdamW 16 | from utils.utils_callbacks import CallBackLogging, CallBackVerification 17 | from utils.utils_config import get_config 18 | from utils.utils_logging import AverageMeter, init_logging 19 | from utils.utils_distributed_sampler import setup_seed 20 | 21 | assert torch.__version__ >= "1.9.0", "In order to enjoy the features of the new torch, \ 22 | we have upgraded the torch to 1.9.0. torch before than 1.9.0 may not work in the future." 23 | 24 | try: 25 | world_size = int(os.environ["WORLD_SIZE"]) 26 | rank = int(os.environ["RANK"]) 27 | distributed.init_process_group("nccl") 28 | except KeyError: 29 | world_size = 1 30 | rank = 0 31 | distributed.init_process_group( 32 | backend="nccl", 33 | init_method="tcp://127.0.0.1:12584", 34 | rank=rank, 35 | world_size=world_size, 36 | ) 37 | 38 | 39 | def main(args): 40 | 41 | # get config 42 | cfg = get_config(args.config) 43 | # global control random seed 44 | setup_seed(seed=cfg.seed, cuda_deterministic=False) 45 | 46 | torch.cuda.set_device(args.local_rank) 47 | 48 | os.makedirs(cfg.output, exist_ok=True) 49 | init_logging(rank, cfg.output) 50 | 51 | summary_writer = ( 52 | SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard")) 53 | if rank == 0 54 | else None 55 | ) 56 | 57 | train_loader = get_dataloader( 58 | cfg.rec, 59 | args.local_rank, 60 | cfg.batch_size, 61 | cfg.dali, 62 | cfg.seed, 63 | cfg.num_workers 64 | ) 65 | 66 | backbone = get_model( 67 | cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda() 68 | 69 | backbone = torch.nn.parallel.DistributedDataParallel( 70 | module=backbone, broadcast_buffers=False, device_ids=[args.local_rank], bucket_cap_mb=16, 71 | find_unused_parameters=True) 72 | 73 | backbone.train() 74 | # FIXME using gradient checkpoint if there are some unused parameters will cause error 75 | backbone._set_static_graph() 76 | 77 | margin_loss = CombinedMarginLoss( 78 | 64, 79 | cfg.margin_list[0], 80 | cfg.margin_list[1], 81 | cfg.margin_list[2], 82 | cfg.interclass_filtering_threshold 83 | ) 84 | 85 | if cfg.optimizer == "sgd": 86 | module_partial_fc = PartialFC( 87 | margin_loss, cfg.embedding_size, cfg.num_classes, 88 | cfg.sample_rate, cfg.fp16) 89 | module_partial_fc.train().cuda() 90 | # TODO the params of partial fc must be last in the params list 91 | opt = torch.optim.SGD( 92 | params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}], 93 | lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay) 94 | 95 | elif cfg.optimizer == "adamw": 96 | module_partial_fc = PartialFCAdamW( 97 | margin_loss, cfg.embedding_size, cfg.num_classes, 98 | cfg.sample_rate, cfg.fp16) 99 | module_partial_fc.train().cuda() 100 | opt = torch.optim.AdamW( 101 | params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}], 102 | lr=cfg.lr, weight_decay=cfg.weight_decay) 103 | else: 104 | raise 105 | 106 | cfg.total_batch_size = cfg.batch_size * world_size 107 | cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch 108 | cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch 109 | 110 | lr_scheduler = PolyScheduler( 111 | optimizer=opt, 112 | base_lr=cfg.lr, 113 | max_steps=cfg.total_step, 114 | warmup_steps=cfg.warmup_step, 115 | last_epoch=-1 116 | ) 117 | 118 | start_epoch = 0 119 | global_step = 0 120 | if cfg.resume: 121 | dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt")) 122 | start_epoch = dict_checkpoint["epoch"] 123 | global_step = dict_checkpoint["global_step"] 124 | backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"]) 125 | module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"]) 126 | opt.load_state_dict(dict_checkpoint["state_optimizer"]) 127 | lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"]) 128 | del dict_checkpoint 129 | 130 | for key, value in cfg.items(): 131 | num_space = 25 - len(key) 132 | logging.info(": " + key + " " * num_space + str(value)) 133 | 134 | callback_verification = CallBackVerification( 135 | val_targets=cfg.val_targets, rec_prefix=cfg.rec, summary_writer=summary_writer 136 | ) 137 | callback_logging = CallBackLogging( 138 | frequent=cfg.frequent, 139 | total_step=cfg.total_step, 140 | batch_size=cfg.batch_size, 141 | start_step = global_step, 142 | writer=summary_writer 143 | ) 144 | 145 | loss_am = AverageMeter() 146 | amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100) 147 | 148 | for epoch in range(start_epoch, cfg.num_epoch): 149 | 150 | if isinstance(train_loader, DataLoader): 151 | train_loader.sampler.set_epoch(epoch) 152 | for _, (img, local_labels) in enumerate(train_loader): 153 | global_step += 1 154 | local_embeddings = backbone(img) 155 | loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels, opt) 156 | 157 | if cfg.fp16: 158 | amp.scale(loss).backward() 159 | amp.unscale_(opt) 160 | torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) 161 | amp.step(opt) 162 | amp.update() 163 | else: 164 | loss.backward() 165 | torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) 166 | opt.step() 167 | 168 | opt.zero_grad() 169 | lr_scheduler.step() 170 | 171 | with torch.no_grad(): 172 | loss_am.update(loss.item(), 1) 173 | callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp) 174 | 175 | if global_step % cfg.verbose == 0 and global_step > 0: 176 | callback_verification(global_step, backbone) 177 | 178 | if cfg.save_all_states: 179 | checkpoint = { 180 | "epoch": epoch + 1, 181 | "global_step": global_step, 182 | "state_dict_backbone": backbone.module.state_dict(), 183 | "state_dict_softmax_fc": module_partial_fc.state_dict(), 184 | "state_optimizer": opt.state_dict(), 185 | "state_lr_scheduler": lr_scheduler.state_dict() 186 | } 187 | torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt")) 188 | 189 | if rank == 0: 190 | path_module = os.path.join(cfg.output, "model.pt") 191 | torch.save(backbone.module.state_dict(), path_module) 192 | 193 | if cfg.dali: 194 | train_loader.reset() 195 | 196 | if rank == 0: 197 | path_module = os.path.join(cfg.output, "model.pt") 198 | torch.save(backbone.module.state_dict(), path_module) 199 | 200 | from torch2onnx import convert_onnx 201 | convert_onnx(backbone.module.cpu().eval(), path_module, os.path.join(cfg.output, "model.onnx")) 202 | 203 | distributed.destroy_process_group() 204 | 205 | 206 | if __name__ == "__main__": 207 | torch.backends.cudnn.benchmark = True 208 | parser = argparse.ArgumentParser( 209 | description="Distributed Arcface Training in Pytorch") 210 | parser.add_argument("config", type=str, help="py config file") 211 | parser.add_argument("--local_rank", type=int, default=0, help="local_rank") 212 | main(parser.parse_args()) 213 | -------------------------------------------------------------------------------- /src/arcface_torch/train_v2.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | 5 | import numpy as np 6 | import torch 7 | from torch import distributed 8 | from torch.utils.data import DataLoader 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from backbones import get_model 12 | from dataset import get_dataloader 13 | from losses import CombinedMarginLoss 14 | from lr_scheduler import PolyScheduler 15 | from partial_fc_v2 import PartialFC_V2 16 | from utils.utils_callbacks import CallBackLogging, CallBackVerification 17 | from utils.utils_config import get_config 18 | from utils.utils_logging import AverageMeter, init_logging 19 | from utils.utils_distributed_sampler import setup_seed 20 | 21 | assert torch.__version__ >= "1.9.0", "In order to enjoy the features of the new torch, \ 22 | we have upgraded the torch to 1.9.0. torch before than 1.9.0 may not work in the future." 23 | 24 | try: 25 | world_size = int(os.environ["WORLD_SIZE"]) 26 | rank = int(os.environ["RANK"]) 27 | distributed.init_process_group("nccl") 28 | except KeyError: 29 | world_size = 1 30 | rank = 0 31 | distributed.init_process_group( 32 | backend="nccl", 33 | init_method="tcp://127.0.0.1:12584", 34 | rank=rank, 35 | world_size=world_size, 36 | ) 37 | 38 | 39 | def main(args): 40 | 41 | # get config 42 | cfg = get_config(args.config) 43 | # global control random seed 44 | setup_seed(seed=cfg.seed, cuda_deterministic=False) 45 | 46 | torch.cuda.set_device(args.local_rank) 47 | 48 | os.makedirs(cfg.output, exist_ok=True) 49 | init_logging(rank, cfg.output) 50 | 51 | summary_writer = ( 52 | SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard")) 53 | if rank == 0 54 | else None 55 | ) 56 | 57 | train_loader = get_dataloader( 58 | cfg.rec, 59 | args.local_rank, 60 | cfg.batch_size, 61 | cfg.dali, 62 | cfg.seed, 63 | cfg.num_workers 64 | ) 65 | 66 | backbone = get_model( 67 | cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda() 68 | 69 | backbone = torch.nn.parallel.DistributedDataParallel( 70 | module=backbone, broadcast_buffers=False, device_ids=[args.local_rank], bucket_cap_mb=16, 71 | find_unused_parameters=True) 72 | 73 | backbone.train() 74 | # FIXME using gradient checkpoint if there are some unused parameters will cause error 75 | backbone._set_static_graph() 76 | 77 | margin_loss = CombinedMarginLoss( 78 | 64, 79 | cfg.margin_list[0], 80 | cfg.margin_list[1], 81 | cfg.margin_list[2], 82 | cfg.interclass_filtering_threshold 83 | ) 84 | 85 | if cfg.optimizer == "sgd": 86 | module_partial_fc = PartialFC_V2( 87 | margin_loss, cfg.embedding_size, cfg.num_classes, 88 | cfg.sample_rate, cfg.fp16) 89 | module_partial_fc.train().cuda() 90 | # TODO the params of partial fc must be last in the params list 91 | opt = torch.optim.SGD( 92 | params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}], 93 | lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay) 94 | 95 | elif cfg.optimizer == "adamw": 96 | module_partial_fc = PartialFC_V2( 97 | margin_loss, cfg.embedding_size, cfg.num_classes, 98 | cfg.sample_rate, cfg.fp16) 99 | module_partial_fc.train().cuda() 100 | opt = torch.optim.AdamW( 101 | params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}], 102 | lr=cfg.lr, weight_decay=cfg.weight_decay) 103 | else: 104 | raise 105 | 106 | cfg.total_batch_size = cfg.batch_size * world_size 107 | cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch 108 | cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch 109 | 110 | lr_scheduler = PolyScheduler( 111 | optimizer=opt, 112 | base_lr=cfg.lr, 113 | max_steps=cfg.total_step, 114 | warmup_steps=cfg.warmup_step, 115 | last_epoch=-1 116 | ) 117 | 118 | start_epoch = 0 119 | global_step = 0 120 | if cfg.resume: 121 | dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt")) 122 | start_epoch = dict_checkpoint["epoch"] 123 | global_step = dict_checkpoint["global_step"] 124 | backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"]) 125 | module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"]) 126 | opt.load_state_dict(dict_checkpoint["state_optimizer"]) 127 | lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"]) 128 | del dict_checkpoint 129 | 130 | for key, value in cfg.items(): 131 | num_space = 25 - len(key) 132 | logging.info(": " + key + " " * num_space + str(value)) 133 | 134 | callback_verification = CallBackVerification( 135 | val_targets=cfg.val_targets, rec_prefix=cfg.rec, summary_writer=summary_writer 136 | ) 137 | callback_logging = CallBackLogging( 138 | frequent=cfg.frequent, 139 | total_step=cfg.total_step, 140 | batch_size=cfg.batch_size, 141 | start_step = global_step, 142 | writer=summary_writer 143 | ) 144 | 145 | loss_am = AverageMeter() 146 | amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100) 147 | 148 | for epoch in range(start_epoch, cfg.num_epoch): 149 | 150 | if isinstance(train_loader, DataLoader): 151 | train_loader.sampler.set_epoch(epoch) 152 | for _, (img, local_labels) in enumerate(train_loader): 153 | global_step += 1 154 | local_embeddings = backbone(img) 155 | loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels) 156 | 157 | if cfg.fp16: 158 | amp.scale(loss).backward() 159 | if global_step % cfg.gradient_acc == 0: 160 | amp.unscale_(opt) 161 | torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) 162 | amp.step(opt) 163 | amp.update() 164 | opt.zero_grad() 165 | else: 166 | loss.backward() 167 | if global_step % cfg.gradient_acc == 0: 168 | torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5) 169 | opt.step() 170 | opt.zero_grad() 171 | lr_scheduler.step() 172 | 173 | with torch.no_grad(): 174 | loss_am.update(loss.item(), 1) 175 | callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp) 176 | 177 | if global_step % cfg.verbose == 0 and global_step > 0: 178 | callback_verification(global_step, backbone) 179 | 180 | if cfg.save_all_states: 181 | checkpoint = { 182 | "epoch": epoch + 1, 183 | "global_step": global_step, 184 | "state_dict_backbone": backbone.module.state_dict(), 185 | "state_dict_softmax_fc": module_partial_fc.state_dict(), 186 | "state_optimizer": opt.state_dict(), 187 | "state_lr_scheduler": lr_scheduler.state_dict() 188 | } 189 | torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt")) 190 | 191 | if rank == 0: 192 | path_module = os.path.join(cfg.output, "model.pt") 193 | torch.save(backbone.module.state_dict(), path_module) 194 | 195 | if cfg.dali: 196 | train_loader.reset() 197 | 198 | if rank == 0: 199 | path_module = os.path.join(cfg.output, "model.pt") 200 | torch.save(backbone.module.state_dict(), path_module) 201 | 202 | 203 | if __name__ == "__main__": 204 | torch.backends.cudnn.benchmark = True 205 | parser = argparse.ArgumentParser( 206 | description="Distributed Arcface Training in Pytorch") 207 | parser.add_argument("config", type=str, help="py config file") 208 | parser.add_argument("--local_rank", type=int, default=0, help="local_rank") 209 | main(parser.parse_args()) 210 | -------------------------------------------------------------------------------- /src/arcface_torch/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wl-zhao/DiffSwap/8596b4d635e3d97621df688245b365bc4d3ae02a/src/arcface_torch/utils/__init__.py -------------------------------------------------------------------------------- /src/arcface_torch/utils/plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import pandas as pd 7 | from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap 8 | from prettytable import PrettyTable 9 | from sklearn.metrics import roc_curve, auc 10 | 11 | with open(sys.argv[1], "r") as f: 12 | files = f.readlines() 13 | 14 | files = [x.strip() for x in files] 15 | image_path = "/train_tmp/IJB_release/IJBC" 16 | 17 | 18 | def read_template_pair_list(path): 19 | pairs = pd.read_csv(path, sep=' ', header=None).values 20 | t1 = pairs[:, 0].astype(np.int) 21 | t2 = pairs[:, 1].astype(np.int) 22 | label = pairs[:, 2].astype(np.int) 23 | return t1, t2, label 24 | 25 | 26 | p1, p2, label = read_template_pair_list( 27 | os.path.join('%s/meta' % image_path, 28 | '%s_template_pair_label.txt' % 'ijbc')) 29 | 30 | methods = [] 31 | scores = [] 32 | for file in files: 33 | methods.append(file) 34 | scores.append(np.load(file)) 35 | 36 | methods = np.array(methods) 37 | scores = dict(zip(methods, scores)) 38 | colours = dict( 39 | zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2'))) 40 | x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] 41 | tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels]) 42 | fig = plt.figure() 43 | for method in methods: 44 | fpr, tpr, _ = roc_curve(label, scores[method]) 45 | roc_auc = auc(fpr, tpr) 46 | fpr = np.flipud(fpr) 47 | tpr = np.flipud(tpr) # select largest tpr at same fpr 48 | plt.plot(fpr, 49 | tpr, 50 | color=colours[method], 51 | lw=1, 52 | label=('[%s (AUC = %0.4f %%)]' % 53 | (method.split('-')[-1], roc_auc * 100))) 54 | tpr_fpr_row = [] 55 | tpr_fpr_row.append(method) 56 | for fpr_iter in np.arange(len(x_labels)): 57 | _, min_index = min( 58 | list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) 59 | tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100)) 60 | tpr_fpr_table.add_row(tpr_fpr_row) 61 | plt.xlim([10 ** -6, 0.1]) 62 | plt.ylim([0.3, 1.0]) 63 | plt.grid(linestyle='--', linewidth=1) 64 | plt.xticks(x_labels) 65 | plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True)) 66 | plt.xscale('log') 67 | plt.xlabel('False Positive Rate') 68 | plt.ylabel('True Positive Rate') 69 | plt.title('ROC on IJB') 70 | plt.legend(loc="lower right") 71 | print(tpr_fpr_table) 72 | -------------------------------------------------------------------------------- /src/arcface_torch/utils/utils_callbacks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import time 4 | from typing import List 5 | 6 | import torch 7 | 8 | from eval import verification 9 | from utils.utils_logging import AverageMeter 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torch import distributed 12 | 13 | 14 | class CallBackVerification(object): 15 | 16 | def __init__(self, val_targets, rec_prefix, summary_writer=None, image_size=(112, 112)): 17 | self.rank: int = distributed.get_rank() 18 | self.highest_acc: float = 0.0 19 | self.highest_acc_list: List[float] = [0.0] * len(val_targets) 20 | self.ver_list: List[object] = [] 21 | self.ver_name_list: List[str] = [] 22 | if self.rank is 0: 23 | self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size) 24 | 25 | self.summary_writer = summary_writer 26 | 27 | def ver_test(self, backbone: torch.nn.Module, global_step: int): 28 | results = [] 29 | for i in range(len(self.ver_list)): 30 | acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test( 31 | self.ver_list[i], backbone, 10, 10) 32 | logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm)) 33 | logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2)) 34 | 35 | self.summary_writer: SummaryWriter 36 | self.summary_writer.add_scalar(tag=self.ver_name_list[i], scalar_value=acc2, global_step=global_step, ) 37 | 38 | if acc2 > self.highest_acc_list[i]: 39 | self.highest_acc_list[i] = acc2 40 | logging.info( 41 | '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i])) 42 | results.append(acc2) 43 | 44 | def init_dataset(self, val_targets, data_dir, image_size): 45 | for name in val_targets: 46 | path = os.path.join(data_dir, name + ".bin") 47 | if os.path.exists(path): 48 | data_set = verification.load_bin(path, image_size) 49 | self.ver_list.append(data_set) 50 | self.ver_name_list.append(name) 51 | 52 | def __call__(self, num_update, backbone: torch.nn.Module): 53 | if self.rank is 0 and num_update > 0: 54 | backbone.eval() 55 | self.ver_test(backbone, num_update) 56 | backbone.train() 57 | 58 | 59 | class CallBackLogging(object): 60 | def __init__(self, frequent, total_step, batch_size, start_step=0,writer=None): 61 | self.frequent: int = frequent 62 | self.rank: int = distributed.get_rank() 63 | self.world_size: int = distributed.get_world_size() 64 | self.time_start = time.time() 65 | self.total_step: int = total_step 66 | self.start_step: int = start_step 67 | self.batch_size: int = batch_size 68 | self.writer = writer 69 | 70 | self.init = False 71 | self.tic = 0 72 | 73 | def __call__(self, 74 | global_step: int, 75 | loss: AverageMeter, 76 | epoch: int, 77 | fp16: bool, 78 | learning_rate: float, 79 | grad_scaler: torch.cuda.amp.GradScaler): 80 | if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0: 81 | if self.init: 82 | try: 83 | speed: float = self.frequent * self.batch_size / (time.time() - self.tic) 84 | speed_total = speed * self.world_size 85 | except ZeroDivisionError: 86 | speed_total = float('inf') 87 | 88 | #time_now = (time.time() - self.time_start) / 3600 89 | #time_total = time_now / ((global_step + 1) / self.total_step) 90 | #time_for_end = time_total - time_now 91 | time_now = time.time() 92 | time_sec = int(time_now - self.time_start) 93 | time_sec_avg = time_sec / (global_step - self.start_step + 1) 94 | eta_sec = time_sec_avg * (self.total_step - global_step - 1) 95 | time_for_end = eta_sec/3600 96 | if self.writer is not None: 97 | self.writer.add_scalar('time_for_end', time_for_end, global_step) 98 | self.writer.add_scalar('learning_rate', learning_rate, global_step) 99 | self.writer.add_scalar('loss', loss.avg, global_step) 100 | if fp16: 101 | msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d " \ 102 | "Fp16 Grad Scale: %2.f Required: %1.f hours" % ( 103 | speed_total, loss.avg, learning_rate, epoch, global_step, 104 | grad_scaler.get_scale(), time_for_end 105 | ) 106 | else: 107 | msg = "Speed %.2f samples/sec Loss %.4f LearningRate %.6f Epoch: %d Global Step: %d " \ 108 | "Required: %1.f hours" % ( 109 | speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end 110 | ) 111 | logging.info(msg) 112 | loss.reset() 113 | self.tic = time.time() 114 | else: 115 | self.init = True 116 | self.tic = time.time() 117 | -------------------------------------------------------------------------------- /src/arcface_torch/utils/utils_config.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os.path as osp 3 | 4 | 5 | def get_config(config_file): 6 | # assert config_file.startswith('configs/'), 'config file setting must start with configs/' 7 | temp_config_name = osp.basename(config_file) 8 | temp_module_name = osp.splitext(temp_config_name)[0] 9 | config = importlib.import_module(".configs.base", package='src.arcface_torch') 10 | cfg = config.config 11 | config = importlib.import_module(".configs.%s" % temp_module_name, package='src.arcface_torch') 12 | job_cfg = config.config 13 | cfg.update(job_cfg) 14 | if cfg.output is None: 15 | cfg.output = osp.join('work_dirs', temp_module_name) 16 | return cfg -------------------------------------------------------------------------------- /src/arcface_torch/utils/utils_distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | from torch.utils.data import DistributedSampler as _DistributedSampler 9 | 10 | 11 | def setup_seed(seed, cuda_deterministic=True): 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | np.random.seed(seed) 15 | random.seed(seed) 16 | os.environ["PYTHONHASHSEED"] = str(seed) 17 | if cuda_deterministic: # slower, more reproducible 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | else: # faster, less reproducible 21 | torch.backends.cudnn.deterministic = False 22 | torch.backends.cudnn.benchmark = True 23 | 24 | 25 | def worker_init_fn(worker_id, num_workers, rank, seed): 26 | # The seed of each worker equals to 27 | # num_worker * rank + worker_id + user_seed 28 | worker_seed = num_workers * rank + worker_id + seed 29 | np.random.seed(worker_seed) 30 | random.seed(worker_seed) 31 | torch.manual_seed(worker_seed) 32 | 33 | 34 | def get_dist_info(): 35 | if dist.is_available() and dist.is_initialized(): 36 | rank = dist.get_rank() 37 | world_size = dist.get_world_size() 38 | else: 39 | rank = 0 40 | world_size = 1 41 | 42 | return rank, world_size 43 | 44 | 45 | def sync_random_seed(seed=None, device="cuda"): 46 | """Make sure different ranks share the same seed. 47 | All workers must call this function, otherwise it will deadlock. 48 | This method is generally used in `DistributedSampler`, 49 | because the seed should be identical across all processes 50 | in the distributed group. 51 | In distributed sampling, different ranks should sample non-overlapped 52 | data in the dataset. Therefore, this function is used to make sure that 53 | each rank shuffles the data indices in the same order based 54 | on the same seed. Then different ranks could use different indices 55 | to select non-overlapped data from the same data list. 56 | Args: 57 | seed (int, Optional): The seed. Default to None. 58 | device (str): The device where the seed will be put on. 59 | Default to 'cuda'. 60 | Returns: 61 | int: Seed to be used. 62 | """ 63 | if seed is None: 64 | seed = np.random.randint(2**31) 65 | assert isinstance(seed, int) 66 | 67 | rank, world_size = get_dist_info() 68 | 69 | if world_size == 1: 70 | return seed 71 | 72 | if rank == 0: 73 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 74 | else: 75 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 76 | 77 | dist.broadcast(random_num, src=0) 78 | 79 | return random_num.item() 80 | 81 | 82 | class DistributedSampler(_DistributedSampler): 83 | def __init__( 84 | self, 85 | dataset, 86 | num_replicas=None, # world_size 87 | rank=None, # local_rank 88 | shuffle=True, 89 | seed=0, 90 | ): 91 | 92 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 93 | 94 | # In distributed sampling, different ranks should sample 95 | # non-overlapped data in the dataset. Therefore, this function 96 | # is used to make sure that each rank shuffles the data indices 97 | # in the same order based on the same seed. Then different ranks 98 | # could use different indices to select non-overlapped data from the 99 | # same data list. 100 | self.seed = sync_random_seed(seed) 101 | 102 | def __iter__(self): 103 | # deterministically shuffle based on epoch 104 | if self.shuffle: 105 | g = torch.Generator() 106 | # When :attr:`shuffle=True`, this ensures all replicas 107 | # use a different random ordering for each epoch. 108 | # Otherwise, the next iteration of this sampler will 109 | # yield the same ordering. 110 | g.manual_seed(self.epoch + self.seed) 111 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 112 | else: 113 | indices = torch.arange(len(self.dataset)).tolist() 114 | 115 | # add extra samples to make it evenly divisible 116 | # in case that indices is shorter than half of total_size 117 | indices = (indices * math.ceil(self.total_size / len(indices)))[ 118 | : self.total_size 119 | ] 120 | assert len(indices) == self.total_size 121 | 122 | # subsample 123 | indices = indices[self.rank : self.total_size : self.num_replicas] 124 | assert len(indices) == self.num_samples 125 | 126 | return iter(indices) 127 | -------------------------------------------------------------------------------- /src/arcface_torch/utils/utils_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | 6 | class AverageMeter(object): 7 | """Computes and stores the average and current value 8 | """ 9 | 10 | def __init__(self): 11 | self.val = None 12 | self.avg = None 13 | self.sum = None 14 | self.count = None 15 | self.reset() 16 | 17 | def reset(self): 18 | self.val = 0 19 | self.avg = 0 20 | self.sum = 0 21 | self.count = 0 22 | 23 | def update(self, val, n=1): 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | 30 | def init_logging(rank, models_root): 31 | if rank == 0: 32 | log_root = logging.getLogger() 33 | log_root.setLevel(logging.INFO) 34 | formatter = logging.Formatter("Training: %(asctime)s-%(message)s") 35 | handler_file = logging.FileHandler(os.path.join(models_root, "training.log")) 36 | handler_stream = logging.StreamHandler(sys.stdout) 37 | handler_file.setFormatter(formatter) 38 | handler_stream.setFormatter(formatter) 39 | log_root.addHandler(handler_file) 40 | log_root.addHandler(handler_stream) 41 | log_root.info('rank_id: %d' % rank) 42 | -------------------------------------------------------------------------------- /tests/face_swap.sh: -------------------------------------------------------------------------------- 1 | CKPT=checkpoints/diffswap.pth 2 | PORT=$(comm -23 <(seq 49152 65535 | sort) <(ss -Htan | awk '{print $4}' | cut -d':' -f2 | sort -u) | shuf | head -n 1) 3 | 4 | PYTHONPATH=./:$PYTHONPATH python3 -m torch.distributed.launch --nproc_per_node=1 --master_port=$PORT tests/faceswap_portrait.py $CKPT --save_img True -------------------------------------------------------------------------------- /tests/faceswap_portrait.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import cv2 4 | from omegaconf import OmegaConf 5 | import torch 6 | from ldm.util import instantiate_from_config 7 | from einops import rearrange, repeat 8 | import torch.nn.functional as F 9 | import pdb 10 | import torchvision 11 | from PIL import Image 12 | import pdb 13 | import numpy as np 14 | import torch.distributed as dist 15 | import builtins 16 | import datetime 17 | from pathlib import Path 18 | from ldm.data.portrait import Portrait 19 | import json 20 | from torchvision import transforms 21 | from tqdm import tqdm 22 | from ldm.models.diffusion.ddim import DDIMSampler 23 | 24 | 25 | import warnings 26 | warnings.filterwarnings("ignore") 27 | 28 | 29 | def setup_for_distributed(is_master): 30 | """ 31 | This function disables printing when not in master process 32 | """ 33 | builtin_print = builtins.print 34 | 35 | def print(*args, **kwargs): 36 | force = kwargs.pop('force', False) 37 | # force = force or (get_world_size() > 8) 38 | if is_master or force: 39 | now = datetime.datetime.now().time() 40 | builtin_print('[{}] '.format(now), end='') # print with time stamp 41 | builtin_print(*args, **kwargs) 42 | 43 | builtins.print = print 44 | 45 | 46 | def is_dist_avail_and_initialized(): 47 | if not dist.is_available(): 48 | return False 49 | if not dist.is_initialized(): 50 | return False 51 | return True 52 | 53 | 54 | def get_world_size(): 55 | if not is_dist_avail_and_initialized(): 56 | return 1 57 | return dist.get_world_size() 58 | 59 | 60 | def get_rank(): 61 | if not is_dist_avail_and_initialized(): 62 | return 0 63 | return dist.get_rank() 64 | 65 | 66 | def is_main_process(): 67 | return get_rank() == 0 68 | 69 | 70 | def save_on_master(*args, **kwargs): 71 | if is_main_process(): 72 | torch.save(*args, **kwargs) 73 | 74 | 75 | def init_distributed_mode(args): 76 | if args.dist_on_itp: 77 | args.rank = int(os.environ['OMPI_COMM_WORLD_RANK']) 78 | args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE']) 79 | args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) 80 | args.dist_url = "tcp://%s:%s" % (os.environ['MASTER_ADDR'], os.environ['MASTER_PORT']) 81 | os.environ['LOCAL_RANK'] = str(args.gpu) 82 | os.environ['RANK'] = str(args.rank) 83 | os.environ['WORLD_SIZE'] = str(args.world_size) 84 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 85 | elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 86 | args.rank = int(os.environ["RANK"]) 87 | args.world_size = int(os.environ['WORLD_SIZE']) 88 | args.gpu = int(os.environ['LOCAL_RANK']) 89 | elif 'SLURM_PROCID' in os.environ: 90 | args.rank = int(os.environ['SLURM_PROCID']) 91 | args.gpu = args.rank % torch.cuda.device_count() 92 | else: 93 | print('Not using distributed mode') 94 | setup_for_distributed(is_master=True) # hack 95 | args.distributed = False 96 | return 97 | args.distributed = True 98 | torch.cuda.set_device(args.gpu) 99 | args.dist_backend = 'nccl' 100 | print('| distributed init (rank {}): {}, gpu {}'.format( 101 | args.rank, args.dist_url, args.gpu), flush=True) 102 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,world_size=args.world_size, rank=args.rank) 103 | torch.distributed.barrier() 104 | setup_for_distributed(args.rank == 0) 105 | # above is all about distributed training and shouldn't be modified 106 | 107 | 108 | @torch.no_grad() 109 | def perform_swap(self, batch, ckpt, ddim_sampler = None, ddim_steps=200, ddim_eta=0., **kwargs): 110 | # now we swap the faces 111 | use_ddim = ddim_steps is not None 112 | 113 | log = dict() 114 | 115 | z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, 116 | return_first_stage_outputs=True, 117 | force_c_encode=True, 118 | return_original_cond=True, swap=True) 119 | N = x.size(0) 120 | 121 | b, h, w = z.shape[0], z.shape[2], z.shape[3] # 64 x 64 122 | 123 | for mask_key in ['mask']: 124 | mask = (1 - batch[mask_key].float())[:, None] 125 | mask = F.interpolate(mask, size=(h, w), mode='nearest') 126 | mask[mask > 0] = 1 127 | mask[mask <= 0] = 0 128 | 129 | 130 | with self.ema_scope("Plotting Inpaint"): 131 | if ddim_sampler is None: 132 | ddim_sampler = DDIMSampler(self) 133 | shape = (self.channels, self.image_size, self.image_size) 134 | samples, _ = ddim_sampler.sample(ddim_steps,N,shape,c, eta=ddim_eta, x0=z[:N], mask=mask, verbose=False,**kwargs) 135 | 136 | x_samples = self.decode_first_stage(samples.to(self.device)) 137 | 138 | gen_imgs = torch.clamp((x_samples+1.0)/2.0, min=0.0, max=1.0).cpu() 139 | gen_imgs = np.array(gen_imgs * 255).astype(np.uint8) 140 | gen_imgs = rearrange(gen_imgs, 'b c h w -> b h w c') 141 | 142 | # save swapped images 143 | for j in range(N): 144 | src = batch['src'][j][:-4] 145 | save_root = f'swap_res/{ckpt}_{ddim_sampler.tgt_scale}' 146 | os.makedirs(os.path.join(save_dir, save_root, src), exist_ok=True) 147 | Image.fromarray(gen_imgs[j]).save(os.path.join(save_dir, save_root, src, batch['target'][j])) 148 | 149 | 150 | if __name__ == '__main__': 151 | batch_size = 16 152 | num_workers = 8 153 | root_dir = 'data/portrait' 154 | save_dir = root_dir 155 | 156 | parser = argparse.ArgumentParser() 157 | parser.add_argument('checkpoint', help='checkpoint to load') 158 | parser.add_argument('--save_img', type=bool, default=False) 159 | parser.add_argument('--tgt_scale', type=float, default=0.01) 160 | parser.add_argument('--world_size', default=1, type=int, 161 | help='number of distributed processes') 162 | parser.add_argument('--local_rank', default=-1, type=int) 163 | parser.add_argument('--dist_on_itp', action='store_true') 164 | parser.add_argument('--dist_url', default='env://', 165 | help='url used to set up distributed training') 166 | args = parser.parse_args() 167 | init_distributed_mode(args) 168 | 169 | setup_for_distributed(is_main_process()) 170 | device = get_rank() 171 | 172 | # get config 173 | config_path = 'configs/diffswap/default-project.yaml' 174 | config = OmegaConf.load(config_path) 175 | 176 | print('ready to build model', config.model) 177 | model = instantiate_from_config(config.model) 178 | model.init_from_ckpt(args.checkpoint) 179 | ckpt = os.path.basename(args.checkpoint).rsplit('.', 1)[0] 180 | 181 | model.eval() 182 | model = model.to(device) 183 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device]) 184 | 185 | print('model built') 186 | model.module.cond_stage_model.affine_crop = True 187 | model.module.cond_stage_model.swap = True 188 | 189 | num_tasks = dist.get_world_size() 190 | 191 | dataset = Portrait(root_dir) 192 | sampler_val = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=device, shuffle=True) 193 | dataloader = torch.utils.data.DataLoader( 194 | dataset, sampler=sampler_val, 195 | batch_size=batch_size, 196 | num_workers=num_workers, 197 | pin_memory=True, 198 | drop_last=False, 199 | ) 200 | 201 | ddim_sampler = DDIMSampler(model.module, tgt_scale=args.tgt_scale) 202 | 203 | print('start batch') 204 | for batch_idx, batch in enumerate(tqdm(dataloader)): 205 | for k, v in batch.items(): 206 | if isinstance(v, torch.Tensor): 207 | batch[k] = v.to(device) 208 | 209 | perform_swap(model.module, batch, ckpt, ddim_sampler) 210 | 211 | 212 | 213 | 214 | -------------------------------------------------------------------------------- /utils/blending/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wl-zhao/DiffSwap/8596b4d635e3d97621df688245b365bc4d3ae02a/utils/blending/__init__.py -------------------------------------------------------------------------------- /utils/blending/blending.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Greg Marcil 4 | CS280 UC Berkeley 5 | Practice implementation of Burt and Adelson's "A Multiresolution Spline With 6 | Application to Image Mosaics 7 | """ 8 | 9 | import numpy as np 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | import scipy 13 | import scipy.ndimage as ndimage 14 | import cv2 15 | 16 | def subtract(a,b): 17 | return a - b 18 | # Add something weird 19 | def im_reduce(img): 20 | ''' 21 | Apply gaussian filter and drop every other pixel 22 | ''' 23 | filter = 1.0 / 20 * np.array([1, 5, 8, 5, 1]) 24 | lowpass = ndimage.filters.correlate1d(img, filter, 0) 25 | lowpass = ndimage.filters.correlate1d(lowpass, filter, 1) 26 | im_reduced = lowpass[::2, ::2, ...] 27 | return im_reduced 28 | 29 | def add(a, b): 30 | return a + b 31 | def im_expand(img, template): 32 | ''' 33 | Re-expand a reduced image by interpolating according to gaussian kernel 34 | Include template parameter to match size, easy way to avoid off by 1 errors 35 | re-expanding a previous layer that may have had odd or even dimension 36 | ''' 37 | # y_temp, x_temp = template.shape[:2] 38 | # im_expanded = np.zeros((y_temp, x_temp) + template.shape[2:], img.dtype) 39 | im_expanded = np.zeros(template.shape, img.dtype) 40 | im_expanded[::2, ::2, ...] = img 41 | 42 | filter = 1.0 / 10 * np.array([1, 5, 8, 5, 1]) 43 | lowpass = ndimage.filters.correlate1d( 44 | im_expanded, filter, 0, mode="constant") 45 | lowpass = ndimage.filters.correlate1d(lowpass, filter, 1, mode="constant") 46 | return lowpass 47 | 48 | 49 | 50 | def gaussian_pyramid(image, layers=7): 51 | ''' 52 | pyramid of increasingly strongly low-pass filtered images, 53 | shrunk 2x h and w each layer 54 | ''' 55 | pyr = [image] 56 | temp_img = image 57 | for i in range(layers): 58 | temp_img = im_reduce(temp_img) 59 | pyr.append(temp_img) 60 | return pyr 61 | 62 | 63 | def laplacian_pyramid(gaussian_pyramid): 64 | ''' 65 | laplacian pyramid is a band-pass filter pyramid, calculated by the 66 | difference between subsequent gaussian pyramid layers, terminating with top 67 | layer of gaussian. Laplacian pyramid can be summed to give back original 68 | image 69 | ''' 70 | pyr = [] 71 | for i in range(len(gaussian_pyramid) - 1): 72 | g_k = gaussian_pyramid[i] 73 | g_k_plus_1 = gaussian_pyramid[i + 1] 74 | g_k_1_expand = im_expand(g_k_plus_1, g_k) 75 | laplacian = g_k - g_k_1_expand 76 | pyr.append(laplacian) 77 | 78 | pyr.append(gaussian_pyramid[-1]) 79 | return pyr 80 | 81 | 82 | def laplacian_collapse(pyr): 83 | ''' 84 | Rejoin all levels of a laplacian pyramid. As the pyramid is a spanning set 85 | of band-pass filter outputs (all frequencies represented once and only 86 | once), joining all levels will give back the original image, modulo 87 | compression loss 88 | ''' 89 | ''' Start with lowest pass data, top of pyramid ''' 90 | partial_img = pyr[-1] 91 | for i in range(len(pyr) - 1): 92 | next_lowest = pyr[-2 - i] 93 | expanded_partial = im_expand(partial_img, next_lowest) 94 | partial_img = expanded_partial + next_lowest 95 | return partial_img 96 | 97 | 98 | def laplacian_pyr_join(pyr1, pyr2): 99 | pyr = [] 100 | for i in range(len(pyr1)): 101 | left = pyr1[i] 102 | right = pyr2[i] 103 | layer = np.zeros(left.shape, left.dtype) 104 | _, x, _ = left.shape 105 | ''' even width ''' 106 | half = x // 2 107 | ''' assign halves ''' 108 | layer[:, :half, ...] = left[:, :half, ...] 109 | layer[:, -half:, ...] = right[:, -half:, ...] 110 | pyr.append(layer) 111 | return pyr 112 | 113 | 114 | 115 | def main(): 116 | plt.ion() 117 | im1 = matplotlib.image.imread('tests/blending/pic/apple.jpg') 118 | im2 = matplotlib.image.imread('tests/blending/pic/orange.jpg') 119 | im1, im2 = np.uint32(im1), np.uint32(im2) 120 | 121 | gp_1, gp_2 = [gaussian_pyramid(im) for im in [im1, im2]] 122 | lp_1, lp_2 = [laplacian_pyramid(gp) for gp in [gp_1, gp_2]] 123 | lp_join = laplacian_pyr_join(lp_1, lp_2) 124 | im_join = laplacian_collapse(lp_join) 125 | 126 | np.clip(im_join, 0, 255, out=im_join) 127 | im_join = np.uint8(im_join) 128 | plt.imsave('tests/blending/pic/orapple.jpg', im_join) 129 | return 0 130 | 131 | 132 | if __name__ == '__main__': 133 | main() -------------------------------------------------------------------------------- /utils/blending/blending_mask.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | Greg Marcil 4 | CS280 UC Berkeley 5 | Practice implementation of Burt and Adelson's "A Multiresolution Spline With 6 | Application to Image Mosaics 7 | """ 8 | 9 | import numpy as np 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | import scipy 13 | import scipy.ndimage as ndimage 14 | import cv2 15 | 16 | def subtract(a,b): 17 | return a - b 18 | # Add something weird 19 | def im_reduce(img): 20 | ''' 21 | Apply gaussian filter and drop every other pixel 22 | ''' 23 | filter = 1.0 / 20 * np.array([1, 5, 8, 5, 1]) 24 | lowpass = ndimage.filters.correlate1d(img, filter, 0) 25 | lowpass = ndimage.filters.correlate1d(lowpass, filter, 1) 26 | im_reduced = lowpass[::2, ::2, ...] 27 | return im_reduced 28 | 29 | def add(a, b): 30 | return a + b 31 | def im_expand(img, template): 32 | ''' 33 | Re-expand a reduced image by interpolating according to gaussian kernel 34 | Include template parameter to match size, easy way to avoid off by 1 errors 35 | re-expanding a previous layer that may have had odd or even dimension 36 | ''' 37 | # y_temp, x_temp = template.shape[:2] 38 | # im_expanded = np.zeros((y_temp, x_temp) + template.shape[2:], img.dtype) 39 | im_expanded = np.zeros(template.shape, img.dtype) 40 | im_expanded[::2, ::2, ...] = img 41 | 42 | filter = 1.0 / 10 * np.array([1, 5, 8, 5, 1]) 43 | lowpass = ndimage.filters.correlate1d( 44 | im_expanded, filter, 0, mode="constant") 45 | lowpass = ndimage.filters.correlate1d(lowpass, filter, 1, mode="constant") 46 | return lowpass 47 | 48 | 49 | 50 | def gaussian_pyramid(image, layers=7): 51 | ''' 52 | pyramid of increasingly strongly low-pass filtered images, 53 | shrunk 2x h and w each layer 54 | ''' 55 | pyr = [image] 56 | temp_img = image 57 | for i in range(layers): 58 | temp_img = im_reduce(temp_img) 59 | pyr.append(temp_img) 60 | return pyr 61 | 62 | 63 | def laplacian_pyramid(gaussian_pyramid): 64 | ''' 65 | laplacian pyramid is a band-pass filter pyramid, calculated by the 66 | difference between subsequent gaussian pyramid layers, terminating with top 67 | layer of gaussian. Laplacian pyramid can be summed to give back original 68 | image 69 | ''' 70 | pyr = [] 71 | for i in range(len(gaussian_pyramid) - 1): 72 | g_k = gaussian_pyramid[i] 73 | g_k_plus_1 = gaussian_pyramid[i + 1] 74 | g_k_1_expand = im_expand(g_k_plus_1, g_k) 75 | laplacian = g_k - g_k_1_expand 76 | pyr.append(laplacian) 77 | 78 | pyr.append(gaussian_pyramid[-1]) 79 | return pyr 80 | 81 | 82 | def laplacian_collapse(pyr): 83 | ''' 84 | Rejoin all levels of a laplacian pyramid. As the pyramid is a spanning set 85 | of band-pass filter outputs (all frequencies represented once and only 86 | once), joining all levels will give back the original image, modulo 87 | compression loss 88 | ''' 89 | ''' Start with lowest pass data, top of pyramid ''' 90 | partial_img = pyr[-1] 91 | for i in range(len(pyr) - 1): 92 | next_lowest = pyr[-2 - i] 93 | expanded_partial = im_expand(partial_img, next_lowest) 94 | partial_img = expanded_partial + next_lowest 95 | return partial_img 96 | 97 | 98 | def laplacian_pyr_join(pyr1, pyr2, mask_gp): 99 | pyr = [] 100 | for i in range(len(pyr1)): 101 | mask = np.array([mask_gp[i], mask_gp[i], mask_gp[i]]).transpose(1, 2, 0) 102 | layer = np.multiply(pyr1[i], mask) + np.multiply(pyr2[i], 1 - mask) 103 | pyr.append(layer) 104 | return pyr 105 | 106 | 107 | 108 | def main(): 109 | plt.ion() 110 | im1 = cv2.imread('test.png') 111 | im1 = cv2.cvtColor(im1, cv2.COLOR_BGR2RGB) 112 | im2 = cv2.imread('data/portrait_old/align/target/0303.png') 113 | im2 = cv2.cvtColor(im2, cv2.COLOR_BGR2RGB) 114 | 115 | mask = matplotlib.image.imread('mask.png') 116 | # mask = matplotlib.image.imread('data/portrait_old/mask/Repoparsing_mask/0303.png') 117 | im1, im2 = np.int32(im1), np.int32(im2) 118 | mask = np.uint8(mask) 119 | 120 | gp_1, gp_2 = [gaussian_pyramid(im) for im in [im1, im2]] 121 | mask_gp = [cv2.resize(mask, (gp.shape[1], gp.shape[0])) for gp in gp_1] 122 | 123 | lp_1, lp_2 = [laplacian_pyramid(gp) for gp in [gp_1, gp_2]] 124 | lp_join = laplacian_pyr_join(lp_1, lp_2, mask_gp) 125 | im_join = laplacian_collapse(lp_join) 126 | 127 | np.clip(im_join, 0, 255, out=im_join) 128 | im_join = np.uint8(im_join) 129 | plt.imsave('test7.png', im_join) 130 | 131 | 132 | if __name__ == '__main__': 133 | main() -------------------------------------------------------------------------------- /utils/portrait.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | import pickle 4 | from PIL import Image 5 | import io 6 | from torchvision import transforms 7 | import cv2 8 | import pdb 9 | import PIL 10 | import numpy as np 11 | from einops import rearrange 12 | from scipy.spatial import ConvexHull 13 | import random 14 | import json 15 | from math import floor 16 | import os 17 | import warnings 18 | warnings.filterwarnings("ignore") 19 | 20 | class Portrait(Dataset): 21 | def __init__(self, root, size = 256, base_res = 256, flip=False, 22 | interpolation="bicubic", dilate=False, convex_hull=True): 23 | super().__init__() 24 | self.size = size 25 | self.root = root 26 | self.base_res = base_res 27 | self.error_img = json.load(open(f'{root}/error_img.json')) 28 | 29 | self.lmk_path = f'{root}/landmark/landmark_256.pkl' 30 | self.landmarks = pickle.load(open(self.lmk_path,'rb')) 31 | 32 | self.src_list = os.listdir(f'{root}/source') 33 | self.src_list = [x for x in self.src_list if x not in self.error_img['source']] 34 | self.src_list.sort() 35 | self.tgt_list = os.listdir(f'{root}/target') 36 | self.tgt_list = [x for x in self.tgt_list if x not in self.error_img['target']] 37 | self.tgt_list.sort() 38 | 39 | print(f'len(self.src_list): {len(self.src_list)}') 40 | self.affine_thetas = json.load(open(f'{root}/affine_theta.json')) 41 | 42 | self.interpolation = {"linear": PIL.Image.LINEAR, 43 | "bilinear": PIL.Image.BILINEAR, 44 | "bicubic": PIL.Image.BICUBIC, 45 | "lanczos": PIL.Image.LANCZOS, 46 | }[interpolation] 47 | self.convex_hull = convex_hull 48 | 49 | 50 | all_indices = np.arange(0, 68) 51 | self.landmark_indices = { 52 | # 'face': all_indices[:17].tolist() + all_indices[17:27].tolist(), 53 | 'l_eye': all_indices[36:42].tolist(), 54 | 'r_eye': all_indices[42:48].tolist(), 55 | 'nose': all_indices[27:36].tolist(), 56 | 'mouth': all_indices[48:68].tolist(), 57 | } 58 | self.dilate = dilate 59 | if dilate: 60 | self.dilate_kernel = np.ones((11, 11), np.uint8) 61 | 62 | def __len__(self): 63 | return len(self.src_list) * len(self.tgt_list) # 9 * 1039 = 9351 64 | 65 | def __getitem__(self, index): # index: 0 - 9350 66 | src_index = floor(index / len(self.tgt_list)) # 0 - 8 67 | tgt_index = index - len(self.tgt_list) * src_index # 0 - 1038 68 | batch = {} 69 | 70 | for type in ['source', 'target']: 71 | if type == 'source': 72 | image = Image.open(os.path.join(f'{self.root}/align/{type}', self.src_list[src_index])).convert('RGB') 73 | affine_theta = np.array(self.affine_thetas[type][self.src_list[src_index]], dtype=np.float32) 74 | landmark = torch.tensor(self.landmarks[type][self.src_list[src_index]]) / self.base_res 75 | elif type == 'target': 76 | image = Image.open(os.path.join(f'{self.root}/align/{type}', self.tgt_list[tgt_index])).convert('RGB') 77 | affine_theta = np.array(self.affine_thetas[type][self.tgt_list[tgt_index]], dtype=np.float32) 78 | landmark = torch.tensor(self.landmarks[type][self.tgt_list[tgt_index]]) / self.base_res 79 | 80 | image = image.resize((self.size, self.size), resample=self.interpolation) 81 | image = np.array(image).astype(np.uint8) 82 | image = (image / 127.5 - 1.0).astype(np.float32) 83 | 84 | if type == 'source': 85 | batch['mask_organ_src'] = self.mask_organ_src(landmark) 86 | batch['image_src'] = image 87 | # batch['image_src']'s identity 88 | batch['affine_theta_src'] = affine_theta 89 | batch['src'] = self.src_list[src_index] 90 | else: 91 | batch['image'] = image 92 | # batch['image']'s identity 93 | batch['landmark'] = landmark 94 | batch['affine_theta'] = affine_theta 95 | batch['target'] = self.tgt_list[tgt_index] 96 | 97 | if self.convex_hull: 98 | mask_dict = self.extract_convex_hulls(batch['landmark']) 99 | batch.update(mask_dict) 100 | return batch 101 | 102 | def mask_organ_src(self, landmark): 103 | mask_organ = [] 104 | for key, indices in self.landmark_indices.items(): 105 | mask_key = self.extract_convex_hull(landmark[indices]) 106 | if self.dilate: 107 | # mask_key = mask_key[:, :, None] 108 | # mask_key = repeat(mask_key, 'h w -> h w k', k=3) 109 | # print(mask_key.shape, type(mask_key)) 110 | mask_key = mask_key.astype(np.uint8) 111 | mask_key = cv2.dilate(mask_key, self.dilate_kernel, iterations=1) 112 | mask_organ.append(mask_key) 113 | return np.stack(mask_organ) 114 | 115 | 116 | def extract_convex_hulls(self, landmark): 117 | mask_dict = {} 118 | mask_organ = [] 119 | for key, indices in self.landmark_indices.items(): 120 | mask_key = self.extract_convex_hull(landmark[indices]) 121 | if self.dilate: 122 | # mask_key = mask_key[:, :, None] 123 | # mask_key = repeat(mask_key, 'h w -> h w k', k=3) 124 | # print(mask_key.shape, type(mask_key)) 125 | mask_key = mask_key.astype(np.uint8) 126 | mask_key = cv2.dilate(mask_key, self.dilate_kernel, iterations=1) 127 | mask_organ.append(mask_key) 128 | mask_organ = np.stack(mask_organ) # (4, 256, 256) 129 | mask_dict['mask_organ'] = mask_organ 130 | mask_dict['mask'] = self.extract_convex_hull(landmark) 131 | return mask_dict 132 | 133 | def extract_convex_hull(self, landmark): 134 | landmark = landmark * self.size 135 | hull = ConvexHull(landmark) 136 | image = np.zeros((self.size, self.size)) 137 | points = [landmark[hull.vertices, :1], landmark[hull.vertices, 1:]] 138 | points = np.concatenate(points, axis=-1).astype('int32') 139 | mask = cv2.fillPoly(image, pts=[points], color=(255,255,255)) 140 | mask = mask > 0 141 | return mask 142 | 143 | def visualize(batch): 144 | n = len(batch['image']) 145 | os.makedirs('ldm/data/debug', exist_ok=True) 146 | for i in range(n): 147 | print(i) 148 | print(batch['src'][i], batch['target'][i]) 149 | image = (batch['image'][i] + 1) / 2 * 255 150 | image = rearrange(image, 'h w c -> c h w') 151 | image = image.numpy().transpose(1, 2, 0).astype('uint8').copy() 152 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 153 | cv2.imwrite(f'ldm/data/debug/{i}_tgt.png', image) 154 | 155 | image_src = (batch['image_src'][i] + 1) / 2 * 255 156 | image_src = rearrange(image_src, 'h w c -> c h w') 157 | image_src = image_src.numpy().transpose(1, 2, 0).astype('uint8').copy() 158 | image_src = cv2.cvtColor(image_src, cv2.COLOR_RGB2BGR) 159 | cv2.imwrite(f'ldm/data/debug/{i}_src.png', image_src) 160 | 161 | lmk = (batch['landmark'][i] * image.shape[0]).numpy().astype('int32') 162 | for k in range(68): 163 | image = cv2.circle(image, (lmk[k, 0], lmk[k, 1]), 3, (255, 0, 255), thickness=-1) 164 | cv2.imwrite(f'ldm/data/debug/{i}_lmk.png', image) 165 | 166 | mask = (batch['mask'][i].numpy() * 255).astype('uint8') #[:, :, None] 167 | mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) 168 | cv2.imwrite(f'ldm/data/debug/{i}_mask.png', mask) 169 | 170 | mask_organ = (batch['mask_organ'][i][0].numpy() * 255).astype('uint8') #[:, :, None] 171 | mask_organ = cv2.cvtColor(mask_organ, cv2.COLOR_GRAY2BGR) 172 | cv2.imwrite(f'ldm/data/debug/{i}_mask_organ.png', mask_organ) 173 | 174 | mask_organ_src = (batch['mask_organ_src'][i][1].numpy() * 255).astype('uint8') #[:, :, None] 175 | mask_organ_src = cv2.cvtColor(mask_organ_src, cv2.COLOR_GRAY2BGR) 176 | cv2.imwrite(f'ldm/data/debug/{i}_mask_organ_src.png', mask_organ_src) 177 | 178 | if __name__ == '__main__': 179 | dataset = Portrait('data/portrait') 180 | dataloader = DataLoader(dataset, batch_size=8, shuffle=True) 181 | 182 | for batch in dataloader: 183 | visualize(batch) 184 | break --------------------------------------------------------------------------------