├── .gitignore ├── LICENSE ├── README.md ├── cal_orientation.py ├── checkpoints └── MichiGAN │ ├── download_model_G.sh │ ├── download_model_IG.sh │ └── download_model_SIG.sh ├── data ├── __init__.py ├── ab_count.npy ├── base_dataset.py ├── clear_ab_count.npy ├── custom_dataset.py ├── image_folder.py ├── pix2pix_dataset.py ├── selected_img_names_10000_29999.txt ├── single_result.png ├── special_img_names_10000_29999.txt ├── teaser.jpg └── val_image_list.txt ├── datasets ├── FFHQ_demo │ ├── images │ │ ├── 59144.jpg │ │ ├── 60429.jpg │ │ └── 67172.jpg │ ├── images_recon │ │ └── 67172.jpg │ ├── labels │ │ ├── 59144.png │ │ ├── 60429.png │ │ └── 67172.png │ └── orients │ │ ├── 59144_orient_dense.png │ │ ├── 60429_orient_dense.png │ │ └── 67172_orient_dense.png └── FFHQ_single │ ├── val_dense_orients │ └── 67172_orient_dense.png │ ├── val_images │ └── 67172.jpg │ └── val_labels │ └── 67172.png ├── demo.py ├── inference.py ├── inference_samples └── inpaint_fake_image.jpg ├── models ├── __init__.py ├── networks │ ├── MaskGAN_networks.py │ ├── __init__.py │ ├── architecture.py │ ├── base_network.py │ ├── discriminator.py │ ├── encoder.py │ ├── generator.py │ ├── loss.py │ ├── normalization.py │ ├── partialconv2d.py │ └── sync_batchnorm │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── batchnorm_reimpl.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py └── pix2pix_model.py ├── options ├── __init__.py ├── base_options.py ├── demo_options.py ├── test_options.py └── train_options.py ├── requirements.txt ├── train.py ├── trainers ├── __init__.py └── pix2pix_trainer.py ├── ui ├── __init__.py ├── mouse_event.py ├── ui4.py ├── ui_buttons.py └── ui_palette.py ├── ui_util ├── __init__.py ├── cal_orient_stroke.py └── config.py └── util ├── __init__.py ├── coco.py ├── html.py ├── iter_counter.py ├── iter_counter_ms.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | # datasets/ 3 | *__pycache__ 4 | .cache/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Zhentao Tan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # MichiGAN: Multi-Input-Conditioned Hair Image Generation for Portrait Editing 3 | ![Teaser](data/teaser.jpg) 4 | ### [Paper](https://mlchai.com/files/tan2020michigan.pdf) 5 | 6 | Zhentao Tan, [Menglei Chai](https://mlchai.com/), [Dongdong Chen](http://www.dongdongchen.bid/), [Jing Liao](https://liaojing.github.io/html/index.html), [Qi Chu](https://scholar.google.com/citations?user=JZjOMdsAAAAJ&hl=en), Lu Yuan, [Sergey Tulyakov](http://www.stulyakov.com/), [Nenghai Yu](https://scholar.google.com/citations?user=7620QAMAAAAJ&hl=zh-CN) 7 | 8 | ## Abstract 9 | >Despite the recent success of face image generation with GANs, conditional hair editing remains challenging due to the under-explored complexity of its geometry and appearance. In this paper, we present MichiGAN (Multi-Input-Conditioned Hair Image GAN), a novel conditional image generation method for interactive portrait hair manipulation. To provide user control over every major hair visual factor, we explicitly disentangle hair into four orthogonal attributes, including shape, structure, appearance, and background. For each of them, we design a corresponding condition module to represent, process, and convert user inputs, and modulate the image generation pipeline in ways that respect the natures of different visual attributes. All these condition modules are integrated with the backbone generator to form the final end-to-end network, which allows fully-conditioned hair generation from multiple user inputs. Upon it, we also build an interactive portrait hair editing system that enables straightforward manipulation of hair by projecting intuitive and high-level user inputs such as painted masks, guiding strokes, or reference photos to well-defined condition representations. Through extensive experiments and evaluations, we demonstrate the superiority of our method regarding both result quality and user controllability. 10 | 11 | 12 | ## Installation 13 | 14 | Clone this repo. 15 | ```bash 16 | git clone https://github.com/tzt101/MichiGAN.git 17 | cd MichiGAN/ 18 | ``` 19 | 20 | This code requires PyTorch 1.0 and python 3+. Please install dependencies by 21 | ```bash 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | Please download the Synchronized-BatchNorm-PyTorch rep. 26 | ``` 27 | cd models/networks/ 28 | git clone https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 29 | cp -rf Synchronized-BatchNorm-PyTorch/sync_batchnorm . 30 | cd ../../ 31 | ``` 32 | 33 | ## Dataset Preparation 34 | 35 | The FFHQ dataset can be downloaded from [Baidu Netdisk](https://pan.baidu.com/s/1jI0EThBSgVRB_bgPype8pg) with the extracted code `ichc` or [OneDrive (RAR)](https://mailustceducn-my.sharepoint.com/:u:/g/personal/tzt_mail_ustc_edu_cn/ES2Ig_Nmmh1Jglv_T1VJzBgBbbxgdAnjDVVhJU1SzqIugA) or [OneDrive (ZIP)](https://mailustceducn-my.sharepoint.com/:u:/g/personal/tzt_mail_ustc_edu_cn/ES_I8Z09JZVJocoFvo-1aKEB3Ah7uI9C56JuMPMZhpXNqQ?e=9n61Lx), you should specify the dataset root from through `--data_dir`. Please follow the [license](https://github.com/NVlabs/ffhq-dataset) when you use the FFHQ dataset. 36 | 37 | ## Generating Images Using Pretrained Model 38 | 39 | Once the dataset is ready, the result images can be generated using pretrained models. 40 | 41 | 1. Download the pretrained models from the [Google Drive Folder](https://drive.google.com/open?id=1Vxilcb82ax1Zlwy9wqHRu5-DCJuZFc_C), save it in 'checkpoints/MichiGAN/'. You can also download the pretrained models with the following commands: 42 | ```bash 43 | cd checkpoints/MichiGAN/ 44 | bash download_model_G.sh 45 | bash download_model_IG.sh 46 | bash download_model_SIG.sh 47 | ``` 48 | 49 | 2. Generate single image using the pretrained model. 50 | ```bash 51 | python inference.py --name MichiGAN --gpu_ids 0 --inference_ref_name 67172 --inference_tag_name 67172 --inference_orient_name 67172 --netG spadeb --which_epoch 50 --use_encoder --noise_background --expand_mask_be --expand_th 5 --use_ig --load_size 512 --crop_size 512 --add_feat_zeros --data_dir [path_to_dataset] 52 | ``` 53 | 3. The outputs images are stored at `./inference_samples/` by default. If you just want to test this single image without download the whole dataset, please set `--data_dir ./datasets/FFHQ_single/`. We give a sample image (67172) here. 54 | 55 | ## Training New Models 56 | 57 | New models can be trained with the following command. 58 | 59 | ```bash 60 | python train.py --name [name_experiment] --batchSize 8 --no_confidence_loss --gpu_ids 0,1,2,3,4,5,6,7 --no_style_loss --no_rgb_loss --no_content_loss --use_encoder --wide_edge 2 --no_background_loss --noise_background --random_expand_mask --use_ig --load_size 568 --crop_size 512 --data_dir [pah_to_dataset] ----checkpoints_dir ./checkpoints 61 | ``` 62 | `[name_experiment]` is the directory name of the checkpoint file saved. if you want to train the model with orientation inpainting model (with the option --use_ig), please download the pretrained inpainting model from [Google Drive Folder](https://drive.google.com/open?id=1Vxilcb82ax1Zlwy9wqHRu5-DCJuZFc_C) and save them in `./checkpoints/[name_experiment]/` firstly. 63 | 64 | ## UI 65 | 66 | You can direct run demo.py to use the Interactive systems. This UI code borrows from [MaskGAN](https://github.com/switchablenorms/CelebAMask-HQ.git). 67 | 68 | ## Orientation for New Dataset 69 | 70 | Once the image and the corresponding hair mask is provided, you can use the following command to extract dense hair orientaiton map. 71 | ```bash 72 | python cal_orientation.py --image_path [your image path] --hairmask_path [you hair mask path] --orientation_root [save root] 73 | ``` 74 | For ease of use, we have rewritten the original c++ code into python. The results of this code are slightly different from the C++ version, but does not affect usage. 75 | 76 | ## Code Structure 77 | 78 | - `train.py`, `inference.py`: the entry point for training and inferencing. 79 | - `trainers/pix2pix_trainer.py`: harnesses and reports the progress of training. 80 | - `models/pix2pix_model.py`: creates the networks, and compute the losses 81 | - `models/networks/`: defines the architecture of all models 82 | - `options/`: creates option lists using `argparse` package. More individuals are dynamically added in other files as well. Please see the section below. 83 | - `data/`: defines the class for loading datas. 84 | 85 | ## Citation 86 | If you use this code for your research, please cite our papers. 87 | ``` 88 | @article{tan2020michigan, 89 | title={MichiGAN: Multi-Input-Conditioned Hair Image Generation for Portrait Editing}, 90 | author={Zhentao Tan, Menglei Chai, Dongdong Chen, Jing Liao, Qi Chu, Lu Yuan, Sergey Tulyakov and Nenghai Yu}, 91 | journal={ACM Transactions on Graphics (TOG)}, 92 | volume={39}, 93 | number={4}, 94 | pages={1--13}, 95 | year={2020}, 96 | publisher={ACM New York, NY, USA} 97 | } 98 | ``` 99 | 100 | ## Acknowledgments 101 | This code borrows heavily from [SPADE](https://github.com/NVlabs/SPADE.git). We thank Jiayuan Mao for his [Synchronized Batch Normalization](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch) code. 102 | -------------------------------------------------------------------------------- /cal_orientation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from PIL import Image 6 | import torchvision.transforms as transforms 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import os 10 | import cv2 11 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 12 | 13 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 14 | parser.add_argument('--image_path', type=str, default='56000.jpg', help='Path to image') 15 | parser.add_argument('--hairmask_path',type=str, default='56000.png', help='Path to hair mask') 16 | parser.add_argument('--orientation_root', type=str, default='./', help='Root to save hair orientation map') 17 | 18 | def DoG_fn(kernel_size, channel_in, channel_out, theta): 19 | # params 20 | sigma_h = nn.Parameter(torch.ones(channel_out) * 1.0, requires_grad=False) 21 | sigma_l = nn.Parameter(torch.ones(channel_out) * 2.0, requires_grad=False) 22 | sigma_y = nn.Parameter(torch.ones(channel_out) * 2.0, requires_grad=False) 23 | 24 | # Bounding box 25 | xmax = kernel_size // 2 26 | ymax = kernel_size // 2 27 | xmin = -xmax 28 | ymin = -ymax 29 | ksize = xmax - xmin + 1 30 | y_0 = torch.arange(ymin, ymax+1) 31 | y = y_0.view(1, -1).repeat(channel_out, channel_in, ksize, 1).float() 32 | x_0 = torch.arange(xmin, xmax+1) 33 | x = x_0.view(-1, 1).repeat(channel_out, channel_in, 1, ksize).float() # [channel_out, channelin, kernel, kernel] 34 | 35 | # Rotation 36 | # don't need to expand, use broadcasting, [64, 1, 1, 1] + [64, 3, 7, 7] 37 | x_theta = x * torch.cos(theta.view(-1, 1, 1, 1)) + y * torch.sin(theta.view(-1, 1, 1, 1)) 38 | y_theta = -x * torch.sin(theta.view(-1, 1, 1, 1)) + y * torch.cos(theta.view(-1, 1, 1, 1)) 39 | 40 | gb = (torch.exp(-.5 * (x_theta ** 2 / sigma_h.view(-1, 1, 1, 1) ** 2 + y_theta ** 2 / sigma_y.view(-1, 1, 1, 1) ** 2))/sigma_h \ 41 | - torch.exp(-.5 * (x_theta ** 2 / sigma_l.view(-1, 1, 1, 1) ** 2 + y_theta ** 2 / sigma_y.view(-1, 1, 1, 1) ** 2))/sigma_l) \ 42 | / (1.0/sigma_h - 1.0/sigma_l) 43 | 44 | return gb 45 | 46 | # L1 loss of orientation map 47 | class orient(nn.Module): 48 | def __init__(self, channel_in=1, channel_out=1, stride=1, padding=8): 49 | super(orient, self).__init__() 50 | self.criterion = nn.L1Loss() 51 | self.channel_in = channel_in 52 | self.channel_out = channel_out 53 | self.stride = stride 54 | self.padding = padding 55 | self.filter = DoG_fn 56 | 57 | self.numKernels = 32 58 | self.kernel_size = 17 59 | 60 | def calOrientation(self, image, mask=None): 61 | resArray = [] 62 | # filter the image with different orientations 63 | for iOrient in range(self.numKernels): 64 | theta = nn.Parameter(torch.ones(self.channel_out)*(math.pi*iOrient/self.numKernels), requires_grad=False) 65 | filterKernel = self.filter(self.kernel_size, self.channel_in, self.channel_out, theta) 66 | filterKernel = filterKernel.float() 67 | response = F.conv2d(image, filterKernel, stride=self.stride, padding=self.padding) 68 | resArray.append(response.clone()) 69 | 70 | resTensor = resArray[0] 71 | for iOrient in range(1, self.numKernels): 72 | resTensor = torch.cat([resTensor, resArray[iOrient]], dim=1) 73 | 74 | # argmax the response 75 | resTensor[resTensor < 0] = 0 76 | maxResTensor = torch.argmax(resTensor, dim=1).float() # range from 0 to 31 77 | confidenceTensor = torch.max(resTensor, dim=1)[0] 78 | confidenceTensor = torch.unsqueeze(confidenceTensor, 1) 79 | 80 | return maxResTensor, confidenceTensor 81 | 82 | if __name__ == '__main__': 83 | args = parser.parse_args() 84 | # mkdir orientation root 85 | if not os.path.exists(args.orientation_root): 86 | os.mkdir(args.orientation_root) 87 | 88 | # Get structure 89 | image = Image.open(args.image_path) 90 | mask = np.array(Image.open(args.hairmask_path)) 91 | if np.max(mask) > 1: 92 | mask = (mask > 130) * 1 93 | trans_image = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 94 | image_tensor = trans_image(image) 95 | image_tensor = torch.unsqueeze(image_tensor, 0) 96 | cal_orient = orient() 97 | fake_image = (image_tensor + 1) / 2.0 * 255 98 | gray = 0.299 * fake_image[:, 0, :, :] + 0.587 * fake_image[:, 1, :, :] + 0.144 * fake_image[:, 2, :, :] 99 | gray = torch.unsqueeze(gray, 1) 100 | orient_tensor, confidence_tensor = cal_orient.calOrientation(gray) 101 | orient_tensor = orient_tensor * math.pi / 31 * 2 102 | mask_tensor = torch.from_numpy(mask).float() 103 | flow_x = torch.cos(orient_tensor) * confidence_tensor * mask_tensor 104 | flow_y = torch.sin(orient_tensor) * confidence_tensor * mask_tensor 105 | flow_x = torch.from_numpy(cv2.GaussianBlur(flow_x.numpy().squeeze(), (0, 0), 4)) 106 | flow_y = torch.from_numpy(cv2.GaussianBlur(flow_y.numpy().squeeze(), (0, 0), 4)) 107 | orient_tensor = torch.atan2(flow_y, flow_x) * 0.5 108 | orient_tensor[orient_tensor < 0] += math.pi 109 | orient_np = orient_tensor.numpy().squeeze() * 255. / math.pi * mask 110 | orient_save = Image.fromarray(np.uint8(orient_np)) 111 | orient_save.save(os.path.join(args.orientation_root, args.image_path.split('/')[-1][:-4]+'.png')) 112 | # cv2.imwrite(args.orientation_root, orient_tensor.numpy().squeeze() * 255. / math.pi) 113 | -------------------------------------------------------------------------------- /checkpoints/MichiGAN/download_model_G.sh: -------------------------------------------------------------------------------- 1 | curl "https://mailustceducn-my.sharepoint.com/personal/tzt_mail_ustc_edu_cn/_layouts/15/download.aspx?UniqueId=df8cad7f"%"2Dc09a"%"2D4251"%"2D8d0f"%"2D959923a0225b" -H "User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:69.0) Gecko/20100101 Firefox/69.0" -H "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8" -H "Accept-Language: zh-CN,zh;q=0.8,zh-TW;q=0.7,zh-HK;q=0.5,en-US;q=0.3,en;q=0.2" --compressed -H "Referer: https://mailustceducn-my.sharepoint.com/personal/tzt_mail_ustc_edu_cn/_layouts/15/onedrive.aspx?id="%"2Fpersonal"%"2Ftzt"%"5Fmail"%"5Fustc"%"5Fedu"%"5Fcn"%"2FDocuments"%"2F2020"%"5Fresearch"%"2FMichiGAN"%"2Fmodels&originalPath=aHR0cHM6Ly9tYWlsdXN0Y2VkdWNuLW15LnNoYXJlcG9pbnQuY29tLzpmOi9nL3BlcnNvbmFsL3R6dF9tYWlsX3VzdGNfZWR1X2NuL0V0d1ZBZFBOT2VwSmx0Wi05MmhNbDRBQm9HanFiN2o1WWw3LTR4dldxa2EwOFE_cnRpbWU9VGFobXpBeVUyRWc" -H "Connection: keep-alive" -H "Cookie: WordWacDataCenter=GUK1; WacDataCenter=GUK1; rtFa=gpN64dS/EnHq5Ooq84bqfxWopR1z2p2cg5x9bOiaxrsmNUE5M0Y3QTgtODYzRi00QUVCLTkwQTAtMTdENkEwOUU3QTM1KXkFNn2OoCfEFc/jn+Xx+UDab5hOD2PO68xJBxKzt7KlmPTuCKJX283AV52oRmpEI10IEo6s8bCxqunVbF1NXk6oKHTFd590hAqzPgDLuvfTb31Iz7Flc2aJQazulUdn6T5BGrSNzKh4xdLCu2Z3BHflE5LM+peX1kofaBO0b1d5U/urYkdHMPCqGhXjFCzzVUjEnIvwBxBZJLCwamzy46VEqj8pAcHj4RcDm24wLnRD1EwpdqRp9RsIdO/VTO+jlss8V9KbSaXL0o7uhp2rZ4atwPisRzCKfdV5ws6h84yAw+yM6QAb2QXzZCvCi0Nk95tGGI57Q+PMjI8butvzUkUAAAA=; FedAuth=77u/PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0idXRmLTgiPz48U1A+VjgsMGguZnxtZW1iZXJzaGlwfDEwMDMyMDAwY2U2NmRmZmRAbGl2ZS5jb20sMCMuZnxtZW1iZXJzaGlwfHR6dEBtYWlsLnVzdGMuZWR1LmNuLDEzMjQxNTgxOTM4MDAwMDAwMCwxMzIzODc3OTcwNjAwMDAwMDAsMTMyNTE0NjE4NzcxNDY4MDE3LDE4LjE0MC4xMTUuMTAwLDMsNWE5M2Y3YTgtODYzZi00YWViLTkwYTAtMTdkNmEwOWU3YTM1LCw0MWEwYWE0Zi0wYzFiLTRlNWQtYTY0NC1kZjQ3OTI3ZDcyMGEsOWE1NThmOWYtNDA1Yi1iMDAwLTc0MjUtY2Y3YTI3OWQyMzk1LDQzMDY5MjlmLWUwMGUtYjAwMC05Y2NmLTNhMmZmNTAwNThiMSwsMCwxMzI1MTExNjI3NzE0NjgwMTcsMTMyNTEyODkwNzcxNDY4MDE3LCwsZXlKNGJYTmZZMk1pT2lKYlhDSkRVREZjSWwwaUxDSjRiWE5mYzNOdElqb2lNU0o5LDI2NTA0Njc3NDM5OTk5OTk5OTksMTMyNTAzMDc3NjQwMDAwMDAwLDQ4ZDhkMTFkLWVhOGEtNGVmNS1hYjc4LTQzZmIzOWE1MmZmMSxLaENEZ2JBN1h3bG5TSlZaeEJzeERLUEJPWUN0ZEUwUVNobkRjdU91NDU0QjR0S3FpWnFKbDh4bVJMS24vTDdESU91SktqUktBRlQ0MFhacnVqS1NiYmJ4UjF0M1hTL28vNUJYSnBTNkd6dGVSS2Jra1d4MkMwNnQ4SmpYMkg5akpuTXNkYy93NFhVejZCcDZXRzFQMENkTDFyKzIwNjhVeGRidzhhaTJqYlFBbUI4UDdlczVQMFdKeHN5Wk1sR0k1Q0o4YVY1SkJHL0sxb2RkaFhFd0NjYUcxSVV6YUZLbDJjeExLODgveVAxMXlQVjY5M0lmY3hybjBzUjZxR2pHWUV3SStodXhoQVRoUm1sZ1NPaWZNUVBLb2ZoTEhpQ0Vabms0RTg4bGVpczJDeW9EUWVVbWExRzJ2LzllbnhyQzJnTmkrNTBZZHlIbXhiOVVMU1V2U3c9PTwvU1A+; odbn=1; cucg=1" -H "Upgrade-Insecure-Requests: 1" -o "50_net_G.pth" -------------------------------------------------------------------------------- /checkpoints/MichiGAN/download_model_IG.sh: -------------------------------------------------------------------------------- 1 | curl "https://mailustceducn-my.sharepoint.com/personal/tzt_mail_ustc_edu_cn/_layouts/15/download.aspx?UniqueId=25888552"%"2D8be3"%"2D4427"%"2Da642"%"2D657faf2d1c56" -H "User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:69.0) Gecko/20100101 Firefox/69.0" -H "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8" -H "Accept-Language: zh-CN,zh;q=0.8,zh-TW;q=0.7,zh-HK;q=0.5,en-US;q=0.3,en;q=0.2" --compressed -H "Referer: https://mailustceducn-my.sharepoint.com/personal/tzt_mail_ustc_edu_cn/_layouts/15/onedrive.aspx?id="%"2Fpersonal"%"2Ftzt"%"5Fmail"%"5Fustc"%"5Fedu"%"5Fcn"%"2FDocuments"%"2F2020"%"5Fresearch"%"2FMichiGAN"%"2Fmodels&originalPath=aHR0cHM6Ly9tYWlsdXN0Y2VkdWNuLW15LnNoYXJlcG9pbnQuY29tLzpmOi9nL3BlcnNvbmFsL3R6dF9tYWlsX3VzdGNfZWR1X2NuL0V0d1ZBZFBOT2VwSmx0Wi05MmhNbDRBQm9HanFiN2o1WWw3LTR4dldxa2EwOFE_cnRpbWU9VGFobXpBeVUyRWc" -H "Connection: keep-alive" -H "Cookie: WordWacDataCenter=GUK1; WacDataCenter=GUK1; rtFa=gpN64dS/EnHq5Ooq84bqfxWopR1z2p2cg5x9bOiaxrsmNUE5M0Y3QTgtODYzRi00QUVCLTkwQTAtMTdENkEwOUU3QTM1KXkFNn2OoCfEFc/jn+Xx+UDab5hOD2PO68xJBxKzt7KlmPTuCKJX283AV52oRmpEI10IEo6s8bCxqunVbF1NXk6oKHTFd590hAqzPgDLuvfTb31Iz7Flc2aJQazulUdn6T5BGrSNzKh4xdLCu2Z3BHflE5LM+peX1kofaBO0b1d5U/urYkdHMPCqGhXjFCzzVUjEnIvwBxBZJLCwamzy46VEqj8pAcHj4RcDm24wLnRD1EwpdqRp9RsIdO/VTO+jlss8V9KbSaXL0o7uhp2rZ4atwPisRzCKfdV5ws6h84yAw+yM6QAb2QXzZCvCi0Nk95tGGI57Q+PMjI8butvzUkUAAAA=; FedAuth=77u/PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0idXRmLTgiPz48U1A+VjgsMGguZnxtZW1iZXJzaGlwfDEwMDMyMDAwY2U2NmRmZmRAbGl2ZS5jb20sMCMuZnxtZW1iZXJzaGlwfHR6dEBtYWlsLnVzdGMuZWR1LmNuLDEzMjQxNTgxOTM4MDAwMDAwMCwxMzIzODc3OTcwNjAwMDAwMDAsMTMyNTE0NjE4NzcxNDY4MDE3LDE4LjE0MC4xMTUuMTAwLDMsNWE5M2Y3YTgtODYzZi00YWViLTkwYTAtMTdkNmEwOWU3YTM1LCw0MWEwYWE0Zi0wYzFiLTRlNWQtYTY0NC1kZjQ3OTI3ZDcyMGEsOWE1NThmOWYtNDA1Yi1iMDAwLTc0MjUtY2Y3YTI3OWQyMzk1LDQzMDY5MjlmLWUwMGUtYjAwMC05Y2NmLTNhMmZmNTAwNThiMSwsMCwxMzI1MTExNjI3NzE0NjgwMTcsMTMyNTEyODkwNzcxNDY4MDE3LCwsZXlKNGJYTmZZMk1pT2lKYlhDSkRVREZjSWwwaUxDSjRiWE5mYzNOdElqb2lNU0o5LDI2NTA0Njc3NDM5OTk5OTk5OTksMTMyNTAzMDc3NjQwMDAwMDAwLDQ4ZDhkMTFkLWVhOGEtNGVmNS1hYjc4LTQzZmIzOWE1MmZmMSxLaENEZ2JBN1h3bG5TSlZaeEJzeERLUEJPWUN0ZEUwUVNobkRjdU91NDU0QjR0S3FpWnFKbDh4bVJMS24vTDdESU91SktqUktBRlQ0MFhacnVqS1NiYmJ4UjF0M1hTL28vNUJYSnBTNkd6dGVSS2Jra1d4MkMwNnQ4SmpYMkg5akpuTXNkYy93NFhVejZCcDZXRzFQMENkTDFyKzIwNjhVeGRidzhhaTJqYlFBbUI4UDdlczVQMFdKeHN5Wk1sR0k1Q0o4YVY1SkJHL0sxb2RkaFhFd0NjYUcxSVV6YUZLbDJjeExLODgveVAxMXlQVjY5M0lmY3hybjBzUjZxR2pHWUV3SStodXhoQVRoUm1sZ1NPaWZNUVBLb2ZoTEhpQ0Vabms0RTg4bGVpczJDeW9EUWVVbWExRzJ2LzllbnhyQzJnTmkrNTBZZHlIbXhiOVVMU1V2U3c9PTwvU1A+; odbn=1; cucg=1" -H "Upgrade-Insecure-Requests: 1" -o "InpaintingModel_gen.pth" -------------------------------------------------------------------------------- /checkpoints/MichiGAN/download_model_SIG.sh: -------------------------------------------------------------------------------- 1 | curl "https://mailustceducn-my.sharepoint.com/personal/tzt_mail_ustc_edu_cn/_layouts/15/download.aspx?UniqueId=c540b02a"%"2D3155"%"2D4180"%"2Da820"%"2Dd47d1e42345c" -H "User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:69.0) Gecko/20100101 Firefox/69.0" -H "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8" -H "Accept-Language: zh-CN,zh;q=0.8,zh-TW;q=0.7,zh-HK;q=0.5,en-US;q=0.3,en;q=0.2" --compressed -H "Referer: https://mailustceducn-my.sharepoint.com/personal/tzt_mail_ustc_edu_cn/_layouts/15/onedrive.aspx?id="%"2Fpersonal"%"2Ftzt"%"5Fmail"%"5Fustc"%"5Fedu"%"5Fcn"%"2FDocuments"%"2F2020"%"5Fresearch"%"2FMichiGAN"%"2Fmodels&originalPath=aHR0cHM6Ly9tYWlsdXN0Y2VkdWNuLW15LnNoYXJlcG9pbnQuY29tLzpmOi9nL3BlcnNvbmFsL3R6dF9tYWlsX3VzdGNfZWR1X2NuL0V0d1ZBZFBOT2VwSmx0Wi05MmhNbDRBQm9HanFiN2o1WWw3LTR4dldxa2EwOFE_cnRpbWU9VGFobXpBeVUyRWc" -H "Connection: keep-alive" -H "Cookie: WordWacDataCenter=GUK1; WacDataCenter=GUK1; rtFa=gpN64dS/EnHq5Ooq84bqfxWopR1z2p2cg5x9bOiaxrsmNUE5M0Y3QTgtODYzRi00QUVCLTkwQTAtMTdENkEwOUU3QTM1KXkFNn2OoCfEFc/jn+Xx+UDab5hOD2PO68xJBxKzt7KlmPTuCKJX283AV52oRmpEI10IEo6s8bCxqunVbF1NXk6oKHTFd590hAqzPgDLuvfTb31Iz7Flc2aJQazulUdn6T5BGrSNzKh4xdLCu2Z3BHflE5LM+peX1kofaBO0b1d5U/urYkdHMPCqGhXjFCzzVUjEnIvwBxBZJLCwamzy46VEqj8pAcHj4RcDm24wLnRD1EwpdqRp9RsIdO/VTO+jlss8V9KbSaXL0o7uhp2rZ4atwPisRzCKfdV5ws6h84yAw+yM6QAb2QXzZCvCi0Nk95tGGI57Q+PMjI8butvzUkUAAAA=; FedAuth=77u/PD94bWwgdmVyc2lvbj0iMS4wIiBlbmNvZGluZz0idXRmLTgiPz48U1A+VjgsMGguZnxtZW1iZXJzaGlwfDEwMDMyMDAwY2U2NmRmZmRAbGl2ZS5jb20sMCMuZnxtZW1iZXJzaGlwfHR6dEBtYWlsLnVzdGMuZWR1LmNuLDEzMjQxNTgxOTM4MDAwMDAwMCwxMzIzODc3OTcwNjAwMDAwMDAsMTMyNTE0NjE4NzcxNDY4MDE3LDE4LjE0MC4xMTUuMTAwLDMsNWE5M2Y3YTgtODYzZi00YWViLTkwYTAtMTdkNmEwOWU3YTM1LCw0MWEwYWE0Zi0wYzFiLTRlNWQtYTY0NC1kZjQ3OTI3ZDcyMGEsOWE1NThmOWYtNDA1Yi1iMDAwLTc0MjUtY2Y3YTI3OWQyMzk1LDQzMDY5MjlmLWUwMGUtYjAwMC05Y2NmLTNhMmZmNTAwNThiMSwsMCwxMzI1MTExNjI3NzE0NjgwMTcsMTMyNTEyODkwNzcxNDY4MDE3LCwsZXlKNGJYTmZZMk1pT2lKYlhDSkRVREZjSWwwaUxDSjRiWE5mYzNOdElqb2lNU0o5LDI2NTA0Njc3NDM5OTk5OTk5OTksMTMyNTAzMDc3NjQwMDAwMDAwLDQ4ZDhkMTFkLWVhOGEtNGVmNS1hYjc4LTQzZmIzOWE1MmZmMSxLaENEZ2JBN1h3bG5TSlZaeEJzeERLUEJPWUN0ZEUwUVNobkRjdU91NDU0QjR0S3FpWnFKbDh4bVJMS24vTDdESU91SktqUktBRlQ0MFhacnVqS1NiYmJ4UjF0M1hTL28vNUJYSnBTNkd6dGVSS2Jra1d4MkMwNnQ4SmpYMkg5akpuTXNkYy93NFhVejZCcDZXRzFQMENkTDFyKzIwNjhVeGRidzhhaTJqYlFBbUI4UDdlczVQMFdKeHN5Wk1sR0k1Q0o4YVY1SkJHL0sxb2RkaFhFd0NjYUcxSVV6YUZLbDJjeExLODgveVAxMXlQVjY5M0lmY3hybjBzUjZxR2pHWUV3SStodXhoQVRoUm1sZ1NPaWZNUVBLb2ZoTEhpQ0Vabms0RTg4bGVpczJDeW9EUWVVbWExRzJ2LzllbnhyQzJnTmkrNTBZZHlIbXhiOVVMU1V2U3c9PTwvU1A+; odbn=1; cucg=1" -H "Upgrade-Insecure-Requests: 1" -H "TE: Trailers" -o "SInpaintingModel_gen.pth" -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | import importlib 7 | import torch.utils.data 8 | from data.base_dataset import BaseDataset 9 | 10 | 11 | def find_dataset_using_name(dataset_name): 12 | # Given the option --dataset [datasetname], 13 | # the file "datasets/datasetname_dataset.py" 14 | # will be imported. 15 | dataset_filename = "data." + dataset_name + "_dataset" 16 | datasetlib = importlib.import_module(dataset_filename) 17 | 18 | # In the file, the class called DatasetNameDataset() will 19 | # be instantiated. It has to be a subclass of BaseDataset, 20 | # and it is case-insensitive. 21 | dataset = None 22 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 23 | for name, cls in datasetlib.__dict__.items(): 24 | if name.lower() == target_dataset_name.lower() \ 25 | and issubclass(cls, BaseDataset): 26 | dataset = cls 27 | 28 | if dataset is None: 29 | raise ValueError("In %s.py, there should be a subclass of BaseDataset " 30 | "with class name that matches %s in lowercase." % 31 | (dataset_filename, target_dataset_name)) 32 | 33 | return dataset 34 | 35 | 36 | def get_option_setter(dataset_name): 37 | dataset_class = find_dataset_using_name(dataset_name) 38 | return dataset_class.modify_commandline_options 39 | 40 | 41 | def create_dataloader(opt, step=1): 42 | dataset = find_dataset_using_name(opt.dataset_mode) 43 | instance = dataset() 44 | 45 | if 'custom' in opt.dataset_mode: 46 | instance.initialize(opt, step) 47 | else: 48 | instance.initialize(opt) 49 | print("dataset [%s] of size %d was created" % 50 | (type(instance).__name__, len(instance))) 51 | dataloader = torch.utils.data.DataLoader( 52 | instance, 53 | batch_size=opt.batchSize, 54 | shuffle=not opt.serial_batches, 55 | num_workers=int(opt.nThreads), 56 | drop_last=opt.isTrain 57 | ) 58 | return dataloader 59 | 60 | def create_dataset_ms(opt, step=1): 61 | dataset = find_dataset_using_name(opt.dataset_mode) 62 | instance = dataset() 63 | 64 | if 'custom' in opt.dataset_mode: 65 | instance.initialize(opt, step) 66 | else: 67 | instance.initialize(opt) 68 | print("dataset [%s] of size %d was created" % 69 | (type(instance).__name__, len(instance))) 70 | 71 | return instance 72 | -------------------------------------------------------------------------------- /data/ab_count.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/data/ab_count.npy -------------------------------------------------------------------------------- /data/clear_ab_count.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/data/clear_ab_count.npy -------------------------------------------------------------------------------- /data/custom_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | from data.pix2pix_dataset import Pix2pixDataset 7 | from data.image_folder import make_dataset 8 | import random 9 | import os 10 | 11 | 12 | class CustomDataset(Pix2pixDataset): 13 | """ Dataset that loads images from directories 14 | Use option --label_dir, --image_dir, --instance_dir to specify the directories. 15 | The images in the directories are sorted in alphabetical order and paired in order. 16 | """ 17 | 18 | @staticmethod 19 | def modify_commandline_options(parser, is_train): 20 | parser = Pix2pixDataset.modify_commandline_options(parser, is_train) 21 | parser.set_defaults(preprocess_mode='resize_and_crop') 22 | # parser.set_defaults(load_size=load_size) 23 | # parser.set_defaults(crop_size=256) 24 | # parser.set_defaults(display_winsize=256) 25 | parser.set_defaults(label_nc=2) 26 | parser.set_defaults(contain_dontcare_label=False) 27 | 28 | # parser.add_argument('--data_dir', type=str, default='/mnt/lvdisk1/tzt/HairSynthesis/SPADE-master/datasets/FFHQ', 29 | # help='path to the directory that contains training & val data') 30 | parser.add_argument('--label_dir', type=str, default='train_labels', 31 | help='path to the directory that contains label images') 32 | parser.add_argument('--image_dir', type=str, default='train_images', 33 | help='path to the directory that contains photo images') 34 | parser.add_argument('--instance_dir', type=str, default='', 35 | help='path to the directory that contains instance maps. Leave black if not exists') 36 | parser.add_argument('--orient_dir', type=str, default='train_dense_orients', 37 | help='path to the directory that contains orientation mask') 38 | parser.add_argument('--clear', type=str, default='', 39 | help='[ |clear_], clear_ means use the selected training data') 40 | 41 | return parser 42 | 43 | def get_paths(self, opt): 44 | 45 | # combine data_dir and others 46 | label_dir = os.path.join(opt.data_dir, opt.clear+opt.label_dir) 47 | image_dir = os.path.join(opt.data_dir, opt.clear+opt.image_dir) 48 | orient_dir = os.path.join(opt.data_dir, opt.clear+opt.orient_dir) 49 | 50 | # label_dir = opt.label_dir 51 | label_paths = make_dataset(label_dir, recursive=False, read_cache=True) 52 | 53 | # image_dir = opt.image_dir 54 | image_paths = make_dataset(image_dir, recursive=False, read_cache=True) 55 | 56 | if len(opt.instance_dir) > 0: 57 | instance_dir = opt.instance_dir 58 | instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True) 59 | else: 60 | instance_paths = [] 61 | 62 | if len(opt.orient_dir) > 0: 63 | # orient_dir = opt.orient_dir 64 | orient_paths = make_dataset(orient_dir, recursive=False, read_cache=True) 65 | else: 66 | orient_paths = [] 67 | 68 | assert len(label_paths) == len(image_paths), "The #images in %s and %s do not match. Is there something wrong?" 69 | 70 | return label_paths, image_paths, instance_paths, orient_paths 71 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | ############################################################################### 7 | # Code from 8 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 9 | # Modified the original code so that it also loads images from the current 10 | # directory as well as the subdirectories 11 | ############################################################################### 12 | import torch.utils.data as data 13 | from PIL import Image 14 | import os 15 | 16 | IMG_EXTENSIONS = [ 17 | '.jpg', '.JPG', '.jpeg', '.JPEG', 18 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff', '.webp' 19 | ] 20 | 21 | 22 | def is_image_file(filename): 23 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 24 | 25 | 26 | def make_dataset_rec(dir, images): 27 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 28 | 29 | for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)): 30 | for fname in fnames: 31 | if is_image_file(fname): 32 | path = os.path.join(root, fname) 33 | images.append(path) 34 | 35 | 36 | def make_dataset(dir, recursive=False, read_cache=False, write_cache=False): 37 | images = [] 38 | 39 | if read_cache: 40 | possible_filelist = os.path.join(dir, 'files.list') 41 | if os.path.isfile(possible_filelist): 42 | with open(possible_filelist, 'r') as f: 43 | images = f.read().splitlines() 44 | return images 45 | 46 | if recursive: 47 | make_dataset_rec(dir, images) 48 | else: 49 | assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir 50 | 51 | for root, dnames, fnames in sorted(os.walk(dir)): 52 | for fname in fnames: 53 | if is_image_file(fname): 54 | path = os.path.join(root, fname) 55 | images.append(path) 56 | 57 | if write_cache: 58 | filelist_cache = os.path.join(dir, 'files.list') 59 | with open(filelist_cache, 'w') as f: 60 | for path in images: 61 | f.write("%s\n" % path) 62 | print('wrote filelist cache at %s' % filelist_cache) 63 | 64 | return images 65 | 66 | 67 | def default_loader(path): 68 | return Image.open(path).convert('RGB') 69 | 70 | 71 | class ImageFolder(data.Dataset): 72 | 73 | def __init__(self, root, transform=None, return_paths=False, 74 | loader=default_loader): 75 | imgs = make_dataset(root) 76 | if len(imgs) == 0: 77 | raise(RuntimeError("Found 0 images in: " + root + "\n" 78 | "Supported image extensions are: " + 79 | ",".join(IMG_EXTENSIONS))) 80 | 81 | self.root = root 82 | self.imgs = imgs 83 | self.transform = transform 84 | self.return_paths = return_paths 85 | self.loader = loader 86 | 87 | def __getitem__(self, index): 88 | path = self.imgs[index] 89 | img = self.loader(path) 90 | if self.transform is not None: 91 | img = self.transform(img) 92 | if self.return_paths: 93 | return img, path 94 | else: 95 | return img 96 | 97 | def __len__(self): 98 | return len(self.imgs) 99 | -------------------------------------------------------------------------------- /data/pix2pix_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | from data.base_dataset import BaseDataset, get_params, get_transform, generate_hole, trans_orient_to_rgb, generate_noise, show_training_data 7 | from PIL import Image 8 | import util.util as util 9 | import os 10 | import numpy as np 11 | import torch 12 | import random 13 | 14 | 15 | 16 | class Pix2pixDataset(BaseDataset): 17 | @staticmethod 18 | def modify_commandline_options(parser, is_train): 19 | parser.add_argument('--no_pairing_check', action='store_true', 20 | help='If specified, skip sanity check of correct label-image file pairing') 21 | return parser 22 | 23 | def initialize(self, opt, step=1): 24 | self.opt = opt 25 | self.step = step 26 | 27 | label_paths, image_paths, instance_paths, orient_paths = self.get_paths(opt) 28 | 29 | util.natural_sort(label_paths) 30 | util.natural_sort(image_paths) 31 | if not opt.no_instance: 32 | util.natural_sort(instance_paths) 33 | if not opt.no_orientation: 34 | util.natural_sort(orient_paths) 35 | 36 | label_paths = label_paths[:opt.max_dataset_size] 37 | image_paths = image_paths[:opt.max_dataset_size] 38 | instance_paths = instance_paths[:opt.max_dataset_size] 39 | orient_paths = orient_paths[:opt.max_dataset_size] 40 | 41 | if not opt.no_pairing_check: 42 | for path1, path2 in zip(label_paths, image_paths): 43 | assert self.paths_match(path1, path2), \ 44 | "The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this." % (path1, path2) 45 | 46 | self.label_paths = label_paths 47 | self.image_paths = image_paths 48 | self.instance_paths = instance_paths 49 | self.orient_paths = orient_paths 50 | 51 | size = len(self.label_paths) 52 | self.dataset_size = size 53 | 54 | def get_paths(self, opt): 55 | label_paths = [] 56 | image_paths = [] 57 | instance_paths = [] 58 | assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)" 59 | return label_paths, image_paths, instance_paths 60 | 61 | def paths_match(self, path1, path2): 62 | filename1_without_ext = os.path.splitext(os.path.basename(path1))[0] 63 | filename2_without_ext = os.path.splitext(os.path.basename(path2))[0] 64 | return filename1_without_ext == filename2_without_ext 65 | 66 | def __getitem__(self, index): 67 | # tag Label 68 | label_path = self.label_paths[index] 69 | label = Image.open(label_path) 70 | params = get_params(self.opt, label.size) 71 | transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) 72 | label_tensor = transform_label(label) * 255.0 73 | label_tensor[label_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc 74 | 75 | # reference Label 76 | if self.step == 1: 77 | index_ref = index 78 | else: 79 | index_ref = random.randint(0, len(self.label_paths)-1) 80 | label_path_ref = self.label_paths[index_ref] 81 | label_ref = Image.open(label_path_ref) 82 | label_tensor_ref = transform_label(label_ref) * 255.0 83 | label_tensor_ref[label_tensor_ref == 255] = self.opt.label_nc # 'unknown' is opt.label_nc 84 | 85 | # input tag image (real images) 86 | image_path = self.image_paths[index] 87 | assert self.paths_match(label_path, image_path), \ 88 | "The label_path %s and image_path %s don't match." % \ 89 | (label_path, image_path) 90 | image = Image.open(image_path) 91 | image = image.convert('RGB') 92 | transform_image = get_transform(self.opt, params) 93 | image_tensor = transform_image(image) 94 | 95 | # input reference image (style images) 96 | image_path_ref = self.image_paths[index_ref] 97 | image_ref = Image.open(image_path_ref).convert('RGB') 98 | if self.opt.color_jitter: 99 | transform_image = get_transform(self.opt, params, color=True) 100 | image_tensor_ref = transform_image(image_ref) 101 | 102 | 103 | # if using instance maps 104 | if self.opt.no_instance: 105 | instance_tensor = 0 106 | else: 107 | instance_path = self.instance_paths[index] 108 | instance = Image.open(instance_path) 109 | if instance.mode == 'L': 110 | instance_tensor = transform_label(instance) * 255 111 | instance_tensor = instance_tensor.long() 112 | else: 113 | instance_tensor = transform_label(instance) 114 | 115 | # if using orientation maps 116 | if self.opt.no_orientation: 117 | orient_tensor = 0 118 | else: 119 | orient_path = self.orient_paths[index] 120 | orient = Image.open(orient_path) 121 | orient_tensor = transform_label(orient)*255 122 | 123 | # rgb orientation maps 124 | index_orient_ref = random.randint(0, len(self.label_paths) - 1) 125 | orient_rgb = Image.open(self.orient_paths[index_orient_ref]) 126 | orient_mask = Image.open(self.label_paths[index_orient_ref]) 127 | orient_random_param = random.random() 128 | orient_random_th = 2 129 | if self.opt.use_ig and not self.opt.no_orientation: 130 | transform_orient_rgb = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) 131 | if orient_random_param < orient_random_th: 132 | # use the target orient with erasure 133 | orient_rgb = trans_orient_to_rgb(np.array(orient), np.array(label)) 134 | orient_rgb_tensor = transform_orient_rgb(orient_rgb) * label_tensor 135 | else: 136 | # use the reference orient that not match the reference image, this is the other random orient 137 | # print('index of sample', index, index_ref, index_orient_ref) 138 | orient_rgb = trans_orient_to_rgb(np.array(orient_rgb), np.array(label), np.array(orient_mask)) 139 | orient_rgb_tensor = transform_orient_rgb(orient_rgb) 140 | orient_rgb_tensor = orient_rgb_tensor * label_tensor 141 | else: 142 | orient_rgb_tensor = torch.tensor(0) 143 | 144 | # process orient mask 145 | orient_mask_tensor = transform_label(orient_mask) * 255.0 146 | orient_mask_tensor[orient_mask_tensor == 255] = self.opt.label_nc # 'unknown' is opt.label_nc 147 | 148 | # hole mask 149 | if self.opt.use_ig: 150 | if orient_random_param < orient_random_th: 151 | hole = np.array(label) 152 | hole = generate_hole(hole, np.array(orient_mask)) 153 | hole_tensor = transform_label(hole) * 255.0 154 | else: 155 | hole_tensor = label_tensor - orient_mask_tensor * label_tensor 156 | else: 157 | hole_tensor = 0 158 | 159 | # generate noise 160 | noise = generate_noise(self.opt.crop_size, self.opt.crop_size) 161 | noise_tensor = torch.tensor(noise).permute(2, 0, 1) 162 | 163 | # # random: the reference label and image are the same with reference or not 164 | # if self.opt.only_tag: 165 | # if not self.opt.use_blender: 166 | # image_tensor_ref = image_tensor.clone() 167 | # label_tensor_ref = label_tensor.clone() 168 | # else: 169 | # if random.random() < 2: 170 | # image_tensor_ref = image_tensor.clone() 171 | # label_tensor_ref = label_tensor.clone() 172 | # else: 173 | # if random.random() < 0.2: 174 | # image_tensor_ref = image_tensor.clone() 175 | # label_tensor_ref = label_tensor.clone() 176 | 177 | 178 | input_dict = {'label_tag': label_tensor, 179 | 'label_ref': label_tensor_ref, 180 | 'instance': instance_tensor, 181 | 'image_tag': image_tensor, 182 | 'image_ref': image_tensor_ref, 183 | 'path': image_path_ref, 184 | 'orient': orient_tensor, 185 | 'hole': hole_tensor, 186 | 'orient_rgb': orient_rgb_tensor, 187 | 'noise': noise_tensor 188 | } 189 | # # show and debug 190 | # show_training_data(input_dict) 191 | # Give subclasses a chance to modify the final output 192 | self.postprocess(input_dict) 193 | 194 | return input_dict 195 | 196 | def postprocess(self, input_dict): 197 | return input_dict 198 | 199 | def __len__(self): 200 | return self.dataset_size 201 | -------------------------------------------------------------------------------- /data/single_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/data/single_result.png -------------------------------------------------------------------------------- /data/special_img_names_10000_29999.txt: -------------------------------------------------------------------------------- 1 | 10165.jpg 2 | 10176.jpg 3 | 10179.jpg 4 | 10322.jpg 5 | 10354.jpg 6 | 10622.jpg 7 | 10642.jpg 8 | 10647.jpg 9 | 10794.jpg 10 | 10827.jpg 11 | 10845.jpg 12 | 10910.jpg 13 | 10919.jpg 14 | 11008.jpg 15 | 11073.jpg 16 | 11107.jpg 17 | 11108.jpg 18 | 11136.jpg 19 | 11137.jpg 20 | 11148.jpg 21 | 11304.jpg 22 | 11431.jpg 23 | 11598.jpg 24 | 11702.jpg 25 | 11802.jpg 26 | 11872.jpg 27 | 11924.jpg 28 | 11952.jpg 29 | 11975.jpg 30 | 11990.jpg 31 | 12022.jpg 32 | 12039.jpg 33 | 12041.jpg 34 | 12147.jpg 35 | 12163.jpg 36 | 12227.jpg 37 | 12251.jpg 38 | 12431.jpg 39 | 12470.jpg 40 | 12474.jpg 41 | 12530.jpg 42 | 12559.jpg 43 | 12602.jpg 44 | 12732.jpg 45 | 12845.jpg 46 | 13007.jpg 47 | 13066.jpg 48 | 13170.jpg 49 | 13256.jpg 50 | 13328.jpg 51 | 13505.jpg 52 | 13537.jpg 53 | 13648.jpg 54 | 13757.jpg 55 | 13806.jpg 56 | 13807.jpg 57 | 13851.jpg 58 | 13894.jpg 59 | 14104.jpg 60 | 14194.jpg 61 | 14352.jpg 62 | 14465.jpg 63 | 14482.jpg 64 | 14564.jpg 65 | 14622.jpg 66 | 14761.jpg 67 | 14808.jpg 68 | 15138.jpg 69 | 15212.jpg 70 | 15252.jpg 71 | 15577.jpg 72 | 15894.jpg 73 | 15895.jpg 74 | 15928.jpg 75 | 16041.jpg 76 | 16190.jpg 77 | 16411.jpg 78 | 16419.jpg 79 | 16426.jpg 80 | 16437.jpg 81 | 16821.jpg 82 | 16964.jpg 83 | 17123.jpg 84 | 17256.jpg 85 | 17291.jpg 86 | 17305.jpg 87 | 17450.jpg 88 | 17719.jpg 89 | 17953.jpg 90 | 18064.jpg 91 | 18069.jpg 92 | 18124.jpg 93 | 18176.jpg 94 | 18196.jpg 95 | 18240.jpg 96 | 18340.jpg 97 | 18359.jpg 98 | 18647.jpg 99 | 18840.jpg 100 | 18858.jpg 101 | 18930.jpg 102 | 19016.jpg 103 | 19072.jpg 104 | 19079.jpg 105 | 19177.jpg 106 | 19209.jpg 107 | 19305.jpg 108 | 19409.jpg 109 | 19419.jpg 110 | 19527.jpg 111 | 19631.jpg 112 | 19651.jpg 113 | 19688.jpg 114 | 19834.jpg 115 | 20073.jpg 116 | 20300.jpg 117 | 20503.jpg 118 | 20624.jpg 119 | 20641.jpg 120 | 20771.jpg 121 | 21400.jpg 122 | 21668.jpg 123 | 21934.jpg 124 | 21940.jpg 125 | 22311.jpg 126 | 22321.jpg 127 | 22466.jpg 128 | 22476.jpg 129 | 22586.jpg 130 | 22591.jpg 131 | 22602.jpg 132 | 22708.jpg 133 | 22850.jpg 134 | 22891.jpg 135 | 22911.jpg 136 | 22956.jpg 137 | 22988.jpg 138 | 23037.jpg 139 | 23239.jpg 140 | 23277.jpg 141 | 23443.jpg 142 | 23648.jpg 143 | 23845.jpg 144 | 23856.jpg 145 | 24012.jpg 146 | 24211.jpg 147 | 24246.jpg 148 | 24274.jpg 149 | 24379.jpg 150 | 24514.jpg 151 | 24708.jpg 152 | 24760.jpg 153 | 24888.jpg 154 | 24915.jpg 155 | 24960.jpg 156 | 25332.jpg 157 | 25352.jpg 158 | 25765.jpg 159 | 25801.jpg 160 | 26104.jpg 161 | 26197.jpg 162 | 26198.jpg 163 | 26613.jpg 164 | 26629.jpg 165 | 26689.jpg 166 | 26721.jpg 167 | 26733.jpg 168 | 26891.jpg 169 | 26953.jpg 170 | 26987.jpg 171 | 27002.jpg 172 | 27042.jpg 173 | 27055.jpg 174 | 27168.jpg 175 | 27371.jpg 176 | 27567.jpg 177 | 27614.jpg 178 | 27727.jpg 179 | 28052.jpg 180 | 28156.jpg 181 | 28314.jpg 182 | 28339.jpg 183 | 28391.jpg 184 | 28405.jpg 185 | 28489.jpg 186 | 28546.jpg 187 | 28563.jpg 188 | 28760.jpg 189 | 28963.jpg 190 | 29093.jpg 191 | 29150.jpg 192 | 29329.jpg 193 | 29442.jpg 194 | 29526.jpg 195 | 29630.jpg 196 | 29750.jpg 197 | 29800.jpg 198 | 29936.jpg 199 | -------------------------------------------------------------------------------- /data/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/data/teaser.jpg -------------------------------------------------------------------------------- /datasets/FFHQ_demo/images/59144.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/datasets/FFHQ_demo/images/59144.jpg -------------------------------------------------------------------------------- /datasets/FFHQ_demo/images/60429.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/datasets/FFHQ_demo/images/60429.jpg -------------------------------------------------------------------------------- /datasets/FFHQ_demo/images/67172.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/datasets/FFHQ_demo/images/67172.jpg -------------------------------------------------------------------------------- /datasets/FFHQ_demo/images_recon/67172.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/datasets/FFHQ_demo/images_recon/67172.jpg -------------------------------------------------------------------------------- /datasets/FFHQ_demo/labels/59144.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/datasets/FFHQ_demo/labels/59144.png -------------------------------------------------------------------------------- /datasets/FFHQ_demo/labels/60429.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/datasets/FFHQ_demo/labels/60429.png -------------------------------------------------------------------------------- /datasets/FFHQ_demo/labels/67172.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/datasets/FFHQ_demo/labels/67172.png -------------------------------------------------------------------------------- /datasets/FFHQ_demo/orients/59144_orient_dense.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/datasets/FFHQ_demo/orients/59144_orient_dense.png -------------------------------------------------------------------------------- /datasets/FFHQ_demo/orients/60429_orient_dense.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/datasets/FFHQ_demo/orients/60429_orient_dense.png -------------------------------------------------------------------------------- /datasets/FFHQ_demo/orients/67172_orient_dense.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/datasets/FFHQ_demo/orients/67172_orient_dense.png -------------------------------------------------------------------------------- /datasets/FFHQ_single/val_dense_orients/67172_orient_dense.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/datasets/FFHQ_single/val_dense_orients/67172_orient_dense.png -------------------------------------------------------------------------------- /datasets/FFHQ_single/val_images/67172.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/datasets/FFHQ_single/val_images/67172.jpg -------------------------------------------------------------------------------- /datasets/FFHQ_single/val_labels/67172.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/datasets/FFHQ_single/val_labels/67172.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | import os 7 | from collections import OrderedDict 8 | 9 | import data 10 | from options.test_options import TestOptions 11 | from models.pix2pix_model import Pix2PixModel 12 | from util.visualizer import Visualizer 13 | from util.util import tensor2im, tensor2label, blend_image 14 | from util import html 15 | from data.base_dataset import single_inference_dataLoad 16 | from PIL import Image 17 | import torch 18 | import math 19 | import numpy as np 20 | import torch.nn as nn 21 | import cv2 22 | 23 | opt = TestOptions().parse() 24 | 25 | model = Pix2PixModel(opt) 26 | model.eval() 27 | 28 | visualizer = Visualizer(opt) 29 | 30 | criterionRGBL1 = nn.L1Loss() 31 | criterionRGBL2 = nn.MSELoss() 32 | 33 | # read data 34 | data = single_inference_dataLoad(opt) 35 | # forward 36 | generated = model(data, mode='inference') 37 | img_path = data['path'] 38 | print('process image... %s' % img_path) 39 | 40 | # remove background 41 | if opt.remove_background: 42 | generated = generated * data['label_tag'].float() + data['image_tag'] *(1 - data['label_tag'].float()) 43 | fake_image = tensor2im(generated[0]) 44 | if opt.add_feat_zeros or opt.add_zeros: 45 | th = opt.add_th 46 | H, W = opt.crop_size, opt.crop_size 47 | fake_image_tmp = fake_image[int(th/2):int(th/2)+H,int(th/2):int(th/2)+W,:] 48 | fake_image = fake_image_tmp 49 | 50 | fake_image_np = fake_image.copy() 51 | fake_image = Image.fromarray(np.uint8(fake_image)) 52 | 53 | if opt.use_ig: 54 | fake_image.save('./inference_samples/inpaint_fake_image.jpg') 55 | else: 56 | fake_image.save('./inference_samples/fake_image.jpg') 57 | -------------------------------------------------------------------------------- /inference_samples/inpaint_fake_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/inference_samples/inpaint_fake_image.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | import importlib 7 | import torch 8 | 9 | 10 | def find_model_using_name(model_name): 11 | # Given the option --model [modelname], 12 | # the file "models/modelname_model.py" 13 | # will be imported. 14 | model_filename = "models." + model_name + "_model" 15 | modellib = importlib.import_module(model_filename) 16 | 17 | # In the file, the class called ModelNameModel() will 18 | # be instantiated. It has to be a subclass of torch.nn.Module, 19 | # and it is case-insensitive. 20 | model = None 21 | target_model_name = model_name.replace('_', '') + 'model' 22 | for name, cls in modellib.__dict__.items(): 23 | if name.lower() == target_model_name.lower() \ 24 | and issubclass(cls, torch.nn.Module): 25 | model = cls 26 | 27 | if model is None: 28 | print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name)) 29 | exit(0) 30 | 31 | return model 32 | 33 | 34 | def get_option_setter(model_name): 35 | model_class = find_model_using_name(model_name) 36 | return model_class.modify_commandline_options 37 | 38 | 39 | def create_model(opt): 40 | model = find_model_using_name(opt.model) 41 | instance = model(opt) 42 | print("model [%s] was created" % (type(instance).__name__)) 43 | 44 | return instance 45 | -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | import torch 7 | from models.networks.base_network import BaseNetwork 8 | from models.networks.loss import * 9 | from models.networks.discriminator import * 10 | from models.networks.generator import * 11 | from models.networks.encoder import * 12 | from models.networks.MaskGAN_networks import Encoder as Feat_Encoder 13 | import util.util as util 14 | 15 | 16 | def find_network_using_name(target_network_name, filename): 17 | target_class_name = target_network_name + filename 18 | module_name = 'models.networks.' + filename 19 | network = util.find_class_in_module(target_class_name, module_name) 20 | 21 | assert issubclass(network, BaseNetwork), \ 22 | "Class %s should be a subclass of BaseNetwork" % network 23 | 24 | return network 25 | 26 | 27 | def modify_commandline_options(parser, is_train): 28 | opt, _ = parser.parse_known_args() 29 | 30 | netG_cls = find_network_using_name(opt.netG, 'generator') 31 | parser = netG_cls.modify_commandline_options(parser, is_train) 32 | if is_train: 33 | netD_cls = find_network_using_name(opt.netD, 'discriminator') 34 | parser = netD_cls.modify_commandline_options(parser, is_train) 35 | netE_cls = find_network_using_name('conv', 'encoder') 36 | parser = netE_cls.modify_commandline_options(parser, is_train) 37 | 38 | return parser 39 | 40 | 41 | def create_network(cls, opt): 42 | net = cls(opt) 43 | net.print_network() 44 | if len(opt.gpu_ids) > 0: 45 | assert(torch.cuda.is_available()) 46 | net.cuda() 47 | net.init_weights(opt.init_type, opt.init_variance) 48 | return net 49 | 50 | 51 | def define_G(opt): 52 | netG_cls = find_network_using_name(opt.netG, 'generator') 53 | return create_network(netG_cls, opt) 54 | 55 | def define_B(opt): 56 | netB_cls = find_network_using_name(opt.netB, 'generator') 57 | return create_network(netB_cls, opt) 58 | 59 | 60 | def define_D(opt): 61 | netD_cls = find_network_using_name(opt.netD, 'discriminator') 62 | return create_network(netD_cls, opt) 63 | 64 | 65 | def define_E(opt): 66 | # there exists only one encoder type 67 | netE_cls = find_network_using_name('conv', 'encoder') 68 | return create_network(netE_cls, opt) 69 | 70 | def define_IG(opt): 71 | netIG_cls = find_network_using_name(opt.netIG, 'generator') 72 | return create_network(netIG_cls, opt) 73 | 74 | def define_SIG(opt): 75 | netIG_cls = find_network_using_name(opt.netSIG, 'generator') 76 | return create_network(netIG_cls, opt) 77 | 78 | def define_FE(opt): 79 | net = Feat_Encoder(opt.output_nc, opt.feat_num, 16, 4) 80 | net.print_network() 81 | if len(opt.gpu_ids) > 0: 82 | assert(torch.cuda.is_available()) 83 | net.cuda() 84 | net.init_weights(opt.init_type, opt.init_variance) 85 | return net -------------------------------------------------------------------------------- /models/networks/architecture.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torchvision 10 | import torch.nn.utils.spectral_norm as spectral_norm 11 | from models.networks.normalization import SPADE, SPADEImage 12 | import torch.nn.utils.weight_norm as weight_norm_0 13 | from models.networks.normalization import weight_norm as weight_norm_1 14 | 15 | 16 | # ResNet block that uses SPADE. 17 | # It differs from the ResNet block of pix2pixHD in that 18 | # it takes in the segmentation map as input, learns the skip connection if necessary, 19 | # and applies normalization first and then convolution. 20 | # This architecture seemed like a standard architecture for unconditional or 21 | # class-conditional GAN architecture using residual block. 22 | # The code was inspired from https://github.com/LMescheder/GAN_stability. 23 | class SPADEResnetBlock(nn.Module): 24 | def __init__(self, fin, fout, opt): 25 | super().__init__() 26 | # Attributes 27 | self.learned_shortcut = (fin != fout) 28 | fmiddle = min(fin, fout) 29 | 30 | # create conv layers 31 | self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) 32 | self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) 33 | if self.learned_shortcut: 34 | self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) 35 | 36 | if opt.weight_norm_G == False: 37 | # apply spectral norm if specified 38 | if 'spectral' in opt.norm_G: 39 | self.conv_0 = spectral_norm(self.conv_0) 40 | self.conv_1 = spectral_norm(self.conv_1) 41 | if self.learned_shortcut: 42 | self.conv_s = spectral_norm(self.conv_s) 43 | else: 44 | if opt.weight_norm_g == 0: 45 | # g is learnable 46 | self.conv_0 = weight_norm_0(self.conv_0) 47 | self.conv_1 = weight_norm_0(self.conv_1) 48 | if self.learned_shortcut: 49 | self.conv_s = weight_norm_0(self.conv_s) 50 | elif opt.weight_norm_g == 1: 51 | # g == 1 52 | self.conv_0 = weight_norm_1(self.conv_0) 53 | self.conv_1 = weight_norm_1(self.conv_1) 54 | if self.learned_shortcut: 55 | self.conv_s = weight_norm_1(self.conv_s) 56 | 57 | # define normalization layers 58 | norm_nc = opt.label_nc + (opt.orient_nc if not opt.no_orientation else 0) + (opt.feat_num if opt.use_instance_feat else 0) + (3 if 'spadebase' in opt.netG else 0) 59 | spade_config_str = opt.norm_G.replace('spectral', '') 60 | self.norm_0 = SPADE(spade_config_str, fin, norm_nc, opt.weight_norm_G) 61 | self.norm_1 = SPADE(spade_config_str, fmiddle, norm_nc, opt.weight_norm_G) 62 | if self.learned_shortcut: 63 | self.norm_s = SPADE(spade_config_str, fin, norm_nc, opt.weight_norm_G) 64 | 65 | # note the resnet block with SPADE also takes in |seg|, 66 | # the semantic segmentation map as input 67 | def forward(self, x, seg): 68 | x_s = self.shortcut(x, seg) 69 | 70 | dx = self.conv_0(self.actvn(self.norm_0(x, seg))) 71 | dx = self.conv_1(self.actvn(self.norm_1(dx, seg))) 72 | 73 | out = x_s + dx 74 | 75 | return out 76 | 77 | def shortcut(self, x, seg): 78 | if self.learned_shortcut: 79 | x_s = self.conv_s(self.norm_s(x, seg)) 80 | else: 81 | x_s = x 82 | return x_s 83 | 84 | def actvn(self, x): 85 | return F.leaky_relu(x, 2e-1) 86 | 87 | # MaskGAN architecture that use spade model 88 | class SPADEImageBlock(nn.Module): 89 | def __init__(self, fin, fout, opt, downsample_n): 90 | super().__init__() 91 | # Attributes 92 | self.learned_shortcut = (fin != fout) 93 | fmiddle = min(fin, fout) 94 | 95 | # create conv layers 96 | self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) 97 | self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) 98 | if self.learned_shortcut: 99 | self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) 100 | 101 | # apply spectral norm if specified 102 | if 'spectral' in opt.norm_G: 103 | self.conv_0 = spectral_norm(self.conv_0) 104 | self.conv_1 = spectral_norm(self.conv_1) 105 | if self.learned_shortcut: 106 | self.conv_s = spectral_norm(self.conv_s) 107 | 108 | # define normalization layers 109 | spade_config_str = opt.norm_G.replace('spectral', '') 110 | self.norm_0 = SPADEImage(spade_config_str, fin, 3, downsample_n) 111 | self.norm_1 = SPADEImage(spade_config_str, fmiddle, 3, downsample_n) 112 | if self.learned_shortcut: 113 | self.norm_s = SPADEImage(spade_config_str, fin, 3, downsample_n) 114 | 115 | # note the resnet block with SPADE also takes in |seg|, 116 | # the semantic segmentation map as input 117 | def forward(self, x, image): 118 | x_s = self.shortcut(x, image) 119 | 120 | dx = self.conv_0(self.actvn(self.norm_0(x, image))) 121 | dx = self.conv_1(self.actvn(self.norm_1(dx, image))) 122 | 123 | out = x_s + dx 124 | 125 | return out 126 | 127 | def shortcut(self, x, image): 128 | if self.learned_shortcut: 129 | x_s = self.conv_s(self.norm_s(x, image)) 130 | else: 131 | x_s = x 132 | return x_s 133 | 134 | def actvn(self, x): 135 | return F.leaky_relu(x, 2e-1) 136 | 137 | 138 | # ResNet block used in pix2pixHD 139 | # We keep the same architecture as pix2pixHD. 140 | class ResnetBlock(nn.Module): 141 | def __init__(self, dim, norm_layer, activation=nn.ReLU(False), kernel_size=3): 142 | super().__init__() 143 | 144 | pw = (kernel_size - 1) // 2 145 | self.conv_block = nn.Sequential( 146 | nn.ReflectionPad2d(pw), 147 | norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)), 148 | activation, 149 | nn.ReflectionPad2d(pw), 150 | norm_layer(nn.Conv2d(dim, dim, kernel_size=kernel_size)) 151 | ) 152 | 153 | def forward(self, x): 154 | y = self.conv_block(x) 155 | out = x + y 156 | return out 157 | 158 | 159 | # VGG architecter, used for the perceptual loss using a pretrained VGG network 160 | class VGG19(torch.nn.Module): 161 | def __init__(self, requires_grad=False): 162 | super().__init__() 163 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features 164 | self.slice1 = torch.nn.Sequential() 165 | self.slice2 = torch.nn.Sequential() 166 | self.slice3 = torch.nn.Sequential() 167 | self.slice4 = torch.nn.Sequential() 168 | self.slice5 = torch.nn.Sequential() 169 | for x in range(2): 170 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 171 | for x in range(2, 7): 172 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 173 | for x in range(7, 12): 174 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 175 | for x in range(12, 21): 176 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 177 | for x in range(21, 30): 178 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 179 | if not requires_grad: 180 | for param in self.parameters(): 181 | param.requires_grad = False 182 | 183 | def forward(self, X): 184 | h_relu1 = self.slice1(X) 185 | h_relu2 = self.slice2(h_relu1) 186 | h_relu3 = self.slice3(h_relu2) 187 | h_relu4 = self.slice4(h_relu3) 188 | h_relu5 = self.slice5(h_relu4) 189 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 190 | return out 191 | 192 | 193 | -------------------------------------------------------------------------------- /models/networks/base_network.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | import torch.nn as nn 7 | from torch.nn import init 8 | 9 | 10 | class BaseNetwork(nn.Module): 11 | def __init__(self): 12 | super(BaseNetwork, self).__init__() 13 | 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train): 16 | return parser 17 | 18 | def print_network(self): 19 | if isinstance(self, list): 20 | self = self[0] 21 | num_params = 0 22 | for param in self.parameters(): 23 | num_params += param.numel() 24 | print('Network [%s] was created. Total number of parameters: %.1f million. ' 25 | 'To see the architecture, do print(network).' 26 | % (type(self).__name__, num_params / 1000000)) 27 | 28 | def init_weights(self, init_type='normal', gain=0.02): 29 | def init_func(m): 30 | classname = m.__class__.__name__ 31 | if classname.find('BatchNorm2d') != -1: 32 | if hasattr(m, 'weight') and m.weight is not None: 33 | init.normal_(m.weight.data, 1.0, gain) 34 | if hasattr(m, 'bias') and m.bias is not None: 35 | init.constant_(m.bias.data, 0.0) 36 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 37 | if init_type == 'normal': 38 | init.normal_(m.weight.data, 0.0, gain) 39 | elif init_type == 'xavier': 40 | init.xavier_normal_(m.weight.data, gain=gain) 41 | elif init_type == 'xavier_uniform': 42 | init.xavier_uniform_(m.weight.data, gain=1.0) 43 | elif init_type == 'kaiming': 44 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 45 | elif init_type == 'orthogonal': 46 | init.orthogonal_(m.weight.data, gain=gain) 47 | elif init_type == 'none': # uses pytorch's default init method 48 | m.reset_parameters() 49 | else: 50 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 51 | if hasattr(m, 'bias') and m.bias is not None: 52 | init.constant_(m.bias.data, 0.0) 53 | 54 | self.apply(init_func) 55 | 56 | # propagate to children 57 | for m in self.children(): 58 | if hasattr(m, 'init_weights'): 59 | m.init_weights(init_type, gain) 60 | -------------------------------------------------------------------------------- /models/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | import torch.nn as nn 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from models.networks.base_network import BaseNetwork 10 | from models.networks.normalization import get_nonspade_norm_layer 11 | import util.util as util 12 | 13 | 14 | class MultiscaleDiscriminator(BaseNetwork): 15 | @staticmethod 16 | def modify_commandline_options(parser, is_train): 17 | parser.add_argument('--netD_subarch', type=str, default='n_layer', 18 | help='architecture of each discriminator') 19 | parser.add_argument('--num_D', type=int, default=2, 20 | help='number of discriminators to be used in multiscale') 21 | opt, _ = parser.parse_known_args() 22 | 23 | # define properties of each discriminator of the multiscale discriminator 24 | subnetD = util.find_class_in_module(opt.netD_subarch + 'discriminator', 25 | 'models.networks.discriminator') 26 | subnetD.modify_commandline_options(parser, is_train) 27 | 28 | return parser 29 | 30 | def __init__(self, opt): 31 | super().__init__() 32 | self.opt = opt 33 | 34 | for i in range(opt.num_D): 35 | subnetD = self.create_single_discriminator(opt) 36 | self.add_module('discriminator_%d' % i, subnetD) 37 | 38 | def create_single_discriminator(self, opt): 39 | subarch = opt.netD_subarch 40 | if subarch == 'n_layer': 41 | netD = NLayerDiscriminator(opt) 42 | else: 43 | raise ValueError('unrecognized discriminator subarchitecture %s' % subarch) 44 | return netD 45 | 46 | def downsample(self, input): 47 | return F.avg_pool2d(input, kernel_size=3, 48 | stride=2, padding=[1, 1], 49 | count_include_pad=False) 50 | 51 | # Returns list of lists of discriminator outputs. 52 | # The final result is of size opt.num_D x opt.n_layers_D 53 | def forward(self, input): 54 | result = [] 55 | get_intermediate_features = not self.opt.no_ganFeat_loss 56 | for name, D in self.named_children(): 57 | out = D(input) 58 | if not get_intermediate_features: 59 | out = [out] 60 | result.append(out) 61 | input = self.downsample(input) 62 | 63 | return result 64 | 65 | 66 | # Defines the PatchGAN discriminator with the specified arguments. 67 | class NLayerDiscriminator(BaseNetwork): 68 | @staticmethod 69 | def modify_commandline_options(parser, is_train): 70 | parser.add_argument('--n_layers_D', type=int, default=4, 71 | help='# layers in each discriminator') 72 | return parser 73 | 74 | def __init__(self, opt): 75 | super().__init__() 76 | self.opt = opt 77 | 78 | kw = 4 79 | padw = int(np.ceil((kw - 1.0) / 2)) 80 | nf = opt.ndf 81 | input_nc = self.compute_D_input_nc(opt) 82 | 83 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_D) 84 | sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), 85 | nn.LeakyReLU(0.2, False)]] 86 | 87 | for n in range(1, opt.n_layers_D): 88 | nf_prev = nf 89 | nf = min(nf * 2, 512) 90 | stride = 1 if n == opt.n_layers_D - 1 else 2 91 | sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, 92 | stride=stride, padding=padw)), 93 | nn.LeakyReLU(0.2, False) 94 | ]] 95 | 96 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 97 | 98 | # We divide the layers into groups to extract intermediate layer outputs 99 | for n in range(len(sequence)): 100 | self.add_module('model' + str(n), nn.Sequential(*sequence[n])) 101 | 102 | def compute_D_input_nc(self, opt): 103 | input_nc = opt.label_nc + opt.output_nc + opt.orient_nc 104 | if opt.contain_dontcare_label: 105 | input_nc += 1 106 | if not opt.no_instance: 107 | input_nc += 1 108 | return input_nc 109 | 110 | def forward(self, input): 111 | results = [input] 112 | for submodel in self.children(): 113 | intermediate_output = submodel(results[-1]) 114 | results.append(intermediate_output) 115 | 116 | get_intermediate_features = not self.opt.no_ganFeat_loss 117 | if get_intermediate_features: 118 | return results[1:] 119 | else: 120 | return results[-1] 121 | -------------------------------------------------------------------------------- /models/networks/encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | import torch.nn as nn 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from models.networks.base_network import BaseNetwork 10 | from models.networks.normalization import get_nonspade_norm_layer 11 | from models.networks.MaskGAN_networks import ConvBlock 12 | import torch 13 | from models.networks.partialconv2d import PartialConv2d 14 | import random 15 | 16 | 17 | class ConvEncoder(BaseNetwork): 18 | """ Same architecture as the image discriminator """ 19 | 20 | def __init__(self, opt): 21 | super().__init__() 22 | 23 | kw = 3 24 | pw = int(np.ceil((kw - 1.0) / 2)) 25 | ndf = opt.ngf 26 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_E) 27 | self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw)) 28 | self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw)) 29 | self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw)) 30 | self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw)) 31 | self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) 32 | if opt.crop_size >= 256: 33 | self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) 34 | 35 | self.so = s0 = 4 36 | self.fc_mu = nn.Linear(ndf * 8 * s0 * s0, 256) 37 | self.fc_var = nn.Linear(ndf * 8 * s0 * s0, 256) 38 | 39 | self.actvn = nn.LeakyReLU(0.2, False) 40 | self.opt = opt 41 | 42 | def forward(self, x): 43 | if x.size(2) != 256 or x.size(3) != 256: 44 | x = F.interpolate(x, size=(256, 256), mode='bilinear') 45 | 46 | x = self.layer1(x) 47 | x = self.layer2(self.actvn(x)) 48 | x = self.layer3(self.actvn(x)) 49 | x = self.layer4(self.actvn(x)) 50 | x = self.layer5(self.actvn(x)) 51 | if self.opt.crop_size >= 256: 52 | x = self.layer6(self.actvn(x)) 53 | x = self.actvn(x) 54 | 55 | x = x.view(x.size(0), -1) 56 | mu = self.fc_mu(x) 57 | logvar = self.fc_var(x) 58 | 59 | return mu, logvar 60 | 61 | class ImageEncoder(BaseNetwork): 62 | """ Same architecture as the image discriminator """ 63 | 64 | def __init__(self, opt, sw, sh): 65 | super().__init__() 66 | 67 | kw = 3 68 | pw = int(np.ceil((kw - 1.0) / 2)) 69 | ndf = opt.ngf # 64 70 | self.sw = sw 71 | self.sh = sh 72 | self.opt = opt 73 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_E) 74 | self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw)) 75 | self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw)) 76 | self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw)) 77 | self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw)) 78 | self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 16, kw, stride=2, padding=pw)) 79 | if opt.crop_size >= 256: 80 | self.layer6 = norm_layer(nn.Conv2d(ndf * 8, ndf * 8, kw, stride=2, padding=pw)) 81 | 82 | self.so = s0 = 4 83 | self.adaptivepool = nn.AdaptiveAvgPool2d(1) 84 | self.fc = self.fc = nn.Conv2d(ndf*16, ndf*16*self.sw*self.sh, 1, 1, 0) 85 | 86 | self.actvn = nn.LeakyReLU(0.2, False) 87 | self.opt = opt 88 | 89 | def forward(self, x, label_ref=None, label_tag=None): 90 | if x.size(2) != 256 or x.size(3) != 256: 91 | x = F.interpolate(x, size=(256, 256), mode='bilinear') 92 | 93 | x = self.layer1(x) 94 | x = self.layer2(self.actvn(x)) 95 | x = self.layer3(self.actvn(x)) 96 | x = self.layer4(self.actvn(x)) 97 | x = self.layer5(self.actvn(x)) 98 | # if self.opt.crop_size >= 256: 99 | # x = self.layer6(self.actvn(x)) 100 | x = self.actvn(x) 101 | x = self.adaptivepool(x) 102 | x = self.fc(x) 103 | x = x.view(x.size()[0], self.opt.ngf*16, self.sh, self.sw) 104 | 105 | return x 106 | 107 | class ImageEncoder2(BaseNetwork): 108 | """ Same architecture as the image discriminator """ 109 | 110 | def __init__(self, opt, sw, sh): 111 | super().__init__() 112 | 113 | kw = 3 114 | pw = int(np.ceil((kw - 1.0) / 2)) 115 | ndf = opt.ngf # 64 116 | self.sw = sw 117 | self.sh = sh 118 | self.opt = opt 119 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_E) 120 | self.layer1 = norm_layer(nn.Conv2d(3, ndf, kw, stride=2, padding=pw)) 121 | self.layer2 = norm_layer(nn.Conv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw)) 122 | self.layer3 = norm_layer(nn.Conv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw)) 123 | self.layer4 = norm_layer(nn.Conv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw)) 124 | self.layer5 = norm_layer(nn.Conv2d(ndf * 8, ndf * 16, kw, stride=2, padding=pw)) 125 | 126 | self.actvn = nn.LeakyReLU(0.2, False) 127 | self.opt = opt 128 | 129 | def forward(self, x, label_ref, label_tag): 130 | 131 | x = self.layer1(x) 132 | x = self.layer2(self.actvn(x)) 133 | x = self.layer3(self.actvn(x)) 134 | x = self.layer4(self.actvn(x)) 135 | x = self.layer5(self.actvn(x)) 136 | # if self.opt.crop_size >= 256: 137 | # x = self.layer6(self.actvn(x)) 138 | x = self.actvn(x) 139 | # resize the label 140 | _,_,xh,xw = x.size() 141 | label_ref = F.interpolate(label_ref, size=(xh, xw), mode='nearest') 142 | label_tag = F.interpolate(label_tag, size=(xh, xw), mode='nearest') 143 | # instance_wise(hair) average pool 144 | outputs_mean = x.clone() 145 | for b in range(x.size()[0]): 146 | if self.opt.ref_global_pool: 147 | tmps = x[b,...] 148 | tmps = torch.mean(torch.mean(tmps, dim=1, keepdim=True), dim=2, keepdim=True) 149 | else: 150 | tmps = x[b, ...] * label_ref[b, ...] 151 | tmps = torch.sum(torch.sum(tmps, dim=1, keepdim=True), dim=2, keepdim=True) / max(torch.sum(label_ref[b, ...]), 1) 152 | tmps = tmps.expand_as(x[b, ...]) 153 | outputs_mean[b, ...] = tmps * label_tag[b] 154 | # resize 155 | if self.sh != xh: 156 | outputs_mean = F.interpolate(outputs_mean, size=(self.sh, self.sw), mode='nearest') 157 | 158 | return outputs_mean 159 | 160 | class ImageEncoder3(BaseNetwork): 161 | """ Same architecture as the image discriminator (Partial Convolution)""" 162 | 163 | def __init__(self, opt, sw, sh): 164 | super().__init__() 165 | 166 | kw = 3 167 | pw = int(np.ceil((kw - 1.0) / 2)) 168 | ndf = opt.ngf # 64 169 | self.sw = sw 170 | self.sh = sh 171 | self.opt = opt 172 | self.layer1 = PartialConv2d(3, ndf, kw, stride=2, padding=pw, return_mask=True) 173 | self.norm1 = nn.InstanceNorm2d(ndf, affine=False) 174 | self.layer2 = PartialConv2d(ndf * 1, ndf * 2, kw, stride=2, padding=pw, return_mask=True) 175 | self.norm2 = nn.InstanceNorm2d(ndf * 2, affine=False) 176 | self.layer3 = PartialConv2d(ndf * 2, ndf * 4, kw, stride=2, padding=pw, return_mask=True) 177 | self.norm3 = nn.InstanceNorm2d(ndf * 4, affine=False) 178 | self.layer4 = PartialConv2d(ndf * 4, ndf * 8, kw, stride=2, padding=pw, return_mask=True) 179 | self.norm4 = nn.InstanceNorm2d(ndf * 8, affine=False) 180 | self.layer5 = PartialConv2d(ndf * 8, ndf * 16, kw, stride=2, padding=pw, return_mask=True) 181 | self.norm5 = nn.InstanceNorm2d(ndf * 16, affine=False) 182 | 183 | self.actvn = nn.LeakyReLU(0.2, False) 184 | self.opt = opt 185 | 186 | def forward(self, x, label_ref0, label_tag0): 187 | 188 | if 'instance' in self.opt.norm_ref_encode: 189 | 190 | x, mask = self.layer1(x, label_ref0) 191 | x = self.norm1(x) 192 | x, mask = self.layer2(self.actvn(x), mask) 193 | x = self.norm2(x) 194 | x, mask = self.layer3(self.actvn(x), mask) 195 | x = self.norm3(x) 196 | x, mask = self.layer4(self.actvn(x), mask) 197 | x = self.norm4(x) 198 | x, mask = self.layer5(self.actvn(x), mask) 199 | x = self.norm5(x) 200 | elif 'none' in self.opt.norm_ref_encode: 201 | x, mask = self.layer1(x, label_ref0) 202 | x, mask = self.layer2(self.actvn(x), mask) 203 | x, mask = self.layer3(self.actvn(x), mask) 204 | x, mask = self.layer4(self.actvn(x), mask) 205 | x, mask = self.layer5(self.actvn(x), mask) 206 | 207 | x = self.actvn(x) 208 | # print('save feat') 209 | # self.show_feature_map(x, min_channel=0, max_channel=25) 210 | # resize the label 211 | _,_,xh,xw = x.size() 212 | label_ref = F.interpolate(label_ref0, size=(xh, xw), mode='nearest') 213 | label_tag = F.interpolate(label_tag0, size=(xh, xw), mode='nearest') 214 | # instance_wise(hair) average pool 215 | outputs_mean = x.clone() 216 | for b in range(x.size()[0]): 217 | tmps = x[b, ...] * label_ref[b, ...] 218 | tmps = torch.sum(torch.sum(tmps, dim=1, keepdim=True), dim=2, keepdim=True) / max(torch.sum(label_ref[b, ...]), 1) 219 | tmps = tmps.expand_as(x[b, ...]) 220 | outputs_mean[b, ...] = tmps * label_tag[b] 221 | # resize 222 | if self.sh != xh: 223 | outputs_mean = F.interpolate(outputs_mean, size=(self.sh, self.sw), mode='bilinear') 224 | 225 | return outputs_mean 226 | 227 | class BackgroundEncode(BaseNetwork): 228 | def __init__(self, opt): 229 | super().__init__() 230 | self.opt = opt 231 | self.ngf = opt.ngf 232 | self.conv1 = ConvBlock(3, self.ngf, 7, 1, 3, norm='none', activation='relu', pad_type='reflect') 233 | self.layer1 = ConvBlock(self.ngf, 2 * self.ngf, 4, 2, 1, norm='none', activation='relu', pad_type='reflect') 234 | self.layer2 = ConvBlock(2 * self.ngf, 4 * self.ngf, 4, 2, 1, norm='none', activation='relu', pad_type='reflect') 235 | self.layer3 = ConvBlock(4 * self.ngf, 8 * self.ngf, 4, 2, 1, norm='none', activation='relu', pad_type='reflect') 236 | self.layer4 = ConvBlock(8 * self.ngf, 16 * self.ngf, 4, 2, 1, norm='none', activation='relu', pad_type='reflect') 237 | 238 | def forward(self, image, mask): 239 | x0 = self.conv1(image) # 64 240 | x1 = self.layer1(x0) # 1/2 64*2 241 | x2 = self.layer2(x1) # 1/4 64*4 242 | x3 = self.layer3(x2) # 1/8 64*8 243 | x4 = self.layer4(x3) # 1/16 64*16 244 | 245 | back_mask = torch.unsqueeze(mask[:,0,:,:], 1) 246 | _,_,sh,sw = back_mask.size() 247 | back_mask1 = F.interpolate(back_mask, size=(int(sh/2), int(sw/2)), mode='nearest') 248 | back_mask2 = F.interpolate(back_mask, size=(int(sh / 4), int(sw / 4)), mode='nearest') 249 | back_mask3 = F.interpolate(back_mask, size=(int(sh / 8), int(sw / 8)), mode='nearest') 250 | back_mask4 = F.interpolate(back_mask, size=(int(sh / 16), int(sw / 16)), mode='nearest') 251 | 252 | 253 | return [x0, x1, x2, x3, x4], [back_mask, back_mask1, back_mask2, back_mask3, back_mask4] 254 | 255 | def save_image(image, name): 256 | from PIL import Image 257 | image_numpy = image[0,...].cpu().numpy() 258 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) #[h,w,3] 259 | image_numpy = (image_numpy + 1) / 2 * 255.0 260 | image_pil = Image.fromarray(np.uint8(image_numpy)) 261 | image_pil.save('./inference_samples/'+name) 262 | 263 | def save_mask(mask, name): 264 | from PIL import Image 265 | image_numpy = mask[0,0,...].cpu().numpy() 266 | image_numpy = image_numpy * 255.0 267 | image_pil = Image.fromarray(np.uint8(image_numpy)) 268 | image_pil.save('./inference_samples/'+name) 269 | 270 | 271 | class BackgroundEncode2(BaseNetwork): 272 | def __init__(self, opt): 273 | super().__init__() 274 | self.opt = opt 275 | self.ngf = opt.ngf 276 | if opt.num_upsampling_layers == 'most': 277 | self.conv0 = ConvBlock(3, self.ngf // 2, 7, 1, 3, norm='none', activation='relu', pad_type='reflect') 278 | self.layer0 = ConvBlock(self.ngf // 2, self.ngf, 4, 2, 1, norm='none', activation='relu', pad_type='reflect') 279 | else: 280 | self.conv1 = ConvBlock(3, self.ngf, 7, 1, 3, norm='none', activation='relu', pad_type='reflect') 281 | self.layer1 = ConvBlock(self.ngf, 2 * self.ngf, 4, 2, 1, norm='none', activation='relu', pad_type='reflect') 282 | self.layer2 = ConvBlock(2 * self.ngf, 4 * self.ngf, 4, 2, 1, norm='none', activation='relu', pad_type='reflect') 283 | self.layer3 = ConvBlock(4 * self.ngf, 8 * self.ngf, 4, 2, 1, norm='none', activation='relu', pad_type='reflect') 284 | self.layer4 = ConvBlock(8 * self.ngf, 16 * self.ngf, 4, 2, 1, norm='none', activation='relu', pad_type='reflect') 285 | 286 | def forward(self, image, mask, noise): 287 | 288 | if self.opt.isTrain: 289 | if self.opt.random_expand_mask: 290 | hair_mask = torch.unsqueeze(mask[:, 1, :, :], 1) 291 | _,_,mh,mw = hair_mask.shape 292 | th = int(mh * self.opt.random_expand_th) 293 | th = th if th % 2 == 1 else th+1 294 | k = random.choice([max(th-4,1),max(th-2,1),th,th+2,th+4]) 295 | p = int(k / 2) 296 | expand_hair_mask = F.max_pool2d(hair_mask, kernel_size=k, stride=1, padding=p) 297 | back_mask = 1 - expand_hair_mask 298 | else: 299 | back_mask = torch.unsqueeze(mask[:, 0, :, :], 1) 300 | else: 301 | if self.opt.expand_mask_be: 302 | hair_mask = torch.unsqueeze(mask[:, 1, :, :], 1) 303 | k = self.opt.expand_th 304 | p = int(k / 2) 305 | if self.opt.add_feat_zeros: 306 | th = self.opt.add_th 307 | H, W = self.opt.crop_size, self.opt.crop_size 308 | expand_hair_mask = hair_mask * 0 309 | hair_no_pad = hair_mask[:,:,int(th/2):int(th/2)+H,int(th/2):int(th/2)+W] 310 | expand_hair_no_pad = F.max_pool2d(hair_no_pad, kernel_size=k, stride=1, padding=p) 311 | expand_hair_mask[:,:,int(th/2):int(th/2)+H,int(th/2):int(th/2)+W] = expand_hair_no_pad 312 | else: 313 | expand_hair_mask = F.max_pool2d(hair_mask, kernel_size=k, stride=1, padding=p) 314 | back_mask = 1 - expand_hair_mask 315 | else: 316 | back_mask = torch.unsqueeze(mask[:, 0, :, :], 1) 317 | 318 | if self.opt.random_noise_background: 319 | input = noise 320 | else: 321 | input = image * back_mask + noise * (1 - back_mask) 322 | 323 | if self.opt.num_upsampling_layers == 'most': 324 | x00 = self.conv0(input) # 64 *0.5 325 | x0 = self.layer0(x00) # 64 1/2 326 | else: 327 | x0 = self.conv1(input) # 64 328 | x1 = self.layer1(x0) # 1/2 64*2 329 | x2 = self.layer2(x1) # 1/4 64*4 330 | x3 = self.layer3(x2) # 1/8 64*8 331 | 332 | _,_,sh,sw = back_mask.size() 333 | back_mask1 = F.interpolate(back_mask, size=(int(sh/2), int(sw/2)), mode='nearest') 334 | back_mask2 = F.interpolate(back_mask, size=(int(sh / 4), int(sw / 4)), mode='nearest') 335 | back_mask3 = F.interpolate(back_mask, size=(int(sh / 8), int(sw / 8)), mode='nearest') 336 | back_mask4 = F.interpolate(back_mask, size=(int(sh / 16), int(sw / 16)), mode='nearest') 337 | 338 | if self.opt.num_upsampling_layers == 'most': 339 | return [x3, x2, x1, x0, x00], [back_mask4, back_mask3, back_mask2, back_mask1, back_mask] 340 | else: 341 | return [x3, x2, x1, x0], [back_mask3, back_mask2, back_mask1, back_mask] 342 | 343 | 344 | 345 | -------------------------------------------------------------------------------- /models/networks/normalization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | import re 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from models.networks.sync_batchnorm import SynchronizedBatchNorm2d 11 | import torch.nn.utils.spectral_norm as spectral_norm 12 | import math 13 | import matplotlib.pyplot as plt 14 | 15 | 16 | # Returns a function that creates a normalization function 17 | # that does not condition on semantic map 18 | def get_nonspade_norm_layer(opt, norm_type='instance'): 19 | # helper function to get # output channels of the previous layer 20 | def get_out_channel(layer): 21 | if hasattr(layer, 'out_channels'): 22 | return getattr(layer, 'out_channels') 23 | return layer.weight.size(0) 24 | 25 | # this function will be returned 26 | def add_norm_layer(layer): 27 | nonlocal norm_type 28 | if norm_type.startswith('spectral'): 29 | layer = spectral_norm(layer) 30 | subnorm_type = norm_type[len('spectral'):] 31 | else: 32 | subnorm_type = norm_type 33 | 34 | if subnorm_type == 'none' or len(subnorm_type) == 0: 35 | return layer 36 | 37 | # remove bias in the previous layer, which is meaningless 38 | # since it has no effect after normalization 39 | if getattr(layer, 'bias', None) is not None: 40 | delattr(layer, 'bias') 41 | layer.register_parameter('bias', None) 42 | 43 | if subnorm_type == 'batch': 44 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) 45 | elif subnorm_type == 'sync_batch': 46 | norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) 47 | elif subnorm_type == 'instance': 48 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) 49 | else: 50 | raise ValueError('normalization layer %s is not recognized' % subnorm_type) 51 | 52 | return nn.Sequential(layer, norm_layer) 53 | 54 | return add_norm_layer 55 | 56 | 57 | # Creates SPADE normalization layer based on the given configuration 58 | # SPADE consists of two steps. First, it normalizes the activations using 59 | # your favorite normalization method, such as Batch Norm or Instance Norm. 60 | # Second, it applies scale and bias to the normalized output, conditioned on 61 | # the segmentation map. 62 | # The format of |config_text| is spade(norm)(ks), where 63 | # (norm) specifies the type of parameter-free normalization. 64 | # (e.g. syncbatch, batch, instance) 65 | # (ks) specifies the size of kernel in the SPADE module (e.g. 3x3) 66 | # Example |config_text| will be spadesyncbatch3x3, or spadeinstance5x5. 67 | # Also, the other arguments are 68 | # |norm_nc|: the #channels of the normalized activations, hence the output dim of SPADE 69 | # |label_nc|: the #channels of the input semantic map, hence the input dim of SPADE 70 | class SPADE(nn.Module): 71 | def __init__(self, config_text, norm_nc, label_nc, use_weight_norm=False): 72 | super().__init__() 73 | 74 | assert config_text.startswith('spade') 75 | parsed = re.search('spade(\D+)(\d)x\d', config_text) 76 | param_free_norm_type = str(parsed.group(1)) 77 | ks = int(parsed.group(2)) 78 | 79 | if param_free_norm_type == 'instance': 80 | self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) 81 | elif param_free_norm_type == 'syncbatch': 82 | self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) 83 | elif param_free_norm_type == 'batch': 84 | self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) 85 | else: 86 | raise ValueError('%s is not a recognized param-free norm type in SPADE' 87 | % param_free_norm_type) 88 | 89 | # The dimension of the intermediate embedding space. Yes, hardcoded. 90 | nhidden = 128 91 | 92 | pw = ks // 2 93 | self.mlp_shared = nn.Sequential( 94 | nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), 95 | nn.ReLU() 96 | ) 97 | self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 98 | self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 99 | self.use_weight_norm = use_weight_norm 100 | 101 | def forward(self, x, segmap): 102 | 103 | if self.use_weight_norm == False: 104 | # Part 1. generate parameter-free normalized activations 105 | normalized = self.param_free_norm(x) 106 | else: 107 | normalized = x 108 | 109 | # Part 2. produce scaling and bias conditioned on semantic map 110 | segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') 111 | actv = self.mlp_shared(segmap) 112 | gamma = self.mlp_gamma(actv) 113 | beta = self.mlp_beta(actv) 114 | 115 | # apply scale and bias 116 | out = normalized * (1 + gamma) + beta 117 | 118 | return out 119 | 120 | # for weight normalization 121 | from torch.nn.parameter import Parameter 122 | class Weight_norm(object): 123 | def __init__(self, name): 124 | self.name = name 125 | self.norm_weight = None 126 | def compute_weight(self, module): 127 | w = getattr(module, self.name) 128 | return Parameter(w.data) 129 | @staticmethod 130 | def apply(module, name, thea): 131 | fn = Weight_norm(name) 132 | weight = getattr(module, name) 133 | # weight_data = weight.data 134 | # weight_norm_data = weight_data / (torch.norm(weight_data)+thea) 135 | # remove w from parameter list 136 | del module._parameters[name] 137 | p = Parameter(weight / (torch.norm(weight)+thea)) 138 | module.register_parameter(name, p) 139 | setattr(module, name, fn.compute_weight(module)) 140 | # recompute weight before every forward() 141 | module.register_forward_pre_hook(fn) 142 | 143 | return fn 144 | 145 | def __call__(self, module, inputs): 146 | setattr(module, self.name, self.compute_weight(module)) 147 | 148 | def weight_norm(module, name='weight', thea = 1e-10): 149 | Weight_norm.apply(module,name,thea) 150 | return module 151 | 152 | class SPADEImage(nn.Module): 153 | def __init__(self, config_text, norm_nc, image_nc, downsample_n): 154 | super().__init__() 155 | 156 | assert config_text.startswith('spade') 157 | parsed = re.search('spade(\D+)(\d)x\d', config_text) 158 | param_free_norm_type = str(parsed.group(1)) 159 | ks = int(parsed.group(2)) 160 | 161 | if param_free_norm_type == 'instance': 162 | self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) 163 | elif param_free_norm_type == 'syncbatch': 164 | self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) 165 | elif param_free_norm_type == 'batch': 166 | self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) 167 | else: 168 | raise ValueError('%s is not a recognized param-free norm type in SPADE' 169 | % param_free_norm_type) 170 | 171 | # The dimension of the intermediate embedding space. Yes, hardcoded. 172 | nhidden = 128 173 | 174 | pw = ks // 2 175 | self.mlp_shared = nn.Sequential( 176 | nn.Conv2d(image_nc, nhidden, kernel_size=ks, padding=pw), 177 | nn.ReLU() 178 | ) 179 | self.middle = [] 180 | for i in range(downsample_n): 181 | self.middle += [nn.Conv2d(nhidden, nhidden, kernel_size=3, padding=pw, stride=2)] 182 | self.middle += [nn.ReLU()] 183 | self.middle = nn.Sequential(*self.middle) 184 | 185 | self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 186 | self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 187 | 188 | def forward(self, x, segmap): 189 | 190 | # Part 1. generate parameter-free normalized activations 191 | normalized = self.param_free_norm(x) 192 | 193 | # Part 2. produce scaling and bias conditioned on semantic map 194 | # segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') 195 | actv = self.mlp_shared(segmap) 196 | actv = self.middle(actv) 197 | gamma = self.mlp_gamma(actv) 198 | beta = self.mlp_beta(actv) 199 | 200 | # apply scale and bias 201 | out = normalized * (1 + gamma) + beta 202 | 203 | return out 204 | -------------------------------------------------------------------------------- /models/networks/partialconv2d.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # BSD 3-Clause License 3 | # 4 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 5 | # 6 | # Author & Contact: Guilin Liu (guilinl@nvidia.com) 7 | ############################################################################### 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn, cuda 12 | from torch.autograd import Variable 13 | 14 | 15 | class PartialConv2d(nn.Conv2d): 16 | def __init__(self, *args, **kwargs): 17 | 18 | # whether the mask is multi-channel or not 19 | if 'multi_channel' in kwargs: 20 | self.multi_channel = kwargs['multi_channel'] 21 | kwargs.pop('multi_channel') 22 | else: 23 | self.multi_channel = False 24 | 25 | if 'return_mask' in kwargs: 26 | self.return_mask = kwargs['return_mask'] 27 | kwargs.pop('return_mask') 28 | else: 29 | self.return_mask = False 30 | 31 | super(PartialConv2d, self).__init__(*args, **kwargs) 32 | 33 | if self.multi_channel: 34 | self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], 35 | self.kernel_size[1]) 36 | else: 37 | self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1]) 38 | 39 | self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * \ 40 | self.weight_maskUpdater.shape[3] 41 | 42 | self.last_size = (None, None, None, None) 43 | self.update_mask = None 44 | self.mask_ratio = None 45 | 46 | def forward(self, input, mask_in=None): 47 | assert len(input.shape) == 4 48 | if mask_in is not None or self.last_size != tuple(input.shape): 49 | self.last_size = tuple(input.shape) 50 | 51 | with torch.no_grad(): 52 | if self.weight_maskUpdater.type() != input.type(): 53 | self.weight_maskUpdater = self.weight_maskUpdater.to(input) 54 | 55 | if mask_in is None: 56 | # if mask is not provided, create a mask 57 | if self.multi_channel: 58 | mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], 59 | input.data.shape[3]).to(input) 60 | else: 61 | mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input) 62 | else: 63 | mask = mask_in 64 | 65 | self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, 66 | padding=self.padding, dilation=self.dilation, groups=1) 67 | 68 | # for mixed precision training, change 1e-8 to 1e-6 69 | self.mask_ratio = self.slide_winsize / (self.update_mask + 1e-8) 70 | # self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8) 71 | self.update_mask = torch.clamp(self.update_mask, 0, 1) 72 | self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask) 73 | 74 | raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input) 75 | 76 | if self.bias is not None: 77 | bias_view = self.bias.view(1, self.out_channels, 1, 1) 78 | output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view 79 | output = torch.mul(output, self.update_mask) 80 | else: 81 | output = torch.mul(raw_out, self.mask_ratio) 82 | 83 | if self.return_mask: 84 | return output, self.update_mask 85 | else: 86 | return output -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .batchnorm import patch_sync_batchnorm, convert_model 13 | from .replicate import DataParallelWithCallback, patch_replication_callback 14 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNorm2dReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | import sys 7 | import argparse 8 | import os 9 | from util import util 10 | import torch 11 | import models 12 | import data 13 | import pickle 14 | 15 | 16 | class BaseOptions(): 17 | def __init__(self): 18 | self.initialized = False 19 | 20 | def initialize(self, parser): 21 | # experiment specifics 22 | parser.add_argument('--name', type=str, default='MichiGAN', help='name of the experiment. It decides where to store samples and models') 23 | 24 | parser.add_argument('--gpu_ids', type=str, default='0,1,2,3,4,5,6,7', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 25 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 26 | parser.add_argument('--model', type=str, default='pix2pix', help='which model to use') 27 | parser.add_argument('--norm_G', type=str, default='spectralinstance', help='instance normalization or batch normalization') 28 | parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization') 29 | parser.add_argument('--norm_E', type=str, default='spectralinstance', help='instance normalization or batch normalization') 30 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 31 | parser.add_argument('--weight_norm_G', action='store_true', help='if specified, use weight normalization to replace feature norm in spade.') 32 | parser.add_argument('--weight_norm_g', type=int, default=0, help='0 means use the function by Pytorch, 1 means use ours.') 33 | # input/output sizes 34 | parser.add_argument('--batchSize', type=int, default=32, help='input batch size') 35 | parser.add_argument('--preprocess_mode', type=str, default='scale_width_and_crop', help='scaling and cropping of images at load time.', choices=("resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "scale_shortside", "scale_shortside_and_crop", "fixed", "none")) 36 | parser.add_argument('--load_size', type=int, default=512, help='Scale images to this size. The final image will be cropped to --crop_size.') 37 | parser.add_argument('--crop_size', type=int, default=512, help='Crop to the width of crop_size (after initially scaling the images to load_size.)') 38 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='The ratio width/height. The final height of the load image will be crop_size/aspect_ratio') 39 | parser.add_argument('--label_nc', type=int, default=2, help='# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.') 40 | parser.add_argument('--contain_dontcare_label', action='store_true', help='if the label map contains dontcare label (dontcare=255)') 41 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 42 | parser.add_argument('--orient_nc', type=int, default=2, help='# of orientation map channels') 43 | parser.add_argument('--add_noise_to_image', action='store_true', help='if specified, add noise to the image that remove hair') 44 | parser.add_argument('--use_original_image', action='store_true', help='if specified, use real image to generator') 45 | # parser.add_argument('--only_tag', action='store_true', help='if specified, not reference input') 46 | 47 | # for setting inputs 48 | parser.add_argument('--data_dir', type=str, default='/mnt/lvdisk1/tzt/HairSynthesis/SPADE-master/datasets/FFHQ', 49 | help='path to the directory that contains training & val data') 50 | parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/') 51 | parser.add_argument('--dataset_mode', type=str, default='custom') 52 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 53 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') 54 | parser.add_argument('--nThreads', default=8, type=int, help='# threads for loading data') 55 | parser.add_argument('--max_dataset_size', type=int, default=sys.maxsize, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 56 | parser.add_argument('--load_from_opt_file', action='store_true', help='load the options from checkpoints and use that as default') 57 | parser.add_argument('--cache_filelist_write', action='store_true', help='saves the current filelist into a text file, so that it loads faster') 58 | parser.add_argument('--cache_filelist_read', action='store_true', help='reads from the file list cache') 59 | parser.add_argument('--color_jitter', action='store_true', help='if specified, use color jitter to ref image.') 60 | parser.add_argument('--orient_random_disturb', action='store_true', help='if specified, random disturb the edges of orient for netG') 61 | parser.add_argument('--hair_random_disturb', action='store_true', help='if specified, random disturb the edges of hair for blend') 62 | 63 | # for displays 64 | parser.add_argument('--display_winsize', type=int, default=512, help='display window size') 65 | 66 | # for generator 67 | parser.add_argument('--netG', type=str, default='spadeb', help='selects model to use for netG (pix2pixhd | spade | spadeb)') 68 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 69 | parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]') 70 | parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution') 71 | parser.add_argument('--z_dim', type=int, default=256, 72 | help="dimension of the latent z vector") 73 | parser.add_argument('--netIG', type=str, default='inpaint') 74 | parser.add_argument('--use_ig', action='store_true', help='which use inpainting generator to inpaint orientation') 75 | parser.add_argument('--ig_model_name', type=str, default='InpaintingModel_gen.pth', help='pretrained inpainting generator model') 76 | parser.add_argument('--norm_model', type=str, default='instance', help='normalization model [Batch | instance] in spaderesidualunet') 77 | parser.add_argument('--fix_netG', action='store_true', 78 | help='if specified, do not update the weight of generator.') 79 | parser.add_argument('--num_upsampling_layers', 80 | choices=('normal', 'more', 'most'), default='more', 81 | help="If 'more', adds upsampling layer between the two middle resnet blocks. If 'most', also add one more upsampling + resnet layer at the end of the generator") 82 | parser.add_argument('--ms_step', type=int, default=0, help='The resolution of input for netG, [0,1,2,3,4].') 83 | parser.add_argument('--batch_sizes', type=str, default='32,32,32,16,8', help='The batch sizes for progressive training.') 84 | parser.add_argument('--alpha_value', type=float, default=-1, help='the alpha for progressive training.') 85 | parser.add_argument('--show_feat_maps', action='store_true', help='if specified, save the feature maps from generator in ./artifacts/.') 86 | 87 | # for image feature encoder like pix2pixHD 88 | parser.add_argument('--use_instance_feat', action='store_true', help='if specified, use feature encoder') 89 | parser.add_argument('--feat_num', type=int, default=3, help='the feature channels of feature encoder') 90 | parser.add_argument('--feat_input_nc', type=int, default=3, help='the input channels of feature encoder') 91 | 92 | # for reference image feature encoder 93 | parser.add_argument('--use_encoder', action='store_true', help='enable training with an image encoder.') 94 | parser.add_argument('--Image_encoder_mode', type=str, default='partialconv', help='encoder network [normal|instance|partialconv]') 95 | parser.add_argument('--norm_ref_encode', type=str, default='instance', help='[instance|none], none means no norm in ref encoder.') 96 | parser.add_argument('--ref_global_pool', action='store_true', help='if specified, use global pool in the ref encode.') 97 | 98 | # for blend network 99 | parser.add_argument('--use_blender', action='store_true', help='use blender to hold background.') 100 | parser.add_argument('--netB', type=str, default='blend2', help='select model to use for netB [blend | blend2]') 101 | parser.add_argument('--only_blend', action='store_true', help='only use blend for training or testing') 102 | 103 | # for instance-wise features 104 | parser.add_argument('--no_instance', default=True, help='if specified, do *not* add instance map as input') 105 | parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer') 106 | parser.add_argument('--use_vae', action='store_true', help='enable training with an image encoder.') 107 | 108 | # for background image input for spadeb 109 | parser.add_argument('--noise_background', action='store_true', help='if specified, do noise to target image hair.') 110 | parser.add_argument('--random_expand_mask', action='store_true', help='if specified, expand the hair mask in the background encode for train.') 111 | parser.add_argument('--random_expand_th', type=float, default=0.05, help='the threshold of random expanding.') 112 | parser.add_argument('--bf_direct_add', action='store_true', help='if specified, direct add background feature without mask.') 113 | parser.add_argument('--random_noise_background', action='store_true', help='if specified, add noise to the background encoder.') 114 | 115 | parser.add_argument('--no_orientation', default=False, help='if specified, do *not* add orientation map as input') 116 | 117 | # for stroke orient inpainting 118 | parser.add_argument('--use_stroke', action='store_true', help='if specified, use stroke inpainting network.') 119 | parser.add_argument('--inpaint_mode', type=str, default='ref', choices=['ref', 'stroke'], 120 | help='point out which inpaint network is used.') 121 | parser.add_argument('--netSIG', type=str, default='sinpaint') 122 | parser.add_argument('--sig_model_name', type=str, default='SInpaintingModel_gen.pth', 123 | help='pretrained stroke inpainting generator model.') 124 | 125 | # for image padding zeros 126 | parser.add_argument('--add_zeros', action='store_true', help='if specified, add zeros to input data.') 127 | parser.add_argument('--add_feat_zeros', action='store_true', help='if specified, add zeros to the tensor before netG.') 128 | parser.add_argument('--add_th', type=int, default=64, help='total threshold of zeros padding.') 129 | # for clip the features 130 | parser.add_argument('--clip_th', type=float, default=300, help='the clip threshold for generator features.') 131 | parser.add_argument('--use_clip', action='store_true', help='if specified, clip the features in generator.') 132 | 133 | self.initialized = True 134 | return parser 135 | 136 | def gather_options(self): 137 | # initialize parser with basic options 138 | if not self.initialized: 139 | parser = argparse.ArgumentParser( 140 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 141 | parser = self.initialize(parser) 142 | 143 | # get the basic options 144 | opt, unknown = parser.parse_known_args() 145 | 146 | # modify model-related parser options 147 | model_name = opt.model 148 | model_option_setter = models.get_option_setter(model_name) 149 | parser = model_option_setter(parser, self.isTrain) 150 | 151 | # modify dataset-related parser options 152 | dataset_mode = opt.dataset_mode 153 | dataset_option_setter = data.get_option_setter(dataset_mode) 154 | parser = dataset_option_setter(parser, self.isTrain) 155 | 156 | opt, unknown = parser.parse_known_args() 157 | 158 | # if there is opt_file, load it. 159 | # The previous default options will be overwritten 160 | if opt.load_from_opt_file: 161 | parser = self.update_options_from_file(parser, opt) 162 | 163 | opt = parser.parse_args() 164 | self.parser = parser 165 | return opt 166 | 167 | def print_options(self, opt): 168 | message = '' 169 | message += '----------------- Options ---------------\n' 170 | for k, v in sorted(vars(opt).items()): 171 | comment = '' 172 | default = self.parser.get_default(k) 173 | if v != default: 174 | comment = '\t[default: %s]' % str(default) 175 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 176 | message += '----------------- End -------------------' 177 | print(message) 178 | 179 | def option_file_path(self, opt, makedir=False): 180 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 181 | if makedir: 182 | util.mkdirs(expr_dir) 183 | file_name = os.path.join(expr_dir, 'opt') 184 | return file_name 185 | 186 | def save_options(self, opt): 187 | file_name = self.option_file_path(opt, makedir=True) 188 | with open(file_name + '.txt', 'wt') as opt_file: 189 | for k, v in sorted(vars(opt).items()): 190 | comment = '' 191 | default = self.parser.get_default(k) 192 | if v != default: 193 | comment = '\t[default: %s]' % str(default) 194 | opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)) 195 | 196 | with open(file_name + '.pkl', 'wb') as opt_file: 197 | pickle.dump(opt, opt_file) 198 | 199 | def update_options_from_file(self, parser, opt): 200 | new_opt = self.load_options(opt) 201 | for k, v in sorted(vars(opt).items()): 202 | if hasattr(new_opt, k) and v != getattr(new_opt, k): 203 | new_val = getattr(new_opt, k) 204 | parser.set_defaults(**{k: new_val}) 205 | return parser 206 | 207 | def load_options(self, opt): 208 | file_name = self.option_file_path(opt, makedir=False) 209 | new_opt = pickle.load(open(file_name + '.pkl', 'rb')) 210 | return new_opt 211 | 212 | def parse(self, save=False): 213 | 214 | opt = self.gather_options() 215 | opt.isTrain = self.isTrain # train or test 216 | 217 | self.print_options(opt) 218 | if opt.isTrain: 219 | self.save_options(opt) 220 | 221 | # Set semantic_nc based on the option. 222 | # This will be convenient in many places 223 | opt.semantic_nc = opt.label_nc + \ 224 | (1 if opt.contain_dontcare_label else 0) + \ 225 | (0 if opt.no_instance else 1) 226 | 227 | # set gpu ids 228 | str_ids = opt.gpu_ids.split(',') 229 | opt.gpu_ids = [] 230 | for str_id in str_ids: 231 | id = int(str_id) 232 | if id >= 0: 233 | opt.gpu_ids.append(id) 234 | if len(opt.gpu_ids) > 0: 235 | torch.cuda.set_device(opt.gpu_ids[0]) 236 | 237 | assert len(opt.gpu_ids) == 0 or opt.batchSize % len(opt.gpu_ids) == 0, \ 238 | "Batch size %d is wrong. It must be a multiple of # GPUs %d." \ 239 | % (opt.batchSize, len(opt.gpu_ids)) 240 | 241 | self.opt = opt 242 | return self.opt 243 | -------------------------------------------------------------------------------- /options/demo_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class DemoOptions(BaseOptions): 5 | def initialize(self, parser): 6 | BaseOptions.initialize(self, parser) 7 | 8 | parser.add_argument('--which_epoch', type=str, default='50', help='which epoch to load? set to latest to use latest cached model') 9 | parser.add_argument('--expand_th', type=int, default=5, help='threshold of expaned the tag hair mask for background encode.') 10 | parser.add_argument('--expand_mask_be', action='store_true', help='if sepcified, expaned the tag hair mask for background encode.') 11 | 12 | parser.set_defaults(preprocess_mode='scale_width_and_crop') 13 | parser.set_defaults(serial_batches=True) 14 | parser.set_defaults(no_flip=True) 15 | parser.set_defaults(gpu_ids='0') 16 | parser.set_defaults(netG='spadeb') 17 | parser.set_defaults(use_encoder=True) 18 | parser.set_defaults(use_ig=True) 19 | parser.set_defaults(noise_background=True) 20 | parser.set_defaults(load_size=512) 21 | parser.set_defaults(crop_size=512) 22 | parser.set_defaults(use_stroke=True) 23 | 24 | parser.set_defaults(name='MichiGAN') 25 | parser.set_defaults(expand_mask_be=True) 26 | parser.set_defaults(which_epoch='50') 27 | parser.set_defaults(add_feat_zeros=True) 28 | 29 | parser.set_defaults(phase='test') 30 | parser.set_defaults(batchSize=1, gpu_ids='1') 31 | parser.add_argument('--demo_data_dir', type=str, default='./datasets/FFHQ_demo/') 32 | parser.add_argument('--results_dir', type=str, 33 | default='/mnt/lvdisk1/tzt/HairSynthesis/SPADE-master/results/SPADEBEncodeInpaint5B/interactive_results/', 34 | help='saves results here.') 35 | # parser.add_argument('--inference_ref_name', type=str, default='56001', help='which reference sample to inference') 36 | # parser.add_argument('--inference_tag_name', type=str, default='56001', help='which target sample to inference') 37 | # parser.add_argument('--inference_orient_name', type=str, default='56001', help='which orient sample to inference, if not specified, means use reference orient') 38 | # parser.add_argument('--remove_background', action='store_true', help='if specified, remove background when output the fake image') 39 | # parser.add_argument('--subset', type=str, default='val', help='which subset to test [val | train]') 40 | parser.add_argument('--expand_tag_mask', action='store_true', help='if specified, expand the tag hair mask before input.') 41 | 42 | 43 | self.isTrain = False 44 | return parser -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | from .base_options import BaseOptions 7 | 8 | 9 | class TestOptions(BaseOptions): 10 | def initialize(self, parser): 11 | BaseOptions.initialize(self, parser) 12 | parser.add_argument('--results_dir', type=str, default='/mnt/lvdisk1/tzt/HairSynthesis/SPADE-master/results/', help='saves results here.') 13 | parser.add_argument('--which_epoch', type=str, default='13', help='which epoch to load? set to latest to use latest cached model') 14 | parser.add_argument('--how_many', type=int, default=5000, help='how many test images to run') 15 | 16 | parser.set_defaults(preprocess_mode='scale_width_and_crop') 17 | parser.set_defaults(serial_batches=True) 18 | parser.set_defaults(no_flip=True) 19 | parser.set_defaults(phase='test') 20 | parser.set_defaults(batchSize=1, gpu_ids='1') 21 | parser.set_defaults(gpu_ids='2') 22 | 23 | parser.add_argument('--source_dir', type=str, default='/mnt/lvdisk1/tzt/HairSynthesis/SPADE-master/results/SPADEBEncodeInpaint5B/') 24 | parser.add_argument('--source_file', type=str, default='comparison') 25 | parser.add_argument('--four_image_show', action='store_true', help='if specified, save the images contain the ref/tag/ori image.') 26 | parser.add_argument('--which_settings', type=str, default='spadeb512', help='which settings to test.') 27 | parser.add_argument('--which_random', type=str, default='orient', help='random the one of the input.') 28 | parser.add_argument('--input_relation', type=str, default='ref=tag!=ori', help='the relationship of three input frames.') 29 | parser.add_argument('--val_list_dir', type=str, default='data/val_image_list.txt', help='the text file which contains the image names to val.') 30 | 31 | parser.add_argument('--inference_ref_name', type=str, default='57541', help='which reference sample to inference') 32 | parser.add_argument('--inference_tag_name', type=str, default='56001', help='which target sample to inference') 33 | parser.add_argument('--inference_orient_name', type=str, default='56001', help='which orient sample to inference, if not specified, means use reference orient') 34 | parser.add_argument('--remove_background', action='store_true', help='if specified, remove background when output the fake image') 35 | parser.add_argument('--subset', type=str, default='val', help='which subset to test [val | train]') 36 | parser.add_argument('--expand_tag_mask', action='store_true', help='if specified, expand the tag hair mask before input..') 37 | parser.add_argument('--expand_th', type=int, default=11, help='expaned the tag hair mask for background encode.') 38 | parser.add_argument('--expand_mask_be', action='store_true', help='if sepcified, expaned the tag hair mask for background encode.') 39 | 40 | self.isTrain = False 41 | return parser 42 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | from .base_options import BaseOptions 7 | 8 | 9 | class TrainOptions(BaseOptions): 10 | def initialize(self, parser): 11 | BaseOptions.initialize(self, parser) 12 | # for displays 13 | parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 14 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 15 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 16 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 17 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 18 | parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') 19 | parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') 20 | 21 | # for training 22 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 23 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 24 | parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay') 25 | parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero') 26 | parser.add_argument('--optimizer', type=str, default='adam') 27 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 28 | parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') 29 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 30 | parser.add_argument('--D_steps_per_G', type=int, default=1, help='number of discriminator iterations per generator iterations.') 31 | parser.add_argument('--G_steps_per_D', type=int, default=1, help='number of generator iterations per discriminator iterations.') 32 | 33 | # for progressive training 34 | parser.add_argument('--smooth', action='store_true', help='if specified, smooth the training between each resolution.') 35 | parser.add_argument('--epoch_each_step', type=int, default=10, help='number of epochs for each resolution.') 36 | 37 | # add unpair training 38 | parser.add_argument('--unpairTrain', action='store_true', help='if specified, use unpair training strategy.') 39 | parser.add_argument('--curr_step', type=int, default=1, help='point out the step [1|2], 1 means the pair training stage and 2 means the unpair training stage.') 40 | parser.add_argument('--same_netD_model', action='store_true', help='if specified, use the same model to init netD and netD2.') 41 | parser.add_argument('--lambda_hairavglab', type=float, default=1.0, help='weight for hair avg lab l1 loss') 42 | 43 | 44 | # for discriminators 45 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 46 | parser.add_argument('--lambda_feat', type=float, default=1.0, help='weight for feature matching loss') 47 | parser.add_argument('--lambda_vgg', type=float, default=1.0, help='weight for vgg loss') 48 | parser.add_argument('--lambda_orient', type=float, default=10.0, help='weight for orientation loss') 49 | parser.add_argument('--lambda_confidence', type=float, default=100.0, help='weight for confidence loss') 50 | parser.add_argument('--lambda_content', type=float, default=1.0, help='weight for content loss') 51 | parser.add_argument('--lambda_style', type=float, default=1.0, help='weight for style loss') 52 | parser.add_argument('--lambda_background', type=float, default=1.0, help='weight for background loss') 53 | parser.add_argument('--lambda_rgb', type=float, default=1.0, help='weight for rgb l1 loss') 54 | parser.add_argument('--lambda_lab', type=float, default=1.0, help='weight for lab l1 loss') 55 | parser.add_argument('--no_gan_loss', action='store_true', help='if specified, do *not* use GAN loss') 56 | parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') 57 | parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss') 58 | parser.add_argument('--no_background_loss', action='store_true', help='if specified, do *not* use background loss') 59 | parser.add_argument('--no_rgb_loss', action='store_true', help='if specified, do *not* use rgb l1 loss') 60 | parser.add_argument('--no_lab_loss', action='store_true', help='if specified, do *not* use lab l1 loss') 61 | parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)') 62 | parser.add_argument('--netD', type=str, default='multiscale', help='(n_layers|multiscale|image)') 63 | parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme') 64 | parser.add_argument('--lambda_kld', type=float, default=0.05) 65 | parser.add_argument('--no_orient_loss', action='store_true', help='if specified, do *not* use orient constraint loss') 66 | parser.add_argument('--no_confidence_loss', action='store_true', 67 | help='if specified, do *not* use confidence constraint loss') 68 | parser.add_argument('--no_content_loss', action='store_true', help='if specified, do *not* use content loss') 69 | parser.add_argument('--no_style_loss', action='store_true', help='if specified, do *not* use style loss') 70 | parser.add_argument('--remove_background', action='store_true', help='if specified, remove background when calculate loss') 71 | parser.add_argument('--orient_filter', type=str, default='gabor', help='which filter is cal orient [gabor|dog]') 72 | parser.add_argument('--wide_edge', type=float, default=1.0, help='if value bigger than 1, highlight the wide edge weight when cal GAN loss') 73 | parser.add_argument('--no_discriminator', action='store_true', help='if specified, do *not* use discriminator') 74 | 75 | # Lab balance 76 | parser.add_argument('--balance_Lab', action='store_true', help='if specified, add weight when cal the Lab loss') 77 | parser.add_argument('--weight_dir', type=str, default='./data/ab_count.npy', help='weight file dir') 78 | parser.add_argument('--Lab_weight_th', type=float, default=10.0, help='The max weight value') 79 | 80 | self.isTrain = True 81 | # self.only_tag = True 82 | return parser 83 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0.0 2 | torchvision 3 | dominate>=2.3.1 4 | dill 5 | scikit-image 6 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | import torch 6 | 7 | if not torch.set_flush_denormal(True): 8 | print("Unable to set flush denormal") 9 | print("Pytorch compiled without advanced CPU") 10 | print("at: https://github.com/pytorch/pytorch/blob/84b275b70f73d5fd311f62614bccc405f3d5bfa3/aten/src/ATen/cpu/FlushDenormal.cpp#L13") 11 | import sys 12 | from collections import OrderedDict 13 | from options.train_options import TrainOptions 14 | import data 15 | from util.iter_counter import IterationCounter 16 | from util.visualizer import Visualizer 17 | from trainers.pix2pix_trainer import Pix2PixTrainer 18 | 19 | 20 | # parse options 21 | opt = TrainOptions().parse() 22 | 23 | # print options to help debugging 24 | print(' '.join(sys.argv)) 25 | 26 | # load the dataset 27 | dataloader = data.create_dataloader(opt) 28 | if opt.unpairTrain: 29 | dataloader2 = data.create_dataloader(opt, 2) 30 | 31 | # create trainer for our model 32 | trainer = Pix2PixTrainer(opt) 33 | 34 | # create tool for counting iterations 35 | iter_counter = IterationCounter(opt, len(dataloader)) 36 | data_size = len(dataloader) 37 | 38 | # create tool for visualization 39 | visualizer = Visualizer(opt) 40 | 41 | for epoch in iter_counter.training_epochs(): 42 | # for unpair training 43 | if opt.unpairTrain: 44 | # dataloader2 = data.create_dataloader(opt, 2) 45 | iter_counter.record_epoch_start(epoch) 46 | opt.curr_step = 2 47 | trainer.init_losses() 48 | for i, data_i in enumerate(dataloader2, start=iter_counter.epoch_iter): 49 | iter_counter.record_one_iteration() 50 | 51 | # Training 52 | # train generator 53 | if i % opt.D_steps_per_G == 0: 54 | trainer.run_generator_one_step(data_i) 55 | 56 | # train discriminator 57 | if i % opt.G_steps_per_D == 0 and not opt.no_discriminator: 58 | trainer.run_discriminator_one_step(data_i) 59 | 60 | # Visualizations 61 | if iter_counter.needs_printing(): 62 | losses = trainer.get_latest_losses() 63 | visualizer.print_current_errors(epoch, iter_counter.epoch_iter, 64 | losses, iter_counter.time_per_iter) 65 | visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) 66 | 67 | if iter_counter.needs_displaying(): 68 | visuals = OrderedDict([('input_ref', data_i['label_ref']), 69 | ('input_tag', data_i['label_tag']), 70 | ('synthesized_image', trainer.get_latest_generated()), 71 | ('image_ref', data_i['image_ref']), 72 | ('image_tag', data_i['image_tag'])]) 73 | visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) 74 | 75 | if iter_counter.needs_saving(): 76 | print('saving the latest model (epoch %d, total_steps %d)' % 77 | (epoch, iter_counter.total_steps_so_far)) 78 | trainer.save('latest') 79 | iter_counter.record_current_iter() 80 | 81 | 82 | trainer.update_learning_rate(epoch) 83 | iter_counter.record_epoch_end() 84 | 85 | # if epoch % opt.save_epoch_freq == 0 or \ 86 | # epoch == iter_counter.total_epochs: 87 | # print('saving the model at the end of epoch %d, iters %d' % 88 | # (epoch, iter_counter.total_steps_so_far)) 89 | # trainer.save('latest') 90 | # trainer.save(epoch) 91 | # for step 1 training 92 | # dataloader = data.create_dataloader(opt) 93 | iter_counter.record_epoch_start(epoch) 94 | opt.curr_step = 1 95 | trainer.init_losses() 96 | for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): 97 | iter_counter.record_one_iteration() 98 | 99 | # Training 100 | # train generator 101 | if i % opt.D_steps_per_G == 0: 102 | trainer.run_generator_one_step(data_i) 103 | 104 | # train discriminator 105 | if i % opt.G_steps_per_D == 0 and not opt.no_discriminator: 106 | trainer.run_discriminator_one_step(data_i) 107 | 108 | # Visualizations 109 | if iter_counter.needs_printing(): 110 | losses = trainer.get_latest_losses() 111 | visualizer.print_current_errors(epoch, iter_counter.epoch_iter, 112 | losses, iter_counter.time_per_iter) 113 | visualizer.plot_current_errors(losses, iter_counter.total_steps_so_far) 114 | 115 | if iter_counter.needs_displaying(): 116 | visuals = OrderedDict([('input_ref', data_i['label_ref']), 117 | ('input_tag', data_i['label_tag']), 118 | ('synthesized_image', trainer.get_latest_generated()), 119 | ('image_ref', data_i['image_ref']), 120 | ('image_tag', data_i['image_tag'])]) 121 | visualizer.display_current_results(visuals, epoch, iter_counter.total_steps_so_far) 122 | 123 | if iter_counter.needs_saving(): 124 | print('saving the latest model (epoch %d, total_steps %d)' % 125 | (epoch, iter_counter.total_steps_so_far)) 126 | trainer.save('latest') 127 | iter_counter.record_current_iter() 128 | 129 | # if (i + 1) * opt.batchSize >= 28000 and opt.unpairTrain: 130 | # break 131 | 132 | trainer.update_learning_rate(epoch) 133 | iter_counter.record_epoch_end() 134 | 135 | if epoch % opt.save_epoch_freq == 0 or \ 136 | epoch == iter_counter.total_epochs: 137 | print('saving the model at the end of epoch %d, iters %d' % 138 | (epoch, iter_counter.total_steps_so_far)) 139 | trainer.save('latest') 140 | trainer.save(epoch) 141 | 142 | 143 | print('Training was successfully finished.') 144 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | -------------------------------------------------------------------------------- /trainers/pix2pix_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | from models.networks.sync_batchnorm import DataParallelWithCallback 7 | from models.pix2pix_model import Pix2PixModel 8 | import pdb 9 | 10 | 11 | class Pix2PixTrainer(): 12 | """ 13 | Trainer creates the model and optimizers, and uses them to 14 | updates the weights of the network while reporting losses 15 | and the latest visuals to visualize the progress in training. 16 | """ 17 | 18 | def __init__(self, opt): 19 | self.opt = opt 20 | self.pix2pix_model = Pix2PixModel(opt) 21 | if len(opt.gpu_ids) > 0: 22 | self.pix2pix_model = DataParallelWithCallback(self.pix2pix_model, 23 | device_ids=opt.gpu_ids) 24 | self.pix2pix_model_on_one_gpu = self.pix2pix_model.module 25 | else: 26 | self.pix2pix_model_on_one_gpu = self.pix2pix_model 27 | 28 | self.generated = None 29 | if opt.isTrain: 30 | if not opt.unpairTrain: 31 | self.optimizer_G, self.optimizer_D = self.pix2pix_model_on_one_gpu.create_optimizers(opt) 32 | else: 33 | self.optimizer_G, self.optimizer_D, self.optimizer_D2 = self.pix2pix_model_on_one_gpu.create_optimizers(opt) 34 | self.old_lr = opt.lr 35 | 36 | self.d_losses = {} 37 | self.nanCount = 0 38 | 39 | def run_generator_one_step(self, data): 40 | self.optimizer_G.zero_grad() 41 | g_losses, generated = self.pix2pix_model(data, mode='generator') 42 | g_loss = sum(g_losses.values()).mean() 43 | g_loss.backward() 44 | 45 | # flag = False 46 | # for n, p in self.pix2pix_model.module.netB.named_parameters(): 47 | # g = p.grad 48 | # if (g != g).sum() > 0: 49 | # flag = True 50 | # self.nanCount = self.nanCount + 1 51 | # break 52 | # if self.nanCount > 100: 53 | # pdb.set_trace() 54 | # if not flag: 55 | # self.optimizer_G.step() 56 | # print('count:', self.nanCount) 57 | self.optimizer_G.step() 58 | self.g_losses = g_losses 59 | self.generated = generated 60 | 61 | def run_discriminator_one_step(self, data): 62 | if self.opt.curr_step == 1: 63 | # print('step1') 64 | self.optimizer_D.zero_grad() 65 | d_losses = self.pix2pix_model(data, mode='discriminator') 66 | d_loss = sum(d_losses.values()).mean() 67 | d_loss.backward() 68 | self.optimizer_D.step() 69 | self.d_losses = d_losses 70 | else: 71 | # print('step2') 72 | self.optimizer_D2.zero_grad() 73 | d_losses = self.pix2pix_model(data, mode='discriminator') 74 | d_loss = sum(d_losses.values()).mean() 75 | d_loss.backward() 76 | self.optimizer_D2.step() 77 | self.d_losses = d_losses 78 | 79 | def get_latest_losses(self): 80 | return {**self.g_losses, **self.d_losses} 81 | 82 | def get_latest_generated(self): 83 | return self.generated 84 | 85 | def update_learning_rate(self, epoch): 86 | self.update_learning_rate(epoch) 87 | 88 | def save(self, epoch): 89 | self.pix2pix_model_on_one_gpu.save(epoch) 90 | 91 | def init_losses(self): 92 | self.g_losses = {} 93 | self.d_losses = {} 94 | 95 | ################################################################## 96 | # Helper functions 97 | ################################################################## 98 | 99 | def update_learning_rate(self, epoch): 100 | if epoch > self.opt.niter: 101 | lrd = self.opt.lr / self.opt.niter_decay 102 | new_lr = self.old_lr - lrd 103 | else: 104 | new_lr = self.old_lr 105 | 106 | if new_lr != self.old_lr: 107 | if self.opt.no_TTUR: 108 | new_lr_G = new_lr 109 | new_lr_D = new_lr 110 | else: 111 | new_lr_G = new_lr / 2 112 | new_lr_D = new_lr * 2 113 | 114 | for param_group in self.optimizer_D.param_groups: 115 | param_group['lr'] = new_lr_D 116 | for param_group in self.optimizer_G.param_groups: 117 | param_group['lr'] = new_lr_G 118 | print('update learning rate: %f -> %f' % (self.old_lr, new_lr)) 119 | self.old_lr = new_lr 120 | -------------------------------------------------------------------------------- /ui/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/ui/__init__.py -------------------------------------------------------------------------------- /ui/mouse_event.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from PyQt5.QtCore import * 4 | from PyQt5.QtGui import * 5 | from PyQt5.QtWidgets import * 6 | import numpy as np 7 | 8 | color_list = [QColor(0, 0, 0), QColor(255, 255, 255), QColor(76, 153, 0), QColor(204, 204, 0), QColor(204, 0, 204), QColor(51, 51, 255), QColor(0, 255, 255), QColor(51, 255, 255), QColor(255, 0, 0), QColor(102, 51, 0), QColor(102, 204, 0), QColor(255, 255, 0), QColor(0, 0, 153), QColor(0, 0, 204), QColor(255, 51, 153), QColor(0, 204, 204), QColor(0, 51, 0), QColor(255, 153, 51), QColor(0, 204, 0)] 9 | 10 | class GraphicsScene(QGraphicsScene): 11 | def __init__(self, mode, size, parent=None): 12 | QGraphicsScene.__init__(self, parent) 13 | self.mode = mode 14 | self.size = size 15 | self.mouse_clicked = False 16 | self.prev_pt = None 17 | 18 | # self.masked_image = None 19 | 20 | # save the points 21 | self.mask_points = [] 22 | for i in range(len(color_list)): 23 | self.mask_points.append([]) 24 | 25 | # save the size of points 26 | self.size_points = [] 27 | for i in range(len(color_list)): 28 | self.size_points.append([]) 29 | 30 | # save the history of edit 31 | self.history = [] 32 | 33 | def reset(self): 34 | # save the points 35 | self.mask_points = [] 36 | for i in range(len(color_list)): 37 | self.mask_points.append([]) 38 | # save the size of points 39 | self.size_points = [] 40 | for i in range(len(color_list)): 41 | self.size_points.append([]) 42 | # save the history of edit 43 | self.history = [] 44 | 45 | self.mode = 0 46 | self.prev_pt = None 47 | 48 | def mousePressEvent(self, event): 49 | self.mouse_clicked = True 50 | 51 | def mouseReleaseEvent(self, event): 52 | self.prev_pt = None 53 | self.mouse_clicked = False 54 | 55 | def mouseMoveEvent(self, event): # drawing 56 | if self.mouse_clicked: 57 | if self.prev_pt: 58 | self.drawMask(self.prev_pt, event.scenePos(), color_list[self.mode], self.size) 59 | pts = {} 60 | pts['prev'] = (int(self.prev_pt.x()),int(self.prev_pt.y())) 61 | pts['curr'] = (int(event.scenePos().x()),int(event.scenePos().y())) 62 | 63 | self.size_points[self.mode].append(self.size) 64 | self.mask_points[self.mode].append(pts) 65 | self.history.append(self.mode) 66 | self.prev_pt = event.scenePos() 67 | else: 68 | self.prev_pt = event.scenePos() 69 | 70 | def drawMask(self, prev_pt, curr_pt, color, size): 71 | lineItem = QGraphicsLineItem(QLineF(prev_pt, curr_pt)) 72 | lineItem.setPen(QPen(color, size, Qt.SolidLine)) # rect 73 | self.addItem(lineItem) 74 | 75 | def erase_prev_pt(self): 76 | self.prev_pt = None 77 | 78 | def reset_items(self): 79 | for i in range(len(self.items())): 80 | item = self.items()[0] 81 | self.removeItem(item) 82 | 83 | def undo(self): 84 | if len(self.items())>1: 85 | if len(self.items())>=9: 86 | for i in range(8): 87 | item = self.items()[0] 88 | self.removeItem(item) 89 | if self.history[-1] == self.mode: 90 | self.mask_points[self.mode].pop() 91 | self.size_points[self.mode].pop() 92 | self.history.pop() 93 | else: 94 | for i in range(len(self.items())-1): 95 | item = self.items()[0] 96 | self.removeItem(item) 97 | if self.history[-1] == self.mode: 98 | self.mask_points[self.mode].pop() 99 | self.size_points[self.mode].pop() 100 | self.history.pop() 101 | -------------------------------------------------------------------------------- /ui/ui4.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtCore import * 2 | from PyQt5.QtGui import * 3 | from PyQt5.QtWidgets import * 4 | import numpy as np 5 | import sys 6 | import os 7 | 8 | class Ui_Form(object): 9 | def setupUi(self, Form): 10 | Form.setObjectName("Form") 11 | 12 | # for graph 13 | self.graphicsView = QGraphicsView(Form) # tag_mask 14 | self.graphicsView.setObjectName("graphicsView") 15 | self.graphicsView.setFixedSize(512, 512) 16 | self.graphicsView_2 = QGraphicsView(Form) # orient map 17 | self.graphicsView_2.setObjectName("graphicsView_2") 18 | self.graphicsView_2.setFixedSize(512, 512) 19 | self.graphicsView_3 = QGraphicsView(Form) # result 20 | self.graphicsView_3.setObjectName("graphicsView_3") 21 | self.graphicsView_3.setFixedSize(512, 512) 22 | self.graphicsView_4 = QGraphicsView(Form) # tag image 23 | self.graphicsView_4.setFixedSize(256, 256) 24 | self.graphicsView_4.setObjectName("graphicsView_4") 25 | self.graphicsView_5 = QGraphicsView(Form) # ref image 26 | self.graphicsView_5.setFixedSize(256, 256) 27 | self.graphicsView_5.setObjectName("graphicsView_5") 28 | 29 | # self.grid1 = QGridLayout() 30 | # self.grid1.addWidget(self.graphicsView, 0,0,1,1) 31 | # self.grid1.addWidget(self.graphicsView_2, 0, 1, 1, 1) 32 | # self.grid1.addWidget(self.graphicsView_3, 0, 3, 1, 1) 33 | 34 | # self.grid0 = QHBoxLayout() 35 | # self.grid0.addWidget(self.graphicsView_5) 36 | # self.grid0.addWidget(self.graphicsView_4) 37 | 38 | # for Buttons 39 | self.button_W = 107 40 | self.button_H = 37 41 | 42 | self.pushButton0 = QPushButton(Form) 43 | self.pushButton0.setObjectName("Save") 44 | self.pushButton0.setFixedSize(self.button_W, self.button_H) 45 | self.pushButton = QPushButton(Form) 46 | self.pushButton.setObjectName("Edit") 47 | self.pushButton.setFixedSize(self.button_W, self.button_H) 48 | self.pushButton_2 = QPushButton(Form) 49 | self.pushButton_2.setFixedSize(self.button_W, self.button_H) 50 | self.pushButton_2.setObjectName("open_ref_img") 51 | self.pushButton_3 = QPushButton(Form) 52 | self.pushButton_3.setFixedSize(self.button_W, self.button_H) 53 | self.pushButton_3.setObjectName("open_tag_img") 54 | self.pushButton_4 = QPushButton(Form) 55 | self.pushButton_4.setFixedSize(self.button_W, self.button_H) 56 | self.pushButton_4.setObjectName("open_mask") 57 | self.pushButton_5 = QPushButton(Form) 58 | self.pushButton_5.setFixedSize(self.button_W, self.button_H) 59 | self.pushButton_5.setObjectName("open_orient") 60 | self.pushButton_6 = QPushButton(Form) 61 | self.pushButton_6.setFixedSize(self.button_W, self.button_H) 62 | self.pushButton_6.setObjectName("hair") 63 | self.pushButton_7 = QPushButton(Form) 64 | self.pushButton_7.setFixedSize(self.button_W, self.button_H) 65 | self.pushButton_7.setObjectName("background") 66 | self.pushButton_8 = QPushButton(Form) 67 | self.pushButton_8.setFixedSize(self.button_W, self.button_H) 68 | self.pushButton_8.setObjectName("mask_+") 69 | self.pushButton_9 = QPushButton(Form) 70 | self.pushButton_9.setFixedSize(self.button_W, self.button_H) 71 | self.pushButton_9.setObjectName("mask_-") 72 | self.pushButton_10 = QPushButton(Form) 73 | self.pushButton_10.setFixedSize(self.button_W, self.button_H) 74 | self.pushButton_10.setObjectName("clear") 75 | self.pushButton_11 = QPushButton(Form) 76 | self.pushButton_11.setFixedSize(self.button_W, self.button_H) 77 | self.pushButton_11.setObjectName("brush") 78 | # self.pushButton_12 = QPushButton(Form) 79 | # self.pushButton_12.setFixedSize(self.button_W, self.button_H) 80 | # self.pushButton_12.setObjectName("background") 81 | self.pushButton_13 = QPushButton(Form) 82 | self.pushButton_13.setFixedSize(self.button_W, self.button_H) 83 | self.pushButton_13.setObjectName("orient_+") 84 | self.pushButton_14 = QPushButton(Form) 85 | self.pushButton_14.setFixedSize(self.button_W, self.button_H) 86 | self.pushButton_14.setObjectName("orient_-") 87 | # self.pushButton_15 = QPushButton(Form) 88 | # self.pushButton_15.setFixedSize(self.button_W, self.button_H) 89 | # self.pushButton_15.setObjectName("erase") 90 | 91 | _translate = QCoreApplication.translate 92 | self.pushButton0.setText(_translate("Form", "Save")) 93 | self.pushButton.setText(_translate("Form", "Edit")) 94 | self.pushButton_2.setText(_translate("Form", "Open Ref")) 95 | self.pushButton_3.setText(_translate("Form", "Open Tag")) 96 | self.pushButton_4.setText(_translate("Form", "Open Mask")) 97 | self.pushButton_5.setText(_translate("Form", "Open Orient")) 98 | self.pushButton_6.setText(_translate("Form", "Hair")) 99 | self.pushButton_7.setText(_translate("Form", "BackGround")) 100 | self.pushButton_8.setText(_translate("Form", "+")) 101 | self.pushButton_9.setText(_translate("Form", "-")) 102 | self.pushButton_10.setText(_translate("Form", "Clear")) 103 | self.pushButton_11.setText(_translate("Form", "Brush")) 104 | # self.pushButton_12.setText(_translate("Form", "Orient Edit")) 105 | self.pushButton_13.setText(_translate("Form", "+")) 106 | self.pushButton_14.setText(_translate("Form", "-")) 107 | # self.pushButton_15.setText(_translate("Form", "Erase")) 108 | 109 | self.pushButton0.clicked.connect(Form.save) 110 | self.pushButton.clicked.connect(Form.edit) 111 | self.pushButton_2.clicked.connect(Form.open_ref) 112 | self.pushButton_3.clicked.connect(Form.open_tag) 113 | self.pushButton_4.clicked.connect(Form.open_mask) 114 | self.pushButton_5.clicked.connect(Form.open_orient) 115 | self.pushButton_6.clicked.connect(Form.hair_mode) 116 | self.pushButton_7.clicked.connect(Form.bg_mode) 117 | self.pushButton_8.clicked.connect(Form.increase) 118 | self.pushButton_9.clicked.connect(Form.decrease) 119 | self.pushButton_10.clicked.connect(Form.clear) 120 | self.pushButton_11.clicked.connect(Form.orient_mode) 121 | # self.pushButton_12.clicked.connect(Form.orient_edit) 122 | self.pushButton_13.clicked.connect(Form.orient_increase) 123 | self.pushButton_14.clicked.connect(Form.orient_decrease) 124 | # self.pushButton_15.clicked.connect(Form.erase_mode) 125 | 126 | self.grid2 = QGridLayout() 127 | self.grid2.addWidget(self.pushButton0,0,1,1,1) 128 | self.grid2.addWidget(self.pushButton,0,0,1,1) 129 | self.grid2.addWidget(self.pushButton_2,1,0,1,1) 130 | self.grid2.addWidget(self.pushButton_3,1,1,1,1) 131 | # self.grid2.addWidget(self.pushButton_4) 132 | # self.grid2.addWidget(self.pushButton_5) 133 | 134 | self.grid3 = QGridLayout() 135 | self.grid3.addWidget(self.pushButton_4,0,0,1,1) 136 | self.grid3.addWidget(self.pushButton_6,0,1,1,1) 137 | self.grid3.addWidget(self.pushButton_7,1,1,1,1) 138 | self.grid3.addWidget(self.pushButton_8,0,2,1,1) 139 | self.grid3.addWidget(self.pushButton_9,1,2,1,1) 140 | self.grid3.addWidget(self.pushButton_10,1,0,1,1) 141 | 142 | self.grid4 = QGridLayout() 143 | self.grid4.addWidget(self.pushButton_5,0,0,1,1) 144 | self.grid4.addWidget(self.pushButton_11,0,1,1,1) 145 | # self.grid4.addWidget(self.pushButton_15) 146 | self.grid4.addWidget(self.pushButton_13,1,0,1,1) 147 | self.grid4.addWidget(self.pushButton_14,1,1,1,1) 148 | # self.grid4.addWidget(self.pushButton_12) 149 | 150 | 151 | # for radioButton 152 | self.clickButtion1 = QRadioButton(Form) 153 | self.clickButtion1.setText('Reference') 154 | self.clickButtion1.setChecked(True) 155 | # self.clickButtion1.clicked.connect(Form.selectM) 156 | self.clickButtion2 = QRadioButton(Form) 157 | self.clickButtion2.setText('Edited') 158 | # self.clickButtion2.clicked.connect(Form.selectM) 159 | 160 | self.grid6 = QHBoxLayout() 161 | self.grid6_1 = QGridLayout() 162 | self.grid6_1.addWidget(self.clickButtion1,0,0,1,1) 163 | self.grid6_1.addWidget(self.clickButtion2,1,0,1,1) 164 | self.grid6.addLayout(self.AddLayout(self.grid6_1, 'Hair Mask')) 165 | 166 | self.clickButtion3 = QRadioButton(Form) 167 | self.clickButtion3.setText('Reference') 168 | self.clickButtion3.setChecked(True) 169 | # self.clickButtion3.clicked.connect(Form.selectO) 170 | self.clickButtion4 = QRadioButton(Form) 171 | self.clickButtion4.setText('Edited') 172 | # self.clickButtion4.clicked.connect(Form.selectO) 173 | self.grid6_2 = QGridLayout() 174 | self.grid6_2.addWidget(self.clickButtion3,0,0,1,1) 175 | self.grid6_2.addWidget(self.clickButtion4,1,0,1,1) 176 | 177 | self.grid6.addLayout(self.AddLayout(self.grid6_2, 'Hair Orientation')) 178 | 179 | # for Layout setting 180 | mainLayout = QVBoxLayout() 181 | Form.setLayout(mainLayout) 182 | Form.resize(1616, 808) 183 | 184 | subLayout = QHBoxLayout() 185 | subLayout.addLayout(self.AddWidgt(self.graphicsView, 'Hair Mask')) 186 | subLayout.addLayout(self.AddWidgt(self.graphicsView_2, 'Hair Orientation')) 187 | subLayout.addLayout(self.AddWidgt(self.graphicsView_3, 'Result')) 188 | 189 | subLayout2_1 = QHBoxLayout() 190 | # subLayout2_1.addLayout(self.AddLayout(self.grid2, 'Main Buttons')) 191 | subLayout2_1.addLayout(self.AddLayout(self.grid3, 'Mask Edit')) 192 | subLayout2_2 = QVBoxLayout() 193 | subLayout2_2.addLayout(self.AddLayout(self.grid6, 'State')) 194 | subLayout2_2.addLayout(subLayout2_1) 195 | 196 | subLayout2_3 = QVBoxLayout() 197 | subLayout2_3.addLayout(self.AddLayout(self.grid2, 'Main Buttons')) 198 | subLayout2_3.addLayout(self.AddLayout(self.grid4, 'Orient Edit')) 199 | 200 | subLayout2 = QHBoxLayout() 201 | subLayout2.addLayout(self.AddWidgt(self.graphicsView_4, 'Tagert Image')) 202 | subLayout2.addLayout(self.AddWidgt(self.graphicsView_5, 'Reference Image')) 203 | subLayout2.addLayout(subLayout2_2) 204 | subLayout2.addLayout(subLayout2_3) 205 | 206 | mainLayout.addLayout(subLayout) 207 | mainLayout.addLayout(subLayout2) 208 | 209 | def setButtonColor(self, button, path, H, W): 210 | pixmap = QPixmap(path) 211 | fitPixmap = pixmap.scaled(W, H, Qt.IgnoreAspectRatio, Qt.SmoothTransformation) 212 | icon = QIcon(fitPixmap) 213 | button.setIcon(icon) 214 | # button.setIconSize(W,H) 215 | 216 | def AddLayout(self, widget, title=''): 217 | widgetLayout = QVBoxLayout() 218 | widgetBox = QGroupBox() 219 | if title != '': 220 | widgetBox.setTitle(title) 221 | widgetBox.setAlignment(Qt.AlignCenter) 222 | widgetBox.setLayout(widget) 223 | widgetLayout.addWidget(widgetBox) 224 | 225 | return widgetLayout 226 | 227 | 228 | def AddWidgt(self, widget, title): 229 | widgetLayout = QVBoxLayout() 230 | widgetBox = QGroupBox() 231 | widgetBox.setTitle(title) 232 | widgetBox.setAlignment(Qt.AlignCenter) 233 | vbox_t = QGridLayout() 234 | vbox_t.addWidget(widget,0,0,1,1) 235 | widgetBox.setLayout(vbox_t) 236 | widgetLayout.addWidget(widgetBox) 237 | 238 | return widgetLayout 239 | 240 | 241 | if __name__=="__main__": 242 | app=QApplication(sys.argv) 243 | Form = QWidget() 244 | ui = Ui_Form() 245 | ui.setupUi(Form) 246 | Form.show() 247 | sys.exit(app.exec_()) 248 | 249 | -------------------------------------------------------------------------------- /ui/ui_buttons.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtCore import * 2 | from PyQt5.QtGui import * 3 | from PyQt5.QtWidgets import * 4 | import numpy as np 5 | import sys 6 | 7 | 8 | class GUIButton(QWidget): 9 | def __init__(self): 10 | QWidget.__init__(self) 11 | 12 | self.pushButton = QPushButton() 13 | self.pushButton.setObjectName("Edit") 14 | self.pushButton_2 = QPushButton() 15 | self.pushButton_2.setObjectName("open_ref_img") 16 | self.pushButton_3 = QPushButton() 17 | self.pushButton_3.setObjectName("open_tag_img") 18 | self.pushButton_4 = QPushButton() 19 | self.pushButton_4.setObjectName("open_mask") 20 | self.pushButton_5 = QPushButton() 21 | self.pushButton_5.setObjectName("open_orient") 22 | self.pushButton_6 = QPushButton() 23 | self.pushButton_6.setObjectName("hair") 24 | self.pushButton_7 = QPushButton() 25 | self.pushButton_7.setObjectName("background") 26 | self.pushButton_8 = QPushButton() 27 | self.pushButton_8.setObjectName("mask_+") 28 | self.pushButton_9 = QPushButton() 29 | self.pushButton_9.setObjectName("mask_-") 30 | self.pushButton_10 = QPushButton() 31 | self.pushButton_10.setObjectName("brush") 32 | self.pushButton_11 = QPushButton() 33 | self.pushButton_11.setObjectName("background") 34 | self.pushButton_12 = QPushButton() 35 | self.pushButton_12.setObjectName("orient_+") 36 | self.pushButton_13 = QPushButton() 37 | self.pushButton_13.setObjectName("orient_-") 38 | 39 | self.pushButton.clicked.connect(self.edit) 40 | self.pushButton_2.clicked.connect(self.open_ref) 41 | self.pushButton_3.clicked.connect(open_tag) 42 | self.pushButton_4.clicked.connect(Form.open_mask) 43 | self.pushButton_5.clicked.connect(Form.open_orient) 44 | self.pushButton_6.clicked.connect(Form.save_img) 45 | self.pushButton_7.clicked.connect(Form.bg_mode) 46 | self.pushButton_8.clicked.connect(Form.hair_mode) 47 | self.pushButton_9.clicked.connect(Form.clear) 48 | self.pushButton_10.clicked.connect(Form.increase) 49 | self.pushButton_11.clicked.connect(Form.decrease) 50 | 51 | 52 | self.grid1 = QGridLayout() 53 | self.setLayout(self.grid) 54 | self.resize(60, 100) 55 | self.grid1.addWidget(self.pushButton_2, 0,0,1,1) 56 | self.grid1.addWidget(self.pushButton_3, 1, 0, 1, 1) 57 | self.grid1.addWidget(self.pushButton_4, 1, 0, 1, 1) 58 | self.grid1.addWidget(self.pushButton_5, 1, 1, 1, 1) 59 | 60 | 61 | if __name__=="__main__": 62 | app=QApplication(sys.argv) 63 | win=GUIButton() 64 | win.show() 65 | sys.exit(app.exec_()) -------------------------------------------------------------------------------- /ui/ui_palette.py: -------------------------------------------------------------------------------- 1 | from PyQt5.QtCore import * 2 | from PyQt5.QtGui import * 3 | from PyQt5.QtWidgets import * 4 | import numpy as np 5 | 6 | 7 | class GUIPalette(QWidget): 8 | def __init__(self, grid_sz=(6, 3)): 9 | QWidget.__init__(self) 10 | self.color_width = 25 11 | self.border = 6 12 | self.win_width = grid_sz[0] * self.color_width + (grid_sz[0] + 1) * self.border 13 | self.win_height = grid_sz[1] * self.color_width + (grid_sz[1] + 1) * self.border 14 | self.setFixedSize(self.win_width, self.win_height) 15 | self.num_colors = grid_sz[0] * grid_sz[1] 16 | self.grid_sz = grid_sz 17 | self.colors = None 18 | self.color_id = -1 19 | self.reset() 20 | 21 | def set_colors(self, colors): 22 | if colors is not None: 23 | self.colors = (colors[:min(colors.shape[0], self.num_colors), :] * 255).astype(np.uint8) 24 | self.color_id = -1 25 | self.update() 26 | 27 | def paintEvent(self, event): 28 | painter = QPainter() 29 | painter.begin(self) 30 | painter.setRenderHint(QPainter.Antialiasing) 31 | painter.fillRect(event.rect(), Qt.white) 32 | if self.colors is not None: 33 | for n, c in enumerate(self.colors): 34 | ca = QColor(c[0], c[1], c[2], 255) 35 | painter.setPen(QPen(Qt.black, 1)) 36 | painter.setBrush(ca) 37 | grid_x = n % self.grid_sz[0] 38 | grid_y = (n - grid_x) // self.grid_sz[0] 39 | x = grid_x * (self.color_width + self.border) + self.border 40 | y = grid_y * (self.color_width + self.border) + self.border 41 | 42 | if n == self.color_id: 43 | painter.drawEllipse(x, y, self.color_width, self.color_width) 44 | else: 45 | painter.drawRoundedRect(x, y, self.color_width, self.color_width, 2, 2) 46 | 47 | painter.end() 48 | 49 | def sizeHint(self): 50 | return QSize(self.win_width, self.win_height) 51 | 52 | def reset(self): 53 | self.colors = None 54 | self.mouseClicked = False 55 | self.color_id = -1 56 | self.update() 57 | 58 | def selected_color(self, pos): 59 | width = self.color_width + self.border 60 | dx = pos.x() % width 61 | dy = pos.y() % width 62 | if dx >= self.border and dy >= self.border: 63 | x_id = (pos.x() - dx) // width 64 | y_id = (pos.y() - dy) // width 65 | color_id = x_id + y_id * self.grid_sz[0] 66 | return int(color_id) 67 | else: 68 | return -1 69 | 70 | def update_ui(self, color_id): 71 | self.color_id = int(color_id) 72 | self.update() 73 | if color_id >= 0: 74 | print('choose color (%d) type (%s)' % (color_id, type(color_id))) 75 | color = self.colors[color_id] 76 | self.emit(SIGNAL('update_color'), color) 77 | self.update() 78 | 79 | def mousePressEvent(self, event): 80 | if event.button() == Qt.LeftButton: # click the point 81 | color_id = self.selected_color(event.pos()) 82 | self.update_ui(color_id) 83 | self.mouseClicked = True 84 | 85 | def mouseMoveEvent(self, event): 86 | if self.mouseClicked: 87 | color_id = self.selected_color(event.pos()) 88 | self.update_ui(color_id) 89 | 90 | def mouseReleaseEvent(self, event): 91 | self.mouseClicked = False 92 | -------------------------------------------------------------------------------- /ui_util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tzt101/MichiGAN/31596104075a8a64fd94ec7fb8bc41504c62304e/ui_util/__init__.py -------------------------------------------------------------------------------- /ui_util/cal_orient_stroke.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import random 5 | import PIL 6 | import numpy as np 7 | from torchvision import transforms 8 | from PIL import Image 9 | import cv2 10 | import math 11 | import matplotlib.pyplot as plt 12 | 13 | def gabor_fn(kernel_size, channel_in, channel_out, theta): 14 | # sigma_x = sigma 15 | # sigma_y = sigma.float() / gamma 16 | sigma_x = nn.Parameter(torch.ones(channel_out) * 2.0, requires_grad=False).cuda() 17 | sigma_y = nn.Parameter(torch.ones(channel_out) * 3.0, requires_grad=False).cuda() 18 | Lambda = nn.Parameter(torch.ones(channel_out) * 4.0, requires_grad=False).cuda() 19 | psi = nn.Parameter(torch.ones(channel_out) * 0.0, requires_grad=False).cuda() 20 | 21 | # Bounding box 22 | xmax = kernel_size // 2 23 | ymax = kernel_size // 2 24 | xmin = -xmax 25 | ymin = -ymax 26 | ksize = xmax - xmin + 1 27 | y_0 = torch.arange(ymin, ymax+1).cuda() 28 | y = y_0.view(1, -1).repeat(channel_out, channel_in, ksize, 1).float() 29 | x_0 = torch.arange(xmin, xmax+1).cuda() 30 | x = x_0.view(-1, 1).repeat(channel_out, channel_in, 1, ksize).float() # [channel_out, channelin, kernel, kernel] 31 | 32 | # Rotation 33 | # don't need to expand, use broadcasting, [64, 1, 1, 1] + [64, 3, 7, 7] 34 | x_theta = x * torch.cos(theta.view(-1, 1, 1, 1)) + y * torch.sin(theta.view(-1, 1, 1, 1)) 35 | y_theta = -x * torch.sin(theta.view(-1, 1, 1, 1)) + y * torch.cos(theta.view(-1, 1, 1, 1)) 36 | 37 | # [channel_out, channel_in, kernel, kernel] 38 | gb = torch.exp(-.5 * (x_theta ** 2 / sigma_x.view(-1, 1, 1, 1) ** 2 + y_theta ** 2 / sigma_y.view(-1, 1, 1, 1) ** 2)) \ 39 | * torch.cos(2 * math.pi / Lambda.view(-1, 1, 1, 1) * x_theta + psi.view(-1, 1, 1, 1)) 40 | 41 | return gb 42 | 43 | def DoG_fn(kernel_size, channel_in, channel_out, theta): 44 | # params 45 | sigma_h = nn.Parameter(torch.ones(channel_out) * 1.0, requires_grad=False).cuda() 46 | sigma_l = nn.Parameter(torch.ones(channel_out) * 2.0, requires_grad=False).cuda() 47 | sigma_y = nn.Parameter(torch.ones(channel_out) * 2.0, requires_grad=False).cuda() 48 | 49 | # Bounding box 50 | xmax = kernel_size // 2 51 | ymax = kernel_size // 2 52 | xmin = -xmax 53 | ymin = -ymax 54 | ksize = xmax - xmin + 1 55 | y_0 = torch.arange(ymin, ymax+1).cuda() 56 | y = y_0.view(1, -1).repeat(channel_out, channel_in, ksize, 1).float() 57 | x_0 = torch.arange(xmin, xmax+1).cuda() 58 | x = x_0.view(-1, 1).repeat(channel_out, channel_in, 1, ksize).float() # [channel_out, channelin, kernel, kernel] 59 | 60 | # Rotation 61 | # don't need to expand, use broadcasting, [64, 1, 1, 1] + [64, 3, 7, 7] 62 | x_theta = x * torch.cos(theta.view(-1, 1, 1, 1)) + y * torch.sin(theta.view(-1, 1, 1, 1)) 63 | y_theta = -x * torch.sin(theta.view(-1, 1, 1, 1)) + y * torch.cos(theta.view(-1, 1, 1, 1)) 64 | 65 | gb = (torch.exp(-.5 * (x_theta ** 2 / sigma_h.view(-1, 1, 1, 1) ** 2 + y_theta ** 2 / sigma_y.view(-1, 1, 1, 1) ** 2))/sigma_h \ 66 | - torch.exp(-.5 * (x_theta ** 2 / sigma_l.view(-1, 1, 1, 1) ** 2 + y_theta ** 2 / sigma_y.view(-1, 1, 1, 1) ** 2))/sigma_l) \ 67 | / (1.0/sigma_h - 1.0/sigma_l) 68 | 69 | return gb 70 | 71 | # L1 loss of orientation map 72 | class orient(nn.Module): 73 | def __init__(self, channel_in=1, channel_out=1, stride=1, padding=8, mode='dog'): 74 | super(orient, self).__init__() 75 | self.criterion = nn.L1Loss() 76 | self.channel_in = channel_in 77 | self.channel_out = channel_out 78 | self.stride = stride 79 | self.padding = padding 80 | self.filter = gabor_fn if mode == 'gabor' else DoG_fn 81 | 82 | self.numKernels = 32 83 | self.kernel_size = 17 84 | 85 | def calOrientation(self, image, mask=None): 86 | resArray = [] 87 | # filter the image with different orientations 88 | for iOrient in range(self.numKernels): 89 | theta = nn.Parameter(torch.ones(self.channel_out) * (math.pi * iOrient / self.numKernels), 90 | requires_grad=False).cuda() 91 | filterKernel = self.filter(self.kernel_size, self.channel_in, self.channel_out, theta) 92 | filterKernel = filterKernel.float() 93 | response = F.conv2d(image, filterKernel, stride=self.stride, padding=self.padding) 94 | resArray.append(response.clone()) 95 | 96 | resTensor = resArray[0] 97 | for iOrient in range(1, self.numKernels): 98 | resTensor = torch.cat([resTensor, resArray[iOrient]], dim=1) 99 | 100 | # argmax the response 101 | resTensor[resTensor < 0] = 0 102 | maxResTensor = torch.argmax(resTensor, dim=1).float() 103 | confidenceTensor = torch.max(resTensor, dim=1)[0] 104 | # confidenceTensor = (torch.tanh(confidenceTensor)+1)/2.0 # [0, 1] 105 | # confidenceTensor = confidenceTensor / torch.max(confidenceTensor) 106 | confidenceTensor = torch.unsqueeze(confidenceTensor, 1) 107 | # print(torch.unique(confidenceTensor)) 108 | # th = 0.4 109 | # 110 | # confidenceTensor[confidenceTensor >= th] = 1 111 | # confidenceTensor[confidenceTensor < th] = 0 112 | # print(torch.unique(confidenceTensor)) 113 | # print(torch.sum(confidenceTensor)) 114 | # confidenceTensor = torch.unsqueeze(confidenceTensor, 1) / torch.max(confidenceTensor) 115 | 116 | # cal the angle a 117 | orientTensor = maxResTensor * math.pi / self.numKernels 118 | orientTensor = torch.unsqueeze(orientTensor, 1) 119 | # cal the sin2a and cos2a 120 | orientTwoChannel = torch.cat([torch.sin(2 * orientTensor), torch.cos(2 * orientTensor)], dim=1) 121 | return orientTwoChannel, confidenceTensor 122 | 123 | def convert_orient_to_RGB_test(self, input, label): 124 | import torch 125 | label = label.float() 126 | input = input * label 127 | out_r = torch.unsqueeze(input[1, :, :] * label[0, ...] + (1 - label[0, ...]) * -1, 0) 128 | out_g = torch.unsqueeze(input[0, :, :] * label[0, ...] + (1 - label[0, ...]) * -1, 0) 129 | out_b = torch.unsqueeze(input[0, :, :] * 0 * label[0, ...] + (1 - label[0, ...]) * -1, 0) 130 | # print(out_b.shape) 131 | return torch.cat([out_r, out_g, out_b], dim=0) 132 | 133 | def stroke_to_orient(self, stroke_mask): 134 | ''' 135 | :param stroke_mask: type: np.array, shape: 512*512, range: {0, 1} 136 | :return: type: np.array, shape: 512*512, range: [0,255] 137 | ''' 138 | stroke_mask_img = Image.fromarray(np.uint8(stroke_mask*255)) 139 | trans_label = transforms.Compose([transforms.ToTensor()]) 140 | 141 | stroke_mask_tensor = trans_label(stroke_mask_img) 142 | 143 | stroke_mask_tensor = torch.unsqueeze(stroke_mask_tensor, 0).cuda() 144 | 145 | orient_tensor, confidence_tensor = self.calOrientation(stroke_mask_tensor) 146 | orient_tensor = orient_tensor * stroke_mask_tensor 147 | # vis 148 | orient_rgb = self.convert_orient_to_RGB_test(orient_tensor[0, ...], stroke_mask_tensor[0, ...]) # [3, h, w] 149 | orient_numpy = (np.transpose(orient_rgb.cpu().numpy(), (1, 2, 0)) + 1) / 2.0 * 255.0 150 | 151 | return orient_numpy -------------------------------------------------------------------------------- /ui_util/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import os 4 | import logging 5 | 6 | logger = logging.getLogger() 7 | 8 | class Config(object): 9 | def __init__(self, filename=None): 10 | assert os.path.exists(filename), "ERROR: Config File doesn't exist." 11 | try: 12 | with open(filename, 'r') as f: 13 | self._cfg_dict = yaml.load(f) 14 | # parent of IOError, OSError *and* WindowsError where available 15 | except EnvironmentError: 16 | logger.error('Please check the file with name of "%s"', filename) 17 | logger.info(' APP CONFIG '.center(80, '-')) 18 | logger.info(''.center(80, '-')) 19 | 20 | def __getattr__(self, name): 21 | value = self._cfg_dict[name] 22 | if isinstance(value, dict): 23 | value = DictAsMember(value) 24 | return value 25 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | -------------------------------------------------------------------------------- /util/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | 7 | def id2label(id): 8 | if id == 182: 9 | id = 0 10 | else: 11 | id = id + 1 12 | labelmap = \ 13 | {0: 'unlabeled', 14 | 1: 'person', 15 | 2: 'bicycle', 16 | 3: 'car', 17 | 4: 'motorcycle', 18 | 5: 'airplane', 19 | 6: 'bus', 20 | 7: 'train', 21 | 8: 'truck', 22 | 9: 'boat', 23 | 10: 'traffic light', 24 | 11: 'fire hydrant', 25 | 12: 'street sign', 26 | 13: 'stop sign', 27 | 14: 'parking meter', 28 | 15: 'bench', 29 | 16: 'bird', 30 | 17: 'cat', 31 | 18: 'dog', 32 | 19: 'horse', 33 | 20: 'sheep', 34 | 21: 'cow', 35 | 22: 'elephant', 36 | 23: 'bear', 37 | 24: 'zebra', 38 | 25: 'giraffe', 39 | 26: 'hat', 40 | 27: 'backpack', 41 | 28: 'umbrella', 42 | 29: 'shoe', 43 | 30: 'eye glasses', 44 | 31: 'handbag', 45 | 32: 'tie', 46 | 33: 'suitcase', 47 | 34: 'frisbee', 48 | 35: 'skis', 49 | 36: 'snowboard', 50 | 37: 'sports ball', 51 | 38: 'kite', 52 | 39: 'baseball bat', 53 | 40: 'baseball glove', 54 | 41: 'skateboard', 55 | 42: 'surfboard', 56 | 43: 'tennis racket', 57 | 44: 'bottle', 58 | 45: 'plate', 59 | 46: 'wine glass', 60 | 47: 'cup', 61 | 48: 'fork', 62 | 49: 'knife', 63 | 50: 'spoon', 64 | 51: 'bowl', 65 | 52: 'banana', 66 | 53: 'apple', 67 | 54: 'sandwich', 68 | 55: 'orange', 69 | 56: 'broccoli', 70 | 57: 'carrot', 71 | 58: 'hot dog', 72 | 59: 'pizza', 73 | 60: 'donut', 74 | 61: 'cake', 75 | 62: 'chair', 76 | 63: 'couch', 77 | 64: 'potted plant', 78 | 65: 'bed', 79 | 66: 'mirror', 80 | 67: 'dining table', 81 | 68: 'window', 82 | 69: 'desk', 83 | 70: 'toilet', 84 | 71: 'door', 85 | 72: 'tv', 86 | 73: 'laptop', 87 | 74: 'mouse', 88 | 75: 'remote', 89 | 76: 'keyboard', 90 | 77: 'cell phone', 91 | 78: 'microwave', 92 | 79: 'oven', 93 | 80: 'toaster', 94 | 81: 'sink', 95 | 82: 'refrigerator', 96 | 83: 'blender', 97 | 84: 'book', 98 | 85: 'clock', 99 | 86: 'vase', 100 | 87: 'scissors', 101 | 88: 'teddy bear', 102 | 89: 'hair drier', 103 | 90: 'toothbrush', 104 | 91: 'hair brush', # Last class of Thing 105 | 92: 'banner', # Beginning of Stuff 106 | 93: 'blanket', 107 | 94: 'branch', 108 | 95: 'bridge', 109 | 96: 'building-other', 110 | 97: 'bush', 111 | 98: 'cabinet', 112 | 99: 'cage', 113 | 100: 'cardboard', 114 | 101: 'carpet', 115 | 102: 'ceiling-other', 116 | 103: 'ceiling-tile', 117 | 104: 'cloth', 118 | 105: 'clothes', 119 | 106: 'clouds', 120 | 107: 'counter', 121 | 108: 'cupboard', 122 | 109: 'curtain', 123 | 110: 'desk-stuff', 124 | 111: 'dirt', 125 | 112: 'door-stuff', 126 | 113: 'fence', 127 | 114: 'floor-marble', 128 | 115: 'floor-other', 129 | 116: 'floor-stone', 130 | 117: 'floor-tile', 131 | 118: 'floor-wood', 132 | 119: 'flower', 133 | 120: 'fog', 134 | 121: 'food-other', 135 | 122: 'fruit', 136 | 123: 'furniture-other', 137 | 124: 'grass', 138 | 125: 'gravel', 139 | 126: 'ground-other', 140 | 127: 'hill', 141 | 128: 'house', 142 | 129: 'leaves', 143 | 130: 'light', 144 | 131: 'mat', 145 | 132: 'metal', 146 | 133: 'mirror-stuff', 147 | 134: 'moss', 148 | 135: 'mountain', 149 | 136: 'mud', 150 | 137: 'napkin', 151 | 138: 'net', 152 | 139: 'paper', 153 | 140: 'pavement', 154 | 141: 'pillow', 155 | 142: 'plant-other', 156 | 143: 'plastic', 157 | 144: 'platform', 158 | 145: 'playingfield', 159 | 146: 'railing', 160 | 147: 'railroad', 161 | 148: 'river', 162 | 149: 'road', 163 | 150: 'rock', 164 | 151: 'roof', 165 | 152: 'rug', 166 | 153: 'salad', 167 | 154: 'sand', 168 | 155: 'sea', 169 | 156: 'shelf', 170 | 157: 'sky-other', 171 | 158: 'skyscraper', 172 | 159: 'snow', 173 | 160: 'solid-other', 174 | 161: 'stairs', 175 | 162: 'stone', 176 | 163: 'straw', 177 | 164: 'structural-other', 178 | 165: 'table', 179 | 166: 'tent', 180 | 167: 'textile-other', 181 | 168: 'towel', 182 | 169: 'tree', 183 | 170: 'vegetable', 184 | 171: 'wall-brick', 185 | 172: 'wall-concrete', 186 | 173: 'wall-other', 187 | 174: 'wall-panel', 188 | 175: 'wall-stone', 189 | 176: 'wall-tile', 190 | 177: 'wall-wood', 191 | 178: 'water-other', 192 | 179: 'waterdrops', 193 | 180: 'window-blind', 194 | 181: 'window-other', 195 | 182: 'wood'} 196 | if id in labelmap: 197 | return labelmap[id] 198 | else: 199 | return 'unknown' 200 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | import datetime 7 | import dominate 8 | from dominate.tags import * 9 | import os 10 | 11 | 12 | class HTML: 13 | def __init__(self, web_dir, title, refresh=0): 14 | if web_dir.endswith('.html'): 15 | web_dir, html_name = os.path.split(web_dir) 16 | else: 17 | web_dir, html_name = web_dir, 'index.html' 18 | self.title = title 19 | self.web_dir = web_dir 20 | self.html_name = html_name 21 | self.img_dir = os.path.join(self.web_dir, 'images') 22 | if len(self.web_dir) > 0 and not os.path.exists(self.web_dir): 23 | os.makedirs(self.web_dir) 24 | if len(self.web_dir) > 0 and not os.path.exists(self.img_dir): 25 | os.makedirs(self.img_dir) 26 | 27 | self.doc = dominate.document(title=title) 28 | with self.doc: 29 | h1(datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")) 30 | if refresh > 0: 31 | with self.doc.head: 32 | meta(http_equiv="refresh", content=str(refresh)) 33 | 34 | def get_image_dir(self): 35 | return self.img_dir 36 | 37 | def add_header(self, str): 38 | with self.doc: 39 | h3(str) 40 | 41 | def add_table(self, border=1): 42 | self.t = table(border=border, style="table-layout: fixed;") 43 | self.doc.add(self.t) 44 | 45 | def add_images(self, ims, txts, links, width=512): 46 | self.add_table() 47 | with self.t: 48 | with tr(): 49 | for im, txt, link in zip(ims, txts, links): 50 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 51 | with p(): 52 | with a(href=os.path.join('images', link)): 53 | img(style="width:%dpx" % (width), src=os.path.join('images', im)) 54 | br() 55 | p(txt.encode('utf-8')) 56 | 57 | def save(self): 58 | html_file = os.path.join(self.web_dir, self.html_name) 59 | f = open(html_file, 'wt') 60 | f.write(self.doc.render()) 61 | f.close() 62 | 63 | 64 | if __name__ == '__main__': 65 | html = HTML('web/', 'test_html') 66 | html.add_header('hello world') 67 | 68 | ims = [] 69 | txts = [] 70 | links = [] 71 | for n in range(4): 72 | ims.append('image_%d.jpg' % n) 73 | txts.append('text_%d' % n) 74 | links.append('image_%d.jpg' % n) 75 | html.add_images(ims, txts, links) 76 | html.save() 77 | -------------------------------------------------------------------------------- /util/iter_counter.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | import os 7 | import time 8 | import numpy as np 9 | 10 | 11 | # Helper class that keeps track of training iterations 12 | class IterationCounter(): 13 | def __init__(self, opt, dataset_size): 14 | self.opt = opt 15 | self.dataset_size = dataset_size 16 | 17 | self.first_epoch = 1 18 | self.total_epochs = opt.niter + opt.niter_decay # 50 19 | self.epoch_iter = 0 # iter number within each epoch 20 | self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, 'iter.txt') 21 | if opt.isTrain and opt.continue_train: 22 | try: 23 | self.first_epoch, self.epoch_iter = np.loadtxt( 24 | self.iter_record_path, delimiter=',', dtype=int) 25 | print('Resuming from epoch %d at iteration %d' % (self.first_epoch, self.epoch_iter)) 26 | except: 27 | print('Could not load iteration record at %s. Starting from beginning.' % 28 | self.iter_record_path) 29 | 30 | self.total_steps_so_far = (self.first_epoch - 1) * dataset_size + self.epoch_iter 31 | 32 | # return the iterator of epochs for the training 33 | def training_epochs(self): 34 | return range(self.first_epoch, self.total_epochs + 1) 35 | 36 | def record_epoch_start(self, epoch): 37 | self.epoch_start_time = time.time() 38 | self.epoch_iter = 0 39 | self.last_iter_time = time.time() 40 | self.current_epoch = epoch 41 | 42 | def record_one_iteration(self): 43 | current_time = time.time() 44 | 45 | # the last remaining batch is dropped (see data/__init__.py), 46 | # so we can assume batch size is always opt.batchSize 47 | self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize 48 | self.last_iter_time = current_time 49 | self.total_steps_so_far += self.opt.batchSize 50 | self.epoch_iter += self.opt.batchSize 51 | 52 | def record_epoch_end(self): 53 | current_time = time.time() 54 | self.time_per_epoch = current_time - self.epoch_start_time 55 | print('End of epoch %d / %d \t Time Taken: %d sec' % 56 | (self.current_epoch, self.total_epochs, self.time_per_epoch)) 57 | if self.current_epoch % self.opt.save_epoch_freq == 0: 58 | np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), 59 | delimiter=',', fmt='%d') 60 | print('Saved current iteration count at %s.' % self.iter_record_path) 61 | 62 | def record_current_iter(self): 63 | np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), 64 | delimiter=',', fmt='%d') 65 | print('Saved current iteration count at %s.' % self.iter_record_path) 66 | 67 | def needs_saving(self): 68 | return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize 69 | 70 | def needs_printing(self): 71 | return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize 72 | 73 | def needs_displaying(self): 74 | return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize 75 | -------------------------------------------------------------------------------- /util/iter_counter_ms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | import os 7 | import time 8 | import numpy as np 9 | 10 | 11 | # Helper class that keeps track of training iterations 12 | class IterationCounter(): 13 | def __init__(self, opt): 14 | self.opt = opt 15 | 16 | self.first_epoch = 1 17 | self.total_epochs = opt.niter + opt.niter_decay # 50 18 | self.epoch_iter = 0 # iter number within each epoch 19 | self.total_steps_so_far = 0 20 | self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, 'iter.txt') 21 | if opt.isTrain and opt.continue_train: 22 | try: 23 | self.first_epoch, self.epoch_iter, self.total_steps_so_far = np.loadtxt( 24 | self.iter_record_path, delimiter=',', dtype=int) 25 | print('Resuming from epoch %d at iteration %d' % (self.first_epoch, self.epoch_iter)) 26 | except: 27 | print('Could not load iteration record at %s. Starting from beginning.' % 28 | self.iter_record_path) 29 | 30 | # return the iterator of epochs for the training 31 | def training_epochs(self): 32 | return range(self.first_epoch, self.total_epochs + 1) 33 | 34 | def record_epoch_start(self, epoch): 35 | self.epoch_start_time = time.time() 36 | self.epoch_iter = 0 37 | self.last_iter_time = time.time() 38 | self.current_epoch = epoch 39 | 40 | def record_one_iteration(self): 41 | current_time = time.time() 42 | 43 | # the last remaining batch is dropped (see data/__init__.py), 44 | # so we can assume batch size is always opt.batchSize 45 | self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize 46 | self.last_iter_time = current_time 47 | self.total_steps_so_far += self.opt.batchSize 48 | self.epoch_iter += self.opt.batchSize 49 | # print(self.opt.batchSize) 50 | 51 | def record_epoch_end(self): 52 | current_time = time.time() 53 | self.time_per_epoch = current_time - self.epoch_start_time 54 | print('End of epoch %d / %d \t Time Taken: %d sec' % 55 | (self.current_epoch, self.total_epochs, self.time_per_epoch)) 56 | if self.current_epoch % self.opt.save_epoch_freq == 0: 57 | np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0, self.total_steps_so_far), 58 | delimiter=',', fmt='%d') 59 | print('Saved current iteration count at %s.' % self.iter_record_path) 60 | 61 | def record_current_iter(self): 62 | np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter, self.total_steps_so_far), 63 | delimiter=',', fmt='%d') 64 | print('Saved current iteration count at %s.' % self.iter_record_path) 65 | 66 | def needs_saving(self): 67 | return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize 68 | 69 | def needs_printing(self): 70 | return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize 71 | 72 | def needs_displaying(self): 73 | return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize 74 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | import re 7 | import importlib 8 | import torch 9 | from argparse import Namespace 10 | import numpy as np 11 | from PIL import Image 12 | import os 13 | import argparse 14 | import dill as pickle 15 | import util.coco 16 | import cv2 17 | 18 | def save_obj(obj, name): 19 | with open(name, 'wb') as f: 20 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 21 | 22 | 23 | def load_obj(name): 24 | with open(name, 'rb') as f: 25 | return pickle.load(f) 26 | 27 | # returns a configuration for creating a generator 28 | # |default_opt| should be the opt of the current experiment 29 | # |**kwargs|: if any configuration should be overriden, it can be specified here 30 | 31 | 32 | def copyconf(default_opt, **kwargs): 33 | conf = argparse.Namespace(**vars(default_opt)) 34 | for key in kwargs: 35 | print(key, kwargs[key]) 36 | setattr(conf, key, kwargs[key]) 37 | return conf 38 | 39 | 40 | def tile_images(imgs, picturesPerRow=4): 41 | """ Code borrowed from 42 | https://stackoverflow.com/questions/26521365/cleanly-tile-numpy-array-of-images-stored-in-a-flattened-1d-format/26521997 43 | """ 44 | 45 | # Padding 46 | if imgs.shape[0] % picturesPerRow == 0: 47 | rowPadding = 0 48 | else: 49 | rowPadding = picturesPerRow - imgs.shape[0] % picturesPerRow 50 | if rowPadding > 0: 51 | imgs = np.concatenate([imgs, np.zeros((rowPadding, *imgs.shape[1:]), dtype=imgs.dtype)], axis=0) 52 | 53 | # Tiling Loop (The conditionals are not necessary anymore) 54 | tiled = [] 55 | for i in range(0, imgs.shape[0], picturesPerRow): 56 | tiled.append(np.concatenate([imgs[j] for j in range(i, i + picturesPerRow)], axis=1)) 57 | 58 | tiled = np.concatenate(tiled, axis=0) 59 | return tiled 60 | 61 | 62 | # Converts a Tensor into a Numpy array 63 | # |imtype|: the desired type of the converted numpy array 64 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True, tile=False): 65 | if isinstance(image_tensor, list): 66 | image_numpy = [] 67 | for i in range(len(image_tensor)): 68 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) 69 | return image_numpy 70 | 71 | if image_tensor.dim() == 4: 72 | # transform each image in the batch 73 | images_np = [] 74 | for b in range(image_tensor.size(0)): 75 | one_image = image_tensor[b] 76 | one_image_np = tensor2im(one_image) 77 | images_np.append(one_image_np.reshape(1, *one_image_np.shape)) 78 | images_np = np.concatenate(images_np, axis=0) 79 | if tile: 80 | images_tiled = tile_images(images_np) 81 | return images_tiled 82 | else: 83 | return images_np 84 | 85 | if image_tensor.dim() == 2: 86 | image_tensor = image_tensor.unsqueeze(0) 87 | image_numpy = image_tensor.detach().cpu().float().numpy() 88 | if normalize: 89 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 90 | else: 91 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 92 | image_numpy = np.clip(image_numpy, 0, 255) 93 | if image_numpy.shape[2] == 1: 94 | image_numpy = image_numpy[:, :, 0] 95 | return image_numpy.astype(imtype) 96 | 97 | 98 | # Converts a one-hot tensor into a colorful label map 99 | def tensor2label(label_tensor, n_label, imtype=np.uint8, tile=False): 100 | if label_tensor.dim() == 4: 101 | # transform each image in the batch 102 | images_np = [] 103 | for b in range(label_tensor.size(0)): 104 | one_image = label_tensor[b] 105 | one_image_np = tensor2label(one_image, n_label, imtype) 106 | images_np.append(one_image_np.reshape(1, *one_image_np.shape)) 107 | images_np = np.concatenate(images_np, axis=0) 108 | if tile: 109 | images_tiled = tile_images(images_np) 110 | return images_tiled 111 | else: 112 | images_np = images_np[0] 113 | return images_np 114 | 115 | if label_tensor.dim() == 1: 116 | return np.zeros((64, 64, 3), dtype=np.uint8) 117 | if n_label == 0: 118 | return tensor2im(label_tensor, imtype) 119 | label_tensor = label_tensor.cpu().float() 120 | if label_tensor.size()[0] > 1: 121 | label_tensor = label_tensor.max(0, keepdim=True)[1] 122 | label_tensor = Colorize(n_label)(label_tensor) 123 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) 124 | result = label_numpy.astype(imtype) 125 | return result 126 | 127 | 128 | def save_image(image_numpy, image_path, create_dir=False): 129 | if create_dir: 130 | os.makedirs(os.path.dirname(image_path), exist_ok=True) 131 | if len(image_numpy.shape) == 2: 132 | image_numpy = np.expand_dims(image_numpy, axis=2) 133 | if image_numpy.shape[2] == 1: 134 | image_numpy = np.repeat(image_numpy, 3, 2) 135 | image_pil = Image.fromarray(image_numpy) 136 | 137 | # save to png 138 | image_pil.save(image_path.replace('.jpg', '.png')) 139 | 140 | 141 | def mkdirs(paths): 142 | if isinstance(paths, list) and not isinstance(paths, str): 143 | for path in paths: 144 | mkdir(path) 145 | else: 146 | mkdir(paths) 147 | 148 | 149 | def mkdir(path): 150 | if not os.path.exists(path): 151 | os.makedirs(path) 152 | 153 | 154 | def atoi(text): 155 | return int(text) if text.isdigit() else text 156 | 157 | 158 | def natural_keys(text): 159 | ''' 160 | alist.sort(key=natural_keys) sorts in human order 161 | http://nedbatchelder.com/blog/200712/human_sorting.html 162 | (See Toothy's implementation in the comments) 163 | ''' 164 | return [atoi(c) for c in re.split('(\d+)', text)] 165 | 166 | 167 | def natural_sort(items): 168 | items.sort(key=natural_keys) 169 | 170 | 171 | def str2bool(v): 172 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 173 | return True 174 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 175 | return False 176 | else: 177 | raise argparse.ArgumentTypeError('Boolean value expected.') 178 | 179 | 180 | def find_class_in_module(target_cls_name, module): 181 | target_cls_name = target_cls_name.replace('_', '').lower() 182 | clslib = importlib.import_module(module) 183 | cls = None 184 | for name, clsobj in clslib.__dict__.items(): 185 | if name.lower() == target_cls_name: 186 | cls = clsobj 187 | 188 | if cls is None: 189 | print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)) 190 | exit(0) 191 | 192 | return cls 193 | 194 | 195 | def save_network(net, label, epoch, opt): 196 | save_filename = '%s_net_%s.pth' % (epoch, label) 197 | save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename) 198 | torch.save(net.cpu().state_dict(), save_path) 199 | if len(opt.gpu_ids) and torch.cuda.is_available(): 200 | net.cuda() 201 | 202 | def load_weights(cnn_model, weights): 203 | """ 204 | argus: 205 | :param cnn_model: the cnn networks need to load weights 206 | :param weights: the pretrained weigths 207 | :return: no return 208 | """ 209 | from torch.nn.parameter import Parameter 210 | pre_dict = cnn_model.state_dict() 211 | for key, val in weights.items(): 212 | if key[0:7] == 'module.': # the pretrained networks was trained on multi-GPU 213 | key = key[7:] # remove 'module.' from the key 214 | if key in pre_dict.keys(): 215 | if isinstance(val, Parameter): 216 | val = val.data 217 | pre_dict[key].copy_(val) 218 | cnn_model.load_state_dict(pre_dict) 219 | 220 | 221 | def load_network(net, label, epoch, opt): 222 | if 'D' in label and not opt.same_netD_model and opt.unpairTrain: 223 | save_filename = '%s_net_%s2.pth' % (epoch, label) 224 | else: 225 | save_filename = '%s_net_%s.pth' % (epoch, label) 226 | save_dir = os.path.join(opt.checkpoints_dir, opt.name) 227 | save_path = os.path.join(save_dir, save_filename) 228 | weights = torch.load(save_path) 229 | # net.load_state_dict(weights) 230 | load_weights(net, weights) 231 | return net 232 | 233 | def load_blend_network(net, label, epoch, opt): 234 | save_filename = '%s_net_%s.pth' % (epoch, label) 235 | save_dir = os.path.join(opt.checkpoints_dir, opt.name) 236 | save_path = os.path.join(save_dir, save_filename) 237 | if os.path.exists(save_path): 238 | weights = torch.load(save_path) 239 | net.load_state_dict(weights) 240 | return net 241 | else: 242 | return net 243 | 244 | 245 | def load_inpainting_network(net, opt): 246 | save_filename = opt.ig_model_name 247 | save_dir = os.path.join(opt.checkpoints_dir, opt.name) 248 | save_path = os.path.join(save_dir, save_filename) 249 | if torch.cuda.is_available(): 250 | data = torch.load(save_path) 251 | else: 252 | data = torch.load(save_path, map_location=lambda storage, loc: storage) 253 | if len(opt.gpu_ids) > 1: 254 | net.load_state_dict(data['generator']) 255 | else: 256 | net.load_state_dict(data['generator']) 257 | return net 258 | 259 | def load_sinpainting_network(net, opt): 260 | save_filename = opt.sig_model_name 261 | save_dir = os.path.join(opt.checkpoints_dir, opt.name) 262 | save_path = os.path.join(save_dir, save_filename) 263 | print(save_path) 264 | if torch.cuda.is_available(): 265 | data = torch.load(save_path) 266 | else: 267 | data = torch.load(save_path, map_location=lambda storage, loc: storage) 268 | if len(opt.gpu_ids) > 1: 269 | net.load_state_dict(data['generator']) 270 | else: 271 | net.load_state_dict(data['generator']) 272 | return net 273 | 274 | def blend_image(fake, tag, mask): 275 | fake_mask = np.ones_like(fake.shape, fake.dtype) 276 | fake_mask = fake_mask * mask[...,np.newaxis] * 255 277 | fake_mask = fake_mask.astype(fake.dtype) 278 | coord = np.where(mask == 1) 279 | h_min = np.min(coord[0]) 280 | h_max = np.max(coord[0]) 281 | w_min = np.min(coord[1]) 282 | w_max = np.max(coord[1]) 283 | center = (int((w_min+w_max)/2), int((h_min+h_max)/2)) 284 | out = cv2.seamlessClone(fake, tag, fake_mask, center, cv2.MIXED_CLONE) 285 | return out 286 | 287 | 288 | ############################################################################### 289 | # Code from 290 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py 291 | # Modified so it complies with the Citscape label map colors 292 | ############################################################################### 293 | def uint82bin(n, count=8): 294 | """returns the binary of integer n, count refers to amount of bits""" 295 | return ''.join([str((n >> y) & 1) for y in range(count - 1, -1, -1)]) 296 | 297 | 298 | def labelcolormap(N): 299 | if N == 35: # cityscape 300 | cmap = np.array([(0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 0, 0), (111, 74, 0), (81, 0, 81), 301 | (128, 64, 128), (244, 35, 232), (250, 170, 160), (230, 150, 140), (70, 70, 70), (102, 102, 156), (190, 153, 153), 302 | (180, 165, 180), (150, 100, 100), (150, 120, 90), (153, 153, 153), (153, 153, 153), (250, 170, 30), (220, 220, 0), 303 | (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0), (0, 0, 142), (0, 0, 70), 304 | (0, 60, 100), (0, 0, 90), (0, 0, 110), (0, 80, 100), (0, 0, 230), (119, 11, 32), (0, 0, 142)], 305 | dtype=np.uint8) 306 | else: 307 | cmap = np.zeros((N, 3), dtype=np.uint8) 308 | for i in range(N): 309 | r, g, b = 0, 0, 0 310 | id = i + 1 # let's give 0 a color 311 | for j in range(7): 312 | str_id = uint82bin(id) 313 | r = r ^ (np.uint8(str_id[-1]) << (7 - j)) 314 | g = g ^ (np.uint8(str_id[-2]) << (7 - j)) 315 | b = b ^ (np.uint8(str_id[-3]) << (7 - j)) 316 | id = id >> 3 317 | cmap[i, 0] = r 318 | cmap[i, 1] = g 319 | cmap[i, 2] = b 320 | 321 | if N == 182: # COCO 322 | important_colors = { 323 | 'sea': (54, 62, 167), 324 | 'sky-other': (95, 219, 255), 325 | 'tree': (140, 104, 47), 326 | 'clouds': (170, 170, 170), 327 | 'grass': (29, 195, 49) 328 | } 329 | for i in range(N): 330 | name = util.coco.id2label(i) 331 | if name in important_colors: 332 | color = important_colors[name] 333 | cmap[i] = np.array(list(color)) 334 | 335 | return cmap 336 | 337 | 338 | class Colorize(object): 339 | def __init__(self, n=35): 340 | self.cmap = labelcolormap(n) 341 | self.cmap = torch.from_numpy(self.cmap[:n]) 342 | 343 | def __call__(self, gray_image): 344 | size = gray_image.size() 345 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 346 | 347 | for label in range(0, len(self.cmap)): 348 | mask = (label == gray_image[0]).cpu() 349 | color_image[0][mask] = self.cmap[label][0] 350 | color_image[1][mask] = self.cmap[label][1] 351 | color_image[2][mask] = self.cmap[label][2] 352 | 353 | return color_image 354 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) University of Science and Technology of China. 3 | Licensed under the MIT License. 4 | """ 5 | 6 | import os 7 | import ntpath 8 | import time 9 | from . import util 10 | from . import html 11 | import scipy.misc 12 | try: 13 | from StringIO import StringIO # Python 2.7 14 | except ImportError: 15 | from io import BytesIO # Python 3.x 16 | 17 | class Visualizer(): 18 | def __init__(self, opt): 19 | self.opt = opt 20 | self.tf_log = opt.isTrain and opt.tf_log 21 | self.use_html = opt.isTrain and not opt.no_html 22 | self.win_size = opt.display_winsize 23 | self.name = opt.name 24 | if self.tf_log: 25 | import tensorflow as tf 26 | self.tf = tf 27 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') 28 | self.writer = tf.summary.FileWriter(self.log_dir) 29 | 30 | if self.use_html: 31 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 32 | self.img_dir = os.path.join(self.web_dir, 'images') 33 | print('create web directory %s...' % self.web_dir) 34 | util.mkdirs([self.web_dir, self.img_dir]) 35 | if opt.isTrain: 36 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 37 | with open(self.log_name, "a") as log_file: 38 | now = time.strftime("%c") 39 | log_file.write('================ Training Loss (%s) ================\n' % now) 40 | 41 | # |visuals|: dictionary of images to display or save 42 | def display_current_results(self, visuals, epoch, step): 43 | 44 | ## convert tensors to numpy arrays 45 | visuals = self.convert_visuals_to_numpy(visuals) 46 | 47 | if self.tf_log: # show images in tensorboard output 48 | img_summaries = [] 49 | for label, image_numpy in visuals.items(): 50 | # Write the image to a string 51 | try: 52 | s = StringIO() 53 | except: 54 | s = BytesIO() 55 | if len(image_numpy.shape) >= 4: 56 | image_numpy = image_numpy[0] 57 | scipy.misc.toimage(image_numpy).save(s, format="jpeg") 58 | # Create an Image object 59 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) 60 | # Create a Summary value 61 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) 62 | 63 | # Create and write Summary 64 | summary = self.tf.Summary(value=img_summaries) 65 | self.writer.add_summary(summary, step) 66 | 67 | if self.use_html: # save images to a html file 68 | for label, image_numpy in visuals.items(): 69 | if isinstance(image_numpy, list): 70 | for i in range(len(image_numpy)): 71 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s_%d.png' % (epoch, step, label, i)) 72 | util.save_image(image_numpy[i], img_path) 73 | else: 74 | img_path = os.path.join(self.img_dir, 'epoch%.3d_iter%.3d_%s.png' % (epoch, step, label)) 75 | if len(image_numpy.shape) >= 4: 76 | image_numpy = image_numpy[0] 77 | util.save_image(image_numpy, img_path) 78 | 79 | # update website 80 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=5) 81 | for n in range(epoch, 0, -1): 82 | webpage.add_header('epoch [%d]' % n) 83 | ims = [] 84 | txts = [] 85 | links = [] 86 | 87 | for label, image_numpy in visuals.items(): 88 | if isinstance(image_numpy, list): 89 | for i in range(len(image_numpy)): 90 | img_path = 'epoch%.3d_iter%.3d_%s_%d.png' % (n, step, label, i) 91 | ims.append(img_path) 92 | txts.append(label+str(i)) 93 | links.append(img_path) 94 | else: 95 | img_path = 'epoch%.3d_iter%.3d_%s.png' % (n, step, label) 96 | ims.append(img_path) 97 | txts.append(label) 98 | links.append(img_path) 99 | if len(ims) < 10: 100 | webpage.add_images(ims, txts, links, width=self.win_size) 101 | else: 102 | num = int(round(len(ims)/2.0)) 103 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) 104 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) 105 | webpage.save() 106 | 107 | # errors: dictionary of error labels and values 108 | def plot_current_errors(self, errors, step): 109 | if self.tf_log: 110 | for tag, value in errors.items(): 111 | value = value.mean().float() 112 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) 113 | self.writer.add_summary(summary, step) 114 | 115 | # errors: same format as |errors| of plotCurrentErrors 116 | def print_current_errors(self, epoch, i, errors, t, curr_size=None): 117 | import time 118 | tt = time.asctime( time.localtime(time.time()) ) 119 | if curr_size is not None: 120 | message = '(step: %d, epoch: %d, iters: %d, time: %.3f, size: %d) ' % (self.opt.curr_step, epoch, i, t, curr_size) 121 | else: 122 | message = '(step: %d, epoch: %d, iters: %d, time: %.3f) ' % (self.opt.curr_step, epoch, i, t) 123 | message = str(tt) + message 124 | for k, v in errors.items(): 125 | #print(v) 126 | #if v != 0: 127 | v = v.mean().float() 128 | message += '%s: %.3f ' % (k, v) 129 | 130 | print(message) 131 | with open(self.log_name, "a") as log_file: 132 | log_file.write('%s\n' % message) 133 | 134 | def convert_visuals_to_numpy(self, visuals): 135 | for key, t in visuals.items(): 136 | tile = self.opt.batchSize > 8 137 | if 'input' in key: 138 | t = util.tensor2label(t, self.opt.label_nc + 2, tile=tile) 139 | else: 140 | t = util.tensor2im(t, tile=tile) 141 | visuals[key] = t 142 | return visuals 143 | 144 | def convert_orient_to_RGB_test(self, input, label): 145 | import torch 146 | label = label.float() 147 | input = input * label 148 | out_r = torch.unsqueeze(input[1,:,:]*label[0,...]+(1-label[0,...])*-1, 0) 149 | out_g = torch.unsqueeze(input[0,:,:]*label[0,...]+(1-label[0,...])*-1, 0) 150 | out_b = torch.unsqueeze(input[0,:,:]*0*label[0,...]+(1-label[0,...])*-1, 0) 151 | # print(out_b.shape) 152 | return torch.cat([out_r, out_g, out_b], dim=0) 153 | 154 | # save image to the disk 155 | def save_images(self, webpage, visuals, image_path): 156 | visuals = self.convert_visuals_to_numpy(visuals) 157 | 158 | image_dir = webpage.get_image_dir() 159 | short_path = ntpath.basename(image_path[0]) 160 | name = os.path.splitext(short_path)[0] 161 | 162 | webpage.add_header(name) 163 | ims = [] 164 | txts = [] 165 | links = [] 166 | 167 | for label, image_numpy in visuals.items(): 168 | image_name = os.path.join(label, '%s.png' % (name)) 169 | save_path = os.path.join(image_dir, image_name) 170 | util.save_image(image_numpy, save_path, create_dir=True) 171 | 172 | ims.append(image_name) 173 | txts.append(label) 174 | links.append(image_name) 175 | webpage.add_images(ims, txts, links, width=self.win_size) 176 | --------------------------------------------------------------------------------