├── .gitignore ├── data ├── __init__.py ├── aligned_dataset.py ├── base_data_loader.py ├── base_dataset.py ├── image_folder.py └── single_dataset.py ├── docs └── tips.md ├── imgs ├── architecture-pami.jpg └── sample │ ├── 140_large-img_1615_fake_B.png │ ├── 140_large-img_1615_real_A.png │ ├── 140_large-img_1616_fake_B.png │ ├── 140_large-img_1616_real_A.png │ ├── 140_large-img_1673_fake_B.png │ ├── 140_large-img_1673_real_A.png │ ├── 140_large-img_1684_fake_B.png │ ├── 140_large-img_1684_real_A.png │ ├── 140_large-img_1696_fake_B.png │ ├── 140_large-img_1696_real_A.png │ ├── 140_large-img_1701_fake_B.png │ └── 140_large-img_1701_real_A.png ├── models ├── __init__.py ├── apdrawingpp_style_model.py ├── base_model.py ├── networks.py └── test_model.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── preprocess ├── combine_A_and_B.py ├── example │ ├── img_1701.jpg │ ├── img_1701_aligned.png │ ├── img_1701_aligned.txt │ ├── img_1701_aligned_68lm.txt │ ├── img_1701_aligned_bgmask.png │ ├── img_1701_aligned_eyelmask.png │ ├── img_1701_aligned_eyermask.png │ ├── img_1701_aligned_facemask.png │ ├── img_1701_aligned_mouthmask.png │ ├── img_1701_aligned_nosemask.png │ └── img_1701_facial5point.mat ├── face_align_512.m ├── get_partmask.py └── readme.md ├── readme.md ├── requirements.txt ├── script ├── test.sh ├── test_single.sh └── train.sh ├── test.py ├── train.py └── util ├── __init__.py ├── get_data.py ├── html.py ├── image_pool.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | debug* 3 | datasets/*/*/*.jpg 4 | datasets/*/*/*.txt 5 | dataset/ 6 | checkpoints/ 7 | results/ 8 | build/ 9 | dist/ 10 | torch.egg-info/ 11 | */**/__pycache__ 12 | torch/version.py 13 | torch/csrc/generic/TensorMethods.cpp 14 | torch/lib/*.so* 15 | torch/lib/*.dylib* 16 | torch/lib/*.h 17 | torch/lib/build 18 | torch/lib/tmp_install 19 | torch/lib/include 20 | torch/lib/torch_shm_manager 21 | torch/csrc/cudnn/cuDNN.cpp 22 | torch/csrc/nn/THNN.cwrap 23 | torch/csrc/nn/THNN.cpp 24 | torch/csrc/nn/THCUNN.cwrap 25 | torch/csrc/nn/THCUNN.cpp 26 | torch/csrc/nn/THNN_generic.cwrap 27 | torch/csrc/nn/THNN_generic.cpp 28 | torch/csrc/nn/THNN_generic.h 29 | docs/src/**/* 30 | test/data/legacy_modules.t7 31 | test/data/gpu_tensors.pt 32 | test/htmlcov 33 | test/.coverage 34 | */*.pyc 35 | */**/*.pyc 36 | */**/**/*.pyc 37 | */**/**/**/*.pyc 38 | */**/**/**/**/*.pyc 39 | */*.so* 40 | */**/*.so* 41 | */**/*.dylib* 42 | test/data/legacy_serialized.pt 43 | *~ 44 | .idea 45 | *.zip -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch.utils.data 3 | from data.base_data_loader import BaseDataLoader 4 | from data.base_dataset import BaseDataset 5 | 6 | 7 | def find_dataset_using_name(dataset_name): 8 | # Given the option --dataset_mode [datasetname], 9 | # the file "data/datasetname_dataset.py" 10 | # will be imported. 11 | dataset_filename = "data." + dataset_name + "_dataset" 12 | datasetlib = importlib.import_module(dataset_filename) 13 | 14 | # In the file, the class called DatasetNameDataset() will 15 | # be instantiated. It has to be a subclass of BaseDataset, 16 | # and it is case-insensitive. 17 | dataset = None 18 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 19 | for name, cls in datasetlib.__dict__.items(): 20 | if name.lower() == target_dataset_name.lower() \ 21 | and issubclass(cls, BaseDataset): 22 | dataset = cls 23 | 24 | if dataset is None: 25 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 26 | exit(0) 27 | 28 | return dataset 29 | 30 | 31 | def get_option_setter(dataset_name): 32 | dataset_class = find_dataset_using_name(dataset_name) 33 | return dataset_class.modify_commandline_options 34 | 35 | 36 | def create_dataset(opt): 37 | dataset = find_dataset_using_name(opt.dataset_mode) 38 | instance = dataset() 39 | instance.initialize(opt) 40 | print("dataset [%s] was created" % (instance.name())) 41 | return instance 42 | 43 | 44 | def CreateDataLoader(opt): 45 | data_loader = CustomDatasetDataLoader() 46 | data_loader.initialize(opt) 47 | return data_loader 48 | 49 | 50 | # Wrapper class of Dataset class that performs 51 | # multi-threaded data loading 52 | class CustomDatasetDataLoader(BaseDataLoader): 53 | def name(self): 54 | return 'CustomDatasetDataLoader' 55 | 56 | def initialize(self, opt): 57 | BaseDataLoader.initialize(self, opt) 58 | self.dataset = create_dataset(opt) 59 | self.dataloader = torch.utils.data.DataLoader( 60 | self.dataset, 61 | batch_size=opt.batch_size, 62 | shuffle=not opt.serial_batches,#in training, serial_batches by default is false, shuffle=true 63 | num_workers=int(opt.num_threads)) 64 | 65 | def load_data(self): 66 | return self 67 | 68 | def __len__(self): 69 | return min(len(self.dataset), self.opt.max_dataset_size) 70 | 71 | def __iter__(self): 72 | for i, data in enumerate(self.dataloader): 73 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 74 | break 75 | yield data 76 | -------------------------------------------------------------------------------- /data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import torchvision.transforms as transforms 4 | import torch 5 | from data.base_dataset import BaseDataset 6 | from data.image_folder import make_dataset 7 | from PIL import Image 8 | import numpy as np 9 | import cv2 10 | import csv 11 | 12 | def getfeats(featpath): 13 | trans_points = np.empty([5,2],dtype=np.int64) 14 | with open(featpath, 'r') as csvfile: 15 | reader = csv.reader(csvfile, delimiter=' ') 16 | for ind,row in enumerate(reader): 17 | trans_points[ind,:] = row 18 | return trans_points 19 | 20 | def tocv2(ts): 21 | img = (ts.numpy()/2+0.5)*255 22 | img = img.astype('uint8') 23 | img = np.transpose(img,(1,2,0)) 24 | img = img[:,:,::-1]#rgb->bgr 25 | return img 26 | 27 | def dt(img): 28 | if(img.shape[2]==3): 29 | img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY) 30 | #convert to BW 31 | ret1,thresh1 = cv2.threshold(img,127,255,cv2.THRESH_BINARY) 32 | ret2,thresh2 = cv2.threshold(img,127,255,cv2.THRESH_BINARY_INV) 33 | dt1 = cv2.distanceTransform(thresh1,cv2.DIST_L2,5) 34 | dt2 = cv2.distanceTransform(thresh2,cv2.DIST_L2,5) 35 | dt1 = dt1/dt1.max()#->[0,1] 36 | dt2 = dt2/dt2.max() 37 | return dt1, dt2 38 | 39 | def getSoft(size,xb,yb,boundwidth=5.0): 40 | xarray = np.tile(np.arange(0,size[1]),(size[0],1)) 41 | yarray = np.tile(np.arange(0,size[0]),(size[1],1)).transpose() 42 | cxdists = [] 43 | cydists = [] 44 | for i in range(len(xb)): 45 | xba = np.tile(xb[i],(size[1],1)).transpose() 46 | yba = np.tile(yb[i],(size[0],1)) 47 | cxdists.append(np.abs(xarray-xba)) 48 | cydists.append(np.abs(yarray-yba)) 49 | xdist = np.minimum.reduce(cxdists) 50 | ydist = np.minimum.reduce(cydists) 51 | manhdist = np.minimum.reduce([xdist,ydist]) 52 | im = (manhdist+1) / (boundwidth+1) * 1.0 53 | im[im>=1.0] = 1.0 54 | return im 55 | 56 | class AlignedDataset(BaseDataset): 57 | @staticmethod 58 | def modify_commandline_options(parser, is_train): 59 | return parser 60 | 61 | def initialize(self, opt): 62 | self.opt = opt 63 | self.root = opt.dataroot 64 | imglist = 'datasets/apdrawing_list/%s/%s.txt' % (opt.phase, opt.dataroot) 65 | if os.path.exists(imglist): 66 | lines = open(imglist, 'r').read().splitlines() 67 | lines = sorted(lines) 68 | self.AB_paths = [line.split()[0] for line in lines] 69 | if len(lines[0].split()) == 2: 70 | self.B_paths = [line.split()[1] for line in lines] 71 | else: 72 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) 73 | self.AB_paths = sorted(make_dataset(self.dir_AB)) 74 | assert(opt.resize_or_crop == 'resize_and_crop') 75 | 76 | def __getitem__(self, index): 77 | AB_path = self.AB_paths[index] 78 | AB = Image.open(AB_path).convert('RGB') 79 | w, h = AB.size 80 | if w/h == 2: 81 | w2 = int(w / 2) 82 | A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 83 | B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 84 | else: # if w/h != 2, need B_paths 85 | A = AB.resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 86 | B = Image.open(self.B_paths[index]).convert('RGB') 87 | B = B.resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 88 | A = transforms.ToTensor()(A) 89 | B = transforms.ToTensor()(B) 90 | w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) 91 | h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) 92 | 93 | A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]#C,H,W 94 | B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] 95 | 96 | A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) 97 | B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B) 98 | 99 | if self.opt.which_direction == 'BtoA': 100 | input_nc = self.opt.output_nc 101 | output_nc = self.opt.input_nc 102 | else: 103 | input_nc = self.opt.input_nc 104 | output_nc = self.opt.output_nc 105 | 106 | flipped = False 107 | if (not self.opt.no_flip) and random.random() < 0.5: 108 | flipped = True 109 | idx = [i for i in range(A.size(2) - 1, -1, -1)] 110 | idx = torch.LongTensor(idx) 111 | A = A.index_select(2, idx) 112 | B = B.index_select(2, idx) 113 | 114 | if input_nc == 1: # RGB to gray 115 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 116 | A = tmp.unsqueeze(0) 117 | 118 | if output_nc == 1: # RGB to gray 119 | tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 120 | B = tmp.unsqueeze(0) 121 | 122 | item = {'A': A, 'B': B, 123 | 'A_paths': AB_path, 'B_paths': AB_path} 124 | 125 | if self.opt.use_local: 126 | regions = ['eyel','eyer','nose','mouth'] 127 | basen = os.path.basename(AB_path)[:-4]+'.txt' 128 | if self.opt.region_enm in [0,1]: 129 | featdir = self.opt.lm_dir 130 | featpath = os.path.join(featdir,basen) 131 | feats = getfeats(featpath) 132 | if flipped: 133 | for i in range(5): 134 | feats[i,0] = self.opt.fineSize - feats[i,0] - 1 135 | tmp = [feats[0,0],feats[0,1]] 136 | feats[0,:] = [feats[1,0],feats[1,1]] 137 | feats[1,:] = tmp 138 | mouth_x = int((feats[3,0]+feats[4,0])/2.0) 139 | mouth_y = int((feats[3,1]+feats[4,1])/2.0) 140 | ratio = self.opt.fineSize / 256 141 | EYE_H = self.opt.EYE_H * ratio 142 | EYE_W = self.opt.EYE_W * ratio 143 | NOSE_H = self.opt.NOSE_H * ratio 144 | NOSE_W = self.opt.NOSE_W * ratio 145 | MOUTH_H = self.opt.MOUTH_H * ratio 146 | MOUTH_W = self.opt.MOUTH_W * ratio 147 | center = torch.IntTensor([[feats[0,0],feats[0,1]-4*ratio],[feats[1,0],feats[1,1]-4*ratio],[feats[2,0],feats[2,1]-NOSE_H/2+16*ratio],[mouth_x,mouth_y]]) 148 | item['center'] = center 149 | rhs = [int(EYE_H),int(EYE_H),int(NOSE_H),int(MOUTH_H)] 150 | rws = [int(EYE_W),int(EYE_W),int(NOSE_W),int(MOUTH_W)] 151 | if self.opt.soft_border: 152 | soft_border_mask4 = [] 153 | for i in range(4): 154 | xb = [np.zeros(rhs[i]),np.ones(rhs[i])*(rws[i]-1)] 155 | yb = [np.zeros(rws[i]),np.ones(rws[i])*(rhs[i]-1)] 156 | soft_border_mask = getSoft([rhs[i],rws[i]],xb,yb) 157 | soft_border_mask4.append(torch.Tensor(soft_border_mask).unsqueeze(0)) 158 | item['soft_'+regions[i]+'_mask'] = soft_border_mask4[i] 159 | for i in range(4): 160 | item[regions[i]+'_A'] = A[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2] 161 | item[regions[i]+'_B'] = B[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2] 162 | if self.opt.soft_border: 163 | item[regions[i]+'_A'] = item[regions[i]+'_A'] * soft_border_mask4[i].repeat(int(input_nc/output_nc),1,1) 164 | item[regions[i]+'_B'] = item[regions[i]+'_B'] * soft_border_mask4[i] 165 | if self.opt.compactmask: 166 | cmasks0 = [] 167 | cmasks = [] 168 | for i in range(4): 169 | if flipped and i in [0,1]: 170 | cmaskpath = os.path.join(self.opt.cmask_dir,regions[1-i],basen[:-4]+'.png') 171 | else: 172 | cmaskpath = os.path.join(self.opt.cmask_dir,regions[i],basen[:-4]+'.png') 173 | im_cmask = Image.open(cmaskpath) 174 | cmask0 = transforms.ToTensor()(im_cmask) 175 | if flipped: 176 | cmask0 = cmask0.index_select(2, idx) 177 | if output_nc == 1 and cmask0.shape[0] == 3: 178 | tmp = cmask0[0, ...] * 0.299 + cmask0[1, ...] * 0.587 + cmask0[2, ...] * 0.114 179 | cmask0 = tmp.unsqueeze(0) 180 | cmask0 = (cmask0 >= 0.5).float() 181 | cmasks0.append(cmask0) 182 | cmask = cmask0.clone() 183 | if self.opt.region_enm in [0,1]: 184 | cmask = cmask[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2] 185 | elif self.opt.region_enm in [2]: # need to multiply cmask 186 | item[regions[i]+'_A'] = (A/2+0.5) * cmask * 2 - 1 187 | item[regions[i]+'_B'] = (B/2+0.5) * cmask * 2 - 1 188 | cmasks.append(cmask) 189 | item['cmaskel'] = cmasks[0] 190 | item['cmasker'] = cmasks[1] 191 | item['cmask'] = cmasks[2] 192 | item['cmaskmo'] = cmasks[3] 193 | if self.opt.hair_local: 194 | mask = torch.ones(B.shape) 195 | if self.opt.region_enm == 0: 196 | for i in range(4): 197 | mask[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2] = 0 198 | if self.opt.soft_border: 199 | imgsize = self.opt.fineSize 200 | maskn = mask[0].numpy() 201 | masks = [np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize])] 202 | masks[0][1:] = maskn[:-1] 203 | masks[1][:-1] = maskn[1:] 204 | masks[2][:,1:] = maskn[:,:-1] 205 | masks[3][:,:-1] = maskn[:,1:] 206 | masks2 = [maskn-e for e in masks] 207 | bound = np.minimum.reduce(masks2) 208 | bound = -bound 209 | xb = [] 210 | yb = [] 211 | for i in range(4): 212 | xbi = [center[i,0]-rws[i]/2, center[i,0]+rws[i]/2-1] 213 | ybi = [center[i,1]-rhs[i]/2, center[i,1]+rhs[i]/2-1] 214 | for j in range(2): 215 | maskx = bound[:,xbi[j]] 216 | masky = bound[ybi[j],:] 217 | tmp_a = torch.from_numpy(maskx)*xbi[j].double() 218 | tmp_b = torch.from_numpy(1-maskx) 219 | xb += [tmp_b*10000 + tmp_a] 220 | 221 | tmp_a = torch.from_numpy(masky)*ybi[j].double() 222 | tmp_b = torch.from_numpy(1-masky) 223 | yb += [tmp_b*10000 + tmp_a] 224 | soft = 1-getSoft([imgsize,imgsize],xb,yb) 225 | soft = torch.Tensor(soft).unsqueeze(0) 226 | mask = (torch.ones(mask.shape)-mask)*soft + mask 227 | elif self.opt.region_enm == 1: 228 | for i in range(4): 229 | cmask0 = cmasks0[i] 230 | rec = torch.zeros(B.shape) 231 | rec[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2] = 1 232 | mask = mask * (torch.ones(B.shape) - cmask0 * rec) 233 | elif self.opt.region_enm == 2: 234 | for i in range(4): 235 | cmask0 = cmasks0[i] 236 | mask = mask * (torch.ones(B.shape) - cmask0) 237 | hair_A = (A/2+0.5) * mask.repeat(int(input_nc/output_nc),1,1) * 2 - 1 238 | hair_B = (B/2+0.5) * mask * 2 - 1 239 | item['hair_A'] = hair_A 240 | item['hair_B'] = hair_B 241 | item['mask'] = mask # mask out eyes, nose, mouth 242 | if self.opt.bg_local: 243 | bgdir = self.opt.bg_dir 244 | bgpath = os.path.join(bgdir,basen[:-4]+'.png') 245 | im_bg = Image.open(bgpath) 246 | mask2 = transforms.ToTensor()(im_bg) # mask out background 247 | if flipped: 248 | mask2 = mask2.index_select(2, idx) 249 | mask2 = (mask2 >= 0.5).float() 250 | hair_A = (A/2+0.5) * mask.repeat(int(input_nc/output_nc),1,1) * mask2.repeat(int(input_nc/output_nc),1,1) * 2 - 1 251 | hair_B = (B/2+0.5) * mask * mask2 * 2 - 1 252 | bg_A = (A/2+0.5) * (torch.ones(mask2.shape)-mask2).repeat(int(input_nc/output_nc),1,1) * 2 - 1 253 | bg_B = (B/2+0.5) * (torch.ones(mask2.shape)-mask2) * 2 - 1 254 | item['hair_A'] = hair_A 255 | item['hair_B'] = hair_B 256 | item['bg_A'] = bg_A 257 | item['bg_B'] = bg_B 258 | item['mask'] = mask 259 | item['mask2'] = mask2 260 | 261 | if (self.opt.isTrain and self.opt.chamfer_loss): 262 | if self.opt.which_direction == 'AtoB': 263 | img = tocv2(B) 264 | else: 265 | img = tocv2(A) 266 | dt1, dt2 = dt(img) 267 | dt1 = torch.from_numpy(dt1) 268 | dt2 = torch.from_numpy(dt2) 269 | dt1 = dt1.unsqueeze(0) 270 | dt2 = dt2.unsqueeze(0) 271 | item['dt1gt'] = dt1 272 | item['dt2gt'] = dt2 273 | 274 | if self.opt.isTrain and self.opt.emphasis_conti_face: 275 | face_mask_path = os.path.join(self.opt.facemask_dir,basen[:-4]+'.png') 276 | face_mask = Image.open(face_mask_path) 277 | face_mask = transforms.ToTensor()(face_mask) # [0,1] 278 | if flipped: 279 | face_mask = face_mask.index_select(2, idx) 280 | item['face_mask'] = face_mask 281 | 282 | return item 283 | 284 | def __len__(self): 285 | return len(self.AB_paths) 286 | 287 | def name(self): 288 | return 'AlignedDataset' 289 | -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | class BaseDataLoader(): 2 | def __init__(self): 3 | pass 4 | 5 | def initialize(self, opt): 6 | self.opt = opt 7 | pass 8 | 9 | def load_data(): 10 | return None 11 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | 6 | class BaseDataset(data.Dataset): 7 | def __init__(self): 8 | super(BaseDataset, self).__init__() 9 | 10 | def name(self): 11 | return 'BaseDataset' 12 | 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train): 15 | return parser 16 | 17 | def initialize(self, opt): 18 | pass 19 | 20 | def __len__(self): 21 | return 0 22 | 23 | 24 | def get_transform(opt): 25 | transform_list = [] 26 | if opt.resize_or_crop == 'resize_and_crop': 27 | osize = [opt.loadSize, opt.fineSize] 28 | transform_list.append(transforms.Resize(osize, Image.BICUBIC)) 29 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 30 | elif opt.resize_or_crop == 'crop': 31 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 32 | elif opt.resize_or_crop == 'scale_width': 33 | transform_list.append(transforms.Lambda( 34 | lambda img: __scale_width(img, opt.fineSize))) 35 | elif opt.resize_or_crop == 'scale_width_and_crop': 36 | transform_list.append(transforms.Lambda( 37 | lambda img: __scale_width(img, opt.loadSize))) 38 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 39 | elif opt.resize_or_crop == 'none': 40 | transform_list.append(transforms.Lambda( 41 | lambda img: __adjust(img))) 42 | else: 43 | raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop) 44 | 45 | if opt.isTrain and not opt.no_flip: 46 | transform_list.append(transforms.RandomHorizontalFlip()) 47 | 48 | transform_list += [transforms.ToTensor(), 49 | transforms.Normalize((0.5, 0.5, 0.5), 50 | (0.5, 0.5, 0.5))] 51 | return transforms.Compose(transform_list) 52 | 53 | # just modify the width and height to be multiple of 4 54 | def __adjust(img): 55 | ow, oh = img.size 56 | 57 | # the size needs to be a multiple of this number, 58 | # because going through generator network may change img size 59 | # and eventually cause size mismatch error 60 | mult = 4 61 | if ow % mult == 0 and oh % mult == 0: 62 | return img 63 | w = (ow - 1) // mult 64 | w = (w + 1) * mult 65 | h = (oh - 1) // mult 66 | h = (h + 1) * mult 67 | 68 | if ow != w or oh != h: 69 | __print_size_warning(ow, oh, w, h) 70 | 71 | return img.resize((w, h), Image.BICUBIC) 72 | 73 | 74 | def __scale_width(img, target_width): 75 | ow, oh = img.size 76 | 77 | # the size needs to be a multiple of this number, 78 | # because going through generator network may change img size 79 | # and eventually cause size mismatch error 80 | mult = 4 81 | assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult 82 | if (ow == target_width and oh % mult == 0): 83 | return img 84 | w = target_width 85 | target_height = int(target_width * oh / ow) 86 | m = (target_height - 1) // mult 87 | h = (m + 1) * mult 88 | 89 | if target_height != h: 90 | __print_size_warning(target_width, target_height, w, h) 91 | 92 | return img.resize((w, h), Image.BICUBIC) 93 | 94 | 95 | def __print_size_warning(ow, oh, w, h): 96 | if not hasattr(__print_size_warning, 'has_printed'): 97 | print("The image size needs to be a multiple of 4. " 98 | "The loaded image size was (%d, %d), so it was adjusted to " 99 | "(%d, %d). This adjustment will be done to all images " 100 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 101 | __print_size_warning.has_printed = True 102 | 103 | 104 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | 34 | return images 35 | 36 | 37 | def default_loader(path): 38 | return Image.open(path).convert('RGB') 39 | 40 | 41 | class ImageFolder(data.Dataset): 42 | 43 | def __init__(self, root, transform=None, return_paths=False, 44 | loader=default_loader): 45 | imgs = make_dataset(root) 46 | if len(imgs) == 0: 47 | raise(RuntimeError("Found 0 images in: " + root + "\n" 48 | "Supported image extensions are: " + 49 | ",".join(IMG_EXTENSIONS))) 50 | 51 | self.root = root 52 | self.imgs = imgs 53 | self.transform = transform 54 | self.return_paths = return_paths 55 | self.loader = loader 56 | 57 | def __getitem__(self, index): 58 | path = self.imgs[index] 59 | img = self.loader(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.return_paths: 63 | return img, path 64 | else: 65 | return img 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | -------------------------------------------------------------------------------- /data/single_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import numpy as np 6 | import csv 7 | import torch 8 | import torchvision.transforms as transforms 9 | 10 | def getfeats(featpath): 11 | trans_points = np.empty([5,2],dtype=np.int64) 12 | with open(featpath, 'r') as csvfile: 13 | reader = csv.reader(csvfile, delimiter=' ') 14 | for ind,row in enumerate(reader): 15 | trans_points[ind,:] = row 16 | return trans_points 17 | 18 | def getSoft(size,xb,yb,boundwidth=5.0): 19 | xarray = np.tile(np.arange(0,size[1]),(size[0],1)) 20 | yarray = np.tile(np.arange(0,size[0]),(size[1],1)).transpose() 21 | cxdists = [] 22 | cydists = [] 23 | for i in range(len(xb)): 24 | xba = np.tile(xb[i],(size[1],1)).transpose() 25 | yba = np.tile(yb[i],(size[0],1)) 26 | cxdists.append(np.abs(xarray-xba)) 27 | cydists.append(np.abs(yarray-yba)) 28 | xdist = np.minimum.reduce(cxdists) 29 | ydist = np.minimum.reduce(cydists) 30 | manhdist = np.minimum.reduce([xdist,ydist]) 31 | im = (manhdist+1) / (boundwidth+1) * 1.0 32 | im[im>=1.0] = 1.0 33 | return im 34 | 35 | class SingleDataset(BaseDataset): 36 | @staticmethod 37 | def modify_commandline_options(parser, is_train): 38 | return parser 39 | 40 | def initialize(self, opt): 41 | self.opt = opt 42 | self.root = opt.dataroot 43 | self.dir_A = os.path.join(opt.dataroot) 44 | imglist = 'datasets/apdrawing_list/%s/%s.txt' % (opt.phase, opt.dataroot) 45 | if os.path.exists(imglist): 46 | lines = open(imglist, 'r').read().splitlines() 47 | self.A_paths = sorted(lines) 48 | else: 49 | self.A_paths = make_dataset(self.dir_A) 50 | self.A_paths = sorted(self.A_paths) 51 | self.transform = get_transform(opt) # this function uses NO_FLIP; aligned dataset do not use this, aligned dataset manually transform 52 | 53 | def __getitem__(self, index): 54 | A_path = self.A_paths[index] 55 | A_img = Image.open(A_path).convert('RGB') 56 | A = self.transform(A_img) 57 | if self.opt.which_direction == 'BtoA': 58 | input_nc = self.opt.output_nc 59 | output_nc = self.opt.input_nc 60 | else: 61 | input_nc = self.opt.input_nc 62 | output_nc = self.opt.output_nc 63 | 64 | if input_nc == 1: # RGB to gray 65 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 66 | A = tmp.unsqueeze(0) 67 | 68 | item = {'A': A, 'A_paths': A_path} 69 | 70 | if self.opt.use_local: 71 | regions = ['eyel','eyer','nose','mouth'] 72 | basen = os.path.basename(A_path)[:-4]+'.txt' 73 | featdir = self.opt.lm_dir 74 | featpath = os.path.join(featdir,basen) 75 | feats = getfeats(featpath) 76 | mouth_x = int((feats[3,0]+feats[4,0])/2.0) 77 | mouth_y = int((feats[3,1]+feats[4,1])/2.0) 78 | ratio = self.opt.fineSize / 256 79 | EYE_H = self.opt.EYE_H * ratio 80 | EYE_W = self.opt.EYE_W * ratio 81 | NOSE_H = self.opt.NOSE_H * ratio 82 | NOSE_W = self.opt.NOSE_W * ratio 83 | MOUTH_H = self.opt.MOUTH_H * ratio 84 | MOUTH_W = self.opt.MOUTH_W * ratio 85 | center = torch.IntTensor([[feats[0,0],feats[0,1]-4*ratio],[feats[1,0],feats[1,1]-4*ratio],[feats[2,0],feats[2,1]-NOSE_H/2+16*ratio],[mouth_x,mouth_y]]) 86 | item['center'] = center 87 | rhs = [int(EYE_H),int(EYE_H),int(NOSE_H),int(MOUTH_H)] 88 | rws = [int(EYE_W),int(EYE_W),int(NOSE_W),int(MOUTH_W)] 89 | if self.opt.soft_border: 90 | soft_border_mask4 = [] 91 | for i in range(4): 92 | xb = [np.zeros(rhs[i]),np.ones(rhs[i])*(rws[i]-1)] 93 | yb = [np.zeros(rws[i]),np.ones(rws[i])*(rhs[i]-1)] 94 | soft_border_mask = getSoft([rhs[i],rws[i]],xb,yb) 95 | soft_border_mask4.append(torch.Tensor(soft_border_mask).unsqueeze(0)) 96 | item['soft_'+regions[i]+'_mask'] = soft_border_mask4[i] 97 | for i in range(4): 98 | item[regions[i]+'_A'] = A[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2] 99 | if self.opt.soft_border: 100 | item[regions[i]+'_A'] = item[regions[i]+'_A'] * soft_border_mask4[i].repeat(int(input_nc/output_nc),1,1) 101 | if self.opt.compactmask: 102 | cmasks0 = [] 103 | cmasks = [] 104 | for i in range(4): 105 | cmaskpath = os.path.join(self.opt.cmask_dir,regions[i],basen[:-4]+'.png') 106 | im_cmask = Image.open(cmaskpath) 107 | cmask0 = transforms.ToTensor()(im_cmask) 108 | if output_nc == 1 and cmask0.shape[0] == 3: 109 | tmp = cmask0[0, ...] * 0.299 + cmask0[1, ...] * 0.587 + cmask0[2, ...] * 0.114 110 | cmask0 = tmp.unsqueeze(0) 111 | cmask0 = (cmask0 >= 0.5).float() 112 | cmasks0.append(cmask0) 113 | cmask = cmask0.clone() 114 | cmask = cmask[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2] 115 | cmasks.append(cmask) 116 | item['cmaskel'] = cmasks[0] 117 | item['cmasker'] = cmasks[1] 118 | item['cmask'] = cmasks[2] 119 | item['cmaskmo'] = cmasks[3] 120 | if self.opt.hair_local: 121 | output_nc = self.opt.output_nc 122 | mask = torch.ones([output_nc,A.shape[1],A.shape[2]]) 123 | for i in range(4): 124 | mask[:,center[i,1]-rhs[i]/2:center[i,1]+rhs[i]/2,center[i,0]-rws[i]/2:center[i,0]+rws[i]/2] = 0 125 | if self.opt.soft_border: 126 | imgsize = self.opt.fineSize 127 | maskn = mask[0].numpy() 128 | masks = [np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize]),np.ones([imgsize,imgsize])] 129 | masks[0][1:] = maskn[:-1] 130 | masks[1][:-1] = maskn[1:] 131 | masks[2][:,1:] = maskn[:,:-1] 132 | masks[3][:,:-1] = maskn[:,1:] 133 | masks2 = [maskn-e for e in masks] 134 | bound = np.minimum.reduce(masks2) 135 | bound = -bound 136 | xb = [] 137 | yb = [] 138 | for i in range(4): 139 | xbi = [center[i,0]-rws[i]/2, center[i,0]+rws[i]/2-1] 140 | ybi = [center[i,1]-rhs[i]/2, center[i,1]+rhs[i]/2-1] 141 | for j in range(2): 142 | maskx = bound[:,xbi[j]] 143 | masky = bound[ybi[j],:] 144 | tmp_a = torch.from_numpy(maskx)*xbi[j].double() 145 | tmp_b = torch.from_numpy(1-maskx) 146 | xb += [tmp_b*10000 + tmp_a] 147 | 148 | tmp_a = torch.from_numpy(masky)*ybi[j].double() 149 | tmp_b = torch.from_numpy(1-masky) 150 | yb += [tmp_b*10000 + tmp_a] 151 | soft = 1-getSoft([imgsize,imgsize],xb,yb) 152 | soft = torch.Tensor(soft).unsqueeze(0) 153 | mask = (torch.ones(mask.shape)-mask)*soft + mask 154 | hair_A = (A/2+0.5) * mask.repeat(int(input_nc/output_nc),1,1) * 2 - 1 155 | item['hair_A'] = hair_A 156 | item['mask'] = mask 157 | if self.opt.bg_local: 158 | bgdir = self.opt.bg_dir 159 | bgpath = os.path.join(bgdir,basen[:-4]+'.png') 160 | im_bg = Image.open(bgpath) 161 | mask2 = transforms.ToTensor()(im_bg) # mask out background 162 | mask2 = (mask2 >= 0.5).float() 163 | hair_A = (A/2+0.5) * mask.repeat(int(input_nc/output_nc),1,1) * mask2.repeat(int(input_nc/output_nc),1,1) * 2 - 1 164 | bg_A = (A/2+0.5) * (torch.ones(mask2.shape)-mask2).repeat(int(input_nc/output_nc),1,1) * 2 - 1 165 | item['hair_A'] = hair_A 166 | item['bg_A'] = bg_A 167 | item['mask'] = mask 168 | item['mask2'] = mask2 169 | 170 | return item 171 | 172 | def __len__(self): 173 | return len(self.A_paths) 174 | 175 | def name(self): 176 | return 'SingleImageDataset' 177 | -------------------------------------------------------------------------------- /docs/tips.md: -------------------------------------------------------------------------------- 1 | ## Training/test Tips 2 | - Flags: see `options/train_options.py` and `options/base_options.py` for the training flags; see `options/test_options.py` and `options/base_options.py` for the test flags. The default values of these options are somtimes adjusted in the model files. 3 | 4 | - CPU/GPU (default `--gpu_ids 0`): set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. You need a large batch size (e.g. `--batch_size 32`) to benefit from multiple GPUs. 5 | 6 | - Visualization: during training, the current results can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will appear on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have `visdom` installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id -1`. Second, the intermediate results are saved to `[opt.checkpoints_dir]/[opt.name]/web/` as an HTML file. To avoid this, set `--no_html`. 7 | 8 | - Fine-tuning/Resume training: to fine-tune a pre-trained model, or resume the previous training, use the `--continue_train` flag. The program will then load the model based on `which_epoch`. By default, the program will initialize the epoch count as 1. Set `--epoch_count ` to specify a different starting epoch count. 9 | -------------------------------------------------------------------------------- /imgs/architecture-pami.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/imgs/architecture-pami.jpg -------------------------------------------------------------------------------- /imgs/sample/140_large-img_1615_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/imgs/sample/140_large-img_1615_fake_B.png -------------------------------------------------------------------------------- /imgs/sample/140_large-img_1615_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/imgs/sample/140_large-img_1615_real_A.png -------------------------------------------------------------------------------- /imgs/sample/140_large-img_1616_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/imgs/sample/140_large-img_1616_fake_B.png -------------------------------------------------------------------------------- /imgs/sample/140_large-img_1616_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/imgs/sample/140_large-img_1616_real_A.png -------------------------------------------------------------------------------- /imgs/sample/140_large-img_1673_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/imgs/sample/140_large-img_1673_fake_B.png -------------------------------------------------------------------------------- /imgs/sample/140_large-img_1673_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/imgs/sample/140_large-img_1673_real_A.png -------------------------------------------------------------------------------- /imgs/sample/140_large-img_1684_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/imgs/sample/140_large-img_1684_fake_B.png -------------------------------------------------------------------------------- /imgs/sample/140_large-img_1684_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/imgs/sample/140_large-img_1684_real_A.png -------------------------------------------------------------------------------- /imgs/sample/140_large-img_1696_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/imgs/sample/140_large-img_1696_fake_B.png -------------------------------------------------------------------------------- /imgs/sample/140_large-img_1696_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/imgs/sample/140_large-img_1696_real_A.png -------------------------------------------------------------------------------- /imgs/sample/140_large-img_1701_fake_B.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/imgs/sample/140_large-img_1701_fake_B.png -------------------------------------------------------------------------------- /imgs/sample/140_large-img_1701_real_A.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/imgs/sample/140_large-img_1701_real_A.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from models.base_model import BaseModel 3 | 4 | 5 | def find_model_using_name(model_name): 6 | # Given the option --model [modelname], 7 | # the file "models/modelname_model.py" 8 | # will be imported. 9 | model_filename = "models." + model_name + "_model" 10 | modellib = importlib.import_module(model_filename) 11 | 12 | # In the file, the class called ModelNameModel() will 13 | # be instantiated. It has to be a subclass of BaseModel, 14 | # and it is case-insensitive. 15 | model = None 16 | target_model_name = model_name.replace('_', '') + 'model' 17 | for name, cls in modellib.__dict__.items(): 18 | if name.lower() == target_model_name.lower() \ 19 | and issubclass(cls, BaseModel): 20 | model = cls 21 | 22 | if model is None: 23 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 24 | exit(0) 25 | 26 | return model 27 | 28 | 29 | def get_option_setter(model_name): 30 | model_class = find_model_using_name(model_name) 31 | return model_class.modify_commandline_options 32 | 33 | 34 | def create_model(opt): 35 | model = find_model_using_name(opt.model) 36 | instance = model() 37 | instance.initialize(opt) 38 | print("model [%s] was created" % (instance.name())) 39 | return instance 40 | -------------------------------------------------------------------------------- /models/apdrawingpp_style_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from util.image_pool import ImagePool 3 | from .base_model import BaseModel 4 | from . import networks 5 | import os 6 | import math 7 | 8 | W = 11 9 | aa = int(math.floor(512./W)) 10 | res = 512 - W*aa 11 | 12 | 13 | def padpart(A,part,centers,opt,device): 14 | IMAGE_SIZE = opt.fineSize 15 | bs,nc,_,_ = A.shape 16 | ratio = IMAGE_SIZE / 256 17 | NOSE_W = opt.NOSE_W * ratio 18 | NOSE_H = opt.NOSE_H * ratio 19 | EYE_W = opt.EYE_W * ratio 20 | EYE_H = opt.EYE_H * ratio 21 | MOUTH_W = opt.MOUTH_W * ratio 22 | MOUTH_H = opt.MOUTH_H * ratio 23 | A_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(device) 24 | padvalue = -1 # black 25 | for i in range(bs): 26 | center = centers[i] 27 | if part == 'nose': 28 | A_p[i] = torch.nn.ConstantPad2d((center[2,0] - NOSE_W / 2, IMAGE_SIZE - (center[2,0]+NOSE_W/2), center[2,1] - NOSE_H / 2, IMAGE_SIZE - (center[2,1]+NOSE_H/2)),padvalue)(A[i]) 29 | elif part == 'eyel': 30 | A_p[i] = torch.nn.ConstantPad2d((center[0,0] - EYE_W / 2, IMAGE_SIZE - (center[0,0]+EYE_W/2), center[0,1] - EYE_H / 2, IMAGE_SIZE - (center[0,1]+EYE_H/2)),padvalue)(A[i]) 31 | elif part == 'eyer': 32 | A_p[i] = torch.nn.ConstantPad2d((center[1,0] - EYE_W / 2, IMAGE_SIZE - (center[1,0]+EYE_W/2), center[1,1] - EYE_H / 2, IMAGE_SIZE - (center[1,1]+EYE_H/2)),padvalue)(A[i]) 33 | elif part == 'mouth': 34 | A_p[i] = torch.nn.ConstantPad2d((center[3,0] - MOUTH_W / 2, IMAGE_SIZE - (center[3,0]+MOUTH_W/2), center[3,1] - MOUTH_H / 2, IMAGE_SIZE - (center[3,1]+MOUTH_H/2)),padvalue)(A[i]) 35 | return A_p 36 | 37 | import numpy as np 38 | def nonlinearDt(dt,type='atan',xmax=torch.Tensor([10.0])):#dt in [0,1], first multiply xmax(>1), then remap to [0,1] 39 | if type == 'atan': 40 | nldt = torch.atan(dt*xmax) / torch.atan(xmax) 41 | elif type == 'sigmoid': 42 | nldt = (torch.sigmoid(dt*xmax)-0.5) / (torch.sigmoid(xmax)-0.5) 43 | elif type == 'tanh': 44 | nldt = torch.tanh(dt*xmax) / torch.tanh(xmax) 45 | elif type == 'pow': 46 | nldt = torch.pow(dt*xmax,2) / torch.pow(xmax,2) 47 | elif type == 'exp': 48 | if xmax.item()>1: 49 | xmax = xmax / 3 50 | nldt = (torch.exp(dt*xmax)-1) / (torch.exp(xmax)-1) 51 | #print("remap dt:", type, xmax.item()) 52 | return nldt 53 | 54 | class APDrawingPPStyleModel(BaseModel): 55 | def name(self): 56 | return 'APDrawingPPStyleModel' 57 | 58 | @staticmethod 59 | def modify_commandline_options(parser, is_train=True): 60 | 61 | # changing the default values to match the pix2pix paper 62 | # (https://phillipi.github.io/pix2pix/) 63 | parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch')# no_lsgan=True, use_lsgan=False 64 | parser.set_defaults(dataset_mode='aligned') 65 | parser.set_defaults(auxiliary_root='auxiliaryeye2o') 66 | parser.set_defaults(use_local=True, hair_local=True, bg_local=True) 67 | parser.set_defaults(discriminator_local=True, gan_loss_strategy=2) 68 | parser.set_defaults(chamfer_loss=True, dt_nonlinear='exp', lambda_chamfer=0.35, lambda_chamfer2=0.35) 69 | parser.set_defaults(nose_ae=True, others_ae=True, compactmask=True, MOUTH_H=56) 70 | parser.set_defaults(soft_border=1, batch_size=1, save_epoch_freq=25) 71 | parser.add_argument('--nnG_hairc', type=int, default=6, help='nnG for hair classifier') 72 | parser.add_argument('--use_resnet', action='store_true', help='use resnet for generator') 73 | parser.add_argument('--regarch', type=int, default=4, help='architecture for netRegressor') 74 | if is_train: 75 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 76 | parser.add_argument('--lambda_local', type=float, default=25.0, help='weight for Local loss') 77 | parser.set_defaults(netG_dt='unet_512') 78 | parser.set_defaults(netG_line='unet_512') 79 | 80 | return parser 81 | 82 | def initialize(self, opt): 83 | BaseModel.initialize(self, opt) 84 | self.isTrain = opt.isTrain 85 | # specify the training losses you want to print out. The program will call base_model.get_current_losses 86 | self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] 87 | if self.isTrain and self.opt.no_l1_loss: 88 | self.loss_names = ['G_GAN', 'D_real', 'D_fake'] 89 | if self.isTrain and self.opt.use_local and not self.opt.no_G_local_loss: 90 | self.loss_names.append('G_local') 91 | self.loss_names.append('G_hair_local') 92 | self.loss_names.append('G_bg_local') 93 | if self.isTrain and self.opt.discriminator_local: 94 | self.loss_names.append('D_real_local') 95 | self.loss_names.append('D_fake_local') 96 | self.loss_names.append('G_GAN_local') 97 | if self.isTrain and self.opt.chamfer_loss: 98 | self.loss_names.append('G_chamfer') 99 | self.loss_names.append('G_chamfer2') 100 | if self.isTrain and self.opt.continuity_loss: 101 | self.loss_names.append('G_continuity') 102 | self.loss_names.append('G') 103 | print('loss_names', self.loss_names) 104 | # specify the images you want to save/display. The program will call base_model.get_current_visuals 105 | self.visual_names = ['real_A', 'fake_B', 'real_B'] 106 | if self.opt.use_local: 107 | self.visual_names += ['fake_B0', 'fake_B1'] 108 | self.visual_names += ['fake_B_hair', 'real_B_hair', 'real_A_hair'] 109 | self.visual_names += ['fake_B_bg', 'real_B_bg', 'real_A_bg'] 110 | if self.opt.region_enm in [0,1]: 111 | if self.opt.nose_ae: 112 | self.visual_names += ['fake_B_nose_v','fake_B_nose_v1','fake_B_nose_v2','cmask1no'] 113 | if self.opt.others_ae: 114 | self.visual_names += ['fake_B_eyel_v','fake_B_eyel_v1','fake_B_eyel_v2','cmask1el'] 115 | self.visual_names += ['fake_B_eyer_v','fake_B_eyer_v1','fake_B_eyer_v2','cmask1er'] 116 | self.visual_names += ['fake_B_mouth_v','fake_B_mouth_v1','fake_B_mouth_v2','cmask1mo'] 117 | elif self.opt.region_enm in [2]: 118 | self.visual_names += ['fake_B_nose','fake_B_eyel','fake_B_eyer','fake_B_mouth'] 119 | if self.isTrain and self.opt.chamfer_loss: 120 | self.visual_names += ['dt1', 'dt2'] 121 | self.visual_names += ['dt1gt', 'dt2gt'] 122 | if self.isTrain and self.opt.soft_border: 123 | self.visual_names += ['mask'] 124 | if not self.isTrain and self.opt.save2: 125 | self.visual_names = ['real_A', 'fake_B'] 126 | print('visuals', self.visual_names) 127 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks 128 | self.auxiliary_model_names = [] 129 | if self.isTrain: 130 | self.model_names = ['G', 'D'] 131 | if self.opt.discriminator_local: 132 | self.model_names += ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG'] 133 | # auxiliary nets for loss calculation 134 | if self.opt.chamfer_loss: 135 | self.auxiliary_model_names += ['DT1', 'DT2'] 136 | self.auxiliary_model_names += ['Line1', 'Line2'] 137 | if self.opt.continuity_loss: 138 | self.auxiliary_model_names += ['Regressor'] 139 | else: # during test time, only load Gs 140 | self.model_names = ['G'] 141 | if self.opt.test_continuity_loss: 142 | self.auxiliary_model_names += ['Regressor'] 143 | if self.opt.use_local: 144 | self.model_names += ['GLEyel','GLEyer','GLNose','GLMouth','GLHair','GLBG','GCombine'] 145 | self.auxiliary_model_names += ['CLm','CLh'] 146 | # auxiliary nets for local output refinement 147 | if self.opt.nose_ae: 148 | self.auxiliary_model_names += ['AE'] 149 | if self.opt.others_ae: 150 | self.auxiliary_model_names += ['AEel','AEer','AEmowhite','AEmoblack'] 151 | print('model_names', self.model_names) 152 | print('auxiliary_model_names', self.auxiliary_model_names) 153 | # load/define networks 154 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, 155 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 156 | opt.nnG) 157 | print('netG', opt.netG) 158 | 159 | if self.isTrain: 160 | use_sigmoid = opt.no_lsgan 161 | self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 162 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) 163 | print('netD', opt.netD, opt.n_layers_D) 164 | if self.opt.discriminator_local: 165 | self.netDLEyel = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 166 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) 167 | self.netDLEyer = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 168 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) 169 | self.netDLNose = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 170 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) 171 | self.netDLMouth = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 172 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) 173 | self.netDLHair = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 174 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) 175 | self.netDLBG = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, 176 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) 177 | 178 | 179 | if self.opt.use_local: 180 | netlocal1 = 'partunet' if self.opt.use_resnet == 0 else 'resnet_nblocks' 181 | netlocal2 = 'partunet2' if self.opt.use_resnet == 0 else 'resnet_6blocks' 182 | netlocal2_style = 'partunet2style' if self.opt.use_resnet == 0 else 'resnet_style2_6blocks' 183 | self.netGLEyel = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm, 184 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3) 185 | self.netGLEyer = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm, 186 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3) 187 | self.netGLNose = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm, 188 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3) 189 | self.netGLMouth = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm, 190 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3) 191 | self.netGLHair = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2_style, opt.norm, 192 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=4, 193 | extra_channel=3) 194 | self.netGLBG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2, opt.norm, 195 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=4) 196 | # by default combiner_type is combiner, which uses resnet 197 | print('combiner_type', self.opt.combiner_type) 198 | self.netGCombine = networks.define_G(2*opt.output_nc, opt.output_nc, opt.ngf, self.opt.combiner_type, opt.norm, 199 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 2) 200 | # auxiliary classifiers for mouth and hair 201 | ratio = self.opt.fineSize / 256 202 | self.MOUTH_H = int(self.opt.MOUTH_H * ratio) 203 | self.MOUTH_W = int(self.opt.MOUTH_W * ratio) 204 | self.netCLm = networks.define_G(opt.input_nc, 2, opt.ngf, 'classifier', opt.norm, 205 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 206 | nnG = 3, ae_h = self.MOUTH_H, ae_w = self.MOUTH_W) 207 | self.netCLh = networks.define_G(opt.input_nc, 3, opt.ngf, 'classifier', opt.norm, 208 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 209 | nnG = opt.nnG_hairc, ae_h = opt.fineSize, ae_w = opt.fineSize) 210 | 211 | 212 | if self.isTrain: 213 | self.fake_AB_pool = ImagePool(opt.pool_size) 214 | # define loss functions 215 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) 216 | self.criterionL1 = torch.nn.L1Loss() 217 | 218 | # initialize optimizers 219 | self.optimizers = [] 220 | if not self.opt.use_local: 221 | print('G_params 1 components') 222 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), 223 | lr=opt.lr, betas=(opt.beta1, 0.999)) 224 | else: 225 | G_params = list(self.netG.parameters()) + list(self.netGLEyel.parameters()) + list(self.netGLEyer.parameters()) + list(self.netGLNose.parameters()) + list(self.netGLMouth.parameters()) + list(self.netGCombine.parameters()) + list(self.netGLHair.parameters()) + list(self.netGLBG.parameters()) 226 | print('G_params 8 components') 227 | self.optimizer_G = torch.optim.Adam(G_params, 228 | lr=opt.lr, betas=(opt.beta1, 0.999)) 229 | 230 | if not self.opt.discriminator_local: 231 | print('D_params 1 components') 232 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), 233 | lr=opt.lr, betas=(opt.beta1, 0.999)) 234 | else:#self.opt.discriminator_local == True 235 | D_params = list(self.netD.parameters()) + list(self.netDLEyel.parameters()) +list(self.netDLEyer.parameters()) + list(self.netDLNose.parameters()) + list(self.netDLMouth.parameters()) + list(self.netDLHair.parameters()) + list(self.netDLBG.parameters()) 236 | print('D_params 7 components') 237 | self.optimizer_D = torch.optim.Adam(D_params, 238 | lr=opt.lr, betas=(opt.beta1, 0.999)) 239 | self.optimizers.append(self.optimizer_G) 240 | self.optimizers.append(self.optimizer_D) 241 | 242 | # ==================================auxiliary nets (loaded, parameters fixed)============================= 243 | if self.opt.use_local and self.opt.nose_ae: 244 | ratio = self.opt.fineSize / 256 245 | NOSE_H = self.opt.NOSE_H * ratio 246 | NOSE_W = self.opt.NOSE_W * ratio 247 | self.netAE = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch', 248 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 249 | latent_dim=self.opt.ae_latentno, ae_h=NOSE_H, ae_w=NOSE_W) 250 | self.set_requires_grad(self.netAE, False) 251 | if self.opt.use_local and self.opt.others_ae: 252 | ratio = self.opt.fineSize / 256 253 | EYE_H = self.opt.EYE_H * ratio 254 | EYE_W = self.opt.EYE_W * ratio 255 | MOUTH_H = self.opt.MOUTH_H * ratio 256 | MOUTH_W = self.opt.MOUTH_W * ratio 257 | self.netAEel = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch', 258 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 259 | latent_dim=self.opt.ae_latenteye, ae_h=EYE_H, ae_w=EYE_W) 260 | self.netAEer = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch', 261 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 262 | latent_dim=self.opt.ae_latenteye, ae_h=EYE_H, ae_w=EYE_W) 263 | self.netAEmowhite = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch', 264 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 265 | latent_dim=self.opt.ae_latentmo, ae_h=MOUTH_H, ae_w=MOUTH_W) 266 | self.netAEmoblack = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch', 267 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 268 | latent_dim=self.opt.ae_latentmo, ae_h=MOUTH_H, ae_w=MOUTH_W) 269 | self.set_requires_grad(self.netAEel, False) 270 | self.set_requires_grad(self.netAEer, False) 271 | self.set_requires_grad(self.netAEmowhite, False) 272 | self.set_requires_grad(self.netAEmoblack, False) 273 | 274 | 275 | if self.isTrain and self.opt.continuity_loss: 276 | self.nc = 1 277 | self.netRegressor = networks.define_G(self.nc, 1, opt.ngf, 'regressor', opt.norm, 278 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p, 279 | nnG = opt.regarch) 280 | self.set_requires_grad(self.netRegressor, False) 281 | 282 | if self.isTrain and self.opt.chamfer_loss: 283 | self.nc = 1 284 | self.netDT1 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_dt, opt.norm, 285 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p) 286 | self.netDT2 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_dt, opt.norm, 287 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p) 288 | self.set_requires_grad(self.netDT1, False) 289 | self.set_requires_grad(self.netDT2, False) 290 | self.netLine1 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_line, opt.norm, 291 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p) 292 | self.netLine2 = networks.define_G(self.nc, self.nc, opt.ngf, opt.netG_line, opt.norm, 293 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids_p) 294 | self.set_requires_grad(self.netLine1, False) 295 | self.set_requires_grad(self.netLine2, False) 296 | 297 | # ==================================for test (nets loaded, parameters fixed)============================= 298 | if not self.isTrain and self.opt.test_continuity_loss: 299 | self.nc = 1 300 | self.netRegressor = networks.define_G(self.nc, 1, opt.ngf, 'regressor', opt.norm, 301 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 302 | nnG = opt.regarch) 303 | self.set_requires_grad(self.netRegressor, False) 304 | 305 | 306 | def set_input(self, input): 307 | AtoB = self.opt.which_direction == 'AtoB' 308 | self.real_A = input['A' if AtoB else 'B'].to(self.device) 309 | self.real_B = input['B' if AtoB else 'A'].to(self.device) 310 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] 311 | self.batch_size = len(self.image_paths) 312 | if self.opt.use_local: 313 | self.real_A_eyel = input['eyel_A'].to(self.device) 314 | self.real_A_eyer = input['eyer_A'].to(self.device) 315 | self.real_A_nose = input['nose_A'].to(self.device) 316 | self.real_A_mouth = input['mouth_A'].to(self.device) 317 | self.real_B_eyel = input['eyel_B'].to(self.device) 318 | self.real_B_eyer = input['eyer_B'].to(self.device) 319 | self.real_B_nose = input['nose_B'].to(self.device) 320 | self.real_B_mouth = input['mouth_B'].to(self.device) 321 | if self.opt.region_enm in [0,1]: 322 | self.center = input['center'] 323 | if self.opt.soft_border: 324 | self.softel = input['soft_eyel_mask'].to(self.device) 325 | self.softer = input['soft_eyer_mask'].to(self.device) 326 | self.softno = input['soft_nose_mask'].to(self.device) 327 | self.softmo = input['soft_mouth_mask'].to(self.device) 328 | if self.opt.compactmask: 329 | self.cmask = input['cmask'].to(self.device) 330 | self.cmask1 = self.cmask*2-1#[0,1]->[-1,1] 331 | self.cmaskel = input['cmaskel'].to(self.device) 332 | self.cmask1el = self.cmaskel*2-1 333 | self.cmasker = input['cmasker'].to(self.device) 334 | self.cmask1er = self.cmasker*2-1 335 | self.cmaskmo = input['cmaskmo'].to(self.device) 336 | self.cmask1mo = self.cmaskmo*2-1 337 | self.real_A_hair = input['hair_A'].to(self.device) 338 | self.real_B_hair = input['hair_B'].to(self.device) 339 | self.mask = input['mask'].to(self.device) # mask for non-eyes,nose,mouth 340 | self.mask2 = input['mask2'].to(self.device) # mask for non-bg 341 | self.real_A_bg = input['bg_A'].to(self.device) 342 | self.real_B_bg = input['bg_B'].to(self.device) 343 | if (self.isTrain and self.opt.chamfer_loss): 344 | self.dt1gt = input['dt1gt'].to(self.device) 345 | self.dt2gt = input['dt2gt'].to(self.device) 346 | if self.isTrain and self.opt.emphasis_conti_face: 347 | self.face_mask = input['face_mask'].cuda(self.gpu_ids_p[0]) 348 | 349 | def getonehot(self,outputs,classes): 350 | [maxv,index] = torch.max(outputs,1) 351 | y = torch.unsqueeze(index,1) 352 | onehot = torch.FloatTensor(self.batch_size,classes).to(self.device) 353 | onehot.zero_() 354 | onehot.scatter_(1,y,1) 355 | return onehot 356 | 357 | def forward(self): 358 | if not self.opt.use_local: 359 | self.fake_B = self.netG(self.real_A) 360 | else: 361 | self.fake_B0 = self.netG(self.real_A) 362 | # EYES, MOUTH 363 | outputs1 = self.netCLm(self.real_A_mouth) 364 | onehot1 = self.getonehot(outputs1,2) 365 | 366 | if not self.opt.others_ae: 367 | fake_B_eyel = self.netGLEyel(self.real_A_eyel) 368 | fake_B_eyer = self.netGLEyer(self.real_A_eyer) 369 | fake_B_mouth = self.netGLMouth(self.real_A_mouth) 370 | else: # use AE that only constains compact region, need cmask! 371 | self.fake_B_eyel1 = self.netGLEyel(self.real_A_eyel) 372 | self.fake_B_eyer1 = self.netGLEyer(self.real_A_eyer) 373 | self.fake_B_mouth1 = self.netGLMouth(self.real_A_mouth) 374 | self.fake_B_eyel2,_ = self.netAEel(self.fake_B_eyel1) 375 | self.fake_B_eyer2,_ = self.netAEer(self.fake_B_eyer1) 376 | # USE 2 AEs 377 | self.fake_B_mouth2 = torch.FloatTensor(self.batch_size,self.opt.output_nc,self.MOUTH_H,self.MOUTH_W).to(self.device) 378 | for i in range(self.batch_size): 379 | if onehot1[i][0] == 1: 380 | self.fake_B_mouth2[i],_ = self.netAEmowhite(self.fake_B_mouth1[i].unsqueeze(0)) 381 | #print('AEmowhite') 382 | elif onehot1[i][1] == 1: 383 | self.fake_B_mouth2[i],_ = self.netAEmoblack(self.fake_B_mouth1[i].unsqueeze(0)) 384 | #print('AEmoblack') 385 | fake_B_eyel = self.add_with_mask(self.fake_B_eyel2,self.fake_B_eyel1,self.cmaskel) 386 | fake_B_eyer = self.add_with_mask(self.fake_B_eyer2,self.fake_B_eyer1,self.cmasker) 387 | fake_B_mouth = self.add_with_mask(self.fake_B_mouth2,self.fake_B_mouth1,self.cmaskmo) 388 | # NOSE 389 | if not self.opt.nose_ae: 390 | fake_B_nose = self.netGLNose(self.real_A_nose) 391 | else: # use AE that only constains compact region, need cmask! 392 | self.fake_B_nose1 = self.netGLNose(self.real_A_nose) 393 | self.fake_B_nose2,_ = self.netAE(self.fake_B_nose1) 394 | fake_B_nose = self.add_with_mask(self.fake_B_nose2,self.fake_B_nose1,self.cmask) 395 | 396 | # for visuals and later local loss 397 | if self.opt.region_enm in [0,1]: 398 | self.fake_B_nose = fake_B_nose 399 | self.fake_B_eyel = fake_B_eyel 400 | self.fake_B_eyer = fake_B_eyer 401 | self.fake_B_mouth = fake_B_mouth 402 | # for soft border of 4 rectangle facial feature 403 | if self.opt.region_enm == 0 and self.opt.soft_border: 404 | self.fake_B_nose = self.masked(fake_B_nose, self.softno) 405 | self.fake_B_eyel = self.masked(fake_B_eyel, self.softel) 406 | self.fake_B_eyer = self.masked(fake_B_eyer, self.softer) 407 | self.fake_B_mouth = self.masked(fake_B_mouth, self.softmo) 408 | elif self.opt.region_enm in [2]: # need to multiply cmask 409 | self.fake_B_nose = self.masked(fake_B_nose,self.cmask) 410 | self.fake_B_eyel = self.masked(fake_B_eyel,self.cmaskel) 411 | self.fake_B_eyer = self.masked(fake_B_eyer,self.cmasker) 412 | self.fake_B_mouth = self.masked(fake_B_mouth,self.cmaskmo) 413 | 414 | # HAIR, BG AND PARTCOMBINE 415 | outputs2 = self.netCLh(self.real_A_hair) 416 | onehot2 = self.getonehot(outputs2,3) 417 | 418 | if not self.isTrain: 419 | opt = self.opt 420 | if opt.imagefolder == 'images': 421 | file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'styleonehot.txt') 422 | else: 423 | file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), opt.imagefolder, 'styleonehot.txt') 424 | message = '%s [%d %d] [%d %d %d]' % (self.image_paths[0], onehot1[0][0], onehot1[0][1], 425 | onehot2[0][0], onehot2[0][1], onehot2[0][2]) 426 | with open(file_name, 'a+') as s_file: 427 | s_file.write(message) 428 | s_file.write('\n') 429 | 430 | fake_B_hair = self.netGLHair(self.real_A_hair,onehot2) 431 | fake_B_bg = self.netGLBG(self.real_A_bg) 432 | self.fake_B_hair = self.masked(fake_B_hair,self.mask*self.mask2) 433 | self.fake_B_bg = self.masked(fake_B_bg,self.inverse_mask(self.mask2)) 434 | if not self.opt.compactmask: 435 | self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,fake_B_bg,self.mask*self.mask2,self.inverse_mask(self.mask2),self.opt.comb_op) 436 | else: 437 | self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,fake_B_bg,self.mask*self.mask2,self.inverse_mask(self.mask2),self.opt.comb_op,self.opt.region_enm,self.cmaskel,self.cmasker,self.cmask,self.cmaskmo) 438 | 439 | self.fake_B = self.netGCombine(torch.cat([self.fake_B0,self.fake_B1],1)) 440 | 441 | # for AE visuals 442 | if self.opt.region_enm in [0,1]: 443 | if self.opt.nose_ae: 444 | self.fake_B_nose_v = padpart(self.fake_B_nose, 'nose', self.center, self.opt, self.device) 445 | self.fake_B_nose_v1 = padpart(self.fake_B_nose1, 'nose', self.center, self.opt, self.device) 446 | self.fake_B_nose_v2 = padpart(self.fake_B_nose2, 'nose', self.center, self.opt, self.device) 447 | self.cmask1no = padpart(self.cmask1, 'nose', self.center, self.opt, self.device) 448 | if self.opt.others_ae: 449 | self.fake_B_eyel_v = padpart(self.fake_B_eyel, 'eyel', self.center, self.opt, self.device) 450 | self.fake_B_eyel_v1 = padpart(self.fake_B_eyel1, 'eyel', self.center, self.opt, self.device) 451 | self.fake_B_eyel_v2 = padpart(self.fake_B_eyel2, 'eyel', self.center, self.opt, self.device) 452 | self.cmask1el = padpart(self.cmask1el, 'eyel', self.center, self.opt, self.device) 453 | self.fake_B_eyer_v = padpart(self.fake_B_eyer, 'eyer', self.center, self.opt, self.device) 454 | self.fake_B_eyer_v1 = padpart(self.fake_B_eyer1, 'eyer', self.center, self.opt, self.device) 455 | self.fake_B_eyer_v2 = padpart(self.fake_B_eyer2, 'eyer', self.center, self.opt, self.device) 456 | self.cmask1er = padpart(self.cmask1er, 'eyer', self.center, self.opt, self.device) 457 | self.fake_B_mouth_v = padpart(self.fake_B_mouth, 'mouth', self.center, self.opt, self.device) 458 | self.fake_B_mouth_v1 = padpart(self.fake_B_mouth1, 'mouth', self.center, self.opt, self.device) 459 | self.fake_B_mouth_v2 = padpart(self.fake_B_mouth2, 'mouth', self.center, self.opt, self.device) 460 | self.cmask1mo = padpart(self.cmask1mo, 'mouth', self.center, self.opt, self.device) 461 | 462 | if not self.isTrain and self.opt.test_continuity_loss: 463 | self.ContinuityForTest(real=1) 464 | 465 | 466 | def backward_D(self): 467 | # Fake 468 | # stop backprop to the generator by detaching fake_B 469 | fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) 470 | #print('fake_AB', fake_AB.shape) # (1,4,512,512) 471 | pred_fake = self.netD(fake_AB.detach())# by detach, not affect G's gradient 472 | self.loss_D_fake = self.criterionGAN(pred_fake, False) 473 | if self.opt.discriminator_local: 474 | fake_AB_parts = self.getLocalParts(fake_AB) 475 | local_names = ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG'] 476 | self.loss_D_fake_local = 0 477 | for i in range(len(fake_AB_parts)): 478 | net = getattr(self, 'net' + local_names[i]) 479 | pred_fake_tmp = net(fake_AB_parts[i].detach()) 480 | addw = self.getaddw(local_names[i]) 481 | self.loss_D_fake_local = self.loss_D_fake_local + self.criterionGAN(pred_fake_tmp, False) * addw 482 | self.loss_D_fake = self.loss_D_fake + self.loss_D_fake_local 483 | 484 | # Real 485 | real_AB = torch.cat((self.real_A, self.real_B), 1) 486 | pred_real = self.netD(real_AB) 487 | self.loss_D_real = self.criterionGAN(pred_real, True) 488 | if self.opt.discriminator_local: 489 | real_AB_parts = self.getLocalParts(real_AB) 490 | local_names = ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG'] 491 | self.loss_D_real_local = 0 492 | for i in range(len(real_AB_parts)): 493 | net = getattr(self, 'net' + local_names[i]) 494 | pred_real_tmp = net(real_AB_parts[i]) 495 | addw = self.getaddw(local_names[i]) 496 | self.loss_D_real_local = self.loss_D_real_local + self.criterionGAN(pred_real_tmp, True) * addw 497 | self.loss_D_real = self.loss_D_real + self.loss_D_real_local 498 | 499 | # Combined loss 500 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 501 | 502 | self.loss_D.backward() 503 | 504 | def backward_G(self): 505 | # First, G(A) should fake the discriminator 506 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 507 | pred_fake = self.netD(fake_AB) # (1,4,512,512)->(1,1,30,30) 508 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) 509 | if self.opt.discriminator_local: 510 | fake_AB_parts = self.getLocalParts(fake_AB) 511 | local_names = ['DLEyel','DLEyer','DLNose','DLMouth','DLHair','DLBG'] 512 | self.loss_G_GAN_local = 0 # G_GAN_local is then added into G_GAN 513 | for i in range(len(fake_AB_parts)): 514 | net = getattr(self, 'net' + local_names[i]) 515 | pred_fake_tmp = net(fake_AB_parts[i]) 516 | addw = self.getaddw(local_names[i]) 517 | self.loss_G_GAN_local = self.loss_G_GAN_local + self.criterionGAN(pred_fake_tmp, True) * addw 518 | if self.opt.gan_loss_strategy == 1: 519 | self.loss_G_GAN = (self.loss_G_GAN + self.loss_G_GAN_local) / (len(fake_AB_parts) + 1) 520 | elif self.opt.gan_loss_strategy == 2: 521 | self.loss_G_GAN_local = self.loss_G_GAN_local * 0.25 522 | self.loss_G_GAN = self.loss_G_GAN + self.loss_G_GAN_local 523 | 524 | # Second, G(A) = B 525 | if not self.opt.no_l1_loss: 526 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 527 | 528 | if self.opt.use_local and not self.opt.no_G_local_loss: 529 | local_names = ['eyel','eyer','nose','mouth'] 530 | self.loss_G_local = 0 531 | for i in range(len(local_names)): 532 | fakeblocal = getattr(self, 'fake_B_' + local_names[i]) 533 | realblocal = getattr(self, 'real_B_' + local_names[i]) 534 | addw = self.getaddw(local_names[i]) 535 | self.loss_G_local = self.loss_G_local + self.criterionL1(fakeblocal,realblocal) * self.opt.lambda_local * addw 536 | self.loss_G_hair_local = self.criterionL1(self.fake_B_hair, self.real_B_hair) * self.opt.lambda_local * self.opt.addw_hair 537 | self.loss_G_bg_local = self.criterionL1(self.fake_B_bg, self.real_B_bg) * self.opt.lambda_local * self.opt.addw_bg 538 | 539 | # Third, chamfer matching (assume chamfer_2way and chamfer_only_line is true) 540 | if self.opt.chamfer_loss: 541 | if self.fake_B.shape[1] == 3: 542 | tmp = self.fake_B[:,0,...]*0.299+self.fake_B[:,1,...]*0.587+self.fake_B[:,2,...]*0.114 543 | fake_B_gray = tmp.unsqueeze(1) 544 | else: 545 | fake_B_gray = self.fake_B 546 | if self.real_B.shape[1] == 3: 547 | tmp = self.real_B[:,0,...]*0.299+self.real_B[:,1,...]*0.587+self.real_B[:,2,...]*0.114 548 | real_B_gray = tmp.unsqueeze(1) 549 | else: 550 | real_B_gray = self.real_B 551 | 552 | gpu_p = self.opt.gpu_ids_p[0] 553 | gpu = self.opt.gpu_ids[0] 554 | if gpu_p != gpu: 555 | fake_B_gray = fake_B_gray.cuda(gpu_p) 556 | real_B_gray = real_B_gray.cuda(gpu_p) 557 | 558 | # d_CM(a_i,G(p_i)) 559 | self.dt1 = self.netDT1(fake_B_gray) 560 | self.dt2 = self.netDT2(fake_B_gray) 561 | dt1 = self.dt1/2.0+0.5#[-1,1]->[0,1] 562 | dt2 = self.dt2/2.0+0.5 563 | if self.opt.dt_nonlinear != '': 564 | dt_xmax = torch.Tensor([self.opt.dt_xmax]).cuda(gpu_p) 565 | dt1 = nonlinearDt(dt1, self.opt.dt_nonlinear, dt_xmax) 566 | dt2 = nonlinearDt(dt2, self.opt.dt_nonlinear, dt_xmax) 567 | #print('dt1dt2',torch.min(dt1).item(),torch.max(dt1).item(),torch.min(dt2).item(),torch.max(dt2).item()) 568 | 569 | bs = real_B_gray.shape[0] 570 | real_B_gray_line1 = self.netLine1(real_B_gray) 571 | real_B_gray_line2 = self.netLine2(real_B_gray) 572 | self.loss_G_chamfer = (dt1[(real_B_gray<0)&(real_B_gray_line1<0)].sum() + dt2[(real_B_gray>=0)&(real_B_gray_line2>=0)].sum()) / bs * self.opt.lambda_chamfer 573 | if gpu_p != gpu: 574 | self.loss_G_chamfer = self.loss_G_chamfer.cuda(gpu) 575 | 576 | # d_CM(G(p_i),a_i) 577 | if gpu_p != gpu: 578 | dt1gt = self.dt1gt.cuda(gpu_p) 579 | dt2gt = self.dt2gt.cuda(gpu_p) 580 | else: 581 | dt1gt = self.dt1gt 582 | dt2gt = self.dt2gt 583 | if self.opt.dt_nonlinear != '': 584 | dt1gt = nonlinearDt(dt1gt, self.opt.dt_nonlinear, dt_xmax) 585 | dt2gt = nonlinearDt(dt2gt, self.opt.dt_nonlinear, dt_xmax) 586 | #print('dt1gtdt2gt',torch.min(dt1gt).item(),torch.max(dt1gt).item(),torch.min(dt2gt).item(),torch.max(dt2gt).item()) 587 | self.dt1gt = (self.dt1gt-0.5)*2 588 | self.dt2gt = (self.dt2gt-0.5)*2 589 | 590 | fake_B_gray_line1 = self.netLine1(fake_B_gray) 591 | fake_B_gray_line2 = self.netLine2(fake_B_gray) 592 | self.loss_G_chamfer2 = (dt1gt[(fake_B_gray<0)&(fake_B_gray_line1<0)].sum() + dt2gt[(fake_B_gray>=0)&(fake_B_gray_line2>=0)].sum()) / bs * self.opt.lambda_chamfer2 593 | if gpu_p != gpu: 594 | self.loss_G_chamfer2 = self.loss_G_chamfer2.cuda(gpu) 595 | 596 | # Fourth, line continuity loss, constrained on synthesized drawing 597 | if self.opt.continuity_loss: 598 | # Patch-based 599 | self.get_patches() 600 | self.outputs = self.netRegressor(self.fake_B_patches) 601 | if not self.opt.emphasis_conti_face: 602 | self.loss_G_continuity = (1.0-torch.mean(self.outputs)).cuda(gpu) * self.opt.lambda_continuity 603 | else: 604 | self.loss_G_continuity = torch.mean((1.0-self.outputs)*self.conti_weights).cuda(gpu) * self.opt.lambda_continuity 605 | 606 | 607 | 608 | self.loss_G = self.loss_G_GAN 609 | if 'G_L1' in self.loss_names: 610 | self.loss_G = self.loss_G + self.loss_G_L1 611 | if 'G_local' in self.loss_names: 612 | self.loss_G = self.loss_G + self.loss_G_local 613 | if 'G_hair_local' in self.loss_names: 614 | self.loss_G = self.loss_G + self.loss_G_hair_local 615 | if 'G_bg_local' in self.loss_names: 616 | self.loss_G = self.loss_G + self.loss_G_bg_local 617 | if 'G_chamfer' in self.loss_names: 618 | self.loss_G = self.loss_G + self.loss_G_chamfer 619 | if 'G_chamfer2' in self.loss_names: 620 | self.loss_G = self.loss_G + self.loss_G_chamfer2 621 | if 'G_continuity' in self.loss_names: 622 | self.loss_G = self.loss_G + self.loss_G_continuity 623 | 624 | self.loss_G.backward() 625 | 626 | def optimize_parameters(self): 627 | self.forward() 628 | # update D 629 | self.set_requires_grad(self.netD, True) 630 | 631 | if self.opt.discriminator_local: 632 | self.set_requires_grad(self.netDLEyel, True) 633 | self.set_requires_grad(self.netDLEyer, True) 634 | self.set_requires_grad(self.netDLNose, True) 635 | self.set_requires_grad(self.netDLMouth, True) 636 | self.set_requires_grad(self.netDLHair, True) 637 | self.set_requires_grad(self.netDLBG, True) 638 | self.optimizer_D.zero_grad() 639 | self.backward_D() 640 | self.optimizer_D.step() 641 | 642 | # update G 643 | self.set_requires_grad(self.netD, False) 644 | if self.opt.discriminator_local: 645 | self.set_requires_grad(self.netDLEyel, False) 646 | self.set_requires_grad(self.netDLEyer, False) 647 | self.set_requires_grad(self.netDLNose, False) 648 | self.set_requires_grad(self.netDLMouth, False) 649 | self.set_requires_grad(self.netDLHair, False) 650 | self.set_requires_grad(self.netDLBG, False) 651 | self.optimizer_G.zero_grad() 652 | self.backward_G() 653 | self.optimizer_G.step() 654 | 655 | def get_patches(self): 656 | gpu_p = self.opt.gpu_ids_p[0] 657 | gpu = self.opt.gpu_ids[0] 658 | if gpu_p != gpu: 659 | self.fake_B = self.fake_B.cuda(gpu_p) 660 | # [1,1,512,512]->[bs,1,11,11] 661 | patches = [] 662 | if self.isTrain and self.opt.emphasis_conti_face: 663 | weights = [] 664 | W2 = int(W/2) 665 | t = np.random.randint(res,size=2) 666 | for i in range(aa): 667 | for j in range(aa): 668 | p = self.fake_B[:,:,t[0]+i*W:t[0]+(i+1)*W,t[1]+j*W:t[1]+(j+1)*W] 669 | whitenum = torch.sum(p>=0.0) 670 | #if whitenum < 5 or whitenum > W*W-5: 671 | if whitenum < 1 or whitenum > W*W-1: 672 | continue 673 | patches.append(p) 674 | if self.isTrain and self.opt.emphasis_conti_face: 675 | weights.append(self.face_mask[:,:,t[0]+i*W+W2,t[1]+j*W+W2]) 676 | self.fake_B_patches = torch.cat(patches, dim=0) 677 | if self.isTrain and self.opt.emphasis_conti_face: 678 | self.conti_weights = torch.cat(weights, dim=0)+1 #0->1,1->2 679 | 680 | def get_patches_real(self): 681 | # [1,1,512,512]->[bs,1,11,11] 682 | patches = [] 683 | t = np.random.randint(res,size=2) 684 | for i in range(aa): 685 | for j in range(aa): 686 | p = self.real_B[:,:,t[0]+i*W:t[0]+(i+1)*W,t[1]+j*W:t[1]+(j+1)*W] 687 | whitenum = torch.sum(p>=0.0) 688 | #if whitenum < 5 or whitenum > W*W-5: 689 | if whitenum < 1 or whitenum > W*W-1: 690 | continue 691 | patches.append(p) 692 | self.real_B_patches = torch.cat(patches, dim=0) -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from . import networks 5 | 6 | 7 | class BaseModel(): 8 | 9 | # modify parser to add command line options, 10 | # and also change the default values if needed 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train): 13 | return parser 14 | 15 | def name(self): 16 | return 'BaseModel' 17 | 18 | def initialize(self, opt): 19 | self.opt = opt 20 | self.gpu_ids = opt.gpu_ids 21 | self.gpu_ids_p = opt.gpu_ids_p 22 | self.isTrain = opt.isTrain 23 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 24 | self.device_p = torch.device('cuda:{}'.format(self.gpu_ids_p[0])) if self.gpu_ids else torch.device('cpu') 25 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 26 | self.auxiliary_dir = os.path.join(opt.checkpoints_dir, opt.auxiliary_root) 27 | if opt.resize_or_crop != 'scale_width': 28 | torch.backends.cudnn.benchmark = True 29 | self.loss_names = [] 30 | self.model_names = [] 31 | self.visual_names = [] 32 | self.image_paths = [] 33 | 34 | def set_input(self, input): 35 | self.input = input 36 | 37 | def forward(self): 38 | pass 39 | 40 | # load and print networks; create schedulers 41 | def setup(self, opt, parser=None): 42 | if self.isTrain: 43 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 44 | 45 | if not self.isTrain or opt.continue_train: 46 | self.load_networks(opt.which_epoch) 47 | if len(self.auxiliary_model_names) > 0: 48 | self.load_auxiliary_networks() 49 | self.print_networks(opt.verbose) 50 | 51 | # make models eval mode during test time 52 | def eval(self): 53 | for name in self.model_names: 54 | if isinstance(name, str): 55 | net = getattr(self, 'net' + name) 56 | net.eval() 57 | 58 | # used in test time, wrapping `forward` in no_grad() so we don't save 59 | # intermediate steps for backprop 60 | def test(self): 61 | with torch.no_grad(): 62 | self.forward() 63 | 64 | # get image paths 65 | def get_image_paths(self): 66 | return self.image_paths 67 | 68 | def optimize_parameters(self): 69 | pass 70 | 71 | # update learning rate (called once every epoch) 72 | def update_learning_rate(self): 73 | for scheduler in self.schedulers: 74 | scheduler.step() 75 | lr = self.optimizers[0].param_groups[0]['lr'] 76 | print('learning rate = %.7f' % lr) 77 | 78 | # return visualization images. train.py will display these images, and save the images to a html 79 | def get_current_visuals(self): 80 | visual_ret = OrderedDict() 81 | for name in self.visual_names: 82 | if isinstance(name, str): 83 | visual_ret[name] = getattr(self, name) 84 | return visual_ret 85 | 86 | # return traning losses/errors. train.py will print out these errors as debugging information 87 | def get_current_losses(self): 88 | errors_ret = OrderedDict() 89 | for name in self.loss_names: 90 | if isinstance(name, str): 91 | # float(...) works for both scalar tensor and float number 92 | errors_ret[name] = float(getattr(self, 'loss_' + name)) 93 | return errors_ret 94 | 95 | # save models to the disk 96 | def save_networks(self, which_epoch): 97 | for name in self.model_names: 98 | if isinstance(name, str): 99 | save_filename = '%s_net_%s.pth' % (which_epoch, name) 100 | save_path = os.path.join(self.save_dir, save_filename) 101 | net = getattr(self, 'net' + name) 102 | 103 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 104 | torch.save(net.module.cpu().state_dict(), save_path) 105 | net.cuda(self.gpu_ids[0]) 106 | else: 107 | torch.save(net.cpu().state_dict(), save_path) 108 | 109 | def save_networks2(self, which_epoch): 110 | gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch)) 111 | dis_name = os.path.join(self.save_dir, '%s_net_dis.pt' % (which_epoch)) 112 | dict_gen = {} 113 | dict_dis = {} 114 | for name in self.model_names: 115 | if isinstance(name, str): 116 | net = getattr(self, 'net' + name) 117 | 118 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 119 | state_dict = net.module.cpu().state_dict() 120 | net.cuda(self.gpu_ids[0]) 121 | else: 122 | state_dict = net.cpu().state_dict() 123 | 124 | if name[0] == 'G': 125 | dict_gen[name] = state_dict 126 | elif name[0] == 'D': 127 | dict_dis[name] = state_dict 128 | else: 129 | save_filename = '%s_net_%s.pth' % (which_epoch, name) 130 | save_path = os.path.join(self.save_dir, save_filename) 131 | torch.save(state_dict, save_path) 132 | if dict_gen: 133 | torch.save(dict_gen, gen_name) 134 | if dict_dis: 135 | torch.save(dict_dis, dis_name) 136 | 137 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 138 | key = keys[i] 139 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 140 | if module.__class__.__name__.startswith('InstanceNorm') and \ 141 | (key == 'running_mean' or key == 'running_var'): 142 | if getattr(module, key) is None: 143 | state_dict.pop('.'.join(keys)) 144 | if module.__class__.__name__.startswith('InstanceNorm') and \ 145 | (key == 'num_batches_tracked'): 146 | state_dict.pop('.'.join(keys)) 147 | else: 148 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 149 | 150 | # load models from the disk 151 | def load_networks(self, which_epoch): 152 | gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch)) 153 | if os.path.exists(gen_name): 154 | self.load_networks2(which_epoch) 155 | return 156 | for name in self.model_names: 157 | if isinstance(name, str): 158 | load_filename = '%s_net_%s.pth' % (which_epoch, name) 159 | load_path = os.path.join(self.save_dir, load_filename) 160 | net = getattr(self, 'net' + name) 161 | if isinstance(net, torch.nn.DataParallel): 162 | net = net.module 163 | print('loading the model from %s' % load_path) 164 | # if you are using PyTorch newer than 0.4 (e.g., built from 165 | # GitHub source), you can remove str() on self.device 166 | state_dict = torch.load(load_path, map_location=str(self.device)) 167 | if hasattr(state_dict, '_metadata'): 168 | del state_dict._metadata 169 | 170 | # patch InstanceNorm checkpoints prior to 0.4 171 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 172 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 173 | net.load_state_dict(state_dict) 174 | 175 | def load_networks2(self, which_epoch): 176 | gen_name = os.path.join(self.save_dir, '%s_net_gen.pt' % (which_epoch)) 177 | gen_state_dict = torch.load(gen_name, map_location=str(self.device)) 178 | if self.isTrain and self.opt.model != 'apdrawing_style_nogan': 179 | dis_name = os.path.join(self.save_dir, '%s_net_dis.pt' % (which_epoch)) 180 | dis_state_dict = torch.load(dis_name, map_location=str(self.device)) 181 | for name in self.model_names: 182 | if isinstance(name, str): 183 | net = getattr(self, 'net' + name) 184 | if isinstance(net, torch.nn.DataParallel): 185 | net = net.module 186 | if name[0] == 'G': 187 | print('loading the model %s from %s' % (name,gen_name)) 188 | state_dict = gen_state_dict[name] 189 | elif name[0] == 'D': 190 | print('loading the model %s from %s' % (name,gen_name)) 191 | state_dict = dis_state_dict[name] 192 | 193 | if hasattr(state_dict, '_metadata'): 194 | del state_dict._metadata 195 | # patch InstanceNorm checkpoints prior to 0.4 196 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 197 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 198 | net.load_state_dict(state_dict) 199 | 200 | # load auxiliary net models from the disk 201 | def load_auxiliary_networks(self): 202 | for name in self.auxiliary_model_names: 203 | if isinstance(name, str): 204 | if 'AE' in name and self.opt.ae_small: 205 | load_filename = '%s_net_%s_small.pth' % ('latest', name) 206 | elif 'Regressor' in name: 207 | load_filename = '%s_net_%s%d.pth' % ('latest', name, self.opt.regarch) 208 | else: 209 | load_filename = '%s_net_%s.pth' % ('latest', name) 210 | load_path = os.path.join(self.auxiliary_dir, load_filename) 211 | net = getattr(self, 'net' + name) 212 | if isinstance(net, torch.nn.DataParallel): 213 | net = net.module 214 | print('loading the model from %s' % load_path) 215 | # if you are using PyTorch newer than 0.4 (e.g., built from 216 | # GitHub source), you can remove str() on self.device 217 | if name in ['DT1', 'DT2', 'Line1', 'Line2', 'Continuity1', 'Continuity2', 'Regressor', 'Regressorhair', 'Regressorface']: 218 | state_dict = torch.load(load_path, map_location=str(self.device_p)) 219 | else: 220 | state_dict = torch.load(load_path, map_location=str(self.device)) 221 | if hasattr(state_dict, '_metadata'): 222 | del state_dict._metadata 223 | 224 | # patch InstanceNorm checkpoints prior to 0.4 225 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 226 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 227 | net.load_state_dict(state_dict) 228 | 229 | # print network information 230 | def print_networks(self, verbose): 231 | print('---------- Networks initialized -------------') 232 | for name in self.model_names: 233 | if isinstance(name, str): 234 | net = getattr(self, 'net' + name) 235 | num_params = 0 236 | for param in net.parameters(): 237 | num_params += param.numel() 238 | if verbose: 239 | print(net) 240 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 241 | print('-----------------------------------------------') 242 | 243 | # set requies_grad=Fasle to avoid computation 244 | def set_requires_grad(self, nets, requires_grad=False): 245 | if not isinstance(nets, list): 246 | nets = [nets] 247 | for net in nets: 248 | if net is not None: 249 | for param in net.parameters(): 250 | param.requires_grad = requires_grad 251 | 252 | # ============================================================================================================= 253 | def inverse_mask(self, mask): 254 | return torch.ones(mask.shape).to(self.device)-mask 255 | 256 | def masked(self, A,mask): 257 | return (A/2+0.5)*mask*2-1 258 | 259 | def add_with_mask(self, A,B,mask): 260 | return ((A/2+0.5)*mask+(B/2+0.5)*(torch.ones(mask.shape).to(self.device)-mask))*2-1 261 | 262 | def addone_with_mask(self, A,mask): 263 | return ((A/2+0.5)*mask+(torch.ones(mask.shape).to(self.device)-mask))*2-1 264 | 265 | def partCombiner(self, eyel, eyer, nose, mouth, average_pos=False, comb_op = 1, region_enm = 0, cmaskel = None, cmasker = None, cmaskno = None, cmaskmo = None): 266 | ''' 267 | x y 268 | 100.571 123.429 269 | 155.429 123.429 270 | 128.000 155.886 271 | 103.314 185.417 272 | 152.686 185.417 273 | this is the mean locaiton of 5 landmarks (for 256x256) 274 | Pad2d Left,Right,Top,Down 275 | ''' 276 | if comb_op == 0: 277 | # use max pooling, pad black for eyes etc 278 | padvalue = -1 279 | if region_enm in [1,2]: 280 | eyel = eyel * cmaskel 281 | eyer = eyer * cmasker 282 | nose = nose * cmaskno 283 | mouth = mouth * cmaskmo 284 | else: 285 | # use min pooling, pad white for eyes etc 286 | padvalue = 1 287 | if region_enm in [1,2]: 288 | eyel = self.addone_with_mask(eyel, cmaskel) 289 | eyer = self.addone_with_mask(eyer, cmasker) 290 | nose = self.addone_with_mask(nose, cmaskno) 291 | mouth = self.addone_with_mask(mouth, cmaskmo) 292 | if region_enm in [0,1]: # need to pad 293 | IMAGE_SIZE = self.opt.fineSize 294 | ratio = IMAGE_SIZE / 256 295 | EYE_W = self.opt.EYE_W * ratio 296 | EYE_H = self.opt.EYE_H * ratio 297 | NOSE_W = self.opt.NOSE_W * ratio 298 | NOSE_H = self.opt.NOSE_H * ratio 299 | MOUTH_W = self.opt.MOUTH_W * ratio 300 | MOUTH_H = self.opt.MOUTH_H * ratio 301 | bs,nc,_,_ = eyel.shape 302 | eyel_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 303 | eyer_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 304 | nose_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 305 | mouth_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 306 | for i in range(bs): 307 | if not average_pos: 308 | center = self.center[i]#x,y 309 | else:# if average_pos = True 310 | center = torch.tensor([[101,123-4],[155,123-4],[128,156-NOSE_H/2+16],[128,185]]) 311 | eyel_p[i] = torch.nn.ConstantPad2d((int(center[0,0] - EYE_W / 2 - 1), int(IMAGE_SIZE - (center[0,0]+EYE_W/2-1)), int(center[0,1] - EYE_H / 2 - 1),int(IMAGE_SIZE - (center[0,1]+EYE_H/2 - 1))),-1)(eyel[i]) 312 | eyer_p[i] = torch.nn.ConstantPad2d((int(center[1,0] - EYE_W / 2 - 1), int(IMAGE_SIZE - (center[1,0]+EYE_W/2-1)), int(center[1,1] - EYE_H / 2 - 1), int(IMAGE_SIZE - (center[1,1]+EYE_H/2 - 1))),-1)(eyer[i]) 313 | nose_p[i] = torch.nn.ConstantPad2d((int(center[2,0] - NOSE_W / 2 - 1), int(IMAGE_SIZE - (center[2,0]+NOSE_W/2-1)), int(center[2,1] - NOSE_H / 2 - 1), int(IMAGE_SIZE - (center[2,1]+NOSE_H/2 - 1))),-1)(nose[i]) 314 | mouth_p[i] = torch.nn.ConstantPad2d((int(center[3,0] - MOUTH_W / 2 - 1), int(IMAGE_SIZE - (center[3,0]+MOUTH_W/2-1)), int(center[3,1] - MOUTH_H / 2 - 1), int(IMAGE_SIZE - (center[3,1]+MOUTH_H/2 - 1))),-1)(mouth[i]) 315 | elif region_enm in [2]: 316 | eyel_p = eyel 317 | eyer_p = eyer 318 | nose_p = nose 319 | mouth_p = mouth 320 | if comb_op == 0: 321 | # use max pooling 322 | eyes = torch.max(eyel_p, eyer_p) 323 | eye_nose = torch.max(eyes, nose_p) 324 | result = torch.max(eye_nose, mouth_p) 325 | else: 326 | # use min pooling 327 | eyes = torch.min(eyel_p, eyer_p) 328 | eye_nose = torch.min(eyes, nose_p) 329 | result = torch.min(eye_nose, mouth_p) 330 | return result 331 | 332 | def partCombiner2(self, eyel, eyer, nose, mouth, hair, mask, comb_op = 1, region_enm = 0, cmaskel = None, cmasker = None, cmaskno = None, cmaskmo = None): 333 | if comb_op == 0: 334 | # use max pooling, pad black for eyes etc 335 | padvalue = -1 336 | hair = self.masked(hair, mask) 337 | if region_enm in [1,2]: 338 | eyel = eyel * cmaskel 339 | eyer = eyer * cmasker 340 | nose = nose * cmaskno 341 | mouth = mouth * cmaskmo 342 | else: 343 | # use min pooling, pad white for eyes etc 344 | padvalue = 1 345 | hair = self.addone_with_mask(hair, mask) 346 | if region_enm in [1,2]: 347 | eyel = self.addone_with_mask(eyel, cmaskel) 348 | eyer = self.addone_with_mask(eyer, cmasker) 349 | nose = self.addone_with_mask(nose, cmaskno) 350 | mouth = self.addone_with_mask(mouth, cmaskmo) 351 | if region_enm in [0,1]: # need to pad 352 | IMAGE_SIZE = self.opt.fineSize 353 | ratio = IMAGE_SIZE / 256 354 | EYE_W = self.opt.EYE_W * ratio 355 | EYE_H = self.opt.EYE_H * ratio 356 | NOSE_W = self.opt.NOSE_W * ratio 357 | NOSE_H = self.opt.NOSE_H * ratio 358 | MOUTH_W = self.opt.MOUTH_W * ratio 359 | MOUTH_H = self.opt.MOUTH_H * ratio 360 | bs,nc,_,_ = eyel.shape 361 | eyel_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 362 | eyer_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 363 | nose_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 364 | mouth_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 365 | for i in range(bs): 366 | center = self.center[i]#x,y 367 | eyel_p[i] = torch.nn.ConstantPad2d((center[0,0] - EYE_W / 2, IMAGE_SIZE - (center[0,0]+EYE_W/2), center[0,1] - EYE_H / 2, IMAGE_SIZE - (center[0,1]+EYE_H/2)),padvalue)(eyel[i]) 368 | eyer_p[i] = torch.nn.ConstantPad2d((center[1,0] - EYE_W / 2, IMAGE_SIZE - (center[1,0]+EYE_W/2), center[1,1] - EYE_H / 2, IMAGE_SIZE - (center[1,1]+EYE_H/2)),padvalue)(eyer[i]) 369 | nose_p[i] = torch.nn.ConstantPad2d((center[2,0] - NOSE_W / 2, IMAGE_SIZE - (center[2,0]+NOSE_W/2), center[2,1] - NOSE_H / 2, IMAGE_SIZE - (center[2,1]+NOSE_H/2)),padvalue)(nose[i]) 370 | mouth_p[i] = torch.nn.ConstantPad2d((center[3,0] - MOUTH_W / 2, IMAGE_SIZE - (center[3,0]+MOUTH_W/2), center[3,1] - MOUTH_H / 2, IMAGE_SIZE - (center[3,1]+MOUTH_H/2)),padvalue)(mouth[i]) 371 | elif region_enm in [2]: 372 | eyel_p = eyel 373 | eyer_p = eyer 374 | nose_p = nose 375 | mouth_p = mouth 376 | if comb_op == 0: 377 | # use max pooling 378 | eyes = torch.max(eyel_p, eyer_p) 379 | eye_nose = torch.max(eyes, nose_p) 380 | eye_nose_mouth = torch.max(eye_nose, mouth_p) 381 | result = torch.max(hair,eye_nose_mouth) 382 | else: 383 | # use min pooling 384 | eyes = torch.min(eyel_p, eyer_p) 385 | eye_nose = torch.min(eyes, nose_p) 386 | eye_nose_mouth = torch.min(eye_nose, mouth_p) 387 | result = torch.min(hair,eye_nose_mouth) 388 | return result 389 | 390 | def partCombiner2_bg(self, eyel, eyer, nose, mouth, hair, bg, maskh, maskb, comb_op = 1, region_enm = 0, cmaskel = None, cmasker = None, cmaskno = None, cmaskmo = None): 391 | if comb_op == 0: 392 | # use max pooling, pad black for eyes etc 393 | padvalue = -1 394 | hair = self.masked(hair, maskh) 395 | bg = self.masked(bg, maskb) 396 | if region_enm in [1,2]: 397 | eyel = eyel * cmaskel 398 | eyer = eyer * cmasker 399 | nose = nose * cmaskno 400 | mouth = mouth * cmaskmo 401 | else: 402 | # use min pooling, pad white for eyes etc 403 | padvalue = 1 404 | hair = self.addone_with_mask(hair, maskh) 405 | bg = self.addone_with_mask(bg, maskb) 406 | if region_enm in [1,2]: 407 | eyel = self.addone_with_mask(eyel, cmaskel) 408 | eyer = self.addone_with_mask(eyer, cmasker) 409 | nose = self.addone_with_mask(nose, cmaskno) 410 | mouth = self.addone_with_mask(mouth, cmaskmo) 411 | if region_enm in [0,1]: # need to pad to full size 412 | IMAGE_SIZE = self.opt.fineSize 413 | ratio = IMAGE_SIZE / 256 414 | EYE_W = self.opt.EYE_W * ratio 415 | EYE_H = self.opt.EYE_H * ratio 416 | NOSE_W = self.opt.NOSE_W * ratio 417 | NOSE_H = self.opt.NOSE_H * ratio 418 | MOUTH_W = self.opt.MOUTH_W * ratio 419 | MOUTH_H = self.opt.MOUTH_H * ratio 420 | bs,nc,_,_ = eyel.shape 421 | eyel_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 422 | eyer_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 423 | nose_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 424 | mouth_p = torch.ones((bs,nc,IMAGE_SIZE,IMAGE_SIZE)).to(self.device) 425 | for i in range(bs): 426 | center = self.center[i]#x,y 427 | eyel_p[i] = torch.nn.ConstantPad2d((center[0,0] - EYE_W / 2, IMAGE_SIZE - (center[0,0]+EYE_W/2), center[0,1] - EYE_H / 2, IMAGE_SIZE - (center[0,1]+EYE_H/2)),padvalue)(eyel[i]) 428 | eyer_p[i] = torch.nn.ConstantPad2d((center[1,0] - EYE_W / 2, IMAGE_SIZE - (center[1,0]+EYE_W/2), center[1,1] - EYE_H / 2, IMAGE_SIZE - (center[1,1]+EYE_H/2)),padvalue)(eyer[i]) 429 | nose_p[i] = torch.nn.ConstantPad2d((center[2,0] - NOSE_W / 2, IMAGE_SIZE - (center[2,0]+NOSE_W/2), center[2,1] - NOSE_H / 2, IMAGE_SIZE - (center[2,1]+NOSE_H/2)),padvalue)(nose[i]) 430 | mouth_p[i] = torch.nn.ConstantPad2d((center[3,0] - MOUTH_W / 2, IMAGE_SIZE - (center[3,0]+MOUTH_W/2), center[3,1] - MOUTH_H / 2, IMAGE_SIZE - (center[3,1]+MOUTH_H/2)),padvalue)(mouth[i]) 431 | elif region_enm in [2]: 432 | eyel_p = eyel 433 | eyer_p = eyer 434 | nose_p = nose 435 | mouth_p = mouth 436 | if comb_op == 0: 437 | eyes = torch.max(eyel_p, eyer_p) 438 | eye_nose = torch.max(eyes, nose_p) 439 | eye_nose_mouth = torch.max(eye_nose, mouth_p) 440 | eye_nose_mouth_hair = torch.max(hair,eye_nose_mouth) 441 | result = torch.max(bg,eye_nose_mouth_hair) 442 | else: 443 | eyes = torch.min(eyel_p, eyer_p) 444 | eye_nose = torch.min(eyes, nose_p) 445 | eye_nose_mouth = torch.min(eye_nose, mouth_p) 446 | eye_nose_mouth_hair = torch.min(hair,eye_nose_mouth) 447 | result = torch.min(bg,eye_nose_mouth_hair) 448 | return result 449 | 450 | def partCombiner3(self, face, hair, maskf, maskh, comb_op = 1): 451 | if comb_op == 0: 452 | # use max pooling, pad black etc 453 | padvalue = -1 454 | face = self.masked(face, maskf) 455 | hair = self.masked(hair, maskh) 456 | else: 457 | # use min pooling, pad white etc 458 | padvalue = 1 459 | face = self.addone_with_mask(face, maskf) 460 | hair = self.addone_with_mask(hair, maskh) 461 | if comb_op == 0: 462 | result = torch.max(face,hair) 463 | else: 464 | result = torch.min(face,hair) 465 | return result 466 | 467 | 468 | def tocv2(ts): 469 | img = (ts.numpy()/2+0.5)*255 470 | img = img.astype('uint8') 471 | img = np.transpose(img,(1,2,0)) 472 | img = img[:,:,::-1]#rgb->bgr 473 | return img 474 | 475 | def totor(img): 476 | img = img[:,:,::-1] 477 | tor = transforms.ToTensor()(img) 478 | tor = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(tor) 479 | return tor 480 | 481 | 482 | def ContinuityForTest(self, real = 0): 483 | # Patch-based 484 | self.get_patches() 485 | self.outputs = self.netRegressor(self.fake_B_patches) 486 | line_continuity = torch.mean(self.outputs) 487 | opt = self.opt 488 | file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'continuity.txt') 489 | message = '%s %.04f' % (self.image_paths[0], line_continuity) 490 | with open(file_name, 'a+') as c_file: 491 | c_file.write(message) 492 | c_file.write('\n') 493 | if real == 1: 494 | self.get_patches_real() 495 | self.outputs2 = self.netRegressor(self.real_B_patches) 496 | line_continuity2 = torch.mean(self.outputs2) 497 | file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'continuity-r.txt') 498 | message = '%s %.04f' % (self.image_paths[0], line_continuity2) 499 | with open(file_name, 'a+') as c_file: 500 | c_file.write(message) 501 | c_file.write('\n') 502 | 503 | def getLocalParts(self,fakeAB): 504 | bs,nc,_,_ = fakeAB.shape #dtype torch.float32 505 | ncr = int(nc / self.opt.output_nc) 506 | if self.opt.region_enm in [0,1]: 507 | ratio = self.opt.fineSize / 256 508 | EYE_H = self.opt.EYE_H * ratio 509 | EYE_W = self.opt.EYE_W * ratio 510 | NOSE_H = self.opt.NOSE_H * ratio 511 | NOSE_W = self.opt.NOSE_W * ratio 512 | MOUTH_H = self.opt.MOUTH_H * ratio 513 | MOUTH_W = self.opt.MOUTH_W * ratio 514 | eyel = torch.ones((bs,nc,int(EYE_H),int(EYE_W))).to(self.device) 515 | eyer = torch.ones((bs,nc,int(EYE_H),int(EYE_W))).to(self.device) 516 | nose = torch.ones((bs,nc,int(NOSE_H),int(NOSE_W))).to(self.device) 517 | mouth = torch.ones((bs,nc,int(MOUTH_H),int(MOUTH_W))).to(self.device) 518 | for i in range(bs): 519 | center = self.center[i] 520 | eyel[i] = fakeAB[i,:,center[0,1]-EYE_H/2:center[0,1]+EYE_H/2,center[0,0]-EYE_W/2:center[0,0]+EYE_W/2] 521 | eyer[i] = fakeAB[i,:,center[1,1]-EYE_H/2:center[1,1]+EYE_H/2,center[1,0]-EYE_W/2:center[1,0]+EYE_W/2] 522 | nose[i] = fakeAB[i,:,center[2,1]-NOSE_H/2:center[2,1]+NOSE_H/2,center[2,0]-NOSE_W/2:center[2,0]+NOSE_W/2] 523 | mouth[i] = fakeAB[i,:,center[3,1]-MOUTH_H/2:center[3,1]+MOUTH_H/2,center[3,0]-MOUTH_W/2:center[3,0]+MOUTH_W/2] 524 | elif self.opt.region_enm in [2]: 525 | eyel = (fakeAB/2+0.5) * self.cmaskel.repeat(1,ncr,1,1) * 2 - 1 526 | eyer = (fakeAB/2+0.5) * self.cmasker.repeat(1,ncr,1,1) * 2 - 1 527 | nose = (fakeAB/2+0.5) * self.cmask.repeat(1,ncr,1,1) * 2 - 1 528 | mouth = (fakeAB/2+0.5) * self.cmaskmo.repeat(1,ncr,1,1) * 2 - 1 529 | hair = (fakeAB/2+0.5) * self.mask.repeat(1,ncr,1,1) * self.mask2.repeat(1,ncr,1,1) * 2 - 1 530 | bg = (fakeAB/2+0.5) * (torch.ones(fakeAB.shape).to(self.device)-self.mask2.repeat(1,ncr,1,1)) * 2 - 1 531 | return eyel, eyer, nose, mouth, hair, bg 532 | 533 | def getaddw(self,local_name): 534 | addw = 1 535 | if local_name in ['DLEyel','DLEyer','eyel','eyer','DLFace','face']: 536 | addw = self.opt.addw_eye 537 | elif local_name in ['DLNose', 'nose']: 538 | addw = self.opt.addw_nose 539 | elif local_name in ['DLMouth', 'mouth']: 540 | addw = self.opt.addw_mouth 541 | elif local_name in ['DLHair', 'hair']: 542 | addw = self.opt.addw_hair 543 | elif local_name in ['DLBG', 'bg']: 544 | addw = self.opt.addw_bg 545 | return addw -------------------------------------------------------------------------------- /models/test_model.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from . import networks 3 | import torch 4 | 5 | 6 | class TestModel(BaseModel): 7 | def name(self): 8 | return 'TestModel' 9 | 10 | @staticmethod 11 | def modify_commandline_options(parser, is_train=True): 12 | assert not is_train, 'TestModel cannot be used in train mode' 13 | # uncomment because default CycleGAN did not use dropout ( parser.set_defaults(no_dropout=True) ) 14 | # parser = CycleGANModel.modify_commandline_options(parser, is_train=False) 15 | parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch')# no_lsgan=True, use_lsgan=False 16 | parser.set_defaults(dataset_mode='single') 17 | parser.set_defaults(auxiliary_root='auxiliaryeye2o') 18 | parser.set_defaults(use_local=True, hair_local=True, bg_local=True) 19 | parser.set_defaults(nose_ae=True, others_ae=True, compactmask=True, MOUTH_H=56) 20 | parser.set_defaults(soft_border=1) 21 | parser.add_argument('--nnG_hairc', type=int, default=6, help='nnG for hair classifier') 22 | parser.add_argument('--use_resnet', action='store_true', help='use resnet for generator') 23 | 24 | parser.add_argument('--model_suffix', type=str, default='', 25 | help='In checkpoints_dir, [which_epoch]_net_G[model_suffix].pth will' 26 | ' be loaded as the generator of TestModel') 27 | 28 | return parser 29 | 30 | def initialize(self, opt): 31 | assert(not opt.isTrain) 32 | BaseModel.initialize(self, opt) 33 | 34 | # specify the training losses you want to print out. The program will call base_model.get_current_losses 35 | self.loss_names = [] 36 | # specify the images you want to save/display. The program will call base_model.get_current_visuals 37 | self.visual_names = ['real_A', 'fake_B'] 38 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks 39 | self.model_names = ['G' + opt.model_suffix] 40 | self.auxiliary_model_names = [] 41 | if self.opt.use_local: 42 | self.model_names += ['GLEyel','GLEyer','GLNose','GLMouth','GLHair','GLBG','GCombine'] 43 | self.auxiliary_model_names += ['CLm','CLh'] 44 | # auxiliary nets for local output refinement 45 | if self.opt.nose_ae: 46 | self.auxiliary_model_names += ['AE'] 47 | if self.opt.others_ae: 48 | self.auxiliary_model_names += ['AEel','AEer','AEmowhite','AEmoblack'] 49 | print('model_names', self.model_names) 50 | print('auxiliary_model_names', self.auxiliary_model_names) 51 | 52 | # load/define networks 53 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm, 54 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 55 | opt.nnG) 56 | print('netG', opt.netG) 57 | if self.opt.use_local: 58 | netlocal1 = 'partunet' if self.opt.use_resnet == 0 else 'resnet_nblocks' 59 | netlocal2 = 'partunet2' if self.opt.use_resnet == 0 else 'resnet_6blocks' 60 | netlocal2_style = 'partunet2style' if self.opt.use_resnet == 0 else 'resnet_style2_6blocks' 61 | self.netGLEyel = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm, 62 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3) 63 | self.netGLEyer = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm, 64 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3) 65 | self.netGLNose = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm, 66 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3) 67 | self.netGLMouth = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal1, opt.norm, 68 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=3) 69 | self.netGLHair = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2_style, opt.norm, 70 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=4, 71 | extra_channel=3) 72 | self.netGLBG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, netlocal2, opt.norm, 73 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, nnG=4) 74 | # by default combiner_type is combiner, which uses resnet 75 | print('combiner_type', self.opt.combiner_type) 76 | self.netGCombine = networks.define_G(2*opt.output_nc, opt.output_nc, opt.ngf, self.opt.combiner_type, opt.norm, 77 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 2) 78 | # auxiliary classifiers for mouth and hair 79 | ratio = self.opt.fineSize / 256 80 | self.MOUTH_H = int(self.opt.MOUTH_H * ratio) 81 | self.MOUTH_W = int(self.opt.MOUTH_W * ratio) 82 | self.netCLm = networks.define_G(opt.input_nc, 2, opt.ngf, 'classifier', opt.norm, 83 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 84 | nnG = 3, ae_h = self.MOUTH_H, ae_w = self.MOUTH_W) 85 | self.netCLh = networks.define_G(opt.input_nc, 3, opt.ngf, 'classifier', opt.norm, 86 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 87 | nnG = opt.nnG_hairc, ae_h = opt.fineSize, ae_w = opt.fineSize) 88 | # ==================================auxiliary nets (loaded, parameters fixed)============================= 89 | if self.opt.use_local and self.opt.nose_ae: 90 | ratio = self.opt.fineSize / 256 91 | NOSE_H = self.opt.NOSE_H * ratio 92 | NOSE_W = self.opt.NOSE_W * ratio 93 | self.netAE = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch', 94 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 95 | latent_dim=self.opt.ae_latentno, ae_h=NOSE_H, ae_w=NOSE_W) 96 | self.set_requires_grad(self.netAE, False) 97 | if self.opt.use_local and self.opt.others_ae: 98 | ratio = self.opt.fineSize / 256 99 | EYE_H = self.opt.EYE_H * ratio 100 | EYE_W = self.opt.EYE_W * ratio 101 | MOUTH_H = self.opt.MOUTH_H * ratio 102 | MOUTH_W = self.opt.MOUTH_W * ratio 103 | self.netAEel = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch', 104 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 105 | latent_dim=self.opt.ae_latenteye, ae_h=EYE_H, ae_w=EYE_W) 106 | self.netAEer = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch', 107 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 108 | latent_dim=self.opt.ae_latenteye, ae_h=EYE_H, ae_w=EYE_W) 109 | self.netAEmowhite = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch', 110 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 111 | latent_dim=self.opt.ae_latentmo, ae_h=MOUTH_H, ae_w=MOUTH_W) 112 | self.netAEmoblack = networks.define_G(opt.output_nc, opt.output_nc, opt.ngf, self.opt.nose_ae_net, 'batch', 113 | not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, 114 | latent_dim=self.opt.ae_latentmo, ae_h=MOUTH_H, ae_w=MOUTH_W) 115 | self.set_requires_grad(self.netAEel, False) 116 | self.set_requires_grad(self.netAEer, False) 117 | self.set_requires_grad(self.netAEmowhite, False) 118 | self.set_requires_grad(self.netAEmoblack, False) 119 | 120 | # assigns the model to self.netG_[suffix] so that it can be loaded 121 | # please see BaseModel.load_networks 122 | setattr(self, 'netG' + opt.model_suffix, self.netG) 123 | 124 | def set_input(self, input): 125 | # we need to use single_dataset mode 126 | self.real_A = input['A'].to(self.device) 127 | self.image_paths = input['A_paths'] 128 | self.batch_size = len(self.image_paths) 129 | if self.opt.use_local: 130 | self.real_A_eyel = input['eyel_A'].to(self.device) 131 | self.real_A_eyer = input['eyer_A'].to(self.device) 132 | self.real_A_nose = input['nose_A'].to(self.device) 133 | self.real_A_mouth = input['mouth_A'].to(self.device) 134 | self.center = input['center'] 135 | if self.opt.soft_border: 136 | self.softel = input['soft_eyel_mask'].to(self.device) 137 | self.softer = input['soft_eyer_mask'].to(self.device) 138 | self.softno = input['soft_nose_mask'].to(self.device) 139 | self.softmo = input['soft_mouth_mask'].to(self.device) 140 | if self.opt.compactmask: 141 | self.cmask = input['cmask'].to(self.device) 142 | self.cmask1 = self.cmask*2-1#[0,1]->[-1,1] 143 | self.cmaskel = input['cmaskel'].to(self.device) 144 | self.cmask1el = self.cmaskel*2-1 145 | self.cmasker = input['cmasker'].to(self.device) 146 | self.cmask1er = self.cmasker*2-1 147 | self.cmaskmo = input['cmaskmo'].to(self.device) 148 | self.cmask1mo = self.cmaskmo*2-1 149 | self.real_A_hair = input['hair_A'].to(self.device) 150 | self.mask = input['mask'].to(self.device) # mask for non-eyes,nose,mouth 151 | self.mask2 = input['mask2'].to(self.device) # mask for non-bg 152 | self.real_A_bg = input['bg_A'].to(self.device) 153 | 154 | def getonehot(self,outputs,classes): 155 | [maxv,index] = torch.max(outputs,1) 156 | y = torch.unsqueeze(index,1) 157 | onehot = torch.FloatTensor(self.batch_size,classes).to(self.device) 158 | onehot.zero_() 159 | onehot.scatter_(1,y,1) 160 | return onehot 161 | 162 | def forward(self): 163 | if not self.opt.use_local: 164 | self.fake_B = self.netG(self.real_A) 165 | else: 166 | self.fake_B0 = self.netG(self.real_A) 167 | # EYES, MOUTH 168 | outputs1 = self.netCLm(self.real_A_mouth) 169 | onehot1 = self.getonehot(outputs1,2) 170 | 171 | if not self.opt.others_ae: 172 | fake_B_eyel = self.netGLEyel(self.real_A_eyel) 173 | fake_B_eyer = self.netGLEyer(self.real_A_eyer) 174 | fake_B_mouth = self.netGLMouth(self.real_A_mouth) 175 | else: # use AE that only constains compact region, need cmask! 176 | self.fake_B_eyel1 = self.netGLEyel(self.real_A_eyel) 177 | self.fake_B_eyer1 = self.netGLEyer(self.real_A_eyer) 178 | self.fake_B_mouth1 = self.netGLMouth(self.real_A_mouth) 179 | self.fake_B_eyel2,_ = self.netAEel(self.fake_B_eyel1) 180 | self.fake_B_eyer2,_ = self.netAEer(self.fake_B_eyer1) 181 | # USE 2 AEs 182 | self.fake_B_mouth2 = torch.FloatTensor(self.batch_size,self.opt.output_nc,self.MOUTH_H,self.MOUTH_W).to(self.device) 183 | for i in range(self.batch_size): 184 | if onehot1[i][0] == 1: 185 | self.fake_B_mouth2[i],_ = self.netAEmowhite(self.fake_B_mouth1[i].unsqueeze(0)) 186 | #print('AEmowhite') 187 | elif onehot1[i][1] == 1: 188 | self.fake_B_mouth2[i],_ = self.netAEmoblack(self.fake_B_mouth1[i].unsqueeze(0)) 189 | #print('AEmoblack') 190 | fake_B_eyel = self.add_with_mask(self.fake_B_eyel2,self.fake_B_eyel1,self.cmaskel) 191 | fake_B_eyer = self.add_with_mask(self.fake_B_eyer2,self.fake_B_eyer1,self.cmasker) 192 | fake_B_mouth = self.add_with_mask(self.fake_B_mouth2,self.fake_B_mouth1,self.cmaskmo) 193 | # NOSE 194 | if not self.opt.nose_ae: 195 | fake_B_nose = self.netGLNose(self.real_A_nose) 196 | else: # use AE that only constains compact region, need cmask! 197 | self.fake_B_nose1 = self.netGLNose(self.real_A_nose) 198 | self.fake_B_nose2,_ = self.netAE(self.fake_B_nose1) 199 | fake_B_nose = self.add_with_mask(self.fake_B_nose2,self.fake_B_nose1,self.cmask) 200 | 201 | # HAIR, BG AND PARTCOMBINE 202 | outputs2 = self.netCLh(self.real_A_hair) 203 | onehot2 = self.getonehot(outputs2,3) 204 | 205 | fake_B_hair = self.netGLHair(self.real_A_hair,onehot2) 206 | fake_B_bg = self.netGLBG(self.real_A_bg) 207 | self.fake_B_hair = self.masked(fake_B_hair,self.mask*self.mask2) 208 | self.fake_B_bg = self.masked(fake_B_bg,self.inverse_mask(self.mask2)) 209 | if not self.opt.compactmask: 210 | self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,fake_B_bg,self.mask*self.mask2,self.inverse_mask(self.mask2),self.opt.comb_op) 211 | else: 212 | self.fake_B1 = self.partCombiner2_bg(fake_B_eyel,fake_B_eyer,fake_B_nose,fake_B_mouth,fake_B_hair,fake_B_bg,self.mask*self.mask2,self.inverse_mask(self.mask2),self.opt.comb_op,self.opt.region_enm,self.cmaskel,self.cmasker,self.cmask,self.cmaskmo) 213 | 214 | self.fake_B = self.netGCombine(torch.cat([self.fake_B0,self.fake_B1],1)) 215 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/options/__init__.py -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | def __init__(self): 11 | self.initialized = False 12 | 13 | def initialize(self, parser): 14 | parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 15 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 16 | parser.add_argument('--loadSize', type=int, default=512, help='scale images to this size') 17 | parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size') 18 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 19 | parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels') 20 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 21 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 22 | parser.add_argument('--netD', type=str, default='basic', help='selects model to use for netD') 23 | parser.add_argument('--netG', type=str, default='unet_256', help='selects model to use for netG') 24 | parser.add_argument('--nnG', type=int, default=9, help='specify nblock for resnet_nblocks, ndown for unet for unet_ndown') 25 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') 26 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 27 | parser.add_argument('--gpu_ids_p', type=str, default='0', help='gpu ids for pretrained auxiliary models: e.g. 0 0,1,2, 0,2. use -1 for CPU') 28 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 29 | parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [unaligned | aligned | single]') 30 | parser.add_argument('--model', type=str, default='apdrawing', 31 | help='chooses which model to use. cycle_gan, pix2pix, test, autoencoder') 32 | parser.add_argument('--use_local', action='store_true', help='use local part network') 33 | parser.add_argument('--lm_dir', type=str, default='dataset/landmark/', help='path to facial landmarks') 34 | parser.add_argument('--nose_ae', action='store_true', help='use nose autoencoder') 35 | parser.add_argument('--others_ae', action='store_true', help='use autoencoder for eyes and mouth too') 36 | parser.add_argument('--nose_ae_net', type=str, default='autoencoderfc', help='net for nose autoencoder [autoencoder | autoencoderfc]') 37 | parser.add_argument('--comb_op', type=int, default=1, help='use min-pooling(1) or max-pooling(0) for overlapping regions') 38 | parser.add_argument('--hair_local', action='store_true', help='add hair part') 39 | parser.add_argument('--bg_local', action='store_true', help='use background mask to seperate background') 40 | parser.add_argument('--bg_dir', default='dataset/mask/bg/', type=str, help='choose bg_dir') 41 | parser.add_argument('--region_enm', type=int, default=0, help='region type for eyes nose mouth: 0 for rectangle, 1 for campact mask in rectangle, 2 for mask no rectangle (1,2 must have compactmask, 0 use compactmask for AE)') 42 | parser.add_argument('--soft_border', type=int, default=0, help='use mask with soft border') 43 | parser.add_argument('--EYE_H', type=int, default=40, help='EYE_H') 44 | parser.add_argument('--EYE_W', type=int, default=56, help='EYE_W') 45 | parser.add_argument('--NOSE_H', type=int, default=48, help='NOSE_H') 46 | parser.add_argument('--NOSE_W', type=int, default=48, help='NOSE_W') 47 | parser.add_argument('--MOUTH_H', type=int, default=40, help='MOUTH_H') 48 | parser.add_argument('--MOUTH_W', type=int, default=64, help='MOUTH_W') 49 | parser.add_argument('--average_pos', action='store_true', help='use avg pos in partCombiner') 50 | parser.add_argument('--combiner_type', type=str, default='combiner', help='choose combiner type') 51 | parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA') 52 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 53 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 54 | parser.add_argument('--auxiliary_root', type=str, default='auxiliary', help='auxiliary model folder') 55 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 56 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 57 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 58 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 59 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 60 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 61 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 62 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 63 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 64 | parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 65 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 66 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 67 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 68 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 69 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{loadSize}') 70 | # compact mask 71 | parser.add_argument('--compactmask', action='store_true', help='use compact mask as input and apply to loss')# "when you calculate the (ae) loss, you should also restrict to nose pixels" 72 | parser.add_argument('--cmask_dir', type=str, default='dataset/mask/', help='compact mask directory') 73 | parser.add_argument('--ae_latentno', type=int, default=1024 ,help='latent space dim for pretrained NOSE AEwithfc') 74 | parser.add_argument('--ae_latentmo', type=int, default=1024 ,help='latent space dim for pretrained MOUTH AEwithfc') 75 | parser.add_argument('--ae_latenteye', type=int, default=1024 ,help='latent space dim for pretrained EYEL/EYER AEwithfc') 76 | parser.add_argument('--ae_small', type=int, default=0 ,help='use latent dim smaller than default 1024 in 4 AEs') 77 | # below for autoencoder 78 | parser.add_argument('--ae_latent', type=int, default=1024 ,help='latent space dim for autoencoderfc') 79 | parser.add_argument('--ae_multiple', type=float, default=2 ,help='filter number change in ae encoder') 80 | parser.add_argument('--ae_h', type=int, default=96 ,help='ae input h') 81 | parser.add_argument('--ae_w', type=int, default=96 ,help='ae input w') 82 | parser.add_argument('--ae_region', type=str, default='nose' ,help='autoencoder for which region') 83 | parser.add_argument('--no_ae', action='store_true', help='no ae') 84 | self.initialized = True 85 | return parser 86 | 87 | def gather_options(self): 88 | # initialize parser with basic options 89 | if not self.initialized: 90 | parser = argparse.ArgumentParser( 91 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 92 | parser = self.initialize(parser) 93 | 94 | # get the basic options 95 | opt, _ = parser.parse_known_args() 96 | 97 | # modify model-related parser options 98 | model_name = opt.model 99 | model_option_setter = models.get_option_setter(model_name) 100 | parser = model_option_setter(parser, self.isTrain) 101 | opt, _ = parser.parse_known_args() # parse again with the new defaults 102 | 103 | # modify dataset-related parser options 104 | dataset_name = opt.dataset_mode 105 | dataset_option_setter = data.get_option_setter(dataset_name) 106 | parser = dataset_option_setter(parser, self.isTrain) 107 | 108 | self.parser = parser 109 | 110 | return parser.parse_args() 111 | 112 | def print_options(self, opt): 113 | message = '' 114 | message += '----------------- Options ---------------\n' 115 | for k, v in sorted(vars(opt).items()): 116 | comment = '' 117 | default = self.parser.get_default(k) 118 | if v != default: 119 | comment = '\t[default: %s]' % str(default) 120 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 121 | message += '----------------- End -------------------' 122 | print(message) 123 | 124 | # save to the disk 125 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 126 | util.mkdirs(expr_dir) 127 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) 128 | with open(file_name, 'wt') as opt_file: 129 | opt_file.write(message) 130 | opt_file.write('\n') 131 | 132 | def parse(self, print=True): 133 | 134 | opt = self.gather_options() 135 | if opt.use_local: 136 | opt.loadSize = opt.fineSize 137 | if opt.region_enm in [1,2]: 138 | opt.compactmask = True 139 | if opt.nose_ae or opt.others_ae: 140 | opt.compactmask = True 141 | if opt.ae_latentno < 1024 and opt.ae_latentmo < 1024 and opt.ae_latenteye < 1024: 142 | opt.ae_small = 1 143 | opt.isTrain = self.isTrain # train or test 144 | 145 | # process opt.suffix 146 | if opt.suffix: 147 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 148 | opt.name = opt.name + suffix 149 | 150 | if self.isTrain and opt.pretrain: 151 | opt.nose_ae = False 152 | opt.others_ae = False 153 | opt.compactmask = False 154 | opt.chamfer_loss = False 155 | if not self.isTrain and opt.pretrain: 156 | opt.nose_ae = False 157 | opt.others_ae = False 158 | opt.compactmask = False 159 | if opt.no_ae: 160 | opt.nose_ae = False 161 | opt.others_ae = False 162 | opt.compactmask = False 163 | if self.isTrain and opt.no_dtremap: 164 | opt.dt_nonlinear = '' 165 | opt.lambda_chamfer = 0.1 166 | opt.lambda_chamfer2 = 0.1 167 | if self.isTrain and opt.no_dt: 168 | opt.chamfer_loss = False 169 | 170 | if print: 171 | self.print_options(opt) 172 | 173 | # set gpu ids 174 | str_ids = opt.gpu_ids.split(',') 175 | opt.gpu_ids = [] 176 | for str_id in str_ids: 177 | id = int(str_id) 178 | if id >= 0: 179 | opt.gpu_ids.append(id) 180 | if len(opt.gpu_ids) > 0: 181 | torch.cuda.set_device(opt.gpu_ids[0]) 182 | 183 | # set gpu ids 184 | str_ids = opt.gpu_ids_p.split(',') 185 | opt.gpu_ids_p = [] 186 | for str_id in str_ids: 187 | id = int(str_id) 188 | if id >= 0: 189 | opt.gpu_ids_p.append(id) 190 | 191 | self.opt = opt 192 | return self.opt 193 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 8 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 9 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 10 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 11 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 12 | parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') 13 | parser.add_argument('--test_continuity_loss', action='store_true', help='get continuity value in test') 14 | parser.add_argument('--netG_line', type=str, default='unet_512', help='selects model to use for netG_line') 15 | parser.add_argument('--save2', action='store_true', help='only save real_A and fake_B') 16 | parser.add_argument('--imagefolder', type=str, default='images', help='subfolder to save images') 17 | parser.add_argument('--pretrain', action='store_true', help='pretrain stage, no dt loss, no ae') 18 | 19 | parser.set_defaults(model='test') 20 | # To avoid cropping, the loadSize should be the same as fineSize 21 | parser.set_defaults(loadSize=parser.get_default('fineSize')) 22 | self.isTrain = False 23 | return parser 24 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 8 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 9 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 10 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 11 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 12 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 13 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 14 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 15 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 16 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 17 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 18 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 19 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 20 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 21 | parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 22 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 23 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 24 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine') 25 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 26 | # ============================================loss========================================================= 27 | # chamfer loss 28 | parser.add_argument('--chamfer_loss', action='store_true', help='use chamfer loss') 29 | parser.add_argument('--chamfer_2way', action='store_true', help='use chamfer loss 2 way') 30 | parser.add_argument('--chamfer_only_line', action='store_true', help='use chamfer only on lines') 31 | parser.add_argument('--lambda_chamfer', type=float, default=0.1, help='weight for chamfer loss') 32 | parser.add_argument('--lambda_chamfer2', type=float, default=0.1, help='weight for chamfer loss2') 33 | parser.add_argument('--dt_nonlinear', type=str, default='', help='nonlinear remap on dt [atan | sigmoid | tanh]') 34 | parser.add_argument('--dt_xmax', type=float, default=10, help='first mutiply dt to range [0,xmax], then use atan/sigmoid/tanh etc, to have more nonlinearity (not much nonlinearity in range [0,1])') 35 | # line continuity loss 36 | parser.add_argument('--continuity_loss', action='store_true', help='use line continuity loss') 37 | parser.add_argument('--lambda_continuity', type=float, default=10.0, help='weight for continuity loss') 38 | parser.add_argument('--emphasis_conti_face', action='store_true', help='constrain conti loss to pixels in original lines (avoid apply to background etc)') 39 | parser.add_argument('--facemask_dir', type=str, default='dataset/mask/face/', help='mask folder to constrain conti loss to pixels in original lines') 40 | # =====================================auxilary net structure=============================================== 41 | # dt & line net structure 42 | parser.add_argument('--netG_dt', type=str, default='unet_512', help='selects model to use for netG_dt, for chamfer loss') 43 | parser.add_argument('--netG_line', type=str, default='unet_512', help='selects model to use for netG_line, for chamfer loss') 44 | # multiple discriminators 45 | parser.add_argument('--discriminator_local', action='store_true', help='use six diffent local discriminator for 6 local regions') 46 | parser.add_argument('--gan_loss_strategy', type=int, default=2, help='specify how to calculate gan loss for g, 1: average global and local discriminators; 2: not change global discriminator weight, 0.25 for local') 47 | parser.add_argument('--addw_eye', type=float, default=1.0, help='additional weight for eye region') 48 | parser.add_argument('--addw_nose', type=float, default=1.0, help='additional weight for nose region') 49 | parser.add_argument('--addw_mouth', type=float, default=1.0, help='additional weight for mouth region') 50 | parser.add_argument('--addw_hair', type=float, default=1.0, help='additional weight for hair region') 51 | parser.add_argument('--addw_bg', type=float, default=1.0, help='additional weight for bg region') 52 | # ==========================================ablation======================================================== 53 | parser.add_argument('--no_l1_loss', action='store_true', help='no l1 loss') 54 | parser.add_argument('--no_G_local_loss', action='store_true', help='not using local transfer loss for local generator output') 55 | parser.add_argument('--no_dtremap', action='store_true', help='no dt remap') 56 | parser.add_argument('--no_dt', action='store_true', help='no dt') 57 | 58 | parser.add_argument('--pretrain', action='store_true', help='pretrain stage, no dt loss, no ae') 59 | 60 | 61 | self.isTrain = True 62 | return parser 63 | -------------------------------------------------------------------------------- /preprocess/combine_A_and_B.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser('create image pairs') 7 | parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges') 8 | parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg') 9 | parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB') 10 | parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000) 11 | parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true') 12 | args = parser.parse_args() 13 | 14 | for arg in vars(args): 15 | print('[%s] = ' % arg, getattr(args, arg)) 16 | 17 | splits = os.listdir(args.fold_A) 18 | 19 | for sp in splits: 20 | img_fold_A = os.path.join(args.fold_A, sp) 21 | img_fold_B = os.path.join(args.fold_B, sp) 22 | img_list = os.listdir(img_fold_A) 23 | if args.use_AB: 24 | img_list = [img_path for img_path in img_list if '_A.' in img_path] 25 | 26 | num_imgs = min(args.num_imgs, len(img_list)) 27 | print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list))) 28 | img_fold_AB = os.path.join(args.fold_AB, sp) 29 | if not os.path.isdir(img_fold_AB): 30 | os.makedirs(img_fold_AB) 31 | print('split = %s, number of images = %d' % (sp, num_imgs)) 32 | for n in range(num_imgs): 33 | name_A = img_list[n] 34 | path_A = os.path.join(img_fold_A, name_A) 35 | if args.use_AB: 36 | name_B = name_A.replace('_A.', '_B.') 37 | else: 38 | name_B = name_A 39 | path_B = os.path.join(img_fold_B, name_B) 40 | if os.path.isfile(path_A) and os.path.isfile(path_B): 41 | name_AB = name_A 42 | if args.use_AB: 43 | name_AB = name_AB.replace('_A.', '.') # remove _A 44 | path_AB = os.path.join(img_fold_AB, name_AB) 45 | im_A = cv2.imread(path_A, cv2.IMREAD_COLOR) 46 | im_B = cv2.imread(path_B, cv2.IMREAD_COLOR) 47 | im_AB = np.concatenate([im_A, im_B], 1) 48 | cv2.imwrite(path_AB, im_AB) 49 | -------------------------------------------------------------------------------- /preprocess/example/img_1701.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/preprocess/example/img_1701.jpg -------------------------------------------------------------------------------- /preprocess/example/img_1701_aligned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/preprocess/example/img_1701_aligned.png -------------------------------------------------------------------------------- /preprocess/example/img_1701_aligned.txt: -------------------------------------------------------------------------------- 1 | 194 248 2 | 314 249 3 | 261 312 4 | 209 368 5 | 302 371 6 | -------------------------------------------------------------------------------- /preprocess/example/img_1701_aligned_68lm.txt: -------------------------------------------------------------------------------- 1 | 120 261 2 | 124 294 3 | 129 326 4 | 133 358 5 | 142 388 6 | 162 412 7 | 190 430 8 | 220 445 9 | 253 449 10 | 287 447 11 | 317 432 12 | 344 411 13 | 362 385 14 | 370 354 15 | 375 322 16 | 382 291 17 | 385 258 18 | 142 225 19 | 161 209 20 | 188 204 21 | 215 208 22 | 242 218 23 | 269 218 24 | 296 208 25 | 324 206 26 | 351 213 27 | 369 231 28 | 256 244 29 | 256 264 30 | 256 284 31 | 256 305 32 | 232 324 33 | 244 328 34 | 256 332 35 | 267 329 36 | 277 325 37 | 172 252 38 | 186 243 39 | 203 243 40 | 218 253 41 | 203 257 42 | 186 257 43 | 290 254 44 | 305 244 45 | 322 246 46 | 336 255 47 | 322 260 48 | 305 259 49 | 210 368 50 | 229 358 51 | 245 352 52 | 256 354 53 | 267 352 54 | 283 358 55 | 300 368 56 | 284 382 57 | 268 388 58 | 255 389 59 | 244 388 60 | 228 381 61 | 220 368 62 | 245 363 63 | 256 364 64 | 267 364 65 | 290 368 66 | 267 370 67 | 255 372 68 | 244 371 69 | -------------------------------------------------------------------------------- /preprocess/example/img_1701_aligned_bgmask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/preprocess/example/img_1701_aligned_bgmask.png -------------------------------------------------------------------------------- /preprocess/example/img_1701_aligned_eyelmask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/preprocess/example/img_1701_aligned_eyelmask.png -------------------------------------------------------------------------------- /preprocess/example/img_1701_aligned_eyermask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/preprocess/example/img_1701_aligned_eyermask.png -------------------------------------------------------------------------------- /preprocess/example/img_1701_aligned_facemask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/preprocess/example/img_1701_aligned_facemask.png -------------------------------------------------------------------------------- /preprocess/example/img_1701_aligned_mouthmask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/preprocess/example/img_1701_aligned_mouthmask.png -------------------------------------------------------------------------------- /preprocess/example/img_1701_aligned_nosemask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/preprocess/example/img_1701_aligned_nosemask.png -------------------------------------------------------------------------------- /preprocess/example/img_1701_facial5point.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/preprocess/example/img_1701_facial5point.mat -------------------------------------------------------------------------------- /preprocess/face_align_512.m: -------------------------------------------------------------------------------- 1 | function [trans_img,trans_facial5point]=face_align_512(impath,facial5point,savedir) 2 | % align the faces by similarity transformation. 3 | % using 5 facial landmarks: 2 eyes, nose, 2 mouth corners. 4 | % impath: path to image 5 | % facial5point: 5x2 size, 5 facial landmark positions, detected by MTCNN 6 | % savedir: savedir for cropped image and transformed facial landmarks 7 | 8 | %% alignment settings 9 | imgSize = [512,512]; 10 | coord5point = [180,230; 11 | 300,230; 12 | 240,301; 13 | 186,365.6; 14 | 294,365.6];%480x480 15 | coord5point = (coord5point-240)/560 * 512 + 256; 16 | 17 | %% face alignment 18 | 19 | % load and align, resize image to imgSize 20 | img = imread(impath); 21 | facial5point = double(facial5point); 22 | transf = cp2tform(facial5point, coord5point, 'similarity'); 23 | trans_img = imtransform(img, transf, 'XData', [1 imgSize(2)],... 24 | 'YData', [1 imgSize(1)],... 25 | 'Size', imgSize,... 26 | 'FillValues', [255;255;255]); 27 | trans_facial5point = round(tformfwd(transf,facial5point)); 28 | 29 | 30 | %% save results 31 | if ~exist(savedir,'dir') 32 | mkdir(savedir) 33 | end 34 | [~,name,~] = fileparts(impath); 35 | % save trans_img 36 | imwrite(trans_img, fullfile(savedir,[name,'_aligned.png'])); 37 | fprintf('write aligned image to %s\n',fullfile(savedir,[name,'_aligned.png'])); 38 | % save trans_facial5point 39 | write_5pt(fullfile(savedir, [name, '_aligned.txt']), trans_facial5point); 40 | fprintf('write transformed facial landmark to %s\n',fullfile(savedir,[name,'_aligned.txt'])); 41 | 42 | %% show results 43 | imshow(trans_img); hold on; 44 | plot(trans_facial5point(:,1),trans_facial5point(:,2),'b'); 45 | plot(trans_facial5point(:,1),trans_facial5point(:,2),'r+'); 46 | 47 | end 48 | 49 | function [] = write_5pt(fn, trans_pt) 50 | fid = fopen(fn, 'w'); 51 | for i = 1:5 52 | fprintf(fid, '%d %d\n', trans_pt(i,1), trans_pt(i,2));%will be read as np.int32 53 | end 54 | fclose(fid); 55 | end -------------------------------------------------------------------------------- /preprocess/get_partmask.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os, glob, csv, shutil 3 | import numpy as np 4 | import dlib 5 | import math 6 | from shapely.geometry import Point 7 | from shapely.geometry import Polygon 8 | import sys 9 | 10 | detector = dlib.get_frontal_face_detector() 11 | predictor = dlib.shape_predictor('../checkpoints/shape_predictor_68_face_landmarks.dat') 12 | 13 | def getfeats(featpath): 14 | trans_points = np.empty([68,2],dtype=np.int64) 15 | with open(featpath, 'r') as csvfile: 16 | reader = csv.reader(csvfile, delimiter=' ') 17 | for ind,row in enumerate(reader): 18 | trans_points[ind,:] = row 19 | return trans_points 20 | 21 | def getinternal(lm1,lm2): 22 | lminternal = [] 23 | if abs(lm1[1]-lm2[1]) > abs(lm1[0]-lm2[0]): 24 | if lm1[1] > lm2[1]: 25 | tmp = lm1 26 | lm1 = lm2 27 | lm2 = tmp 28 | for y in range(lm1[1]+1,lm2[1]): 29 | x = int(round(float(y-lm1[1])/(lm2[1]-lm1[1])*(lm2[0]-lm1[0])+lm1[0])) 30 | lminternal.append((x,y)) 31 | else: 32 | if lm1[0] > lm2[0]: 33 | tmp = lm1 34 | lm1 = lm2 35 | lm2 = tmp 36 | for x in range(lm1[0]+1,lm2[0]): 37 | y = int(round(float(x-lm1[0])/(lm2[0]-lm1[0])*(lm2[1]-lm1[1])+lm1[1])) 38 | lminternal.append((x,y)) 39 | return lminternal 40 | 41 | def mulcross(p,x_1,x):#p-x_1,x-x_1 42 | vp = [p[0]-x_1[0],p[1]-x_1[1]] 43 | vq = [x[0]-x_1[0],x[1]-x_1[1]] 44 | return vp[0]*vq[1]-vp[1]*vq[0] 45 | 46 | def shape_to_np(shape, dtype="int"): 47 | # initialize the list of (x, y)-coordinates 48 | coords = np.zeros((shape.num_parts, 2), dtype=dtype) 49 | # loop over all facial landmarks and convert them 50 | # to a 2-tuple of (x, y)-coordinates 51 | for i in range(0, shape.num_parts): 52 | coords[i] = (shape.part(i).x, shape.part(i).y) 53 | # return the list of (x, y)-coordinates 54 | return coords 55 | 56 | def get_68lm(imgfile,savepath): 57 | image = cv2.imread(imgfile) 58 | rgbImg = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 59 | rects = detector(rgbImg, 1) 60 | for (i, rect) in enumerate(rects): 61 | landmarks = predictor(rgbImg, rect) 62 | landmarks = shape_to_np(landmarks) 63 | f = open(savepath,'w') 64 | for i in range(len(landmarks)): 65 | lm = landmarks[i] 66 | print(lm[0], lm[1], file=f) 67 | f.close() 68 | 69 | def get_partmask(imgfile,part,lmpath,savefile): 70 | img = cv2.imread(imgfile) 71 | mask = np.zeros(img.shape, np.uint8) 72 | lms = getfeats(lmpath) 73 | 74 | if os.path.exists(savefile): 75 | return 76 | 77 | if part == 'nose': 78 | # 27,31....,35 -> up, left, right, lower5 -- eight points 79 | up = [int(round(1.2*lms[27][0]-0.2*lms[33][0])),int(round(1.2*lms[27][1]-0.2*lms[33][1]))] 80 | lower5 = [[0,0]]*5 81 | for i in range(31,36): 82 | lower5[i-31] = [int(round(1.1*lms[i][0]-0.1*lms[27][0])),int(round(1.1*lms[i][1]-0.1*lms[27][1]))] 83 | ratio = 2.5 84 | left = [int(round(ratio*lower5[0][0]-(ratio-1)*lower5[1][0])),int(round(ratio*lower5[0][1]-(ratio-1)*lower5[1][1]))] 85 | right = [int(round(ratio*lower5[4][0]-(ratio-1)*lower5[3][0])),int(round(ratio*lower5[4][1]-(ratio-1)*lower5[3][1]))] 86 | loop = [up,left,lower5[0],lower5[1],lower5[2],lower5[3],lower5[4],right] 87 | elif part == 'eyel': 88 | height = max(lms[41][1]-lms[37][1],lms[40][1]-lms[38][1]) 89 | width = lms[39][0]-lms[36][0] 90 | ratio = 0.1 91 | gap = int(math.ceil(width*ratio)) 92 | ratio2 = 0.6 93 | gaph = int(math.ceil(height*ratio2)) 94 | ratio3 = 1.5 95 | gaph2 = int(math.ceil(height*ratio3)) 96 | upper = [[lms[17][0]-2*gap,lms[17][1]],[lms[17][0]-2*gap,lms[17][1]-gaph],[lms[18][0],lms[18][1]-gaph],[lms[19][0],lms[19][1]-gaph],[lms[20][0],lms[20][1]-gaph],[lms[21][0]+gap*2,lms[21][1]-gaph]] 97 | lower = [[lms[39][0]+gap,lms[40][1]+gaph2],[lms[40][0],lms[40][1]+gaph2],[lms[41][0],lms[41][1]+gaph2],[lms[36][0]-2*gap,lms[41][1]+gaph2]] 98 | loop = upper + lower 99 | loop.reverse() 100 | elif part == 'eyer': 101 | height = max(lms[47][1]-lms[43][1],lms[46][1]-lms[44][1]) 102 | width = lms[45][0]-lms[42][0] 103 | ratio = 0.1 104 | gap = int(math.ceil(width*ratio)) 105 | ratio2 = 0.6 106 | gaph = int(math.ceil(height*ratio2)) 107 | ratio3 = 1.5 108 | gaph2 = int(math.ceil(height*ratio3)) 109 | upper = [[lms[22][0]-2*gap,lms[22][1]],[lms[22][0]-2*gap,lms[22][1]-gaph],[lms[23][0],lms[23][1]-gaph],[lms[24][0],lms[24][1]-gaph],[lms[25][0],lms[25][1]-gaph],[lms[26][0]+gap*2,lms[26][1]-gaph]] 110 | lower = [[lms[45][0]+2*gap,lms[46][1]+gaph2],[lms[46][0],lms[46][1]+gaph2],[lms[47][0],lms[47][1]+gaph2],[lms[42][0]-gap,lms[42][1]+gaph2]] 111 | loop = upper + lower 112 | loop.reverse() 113 | elif part == 'mouth': 114 | height = lms[62][1]-lms[51][1] 115 | width = lms[54][0]-lms[48][0] 116 | ratio = 1 117 | ratio2 = 0.2#0.1 118 | gaph = int(math.ceil(ratio*height)) 119 | gapw = int(math.ceil(ratio2*width)) 120 | left = [(lms[48][0]-gapw,lms[48][1])] 121 | upper = [(lms[i][0], lms[i][1]-gaph) for i in range(48,55)] 122 | right = [(lms[54][0]+gapw,lms[54][1])] 123 | lower = [(lms[i][0], lms[i][1]+gaph) for i in list(range(54,60))+[48]] 124 | loop = left + upper + right + lower 125 | loop.reverse() 126 | pl = Polygon(loop) 127 | 128 | for i in range(mask.shape[0]): 129 | for j in range(mask.shape[1]): 130 | if part != 'mouth' and part != 'jaw': 131 | p = [j,i] 132 | flag = 1 133 | for k in range(len(loop)): 134 | if mulcross(p,loop[k],loop[(k+1)%len(loop)]) < 0:#y downside... >0 represents counter-clockwise, <0 clockwise 135 | flag = 0 136 | break 137 | else: 138 | p = Point(j,i) 139 | flag = pl.contains(p) 140 | if flag: 141 | mask[i,j] = [255,255,255] 142 | if not os.path.exists(os.path.dirname(savefile)): 143 | os.mkdir(os.path.dirname(savefile)) 144 | cv2.imwrite(savefile,mask) 145 | 146 | if __name__ == '__main__': 147 | imgfile = 'example/img_1701_aligned.png' 148 | lmfile = 'example/img_1701_aligned_68lm.txt' 149 | get_68lm(imgfile,lmfile) 150 | for part in ['eyel','eyer','nose','mouth']: 151 | savepath = 'example/img_1701_aligned_'+part+'mask.png' 152 | get_partmask(imgfile,part,lmfile,savepath) 153 | -------------------------------------------------------------------------------- /preprocess/readme.md: -------------------------------------------------------------------------------- 1 | ## Preprocessing steps 2 | 3 | Both training and testing images need: 4 | 5 | - align to 512x512 6 | - facial landmarks 7 | - mask for eyes,nose,mouth,background 8 | 9 | Training images additionally need: 10 | 11 | - mask for face region 12 | 13 | 14 | ### 1. Align, resize, crop images to 512x512, and get facial landmarks 15 | 16 | All training and testing images in our model are aligned using facial landmarks. And landmarks after alignment are needed in our code. 17 | 18 | - First, 5 facial landmark for a face photo need to be detected (we detect using [MTCNN](https://github.com/kpzhang93/MTCNN_face_detection_alignment)(MTCNNv1)). 19 | 20 | - Then, we provide a matlab function in `face_align_512.m` to align, resize and crop face photos (and corresponding drawings) to 512x512.Call this function in MATLAB to align the image to 512x512. 21 | For example, for `img_1701.jpg` in `example` dir, 5 detected facial landmark is saved in `example/img_1701_facial5point.mat`. Call following in MATLAB: 22 | ```bash 23 | load('example/img_1701_facial5point.mat'); 24 | [trans_img,trans_facial5point]=face_align_512('example/img_1701.jpg',facial5point,'example'); 25 | ``` 26 | 27 | This will align the image, and output aligned image + transformed facial landmark (in txt format) in `example` folder. 28 | See `face_align_512.m` for more instructions. 29 | 30 | The saved transformed facial landmark need to be copied to `dataset/landmark/`, and has the **same filename** with aligned face photos (e.g. `dataset/data/test_single/31.png` should have landmark file `dataset/landmark/31.txt`). 31 | 32 | ### 2. Prepare background masks 33 | 34 | In our work, background mask is segmented by method in 35 | "Automatic Portrait Segmentation for Image Stylization" 36 | Xiaoyong Shen, Aaron Hertzmann, Jiaya Jia, Sylvain Paris, Brian Price, Eli Shechtman, Ian Sachs. Computer Graphics Forum, 35(2)(Proc. Eurographics), 2016. 37 | 38 | We use code in http://xiaoyongshen.me/webpage_portrait/index.html to detect background masks for aligned face photos. 39 | An example background mask is shown in `example/img_1701_aligned_bgmask.png`. 40 | 41 | The background masks need to be copied to `dataset/mask/bg/`, and has the **same filename** with aligned face photos (e.g. `dataset/data/test_single/31.png` should have background mask `dataset/mask/bg/31.png`) 42 | 43 | ### 3. Prepare eyes/nose/mouth masks 44 | 45 | We use dlib to extract 68 landmarks for aligned face photos, and use these landmarks to get masks for local regions. 46 | See an example in `get_partmask.py`, the eyes, nose, mouth masks for `example/img_1701_aligned.png` are `example/img_1701_aligned_[part]mask.png`, where part is in [eyel,eyer,nose,mouth]. 47 | 48 | The part masks need to be copied to `dataset/mask/[part]/`, and has the **same filename** with aligned face photos. 49 | 50 | ### 4. (For training) Prepare face masks 51 | 52 | We use the face parsing net in https://github.com/cientgu/Mask_Guided_Portrait_Editing to detect face region. 53 | The face parsing net will label each face into 11 classes, the 0 is for background, 10 is for hair, and the 1~9 are face regions. 54 | An example face mask is shown in `example/img_1701_aligned_facemask.png`. 55 | 56 | The face masks need to be copied to `dataset/mask/face/`, and has the **same filename** with aligned face photos. 57 | 58 | ### 5. (For training) Combine A and B 59 | 60 | We provide a python script to generate training data in the form of pairs of images {A,B}, i.e. pairs {face photo, drawing}. This script will concatenate each pair of images horizontally into one single image. Then we can learn to translate A to B: 61 | 62 | Create folder `/path/to/data` with subfolders `A` and `B`. `A` and `B` should each have their own subfolders `train`, `test`, etc. In `/path/to/data/A/train`, put training face photos. In `/path/to/data/B/train`, put the corresponding artist drawings. Repeat same for `test`. 63 | 64 | Corresponding images in a pair {A,B} must both be images after aligning and of size 512x512, and have the same filename, e.g., `/path/to/data/A/train/1.png` is considered to correspond to `/path/to/data/B/train/1.png`. 65 | 66 | Once the data is formatted this way, call: 67 | ```bash 68 | python datasets/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data 69 | ``` 70 | 71 | This will combine each pair of images (A,B) into a single image file, ready for training. -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 2 | # APDrawingGAN++ 3 | 4 | We provide PyTorch implementations for our TPAMI paper "Line Drawings for Face Portraits from Photos using Global and Local Structure based GANs". 5 | It is a journal extension of our previous CVPR 2019 work [APDrawingGAN](https://github.com/yiranran/APDrawingGAN). 6 | 7 | This project generates artistic portrait drawings from face photos using a GAN-based model. 8 | You may find useful information in [preprocessing steps](preprocess/readme.md) and [training/testing tips](docs/tips.md). 9 | 10 | [[Jittor implementation]](https://github.com/yiranran/APDrawingGAN2-Jittor) 11 | 12 | ## Our Proposed Framework 13 | 14 | 15 | 16 | ## Sample Results 17 | Up: input, Down: output 18 |

19 | 20 | 21 | 22 | 23 | 24 | 25 |

26 |

27 | 28 | 29 | 30 | 31 | 32 | 33 |

34 | 35 | ## Citation 36 | If you use this code for your research, please cite our paper. 37 | ``` 38 | @inproceedings{YiXLLR20, 39 | title = {Line Drawings for Face Portraits from Photos using Global and Local Structure based {GAN}s}, 40 | author = {Yi, Ran and Xia, Mengfei and Liu, Yong-Jin and Lai, Yu-Kun and Rosin, Paul L}, 41 | booktitle = {{IEEE} Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, 42 | doi = {10.1109/TPAMI.2020.2987931}, 43 | year = {2020} 44 | } 45 | ``` 46 | 47 | ## Prerequisites 48 | - Linux or macOS 49 | - Python 2 or 3 50 | - CPU or NVIDIA GPU + CUDA CuDNN 51 | 52 | 53 | ## Getting Started 54 | ### 1.Installation 55 | ```bash 56 | pip install -r requirements.txt 57 | ``` 58 | 59 | ### 2.Quick Start (Apply a Pre-trained Model) 60 | - Download APDrawing dataset from [BaiduYun](https://pan.baidu.com/s/1cN5gEYJ2tnE9WboLA79Z5g)(extract code:0zuv) or [YandexDrive](https://yadi.sk/d/4vWhi8-ZQj_nRw), and extract to `dataset`. 61 | 62 | - Download pre-trained models and auxiliary nets from [BaiduYun](https://pan.baidu.com/s/1nrtCHQmgcwbSGxWuAVzWhA)(extract code:imqp) or [YandexDrive](https://yadi.sk/d/DS4271lbEPhGVQ), and extract to `checkpoints`. 63 | 64 | - Generate artistic portrait drawings for example photos in `dataset/test_single` using 65 | ``` bash 66 | python test.py --dataroot dataset/test_single --name apdrawinggan++_author --model test --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-single 67 | ``` 68 | The test results will be saved to a html file here: `./results/apdrawinggan++_author/test_150/index-single.html`. 69 | 70 | - If you want to test on your own data, please first align your pictures and prepare your data's facial landmarks and masks according to tutorial in [preprocessing steps](preprocess/readme.md), then change the --dataroot flag above to your directory of aligned photos. 71 | 72 | ### 3.Train 73 | - Run `python -m visdom.server` 74 | - Train a model (with pre-training as initialization): 75 | first copy "pre2" models into checkpoints dir of current experiment, e.g. `checkpoints/apdrawinggan++_1`. 76 | ```bash 77 | mkdir checkpoints/apdrawinggan++_1/ 78 | cp checkpoints/pre2/*.pt checkpoints/apdrawinggan++_1/ 79 | python train.py --dataroot dataset/AB_140_aug3_H_hm2 --name apdrawinggan++_1 --model apdrawingpp_style --use_resnet --netG resnet_9blocks --continue_train --continuity_loss --lambda_continuity 40.0 --gpu_ids 0 --gpu_ids_p 1 --display_env apdrawinggan++_1 --niter 200 --niter_decay 0 --lr 0.0001 --batch_size 1 --emphasis_conti_face --auxiliary_root auxiliaryeye2o 80 | ``` 81 | - To view training results and loss plots, click the URL http://localhost:8097. To see more intermediate results, check out `./checkpoints/apdrawinggan++_1/web/index.html` 82 | 83 | ### 4.Test 84 | - To test the model on test set: 85 | ```bash 86 | python test.py --dataroot dataset/AB_140_aug3_H_hm2 --name apdrawinggan++_author --model apdrawingpp_style --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-apd70 87 | ``` 88 | The test results will be saved to a html file: `./results/apdrawinggan++_author/test_150/index-apd70.html`. 89 | 90 | - To test the model on images without paired ground truth, same as 2. Apply a pre-trained model. 91 | 92 | You can find these scripts at `scripts` directory. 93 | 94 | 95 | ## [Preprocessing Steps](preprocess/readme.md) 96 | Preprocessing steps for your own data (either for testing or training). 97 | 98 | 99 | ## [Training/Test Tips](docs/tips.md) 100 | Best practice for training and testing your models. 101 | 102 | You can contact email yr16@mails.tsinghua.edu.cn for any questions. 103 | 104 | ## Acknowledgments 105 | Our code is inspired by [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). 106 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.1.0 2 | torchvision==0.4.0 3 | dominate==2.4.0 4 | visdom==0.1.8.9 5 | scipy==1.1.0 6 | numpy==1.16.4 7 | Pillow==4.3.0 8 | opencv-python==4.1.0.25 9 | dlib==19.18.0 10 | shapely==1.7.0 -------------------------------------------------------------------------------- /script/test.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python test.py --dataroot dataset/AB_140_aug3_H_hm2 --name apdrawinggan++_author --model apdrawingpp_style --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-apd70 3 | -------------------------------------------------------------------------------- /script/test_single.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python test.py --dataroot dataset/test_single --name apdrawinggan++_author --model test --use_resnet --netG resnet_9blocks --which_epoch 150 --how_many 1000 --gpu_ids 0 --gpu_ids_p 0 --imagefolder images-single -------------------------------------------------------------------------------- /script/train.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python train.py --dataroot dataset/AB_140_aug3_H_hm2 --name apdrawinggan++_1 --model apdrawingpp_style --use_resnet --netG resnet_9blocks --continue_train --continuity_loss --lambda_continuity 40.0 --gpu_ids 0 --gpu_ids_p 1 --display_env apdrawinggan++_1 --niter 200 --niter_decay 0 --lr 0.0001 --batch_size 1 --emphasis_conti_face --auxiliary_root auxiliaryeye2o 3 | 4 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from options.test_options import TestOptions 3 | from data import CreateDataLoader 4 | from models import create_model 5 | from util.visualizer import save_images 6 | from util import html 7 | 8 | 9 | if __name__ == '__main__': 10 | opt = TestOptions().parse() 11 | opt.num_threads = 1 # test code only supports num_threads = 1 12 | opt.batch_size = 1 # test code only supports batch_size = 1 13 | opt.serial_batches = True # no shuffle 14 | opt.no_flip = True # no flip 15 | opt.display_id = -1 # no visdom display 16 | data_loader = CreateDataLoader(opt) 17 | dataset = data_loader.load_data() 18 | model = create_model(opt) 19 | model.setup(opt) 20 | # create website 21 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 22 | #webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 23 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch),reflesh=0, folder=opt.imagefolder) 24 | if opt.test_continuity_loss: 25 | file_name = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'continuity.txt') 26 | file_name1 = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch), 'continuity-r.txt') 27 | if os.path.exists(file_name): 28 | os.remove(file_name) 29 | if os.path.exists(file_name1): 30 | os.remove(file_name1) 31 | # test 32 | #model.eval() 33 | for i, data in enumerate(dataset): 34 | if i >= opt.how_many:#test code only supports batch_size = 1, how_many means how many test images to run 35 | break 36 | model.set_input(data) 37 | model.test() 38 | visuals = model.get_current_visuals()#in test the loadSize is set to the same as fineSize 39 | img_path = model.get_image_paths() 40 | #if i % 5 == 0: 41 | # print('processing (%04d)-th image... %s' % (i, img_path)) 42 | save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) 43 | 44 | webpage.save() 45 | if opt.model == 'regressor': 46 | print(model.cnt) 47 | print(model.value/model.cnt) 48 | print(model.minval) 49 | print(model.avg/model.cnt) 50 | print(model.max) 51 | html = os.path.join(web_dir,'cindex'+opt.imagefolder[6:]+'.html') 52 | f=open(html,'w') 53 | print('',file=f,end='') 54 | print('',file=f,end='') 55 | print('',file=f,end='') 56 | print('',file=f,end='') 57 | print('',file=f,end='') 58 | print('',file=f,end='') 59 | print('',file=f,end='') 60 | for info in model.info: 61 | basen = os.path.basename(info[0])[:-4] 62 | print('',file=f,end='') 63 | print(''%basen,file=f,end='') 64 | print(''%(opt.imagefolder,basen),file=f,end='') 65 | print(''%info[1],file=f,end='') 66 | print(''%info[2],file=f,end='') 67 | print('',file=f,end='') 68 | print('
image namerealArealBfakeB
%s%.4f%.4f
',file=f,end='') 69 | f.close() 70 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data import CreateDataLoader 4 | from models import create_model 5 | from util.visualizer import Visualizer 6 | 7 | if __name__ == '__main__': 8 | start = time.time() 9 | opt = TrainOptions().parse() 10 | data_loader = CreateDataLoader(opt) 11 | dataset = data_loader.load_data() 12 | dataset_size = len(data_loader) 13 | print('#training images = %d' % dataset_size) 14 | 15 | model = create_model(opt) 16 | model.setup(opt) 17 | visualizer = Visualizer(opt) 18 | total_steps = 0 19 | model.save_networks2(opt.which_epoch) 20 | 21 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): 22 | epoch_start_time = time.time() 23 | iter_data_time = time.time() 24 | epoch_iter = 0 25 | 26 | for i, data in enumerate(dataset): 27 | iter_start_time = time.time() 28 | if total_steps % opt.print_freq == 0: 29 | t_data = iter_start_time - iter_data_time 30 | visualizer.reset() 31 | total_steps += opt.batch_size 32 | epoch_iter += opt.batch_size 33 | model.set_input(data) 34 | model.optimize_parameters() 35 | 36 | if total_steps % opt.display_freq == 0: 37 | save_result = total_steps % opt.update_html_freq == 0 38 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 39 | #print('display',total_steps) 40 | 41 | if total_steps % opt.print_freq == 0:#print freq 100 42 | losses = model.get_current_losses() 43 | t = (time.time() - iter_start_time) / opt.batch_size 44 | visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) 45 | if opt.display_id > 0: 46 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses) 47 | 48 | if total_steps % opt.save_latest_freq == 0: 49 | print('saving the latest model (epoch %d, total_steps %d)' % 50 | (epoch, total_steps)) 51 | #model.save_networks('latest') 52 | model.save_networks2('latest') 53 | 54 | iter_data_time = time.time() 55 | if epoch % opt.save_epoch_freq == 0: 56 | print('saving the model at the end of epoch %d, iters %d' % 57 | (epoch, total_steps)) 58 | #model.save_networks('latest') 59 | #model.save_networks(epoch) 60 | model.save_networks2('latest') 61 | model.save_networks2(epoch) 62 | 63 | print('End of epoch %d / %d \t Time Taken: %d sec' % 64 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 65 | model.update_learning_rate() 66 | 67 | print('Total Time Taken: %d sec' % (time.time() - start)) 68 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yiranran/APDrawingGAN2/7d04e32dc59a5675843f7885c13e830bf5ff56f6/util/__init__.py -------------------------------------------------------------------------------- /util/get_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import tarfile 4 | import requests 5 | from warnings import warn 6 | from zipfile import ZipFile 7 | from bs4 import BeautifulSoup 8 | from os.path import abspath, isdir, join, basename 9 | 10 | 11 | class GetData(object): 12 | """ 13 | 14 | Download CycleGAN or Pix2Pix Data. 15 | 16 | Args: 17 | technique : str 18 | One of: 'cyclegan' or 'pix2pix'. 19 | verbose : bool 20 | If True, print additional information. 21 | 22 | Examples: 23 | >>> from util.get_data import GetData 24 | >>> gd = GetData(technique='cyclegan') 25 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. 26 | 27 | """ 28 | 29 | def __init__(self, technique='cyclegan', verbose=True): 30 | url_dict = { 31 | 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets', 32 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' 33 | } 34 | self.url = url_dict.get(technique.lower()) 35 | self._verbose = verbose 36 | 37 | def _print(self, text): 38 | if self._verbose: 39 | print(text) 40 | 41 | @staticmethod 42 | def _get_options(r): 43 | soup = BeautifulSoup(r.text, 'lxml') 44 | options = [h.text for h in soup.find_all('a', href=True) 45 | if h.text.endswith(('.zip', 'tar.gz'))] 46 | return options 47 | 48 | def _present_options(self): 49 | r = requests.get(self.url) 50 | options = self._get_options(r) 51 | print('Options:\n') 52 | for i, o in enumerate(options): 53 | print("{0}: {1}".format(i, o)) 54 | choice = input("\nPlease enter the number of the " 55 | "dataset above you wish to download:") 56 | return options[int(choice)] 57 | 58 | def _download_data(self, dataset_url, save_path): 59 | if not isdir(save_path): 60 | os.makedirs(save_path) 61 | 62 | base = basename(dataset_url) 63 | temp_save_path = join(save_path, base) 64 | 65 | with open(temp_save_path, "wb") as f: 66 | r = requests.get(dataset_url) 67 | f.write(r.content) 68 | 69 | if base.endswith('.tar.gz'): 70 | obj = tarfile.open(temp_save_path) 71 | elif base.endswith('.zip'): 72 | obj = ZipFile(temp_save_path, 'r') 73 | else: 74 | raise ValueError("Unknown File Type: {0}.".format(base)) 75 | 76 | self._print("Unpacking Data...") 77 | obj.extractall(save_path) 78 | obj.close() 79 | os.remove(temp_save_path) 80 | 81 | def get(self, save_path, dataset=None): 82 | """ 83 | 84 | Download a dataset. 85 | 86 | Args: 87 | save_path : str 88 | A directory to save the data to. 89 | dataset : str, optional 90 | A specific dataset to download. 91 | Note: this must include the file extension. 92 | If None, options will be presented for you 93 | to choose from. 94 | 95 | Returns: 96 | save_path_full : str 97 | The absolute path to the downloaded data. 98 | 99 | """ 100 | if dataset is None: 101 | selected_dataset = self._present_options() 102 | else: 103 | selected_dataset = dataset 104 | 105 | save_path_full = join(save_path, selected_dataset.split('.')[0]) 106 | 107 | if isdir(save_path_full): 108 | warn("\n'{0}' already exists. Voiding Download.".format( 109 | save_path_full)) 110 | else: 111 | self._print('Downloading Data...') 112 | url = "{0}/{1}".format(self.url, selected_dataset) 113 | self._download_data(url, save_path=save_path) 114 | 115 | return abspath(save_path_full) 116 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0, folder='images'): 8 | self.title = title 9 | self.web_dir = web_dir 10 | #self.img_dir = os.path.join(self.web_dir, 'images') 11 | self.img_dir = os.path.join(self.web_dir, folder) 12 | self.folder = folder 13 | if not os.path.exists(self.web_dir): 14 | os.makedirs(self.web_dir) 15 | if not os.path.exists(self.img_dir): 16 | os.makedirs(self.img_dir) 17 | # print(self.img_dir) 18 | 19 | self.doc = dominate.document(title=title) 20 | if reflesh > 0: 21 | with self.doc.head: 22 | meta(http_equiv="reflesh", content=str(reflesh)) 23 | 24 | def get_image_dir(self): 25 | return self.img_dir 26 | 27 | def add_header(self, str): 28 | with self.doc: 29 | h3(str) 30 | 31 | def add_table(self, border=1): 32 | self.t = table(border=border, style="table-layout: fixed;") 33 | self.doc.add(self.t) 34 | 35 | def add_images(self, ims, txts, links, width=400): 36 | self.add_table() 37 | with self.t: 38 | with tr(): 39 | for im, txt, link in zip(ims, txts, links): 40 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 41 | with p(): 42 | with a(href=os.path.join('images', link)): 43 | #img(style="width:%dpx" % width, src=os.path.join('images', im)) 44 | img(style="width:%dpx" % width, src=os.path.join(self.folder, im)) 45 | br() 46 | p(txt) 47 | 48 | def save(self): 49 | #html_file = '%s/index.html' % self.web_dir 50 | html_file = '%s/index%s.html' % (self.web_dir, self.folder[6:]) 51 | f = open(html_file, 'wt') 52 | f.write(self.doc.render()) 53 | f.close() 54 | 55 | 56 | if __name__ == '__main__': 57 | html = HTML('web/', 'test_html') 58 | html.add_header('hello world') 59 | 60 | ims = [] 61 | txts = [] 62 | links = [] 63 | for n in range(4): 64 | ims.append('image_%d.png' % n) 65 | txts.append('text_%d' % n) 66 | links.append('image_%d.png' % n) 67 | html.add_images(ims, txts, links) 68 | html.save() 69 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | def __init__(self, pool_size): 7 | self.pool_size = pool_size 8 | if self.pool_size > 0: 9 | self.num_imgs = 0 10 | self.images = [] 11 | 12 | def query(self, images): 13 | if self.pool_size == 0: 14 | return images 15 | return_images = [] 16 | for image in images: 17 | image = torch.unsqueeze(image.data, 0) 18 | if self.num_imgs < self.pool_size: 19 | self.num_imgs = self.num_imgs + 1 20 | self.images.append(image) 21 | return_images.append(image) 22 | else: 23 | p = random.uniform(0, 1) 24 | if p > 0.5: 25 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 26 | tmp = self.images[random_id].clone() 27 | self.images[random_id] = image 28 | return_images.append(tmp) 29 | else: 30 | return_images.append(image) 31 | return_images = torch.cat(return_images, 0) 32 | return return_images 33 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | 7 | 8 | # Converts a Tensor into an image array (numpy) 9 | # |imtype|: the desired type of the converted numpy array 10 | def tensor2im(input_image, imtype=np.uint8): 11 | if isinstance(input_image, torch.Tensor): 12 | image_tensor = input_image.data 13 | else: 14 | return input_image 15 | image_numpy = image_tensor[0].cpu().float().numpy() 16 | if image_numpy.shape[0] == 1: 17 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 18 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 19 | return image_numpy.astype(imtype) 20 | 21 | 22 | def diagnose_network(net, name='network'): 23 | mean = 0.0 24 | count = 0 25 | for param in net.parameters(): 26 | if param.grad is not None: 27 | mean += torch.mean(torch.abs(param.grad.data)) 28 | count += 1 29 | if count > 0: 30 | mean = mean / count 31 | print(name) 32 | print(mean) 33 | 34 | 35 | def save_image(image_numpy, image_path): 36 | image_pil = Image.fromarray(image_numpy) 37 | image_pil.save(image_path) 38 | 39 | 40 | def print_numpy(x, val=True, shp=False): 41 | x = x.astype(np.float64) 42 | if shp: 43 | print('shape,', x.shape) 44 | if val: 45 | x = x.flatten() 46 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 47 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 48 | 49 | 50 | def mkdirs(paths): 51 | if isinstance(paths, list) and not isinstance(paths, str): 52 | for path in paths: 53 | mkdir(path) 54 | else: 55 | mkdir(paths) 56 | 57 | 58 | def mkdir(path): 59 | if not os.path.exists(path): 60 | os.makedirs(path) 61 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | from . import util 6 | from . import html 7 | from scipy.misc import imresize 8 | 9 | 10 | # save image to the disk 11 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 12 | image_dir = webpage.get_image_dir() 13 | short_path = ntpath.basename(image_path[0]) 14 | name = os.path.splitext(short_path)[0] 15 | 16 | webpage.add_header(name) 17 | ims, txts, links = [], [], [] 18 | 19 | for label, im_data in visuals.items(): 20 | im = util.tensor2im(im_data)#tensor to numpy array [-1,1]->[0,1]->[0,255] 21 | image_name = '%s_%s.png' % (name, label) 22 | save_path = os.path.join(image_dir, image_name) 23 | h, w, _ = im.shape 24 | if aspect_ratio > 1.0: 25 | im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') 26 | if aspect_ratio < 1.0: 27 | im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') 28 | util.save_image(im, save_path) 29 | 30 | ims.append(image_name) 31 | txts.append(label) 32 | links.append(image_name) 33 | webpage.add_images(ims, txts, links, width=width) 34 | 35 | 36 | class Visualizer(): 37 | def __init__(self, opt): 38 | self.display_id = opt.display_id 39 | self.use_html = opt.isTrain and not opt.no_html 40 | self.win_size = opt.display_winsize 41 | self.name = opt.name 42 | self.opt = opt 43 | self.saved = False 44 | if self.display_id > 0: 45 | import visdom 46 | self.ncols = opt.display_ncols 47 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env, raise_exceptions=True) 48 | 49 | if self.use_html: 50 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 51 | self.img_dir = os.path.join(self.web_dir, 'images') 52 | print('create web directory %s...' % self.web_dir) 53 | util.mkdirs([self.web_dir, self.img_dir]) 54 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 55 | with open(self.log_name, "a") as log_file: 56 | now = time.strftime("%c") 57 | log_file.write('================ Training Loss (%s) ================\n' % now) 58 | 59 | def reset(self): 60 | self.saved = False 61 | 62 | def throw_visdom_connection_error(self): 63 | print('\n\nCould not connect to Visdom server (https://github.com/facebookresearch/visdom) for displaying training progress.\nYou can suppress connection to Visdom using the option --display_id -1. To install visdom, run \n$ pip install visdom\n, and start the server by \n$ python -m visdom.server.\n\n') 64 | exit(1) 65 | 66 | # |visuals|: dictionary of images to display or save 67 | def display_current_results(self, visuals, epoch, save_result): 68 | if self.display_id > 0: # show images in the browser 69 | ncols = self.ncols 70 | if ncols > 0: 71 | ncols = min(ncols, len(visuals)) 72 | h, w = next(iter(visuals.values())).shape[:2] 73 | table_css = """""" % (w, h) 77 | title = self.name 78 | label_html = '' 79 | label_html_row = '' 80 | images = [] 81 | idx = 0 82 | for label, image in visuals.items(): 83 | image_numpy = util.tensor2im(image) 84 | label_html_row += '%s' % label 85 | images.append(image_numpy.transpose([2, 0, 1])) 86 | idx += 1 87 | if idx % ncols == 0: 88 | label_html += '%s' % label_html_row 89 | label_html_row = '' 90 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 91 | while idx % ncols != 0: 92 | images.append(white_image) 93 | label_html_row += '' 94 | idx += 1 95 | if label_html_row != '': 96 | label_html += '%s' % label_html_row 97 | # pane col = image row 98 | try: 99 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 100 | padding=2, opts=dict(title=title + ' images')) 101 | label_html = '%s
' % label_html 102 | self.vis.text(table_css + label_html, win=self.display_id + 2, 103 | opts=dict(title=title + ' labels')) 104 | except ConnectionError: 105 | self.throw_visdom_connection_error() 106 | 107 | else: 108 | idx = 1 109 | for label, image in visuals.items(): 110 | image_numpy = util.tensor2im(image) 111 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 112 | win=self.display_id + idx) 113 | idx += 1 114 | 115 | if self.use_html and (save_result or not self.saved): # save images to a html file 116 | self.saved = True 117 | for label, image in visuals.items(): 118 | image_numpy = util.tensor2im(image) 119 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 120 | util.save_image(image_numpy, img_path) 121 | # update website 122 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 123 | for n in range(epoch, 0, -1): 124 | webpage.add_header('epoch [%d]' % n) 125 | ims, txts, links = [], [], [] 126 | 127 | for label, image_numpy in visuals.items(): 128 | image_numpy = util.tensor2im(image) 129 | img_path = 'epoch%.3d_%s.png' % (n, label) 130 | ims.append(img_path) 131 | txts.append(label) 132 | links.append(img_path) 133 | webpage.add_images(ims, txts, links, width=self.win_size) 134 | webpage.save() 135 | 136 | def save_current_results1(self, visuals, epoch, epoch_iter): 137 | if not os.path.exists(self.img_dir+'/detailed'): 138 | os.mkdir(self.img_dir+'/detailed') 139 | for label, image in visuals.items(): 140 | image_numpy = util.tensor2im(image) 141 | img_path = os.path.join(self.img_dir, 'detailed', 'epoch%.3d_%.3d_%s.png' % (epoch, epoch_iter, label)) 142 | util.save_image(image_numpy, img_path) 143 | 144 | # losses: dictionary of error labels and values 145 | def plot_current_losses(self, epoch, counter_ratio, opt, losses): 146 | if not hasattr(self, 'plot_data'): 147 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 148 | self.plot_data['X'].append(epoch + counter_ratio) 149 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 150 | try: 151 | self.vis.line( 152 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 153 | Y=np.array(self.plot_data['Y']), 154 | opts={ 155 | 'title': self.name + ' loss over time', 156 | 'legend': self.plot_data['legend'], 157 | 'xlabel': 'epoch', 158 | 'ylabel': 'loss'}, 159 | win=self.display_id) 160 | except ConnectionError: 161 | self.throw_visdom_connection_error() 162 | 163 | # losses: same format as |losses| of plot_current_losses 164 | def print_current_losses(self, epoch, i, losses, t, t_data): 165 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) 166 | for k, v in losses.items(): 167 | message += '%s: %.6f ' % (k, v) 168 | 169 | print(message) 170 | with open(self.log_name, "a") as log_file: 171 | log_file.write('%s\n' % message) 172 | --------------------------------------------------------------------------------