├── .gitignore ├── LICENSE ├── data ├── __init__.py ├── aligned_dataset.py ├── base_data_loader.py ├── base_dataset.py ├── image_folder.py └── single_dataset.py ├── dataset ├── data │ └── test_single │ │ ├── WechatIMG253_2.png │ │ ├── WechatIMG256.png │ │ ├── WechatIMG260.png │ │ ├── WechatIMG261.png │ │ ├── WechatIMG262.png │ │ ├── WechatIMG263.png │ │ ├── brazil-1368806_1920.png │ │ ├── fconrad_Portrait_060414a.png │ │ ├── garden-2768329_1920.png │ │ ├── girl-2099354_1920.png │ │ ├── girl-2099357_1920.png │ │ ├── girl-2122909_1920.png │ │ ├── girl-2122927_1280.png │ │ ├── girl-2128294_1920.png │ │ ├── girl-2132171_1920.png │ │ ├── girl-2143709_1920.png │ │ ├── girl-2164409_1920.png │ │ ├── girl-2177360_1920.png │ │ ├── girl-2720476_1920.png │ │ ├── girl-2999078_1920.png │ │ ├── girl-4024238_1920.png │ │ ├── girl-4024240_1920.png │ │ ├── girl-4024244_1920.png │ │ ├── male-467711_1920.png │ │ ├── mid-autumn-2752710_1920.png │ │ ├── model-2134460_1920.png │ │ ├── model-2911329_1920.png │ │ ├── model-2911332_1920.png │ │ ├── own-2553537_1280.png │ │ ├── passport-picture-businesswoman-brown-hair-450w-250775908.png │ │ ├── pinky-2727846_1920.png │ │ ├── pinky-2727874_1920.png │ │ ├── portrait-2164027_1920.png │ │ ├── portrait-2554431_1920.png │ │ ├── portrait-laughing-businesswoman-long-dark-450w-235195312.png │ │ ├── portrait-smiling-woman-blue-shirt-450w-218101459.png │ │ ├── timg.png │ │ ├── w_sexy_gr.png │ │ └── young-507297_1920.png ├── landmark │ └── ALL │ │ ├── WechatIMG253_2.txt │ │ ├── WechatIMG256.txt │ │ ├── WechatIMG260.txt │ │ ├── WechatIMG261.txt │ │ ├── WechatIMG262.txt │ │ ├── WechatIMG263.txt │ │ ├── brazil-1368806_1920.txt │ │ ├── fconrad_Portrait_060414a.txt │ │ ├── garden-2768329_1920.txt │ │ ├── girl-2099354_1920.txt │ │ ├── girl-2099357_1920.txt │ │ ├── girl-2122909_1920.txt │ │ ├── girl-2122927_1280.txt │ │ ├── girl-2128294_1920.txt │ │ ├── girl-2132171_1920.txt │ │ ├── girl-2143709_1920.txt │ │ ├── girl-2164409_1920.txt │ │ ├── girl-2177360_1920.txt │ │ ├── girl-2720476_1920.txt │ │ ├── girl-2999078_1920.txt │ │ ├── girl-4024238_1920.txt │ │ ├── girl-4024240_1920.txt │ │ ├── girl-4024244_1920.txt │ │ ├── male-467711_1920.txt │ │ ├── mid-autumn-2752710_1920.txt │ │ ├── model-2134460_1920.txt │ │ ├── model-2911329_1920.txt │ │ ├── model-2911332_1920.txt │ │ ├── own-2553537_1280.txt │ │ ├── passport-picture-businesswoman-brown-hair-450w-250775908.txt │ │ ├── pinky-2727846_1920.txt │ │ ├── pinky-2727874_1920.txt │ │ ├── portrait-2164027_1920.txt │ │ ├── portrait-2554431_1920.txt │ │ ├── portrait-laughing-businesswoman-long-dark-450w-235195312.txt │ │ ├── portrait-smiling-woman-blue-shirt-450w-218101459.txt │ │ ├── timg.txt │ │ ├── w_sexy_gr.txt │ │ └── young-507297_1920.txt └── mask │ └── ALL │ ├── WechatIMG253_2.png │ ├── WechatIMG256.png │ ├── WechatIMG260.png │ ├── WechatIMG261.png │ ├── WechatIMG262.png │ ├── WechatIMG263.png │ ├── brazil-1368806_1920.png │ ├── fconrad_Portrait_060414a.png │ ├── garden-2768329_1920.png │ ├── girl-2099354_1920.png │ ├── girl-2099357_1920.png │ ├── girl-2122909_1920.png │ ├── girl-2122927_1280.png │ ├── girl-2128294_1920.png │ ├── girl-2132171_1920.png │ ├── girl-2143709_1920.png │ ├── girl-2164409_1920.png │ ├── girl-2177360_1920.png │ ├── girl-2720476_1920.png │ ├── girl-2999078_1920.png │ ├── girl-4024238_1920.png │ ├── girl-4024240_1920.png │ ├── girl-4024244_1920.png │ ├── male-467711_1920.png │ ├── mid-autumn-2752710_1920.png │ ├── model-2134460_1920.png │ ├── model-2911329_1920.png │ ├── model-2911332_1920.png │ ├── own-2553537_1280.png │ ├── passport-picture-businesswoman-brown-hair-450w-250775908.png │ ├── pinky-2727846_1920.png │ ├── pinky-2727874_1920.png │ ├── portrait-2164027_1920.png │ ├── portrait-2554431_1920.png │ ├── portrait-laughing-businesswoman-long-dark-450w-235195312.png │ ├── portrait-smiling-woman-blue-shirt-450w-218101459.png │ ├── timg.png │ ├── w_sexy_gr.png │ └── young-507297_1920.png ├── docs └── tips.md ├── imgs ├── architecture.png └── samples │ ├── img_1673.png │ ├── img_1673_fake_B.png │ ├── img_1682.png │ ├── img_1682_fake_B.png │ ├── img_1696.png │ ├── img_1696_fake_B.png │ ├── img_1701.png │ ├── img_1701_fake_B.png │ ├── img_1794.png │ └── img_1794_fake_B.png ├── models ├── __init__.py ├── apdrawing_gan_model.py ├── base_model.py ├── networks.py └── test_model.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── preprocess ├── combine_A_and_B.py ├── example │ ├── img_1701.jpg │ ├── img_1701_aligned.png │ ├── img_1701_aligned.txt │ ├── img_1701_aligned_bgmask.png │ └── img_1701_facial5point.mat ├── face_align_512.m └── readme.md ├── readme.md ├── requirements.txt ├── script ├── test.sh ├── test_pretrained.sh ├── test_single.sh ├── train.sh └── train_noinit.sh ├── test.py ├── train.py └── util ├── __init__.py ├── html.py ├── image_pool.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | debug* 3 | checkpoints/ 4 | results/ 5 | build/ 6 | dist/ 7 | torch.egg-info/ 8 | */**/__pycache__ 9 | torch/version.py 10 | torch/csrc/generic/TensorMethods.cpp 11 | torch/lib/*.so* 12 | torch/lib/*.dylib* 13 | torch/lib/*.h 14 | torch/lib/build 15 | torch/lib/tmp_install 16 | torch/lib/include 17 | torch/lib/torch_shm_manager 18 | torch/csrc/cudnn/cuDNN.cpp 19 | torch/csrc/nn/THNN.cwrap 20 | torch/csrc/nn/THNN.cpp 21 | torch/csrc/nn/THCUNN.cwrap 22 | torch/csrc/nn/THCUNN.cpp 23 | torch/csrc/nn/THNN_generic.cwrap 24 | torch/csrc/nn/THNN_generic.cpp 25 | torch/csrc/nn/THNN_generic.h 26 | docs/src/**/* 27 | test/data/legacy_modules.t7 28 | test/data/gpu_tensors.pt 29 | test/htmlcov 30 | test/.coverage 31 | */*.pyc 32 | */**/*.pyc 33 | */**/**/*.pyc 34 | */**/**/**/*.pyc 35 | */**/**/**/**/*.pyc 36 | */*.so* 37 | */**/*.so* 38 | */**/*.dylib* 39 | test/data/legacy_serialized.pt 40 | *~ 41 | .idea 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2019, Ran Yi. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch.utils.data 3 | from data.base_data_loader import BaseDataLoader 4 | from data.base_dataset import BaseDataset 5 | 6 | 7 | def find_dataset_using_name(dataset_name): 8 | # Given the option --dataset_mode [datasetname], 9 | # the file "data/datasetname_dataset.py" 10 | # will be imported. 11 | dataset_filename = "data." + dataset_name + "_dataset" 12 | datasetlib = importlib.import_module(dataset_filename) 13 | 14 | # In the file, the class called DatasetNameDataset() will 15 | # be instantiated. It has to be a subclass of BaseDataset, 16 | # and it is case-insensitive. 17 | dataset = None 18 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 19 | for name, cls in datasetlib.__dict__.items(): 20 | if name.lower() == target_dataset_name.lower() \ 21 | and issubclass(cls, BaseDataset): 22 | dataset = cls 23 | 24 | if dataset is None: 25 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 26 | exit(0) 27 | 28 | return dataset 29 | 30 | 31 | def get_option_setter(dataset_name): 32 | dataset_class = find_dataset_using_name(dataset_name) 33 | return dataset_class.modify_commandline_options 34 | 35 | 36 | def create_dataset(opt): 37 | dataset = find_dataset_using_name(opt.dataset_mode) 38 | instance = dataset() 39 | instance.initialize(opt) 40 | print("dataset [%s] was created" % (instance.name())) 41 | return instance 42 | 43 | 44 | def CreateDataLoader(opt): 45 | data_loader = CustomDatasetDataLoader() 46 | data_loader.initialize(opt) 47 | return data_loader 48 | 49 | 50 | # Wrapper class of Dataset class that performs 51 | # multi-threaded data loading 52 | class CustomDatasetDataLoader(BaseDataLoader): 53 | def name(self): 54 | return 'CustomDatasetDataLoader' 55 | 56 | def initialize(self, opt): 57 | BaseDataLoader.initialize(self, opt) 58 | self.dataset = create_dataset(opt) 59 | self.dataloader = torch.utils.data.DataLoader( 60 | self.dataset, 61 | batch_size=opt.batch_size, 62 | shuffle=not opt.serial_batches,#in training, serial_batches by default is false, shuffle=true 63 | num_workers=int(opt.num_threads)) 64 | 65 | def load_data(self): 66 | return self 67 | 68 | def __len__(self): 69 | return min(len(self.dataset), self.opt.max_dataset_size) 70 | 71 | def __iter__(self): 72 | for i, data in enumerate(self.dataloader): 73 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 74 | break 75 | yield data 76 | -------------------------------------------------------------------------------- /data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import torchvision.transforms as transforms 4 | import torch 5 | from data.base_dataset import BaseDataset 6 | from data.image_folder import make_dataset 7 | from PIL import Image 8 | import numpy as np 9 | import cv2 10 | import csv 11 | 12 | def getfeats(featpath): 13 | trans_points = np.empty([5,2],dtype=np.int64) 14 | with open(featpath, 'r') as csvfile: 15 | reader = csv.reader(csvfile, delimiter=' ') 16 | for ind,row in enumerate(reader): 17 | trans_points[ind,:] = row 18 | return trans_points 19 | 20 | def tocv2(ts): 21 | img = (ts.numpy()/2+0.5)*255 22 | img = img.astype('uint8') 23 | img = np.transpose(img,(1,2,0)) 24 | img = img[:,:,::-1]#rgb->bgr 25 | return img 26 | 27 | def dt(img): 28 | if(img.shape[2]==3): 29 | img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) 30 | #convert to BW 31 | ret1,thresh1 = cv2.threshold(img,127,255,cv2.THRESH_BINARY) 32 | ret2,thresh2 = cv2.threshold(img,127,255,cv2.THRESH_BINARY_INV) 33 | dt1 = cv2.distanceTransform(thresh1,cv2.DIST_L2,5) 34 | dt2 = cv2.distanceTransform(thresh2,cv2.DIST_L2,5) 35 | dt1 = dt1/dt1.max()#->[0,1] 36 | dt2 = dt2/dt2.max() 37 | return dt1, dt2 38 | 39 | def getSoft(size,xb,yb,boundwidth=5.0): 40 | xarray = np.tile(np.arange(0,size[1]),(size[0],1)) 41 | yarray = np.tile(np.arange(0,size[0]),(size[1],1)).transpose() 42 | cxdists = [] 43 | cydists = [] 44 | for i in range(len(xb)): 45 | xba = np.tile(xb[i],(size[1],1)).transpose() 46 | yba = np.tile(yb[i],(size[0],1)) 47 | cxdists.append(np.abs(xarray-xba)) 48 | cydists.append(np.abs(yarray-yba)) 49 | xdist = np.minimum.reduce(cxdists) 50 | ydist = np.minimum.reduce(cydists) 51 | manhdist = np.minimum.reduce([xdist,ydist]) 52 | im = (manhdist+1) / (boundwidth+1) * 1.0 53 | im[im>=1.0] = 1.0 54 | return im 55 | 56 | class AlignedDataset(BaseDataset): 57 | @staticmethod 58 | def modify_commandline_options(parser, is_train): 59 | return parser 60 | 61 | def initialize(self, opt): 62 | self.opt = opt 63 | self.root = opt.dataroot 64 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) 65 | self.AB_paths = sorted(make_dataset(self.dir_AB)) 66 | assert(opt.resize_or_crop == 'resize_and_crop') 67 | 68 | def __getitem__(self, index): 69 | AB_path = self.AB_paths[index] 70 | AB = Image.open(AB_path).convert('RGB') 71 | w, h = AB.size 72 | w2 = int(w / 2) 73 | A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 74 | B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 75 | A = transforms.ToTensor()(A) 76 | B = transforms.ToTensor()(B) 77 | w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) 78 | h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) 79 | 80 | A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]#C,H,W 81 | B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] 82 | 83 | A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) 84 | B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B) 85 | 86 | if self.opt.which_direction == 'BtoA': 87 | input_nc = self.opt.output_nc 88 | output_nc = self.opt.input_nc 89 | else: 90 | input_nc = self.opt.input_nc 91 | output_nc = self.opt.output_nc 92 | 93 | flipped = False 94 | if (not self.opt.no_flip) and random.random() < 0.5: 95 | flipped = True 96 | idx = [i for i in range(A.size(2) - 1, -1, -1)] 97 | idx = torch.LongTensor(idx) 98 | A = A.index_select(2, idx) 99 | B = B.index_select(2, idx) 100 | 101 | if input_nc == 1: # RGB to gray 102 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 103 | A = tmp.unsqueeze(0) 104 | 105 | if output_nc == 1: # RGB to gray 106 | tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 107 | B = tmp.unsqueeze(0) 108 | 109 | item = {'A': A, 'B': B, 110 | 'A_paths': AB_path, 'B_paths': AB_path} 111 | 112 | if self.opt.use_local: 113 | regions = ['eyel','eyer','nose','mouth'] 114 | basen = os.path.basename(AB_path)[:-4]+'.txt' 115 | featdir = self.opt.lm_dir 116 | featpath = os.path.join(featdir,basen) 117 | feats = getfeats(featpath) 118 | if flipped: 119 | for i in range(5): 120 | feats[i,0] = self.opt.fineSize - feats[i,0] - 1 121 | tmp = [feats[0,0],feats[0,1]] 122 | feats[0,:] = [feats[1,0],feats[1,1]] 123 | feats[1,:] = tmp 124 | mouth_x = int((feats[3,0]+feats[4,0])/2.0) 125 | mouth_y = int((feats[3,1]+feats[4,1])/2.0) 126 | ratio = self.opt.fineSize / 256 127 | EYE_H = self.opt.EYE_H * ratio 128 | EYE_W = self.opt.EYE_W * ratio 129 | NOSE_H = self.opt.NOSE_H * ratio 130 | NOSE_W = self.opt.NOSE_W * ratio 131 | MOUTH_H = self.opt.MOUTH_H * ratio 132 | MOUTH_W = self.opt.MOUTH_W * ratio 133 | center = torch.IntTensor([[feats[0,0],feats[0,1]-4*ratio],[feats[1,0],feats[1,1]-4*ratio],[feats[2,0],feats[2,1]-NOSE_H/2+16*ratio],[mouth_x,mouth_y]]) 134 | item['center'] = center 135 | rhs = [EYE_H,EYE_H,NOSE_H,MOUTH_H] 136 | rws = [EYE_W,EYE_W,NOSE_W,MOUTH_W] 137 | if self.opt.soft_border: 138 | soft_border_mask4 = [] 139 | for i in range(4): 140 | xb = [np.zeros(rhs[i]),np.ones(rhs[i])*(rws[i]-1)] 141 | yb = [np.zeros(rws[i]),np.ones(rws[i])*(rhs[i]-1)] 142 | soft_border_mask = getSoft([rhs[i],rws[i]],xb,yb) 143 | soft_border_mask4.append(torch.Tensor(soft_border_mask).unsqueeze(0)) 144 | item['soft_'+regions[i]+'_mask'] = soft_border_mask4[i] 145 | for i in range(4): 146 | item[regions[i]+'_A'] = A[:,int(center[i,1]-rhs[i]/2):int(center[i,1]+rhs[i]/2),int(center[i,0]-rws[i]/2):int(center[i,0]+rws[i]/2)] 147 | item[regions[i]+'_B'] = B[:,int(center[i,1]-rhs[i]/2):int(center[i,1]+rhs[i]/2),int(center[i,0]-rws[i]/2):int(center[i,0]+rws[i]/2)] 148 | if self.opt.soft_border: 149 | item[regions[i]+'_A'] = item[regions[i]+'_A'] * soft_border_mask4[i].repeat(input_nc/output_nc,1,1) 150 | item[regions[i]+'_B'] = item[regions[i]+'_B'] * soft_border_mask4[i] 151 | 152 | mask = torch.ones(B.shape) # mask out eyes, nose, mouth 153 | for i in range(4): 154 | mask[:,int(center[i,1]-rhs[i]/2):int(center[i,1]+rhs[i]/2),int(center[i,0]-rws[i]/2):int(center[i,0]+rws[i]/2)] = 0 155 | if self.opt.soft_border: 156 | imgsize = self.opt.fineSize 157 | maskn = mask[0].numpy() 158 | masks = [np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize])] 159 | masks[0][1:] = maskn[:-1] 160 | masks[1][:-1] = maskn[1:] 161 | masks[2][:,1:] = maskn[:,:-1] 162 | masks[3][:,:-1] = maskn[:,1:] 163 | masks2 = [maskn-e for e in masks] 164 | bound = np.minimum.reduce(masks2) 165 | bound = -bound 166 | xb = [] 167 | yb = [] 168 | for i in range(4): 169 | xbi = [center[i,0]-rws[i]/2, center[i,0]+rws[i]/2-1] 170 | ybi = [center[i,1]-rhs[i]/2, center[i,1]+rhs[i]/2-1] 171 | for j in range(2): 172 | maskx = bound[:,xbi[j]] 173 | masky = bound[ybi[j],:] 174 | xb += [(1-maskx)*10000 + maskx*xbi[j]] 175 | yb += [(1-masky)*10000 + masky*ybi[j]] 176 | soft = 1-getSoft([imgsize,imgsize],xb,yb) 177 | soft = torch.Tensor(soft).unsqueeze(0) 178 | mask = (torch.ones(mask.shape)-mask)*soft + mask 179 | 180 | bgdir = self.opt.bg_dir 181 | bgpath = os.path.join(bgdir,basen[:-4]+'.png') 182 | im_bg = Image.open(bgpath) 183 | mask2 = transforms.ToTensor()(im_bg) # mask out background 184 | if flipped: 185 | mask2 = mask2.index_select(2, idx) 186 | mask2 = (mask2 >= 0.5).float() 187 | 188 | hair_A = (A/2+0.5) * mask.repeat(input_nc//output_nc,1,1) * mask2.repeat(input_nc//output_nc,1,1) * 2 - 1 189 | hair_B = (B/2+0.5) * mask * mask2 * 2 - 1 190 | bg_A = (A/2+0.5) * (torch.ones(mask2.shape)-mask2).repeat(input_nc//output_nc,1,1) * 2 - 1 191 | bg_B = (B/2+0.5) * (torch.ones(mask2.shape)-mask2) * 2 - 1 192 | item['hair_A'] = hair_A 193 | item['hair_B'] = hair_B 194 | item['bg_A'] = bg_A 195 | item['bg_B'] = bg_B 196 | item['mask'] = mask 197 | item['mask2'] = mask2 198 | 199 | if self.opt.isTrain: 200 | if self.opt.which_direction == 'AtoB': 201 | img = tocv2(B) 202 | else: 203 | img = tocv2(A) 204 | dt1, dt2 = dt(img) 205 | dt1 = torch.from_numpy(dt1) 206 | dt2 = torch.from_numpy(dt2) 207 | dt1 = dt1.unsqueeze(0) 208 | dt2 = dt2.unsqueeze(0) 209 | item['dt1gt'] = dt1 210 | item['dt2gt'] = dt2 211 | 212 | return item 213 | 214 | def __len__(self): 215 | return len(self.AB_paths) 216 | 217 | def name(self): 218 | return 'AlignedDataset' 219 | -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | class BaseDataLoader(): 2 | def __init__(self): 3 | pass 4 | 5 | def initialize(self, opt): 6 | self.opt = opt 7 | pass 8 | 9 | def load_data(): 10 | return None 11 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | 6 | class BaseDataset(data.Dataset): 7 | def __init__(self): 8 | super(BaseDataset, self).__init__() 9 | 10 | def name(self): 11 | return 'BaseDataset' 12 | 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train): 15 | return parser 16 | 17 | def initialize(self, opt): 18 | pass 19 | 20 | def __len__(self): 21 | return 0 22 | 23 | 24 | def get_transform(opt): 25 | transform_list = [] 26 | if opt.resize_or_crop == 'resize_and_crop': 27 | osize = [opt.loadSize, opt.fineSize] 28 | transform_list.append(transforms.Resize(osize, Image.BICUBIC)) 29 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 30 | elif opt.resize_or_crop == 'crop': 31 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 32 | elif opt.resize_or_crop == 'scale_width': 33 | transform_list.append(transforms.Lambda( 34 | lambda img: __scale_width(img, opt.fineSize))) 35 | elif opt.resize_or_crop == 'scale_width_and_crop': 36 | transform_list.append(transforms.Lambda( 37 | lambda img: __scale_width(img, opt.loadSize))) 38 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 39 | elif opt.resize_or_crop == 'none': 40 | transform_list.append(transforms.Lambda( 41 | lambda img: __adjust(img))) 42 | else: 43 | raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop) 44 | 45 | if opt.isTrain and not opt.no_flip: 46 | transform_list.append(transforms.RandomHorizontalFlip()) 47 | 48 | transform_list += [transforms.ToTensor(), 49 | transforms.Normalize((0.5, 0.5, 0.5), 50 | (0.5, 0.5, 0.5))] 51 | return transforms.Compose(transform_list) 52 | 53 | # just modify the width and height to be multiple of 4 54 | def __adjust(img): 55 | ow, oh = img.size 56 | 57 | # the size needs to be a multiple of this number, 58 | # because going through generator network may change img size 59 | # and eventually cause size mismatch error 60 | mult = 4 61 | if ow % mult == 0 and oh % mult == 0: 62 | return img 63 | w = (ow - 1) // mult 64 | w = (w + 1) * mult 65 | h = (oh - 1) // mult 66 | h = (h + 1) * mult 67 | 68 | if ow != w or oh != h: 69 | __print_size_warning(ow, oh, w, h) 70 | 71 | return img.resize((w, h), Image.BICUBIC) 72 | 73 | 74 | def __scale_width(img, target_width): 75 | ow, oh = img.size 76 | 77 | # the size needs to be a multiple of this number, 78 | # because going through generator network may change img size 79 | # and eventually cause size mismatch error 80 | mult = 4 81 | assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult 82 | if (ow == target_width and oh % mult == 0): 83 | return img 84 | w = target_width 85 | target_height = int(target_width * oh / ow) 86 | m = (target_height - 1) // mult 87 | h = (m + 1) * mult 88 | 89 | if target_height != h: 90 | __print_size_warning(target_width, target_height, w, h) 91 | 92 | return img.resize((w, h), Image.BICUBIC) 93 | 94 | 95 | def __print_size_warning(ow, oh, w, h): 96 | if not hasattr(__print_size_warning, 'has_printed'): 97 | print("The image size needs to be a multiple of 4. " 98 | "The loaded image size was (%d, %d), so it was adjusted to " 99 | "(%d, %d). This adjustment will be done to all images " 100 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 101 | __print_size_warning.has_printed = True 102 | 103 | 104 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | 34 | return images 35 | 36 | 37 | def default_loader(path): 38 | return Image.open(path).convert('RGB') 39 | 40 | 41 | class ImageFolder(data.Dataset): 42 | 43 | def __init__(self, root, transform=None, return_paths=False, 44 | loader=default_loader): 45 | imgs = make_dataset(root) 46 | if len(imgs) == 0: 47 | raise(RuntimeError("Found 0 images in: " + root + "\n" 48 | "Supported image extensions are: " + 49 | ",".join(IMG_EXTENSIONS))) 50 | 51 | self.root = root 52 | self.imgs = imgs 53 | self.transform = transform 54 | self.return_paths = return_paths 55 | self.loader = loader 56 | 57 | def __getitem__(self, index): 58 | path = self.imgs[index] 59 | img = self.loader(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.return_paths: 63 | return img, path 64 | else: 65 | return img 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | -------------------------------------------------------------------------------- /data/single_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import torch 3 | import torchvision.transforms as transforms 4 | from data.base_dataset import BaseDataset, get_transform 5 | from data.image_folder import make_dataset 6 | from PIL import Image 7 | import numpy as np 8 | import csv 9 | 10 | def getfeats(featpath): 11 | trans_points = np.empty([5,2],dtype=np.int64) 12 | with open(featpath, 'r') as csvfile: 13 | reader = csv.reader(csvfile, delimiter=' ') 14 | for ind,row in enumerate(reader): 15 | trans_points[ind,:] = row 16 | return trans_points 17 | 18 | def getSoft(size,xb,yb,boundwidth=5.0): 19 | xarray = np.tile(np.arange(0,size[1]),(size[0],1)) 20 | yarray = np.tile(np.arange(0,size[0]),(size[1],1)).transpose() 21 | cxdists = [] 22 | cydists = [] 23 | for i in range(len(xb)): 24 | xba = np.tile(xb[i],(size[1],1)).transpose() 25 | yba = np.tile(yb[i],(size[0],1)) 26 | cxdists.append(np.abs(xarray-xba)) 27 | cydists.append(np.abs(yarray-yba)) 28 | xdist = np.minimum.reduce(cxdists) 29 | ydist = np.minimum.reduce(cydists) 30 | manhdist = np.minimum.reduce([xdist,ydist]) 31 | im = (manhdist+1) / (boundwidth+1) * 1.0 32 | im[im>=1.0] = 1.0 33 | return im 34 | 35 | class SingleDataset(BaseDataset): 36 | @staticmethod 37 | def modify_commandline_options(parser, is_train): 38 | return parser 39 | 40 | def initialize(self, opt): 41 | self.opt = opt 42 | self.root = opt.dataroot 43 | self.dir_A = os.path.join(opt.dataroot) 44 | 45 | self.A_paths = make_dataset(self.dir_A) 46 | 47 | self.A_paths = sorted(self.A_paths) 48 | 49 | self.transform = get_transform(opt) 50 | 51 | def __getitem__(self, index): 52 | A_path = self.A_paths[index] 53 | A_img = Image.open(A_path).convert('RGB') 54 | A = self.transform(A_img) 55 | if self.opt.which_direction == 'BtoA': 56 | input_nc = self.opt.output_nc 57 | output_nc = self.opt.input_nc 58 | else: 59 | input_nc = self.opt.input_nc 60 | output_nc = self.opt.output_nc 61 | 62 | if input_nc == 1: # RGB to gray 63 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 64 | A = tmp.unsqueeze(0) 65 | 66 | item = {'A': A, 'A_paths': A_path} 67 | 68 | if self.opt.use_local: 69 | regions = ['eyel','eyer','nose','mouth'] 70 | basen = os.path.basename(A_path)[:-4]+'.txt' 71 | featdir = self.opt.lm_dir 72 | featpath = os.path.join(featdir,basen) 73 | feats = getfeats(featpath) 74 | mouth_x = int((feats[3,0]+feats[4,0])/2.0) 75 | mouth_y = int((feats[3,1]+feats[4,1])/2.0) 76 | ratio = self.opt.fineSize / 256 77 | EYE_H = self.opt.EYE_H * ratio 78 | EYE_W = self.opt.EYE_W * ratio 79 | NOSE_H = self.opt.NOSE_H * ratio 80 | NOSE_W = self.opt.NOSE_W * ratio 81 | MOUTH_H = self.opt.MOUTH_H * ratio 82 | MOUTH_W = self.opt.MOUTH_W * ratio 83 | center = torch.tensor([[feats[0,0],feats[0,1]-4*ratio],[feats[1,0],feats[1,1]-4*ratio],[feats[2,0],feats[2,1]-NOSE_H/2+16*ratio],[mouth_x,mouth_y]]) 84 | item['center'] = center 85 | rhs = [EYE_H,EYE_H,NOSE_H,MOUTH_H] 86 | rws = [EYE_W,EYE_W,NOSE_W,MOUTH_W] 87 | if self.opt.soft_border: 88 | soft_border_mask4 = [] 89 | for i in range(4): 90 | xb = [np.zeros(rhs[i]),np.ones(rhs[i])*(rws[i]-1)] 91 | yb = [np.zeros(rws[i]),np.ones(rws[i])*(rhs[i]-1)] 92 | soft_border_mask = getSoft([rhs[i],rws[i]],xb,yb) 93 | soft_border_mask4.append(torch.Tensor(soft_border_mask).unsqueeze(0)) 94 | item['soft_'+regions[i]+'_mask'] = soft_border_mask4[i] 95 | for i in range(4): 96 | item[regions[i]+'_A'] = A[:,int(center[i,1]-rhs[i]/2):int(center[i,1]+rhs[i]/2),int(center[i,0]-rws[i]/2):int(center[i,0]+rws[i]/2)] 97 | if self.opt.soft_border: 98 | item[regions[i]+'_A'] = item[regions[i]+'_A'] * soft_border_mask4[i].repeat(input_nc/output_nc,1,1) 99 | 100 | mask = torch.ones([output_nc,A.shape[1],A.shape[2]]) # mask out eyes, nose, mouth 101 | for i in range(4): 102 | mask[:,int(center[i,1]-rhs[i]/2):int(center[i,1]+rhs[i]/2),int(center[i,0]-rws[i]/2):int(center[i,0]+rws[i]/2)] = 0 103 | if self.opt.soft_border: 104 | imgsize = self.opt.fineSize 105 | maskn = mask[0].numpy() 106 | masks = [np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize])] 107 | masks[0][1:] = maskn[:-1] 108 | masks[1][:-1] = maskn[1:] 109 | masks[2][:,1:] = maskn[:,:-1] 110 | masks[3][:,:-1] = maskn[:,1:] 111 | masks2 = [maskn-e for e in masks] 112 | bound = np.minimum.reduce(masks2) 113 | bound = -bound 114 | xb = [] 115 | yb = [] 116 | for i in range(4): 117 | xbi = [center[i,0]-rws[i]/2, center[i,0]+rws[i]/2-1] 118 | ybi = [center[i,1]-rhs[i]/2, center[i,1]+rhs[i]/2-1] 119 | for j in range(2): 120 | maskx = bound[:,xbi[j]] 121 | masky = bound[ybi[j],:] 122 | xb += [(1-maskx)*10000 + maskx*xbi[j]] 123 | yb += [(1-masky)*10000 + masky*ybi[j]] 124 | soft = 1-getSoft([imgsize,imgsize],xb,yb) 125 | soft = torch.Tensor(soft).unsqueeze(0) 126 | mask = (torch.ones(mask.shape)-mask)*soft + mask 127 | 128 | bgdir = self.opt.bg_dir 129 | bgpath = os.path.join(bgdir,basen[:-4]+'.png') 130 | im_bg = Image.open(bgpath) 131 | mask2 = transforms.ToTensor()(im_bg) # mask out background 132 | mask2 = (mask2 >= 0.5).float() 133 | # hair_A = (A/2+0.5) * mask.repeat(input_nc/output_nc,1,1) * mask2.repeat(input_nc/output_nc,1,1) * 2 - 1 134 | # bg_A = (A/2+0.5) * (torch.ones(mask2.shape)-mask2).repeat(input_nc/output_nc,1,1) * 2 - 1 135 | hair_A = (A/2+0.5) * mask.repeat(3,1,1) * mask2.repeat(3,1,1) * 2 - 1 136 | bg_A = (A/2+0.5) * (torch.ones(mask2.shape)-mask2).repeat(3,1,1) * 2 - 1 137 | item['hair_A'] = hair_A 138 | item['bg_A'] = bg_A 139 | item['mask'] = mask 140 | item['mask2'] = mask2 141 | 142 | return item 143 | 144 | def __len__(self): 145 | return len(self.A_paths) 146 | 147 | def name(self): 148 | return 'SingleImageDataset' 149 | -------------------------------------------------------------------------------- /dataset/data/test_single/WechatIMG253_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/WechatIMG253_2.png -------------------------------------------------------------------------------- /dataset/data/test_single/WechatIMG256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/WechatIMG256.png -------------------------------------------------------------------------------- /dataset/data/test_single/WechatIMG260.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/WechatIMG260.png -------------------------------------------------------------------------------- /dataset/data/test_single/WechatIMG261.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/WechatIMG261.png -------------------------------------------------------------------------------- /dataset/data/test_single/WechatIMG262.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/WechatIMG262.png -------------------------------------------------------------------------------- /dataset/data/test_single/WechatIMG263.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/WechatIMG263.png -------------------------------------------------------------------------------- /dataset/data/test_single/brazil-1368806_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/brazil-1368806_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/fconrad_Portrait_060414a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/fconrad_Portrait_060414a.png -------------------------------------------------------------------------------- /dataset/data/test_single/garden-2768329_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/garden-2768329_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/girl-2099354_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/girl-2099354_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/girl-2099357_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/girl-2099357_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/girl-2122909_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/girl-2122909_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/girl-2122927_1280.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/girl-2122927_1280.png -------------------------------------------------------------------------------- /dataset/data/test_single/girl-2128294_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/girl-2128294_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/girl-2132171_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/girl-2132171_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/girl-2143709_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/girl-2143709_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/girl-2164409_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/girl-2164409_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/girl-2177360_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/girl-2177360_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/girl-2720476_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/girl-2720476_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/girl-2999078_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/girl-2999078_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/girl-4024238_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/girl-4024238_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/girl-4024240_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/girl-4024240_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/girl-4024244_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/girl-4024244_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/male-467711_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/male-467711_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/mid-autumn-2752710_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/mid-autumn-2752710_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/model-2134460_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/model-2134460_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/model-2911329_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/model-2911329_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/model-2911332_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/model-2911332_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/own-2553537_1280.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/own-2553537_1280.png -------------------------------------------------------------------------------- /dataset/data/test_single/passport-picture-businesswoman-brown-hair-450w-250775908.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/passport-picture-businesswoman-brown-hair-450w-250775908.png -------------------------------------------------------------------------------- /dataset/data/test_single/pinky-2727846_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/pinky-2727846_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/pinky-2727874_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/pinky-2727874_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/portrait-2164027_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/portrait-2164027_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/portrait-2554431_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/portrait-2554431_1920.png -------------------------------------------------------------------------------- /dataset/data/test_single/portrait-laughing-businesswoman-long-dark-450w-235195312.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/portrait-laughing-businesswoman-long-dark-450w-235195312.png -------------------------------------------------------------------------------- /dataset/data/test_single/portrait-smiling-woman-blue-shirt-450w-218101459.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/portrait-smiling-woman-blue-shirt-450w-218101459.png -------------------------------------------------------------------------------- /dataset/data/test_single/timg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/timg.png -------------------------------------------------------------------------------- /dataset/data/test_single/w_sexy_gr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/w_sexy_gr.png -------------------------------------------------------------------------------- /dataset/data/test_single/young-507297_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/data/test_single/young-507297_1920.png -------------------------------------------------------------------------------- /dataset/landmark/ALL/WechatIMG253_2.txt: -------------------------------------------------------------------------------- 1 | 197 248 2 | 316 247 3 | 254 324 4 | 204 363 5 | 310 365 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/WechatIMG256.txt: -------------------------------------------------------------------------------- 1 | 199 246 2 | 313 247 3 | 255 317 4 | 207 369 5 | 306 368 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/WechatIMG260.txt: -------------------------------------------------------------------------------- 1 | 195 247 2 | 316 246 3 | 259 314 4 | 212 371 5 | 298 370 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/WechatIMG261.txt: -------------------------------------------------------------------------------- 1 | 196 246 2 | 310 247 3 | 267 316 4 | 204 370 5 | 302 368 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/WechatIMG262.txt: -------------------------------------------------------------------------------- 1 | 191 245 2 | 319 246 3 | 258 330 4 | 208 363 5 | 303 363 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/WechatIMG263.txt: -------------------------------------------------------------------------------- 1 | 199 242 2 | 318 244 3 | 241 329 4 | 217 364 5 | 306 368 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/brazil-1368806_1920.txt: -------------------------------------------------------------------------------- 1 | 199 246 2 | 313 246 3 | 254 316 4 | 210 369 5 | 303 370 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/fconrad_Portrait_060414a.txt: -------------------------------------------------------------------------------- 1 | 200 250 2 | 309 251 3 | 259 307 4 | 199 368 5 | 312 371 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/garden-2768329_1920.txt: -------------------------------------------------------------------------------- 1 | 197 249 2 | 311 248 3 | 269 326 4 | 192 364 5 | 311 360 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/girl-2099354_1920.txt: -------------------------------------------------------------------------------- 1 | 200 247 2 | 316 244 3 | 253 313 4 | 213 372 5 | 299 371 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/girl-2099357_1920.txt: -------------------------------------------------------------------------------- 1 | 200 243 2 | 312 245 3 | 254 315 4 | 214 372 5 | 299 372 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/girl-2122909_1920.txt: -------------------------------------------------------------------------------- 1 | 197 242 2 | 310 243 3 | 267 317 4 | 213 373 5 | 293 372 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/girl-2122927_1280.txt: -------------------------------------------------------------------------------- 1 | 202 245 2 | 315 247 3 | 241 316 4 | 213 368 5 | 309 371 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/girl-2128294_1920.txt: -------------------------------------------------------------------------------- 1 | 195 246 2 | 311 247 3 | 266 320 4 | 204 367 5 | 303 368 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/girl-2132171_1920.txt: -------------------------------------------------------------------------------- 1 | 195 247 2 | 309 247 3 | 275 321 4 | 198 367 5 | 304 365 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/girl-2143709_1920.txt: -------------------------------------------------------------------------------- 1 | 201 245 2 | 312 244 3 | 255 308 4 | 215 375 5 | 296 375 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/girl-2164409_1920.txt: -------------------------------------------------------------------------------- 1 | 199 245 2 | 313 246 3 | 254 326 4 | 205 364 5 | 309 366 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/girl-2177360_1920.txt: -------------------------------------------------------------------------------- 1 | 197 248 2 | 316 249 3 | 253 321 4 | 201 365 5 | 311 363 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/girl-2720476_1920.txt: -------------------------------------------------------------------------------- 1 | 194 249 2 | 317 247 3 | 260 326 4 | 201 363 5 | 308 362 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/girl-2999078_1920.txt: -------------------------------------------------------------------------------- 1 | 198 249 2 | 315 251 3 | 250 326 4 | 196 360 5 | 320 361 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/girl-4024238_1920.txt: -------------------------------------------------------------------------------- 1 | 197 251 2 | 310 249 3 | 268 321 4 | 192 363 5 | 313 363 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/girl-4024240_1920.txt: -------------------------------------------------------------------------------- 1 | 197 249 2 | 310 248 3 | 270 326 4 | 192 364 5 | 311 360 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/girl-4024244_1920.txt: -------------------------------------------------------------------------------- 1 | 197 250 2 | 313 249 3 | 261 327 4 | 193 361 5 | 315 361 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/male-467711_1920.txt: -------------------------------------------------------------------------------- 1 | 201 249 2 | 313 247 3 | 253 312 4 | 206 369 5 | 308 370 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/mid-autumn-2752710_1920.txt: -------------------------------------------------------------------------------- 1 | 201 247 2 | 316 250 3 | 243 325 4 | 201 362 5 | 319 363 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/model-2134460_1920.txt: -------------------------------------------------------------------------------- 1 | 201 246 2 | 317 246 3 | 242 311 4 | 217 372 5 | 303 372 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/model-2911329_1920.txt: -------------------------------------------------------------------------------- 1 | 199 246 2 | 309 247 3 | 263 316 4 | 203 368 5 | 306 370 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/model-2911332_1920.txt: -------------------------------------------------------------------------------- 1 | 200 244 2 | 310 244 3 | 260 317 4 | 210 369 5 | 301 372 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/own-2553537_1280.txt: -------------------------------------------------------------------------------- 1 | 196 248 2 | 314 250 3 | 257 319 4 | 201 364 5 | 312 366 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/passport-picture-businesswoman-brown-hair-450w-250775908.txt: -------------------------------------------------------------------------------- 1 | 196 251 2 | 315 253 3 | 256 320 4 | 194 361 5 | 319 361 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/pinky-2727846_1920.txt: -------------------------------------------------------------------------------- 1 | 196 249 2 | 317 250 3 | 254 331 4 | 196 359 5 | 317 359 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/pinky-2727874_1920.txt: -------------------------------------------------------------------------------- 1 | 206 244 2 | 316 247 3 | 231 326 4 | 211 365 5 | 316 366 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/portrait-2164027_1920.txt: -------------------------------------------------------------------------------- 1 | 198 250 2 | 315 250 3 | 252 325 4 | 197 361 5 | 318 362 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/portrait-2554431_1920.txt: -------------------------------------------------------------------------------- 1 | 198 250 2 | 315 250 3 | 256 318 4 | 198 366 5 | 313 362 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/portrait-laughing-businesswoman-long-dark-450w-235195312.txt: -------------------------------------------------------------------------------- 1 | 198 251 2 | 314 251 3 | 256 319 4 | 196 364 5 | 316 363 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/portrait-smiling-woman-blue-shirt-450w-218101459.txt: -------------------------------------------------------------------------------- 1 | 195 245 2 | 315 247 3 | 258 331 4 | 202 362 5 | 309 362 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/timg.txt: -------------------------------------------------------------------------------- 1 | 205 240 2 | 313 241 3 | 240 329 4 | 217 367 5 | 305 370 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/w_sexy_gr.txt: -------------------------------------------------------------------------------- 1 | 200 245 2 | 312 245 3 | 256 316 4 | 210 371 5 | 302 371 6 | -------------------------------------------------------------------------------- /dataset/landmark/ALL/young-507297_1920.txt: -------------------------------------------------------------------------------- 1 | 205 249 2 | 314 246 3 | 244 310 4 | 208 372 5 | 308 370 6 | -------------------------------------------------------------------------------- /dataset/mask/ALL/WechatIMG253_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/WechatIMG253_2.png -------------------------------------------------------------------------------- /dataset/mask/ALL/WechatIMG256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/WechatIMG256.png -------------------------------------------------------------------------------- /dataset/mask/ALL/WechatIMG260.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/WechatIMG260.png -------------------------------------------------------------------------------- /dataset/mask/ALL/WechatIMG261.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/WechatIMG261.png -------------------------------------------------------------------------------- /dataset/mask/ALL/WechatIMG262.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/WechatIMG262.png -------------------------------------------------------------------------------- /dataset/mask/ALL/WechatIMG263.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/WechatIMG263.png -------------------------------------------------------------------------------- /dataset/mask/ALL/brazil-1368806_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/brazil-1368806_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/fconrad_Portrait_060414a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/fconrad_Portrait_060414a.png -------------------------------------------------------------------------------- /dataset/mask/ALL/garden-2768329_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/garden-2768329_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/girl-2099354_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/girl-2099354_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/girl-2099357_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/girl-2099357_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/girl-2122909_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/girl-2122909_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/girl-2122927_1280.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/girl-2122927_1280.png -------------------------------------------------------------------------------- /dataset/mask/ALL/girl-2128294_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/girl-2128294_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/girl-2132171_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/girl-2132171_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/girl-2143709_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/girl-2143709_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/girl-2164409_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/girl-2164409_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/girl-2177360_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/girl-2177360_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/girl-2720476_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/girl-2720476_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/girl-2999078_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/girl-2999078_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/girl-4024238_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/girl-4024238_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/girl-4024240_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/girl-4024240_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/girl-4024244_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/girl-4024244_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/male-467711_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/male-467711_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/mid-autumn-2752710_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/mid-autumn-2752710_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/model-2134460_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/model-2134460_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/model-2911329_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/model-2911329_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/model-2911332_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/model-2911332_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/own-2553537_1280.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/own-2553537_1280.png -------------------------------------------------------------------------------- /dataset/mask/ALL/passport-picture-businesswoman-brown-hair-450w-250775908.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/passport-picture-businesswoman-brown-hair-450w-250775908.png -------------------------------------------------------------------------------- /dataset/mask/ALL/pinky-2727846_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/pinky-2727846_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/pinky-2727874_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/pinky-2727874_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/portrait-2164027_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/portrait-2164027_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/portrait-2554431_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/portrait-2554431_1920.png -------------------------------------------------------------------------------- /dataset/mask/ALL/portrait-laughing-businesswoman-long-dark-450w-235195312.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/portrait-laughing-businesswoman-long-dark-450w-235195312.png -------------------------------------------------------------------------------- /dataset/mask/ALL/portrait-smiling-woman-blue-shirt-450w-218101459.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/portrait-smiling-woman-blue-shirt-450w-218101459.png -------------------------------------------------------------------------------- /dataset/mask/ALL/timg.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/timg.png -------------------------------------------------------------------------------- /dataset/mask/ALL/w_sexy_gr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/w_sexy_gr.png -------------------------------------------------------------------------------- /dataset/mask/ALL/young-507297_1920.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/dataset/mask/ALL/young-507297_1920.png -------------------------------------------------------------------------------- /docs/tips.md: -------------------------------------------------------------------------------- 1 | ## Training/test Tips 2 | - Flags: see `options/train_options.py` and `options/base_options.py` for the training flags; see `options/test_options.py` and `options/base_options.py` for the test flags. The default values of these options are somtimes adjusted in the model files. 3 | 4 | - CPU/GPU (default `--gpu_ids 0`): set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. You need a large batch size (e.g. `--batch_size 32`) to benefit from multiple GPUs. 5 | 6 | - Visualization: during training, the current results can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will appear on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have `visdom` installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id -1`. Second, the intermediate results are saved to `[opt.checkpoints_dir]/[opt.name]/web/` as an HTML file. To avoid this, set `--no_html`. 7 | 8 | - Fine-tuning/Resume training: to fine-tune a pre-trained model, or resume the previous training, use the `--continue_train` flag. The program will then load the model based on `which_epoch`. By default, the program will initialize the epoch count as 1. Set `--epoch_count ` to specify a different starting epoch count. 9 | -------------------------------------------------------------------------------- /imgs/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/imgs/architecture.png -------------------------------------------------------------------------------- /imgs/samples/img_1673.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/imgs/samples/img_1673.png -------------------------------------------------------------------------------- /imgs/samples/img_1673_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/imgs/samples/img_1673_fake_B.png -------------------------------------------------------------------------------- /imgs/samples/img_1682.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/imgs/samples/img_1682.png -------------------------------------------------------------------------------- /imgs/samples/img_1682_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/imgs/samples/img_1682_fake_B.png -------------------------------------------------------------------------------- /imgs/samples/img_1696.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/imgs/samples/img_1696.png -------------------------------------------------------------------------------- /imgs/samples/img_1696_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/imgs/samples/img_1696_fake_B.png -------------------------------------------------------------------------------- /imgs/samples/img_1701.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/imgs/samples/img_1701.png -------------------------------------------------------------------------------- /imgs/samples/img_1701_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/imgs/samples/img_1701_fake_B.png -------------------------------------------------------------------------------- /imgs/samples/img_1794.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/imgs/samples/img_1794.png -------------------------------------------------------------------------------- /imgs/samples/img_1794_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/imgs/samples/img_1794_fake_B.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from models.base_model import BaseModel 3 | 4 | 5 | def find_model_using_name(model_name): 6 | # Given the option --model [modelname], 7 | # the file "models/modelname_model.py" 8 | # will be imported. 9 | model_filename = "models." + model_name + "_model" 10 | modellib = importlib.import_module(model_filename) 11 | 12 | # In the file, the class called ModelNameModel() will 13 | # be instantiated. It has to be a subclass of BaseModel, 14 | # and it is case-insensitive. 15 | model = None 16 | target_model_name = model_name.replace('_', '') + 'model' 17 | for name, cls in modellib.__dict__.items(): 18 | if name.lower() == target_model_name.lower() \ 19 | and issubclass(cls, BaseModel): 20 | model = cls 21 | 22 | if model is None: 23 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 24 | exit(0) 25 | 26 | return model 27 | 28 | 29 | def get_option_setter(model_name): 30 | model_class = find_model_using_name(model_name) 31 | return model_class.modify_commandline_options 32 | 33 | 34 | def create_model(opt): 35 | model = find_model_using_name(opt.model) 36 | instance = model() 37 | instance.initialize(opt) 38 | print("model [%s] was created" % (instance.name())) 39 | return instance 40 | -------------------------------------------------------------------------------- /models/apdrawing_gan_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from util.image_pool import ImagePool 3 | from .base_model import BaseModel 4 | from . import networks 5 | 6 | 7 | class APDrawingGANModel(BaseModel): 8 | def name(self): 9 | return 'APDrawingGANModel' 10 | 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train=True): 13 | 14 | # changing the default values 15 | parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch')# no_lsgan=True, use_lsgan=False 16 | parser.set_defaults(dataset_mode='aligned') 17 | 18 | return parser 19 | 20 | def initialize(self, opt): 21 | BaseModel.initialize(self, opt) 22 | self.isTrain = opt.isTrain 23 | # specify the training losses you want to print out. The program will call base_model.get_current_losses 24 | self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] 25 | if self.isTrain and self.opt.no_l1_loss: 26 | self.loss_names = ['G_GAN', 'D_real', 'D_fake'] 27 | if self.isTrain and self.opt.use_local and not self.opt.no_G_local_loss: 28 | self.loss_names.append('G_local') 29 | if self.isTrain and self.opt.discriminator_local: 30 | self.loss_names.append('D_real_local') 31 | self.loss_names.append('D_fake_local') 32 | self.loss_names.append('G_GAN_local') 33 | if self.isTrain: 34 | self.loss_names.append('G_chamfer') 35 | self.loss_names.append('G_chamfer2') 36 | self.loss_names.append('G') 37 | print('loss_names', self.loss_names) 38 | # specify the images you want to save/display. The program will call base_model.get_current_visuals 39 | self.visual_names = ['real_A', 'fake_B', 'real_B'] 40 | if self.opt.use_local: 41 | self.visual_names += ['fake_B0', 'fake_B1'] 42 | self.visual_names += ['fake_B_hair', 'real_B_hair', 'real_A_hair'] 43 | self.visual_names += ['fake_B_bg', 'real_B_bg', 'real_A_bg'] 44 | if self.isTrain: 45 | self.visual_names += ['dt1', 'dt2', 'dt1gt', 'dt2gt'] 46 | if not self.isTrain and self.opt.save2: 47 | self.visual_names = ['real_A', 'fake_B'] 48 | print('visuals', self.visual_names) 49 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks 50 | if self.isTrain: 51 | self.model_names = ['G', 'D'] 52 | if self.opt.discriminator_local: 53 | self.model_names += ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG'] 54 | # auxiliary nets for loss calculation 55 | self.auxiliary_model_names = ['DT1', 'DT2', 'Line1', 'Line2'] 56 | else: # during test time, only load Gs 57 | self.model_names = ['G'] 58 | self.auxiliary_model_names = [] 59 | if self.opt.use_local: 60 | self.model_names += ['GLEyel','GLEyer','GLNose','GLMouth','GLHair','GLBG','GCombine'] 61 | print('model_names', self.model_names) 62 | print('auxiliary_model_names', self.auxiliary_model_names) 63 | # define networks (both generator and discriminator) 64 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, 65 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 66 | opt.nnG) 67 | print('netG', opt.netG) 68 | 69 | if self.isTrain: 70 | # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc 71 | use_sigmoid = opt.no_lsgan 72 | self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 73 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) 74 | print('netD', opt.netD, opt.n_layers_D) 75 | if self.opt.discriminator_local: 76 | self.netDLEyel = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 77 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) 78 | self.netDLEyer = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 79 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) 80 | self.netDLNose = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 81 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) 82 | self.netDLMouth = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 83 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) 84 | self.netDLHair = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 85 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) 86 | self.netDLBG = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 87 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) 88 | 89 | 90 | if self.opt.use_local: 91 | self.netGLEyel = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet', opt.norm, 92 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 3) 93 | self.netGLEyer = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet', opt.norm, 94 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 3) 95 | self.netGLNose = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet', opt.norm, 96 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 3) 97 | self.netGLMouth = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet', opt.norm, 98 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 3) 99 | self.netGLHair = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet2', opt.norm, 100 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 4) 101 | self.netGLBG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet2', opt.norm, 102 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 4) 103 | self.netGCombine = networks.define_G(2*opt.output_nc, opt.output_nc, opt.ngf, 'combiner', opt.norm, 104 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 2) 105 | 106 | 107 | if self.isTrain: 108 | self.fake_AB_pool = ImagePool(opt.pool_size) 109 | # define loss functions 110 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) 111 | self.criterionL1 = torch.nn.L1Loss() 112 | 113 | # initialize optimizers 114 | self.optimizers = [] 115 | if not self.opt.use_local: 116 | print('G_params 1 components') 117 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), 118 | lr=opt.lr, betas=(opt.beta1, 0.999)) 119 | else: 120 | G_params = list(self.netG.parameters()) + list(self.netGLEyel.parameters()) + list(self.netGLEyer.parameters()) + list(self.netGLNose.parameters()) + list(self.netGLMouth.parameters()) + list(self.netGLHair.parameters()) + list(self.netGLBG.parameters()) + list(self.netGCombine.parameters()) 121 | print('G_params 8 components') 122 | self.optimizer_G = torch.optim.Adam(G_params, 123 | lr=opt.lr, betas=(opt.beta1, 0.999)) 124 | if not self.opt.discriminator_local: 125 | print('D_params 1 components') 126 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), 127 | lr=opt.lr, betas=(opt.beta1, 0.999)) 128 | else: 129 | D_params = list(self.netD.parameters()) + list(self.netDLEyel.parameters()) +list(self.netDLEyer.parameters()) + list(self.netDLNose.parameters()) + list(self.netDLMouth.parameters()) + list(self.netDLHair.parameters()) + list(self.netDLBG.parameters()) 130 | print('D_params 7 components') 131 | self.optimizer_D = torch.optim.Adam(D_params, 132 | lr=opt.lr, betas=(opt.beta1, 0.999)) 133 | self.optimizers.append(self.optimizer_G) 134 | self.optimizers.append(self.optimizer_D) 135 | 136 | # ==================================auxiliary nets (loaded, parameters fixed)============================= 137 | if self.isTrain: 138 | self.nc = 1 139 | self.netDT1 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_dt, opt.norm, 140 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) 141 | self.netDT2 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_dt, opt.norm, 142 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) 143 | self.set_requires_grad(self.netDT1, False) 144 | self.set_requires_grad(self.netDT2, False) 145 | 146 | self.netLine1 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_line, opt.norm, 147 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) 148 | self.netLine2 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_line, opt.norm, 149 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) 150 | self.set_requires_grad(self.netLine1, False) 151 | self.set_requires_grad(self.netLine2, False) 152 | 153 | 154 | def set_input(self, input): 155 | AtoB = self.opt.which_direction == 'AtoB' 156 | self.real_A = input['A' if AtoB else 'B'].to(self.device) 157 | self.real_B = input['B' if AtoB else 'A'].to(self.device) 158 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] 159 | if self.opt.use_local: 160 | self.real_A_eyel = input['eyel_A'].to(self.device) 161 | self.real_A_eyer = input['eyer_A'].to(self.device) 162 | self.real_A_nose = input['nose_A'].to(self.device) 163 | self.real_A_mouth = input['mouth_A'].to(self.device) 164 | self.real_B_eyel = input['eyel_B'].to(self.device) 165 | self.real_B_eyer = input['eyer_B'].to(self.device) 166 | self.real_B_nose = input['nose_B'].to(self.device) 167 | self.real_B_mouth = input['mouth_B'].to(self.device) 168 | self.center = input['center'] 169 | self.real_A_hair = input['hair_A'].to(self.device) 170 | self.real_B_hair = input['hair_B'].to(self.device) 171 | self.real_A_bg = input['bg_A'].to(self.device) 172 | self.real_B_bg = input['bg_B'].to(self.device) 173 | self.mask = input['mask'].to(self.device) # mask for non-eyes,nose,mouth 174 | self.mask2 = input['mask2'].to(self.device) # mask for non-bg 175 | if self.isTrain: 176 | self.dt1gt = input['dt1gt'].to(self.device) 177 | self.dt2gt = input['dt2gt'].to(self.device) 178 | 179 | 180 | def forward(self): 181 | if not self.opt.use_local: 182 | self.fake_B = self.netG(self.real_A) 183 | else: 184 | self.fake_B0 = self.netG(self.real_A) 185 | # EYES, NOSE, MOUTH 186 | fake_B_eyel = self.netGLEyel(self.real_A_eyel) 187 | fake_B_eyer = self.netGLEyer(self.real_A_eyer) 188 | fake_B_nose = self.netGLNose(self.real_A_nose) 189 | fake_B_mouth = self.netGLMouth(self.real_A_mouth) 190 | self.fake_B_nose = fake_B_nose 191 | self.fake_B_eyel = fake_B_eyel 192 | self.fake_B_eyer = fake_B_eyer 193 | self.fake_B_mouth = fake_B_mouth 194 | 195 | # HAIR, BG AND PARTCOMBINE 196 | fake_B_hair = self.netGLHair(self.real_A_hair) 197 | fake_B_bg = self.netGLBG(self.real_A_bg) 198 | self.fake_B_hair = self.masked(fake_B_hair,self.mask*self.mask2) 199 | self.fake_B_bg = self.masked(fake_B_bg,self.inverse_mask(self.mask2)) 200 | self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,fake_B_bg,self.mask*self.mask2,self.inverse_mask(self.mask2),self.opt.comb_op) 201 | 202 | # FUSION NET 203 | self.fake_B = self.netGCombine(torch.cat([self.fake_B0,self.fake_B1],1)) 204 | 205 | 206 | 207 | def backward_D(self): 208 | # Fake 209 | # stop backprop to the generator by detaching fake_B 210 | fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) # we use conditional GANs; we need to feed both input and output to the discriminator 211 | pred_fake = self.netD(fake_AB.detach()) 212 | self.loss_D_fake = self.criterionGAN(pred_fake, False) 213 | if self.opt.discriminator_local: 214 | fake_AB_parts = self.getLocalParts(fake_AB) 215 | local_names = ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG'] 216 | self.loss_D_fake_local = 0 217 | for i in range(len(fake_AB_parts)): 218 | net = getattr(self, 'net' + local_names[i]) 219 | pred_fake_tmp = net(fake_AB_parts[i].detach()) 220 | addw = self.getaddw(local_names[i]) 221 | self.loss_D_fake_local = self.loss_D_fake_local + self.criterionGAN(pred_fake_tmp, False) * addw 222 | self.loss_D_fake = self.loss_D_fake + self.loss_D_fake_local 223 | 224 | # Real 225 | real_AB = torch.cat((self.real_A, self.real_B), 1) 226 | pred_real = self.netD(real_AB) 227 | self.loss_D_real = self.criterionGAN(pred_real, True) 228 | if self.opt.discriminator_local: 229 | real_AB_parts = self.getLocalParts(real_AB) 230 | local_names = ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG'] 231 | self.loss_D_real_local = 0 232 | for i in range(len(real_AB_parts)): 233 | net = getattr(self, 'net' + local_names[i]) 234 | pred_real_tmp = net(real_AB_parts[i]) 235 | addw = self.getaddw(local_names[i]) 236 | self.loss_D_real_local = self.loss_D_real_local + self.criterionGAN(pred_real_tmp, True) * addw 237 | self.loss_D_real = self.loss_D_real + self.loss_D_real_local 238 | 239 | # Combined loss 240 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 241 | 242 | self.loss_D.backward() 243 | 244 | def backward_G(self): 245 | # First, G(A) should fake the discriminator 246 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 247 | pred_fake = self.netD(fake_AB) 248 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) 249 | if self.opt.discriminator_local: 250 | fake_AB_parts = self.getLocalParts(fake_AB) 251 | local_names = ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG'] 252 | self.loss_G_GAN_local = 0 253 | for i in range(len(fake_AB_parts)): 254 | net = getattr(self, 'net' + local_names[i]) 255 | pred_fake_tmp = net(fake_AB_parts[i]) 256 | addw = self.getaddw(local_names[i]) 257 | self.loss_G_GAN_local = self.loss_G_GAN_local + self.criterionGAN(pred_fake_tmp, True) * addw 258 | if self.opt.gan_loss_strategy == 1: 259 | self.loss_G_GAN = (self.loss_G_GAN + self.loss_G_GAN_local) / (len(fake_AB_parts) + 1) 260 | elif self.opt.gan_loss_strategy == 2: 261 | self.loss_G_GAN_local = self.loss_G_GAN_local * 0.25 262 | self.loss_G_GAN = self.loss_G_GAN + self.loss_G_GAN_local 263 | 264 | # Second, G(A) = B 265 | if not self.opt.no_l1_loss: 266 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 267 | 268 | if self.opt.use_local and not self.opt.no_G_local_loss: 269 | local_names = ['eyel','eyer','nose','mouth','hair','bg'] 270 | self.loss_G_local = 0 271 | for i in range(len(local_names)): 272 | fakeblocal = getattr(self, 'fake_B_' + local_names[i]) 273 | realblocal = getattr(self, 'real_B_' + local_names[i]) 274 | addw = self.getaddw(local_names[i]) 275 | self.loss_G_local = self.loss_G_local + self.criterionL1(fakeblocal,realblocal) * self.opt.lambda_local * addw 276 | 277 | # Third, distance transform loss (chamfer matching) 278 | if self.fake_B.shape[1] == 3: 279 | tmp = self.fake_B[:,0,...]*0.299+self.fake_B[:,1,...]*0.587+self.fake_B[:,2,...]*0.114 280 | fake_B_gray = tmp.unsqueeze(1) 281 | else: 282 | fake_B_gray = self.fake_B 283 | if self.real_B.shape[1] == 3: 284 | tmp = self.real_B[:,0,...]*0.299+self.real_B[:,1,...]*0.587+self.real_B[:,2,...]*0.114 285 | real_B_gray = tmp.unsqueeze(1) 286 | else: 287 | real_B_gray = self.real_B 288 | 289 | # d_CM(a_i,G(p_i)) 290 | self.dt1 = self.netDT1(fake_B_gray) 291 | self.dt2 = self.netDT2(fake_B_gray) 292 | dt1 = self.dt1/2.0+0.5#[-1,1]->[0,1] 293 | dt2 = self.dt2/2.0+0.5 294 | 295 | bs = real_B_gray.shape[0] 296 | real_B_gray_line1 = self.netLine1(real_B_gray) 297 | real_B_gray_line2 = self.netLine2(real_B_gray) 298 | self.loss_G_chamfer = (dt1[(real_B_gray<0)&(real_B_gray_line1<0)].sum() + dt2[(real_B_gray>=0)&(real_B_gray_line2>=0)].sum()) / bs * self.opt.lambda_chamfer 299 | 300 | # d_CM(G(p_i),a_i) 301 | dt1gt = self.dt1gt 302 | dt2gt = self.dt2gt 303 | self.dt1gt = (self.dt1gt-0.5)*2 304 | self.dt2gt = (self.dt2gt-0.5)*2 305 | 306 | fake_B_gray_line1 = self.netLine1(fake_B_gray) 307 | fake_B_gray_line2 = self.netLine2(fake_B_gray) 308 | self.loss_G_chamfer2 = (dt1gt[(fake_B_gray<0)&(fake_B_gray_line1<0)].sum() + dt2gt[(fake_B_gray>=0)&(fake_B_gray_line2>=0)].sum()) / bs * self.opt.lambda_chamfer2 309 | 310 | 311 | self.loss_G = self.loss_G_GAN 312 | if 'G_L1' in self.loss_names: 313 | self.loss_G = self.loss_G + self.loss_G_L1 314 | if 'G_local' in self.loss_names: 315 | self.loss_G = self.loss_G + self.loss_G_local 316 | if 'G_chamfer' in self.loss_names: 317 | self.loss_G = self.loss_G + self.loss_G_chamfer 318 | if 'G_chamfer2' in self.loss_names: 319 | self.loss_G = self.loss_G + self.loss_G_chamfer2 320 | 321 | self.loss_G.backward() 322 | 323 | def optimize_parameters(self): 324 | self.forward() 325 | # update D 326 | self.set_requires_grad(self.netD, True) # enable backprop for D 327 | if self.opt.discriminator_local: 328 | self.set_requires_grad(self.netDLEyel, True) 329 | self.set_requires_grad(self.netDLEyer, True) 330 | self.set_requires_grad(self.netDLNose, True) 331 | self.set_requires_grad(self.netDLMouth, True) 332 | self.set_requires_grad(self.netDLHair, True) 333 | self.set_requires_grad(self.netDLBG, True) 334 | self.optimizer_D.zero_grad() 335 | self.backward_D() 336 | self.optimizer_D.step() 337 | 338 | # update G 339 | self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G 340 | if self.opt.discriminator_local: 341 | self.set_requires_grad(self.netDLEyel, False) 342 | self.set_requires_grad(self.netDLEyer, False) 343 | self.set_requires_grad(self.netDLNose, False) 344 | self.set_requires_grad(self.netDLMouth, False) 345 | self.set_requires_grad(self.netDLHair, False) 346 | self.set_requires_grad(self.netDLBG, False) 347 | self.optimizer_G.zero_grad() 348 | self.backward_G() 349 | self.optimizer_G.step() 350 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from . import networks 5 | 6 | 7 | class BaseModel(): 8 | 9 | # modify parser to add command line options, 10 | # and also change the default values if needed 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train): 13 | return parser 14 | 15 | def name(self): 16 | return 'BaseModel' 17 | 18 | def initialize(self, opt): 19 | self.opt = opt 20 | self.gpu_ids = opt.gpu_ids 21 | self.isTrain = opt.isTrain 22 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 23 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 24 | self.auxiliary_dir = os.path.join(opt.checkpoints_dir, opt.auxiliary_root) 25 | if opt.resize_or_crop != 'scale_width': 26 | torch.backends.cudnn.benchmark = True 27 | self.loss_names = [] 28 | self.model_names = [] 29 | self.visual_names = [] 30 | self.image_paths = [] 31 | 32 | def set_input(self, input): 33 | self.input = input 34 | 35 | def forward(self): 36 | pass 37 | 38 | # load and print networks; create schedulers 39 | def setup(self, opt, parser=None): 40 | if self.isTrain: 41 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 42 | 43 | if not self.isTrain or opt.continue_train: 44 | self.load_networks(opt.which_epoch) 45 | if self.isTrain: 46 | self.load_auxiliary_networks() 47 | self.print_networks(opt.verbose) 48 | 49 | # make models eval mode during test time 50 | def eval(self): 51 | for name in self.model_names: 52 | if isinstance(name, str): 53 | net = getattr(self, 'net' + name) 54 | net.eval() 55 | 56 | # used in test time, wrapping `forward` in no_grad() so we don't save 57 | # intermediate steps for backprop 58 | def test(self): 59 | with torch.no_grad(): 60 | self.forward() 61 | 62 | # get image paths 63 | def get_image_paths(self): 64 | return self.image_paths 65 | 66 | def optimize_parameters(self): 67 | pass 68 | 69 | # update learning rate (called once every epoch) 70 | def update_learning_rate(self): 71 | for scheduler in self.schedulers: 72 | scheduler.step() 73 | lr = self.optimizers[0].param_groups[0]['lr'] 74 | print('learning rate = %.7f' % lr) 75 | 76 | # return visualization images. train.py will display these images, and save the images to a html 77 | def get_current_visuals(self): 78 | visual_ret = OrderedDict() 79 | for name in self.visual_names: 80 | if isinstance(name, str): 81 | visual_ret[name] = getattr(self, name) 82 | return visual_ret 83 | 84 | # return traning losses/errors. train.py will print out these errors as debugging information 85 | def get_current_losses(self): 86 | errors_ret = OrderedDict() 87 | for name in self.loss_names: 88 | if isinstance(name, str): 89 | # float(...) works for both scalar tensor and float number 90 | errors_ret[name] = float(getattr(self, 'loss_' + name)) 91 | return errors_ret 92 | 93 | # save models to the disk 94 | def save_networks(self, which_epoch): 95 | for name in self.model_names: 96 | if isinstance(name, str): 97 | save_filename = '%s_net_%s.pth' % (which_epoch, name) 98 | save_path = os.path.join(self.save_dir, save_filename) 99 | net = getattr(self, 'net' + name) 100 | 101 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 102 | torch.save(net.module.cpu().state_dict(), save_path) 103 | net.cuda(self.gpu_ids[0]) 104 | else: 105 | torch.save(net.cpu().state_dict(), save_path) 106 | 107 | # save generators to one file and discriminators to another file 108 | def save_networks2(self, which_epoch): 109 | gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch)) 110 | dis_name = os.path.join(self.save_dir, '%s_net_dis.pt' % (which_epoch)) 111 | dict_gen = {} 112 | dict_dis = {} 113 | for name in self.model_names: 114 | if isinstance(name, str): 115 | net = getattr(self, 'net' + name) 116 | 117 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 118 | state_dict = net.module.cpu().state_dict() 119 | net.cuda(self.gpu_ids[0]) 120 | else: 121 | state_dict = net.cpu().state_dict() 122 | 123 | if name[0] == 'G': 124 | dict_gen[name] = state_dict 125 | elif name[0] == 'D': 126 | dict_dis[name] = state_dict 127 | torch.save(dict_gen, gen_name) 128 | torch.save(dict_dis, dis_name) 129 | 130 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 131 | key = keys[i] 132 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 133 | if module.__class__.__name__.startswith('InstanceNorm') and \ 134 | (key == 'running_mean' or key == 'running_var'): 135 | if getattr(module, key) is None: 136 | state_dict.pop('.'.join(keys)) 137 | if module.__class__.__name__.startswith('InstanceNorm') and \ 138 | (key == 'num_batches_tracked'): 139 | state_dict.pop('.'.join(keys)) 140 | else: 141 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 142 | 143 | # load models from the disk 144 | def load_networks(self, which_epoch): 145 | gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch)) 146 | if os.path.exists(gen_name): 147 | self.load_networks2(which_epoch) 148 | return 149 | for name in self.model_names: 150 | if isinstance(name, str): 151 | load_filename = '%s_net_%s.pth' % (which_epoch, name) 152 | load_path = os.path.join(self.save_dir, load_filename) 153 | net = getattr(self, 'net' + name) 154 | if isinstance(net, torch.nn.DataParallel): 155 | net = net.module 156 | print('loading the model from %s' % load_path) 157 | # if you are using PyTorch newer than 0.4 (e.g., built from 158 | # GitHub source), you can remove str() on self.device 159 | state_dict = torch.load(load_path, map_location=str(self.device)) 160 | if hasattr(state_dict, '_metadata'): 161 | del state_dict._metadata 162 | 163 | # patch InstanceNorm checkpoints prior to 0.4 164 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 165 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 166 | net.load_state_dict(state_dict) 167 | 168 | def load_networks2(self, which_epoch): 169 | gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch)) 170 | gen_state_dict = torch.load(gen_name, map_location=str(self.device)) 171 | if self.isTrain: 172 | dis_name = os.path.join(self.save_dir, '%s_net_dis.pt' % (which_epoch)) 173 | dis_state_dict = torch.load(dis_name, map_location=str(self.device)) 174 | for name in self.model_names: 175 | if isinstance(name, str): 176 | net = getattr(self, 'net' + name) 177 | if isinstance(net, torch.nn.DataParallel): 178 | net = net.module 179 | if name[0] == 'G': 180 | print('loading the model from %s' % gen_name) 181 | state_dict = gen_state_dict[name] 182 | elif name[0] == 'D': 183 | print('loading the model from %s' % dis_name) 184 | state_dict = dis_state_dict[name] 185 | 186 | if hasattr(state_dict, '_metadata'): 187 | del state_dict._metadata 188 | # patch InstanceNorm checkpoints prior to 0.4 189 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 190 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 191 | net.load_state_dict(state_dict) 192 | 193 | # load auxiliary net models from the disk 194 | def load_auxiliary_networks(self): 195 | for name in self.auxiliary_model_names: 196 | if isinstance(name, str): 197 | load_filename = '%s_net_%s.pth' % ('latest', name) 198 | load_path = os.path.join(self.auxiliary_dir, load_filename) 199 | net = getattr(self, 'net' + name) 200 | if isinstance(net, torch.nn.DataParallel): 201 | net = net.module 202 | print('loading the model from %s' % load_path) 203 | # if you are using PyTorch newer than 0.4 (e.g., built from 204 | # GitHub source), you can remove str() on self.device 205 | state_dict = torch.load(load_path, map_location=str(self.device)) 206 | if hasattr(state_dict, '_metadata'): 207 | del state_dict._metadata 208 | 209 | # patch InstanceNorm checkpoints prior to 0.4 210 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 211 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 212 | net.load_state_dict(state_dict) 213 | 214 | # print network information 215 | def print_networks(self, verbose): 216 | print('---------- Networks initialized -------------') 217 | for name in self.model_names: 218 | if isinstance(name, str): 219 | net = getattr(self, 'net' + name) 220 | num_params = 0 221 | for param in net.parameters(): 222 | num_params += param.numel() 223 | if verbose: 224 | print(net) 225 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 226 | print('-----------------------------------------------') 227 | 228 | # set requies_grad=Fasle to avoid computation 229 | def set_requires_grad(self, nets, requires_grad=False): 230 | if not isinstance(nets, list): 231 | nets = [nets] 232 | for net in nets: 233 | if net is not None: 234 | for param in net.parameters(): 235 | param.requires_grad = requires_grad 236 | 237 | # ============================================================================================================= 238 | def inverse_mask(self, mask): 239 | return torch.ones(mask.shape).to(self.device)-mask 240 | 241 | def masked(self, A,mask): 242 | return (A/2+0.5)*mask*2-1 243 | 244 | def add_with_mask(self, A,B,mask): 245 | return ((A/2+0.5)*mask+(B/2+0.5)*(torch.ones(mask.shape).to(self.device)-mask))*2-1 246 | 247 | def addone_with_mask(self, A,mask): 248 | return ((A/2+0.5)*mask+(torch.ones(mask.shape).to(self.device)-mask))*2-1 249 | 250 | def partCombiner2(self, eyel, eyer, nose, mouth, hair, mask, comb_op = 1): 251 | if comb_op == 0: 252 | # use max pooling, pad black for eyes etc 253 | padvalue = -1 254 | hair = self.masked(hair, mask) 255 | else: 256 | # use min pooling, pad white for eyes etc 257 | padvalue = 1 258 | hair = self.addone_with_mask(hair, mask) 259 | IMAGE_SIZE = self.opt.fineSize 260 | ratio = IMAGE_SIZE / 256 261 | EYE_W = self.opt.EYE_W * ratio 262 | EYE_H = self.opt.EYE_H * ratio 263 | NOSE_W = self.opt.NOSE_W * ratio 264 | NOSE_H = self.opt.NOSE_H * ratio 265 | MOUTH_W = self.opt.MOUTH_W * ratio 266 | MOUTH_H = self.opt.MOUTH_H * ratio 267 | bs,nc,_,_ = eyel.shape 268 | eyel_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 269 | eyer_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 270 | nose_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 271 | mouth_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 272 | for i in range(bs): 273 | center = self.center[i]#x,y 274 | eyel_p[i] = torch.nn.ConstantPad2d((center[0,0] - EYE_W / 2, IMAGE_SIZE - (center[0,0]+EYE_W/2), center[0,1] - EYE_H / 2, IMAGE_SIZE - (center[0,1]+EYE_H/2)),padvalue)(eyel[i]) 275 | eyer_p[i] = torch.nn.ConstantPad2d((center[1,0] - EYE_W / 2, IMAGE_SIZE - (center[1,0]+EYE_W/2), center[1,1] - EYE_H / 2, IMAGE_SIZE - (center[1,1]+EYE_H/2)),padvalue)(eyer[i]) 276 | nose_p[i] = torch.nn.ConstantPad2d((center[2,0] - NOSE_W / 2, IMAGE_SIZE - (center[2,0]+NOSE_W/2), center[2,1] - NOSE_H / 2, IMAGE_SIZE - (center[2,1]+NOSE_H/2)),padvalue)(nose[i]) 277 | mouth_p[i] = torch.nn.ConstantPad2d((center[3,0] - MOUTH_W / 2, IMAGE_SIZE - (center[3,0]+MOUTH_W/2), center[3,1] - MOUTH_H / 2, IMAGE_SIZE - (center[3,1]+MOUTH_H/2)),padvalue)(mouth[i]) 278 | if comb_op == 0: 279 | # use max pooling 280 | eyes = torch.max(eyel_p, eyer_p) 281 | eye_nose = torch.max(eyes, nose_p) 282 | eye_nose_mouth = torch.max(eye_nose, mouth_p) 283 | result = torch.max(hair,eye_nose_mouth) 284 | else: 285 | # use min pooling 286 | eyes = torch.min(eyel_p, eyer_p) 287 | eye_nose = torch.min(eyes, nose_p) 288 | eye_nose_mouth = torch.min(eye_nose, mouth_p) 289 | result = torch.min(hair,eye_nose_mouth) 290 | return result 291 | 292 | def partCombiner2_bg(self, eyel, eyer, nose, mouth, hair, bg, maskh, maskb, comb_op = 1): 293 | if comb_op == 0: 294 | # use max pooling, pad black for eyes etc 295 | padvalue = -1 296 | hair = self.masked(hair, maskh) 297 | bg = self.masked(bg, maskb) 298 | else: 299 | # use min pooling, pad white for eyes etc 300 | padvalue = 1 301 | hair = self.addone_with_mask(hair, maskh) 302 | bg = self.addone_with_mask(bg, maskb) 303 | IMAGE_SIZE = self.opt.fineSize 304 | ratio = IMAGE_SIZE / 256 305 | EYE_W = self.opt.EYE_W * ratio 306 | EYE_H = self.opt.EYE_H * ratio 307 | NOSE_W = self.opt.NOSE_W * ratio 308 | NOSE_H = self.opt.NOSE_H * ratio 309 | MOUTH_W = self.opt.MOUTH_W * ratio 310 | MOUTH_H = self.opt.MOUTH_H * ratio 311 | bs,nc,_,_ = eyel.shape 312 | eyel_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 313 | eyer_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 314 | nose_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 315 | mouth_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 316 | for i in range(bs): 317 | center = self.center[i]#x,y 318 | eyel_p[i] = torch.nn.ConstantPad2d((int(center[0,0] - EYE_W / 2), int(IMAGE_SIZE - (center[0,0]+EYE_W/2)), int(center[0,1] - EYE_H / 2), int(IMAGE_SIZE - (center[0,1]+EYE_H/2))),padvalue)(eyel[i]) 319 | eyer_p[i] = torch.nn.ConstantPad2d((int(center[1,0] - EYE_W / 2), int(IMAGE_SIZE - (center[1,0]+EYE_W/2)), int(center[1,1] - EYE_H / 2), int(IMAGE_SIZE - (center[1,1]+EYE_H/2))), padvalue)(eyer[i]) 320 | nose_p[i] = torch.nn.ConstantPad2d((int(center[2,0] - NOSE_W / 2), int(IMAGE_SIZE - (center[2,0]+NOSE_W/2)), int(center[2,1] - NOSE_H / 2), int(IMAGE_SIZE - (center[2,1]+NOSE_H/2))),padvalue)(nose[i]) 321 | mouth_p[i] = torch.nn.ConstantPad2d((int(center[3,0] - MOUTH_W / 2), int(IMAGE_SIZE - (center[3,0]+MOUTH_W/2)), int(center[3,1] - MOUTH_H / 2), int(IMAGE_SIZE - (center[3,1]+MOUTH_H/2))),padvalue)(mouth[i]) 322 | if comb_op == 0: 323 | eyes = torch.max(eyel_p, eyer_p) 324 | eye_nose = torch.max(eyes, nose_p) 325 | eye_nose_mouth = torch.max(eye_nose, mouth_p) 326 | eye_nose_mouth_hair = torch.max(hair,eye_nose_mouth) 327 | result = torch.max(bg,eye_nose_mouth_hair) 328 | else: 329 | eyes = torch.min(eyel_p, eyer_p) 330 | eye_nose = torch.min(eyes, nose_p) 331 | eye_nose_mouth = torch.min(eye_nose, mouth_p) 332 | eye_nose_mouth_hair = torch.min(hair,eye_nose_mouth) 333 | result = torch.min(bg,eye_nose_mouth_hair) 334 | return result 335 | 336 | def partCombiner3(self, face, hair, maskf, maskh, comb_op = 1): 337 | if comb_op == 0: 338 | # use max pooling, pad black etc 339 | padvalue = -1 340 | face = self.masked(face, maskf) 341 | hair = self.masked(hair, maskh) 342 | else: 343 | # use min pooling, pad white etc 344 | padvalue = 1 345 | face = self.addone_with_mask(face, maskf) 346 | hair = self.addone_with_mask(hair, maskh) 347 | if comb_op == 0: 348 | result = torch.max(face,hair) 349 | else: 350 | result = torch.min(face,hair) 351 | return result 352 | 353 | def getLocalParts(self,fakeAB): 354 | bs,nc,_,_ = fakeAB.shape #dtype torch.float32 355 | ncr = nc // self.opt.output_nc 356 | ratio = self.opt.fineSize // 256 357 | EYE_H = self.opt.EYE_H * ratio 358 | EYE_W = self.opt.EYE_W * ratio 359 | NOSE_H = self.opt.NOSE_H * ratio 360 | NOSE_W = self.opt.NOSE_W * ratio 361 | MOUTH_H = self.opt.MOUTH_H * ratio 362 | MOUTH_W = self.opt.MOUTH_W * ratio 363 | eyel = torch.ones((bs,nc,EYE_H,EYE_W)).to(self.device) 364 | eyer = torch.ones((bs,nc,EYE_H,EYE_W)).to(self.device) 365 | nose = torch.ones((bs,nc,NOSE_H,NOSE_W)).to(self.device) 366 | mouth = torch.ones((bs,nc,MOUTH_H,MOUTH_W)).to(self.device) 367 | for i in range(bs): 368 | center = self.center[i] 369 | eyel[i] = fakeAB[i,:,center[0,1]-EYE_H//2:center[0,1]+EYE_H//2,center[0,0]-EYE_W//2:center[0,0]+EYE_W//2] 370 | eyer[i] = fakeAB[i,:,center[1,1]-EYE_H//2:center[1,1]+EYE_H//2,center[1,0]-EYE_W//2:center[1,0]+EYE_W//2] 371 | nose[i] = fakeAB[i,:,center[2,1]-NOSE_H//2:center[2,1]+NOSE_H//2,center[2,0]-NOSE_W//2:center[2,0]+NOSE_W//2] 372 | mouth[i] = fakeAB[i,:,center[3,1]-MOUTH_H//2:center[3,1]+MOUTH_H//2,center[3,0]-MOUTH_W//2:center[3,0]+MOUTH_W//2] 373 | hair = (fakeAB/2+0.5) * self.mask.repeat(1,ncr,1,1) * self.mask2.repeat(1,ncr,1,1) * 2 - 1 374 | bg = (fakeAB/2+0.5) * (torch.ones(fakeAB.shape).to(self.device)-self.mask2.repeat(1,ncr,1,1)) * 2 - 1 375 | return eyel, eyer, nose, mouth, hair, bg 376 | 377 | def getaddw(self,local_name): 378 | addw = 1 379 | if local_name in ['DLEyel','DLEyer','eyel','eyer']: 380 | addw = self.opt.addw_eye 381 | elif local_name in ['DLNose', 'nose']: 382 | addw = self.opt.addw_nose 383 | elif local_name in ['DLMouth', 'mouth']: 384 | addw = self.opt.addw_mouth 385 | elif local_name in ['DLHair', 'hair']: 386 | addw = self.opt.addw_hair 387 | elif local_name in ['DLBG', 'bg']: 388 | addw = self.opt.addw_bg 389 | return addw 390 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | 7 | ############################################################################### 8 | # Helper Functions 9 | ############################################################################### 10 | 11 | 12 | def get_norm_layer(norm_type='instance'): 13 | if norm_type == 'batch': 14 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 15 | elif norm_type == 'instance': 16 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True) 17 | elif norm_type == 'none': 18 | norm_layer = None 19 | else: 20 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 21 | return norm_layer 22 | 23 | 24 | def get_scheduler(optimizer, opt): 25 | if opt.lr_policy == 'lambda': 26 | def lambda_rule(epoch): 27 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 28 | return lr_l 29 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 30 | elif opt.lr_policy == 'step': 31 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 32 | elif opt.lr_policy == 'plateau': 33 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 34 | elif opt.lr_policy == 'cosine': 35 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.niter, eta_min=0) 36 | else: 37 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 38 | return scheduler 39 | 40 | 41 | def init_weights(net, init_type='normal', gain=0.02): 42 | def init_func(m): 43 | classname = m.__class__.__name__ 44 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 45 | if init_type == 'normal': 46 | init.normal_(m.weight.data, 0.0, gain) 47 | elif init_type == 'xavier': 48 | init.xavier_normal_(m.weight.data, gain=gain) 49 | elif init_type == 'kaiming': 50 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 51 | elif init_type == 'orthogonal': 52 | init.orthogonal_(m.weight.data, gain=gain) 53 | else: 54 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 55 | if hasattr(m, 'bias') and m.bias is not None: 56 | init.constant_(m.bias.data, 0.0) 57 | elif classname.find('BatchNorm2d') != -1: 58 | init.normal_(m.weight.data, 1.0, gain) 59 | init.constant_(m.bias.data, 0.0) 60 | 61 | print('initialize network with %s' % init_type) 62 | net.apply(init_func) 63 | 64 | 65 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 66 | if len(gpu_ids) > 0: 67 | assert(torch.cuda.is_available()) 68 | net.to(gpu_ids[0]) 69 | net = torch.nn.DataParallel(net, gpu_ids) 70 | init_weights(net, init_type, gain=init_gain) 71 | return net 72 | 73 | 74 | def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], nnG=9): 75 | net = None 76 | norm_layer = get_norm_layer(norm_type=norm) 77 | 78 | if netG == 'resnet_9blocks': 79 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) 80 | elif netG == 'resnet_6blocks': 81 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) 82 | elif netG == 'resnet_nblocks': 83 | net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=nnG) 84 | elif netG == 'unet_128': 85 | net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 86 | elif netG == 'unet_256':#default for pix2pix 87 | net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 88 | elif netG == 'unet_512': 89 | net = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 90 | elif netG == 'unet_ndown': 91 | net = UnetGenerator(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 92 | elif netG == 'partunet': 93 | net = PartUnet(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 94 | elif netG == 'partunet2': 95 | net = PartUnet2(input_nc, output_nc, nnG, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 96 | elif netG == 'combiner': 97 | net = Combiner(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=2) 98 | else: 99 | raise NotImplementedError('Generator model name [%s] is not recognized' % netG) 100 | return init_net(net, init_type, init_gain, gpu_ids) 101 | 102 | 103 | def define_D(input_nc, ndf, netD, 104 | n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_ids=[]): 105 | net = None 106 | norm_layer = get_norm_layer(norm_type=norm) 107 | 108 | if netD == 'basic': 109 | net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid) 110 | elif netD == 'n_layers': 111 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid) 112 | elif netD == 'pixel': 113 | net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid) 114 | else: 115 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % net) 116 | return init_net(net, init_type, init_gain, gpu_ids) 117 | 118 | 119 | ############################################################################## 120 | # Classes 121 | ############################################################################## 122 | 123 | 124 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 125 | # When LSGAN is used, it is basically same as MSELoss, 126 | # but it abstracts away the need to create the target label tensor 127 | # that has the same size as the input 128 | class GANLoss(nn.Module): 129 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): 130 | super(GANLoss, self).__init__() 131 | self.register_buffer('real_label', torch.tensor(target_real_label)) 132 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 133 | if use_lsgan: 134 | self.loss = nn.MSELoss() 135 | else:#no_lsgan 136 | self.loss = nn.BCELoss() 137 | 138 | def get_target_tensor(self, input, target_is_real): 139 | if target_is_real: 140 | target_tensor = self.real_label 141 | else: 142 | target_tensor = self.fake_label 143 | return target_tensor.expand_as(input) 144 | 145 | def __call__(self, input, target_is_real): 146 | target_tensor = self.get_target_tensor(input, target_is_real) 147 | return self.loss(input, target_tensor) 148 | 149 | # Defines the generator that consists of Resnet blocks between a few 150 | # downsampling/upsampling operations. 151 | # Code and idea originally from Justin Johnson's architecture. 152 | # https://github.com/jcjohnson/fast-neural-style/ 153 | class ResnetGenerator(nn.Module): 154 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): 155 | assert(n_blocks >= 0) 156 | super(ResnetGenerator, self).__init__() 157 | self.input_nc = input_nc 158 | self.output_nc = output_nc 159 | self.ngf = ngf 160 | if type(norm_layer) == functools.partial: 161 | use_bias = norm_layer.func == nn.InstanceNorm2d 162 | else: 163 | use_bias = norm_layer == nn.InstanceNorm2d 164 | 165 | model = [nn.ReflectionPad2d(3), 166 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, 167 | bias=use_bias), 168 | norm_layer(ngf), 169 | nn.ReLU(True)] 170 | 171 | n_downsampling = 2 172 | for i in range(n_downsampling): 173 | mult = 2**i 174 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, 175 | stride=2, padding=1, bias=use_bias), 176 | norm_layer(ngf * mult * 2), 177 | nn.ReLU(True)] 178 | 179 | mult = 2**n_downsampling 180 | for i in range(n_blocks): 181 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 182 | 183 | for i in range(n_downsampling): 184 | mult = 2**(n_downsampling - i) 185 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 186 | kernel_size=3, stride=2, 187 | padding=1, output_padding=1, 188 | bias=use_bias), 189 | norm_layer(int(ngf * mult / 2)), 190 | nn.ReLU(True)] 191 | model += [nn.ReflectionPad2d(3)] 192 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 193 | model += [nn.Tanh()] 194 | 195 | self.model = nn.Sequential(*model) 196 | 197 | def forward(self, input): 198 | return self.model(input) 199 | 200 | class Combiner(nn.Module): 201 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): 202 | assert(n_blocks >= 0) 203 | super(Combiner, self).__init__() 204 | self.input_nc = input_nc 205 | self.output_nc = output_nc 206 | self.ngf = ngf 207 | if type(norm_layer) == functools.partial: 208 | use_bias = norm_layer.func == nn.InstanceNorm2d 209 | else: 210 | use_bias = norm_layer == nn.InstanceNorm2d 211 | 212 | model = [nn.ReflectionPad2d(3), 213 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, 214 | bias=use_bias), 215 | norm_layer(ngf), 216 | nn.ReLU(True)] 217 | 218 | for i in range(n_blocks): 219 | model += [ResnetBlock(ngf, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 220 | 221 | model += [nn.ReflectionPad2d(3)] 222 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 223 | model += [nn.Tanh()] 224 | 225 | self.model = nn.Sequential(*model) 226 | 227 | def forward(self, input): 228 | return self.model(input) 229 | 230 | # Define a resnet block 231 | class ResnetBlock(nn.Module): 232 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 233 | super(ResnetBlock, self).__init__() 234 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 235 | 236 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 237 | conv_block = [] 238 | p = 0 239 | if padding_type == 'reflect': 240 | conv_block += [nn.ReflectionPad2d(1)] 241 | elif padding_type == 'replicate': 242 | conv_block += [nn.ReplicationPad2d(1)] 243 | elif padding_type == 'zero': 244 | p = 1 245 | else: 246 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 247 | 248 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 249 | norm_layer(dim), 250 | nn.ReLU(True)] 251 | if use_dropout: 252 | conv_block += [nn.Dropout(0.5)] 253 | 254 | p = 0 255 | if padding_type == 'reflect': 256 | conv_block += [nn.ReflectionPad2d(1)] 257 | elif padding_type == 'replicate': 258 | conv_block += [nn.ReplicationPad2d(1)] 259 | elif padding_type == 'zero': 260 | p = 1 261 | else: 262 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 263 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 264 | norm_layer(dim)] 265 | 266 | return nn.Sequential(*conv_block) 267 | 268 | def forward(self, x): 269 | out = x + self.conv_block(x) 270 | return out 271 | 272 | 273 | # Defines the Unet generator. 274 | # |num_downs|: number of downsamplings in UNet. For example, 275 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1 276 | # at the bottleneck 277 | class UnetGenerator(nn.Module): 278 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, 279 | norm_layer=nn.BatchNorm2d, use_dropout=False): 280 | super(UnetGenerator, self).__init__() 281 | 282 | # construct unet structure 283 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) 284 | for i in range(num_downs - 5): 285 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 286 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 287 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 288 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 289 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) 290 | 291 | self.model = unet_block 292 | 293 | def forward(self, input): 294 | return self.model(input) 295 | 296 | class PartUnet(nn.Module): 297 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, 298 | norm_layer=nn.BatchNorm2d, use_dropout=False): 299 | super(PartUnet, self).__init__() 300 | 301 | # construct unet structure 302 | # 3 downs 303 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) 304 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 305 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) 306 | 307 | self.model = unet_block 308 | 309 | def forward(self, input): 310 | return self.model(input) 311 | 312 | class PartUnet2(nn.Module): 313 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, 314 | norm_layer=nn.BatchNorm2d, use_dropout=False): 315 | super(PartUnet2, self).__init__() 316 | 317 | # construct unet structure 318 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 2, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) 319 | for i in range(num_downs - 3): 320 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 321 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 322 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) 323 | 324 | self.model = unet_block 325 | 326 | def forward(self, input): 327 | return self.model(input) 328 | 329 | 330 | # Defines the submodule with skip connection. 331 | # X -------------------identity---------------------- X 332 | # |-- downsampling -- |submodule| -- upsampling --| 333 | class UnetSkipConnectionBlock(nn.Module): 334 | def __init__(self, outer_nc, inner_nc, input_nc=None, 335 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 336 | super(UnetSkipConnectionBlock, self).__init__() 337 | self.outermost = outermost 338 | if type(norm_layer) == functools.partial: 339 | use_bias = norm_layer.func == nn.InstanceNorm2d 340 | else: 341 | use_bias = norm_layer == nn.InstanceNorm2d 342 | if input_nc is None: 343 | input_nc = outer_nc 344 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 345 | stride=2, padding=1, bias=use_bias) 346 | downrelu = nn.LeakyReLU(0.2, True) 347 | downnorm = norm_layer(inner_nc) 348 | uprelu = nn.ReLU(True) 349 | upnorm = norm_layer(outer_nc) 350 | 351 | if outermost: 352 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 353 | kernel_size=4, stride=2, 354 | padding=1) 355 | down = [downconv] 356 | up = [uprelu, upconv, nn.Tanh()] 357 | model = down + [submodule] + up 358 | elif innermost: 359 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 360 | kernel_size=4, stride=2, 361 | padding=1, bias=use_bias) 362 | down = [downrelu, downconv] 363 | up = [uprelu, upconv, upnorm] 364 | model = down + up 365 | else: 366 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 367 | kernel_size=4, stride=2, 368 | padding=1, bias=use_bias) 369 | down = [downrelu, downconv, downnorm] 370 | up = [uprelu, upconv, upnorm] 371 | 372 | if use_dropout: 373 | model = down + [submodule] + up + [nn.Dropout(0.5)] 374 | else: 375 | model = down + [submodule] + up 376 | 377 | self.model = nn.Sequential(*model) 378 | 379 | def forward(self, x): 380 | if self.outermost: 381 | return self.model(x) 382 | else: 383 | return torch.cat([x, self.model(x)], 1) 384 | 385 | 386 | # Defines the PatchGAN discriminator with the specified arguments. 387 | class NLayerDiscriminator(nn.Module): 388 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False): 389 | super(NLayerDiscriminator, self).__init__() 390 | if type(norm_layer) == functools.partial: 391 | use_bias = norm_layer.func == nn.InstanceNorm2d 392 | else: 393 | use_bias = norm_layer == nn.InstanceNorm2d 394 | 395 | kw = 4 396 | padw = 1 397 | sequence = [ 398 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 399 | nn.LeakyReLU(0.2, True) 400 | ] 401 | 402 | nf_mult = 1 403 | nf_mult_prev = 1 404 | for n in range(1, n_layers): 405 | nf_mult_prev = nf_mult 406 | nf_mult = min(2**n, 8) 407 | sequence += [ 408 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 409 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), 410 | norm_layer(ndf * nf_mult), 411 | nn.LeakyReLU(0.2, True) 412 | ] 413 | 414 | nf_mult_prev = nf_mult 415 | nf_mult = min(2**n_layers, 8) 416 | sequence += [ 417 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 418 | kernel_size=kw, stride=1, padding=padw, bias=use_bias), 419 | norm_layer(ndf * nf_mult), 420 | nn.LeakyReLU(0.2, True) 421 | ] 422 | 423 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 424 | 425 | if use_sigmoid:#no_lsgan, use sigmoid before calculating bceloss(binary cross entropy) 426 | sequence += [nn.Sigmoid()] 427 | 428 | self.model = nn.Sequential(*sequence) 429 | 430 | def forward(self, input): 431 | return self.model(input) 432 | 433 | 434 | class PixelDiscriminator(nn.Module): 435 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False): 436 | super(PixelDiscriminator, self).__init__() 437 | if type(norm_layer) == functools.partial: 438 | use_bias = norm_layer.func == nn.InstanceNorm2d 439 | else: 440 | use_bias = norm_layer == nn.InstanceNorm2d 441 | 442 | self.net = [ 443 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 444 | nn.LeakyReLU(0.2, True), 445 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), 446 | norm_layer(ndf * 2), 447 | nn.LeakyReLU(0.2, True), 448 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 449 | 450 | if use_sigmoid: 451 | self.net.append(nn.Sigmoid()) 452 | 453 | self.net = nn.Sequential(*self.net) 454 | 455 | def forward(self, input): 456 | return self.net(input) 457 | -------------------------------------------------------------------------------- /models/test_model.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from . import networks 3 | import torch 4 | 5 | 6 | class TestModel(BaseModel): 7 | def name(self): 8 | return 'TestModel' 9 | 10 | @staticmethod 11 | def modify_commandline_options(parser, is_train=True): 12 | assert not is_train, 'TestModel cannot be used in train mode' 13 | 14 | parser.set_defaults(dataset_mode='single') 15 | 16 | return parser 17 | 18 | def initialize(self, opt): 19 | assert(not opt.isTrain) 20 | BaseModel.initialize(self, opt) 21 | 22 | # specify the training losses you want to print out. The program will call base_model.get_current_losses 23 | self.loss_names = [] 24 | # specify the images you want to save/display. The program will call base_model.get_current_visuals 25 | self.visual_names = ['real_A', 'fake_B'] 26 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks 27 | self.model_names = ['G'] 28 | self.auxiliary_model_names = [] 29 | if self.opt.use_local: 30 | self.model_names += ['GLEyel','GLEyer','GLNose','GLMouth','GLHair','GLBG','GCombine'] 31 | 32 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, 33 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 34 | opt.nnG) 35 | if self.opt.use_local: 36 | self.netGLEyel = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet', opt.norm, 37 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 3) 38 | self.netGLEyer = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet', opt.norm, 39 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 3) 40 | self.netGLNose = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet', opt.norm, 41 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 3) 42 | self.netGLMouth = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet', opt.norm, 43 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 3) 44 | self.netGLHair = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet2', opt.norm, 45 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 4) 46 | self.netGLBG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 'partunet2', opt.norm, 47 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 4) 48 | self.netGCombine = networks.define_G(2*opt.output_nc, opt.output_nc, opt.ngf, 'combiner', opt.norm, 49 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 2) 50 | 51 | 52 | def set_input(self, input): 53 | # we need to use single_dataset mode 54 | self.real_A = input['A'].to(self.device) 55 | self.image_paths = input['A_paths'] 56 | if self.opt.use_local: 57 | self.real_A_eyel = input['eyel_A'].to(self.device) 58 | self.real_A_eyer = input['eyer_A'].to(self.device) 59 | self.real_A_nose = input['nose_A'].to(self.device) 60 | self.real_A_mouth = input['mouth_A'].to(self.device) 61 | self.center = input['center'] 62 | self.real_A_hair = input['hair_A'].to(self.device) 63 | self.real_A_bg = input['bg_A'].to(self.device) 64 | self.mask = input['mask'].to(self.device) 65 | self.mask2 = input['mask2'].to(self.device) 66 | 67 | def forward(self): 68 | if not self.opt.use_local: 69 | self.fake_B = self.netG(self.real_A) 70 | else: 71 | self.fake_B0 = self.netG(self.real_A) 72 | # EYES, NOSE, MOUTH 73 | fake_B_eyel = self.netGLEyel(self.real_A_eyel) 74 | fake_B_eyer = self.netGLEyer(self.real_A_eyer) 75 | fake_B_nose = self.netGLNose(self.real_A_nose) 76 | fake_B_mouth = self.netGLMouth(self.real_A_mouth) 77 | 78 | # HAIR, BG AND PARTCOMBINE 79 | fake_B_hair = self.netGLHair(self.real_A_hair) 80 | fake_B_bg = self.netGLBG(self.real_A_bg) 81 | self.fake_B_hair = self.masked(fake_B_hair,self.mask*self.mask2) 82 | self.fake_B_bg = self.masked(fake_B_bg,self.inverse_mask(self.mask2)) 83 | self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,fake_B_bg,self.mask*self.mask2,self.inverse_mask(self.mask2),self.opt.comb_op) 84 | 85 | # FUSION NET 86 | self.fake_B = self.netGCombine(torch.cat([self.fake_B0,self.fake_B1],1)) 87 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | def __init__(self): 11 | self.initialized = False 12 | 13 | def initialize(self, parser): 14 | parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders train, test etc)') 15 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 16 | parser.add_argument('--loadSize', type=int, default=512, help='scale images to this size') 17 | parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size') 18 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 19 | parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels') 20 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 21 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 22 | parser.add_argument('--netD', type=str, default='basic', help='selects model to use for netD') 23 | parser.add_argument('--netG', type=str, default='unet_256', help='selects model to use for netG') 24 | parser.add_argument('--nnG', type=int, default=9, help='specify nblock for resnet_nblocks, ndown for unet for unet_ndown') 25 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') 26 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 27 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 28 | parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [aligned | single]') 29 | parser.add_argument('--model', type=str, default='apdrawing_gan', 30 | help='chooses which model to use. [apdrawing_gan | test]') 31 | parser.add_argument('--use_local', action='store_true', help='use local part network') 32 | parser.add_argument('--comb_op', type=int, default=1, help='use min-pooling(1) or max-pooling(0) for overlapping regions') 33 | parser.add_argument('--lm_dir', type=str, default='dataset/landmark/ALL', help='path to facial landmarks') 34 | parser.add_argument('--bg_dir', type=str, default='dataset/mask/ALL', help='path to background masks') 35 | parser.add_argument('--soft_border', type=int, default=0, help='use mask with soft border') 36 | parser.add_argument('--EYE_H', type=int, default=40, help='EYE_H') 37 | parser.add_argument('--EYE_W', type=int, default=56, help='EYE_W') 38 | parser.add_argument('--NOSE_H', type=int, default=48, help='NOSE_H') 39 | parser.add_argument('--NOSE_W', type=int, default=48, help='NOSE_W') 40 | parser.add_argument('--MOUTH_H', type=int, default=40, help='MOUTH_H') 41 | parser.add_argument('--MOUTH_W', type=int, default=64, help='MOUTH_W') 42 | parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA') 43 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 44 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 45 | parser.add_argument('--auxiliary_root', type=str, default='auxiliary', help='auxiliary model folder') 46 | parser.add_argument('--norm', type=str, default='batch', help='instance normalization or batch normalization') 47 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 48 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 49 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 50 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 51 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 52 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 53 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 54 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 55 | parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 56 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 57 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 58 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 59 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 60 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{loadSize}') 61 | self.initialized = True 62 | return parser 63 | 64 | def gather_options(self): 65 | # initialize parser with basic options 66 | if not self.initialized: 67 | parser = argparse.ArgumentParser( 68 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 69 | parser = self.initialize(parser) 70 | 71 | # get the basic options 72 | opt, _ = parser.parse_known_args() 73 | 74 | # modify model-related parser options 75 | model_name = opt.model 76 | model_option_setter = models.get_option_setter(model_name) 77 | parser = model_option_setter(parser, self.isTrain) 78 | opt, _ = parser.parse_known_args() # parse again with the new defaults 79 | 80 | # modify dataset-related parser options 81 | dataset_name = opt.dataset_mode 82 | dataset_option_setter = data.get_option_setter(dataset_name) 83 | parser = dataset_option_setter(parser, self.isTrain) 84 | 85 | self.parser = parser 86 | 87 | return parser.parse_args() 88 | 89 | def print_options(self, opt): 90 | message = '' 91 | message += '----------------- Options ---------------\n' 92 | for k, v in sorted(vars(opt).items()): 93 | comment = '' 94 | default = self.parser.get_default(k) 95 | if v != default: 96 | comment = '\t[default: %s]' % str(default) 97 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 98 | message += '----------------- End -------------------' 99 | print(message) 100 | 101 | # save to the disk 102 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 103 | util.mkdirs(expr_dir) 104 | file_name = os.path.join(expr_dir, 'opt.txt') 105 | with open(file_name, 'wt') as opt_file: 106 | opt_file.write(message) 107 | opt_file.write('\n') 108 | 109 | def parse(self): 110 | 111 | opt = self.gather_options() 112 | if opt.use_local: 113 | opt.loadSize = opt.fineSize 114 | opt.isTrain = self.isTrain # train or test 115 | 116 | # process opt.suffix 117 | if opt.suffix: 118 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 119 | opt.name = opt.name + suffix 120 | 121 | self.print_options(opt) 122 | 123 | # set gpu ids 124 | str_ids = opt.gpu_ids.split(',') 125 | opt.gpu_ids = [] 126 | for str_id in str_ids: 127 | id = int(str_id) 128 | if id >= 0: 129 | opt.gpu_ids.append(id) 130 | if len(opt.gpu_ids) > 0: 131 | torch.cuda.set_device(opt.gpu_ids[0]) 132 | 133 | self.opt = opt 134 | return self.opt 135 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 8 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 9 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 10 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 11 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 12 | parser.add_argument('--how_many', type=int, default=70, help='how many test images to run') 13 | parser.add_argument('--save2', action='store_true', help='only save real_A and fake_B') 14 | 15 | # To avoid cropping, the loadSize should be the same as fineSize 16 | parser.set_defaults(loadSize=parser.get_default('fineSize')) 17 | self.isTrain = False 18 | return parser 19 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 8 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 9 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 10 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 11 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 12 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 13 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 14 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 15 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 16 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 17 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 18 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 19 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 20 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 21 | parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 22 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 23 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 24 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine') 25 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 26 | # ============================================loss========================================================= 27 | # L1 and local 28 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 29 | parser.add_argument('--lambda_local', type=float, default=25.0, help='weight for Local loss') 30 | # chamfer loss 31 | parser.add_argument('--lambda_chamfer', type=float, default=0.1, help='weight for chamfer loss') 32 | parser.add_argument('--lambda_chamfer2', type=float, default=0.1, help='weight for chamfer loss2') 33 | # =====================================auxilary net structure=============================================== 34 | # dt & line net structure 35 | parser.add_argument('--netG_dt', type=str, default='unet_512', help='selects model to use for netG_dt, for chamfer loss') 36 | parser.add_argument('--netG_line', type=str, default='unet_512', help='selects model to use for netG_line, for chamfer loss') 37 | # multiple discriminators 38 | parser.add_argument('--discriminator_local', action='store_true', help='use six diffent local discriminator for 6 local regions') 39 | parser.add_argument('--gan_loss_strategy', type=int, default=2, help='specify how to calculate gan loss for g, 1: average global and local discriminators; 2: not change global discriminator weight, 0.25 for local') 40 | parser.add_argument('--addw_eye', type=float, default=1.0, help='additional weight for eye region') 41 | parser.add_argument('--addw_nose', type=float, default=1.0, help='additional weight for nose region') 42 | parser.add_argument('--addw_mouth', type=float, default=1.0, help='additional weight for mouth region') 43 | parser.add_argument('--addw_hair', type=float, default=1.0, help='additional weight for hair region') 44 | parser.add_argument('--addw_bg', type=float, default=1.0, help='additional weight for bg region') 45 | # ==========================================ablation======================================================== 46 | parser.add_argument('--no_l1_loss', action='store_true', help='no l1 loss') 47 | parser.add_argument('--no_G_local_loss', action='store_true', help='not using local transfer loss for local generator output') 48 | 49 | self.isTrain = True 50 | return parser 51 | -------------------------------------------------------------------------------- /preprocess/combine_A_and_B.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser('create image pairs') 7 | parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges') 8 | parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg') 9 | parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB') 10 | parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000) 11 | parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true') 12 | args = parser.parse_args() 13 | 14 | for arg in vars(args): 15 | print('[%s] = ' % arg, getattr(args, arg)) 16 | 17 | splits = os.listdir(args.fold_A) 18 | 19 | for sp in splits: 20 | img_fold_A = os.path.join(args.fold_A, sp) 21 | img_fold_B = os.path.join(args.fold_B, sp) 22 | img_list = os.listdir(img_fold_A) 23 | if args.use_AB: 24 | img_list = [img_path for img_path in img_list if '_A.' in img_path] 25 | 26 | num_imgs = min(args.num_imgs, len(img_list)) 27 | print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list))) 28 | img_fold_AB = os.path.join(args.fold_AB, sp) 29 | if not os.path.isdir(img_fold_AB): 30 | os.makedirs(img_fold_AB) 31 | print('split = %s, number of images = %d' % (sp, num_imgs)) 32 | for n in range(num_imgs): 33 | name_A = img_list[n] 34 | path_A = os.path.join(img_fold_A, name_A) 35 | if args.use_AB: 36 | name_B = name_A.replace('_A.', '_B.') 37 | else: 38 | name_B = name_A 39 | path_B = os.path.join(img_fold_B, name_B) 40 | if os.path.isfile(path_A) and os.path.isfile(path_B): 41 | name_AB = name_A 42 | if args.use_AB: 43 | name_AB = name_AB.replace('_A.', '.') # remove _A 44 | path_AB = os.path.join(img_fold_AB, name_AB) 45 | im_A = cv2.imread(path_A, cv2.IMREAD_COLOR) 46 | im_B = cv2.imread(path_B, cv2.IMREAD_COLOR) 47 | im_AB = np.concatenate([im_A, im_B], 1) 48 | cv2.imwrite(path_AB, im_AB) 49 | -------------------------------------------------------------------------------- /preprocess/example/img_1701.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/preprocess/example/img_1701.jpg -------------------------------------------------------------------------------- /preprocess/example/img_1701_aligned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/preprocess/example/img_1701_aligned.png -------------------------------------------------------------------------------- /preprocess/example/img_1701_aligned.txt: -------------------------------------------------------------------------------- 1 | 194 248 2 | 314 249 3 | 261 312 4 | 209 368 5 | 302 371 6 | -------------------------------------------------------------------------------- /preprocess/example/img_1701_aligned_bgmask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/preprocess/example/img_1701_aligned_bgmask.png -------------------------------------------------------------------------------- /preprocess/example/img_1701_facial5point.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/preprocess/example/img_1701_facial5point.mat -------------------------------------------------------------------------------- /preprocess/face_align_512.m: -------------------------------------------------------------------------------- 1 | function [trans_img,trans_facial5point]=face_align_512(impath,facial5point,savedir) 2 | % align the faces by similarity transformation. 3 | % using 5 facial landmarks: 2 eyes, nose, 2 mouth corners. 4 | % impath: path to image 5 | % facial5point: 5x2 size, 5 facial landmark positions, detected by MTCNN 6 | % savedir: savedir for cropped image and transformed facial landmarks 7 | 8 | %% alignment settings 9 | imgSize = [512,512]; 10 | coord5point = [180,230; 11 | 300,230; 12 | 240,301; 13 | 186,365.6; 14 | 294,365.6];%480x480 15 | coord5point = (coord5point-240)/560 * 512 + 256; 16 | 17 | %% face alignment 18 | 19 | % load and align, resize image to imgSize 20 | img = imread(impath); 21 | facial5point = double(facial5point); 22 | transf = cp2tform(facial5point, coord5point, 'similarity'); 23 | trans_img = imtransform(img, transf, 'XData', [1 imgSize(2)],... 24 | 'YData', [1 imgSize(1)],... 25 | 'Size', imgSize,... 26 | 'FillValues', [255;255;255]); 27 | trans_facial5point = round(tformfwd(transf,facial5point)); 28 | 29 | 30 | %% save results 31 | if ~exist(savedir,'dir') 32 | mkdir(savedir) 33 | end 34 | [~,name,~] = fileparts(impath); 35 | % save trans_img 36 | imwrite(trans_img, fullfile(savedir,[name,'_aligned.png'])); 37 | fprintf('write aligned image to %s\n',fullfile(savedir,[name,'_aligned.png'])); 38 | % save trans_facial5point 39 | write_5pt(fullfile(savedir, [name, '_aligned.txt']), trans_facial5point); 40 | fprintf('write transformed facial landmark to %s\n',fullfile(savedir,[name,'_aligned.txt'])); 41 | 42 | %% show results 43 | imshow(trans_img); hold on; 44 | plot(trans_facial5point(:,1),trans_facial5point(:,2),'b'); 45 | plot(trans_facial5point(:,1),trans_facial5point(:,2),'r+'); 46 | 47 | end 48 | 49 | function [] = write_5pt(fn, trans_pt) 50 | fid = fopen(fn, 'w'); 51 | for i = 1:5 52 | fprintf(fid, '%d %d\n', trans_pt(i,1), trans_pt(i,2));%will be read as np.int32 53 | end 54 | fclose(fid); 55 | end -------------------------------------------------------------------------------- /preprocess/readme.md: -------------------------------------------------------------------------------- 1 | ## Preprocessing steps 2 | 3 | Face photos (and paired drawings) need to be aligned and have background mask detected. Aligned images, facial lamdmark files (txt) and background masks are needed for training and testing. 4 | 5 | ### 1. Align, resize, crop images to 512x512 and prepare facial landmarks 6 | 7 | All training and testing images in our model are aligned using facial landmarks. And landmarks after alignment are needed in our code. 8 | 9 | - First, 5 facial landmark for a face photo need to be detected (we detect using [MTCNN](https://github.com/kpzhang93/MTCNN_face_detection_alignment)(MTCNNv1)). 10 | 11 | - Then, we provide a matlab function in `face_align_512.m` to align, resize and crop face photos (and corresponding drawings) to 512x512.Call this function in MATLAB to align the image to 512x512. 12 | For example, for `img_1701.jpg` in `example` dir, 5 detected facial landmark is saved in `example/img_1701_facial5point.mat`. Call following in MATLAB: 13 | ```bash 14 | load('example/img_1701_facial5point.mat'); 15 | [trans_img,trans_facial5point]=face_align_512('example/img_1701.jpg',facial5point,'example'); 16 | ``` 17 | 18 | This will align the image and output aligned image and transformed facial landmark (in txt format) in `example` folder. 19 | See `face_align_512.m` for more instructions. 20 | 21 | - The saved transformed facial landmark need to be copied to `lm_dir` (see [base flags](../options/base_options.py), default is `dataset/landmark/ALL`), and has the **same filename** with aligned face photos (e.g. `dataset/data/test_single/31.png` should have landmark file `dataset/landmark/ALL/31.txt`). 22 | 23 | ### 2. Prepare background masks 24 | 25 | Background masks are needed in our code. 26 | 27 | In our work, background mask is segmented by method in 28 | "Automatic Portrait Segmentation for Image Stylization" 29 | Xiaoyong Shen, Aaron Hertzmann, Jiaya Jia, Sylvain Paris, Brian Price, Eli Shechtman, Ian Sachs. Computer Graphics Forum, 35(2)(Proc. Eurographics), 2016. 30 | 31 | - We use code in http://xiaoyongshen.me/webpage_portrait/index.html to detect background masks for face photos. 32 | A sample background mask is shown in `example/img_1701_aligned_bgmask.png`. 33 | 34 | - The background masks need to be copied to `bg_dir` (see [base flags](../options/base_options.py), default is `dataset/mask/ALL`), and has the **same filename** with aligned face photos (e.g. `dataset/data/test_single/31.png` should have background mask `dataset/mask/ALL/31.png`) 35 | 36 | 37 | ### 3. (For training) Prepare more training data 38 | 39 | We provide a python script to generate training data in the form of pairs of images {A,B}, i.e. pairs {face photo, drawing}. This script will concatenate each pair of images horizontally into one single image. Then we can learn to translate A to B: 40 | 41 | Create folder `/path/to/data` with subfolders `A` and `B`. `A` and `B` should each have their own subfolders `train`, `test`, etc. In `/path/to/data/A/train`, put training face photos. In `/path/to/data/B/train`, put the corresponding artist drawings. Repeat same for `test`. 42 | 43 | Corresponding images in a pair {A,B} must both be images after aligning and of size 512x512, and have the same filename, e.g., `/path/to/data/A/train/1.png` is considered to correspond to `/path/to/data/B/train/1.png`. 44 | 45 | Once the data is formatted this way, call: 46 | ```bash 47 | python datasets/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data 48 | ``` 49 | 50 | This will combine each pair of images (A,B) into a single image file, ready for training. -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 2 | # APDrawingGAN 3 | 4 | We provide PyTorch implementations for our CVPR 2019 paper "APDrawingGAN: Generating Artistic Portrait Drawings from Face Photos with Hierarchical GANs". 5 | 6 | This project generates artistic portrait drawings from face photos using a GAN-based model. 7 | You may find useful information in [preprocessing steps](preprocess/readme.md) and [training/testing tips](docs/tips.md). 8 | 9 | [[Paper]](http://openaccess.thecvf.com/content_CVPR_2019/html/Yi_APDrawingGAN_Generating_Artistic_Portrait_Drawings_From_Face_Photos_With_Hierarchical_CVPR_2019_paper.html) [[Demo]](https://apdrawing.github.io/) 10 | 11 | [[Jittor implementation]](https://github.com/yiranran/APDrawingGAN-Jittor) 12 | 13 | 14 | ## Our Proposed Framework 15 | 16 | 17 | 18 | ## Sample Results 19 | Up: input, Down: output 20 |

21 | 22 | 23 | 24 | 25 | 26 |

27 |

28 | 29 | 30 | 31 | 32 | 33 |

34 | 35 | ## Citation 36 | If you use this code for your research, please cite our paper. 37 | ``` 38 | @inproceedings{YiLLR19, 39 | title = {{APDrawingGAN}: Generating Artistic Portrait Drawings from Face Photos with Hierarchical GANs}, 40 | author = {Yi, Ran and Liu, Yong-Jin and Lai, Yu-Kun and Rosin, Paul L}, 41 | booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition (CVPR '19)}, 42 | pages = {10743--10752}, 43 | year = {2019} 44 | } 45 | ``` 46 | 47 | ## Prerequisites 48 | - Linux or macOS 49 | - Python 2.7 50 | - CPU or NVIDIA GPU + CUDA CuDNN 51 | 52 | 53 | ## Getting Started 54 | ### Installation 55 | - Install PyTorch 0.4+ and torchvision from http://pytorch.org and other dependencies (e.g., [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate)). You can install all the dependencies by 56 | ```bash 57 | pip install -r requirements.txt 58 | ``` 59 | 60 | ### Quick Start (Apply a Pre-trained Model) 61 | 62 | - Download a pre-trained model (using 70 pairs in training set and augmented data) from https://cg.cs.tsinghua.edu.cn/people/~Yongjin/APDrawingGAN-Models1.zip (Model1) and put it in `checkpoints/formal_author`. 63 | 64 | - Then generate artistic portrait drawings for example photos in `dataset/data/test_single` using 65 | ``` bash 66 | python test.py --dataroot dataset/data/test_single --name formal_author --model test --dataset_mode single --norm batch --use_local --which_epoch 300 67 | ``` 68 | The test results will be saved to a html file here: `./results/formal_author/test_300/index.html`. 69 | - If you want to test on your own data, please first align your pictures and prepare your data's facial landmarks and background masks according to tutorial in [preprocessing steps](preprocess/readme.md), then run 70 | ``` bash 71 | python test.py --dataroot {path_to_aligned_photos} --name formal_author --model test --dataset_mode single --norm batch --use_local --which_epoch 300 72 | ``` 73 | - We also provide an online demo at https://face.lol (optimized, using 120 pairs for training), which will be easier to use if you want to test more photos. 74 | 75 | ### Train 76 | - Download our [APDrawing dataset](https://cg.cs.tsinghua.edu.cn/people/~Yongjin/APDrawingDB.zip) and copy content to `dataset` folder 77 | - Download models of pre-traning and auxiliary networks (for fast distance transform and line detection), from https://cg.cs.tsinghua.edu.cn/people/~Yongjin/APDrawingGAN-Models2.zip (Model2). 78 | - Run `python -m visdom.server` 79 | - Train a model (with pre-training as initialization): 80 | first copy "pre-training" models into checkpoints dir of current experiment(`checkpoints/[name]`, e.g. `checkpoints/formal`), and copy "auxiliary" models into `checkpoints/auxiliary`. 81 | ```bash 82 | python train.py --dataroot dataset/data --name formal --continue_train --use_local --discriminator_local --niter 300 --niter_decay 0 --save_epoch_freq 25 83 | ``` 84 | - Train a model (without initialization): 85 | first copy models of auxiliary networks into `checkpoints/auxiliary`. 86 | ```bash 87 | python train.py --dataroot dataset/data --name formal_noinit --use_local --discriminator_local --niter 300 --niter_decay 0 --save_epoch_freq 25 88 | ``` 89 | - To view training results and loss plots, click the URL http://localhost:8097. To see more intermediate results, check out `./checkpoints/formal/web/index.html` 90 | 91 | ### Test 92 | - Test the model on test set: 93 | ```bash 94 | python test.py --dataroot dataset/data --name formal --use_local --which_epoch 250 95 | ``` 96 | The test results will be saved to a html file here: `./results/formal/test_250/index.html`. 97 | - Test the model on images without paired ground truth (please use `--model test`, `--dataset_mode single` and `--norm batch`): 98 | ``` 99 | python test.py --dataroot dataset/data/test_single --name formal --model test --dataset_mode single --norm batch --use_local --which_epoch 250 100 | ``` 101 | 102 | You can find these scripts at `scripts` directory. 103 | 104 | 105 | ## [Preprocessing Steps](preprocess/readme.md) 106 | Preprocessing steps for your own data (either for testing or training). 107 | 108 | 109 | ## [Training/Test Tips](docs/tips.md) 110 | Best practice for training and testing your models. 111 | 112 | You can contact email ranyi@sjtu.edu.cn for any questions. 113 | 114 | ## Acknowledgments 115 | Our code is inspired by [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). 116 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=0.4.0 2 | torchvision>=0.2.1 3 | dominate>=2.3.1 4 | visdom>=0.1.8.3 5 | scipy>=1.1.0 6 | numpy>=1.14.1 7 | Pillow>=5.0.0 8 | opencv-python>=3.4.2 -------------------------------------------------------------------------------- /script/test.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python test.py --dataroot dataset/data --name formal --use_local --which_epoch 250 -------------------------------------------------------------------------------- /script/test_pretrained.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python test.py --dataroot dataset/data --name formal_author --use_local --which_epoch 300 -------------------------------------------------------------------------------- /script/test_single.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python test.py --dataroot dataset/data/test_single --name formal_author --model test --dataset_mode single --norm batch --use_local --which_epoch 300 -------------------------------------------------------------------------------- /script/train.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python train.py --dataroot dataset/data --name formal --continue_train --use_local --discriminator_local --niter 300 --niter_decay 0 --save_epoch_freq 25 -------------------------------------------------------------------------------- /script/train_noinit.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python train.py --dataroot dataset/data --name formal_noinit --use_local --discriminator_local --niter 300 --niter_decay 0 --save_epoch_freq 25 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from options.test_options import TestOptions 3 | from data import CreateDataLoader 4 | from models import create_model 5 | from util.visualizer import save_images 6 | from util import html 7 | 8 | 9 | if __name__ == '__main__': 10 | opt = TestOptions().parse() 11 | opt.num_threads = 1 # test code only supports num_threads = 1 12 | opt.batch_size = 1 # test code only supports batch_size = 1 13 | opt.serial_batches = True # no shuffle 14 | opt.no_flip = True # no flip 15 | opt.display_id = -1 # no visdom display 16 | data_loader = CreateDataLoader(opt) 17 | dataset = data_loader.load_data() 18 | model = create_model(opt) 19 | model.setup(opt) 20 | # create website 21 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 22 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 23 | # test 24 | for i, data in enumerate(dataset): 25 | if i >= opt.how_many:#test code only supports batch_size = 1, how_many means how many test images to run 26 | break 27 | model.set_input(data) 28 | model.test() 29 | visuals = model.get_current_visuals()#in test the loadSize is set to the same as fineSize 30 | img_path = model.get_image_paths() 31 | if i % 5 == 0: 32 | print('processing (%04d)-th image... %s' % (i, img_path)) 33 | save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) 34 | 35 | webpage.save() 36 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data import CreateDataLoader 4 | from models import create_model 5 | from util.visualizer import Visualizer 6 | 7 | if __name__ == '__main__': 8 | start = time.time() 9 | opt = TrainOptions().parse() 10 | data_loader = CreateDataLoader(opt) 11 | dataset = data_loader.load_data() 12 | dataset_size = len(data_loader) 13 | print('#training images = %d' % dataset_size) 14 | 15 | model = create_model(opt) 16 | model.setup(opt) 17 | visualizer = Visualizer(opt) 18 | total_steps = 0 19 | 20 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): 21 | epoch_start_time = time.time() 22 | iter_data_time = time.time() 23 | epoch_iter = 0 24 | 25 | for i, data in enumerate(dataset): 26 | iter_start_time = time.time() 27 | if total_steps % opt.print_freq == 0: 28 | t_data = iter_start_time - iter_data_time 29 | visualizer.reset() 30 | total_steps += opt.batch_size 31 | epoch_iter += opt.batch_size 32 | model.set_input(data) 33 | model.optimize_parameters() 34 | 35 | if total_steps % opt.display_freq == 0: 36 | save_result = total_steps % opt.update_html_freq == 0 37 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 38 | 39 | if total_steps % opt.print_freq == 0: 40 | losses = model.get_current_losses() 41 | t = (time.time() - iter_start_time) / opt.batch_size 42 | visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) 43 | if opt.display_id > 0: 44 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses) 45 | 46 | if total_steps % opt.save_latest_freq == 0: 47 | print('saving the latest model (epoch %d, total_steps %d)' % 48 | (epoch, total_steps)) 49 | #model.save_networks('latest') 50 | model.save_networks2('latest') 51 | 52 | iter_data_time = time.time() 53 | if epoch % opt.save_epoch_freq == 0: 54 | print('saving the model at the end of epoch %d, iters %d' % 55 | (epoch, total_steps)) 56 | #model.save_networks('latest') 57 | #model.save_networks(epoch) 58 | model.save_networks2('latest') 59 | model.save_networks2(epoch) 60 | 61 | print('End of epoch %d / %d \t Time Taken: %d sec' % 62 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 63 | model.update_learning_rate() 64 | 65 | print('Total Time Taken: %d sec' % (time.time() - start)) 66 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN/38f4319f8e724f6bef5a32c348a8c0967baad773/util/__init__.py -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | self.doc = dominate.document(title=title) 18 | if reflesh > 0: 19 | with self.doc.head: 20 | meta(http_equiv="reflesh", content=str(reflesh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, width=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=os.path.join('images', link)): 41 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 42 | br() 43 | p(txt) 44 | 45 | def save(self): 46 | html_file = '%s/index.html' % self.web_dir 47 | f = open(html_file, 'wt') 48 | f.write(self.doc.render()) 49 | f.close() 50 | 51 | 52 | if __name__ == '__main__': 53 | html = HTML('web/', 'test_html') 54 | html.add_header('hello world') 55 | 56 | ims = [] 57 | txts = [] 58 | links = [] 59 | for n in range(4): 60 | ims.append('image_%d.png' % n) 61 | txts.append('text_%d' % n) 62 | links.append('image_%d.png' % n) 63 | html.add_images(ims, txts, links) 64 | html.save() 65 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | def __init__(self, pool_size): 7 | self.pool_size = pool_size 8 | if self.pool_size > 0: 9 | self.num_imgs = 0 10 | self.images = [] 11 | 12 | def query(self, images): 13 | if self.pool_size == 0: 14 | return images 15 | return_images = [] 16 | for image in images: 17 | image = torch.unsqueeze(image.data, 0) 18 | if self.num_imgs < self.pool_size: 19 | self.num_imgs = self.num_imgs + 1 20 | self.images.append(image) 21 | return_images.append(image) 22 | else: 23 | p = random.uniform(0, 1) 24 | if p > 0.5: 25 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 26 | tmp = self.images[random_id].clone() 27 | self.images[random_id] = image 28 | return_images.append(tmp) 29 | else: 30 | return_images.append(image) 31 | return_images = torch.cat(return_images, 0) 32 | return return_images 33 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | 7 | 8 | # Converts a Tensor into an image array (numpy) 9 | # |imtype|: the desired type of the converted numpy array 10 | def tensor2im(input_image, imtype=np.uint8): 11 | if isinstance(input_image, torch.Tensor): 12 | image_tensor = input_image.data 13 | else: 14 | return input_image 15 | image_numpy = image_tensor[0].cpu().float().numpy() 16 | if image_numpy.shape[0] == 1: 17 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 18 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 19 | return image_numpy.astype(imtype) 20 | 21 | 22 | def diagnose_network(net, name='network'): 23 | mean = 0.0 24 | count = 0 25 | for param in net.parameters(): 26 | if param.grad is not None: 27 | mean += torch.mean(torch.abs(param.grad.data)) 28 | count += 1 29 | if count > 0: 30 | mean = mean / count 31 | print(name) 32 | print(mean) 33 | 34 | 35 | def save_image(image_numpy, image_path): 36 | image_pil = Image.fromarray(image_numpy) 37 | image_pil.save(image_path) 38 | 39 | 40 | def print_numpy(x, val=True, shp=False): 41 | x = x.astype(np.float64) 42 | if shp: 43 | print('shape,', x.shape) 44 | if val: 45 | x = x.flatten() 46 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 47 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 48 | 49 | 50 | def mkdirs(paths): 51 | if isinstance(paths, list) and not isinstance(paths, str): 52 | for path in paths: 53 | mkdir(path) 54 | else: 55 | mkdir(paths) 56 | 57 | 58 | def mkdir(path): 59 | if not os.path.exists(path): 60 | os.makedirs(path) 61 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | from . import util 6 | from . import html 7 | #from scipy.misc import imresize 8 | from PIL import Image 9 | 10 | 11 | # save image to the disk 12 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 13 | image_dir = webpage.get_image_dir() 14 | short_path = ntpath.basename(image_path[0]) 15 | name = os.path.splitext(short_path)[0] 16 | 17 | webpage.add_header(name) 18 | ims, txts, links = [], [], [] 19 | 20 | for label, im_data in visuals.items(): 21 | im = util.tensor2im(im_data)#tensor to numpy array [-1,1]->[0,1]->[0,255] 22 | image_name = '%s_%s.png' % (name, label) 23 | save_path = os.path.join(image_dir, image_name) 24 | h, w, _ = im.shape 25 | if aspect_ratio > 1.0: 26 | #im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') 27 | im = np.array(Image.fromarray(im).resize((int(w*aspect_ratio),h), Image.BICUBIC)) 28 | if aspect_ratio < 1.0: 29 | #im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') 30 | im = np.array(Image.fromarray(im).resize((w,int(h/aspect_ratio)), Image.BICUBIC)) 31 | util.save_image(im, save_path) 32 | 33 | ims.append(image_name) 34 | txts.append(label) 35 | links.append(image_name) 36 | webpage.add_images(ims, txts, links, width=width) 37 | 38 | 39 | class Visualizer(): 40 | def __init__(self, opt): 41 | self.display_id = opt.display_id 42 | self.use_html = opt.isTrain and not opt.no_html 43 | self.win_size = opt.display_winsize 44 | self.name = opt.name 45 | self.opt = opt 46 | self.saved = False 47 | if self.display_id > 0: 48 | import visdom 49 | self.ncols = opt.display_ncols 50 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env, raise_exceptions=True) 51 | 52 | if self.use_html: 53 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 54 | self.img_dir = os.path.join(self.web_dir, 'images') 55 | print('create web directory %s...' % self.web_dir) 56 | util.mkdirs([self.web_dir, self.img_dir]) 57 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 58 | with open(self.log_name, "a") as log_file: 59 | now = time.strftime("%c") 60 | log_file.write('================ Training Loss (%s) ================\n' % now) 61 | 62 | def reset(self): 63 | self.saved = False 64 | 65 | def throw_visdom_connection_error(self): 66 | print('\n\nCould not connect to Visdom server (https://github.com/facebookresearch/visdom) for displaying training progress.\nYou can suppress connection to Visdom using the option --display_id -1. To install visdom, run \n$ pip install visdom\n, and start the server by \n$ python -m visdom.server.\n\n') 67 | exit(1) 68 | 69 | # |visuals|: dictionary of images to display or save 70 | def display_current_results(self, visuals, epoch, save_result): 71 | if self.display_id > 0: # show images in the browser 72 | ncols = self.ncols 73 | if ncols > 0: 74 | ncols = min(ncols, len(visuals)) 75 | h, w = next(iter(visuals.values())).shape[:2] 76 | table_css = """""" % (w, h) 80 | title = self.name 81 | label_html = '' 82 | label_html_row = '' 83 | images = [] 84 | idx = 0 85 | for label, image in visuals.items(): 86 | image_numpy = util.tensor2im(image) 87 | label_html_row += '%s' % label 88 | images.append(image_numpy.transpose([2, 0, 1])) 89 | idx += 1 90 | if idx % ncols == 0: 91 | label_html += '%s' % label_html_row 92 | label_html_row = '' 93 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 94 | while idx % ncols != 0: 95 | images.append(white_image) 96 | label_html_row += '' 97 | idx += 1 98 | if label_html_row != '': 99 | label_html += '%s' % label_html_row 100 | # pane col = image row 101 | try: 102 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 103 | padding=2, opts=dict(title=title + ' images')) 104 | label_html = '%s
' % label_html 105 | self.vis.text(table_css + label_html, win=self.display_id + 2, 106 | opts=dict(title=title + ' labels')) 107 | except ConnectionError: 108 | self.throw_visdom_connection_error() 109 | 110 | else: 111 | idx = 1 112 | for label, image in visuals.items(): 113 | image_numpy = util.tensor2im(image) 114 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 115 | win=self.display_id + idx) 116 | idx += 1 117 | 118 | if self.use_html and (save_result or not self.saved): # save images to a html file 119 | self.saved = True 120 | for label, image in visuals.items(): 121 | image_numpy = util.tensor2im(image) 122 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 123 | util.save_image(image_numpy, img_path) 124 | # update website 125 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 126 | for n in range(epoch, 0, -1): 127 | webpage.add_header('epoch [%d]' % n) 128 | ims, txts, links = [], [], [] 129 | 130 | for label, image_numpy in visuals.items(): 131 | image_numpy = util.tensor2im(image) 132 | img_path = 'epoch%.3d_%s.png' % (n, label) 133 | ims.append(img_path) 134 | txts.append(label) 135 | links.append(img_path) 136 | webpage.add_images(ims, txts, links, width=self.win_size) 137 | webpage.save() 138 | 139 | def save_current_results1(self, visuals, epoch, epoch_iter): 140 | if not os.path.exists(self.img_dir+'/detailed'): 141 | os.mkdir(self.img_dir+'/detailed') 142 | for label, image in visuals.items(): 143 | image_numpy = util.tensor2im(image) 144 | img_path = os.path.join(self.img_dir, 'detailed', 'epoch%.3d_%.3d_%s.png' % (epoch, epoch_iter, label)) 145 | util.save_image(image_numpy, img_path) 146 | 147 | # losses: dictionary of error labels and values 148 | def plot_current_losses(self, epoch, counter_ratio, opt, losses): 149 | if not hasattr(self, 'plot_data'): 150 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 151 | self.plot_data['X'].append(epoch + counter_ratio) 152 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 153 | try: 154 | self.vis.line( 155 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 156 | Y=np.array(self.plot_data['Y']), 157 | opts={ 158 | 'title': self.name + ' loss over time', 159 | 'legend': self.plot_data['legend'], 160 | 'xlabel': 'epoch', 161 | 'ylabel': 'loss'}, 162 | win=self.display_id) 163 | except ConnectionError: 164 | self.throw_visdom_connection_error() 165 | 166 | # losses: same format as |losses| of plot_current_losses 167 | def print_current_losses(self, epoch, i, losses, t, t_data): 168 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) 169 | for k, v in losses.items(): 170 | message += '%s: %.6f ' % (k, v) 171 | 172 | print(message) 173 | with open(self.log_name, "a") as log_file: 174 | log_file.write('%s\n' % message) 175 | --------------------------------------------------------------------------------