├── loader ├── __init__.py ├── label_loader.py ├── image_label_loader.py ├── onehot_label_loader.py └── image_loader.py ├── models ├── __init__.py ├── base_model.py ├── flgan.py └── networks.py ├── util ├── __init__.py ├── makedirs.py ├── log.py └── confusion_matrix.py ├── .gitignore ├── README.md ├── .idea ├── vcs.xml ├── misc.xml ├── modules.xml ├── FL-GAN.iml └── workspace.xml ├── test ├── test_g1.py ├── test_g2.py └── test.py └── train.py /loader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints 2 | datasets 3 | *.pth 4 | *.log 5 | *.pyc 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cross-domain-Human-Parsing-via-Adversarial-Feature-and-Label-Adaptation 2 | The project is not completed yet, this is not the final version. 3 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /util/makedirs.py: -------------------------------------------------------------------------------- 1 | import os 2 | def mkdirs(paths): 3 | if isinstance(paths, list) and not isinstance(paths, str): 4 | for path in paths: 5 | mkdir(path) 6 | else: 7 | mkdir(paths) 8 | 9 | 10 | def mkdir(path): 11 | if not os.path.exists(path): 12 | os.makedirs(path) 13 | -------------------------------------------------------------------------------- /.idea/FL-GAN.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /util/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | 4 | class Logger: 5 | 6 | def __init__(self, log_file='log/log.txt', formatter='%(asctime)s\t%(message)s', user='rgh'): 7 | self.user = user 8 | self.log_file = log_file 9 | self.formatter = formatter 10 | self.logger = self.init_logger() 11 | 12 | def init_logger(self): 13 | # create logger with name 14 | # if not specified, it will be root 15 | logger = logging.getLogger(self.user) 16 | logger.setLevel(logging.DEBUG) 17 | 18 | # create a handler, write to log.txt 19 | # logging.FileHandler(self, filename, mode='a', encoding=None, delay=0) 20 | # A handler class which writes formatted logging records to disk files. 21 | fh = logging.FileHandler(self.log_file) 22 | fh.setLevel(logging.DEBUG) 23 | 24 | # create another handler, for stdout in terminal 25 | # A handler class which writes logging records to a stream 26 | sh = logging.StreamHandler() 27 | sh.setLevel(logging.DEBUG) 28 | 29 | # set formatter 30 | # formatter = logging.Formatter('%(asctime)s-%(name)s-%(levelname)s- %(message)s') 31 | formatter = logging.Formatter(self.formatter) 32 | fh.setFormatter(formatter) 33 | sh.setFormatter(formatter) 34 | 35 | # add handler to logger 36 | logger.addHandler(fh) 37 | logger.addHandler(sh) 38 | return logger 39 | 40 | def info(self,message=''): 41 | self.logger.info(message) 42 | 43 | def debug(self,message=''): 44 | self.logger.debug(message) 45 | -------------------------------------------------------------------------------- /loader/label_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import json 4 | import torch 5 | import torchvision 6 | import numpy as np 7 | import scipy.misc as m 8 | import scipy.io as io 9 | import matplotlib.pyplot as plt 10 | import Image 11 | from tqdm import tqdm 12 | from torch.utils import data 13 | 14 | class labelLoader(data.Dataset): 15 | def __init__(self, root, dataName, phase='train', lbl_size=(241,121)): 16 | self.root = root 17 | self.dataName = dataName 18 | self.phase = phase 19 | self.files = collections.defaultdict(list) 20 | self.now_idx = 0 21 | file_list = tuple(open(root + '/' + dataName + '/' + self.phase + '.txt', 'r')) 22 | file_list = [id_.rstrip() for id_ in file_list] 23 | self.files[self.phase] = file_list 24 | # print self.files['train'] 25 | def __len__(self): 26 | return len(self.files[self.phase]) 27 | 28 | def __getitem__(self, index): 29 | lbl_name = self.files[self.phase][index] 30 | lbl_path = self.root + '/' + self.dataName + '/' + 'label/' + self.phase + '/' + lbl_name + '.npy' 31 | 32 | lbl = np.load(lbl_path) 33 | lbl = lbl.copy() 34 | lbl = torch.from_numpy(lbl).float() 35 | 36 | return lbl 37 | 38 | def getBatch(self, index): 39 | self.now_idx = (self.now_idx + 1)%len() 40 | lbl_name = self.files[self.phase][index] 41 | lbl_path = self.root + '/' + self.dataName + '/' + 'label/' + self.phase + '/' + lbl_name + '.npy' 42 | 43 | lbl = np.load(lbl_path) 44 | lbl = lbl.copy() 45 | lbl = torch.from_numpy(lbl).float() 46 | 47 | return lbl 48 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | class BaseModel(): 6 | def name(self): 7 | return 'BaseModel' 8 | 9 | def initialize(self, opt): 10 | self.opt = opt 11 | self.gpu_ids = opt['device_ids'] 12 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 13 | self.save_dir = os.path.join(opt['checkpoints_dir'], opt['name']) 14 | 15 | def set_input(self, input): 16 | self.input = input 17 | 18 | def forward(self): 19 | pass 20 | 21 | # used in test time, no backprop 22 | def test(self): 23 | pass 24 | 25 | def get_image_paths(self): 26 | pass 27 | 28 | def optimize_parameters(self): 29 | pass 30 | 31 | def get_current_visuals(self): 32 | return self.input 33 | 34 | def get_current_errors(self): 35 | return {} 36 | 37 | def save(self, label): 38 | pass 39 | 40 | # helper saving function that can be used by subclasses 41 | def save_network(self, network, network_label, epoch_label, gpu_ids): 42 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 43 | save_path = os.path.join(self.save_dir, save_filename) 44 | torch.save(network.cpu().state_dict(), save_path) 45 | if len(gpu_ids) and torch.cuda.is_available(): 46 | network.cuda(device_id=gpu_ids[0]) 47 | 48 | # helper loading function that can be used by subclasses 49 | def load_network(self, network, network_label, epoch_label): 50 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 51 | save_path = os.path.join(self.save_dir, save_filename) 52 | network.load_state_dict(torch.load(save_path)) 53 | 54 | def update_learning_rate(self): 55 | pass 56 | -------------------------------------------------------------------------------- /loader/image_label_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import json 4 | import torch 5 | import torchvision 6 | import numpy as np 7 | import scipy.misc as m 8 | import scipy.io as io 9 | import matplotlib.pyplot as plt 10 | import Image 11 | from tqdm import tqdm 12 | from torch.utils import data 13 | 14 | class imageLabelLoader(data.Dataset): 15 | def __init__(self, root, dataName, phase='train', img_size=(241,121)): 16 | self.root = root 17 | self.dataName = dataName 18 | self.phase = phase 19 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 20 | self.mean = np.array([128, 128, 128]) 21 | self.files = collections.defaultdict(list) 22 | """ 23 | for phase in ['train', 'val', 'train+5light']: 24 | file_list = tuple(open(root +'/' + dataName +'/'+ phase + '.txt', 'r')) 25 | file_list = [id_.rstrip() for id_ in file_list] 26 | self.files[phase] = file_list 27 | """ 28 | file_list = tuple(open(root + '/' + dataName + '/' + self.phase + '.txt', 'r')) 29 | file_list = [id_.rstrip() for id_ in file_list] 30 | self.files[self.phase] = file_list 31 | # print self.files['train'] 32 | def __len__(self): 33 | return len(self.files[self.phase]) 34 | 35 | def __getitem__(self, index): 36 | img_name = self.files[self.phase][index] 37 | img_path = self.root + '/' + self.dataName + '/' + 'image/' + self.phase+'/'+img_name + '.jpg' 38 | lbl_path = self.root+ '/' + self.dataName + '/' + 'label/' + self.phase+'/'+img_name + '.png' 39 | 40 | img = Image.open(img_path) 41 | #if img.shape 42 | img_size = img.size 43 | if self.img_size[1] != img_size[0] or self.img_size[0] != img_size[1]: 44 | img = img.resize((self.img_size[1], self.img_size[0]), resample=Image.BILINEAR) 45 | img = np.array(img, dtype=np.float32) 46 | if not (len(img.shape) == 3 and img.shape[2] == 3): 47 | img = img.reshape(img.shape[0], img.shape[1], 1) 48 | img = img.repeat(3, 2) 49 | img -= self.mean 50 | img = img[:, :, ::-1]# RGB -> BGR 51 | img = img.transpose(2, 0, 1) 52 | img = img.copy() 53 | 54 | lbl = Image.open(lbl_path) 55 | lbl_size = lbl.size 56 | 57 | if self.img_size[1] != lbl_size[0] or self.img_size[0] != lbl_size[1]: 58 | lbl = lbl.resize((self.img_size[1], self.img_size[0]), resample=Image.BILINEAR) 59 | 60 | lbl = np.array(lbl) 61 | lbl = lbl.copy() 62 | 63 | img = torch.from_numpy(img).float() 64 | lbl = torch.from_numpy(lbl).long() 65 | return img, lbl -------------------------------------------------------------------------------- /loader/onehot_label_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import json 4 | import torch 5 | import torchvision 6 | import numpy as np 7 | import scipy.misc as m 8 | import scipy.io as io 9 | import matplotlib.pyplot as plt 10 | import Image 11 | from tqdm import tqdm 12 | from torch.utils import data 13 | 14 | class onehotLabelLoader(data.Dataset): 15 | def __init__(self, root, dataName, phase='train', label_nums=12, lbl_size=(241,121)): 16 | self.root = root 17 | self.dataName = dataName 18 | self.phase = phase 19 | self.label_nums = label_nums 20 | self.lbl_size = lbl_size if isinstance(lbl_size, tuple) else (lbl_size, lbl_size) 21 | self.files = collections.defaultdict(list) 22 | self.now_idx = 0 23 | file_list = tuple(open(root + '/' + dataName + '/' + self.phase + '.txt', 'r')) 24 | file_list = [id_.rstrip() for id_ in file_list] 25 | self.files[self.phase] = file_list 26 | # print self.files['train'] 27 | def __len__(self): 28 | return len(self.files[self.phase]) 29 | 30 | def __getitem__(self, index): 31 | lbl_name = self.files[self.phase][index] 32 | lbl_path = self.root + '/' + self.dataName + '/' + 'label/' + self.phase + '/' + lbl_name + '.png' 33 | 34 | lbl = Image.open(lbl_path) 35 | lbl_size = lbl.size 36 | if self.lbl_size[1] != lbl_size[0] or self.lbl_size[0] != lbl_size[1]: 37 | lbl = lbl.resize((self.lbl_size[1], self.lbl_size[0]), resample=Image.BILINEAR) 38 | 39 | lbl = np.array(lbl, dtype=np.float32) 40 | lbl_onehot = np.zeros((self.label_nums, self.lbl_size[0], self.lbl_size[1])) 41 | for i in range(self.label_nums): 42 | lbl_onehot[i][lbl==i] = 1 43 | 44 | lbl = lbl.copy() 45 | 46 | lbl = torch.from_numpy(lbl).float() 47 | lbl_onehot = torch.from_numpy(lbl_onehot).float() 48 | 49 | return lbl, lbl_onehot 50 | 51 | def getBatch(self, index): 52 | self.now_idx = (self.now_idx + 1)%len() 53 | lbl_name = self.files[self.phase][index] 54 | lbl_path = self.root + '/' + self.dataName + '/' + 'label/' + self.phase + '/' + lbl_name + '.png' 55 | 56 | lbl = Image.open(lbl_path) 57 | lbl_size = lbl.size 58 | if self.lbl_size[1] != lbl_size[0] or self.lbl_size[0] != lbl_size[1]: 59 | lbl = lbl.resize((self.lbl_size[1], self.lbl_size[0]), resample=Image.BILINEAR) 60 | 61 | lbl = np.array(lbl, dtype=np.float32) 62 | lbl_onehot = np.zeros((self.label_nums, self.lbl_size[0], self.lbl_size[1])) 63 | for i in range(self.label_nums): 64 | lbl_onehot[i][lbl == i] = 1 65 | 66 | lbl = lbl.copy() 67 | 68 | lbl = torch.from_numpy(lbl).float() 69 | lbl_onehot = torch.from_numpy(lbl_onehot).float() 70 | 71 | return lbl, lbl_onehot 72 | -------------------------------------------------------------------------------- /loader/image_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import collections 3 | import json 4 | import torch 5 | import torchvision 6 | import numpy as np 7 | import scipy.misc as m 8 | import scipy.io as io 9 | import matplotlib.pyplot as plt 10 | import Image 11 | from tqdm import tqdm 12 | from torch.utils import data 13 | 14 | class imageLoader(data.Dataset): 15 | def __init__(self, root, dataName, phase='train', img_size=(241,121)): 16 | self.root = root 17 | self.dataName = dataName 18 | self.phase = phase 19 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) 20 | self.mean = np.array([128, 128, 128]) 21 | self.files = collections.defaultdict(list) 22 | self.now_idx = 0 23 | """ 24 | for phase in ['train', 'val', 'train+unlabel']: 25 | file_list = tuple(open(root + '/' + dataName + '/' + phase + '.txt', 'r')) 26 | file_list = [id_.rstrip() for id_ in file_list] 27 | self.files[phase] = file_list 28 | """ 29 | file_list = tuple(open(root + '/' + dataName + '/' + self.phase + '.txt', 'r')) 30 | file_list = [id_.rstrip() for id_ in file_list] 31 | self.files[self.phase] = file_list 32 | # print self.files['train'] 33 | def __len__(self): 34 | return len(self.files[self.phase]) 35 | 36 | def __getitem__(self, index): 37 | img_name = self.files[self.phase][index] 38 | img_path = self.root + '/' + self.dataName + '/' + 'image/' + self.phase + '/' + img_name + '.jpg' 39 | 40 | img = Image.open(img_path) 41 | img_size = img.size 42 | if self.img_size[1] != img_size[0] or self.img_size[0] != img_size[1]: 43 | img = img.resize((self.img_size[1], self.img_size[0]), resample=Image.BILINEAR) 44 | 45 | img = np.array(img, dtype=np.float32) 46 | if not (len(img.shape) == 3 and img.shape[2] == 3): 47 | img = img.reshape(img.shape[0], img.shape[1], 1) 48 | img = img.repeat(3, 2) 49 | img -= self.mean 50 | img = img[:, :, ::-1]# RGB -> BGR 51 | img = img.transpose(2, 0, 1) 52 | img = img.copy() 53 | 54 | img = torch.from_numpy(img).float() 55 | 56 | return img 57 | 58 | def getBatch(self, index): 59 | self.now_idx = (self.now_idx + 1)%len() 60 | img_name = self.files[self.phase][index] 61 | img_path = self.root + '/' + self.dataName + '/' + 'image/' + self.phase + '/' + img_name + '.jpg' 62 | 63 | img = Image.open(img_path) 64 | img = img.resize((self.img_size[1], self.img_size[0]), resample=Image.BILINEAR) 65 | img = np.array(img, dtype=np.float32) 66 | if not (len(img.shape) == 3 and img.shape[2] == 3): 67 | img = img.reshape(img.shape[0], img.shape[1], 1) 68 | img = img.repeat(3, 2) 69 | img -= self.mean 70 | img = img[:, :, ::-1]# RGB -> BGR 71 | img = img.transpose(2, 0, 1) 72 | img = img.copy() 73 | 74 | img = torch.from_numpy(img).float() 75 | 76 | return img 77 | -------------------------------------------------------------------------------- /util/confusion_matrix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | class ConfusionMatrix: 3 | def __init__(self, size=12): 4 | self.size = size 5 | self.diag = np.zeros(self.size) 6 | self.act_sum = np.zeros(self.size) 7 | self.pre_sum = np.zeros(self.size) 8 | 9 | def reset(self): 10 | self.diag = np.zeros(self.size) 11 | self.act_sum = np.zeros(self.size) 12 | self.pre_sum = np.zeros(self.size) 13 | 14 | def update(self, actual, predicted): 15 | for i in range(self.size): 16 | act = actual == i 17 | pre = predicted == i 18 | I = act & pre 19 | self.diag[i] += np.sum(I) 20 | self.act_sum[i] += np.sum(act) 21 | self.pre_sum[i] += np.sum(pre) 22 | 23 | def accuracy(self): 24 | ''' accuracy ''' 25 | diag_sum = np.sum(self.diag) 26 | total_sum = np.sum(self.act_sum) 27 | if total_sum == 0: 28 | return 0 29 | else: 30 | return diag_sum / total_sum 31 | 32 | def fg_accuracy(self): 33 | '''fg_accuracy''' 34 | diag_sum = np.sum(self.diag) - self.diag[0] 35 | total_sum = np.sum(self.act_sum) - self.act_sum[0] 36 | if total_sum == 0: 37 | return 0 38 | else: 39 | return diag_sum / total_sum 40 | 41 | def avg_precision(self): 42 | '''avg_precision: ignore the label that isn't in imgs of gt''' 43 | total_precision = 0 44 | count = 0 45 | for i in range(self.size): 46 | if self.pre_sum[i] > 0: 47 | total_precision += self.diag[i] / self.pre_sum[i] 48 | count += 1 49 | if count == 0: 50 | return 0 51 | else: 52 | return total_precision / count 53 | 54 | def avg_recall(self): 55 | '''avg_recall: ignore the label that isn't in imgs of gt''' 56 | total_recall = 0 57 | count = 0 58 | for i in range(self.size): 59 | if self.act_sum[i] > 0: 60 | total_recall += self.diag[i] / self.act_sum[i] 61 | count += 1 62 | if count == 0: 63 | return 0 64 | else: 65 | return total_recall / count 66 | 67 | def avg_f1score(self): 68 | '''avgF1score: ignore the label that isn't in imgs of gt''' 69 | total_f1score = 0 70 | count = 0 71 | for i in range(self.size): 72 | t = self.pre_sum[i] + self.act_sum[i] 73 | if t > 0: 74 | total_f1score += 2 * self.diag[i] / t 75 | count += 1 76 | if count == 0: 77 | return 0 78 | else: 79 | return total_f1score / count 80 | 81 | def f1score(self): 82 | '''F1score: ignore the label that isn't in imgs of gt''' 83 | f1score = [] 84 | for i in range(self.size): 85 | t = self.pre_sum[i] + self.act_sum[i] 86 | if t > 0: 87 | f1score.append(2 * self.diag[i] / t) 88 | else: 89 | f1score.append(-1) 90 | return f1score 91 | 92 | 93 | 94 | def mean_iou(self): 95 | '''meanIoU: ignore the label that isn't in imgs of gt''' 96 | total_iou = 0 97 | count = 0 98 | for i in range(self.size): 99 | I = self.diag[i] 100 | U = self.act_sum[i] + self.pre_sum[i] - I 101 | if U > 0: 102 | total_iou += I / U 103 | count += 1 104 | return total_iou / count 105 | 106 | def all_acc(self): 107 | return { 108 | 'accuracy':self.accuracy(), 109 | 'fg_accuracy':self.fg_accuracy(), 110 | 'avg_precision':self.avg_precision(), 111 | 'avg_recall':self.avg_recall(), 112 | 'avg_f1score':self.avg_f1score(), 113 | 'mean_iou':self.mean_iou(), 114 | } -------------------------------------------------------------------------------- /test/test_g1.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data 3 | from loader.image_label_loader import imageLabelLoader 4 | from models.deeplab_gan_s2t_with_refine_4 import deeplabGanS2TWithRefine4 5 | from util.confusion_matrix import ConfusionMatrix 6 | import torch 7 | import numpy as np 8 | import scipy.misc 9 | def color(label): 10 | bg = label == 0 11 | bg = bg.reshape(bg.shape[0], bg.shape[1]) 12 | face = label == 1 13 | face = face.reshape(face.shape[0], face.shape[1]) 14 | hair = label == 2 15 | hair = hair.reshape(hair.shape[0], hair.shape[1]) 16 | Upcloth = label == 3 17 | Upcloth = Upcloth.reshape(Upcloth.shape[0], Upcloth.shape[1]) 18 | Larm = label == 4 19 | Larm = Larm.reshape(Larm.shape[0], Larm.shape[1]) 20 | Rarm = label == 5 21 | Rarm = Rarm.reshape(Rarm.shape[0], Rarm.shape[1]) 22 | pants = label == 6 23 | pants = pants.reshape(pants.shape[0], pants.shape[1]) 24 | Lleg = label == 7 25 | Lleg = Lleg.reshape(Lleg.shape[0], Lleg.shape[1]) 26 | Rleg = label == 8 27 | Rleg = Rleg.reshape(Rleg.shape[0], Rleg.shape[1]) 28 | dress = label == 9 29 | dress = dress.reshape(dress.shape[0], dress.shape[1]) 30 | Lshoe = label == 10 31 | Lshoe = Lshoe.reshape(Lshoe.shape[0], Lshoe.shape[1]) 32 | Rshoe = label == 11 33 | Rshoe = Rshoe.reshape(Rshoe.shape[0], Rshoe.shape[1]) 34 | 35 | # bag = label == 12 36 | # bag = bag.reshape(bag.shape[0], bag.shape[1]) 37 | 38 | # repeat 2nd axis to 3 39 | label = label.reshape(bg.shape[0], bg.shape[1], 1) 40 | label = label.repeat(3, 2) 41 | R = label[:, :, 2] 42 | G = label[:, :, 1] 43 | B = label[:, :, 0] 44 | R[bg] = 230 45 | G[bg] = 230 46 | B[bg] = 230 47 | 48 | R[face] = 255 49 | G[face] = 215 50 | B[face] = 0 51 | 52 | R[hair] = 80 53 | G[hair] = 49 54 | B[hair] = 49 55 | 56 | R[Upcloth] = 51 57 | G[Upcloth] = 0 58 | B[Upcloth] = 255 59 | 60 | R[Larm] = 2 61 | G[Larm] = 251 62 | B[Larm] = 49 63 | 64 | R[Rarm] = 141 65 | G[Rarm] = 255 66 | B[Rarm] = 212 67 | 68 | R[pants] = 160 69 | G[pants] = 0 70 | B[pants] = 255 71 | 72 | R[Lleg] = 0 73 | G[Lleg] = 204 74 | B[Lleg] = 255 75 | 76 | R[Rleg] = 191 77 | G[Rleg] = 255 78 | B[Rleg] = 248 79 | 80 | R[dress] = 255 81 | G[dress] = 182 82 | B[dress] = 185 83 | 84 | R[Lshoe] = 180 85 | G[Lshoe] = 122 86 | B[Lshoe] = 121 87 | 88 | R[Rshoe] = 202 89 | G[Rshoe] = 160 90 | B[Rshoe] = 57 91 | 92 | # R[bag] = 255 93 | # G[bag] = 1 94 | # B[bag] = 1 95 | return label 96 | def update_confusion_matrix(matrix, output, target): 97 | values, indices = output.max(1) 98 | output = indices 99 | target = target.cpu().numpy() 100 | output = output.cpu().numpy() 101 | matrix.update(target, output) 102 | return matrix 103 | 104 | def main(): 105 | if len(args['device_ids']) > 0: 106 | torch.cuda.set_device(args['device_ids'][0]) 107 | 108 | test_loader = data.DataLoader(imageLabelLoader(args['data_path'], dataName=args['domainB'], phase='val'), 109 | batch_size=args['batch_size'], 110 | num_workers=args['num_workers'], shuffle=False) 111 | gym = deeplabGanS2TWithRefine4() 112 | gym.initialize(args) 113 | gym.load('/home/ben/mathfinder/PROJECT/AAAI2017/our_Method/v3/deeplab_feature_adaptation/checkpoints/Lip_to_July_g1/best_Ori_on_B_model.pth') 114 | gym.eval() 115 | matrix = ConfusionMatrix(args['label_nums']) 116 | for i, (image, label) in enumerate(test_loader): 117 | label = label.cuda(async=True) 118 | target_var = torch.autograd.Variable(label, volatile=True) 119 | 120 | gym.test(False, image) 121 | output = gym.output 122 | 123 | matrix = update_confusion_matrix(matrix, output.data, label) 124 | print(matrix.avg_f1score()) 125 | print(matrix.f1score()) 126 | 127 | 128 | if __name__ == "__main__": 129 | global args 130 | args = { 131 | 'test_init':False, 132 | 'label_nums':12, 133 | 'l_rate':1e-8, 134 | 'lr_gan': 0.00001, 135 | 'lr_refine': 1e-6, 136 | 'beta1': 0.5, 137 | 'data_path':'datasets', 138 | 'n_epoch':1000, 139 | 'batch_size':10, 140 | 'num_workers':10, 141 | 'print_freq':10, 142 | 'device_ids':[1], 143 | 'domainA': 'Lip', 144 | 'domainB': 'July', 145 | 'weigths_pool': 'pretrain_models', 146 | 'pretrain_model': 'deeplab.pth', 147 | 'fineSizeH':241, 148 | 'fineSizeW':121, 149 | 'input_nc':3, 150 | 'name': 'v3_s->t_Refine_4', 151 | 'checkpoints_dir': 'checkpoints', 152 | 'net_D': 'NoBNSinglePathdilationMultOutputNet', 153 | 'use_lsgan': True, 154 | 'resume':None#'checkpoints/v3_1/', 155 | } 156 | main() -------------------------------------------------------------------------------- /test/test_g2.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data 3 | from loader.image_label_loader import imageLabelLoader 4 | from models.deeplab_g2 import deeplabG2 5 | from util.confusion_matrix import ConfusionMatrix 6 | import torch 7 | import numpy as np 8 | import scipy.misc 9 | def color(label): 10 | bg = label == 0 11 | bg = bg.reshape(bg.shape[0], bg.shape[1]) 12 | face = label == 1 13 | face = face.reshape(face.shape[0], face.shape[1]) 14 | hair = label == 2 15 | hair = hair.reshape(hair.shape[0], hair.shape[1]) 16 | Upcloth = label == 3 17 | Upcloth = Upcloth.reshape(Upcloth.shape[0], Upcloth.shape[1]) 18 | Larm = label == 4 19 | Larm = Larm.reshape(Larm.shape[0], Larm.shape[1]) 20 | Rarm = label == 5 21 | Rarm = Rarm.reshape(Rarm.shape[0], Rarm.shape[1]) 22 | pants = label == 6 23 | pants = pants.reshape(pants.shape[0], pants.shape[1]) 24 | Lleg = label == 7 25 | Lleg = Lleg.reshape(Lleg.shape[0], Lleg.shape[1]) 26 | Rleg = label == 8 27 | Rleg = Rleg.reshape(Rleg.shape[0], Rleg.shape[1]) 28 | dress = label == 9 29 | dress = dress.reshape(dress.shape[0], dress.shape[1]) 30 | Lshoe = label == 10 31 | Lshoe = Lshoe.reshape(Lshoe.shape[0], Lshoe.shape[1]) 32 | Rshoe = label == 11 33 | Rshoe = Rshoe.reshape(Rshoe.shape[0], Rshoe.shape[1]) 34 | 35 | # bag = label == 12 36 | # bag = bag.reshape(bag.shape[0], bag.shape[1]) 37 | 38 | # repeat 2nd axis to 3 39 | label = label.reshape(bg.shape[0], bg.shape[1], 1) 40 | label = label.repeat(3, 2) 41 | R = label[:, :, 2] 42 | G = label[:, :, 1] 43 | B = label[:, :, 0] 44 | R[bg] = 230 45 | G[bg] = 230 46 | B[bg] = 230 47 | 48 | R[face] = 255 49 | G[face] = 215 50 | B[face] = 0 51 | 52 | R[hair] = 80 53 | G[hair] = 49 54 | B[hair] = 49 55 | 56 | R[Upcloth] = 51 57 | G[Upcloth] = 0 58 | B[Upcloth] = 255 59 | 60 | R[Larm] = 2 61 | G[Larm] = 251 62 | B[Larm] = 49 63 | 64 | R[Rarm] = 141 65 | G[Rarm] = 255 66 | B[Rarm] = 212 67 | 68 | R[pants] = 160 69 | G[pants] = 0 70 | B[pants] = 255 71 | 72 | R[Lleg] = 0 73 | G[Lleg] = 204 74 | B[Lleg] = 255 75 | 76 | R[Rleg] = 191 77 | G[Rleg] = 255 78 | B[Rleg] = 248 79 | 80 | R[dress] = 255 81 | G[dress] = 182 82 | B[dress] = 185 83 | 84 | R[Lshoe] = 180 85 | G[Lshoe] = 122 86 | B[Lshoe] = 121 87 | 88 | R[Rshoe] = 202 89 | G[Rshoe] = 160 90 | B[Rshoe] = 57 91 | 92 | # R[bag] = 255 93 | # G[bag] = 1 94 | # B[bag] = 1 95 | return label 96 | def update_confusion_matrix(matrix, output, target): 97 | values, indices = output.max(1) 98 | output = indices 99 | target = target.cpu().numpy() 100 | output = output.cpu().numpy() 101 | matrix.update(target, output) 102 | return matrix 103 | 104 | def main(): 105 | if len(args['device_ids']) > 0: 106 | torch.cuda.set_device(args['device_ids'][0]) 107 | 108 | test_loader = data.DataLoader(imageLabelLoader(args['data_path'], dataName=args['domainB'], phase='val'), 109 | batch_size=args['batch_size'], 110 | num_workers=args['num_workers'], shuffle=False) 111 | gym = deeplabG2() 112 | gym.initialize(args) 113 | gym.load('/home/ben/mathfinder/PROJECT/AAAI2017/our_Method/v3/deeplab_feature_adaptation/checkpoints/g2_lr_gan=0.00000002_interval_G=5_interval_D=5_net_D=lsganMultOutput_D/best_Ori_on_B_model.pth') 114 | gym.eval() 115 | matrix = ConfusionMatrix(args['label_nums']) 116 | for i, (image, label) in enumerate(test_loader): 117 | label = label.cuda(async=True) 118 | target_var = torch.autograd.Variable(label, volatile=True) 119 | 120 | gym.test(image) 121 | output = gym.output 122 | 123 | matrix = update_confusion_matrix(matrix, output.data, label) 124 | print(matrix.avg_f1score()) 125 | print(matrix.f1score()) 126 | 127 | 128 | if __name__ == "__main__": 129 | global args 130 | args = { 131 | 'test_init':False, 132 | 'label_nums':12, 133 | 'l_rate':1e-8, 134 | 'lr_gan': 0.00000002, 135 | 'beta1': 0.5, 136 | 'interval_G':5, 137 | 'interval_D':5, 138 | 'data_path':'datasets', 139 | 'n_epoch':1000, 140 | 'batch_size':10, 141 | 'num_workers':10, 142 | 'print_freq':100, 143 | 'device_ids':[1], 144 | 'domainA': 'Lip', 145 | 'domainB': 'Indoor', 146 | 'weigths_pool': 'pretrain_models', 147 | 'pretrain_model': 'deeplab.pth', 148 | 'fineSizeH':241, 149 | 'fineSizeW':121, 150 | 'input_nc':3, 151 | 'name': 'train_iou0.4_onehot_g2_lr_gan=0.00000002_interval_G=5_interval_D=5_net_D=lsganMultOutput_D', 152 | 'checkpoints_dir': 'checkpoints', 153 | 'net_D': 'lsganMultOutput_D', 154 | 'use_lsgan': True, 155 | 'resume':None,#'checkpoints/g2_lr_gan=0.0000002_interval_G=5_interval_D=10_net_D=lsganMultOutput_D/best_Ori_on_B_model.pth',#'checkpoints/v3_1/', 156 | 'if_adv_train':True, 157 | 'if_adaptive':True, 158 | } 159 | main() -------------------------------------------------------------------------------- /test/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data 3 | from loader.image_label_loader import imageLabelLoader 4 | from models.deeplab_g1_g2 import deeplabG1G2 5 | from util.confusion_matrix import ConfusionMatrix 6 | import torch 7 | import numpy as np 8 | import scipy.misc 9 | def color(label): 10 | bg = label == 0 11 | bg = bg.reshape(bg.shape[0], bg.shape[1]) 12 | face = label == 1 13 | face = face.reshape(face.shape[0], face.shape[1]) 14 | hair = label == 2 15 | hair = hair.reshape(hair.shape[0], hair.shape[1]) 16 | Upcloth = label == 3 17 | Upcloth = Upcloth.reshape(Upcloth.shape[0], Upcloth.shape[1]) 18 | Larm = label == 4 19 | Larm = Larm.reshape(Larm.shape[0], Larm.shape[1]) 20 | Rarm = label == 5 21 | Rarm = Rarm.reshape(Rarm.shape[0], Rarm.shape[1]) 22 | pants = label == 6 23 | pants = pants.reshape(pants.shape[0], pants.shape[1]) 24 | Lleg = label == 7 25 | Lleg = Lleg.reshape(Lleg.shape[0], Lleg.shape[1]) 26 | Rleg = label == 8 27 | Rleg = Rleg.reshape(Rleg.shape[0], Rleg.shape[1]) 28 | dress = label == 9 29 | dress = dress.reshape(dress.shape[0], dress.shape[1]) 30 | Lshoe = label == 10 31 | Lshoe = Lshoe.reshape(Lshoe.shape[0], Lshoe.shape[1]) 32 | Rshoe = label == 11 33 | Rshoe = Rshoe.reshape(Rshoe.shape[0], Rshoe.shape[1]) 34 | 35 | # bag = label == 12 36 | # bag = bag.reshape(bag.shape[0], bag.shape[1]) 37 | 38 | # repeat 2nd axis to 3 39 | label = label.reshape(bg.shape[0], bg.shape[1], 1) 40 | label = label.repeat(3, 2) 41 | R = label[:, :, 0] 42 | G = label[:, :, 1] 43 | B = label[:, :, 2] 44 | R[bg] = 230 45 | G[bg] = 230 46 | B[bg] = 230 47 | 48 | R[face] = 255 49 | G[face] = 215 50 | B[face] = 0 51 | 52 | R[hair] = 80 53 | G[hair] = 49 54 | B[hair] = 49 55 | 56 | R[Upcloth] = 51 57 | G[Upcloth] = 0 58 | B[Upcloth] = 255 59 | 60 | R[Larm] = 2 61 | G[Larm] = 251 62 | B[Larm] = 49 63 | 64 | R[Rarm] = 141 65 | G[Rarm] = 255 66 | B[Rarm] = 212 67 | 68 | R[pants] = 160 69 | G[pants] = 0 70 | B[pants] = 255 71 | 72 | R[Lleg] = 0 73 | G[Lleg] = 204 74 | B[Lleg] = 255 75 | 76 | R[Rleg] = 191 77 | G[Rleg] = 255 78 | B[Rleg] = 248 79 | 80 | R[dress] = 255 81 | G[dress] = 182 82 | B[dress] = 185 83 | 84 | R[Lshoe] = 180 85 | G[Lshoe] = 122 86 | B[Lshoe] = 121 87 | 88 | R[Rshoe] = 202 89 | G[Rshoe] = 160 90 | B[Rshoe] = 57 91 | 92 | # R[bag] = 255 93 | # G[bag] = 1 94 | # B[bag] = 1 95 | return label 96 | def update_confusion_matrix(matrix, output, target): 97 | values, indices = output.max(1) 98 | output = indices 99 | target = target.cpu().numpy() 100 | output = output.cpu().numpy() 101 | matrix.update(target, output) 102 | return matrix 103 | 104 | def main(): 105 | if len(args['device_ids']) > 0: 106 | torch.cuda.set_device(args['device_ids'][0]) 107 | 108 | test_loader = data.DataLoader(imageLabelLoader(args['data_path'], dataName=args['domainB'], phase='val'), 109 | batch_size=args['batch_size'], 110 | num_workers=args['num_workers'], shuffle=False) 111 | gym = deeplabG1G2() 112 | gym.initialize(args) 113 | gym.load('/home/ben/mathfinder/PROJECT/AAAI2017/our_Method/v3/deeplab_feature_adaptation/checkpoints/Lip_to_July_lr_g1=0.00001_lr_g2=0.00000002_interval_g1=5_interval_d1=5_net_D=lsganMultOutput_D_if_adaptive=True_resume_decay=g2/best_Ori_on_B_model.pth') 114 | gym.eval() 115 | matrix = ConfusionMatrix(args['label_nums']) 116 | for i, (image, label) in enumerate(test_loader): 117 | label = label.cuda(async=True) 118 | target_var = torch.autograd.Variable(label, volatile=True) 119 | 120 | gym.test(image) 121 | output = gym.output 122 | 123 | matrix = update_confusion_matrix(matrix, output.data, label) 124 | print(matrix.all_acc()) 125 | print(matrix.f1score()) 126 | 127 | 128 | if __name__ == "__main__": 129 | global args 130 | args = { 131 | 'test_init':False, 132 | 'test_init':False, 133 | 'label_nums':12, 134 | 'l_rate':1e-8, 135 | 'lr_g1': 0.00001, 136 | 'lr_g2': 0.00000002, 137 | 'beta1': 0.5, 138 | 'interval_g2':5, 139 | 'interval_d2':5, 140 | 'data_path':'datasets', 141 | 'n_epoch':1000, 142 | 'batch_size':10, 143 | 'num_workers':10, 144 | 'print_freq':20, 145 | 'device_ids':[0], 146 | 'domainA': 'Lip', 147 | 'domainB': 'July', 148 | 'weigths_pool': 'pretrain_models', 149 | 'pretrain_model': 'deeplab.pth', 150 | 'fineSizeH':241, 151 | 'fineSizeW':121, 152 | 'input_nc':3, 153 | 'name': 'Lip_to_July_lr_g1=0.00001_lr_g2=0.00000002_interval_g1=5_interval_d1=5_net_D=lsganMultOutput_D_if_adaptive=True_resume_decay=g2', 154 | 'checkpoints_dir': 'checkpoints', 155 | 'net_d1': 'NoBNSinglePathdilationMultOutputNet', 156 | 'net_d2': 'lsganMultOutput_D', 157 | 'use_lsgan': True, 158 | 'resume':None,#'checkpoints/g2_lr_gan=0.0000002_interval_G=5_interval_D=10_net_D=lsganMultOutput_D/best_Ori_on_B_model.pth',#'checkpoints/v3_1/', 159 | 'if_adv_train':True, 160 | 'if_adaptive':False, 161 | } 162 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import torch 3 | import time 4 | import torch.nn as nn 5 | from models.flgan import FLGAN 6 | from torch.autograd import Variable 7 | from torch.utils import data 8 | from loader.image_label_loader import imageLabelLoader 9 | from loader.image_loader import imageLoader 10 | from loader.label_loader import labelLoader 11 | from util.confusion_matrix import ConfusionMatrix 12 | import util.makedirs as makedirs 13 | import os 14 | import torchvision.models as models 15 | import matplotlib.pyplot as plt 16 | from util.log import Logger 17 | import numpy as np 18 | import Image 19 | def save_checkpoint(state, filename): 20 | torch.save(state, filename) 21 | 22 | 23 | def update_confusion_matrix(matrix, output, target): 24 | values, indices = output.max(1) 25 | output = indices 26 | target = target.cpu().numpy() 27 | output = output.cpu().numpy() 28 | matrix.update(target, output) 29 | return matrix 30 | 31 | def validate(val_loader, model, criterion, adaptation): 32 | # switch to evaluate mode 33 | run_time = time.time() 34 | matrix = ConfusionMatrix(args['label_nums']) 35 | loss = 0 36 | for i, (images, labels) in enumerate(val_loader): 37 | labels = labels.cuda(async=True) 38 | target_var = torch.autograd.Variable(labels, volatile=True) 39 | 40 | model.test(images) 41 | output = model.output 42 | loss += criterion(output, target_var)/args['batch_size'] 43 | matrix = update_confusion_matrix(matrix, output.data, labels) 44 | loss /= (i+1) 45 | run_time = time.time() - run_time 46 | logger.info('=================================================') 47 | logger.info('val:' 48 | 'loss: {0:.4f}\t' 49 | 'accuracy: {1:.4f}\t' 50 | 'fg_accuracy: {2:.4f}\t' 51 | 'avg_precision: {3:.4f}\t' 52 | 'avg_recall: {4:.4f}\t' 53 | 'avg_f1score: {5:.4f}\t' 54 | 'run_time:{run_time:.2f}\t' 55 | .format(loss.data[0], matrix.accuracy(), 56 | matrix.fg_accuracy(), matrix.avg_precision(), matrix.avg_recall(), matrix.avg_f1score(),run_time=run_time)) 57 | logger.info('=================================================') 58 | return matrix.all_acc() 59 | 60 | 61 | def main(): 62 | 63 | makedirs.mkdirs(os.path.join(args['checkpoints_dir'], args['name'])) 64 | if len(args['device_ids']) > 0: 65 | torch.cuda.set_device(args['device_ids'][0]) 66 | 67 | A_train_loader = data.DataLoader(imageLabelLoader(args['data_path'],dataName=args['domainA'], phase='train'), batch_size=args['batch_size'], 68 | num_workers=args['num_workers'], shuffle=True) 69 | label_train_loader = data.DataLoader(labelLoader(args['data_path'], dataName=args['domainA'], phase='train_onehot'), 70 | batch_size=args['batch_size'], 71 | num_workers=args['num_workers'], shuffle=True) 72 | 73 | A_val_loader = data.DataLoader(imageLabelLoader(args['data_path'], dataName=args['domainA'], phase='val'), batch_size=args['batch_size'], 74 | num_workers=args['num_workers'], shuffle=False) 75 | 76 | B_train_loader = data.DataLoader(imageLoader(args['data_path'], dataName=args['domainB'], phase='train+unlabel'), 77 | batch_size=args['batch_size'], 78 | num_workers=args['num_workers'], shuffle=True) 79 | B_val_loader = data.DataLoader(imageLabelLoader(args['data_path'], dataName=args['domainB'], phase='val'), 80 | batch_size=args['batch_size'], 81 | num_workers=args['num_workers'], shuffle=False) 82 | model = FLGAN() 83 | model.initialize(args) 84 | 85 | # multi GPUS 86 | # model = torch.nn.DataParallel(model,device_ids=args['device_ids']).cuda() 87 | Iter = 0 88 | Epoch = 0 89 | best_Ori_on_B = 0 90 | prec_Ori_on_B = 0 91 | if args['resume']: 92 | if os.path.isfile(args['resume']): 93 | logger.info("=> loading checkpoint '{}'".format(args['resume'])) 94 | Iter, Epoch, best_Ori_on_B = model.load(args['resume']) 95 | prec_Ori_on_B = best_Ori_on_B 96 | if (args['if_adaptive'] and (Epoch + 1) % 30 == 0) or prec_Ori_on_B > 0.56: 97 | model.update_learning_rate() 98 | else: 99 | print("=> no checkpoint found at '{}'".format(args['resume'])) 100 | 101 | model.train() 102 | for epoch in range(Epoch, args['n_epoch']): 103 | # train(A_train_loader, B_train_loader, model, epoch) 104 | # switch to train mode 105 | for i, (A_image, A_label) in enumerate(A_train_loader): 106 | Iter += 1 107 | B_image = next(iter(B_train_loader)) 108 | if Iter % args['interval_d2'] == 0 and args['if_adv_train']: 109 | label_onehot = next(iter(label_train_loader)) 110 | model.set_input({'A': A_image, 'A_label': A_label, 'label_onehot':label_onehot, 'B': B_image}) 111 | else: 112 | model.set_input({'A': A_image, 'A_label': A_label, 'B': B_image}) 113 | 114 | model.step() 115 | output = model.output 116 | if (i+1) % args['print_freq'] == 0: 117 | matrix = ConfusionMatrix() 118 | update_confusion_matrix(matrix, output.data, A_label) 119 | logger.info('Time: {time}\t' 120 | 'Epoch/Iter: [{epoch}/{Iter}]\t' 121 | 'loss: {loss:.4f}\t' 122 | 'acc: {accuracy:.4f}\t' 123 | 'fg_acc: {fg_accuracy:.4f}\t' 124 | 'avg_prec: {avg_precision:.4f}\t' 125 | 'avg_rec: {avg_recall:.4f}\t' 126 | 'avg_f1: {avg_f1core:.4f}\t' 127 | 'loss_G1: {loss_G1:.4f}\t' 128 | 'loss_D1: {loss_D1:.4f}\t' 129 | 'loss_G2: {loss_G2:.4f}\t' 130 | 'loss_D2: {loss_D2:.4f}\t' 131 | .format( 132 | time=time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()), 133 | epoch=epoch, Iter=Iter, loss=model.loss_P.data[0], 134 | accuracy=matrix.accuracy(), 135 | fg_accuracy=matrix.fg_accuracy(), avg_precision=matrix.avg_precision(), 136 | avg_recall=matrix.avg_recall(), avg_f1core=matrix.avg_f1score(), 137 | loss_G1=model.loss_G1.data[0], loss_D1=model.loss_D1.data[0], 138 | loss_G2=model.loss_G2.data[0], loss_D2=model.loss_D2.data[0])) 139 | 140 | if Iter % 1000 == 0: 141 | model.eval() 142 | acc_Ori_on_A = validate(A_val_loader, model, nn.CrossEntropyLoss(size_average=False), False) 143 | acc_Ori_on_B = validate(B_val_loader, model, nn.CrossEntropyLoss(size_average=False), False) 144 | prec_Ori_on_B = acc_Ori_on_B['avg_f1score'] 145 | 146 | is_best = prec_Ori_on_B > best_Ori_on_B 147 | best_Ori_on_B = max(prec_Ori_on_B, best_Ori_on_B) 148 | if is_best: 149 | model.save('best_Ori_on_B', Iter=Iter, epoch=epoch, acc={'acc_Ori_on_A':acc_Ori_on_A, 'acc_Ori_on_B':acc_Ori_on_B}) 150 | model.train() 151 | if (args['if_adaptive'] and (epoch+1) % 30 == 0): 152 | model.update_learning_rate() 153 | 154 | 155 | 156 | 157 | if __name__ == '__main__': 158 | global args 159 | args = { 160 | 'test_init':False, 161 | 'label_nums':12, 162 | 'l_rate':1e-8, 163 | 'lr_g1': 0.00001, 164 | 'lr_g2': 0.00000002, 165 | 'beta1': 0.5, 166 | 'interval_g2':5, 167 | 'interval_d2':5, 168 | 'data_path':'datasets', 169 | 'n_epoch':1000, 170 | 'batch_size':10, 171 | 'num_workers':10, 172 | 'print_freq':100, 173 | 'device_ids':[1], 174 | 'domainA': 'Lip', 175 | 'domainB': 'July', 176 | 'weigths_pool': 'pretrain_models', 177 | 'pretrain_model': 'deeplab.pth', 178 | 'log':'log', 179 | 'fineSizeH':241, 180 | 'fineSizeW':121, 181 | 'input_nc':3, 182 | 'name': 'Lip_to_July_lr_g1=0.00001_lr_g2=0.00000002_interval_g1=5_interval_d1=5_net_D=lsganMultOutput_D_if_adaptive=True_resume_decay=g2', 183 | 'checkpoints_dir': 'checkpoints', 184 | 'net_d1': 'NoBNSinglePathdilationMultOutputNet', 185 | 'net_d2': 'lsganMultOutput_D', 186 | 'use_lsgan': True, 187 | 'resume':None,#'checkpoints/lr_g1=0.00001_lr_g2=0.00000001_interval_g1=6_interval_d1=6_net_D=lsganMultOutput_D_if_adaptive=True/best_Ori_on_B_model.pth',#'checkpoints/v3_1/', 188 | 'if_adv_train':True, 189 | 'if_adaptive':True, 190 | } 191 | if not os.path.exists(args['checkpoints_dir']): 192 | os.makedirs(args['checkpoints_dir']) 193 | if not os.path.exists(args['log']): 194 | os.makedirs(args['log']) 195 | if not os.path.exists(os.path.join(args['data_path'], args['domainA'], 'label', 'train_onehot')): 196 | print('Creat onehot label from domainA...') 197 | onehot_label_path = os.path.join(args['data_path'], args['domainA'], 'label', 'train_onehot') 198 | os.makedirs(onehot_label_path) 199 | label_path = os.path.join(args['data_path'], args['domainA'], 'label', 'train') 200 | for name in os.listdir(label_path): 201 | lbl = Image.open(os.path.join(label_path, name)) 202 | lbl = np.array(lbl) 203 | lbl_onehot = np.zeros((12, 241, 121)) 204 | for i in range(args['label_nums']): 205 | lbl_onehot[i][lbl == i] = 1 206 | np.save(os.path.join(onehot_label_path, name.split('.png')[0] + ".npy"), lbl_onehot) 207 | print('Done!') 208 | 209 | logger = Logger( 210 | log_file=args['log'] + '/' + args['name'] + '-' + time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + '.log') 211 | logger.info('------------ Options -------------\n') 212 | for k, v in args.items(): 213 | logger.info('%s: %s' % (str(k), str(v))) 214 | logger.info('-------------- End ----------------\n') 215 | main() 216 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 28 | 29 | 30 | 32 | 33 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 62 | 63 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 82 | 83 | 84 | 85 | 102 | 103 | 114 | 115 | 133 | 134 | 149 | 150 | 151 | 153 | 154 | 155 | 156 | 1505531067661 157 | 161 | 162 | 1505531567073 163 | 168 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 197 | 198 | 199 | 201 | 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /models/flgan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from .base_model import BaseModel 5 | import networks 6 | import itertools 7 | from torch.autograd import Variable 8 | 9 | def get_parameters(model, parameter_name): 10 | for name, param in model.named_parameters(): 11 | if name in [parameter_name]: 12 | return param 13 | 14 | def weights_init(m): 15 | classname = m.__class__.__name__ 16 | if classname.find('Conv') != -1: 17 | m.weight.data.normal_(0.0, 0.02) 18 | elif classname.find('BatchNorm2d') != -1: 19 | m.weight.data.normal_(1.0, 0.02) 20 | m.bias.data.fill_(0) 21 | elif classname.find('Linear') != -1: 22 | m.weight.data.normal_(0.0, 0.02) 23 | 24 | def define_D(which_netD, input_nc): 25 | if which_netD == 'NoBNSinglePathdilationMultOutputNet': 26 | return networks.NoBNSinglePathdilationMultOutputNet(input_nc) 27 | elif which_netD == 'lsganMultOutput_D': 28 | return networks.lsganMultOutput_D(input_nc) 29 | 30 | 31 | class FLGAN(BaseModel): 32 | def name(self): 33 | return 'flgan' 34 | 35 | def initialize(self, args): 36 | BaseModel.initialize(self, args) 37 | self.if_adv_train = args['if_adv_train'] 38 | self.Iter = 0 39 | self.interval_g2 = args['interval_g2'] 40 | self.interval_d2 = args['interval_d2'] 41 | self.nb = args['batch_size'] 42 | sizeH, sizeW = args['fineSizeH'], args['fineSizeW'] 43 | 44 | self.tImageA = self.Tensor(self.nb, args['input_nc'], sizeH, sizeW) 45 | self.tImageB = self.Tensor(self.nb, args['input_nc'], sizeH, sizeW) 46 | self.tLabelA = torch.cuda.LongTensor(self.nb, 1, sizeH, sizeW) 47 | self.tOnehotLabelA = self.Tensor(self.nb, args['label_nums'], sizeH, sizeW) 48 | self.loss_G = Variable() 49 | self.loss_D = Variable() 50 | 51 | self.netG1 = networks.netG().cuda(device_id=args['device_ids'][0]) 52 | self.netD1 = define_D(args['net_d1'],512).cuda(device_id=args['device_ids'][0]) 53 | self.netD2 = define_D(args['net_d2'],args['label_nums']).cuda(device_id=args['device_ids'][0]) 54 | 55 | self.deeplabPart1 = networks.DeeplabPool1().cuda(device_id=args['device_ids'][0]) 56 | self.deeplabPart2 = networks.DeeplabPool12Pool5().cuda(device_id=args['device_ids'][0]) 57 | self.deeplabPart3 = networks.DeeplabPool52Fc8_interp(output_nc=args['label_nums']).cuda(device_id=args['device_ids'][0]) 58 | 59 | # define loss functions 60 | self.criterionCE = torch.nn.CrossEntropyLoss(size_average=False) 61 | self.criterionAdv = networks.Advloss(use_lsgan=args['use_lsgan'], tensor=self.Tensor) 62 | 63 | 64 | if not args['resume']: 65 | #initialize networks 66 | self.netG1.apply(weights_init) 67 | self.netD1.apply(weights_init) 68 | self.netD2.apply(weights_init) 69 | pretrained_dict = torch.load(args['weigths_pool'] + '/' + args['pretrain_model']) 70 | self.deeplabPart1.weights_init(pretrained_dict=pretrained_dict) 71 | self.deeplabPart2.weights_init(pretrained_dict=pretrained_dict) 72 | self.deeplabPart3.weights_init(pretrained_dict=pretrained_dict) 73 | 74 | # initialize optimizers 75 | self.optimizer_G1 = torch.optim.Adam(self.netG1.parameters(), 76 | lr=args['lr_g1'], betas=(args['beta1'], 0.999)) 77 | self.optimizer_D1 = torch.optim.Adam(self.netD1.parameters(), 78 | lr=args['lr_g1'], betas=(args['beta1'], 0.999)) 79 | 80 | self.optimizer_G2 = torch.optim.Adam([ 81 | {'params': self.deeplabPart1.parameters()}, 82 | {'params': self.deeplabPart2.parameters()}, 83 | {'params': self.deeplabPart3.parameters()}], 84 | lr=args['lr_g2'], betas=(args['beta1'], 0.999)) 85 | self.optimizer_D2 = torch.optim.Adam(self.netD2.parameters(), 86 | lr=args['lr_g2'], betas=(args['beta1'], 0.999)) 87 | 88 | ignored_params = list(map(id, self.deeplabPart3.fc8_1.parameters())) 89 | ignored_params.extend(list(map(id, self.deeplabPart3.fc8_2.parameters()))) 90 | ignored_params.extend(list(map(id, self.deeplabPart3.fc8_3.parameters()))) 91 | ignored_params.extend(list(map(id, self.deeplabPart3.fc8_4.parameters()))) 92 | base_params = filter(lambda p: id(p) not in ignored_params, 93 | self.deeplabPart3.parameters()) 94 | base_params = base_params + filter(lambda p: True, self.deeplabPart1.parameters()) 95 | base_params = base_params + filter(lambda p: True, self.deeplabPart2.parameters()) 96 | 97 | deeplab_params = [{'params': base_params}, 98 | {'params': get_parameters(self.deeplabPart3.fc8_1, 'weight'), 'lr': args['l_rate'] * 10}, 99 | {'params': get_parameters(self.deeplabPart3.fc8_2, 'weight'), 'lr': args['l_rate'] * 10}, 100 | {'params': get_parameters(self.deeplabPart3.fc8_3, 'weight'), 'lr': args['l_rate'] * 10}, 101 | {'params': get_parameters(self.deeplabPart3.fc8_4, 'weight'), 'lr': args['l_rate'] * 10}, 102 | {'params': get_parameters(self.deeplabPart3.fc8_1, 'bias'), 'lr': args['l_rate'] * 20}, 103 | {'params': get_parameters(self.deeplabPart3.fc8_2, 'bias'), 'lr': args['l_rate'] * 20}, 104 | {'params': get_parameters(self.deeplabPart3.fc8_3, 'bias'), 'lr': args['l_rate'] * 20}, 105 | {'params': get_parameters(self.deeplabPart3.fc8_4, 'bias'), 'lr': args['l_rate'] * 20}, 106 | ] 107 | 108 | 109 | self.optimizer_P = torch.optim.SGD(deeplab_params, lr=args['l_rate'], momentum=0.9, weight_decay=5e-4) 110 | 111 | self.optimizer_R = torch.optim.SGD(deeplab_params, lr=args['l_rate'], momentum=0.9, weight_decay=5e-4) 112 | 113 | 114 | print('---------- Networks initialized -------------') 115 | networks.print_network(self.netG1) 116 | networks.print_network(self.netD1) 117 | networks.print_network(self.netD2) 118 | networks.print_network(self.deeplabPart1) 119 | networks.print_network(self.deeplabPart2) 120 | networks.print_network(self.deeplabPart3) 121 | print('-----------------------------------------------') 122 | 123 | 124 | def set_input(self, input): 125 | self.input = input 126 | tImageA = input['A'] 127 | tLabelA = input['A_label'] 128 | tImageB = input['B'] 129 | self.tImageA.resize_(tImageA.size()).copy_(tImageA) 130 | self.vImageA = Variable(self.tImageA) 131 | 132 | self.tLabelA.resize_(tLabelA.size()).copy_(tLabelA) 133 | self.vLabelA = Variable(self.tLabelA) 134 | 135 | self.tImageB.resize_(tImageB.size()).copy_(tImageB) 136 | self.vImageB = Variable(self.tImageB) 137 | 138 | if input.has_key('label_onehot'): 139 | tOnehotLabelA = input['label_onehot'] 140 | self.tOnehotLabelA.resize_(tOnehotLabelA.size()).copy_(tOnehotLabelA) 141 | self.vOnehotLabelA = Variable(self.tOnehotLabelA) 142 | 143 | # used in test time, no backprop 144 | def test(self, input): 145 | self.tImageA.resize_(input.size()).copy_(input) 146 | self.vImageA = Variable(self.tImageA) 147 | self.output = self.deeplabPart3(self.deeplabPart2(self.deeplabPart1(self.vImageA))) 148 | return self.output 149 | 150 | def step_P(self): 151 | # Maintain pool5_B in this status 152 | self.pool5_B = self.deeplabPart2(self.deeplabPart1(self.vImageB)) 153 | self.pool5_B_for_d1 = Variable(self.pool5_B.data) 154 | 155 | self.pool1_A = self.deeplabPart1(self.vImageA) 156 | self.pool5_A = self.deeplabPart2(self.pool1_A) 157 | self.predic_A = self.deeplabPart3(self.pool5_A) 158 | self.output = Variable(self.predic_A.data) 159 | 160 | self.loss_P = self.criterionCE(self.predic_A, self.vLabelA) / self.nb 161 | self.loss_P.backward() 162 | 163 | self.pool1_A = Variable(self.pool1_A.data) 164 | self.pool5_A = Variable(self.pool5_A.data) 165 | 166 | 167 | def step_G1(self): 168 | self.pool5_A = self.pool5_A + self.netG1(self.pool1_A) 169 | pred_fake = self.netD1.forward(self.pool5_A) 170 | 171 | self.loss_G1 = self.criterionAdv(pred_fake, True) 172 | self.loss_G1.backward() 173 | 174 | self.pool5_A = Variable(self.pool5_A.data) 175 | 176 | def step_D1(self): 177 | pred_real = self.netD1.forward(self.pool5_B_for_d1) 178 | loss_D1_real = self.criterionAdv(pred_real, True) 179 | 180 | pred_fake = self.netD1.forward(self.pool5_A) 181 | loss_D1_fake = self.criterionAdv(pred_fake, False) 182 | 183 | self.loss_D1 = (loss_D1_real + loss_D1_fake) * 0.5 184 | self.loss_D1.backward() 185 | 186 | def step_G2(self): 187 | self.predic_B = self.deeplabPart3(self.pool5_B) 188 | pred_fake = self.netD2.forward(self.predic_B) 189 | 190 | self.loss_G2 = self.criterionAdv(pred_fake, True) 191 | self.loss_G2.backward() 192 | 193 | def step_D2(self): 194 | #self.vOnehotLabelA = Variable(self.vOnehotLabelA.data) 195 | pred_real = self.netD2.forward(self.vOnehotLabelA) 196 | loss_D2_real = self.criterionAdv(pred_real, True) 197 | 198 | self.predic_B = Variable(self.predic_B.data) 199 | pred_fake = self.netD2.forward(self.predic_B) 200 | loss_D2_fake = self.criterionAdv(pred_fake, False) 201 | 202 | self.loss_D2 = (loss_D2_real + loss_D2_fake) * 0.5 203 | 204 | self.loss_D2.backward() 205 | 206 | def step_R(self): 207 | pool1 = self.deeplabPart1(self.vImageA) 208 | self.predic_A_R = self.deeplabPart3(self.deeplabPart2(pool1) + self.netG1(pool1)) 209 | self.loss_R = self.criterionCE(self.predic_A_R, self.vLabelA) / self.nb 210 | 211 | self.loss_R.backward() 212 | 213 | def step(self): 214 | self.Iter += 1 215 | # deeplab 216 | self.optimizer_P.zero_grad() 217 | self.step_P() 218 | self.optimizer_P.step() 219 | 220 | # G1 221 | self.optimizer_G1.zero_grad() 222 | self.step_G1() 223 | self.optimizer_G1.step() 224 | # D1 225 | self.optimizer_D1.zero_grad() 226 | self.step_D1() 227 | self.optimizer_D1.step() 228 | if self.Iter % self.interval_g2 == 0 and self.if_adv_train: 229 | # G2 230 | self.optimizer_G2.zero_grad() 231 | self.step_G2() 232 | self.optimizer_G2.step() 233 | if self.Iter % self.interval_d2 == 0 and self.if_adv_train: 234 | # D2 235 | self.optimizer_D2.zero_grad() 236 | self.step_D2() 237 | self.optimizer_D2.step() 238 | 239 | # Refine 240 | self.optimizer_R.zero_grad() 241 | self.step_R() 242 | self.optimizer_R.step() 243 | 244 | 245 | def get_current_visuals(self): 246 | return self.input 247 | 248 | def get_current_errors(self): 249 | return {} 250 | 251 | 252 | def save(self, model_name, Iter=None, epoch=None, acc=[]): 253 | save_filename = '%s_model.pth' % (model_name) 254 | save_path = os.path.join(self.save_dir, save_filename) 255 | torch.save({ 256 | 'name':self.name(), 257 | 'Iter': Iter, 258 | 'epoch': epoch, 259 | 'acc':acc, 260 | 'state_dict_netG1': self.netG1.state_dict(), 261 | 'state_dict_netD1': self.netD1.state_dict(), 262 | 'state_dict_netD2': self.netD2.state_dict(), 263 | 'state_dict_deeplabPart1': self.deeplabPart1.state_dict(), 264 | 'state_dict_deeplabPart2':self.deeplabPart2.state_dict(), 265 | 'state_dict_deeplabPart3': self.deeplabPart3.state_dict(), 266 | 'optimizer_P':self.optimizer_P.state_dict(), 267 | 'optimizer_R': self.optimizer_R.state_dict(), 268 | 'optimizer_G1': self.optimizer_G1.state_dict(), 269 | 'optimizer_D1': self.optimizer_D1.state_dict(), 270 | 'optimizer_G2': self.optimizer_G2.state_dict(), 271 | 'optimizer_D2': self.optimizer_D2.state_dict(), 272 | }, save_path) 273 | 274 | def load(self, load_path): 275 | checkpoint = torch.load(load_path) 276 | self.netG1.load_state_dict(checkpoint['state_dict_netG1']) 277 | self.netD1.load_state_dict(checkpoint['state_dict_netD1']) 278 | self.netD2.load_state_dict(checkpoint['state_dict_netD2']) 279 | self.deeplabPart1.load_state_dict(checkpoint['state_dict_deeplabPart1']) 280 | self.deeplabPart2.load_state_dict(checkpoint['state_dict_deeplabPart2']) 281 | self.deeplabPart3.load_state_dict(checkpoint['state_dict_deeplabPart3']) 282 | 283 | self.optimizer_P.load_state_dict(checkpoint['optimizer_P']) 284 | self.optimizer_G1.load_state_dict(checkpoint['optimizer_G1']) 285 | self.optimizer_D1.load_state_dict(checkpoint['optimizer_D1']) 286 | self.optimizer_G2.load_state_dict(checkpoint['optimizer_G2']) 287 | self.optimizer_D2.load_state_dict(checkpoint['optimizer_D2']) 288 | self.optimizer_R.load_state_dict(checkpoint['optimizer_R']) 289 | for k,v in checkpoint['acc'].items(): 290 | print('=================================================') 291 | if k == 'acc_Ori_on_B': 292 | best_f1 = v['avg_f1score'] 293 | print('accuracy: {0:.4f}\t' 294 | 'fg_accuracy: {1:.4f}\t' 295 | 'avg_precision: {2:.4f}\t' 296 | 'avg_recall: {3:.4f}\t' 297 | 'avg_f1score: {4:.4f}\t' 298 | .format(v['accuracy'],v['fg_accuracy'],v['avg_precision'], v['avg_recall'], v['avg_f1score'])) 299 | print('=================================================') 300 | 301 | return checkpoint['Iter'], checkpoint['epoch'], best_f1 302 | 303 | # helper loading function that can be used by subclasses 304 | def load_network(self, network, network_label, epoch_label): 305 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 306 | save_path = os.path.join(self.save_dir, save_filename) 307 | network.load_state_dict(torch.load(save_path)) 308 | 309 | def update_learning_rate(self): 310 | for param_group in self.optimizer_D1.param_groups: 311 | param_group['lr'] = param_group['lr'] * 0.1 312 | for param_group in self.optimizer_G1.param_groups: 313 | param_group['lr'] = param_group['lr'] * 0.1 314 | 315 | for param_group in self.optimizer_D2.param_groups: 316 | param_group['lr'] = param_group['lr'] * 0.1 317 | for param_group in self.optimizer_G2.param_groups: 318 | param_group['lr'] = param_group['lr'] * 0.1 319 | 320 | def train(self): 321 | self.deeplabPart1.train() 322 | self.deeplabPart2.train() 323 | self.deeplabPart3.train() 324 | self.netG1.train() 325 | self.netD1.train() 326 | self.netD2.train() 327 | 328 | def eval(self): 329 | self.deeplabPart1.eval() 330 | self.deeplabPart2.eval() 331 | self.deeplabPart3.eval() 332 | self.netG1.eval() 333 | self.netD1.eval() 334 | self.netD2.eval() 335 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | from torch.autograd import Variable 6 | import random 7 | ############################################################################### 8 | # Functions 9 | ############################################################################### 10 | def print_network(net): 11 | num_params = 0 12 | for param in net.parameters(): 13 | num_params += param.numel() 14 | print(net) 15 | print('Total number of parameters: %d' % num_params) 16 | 17 | def weights_init(m): 18 | classname = m.__class__.__name__ 19 | if classname.find('Conv') != -1: 20 | m.weight.data.normal_(0.0, 0.02) 21 | elif classname.find('BatchNorm2d') != -1: 22 | m.weight.data.normal_(1.0, 0.02) 23 | m.bias.data.fill_(0) 24 | 25 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 26 | # When LSGAN is used, it is basically same as MSELoss, 27 | # but it abstracts away the need to create the target label tensor 28 | # that has the same size as the input 29 | class Advloss(nn.Module): 30 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, 31 | tensor=torch.FloatTensor): 32 | super(Advloss, self).__init__() 33 | self.real_label = target_real_label 34 | self.fake_label = target_fake_label 35 | self.real_label_var = None 36 | self.fake_label_var = None 37 | self.Tensor = tensor 38 | if use_lsgan: 39 | self.loss = nn.MSELoss() 40 | else: 41 | self.loss = nn.BCELoss() 42 | 43 | def get_target_tensor(self, input, target_is_real): 44 | target_tensor = None 45 | if target_is_real: 46 | create_label = ((self.real_label_var is None) or 47 | (self.real_label_var.numel() != input.numel())) 48 | if create_label: 49 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 50 | self.real_label_var = Variable(real_tensor, requires_grad=False) 51 | target_tensor = self.real_label_var 52 | else: 53 | create_label = ((self.fake_label_var is None) or 54 | (self.fake_label_var.numel() != input.numel())) 55 | if create_label: 56 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 57 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 58 | target_tensor = self.fake_label_var 59 | return target_tensor 60 | 61 | def __call__(self, input, target_is_real): 62 | target_tensor = self.get_target_tensor(input, target_is_real) 63 | return self.loss(input, target_tensor) 64 | 65 | class Deeplab(nn.Module): 66 | def __init__(self, size=(241,121)): 67 | super(Deeplab, self).__init__() 68 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) 69 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 70 | self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 71 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 72 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 73 | self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 74 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 75 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 76 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 77 | self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 78 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 79 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 80 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 81 | self.pool4 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 82 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=2, dilation=2) 83 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=2, dilation=2) 84 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=2, dilation=2) 85 | self.pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 86 | 87 | self.fc6_1 = nn.Conv2d(512, 1024, 3, padding=6, dilation=6) 88 | self.fc7_1 = nn.Conv2d(1024, 1024, 1) 89 | self.fc8_1 = nn.Conv2d(1024, 12, 1) 90 | 91 | self.fc6_2 = nn.Conv2d(512, 1024, 3, padding=12, dilation=12) 92 | self.fc7_2 = nn.Conv2d(1024, 1024, 1) 93 | self.fc8_2 = nn.Conv2d(1024, 12, 1) 94 | 95 | self.fc6_3 = nn.Conv2d(512, 1024, 3, padding=18, dilation=18) 96 | self.fc7_3 = nn.Conv2d(1024, 1024, 1) 97 | self.fc8_3 = nn.Conv2d(1024, 12, 1) 98 | 99 | self.fc6_4 = nn.Conv2d(512, 1024, 3, padding=24, dilation=24) 100 | self.fc7_4 = nn.Conv2d(1024, 1024, 1) 101 | self.fc8_4 = nn.Conv2d(1024, 12, 1) 102 | 103 | #self.fc8_interp = nn.Upsample(scale_factor=8,mode='bilinear') 104 | self.dropout = nn.Dropout2d(0.5) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.fc8_interp = nn.Upsample(size=size, mode='bilinear') 107 | 108 | def weights_init(self, pretrained_dict={}): 109 | init.normal(self.fc8_1.weight.data, mean=0, std=0.01) 110 | init.constant(self.fc8_1.bias.data, 0) 111 | init.normal(self.fc8_2.weight.data, mean=0, std=0.01) 112 | init.constant(self.fc8_2.bias.data, 0) 113 | init.normal(self.fc8_3.weight.data, mean=0, std=0.01) 114 | init.constant(self.fc8_3.bias.data, 0) 115 | init.normal(self.fc8_4.weight.data, mean=0, std=0.01) 116 | init.constant(self.fc8_4.bias.data, 0) 117 | 118 | model_dict = self.state_dict() 119 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 120 | model_dict.update(pretrained_dict) 121 | self.load_state_dict(model_dict) 122 | 123 | def forward(self, x): 124 | x = self.relu(self.conv1_1(x)) 125 | x = self.pool1(self.relu(self.conv1_2(x))) 126 | x = self.relu(self.conv2_1(x)) 127 | x = self.pool2(self.relu(self.conv2_2(x))) 128 | x = self.relu(self.conv3_1(x)) 129 | x = self.relu(self.conv3_2(x)) 130 | x = self.pool3(self.relu(self.conv3_3(x))) 131 | x = self.relu(self.conv4_1(x)) 132 | x = self.relu(self.conv4_2(x)) 133 | x = self.pool4(self.relu(self.conv4_3(x))) 134 | x = self.relu(self.conv5_1(x)) 135 | x = self.relu(self.conv5_2(x)) 136 | x = self.pool5(self.relu(self.conv5_3(x))) 137 | 138 | x1 = self.dropout(0.5)(self.relu(self.fc6_1(x))) 139 | x1 = self.dropout(0.5)(self.relu(self.fc7_1(x1))) 140 | x1 = self.fc8_1(x1) 141 | 142 | x2 = self.dropout(0.5)(self.relu(self.fc6_2(x))) 143 | x2 = self.dropout(0.5)(self.relu(self.fc7_2(x2))) 144 | x2 = self.fc8_2(x2) 145 | 146 | x3 = self.dropout(0.5)(self.relu(self.fc6_3(x))) 147 | x3 = self.dropout(0.5)(self.relu(self.fc7_3(x3))) 148 | x3 = self.fc8_3(x3) 149 | 150 | x4 = self.dropout(0.5)(self.relu(self.fc6_4(x))) 151 | x4 = self.dropout(0.5)(self.relu(self.fc7_4(x4))) 152 | x4 = self.fc8_4(x4) 153 | x = self.fc8_interp(x1 + x2 + x3 + x4) 154 | return x 155 | 156 | class DeeplabPool1(nn.Module): 157 | def __init__(self, size=(241,121)): 158 | super(DeeplabPool1, self).__init__() 159 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) 160 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 161 | self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 162 | self.relu = nn.ReLU(inplace=True) 163 | 164 | def weights_init(self, pretrained_dict={}): 165 | model_dict = self.state_dict() 166 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 167 | model_dict.update(pretrained_dict) 168 | self.load_state_dict(model_dict) 169 | 170 | def forward(self, x): 171 | x = self.relu(self.conv1_1(x)) 172 | x = self.pool1(self.relu(self.conv1_2(x))) 173 | return x 174 | 175 | class DeeplabPool12Conv5_1(nn.Module): 176 | def __init__(self, size=(241,121)): 177 | super(DeeplabPool12Conv5_1, self).__init__() 178 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 179 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 180 | self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 181 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 182 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 183 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 184 | self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 185 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 186 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 187 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 188 | self.pool4 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 189 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=2, dilation=2) 190 | 191 | self.relu = nn.ReLU() 192 | 193 | #self.fc8_interp = nn.Upsample(scale_factor=8,mode='bilinear') 194 | self.fc8_interp = nn.Upsample(size=size, mode='bilinear') 195 | 196 | def weights_init(self, pretrained_dict={}): 197 | model_dict = self.state_dict() 198 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 199 | model_dict.update(pretrained_dict) 200 | self.load_state_dict(model_dict) 201 | 202 | def forward(self, x): 203 | x = self.relu(self.conv2_1(x)) 204 | x = self.pool2(self.relu(self.conv2_2(x))) 205 | x = self.relu(self.conv3_1(x)) 206 | x = self.relu(self.conv3_2(x)) 207 | x = self.pool3(self.relu(self.conv3_3(x))) 208 | x = self.relu(self.conv4_1(x)) 209 | x = self.relu(self.conv4_2(x)) 210 | x = self.pool4(self.relu(self.conv4_3(x))) 211 | x = self.relu(self.conv5_1(x)) 212 | 213 | return x 214 | 215 | class DeeplabConv5_22Fc8_interp(nn.Module): 216 | def __init__(self, size=(241,121)): 217 | super(DeeplabConv5_22Fc8_interp, self).__init__() 218 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=2, dilation=2) 219 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=2, dilation=2) 220 | self.pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 221 | 222 | self.fc6_1 = nn.Conv2d(512, 1024, 3, padding=6, dilation=6) 223 | self.fc7_1 = nn.Conv2d(1024, 1024, 1) 224 | self.fc8_1 = nn.Conv2d(1024, 12, 1) 225 | 226 | self.fc6_2 = nn.Conv2d(512, 1024, 3, padding=12, dilation=12) 227 | self.fc7_2 = nn.Conv2d(1024, 1024, 1) 228 | self.fc8_2 = nn.Conv2d(1024, 12, 1) 229 | 230 | self.fc6_3 = nn.Conv2d(512, 1024, 3, padding=18, dilation=18) 231 | self.fc7_3 = nn.Conv2d(1024, 1024, 1) 232 | self.fc8_3 = nn.Conv2d(1024, 12, 1) 233 | 234 | self.fc6_4 = nn.Conv2d(512, 1024, 3, padding=24, dilation=24) 235 | self.fc7_4 = nn.Conv2d(1024, 1024, 1) 236 | self.fc8_4 = nn.Conv2d(1024, 12, 1) 237 | self.dropout = nn.Dropout2d(0.5) 238 | self.relu = nn.ReLU(inplace=True) 239 | self.fc8_interp = nn.Upsample(size=size, mode='bilinear') 240 | 241 | def weights_init(self, pretrained_dict={}): 242 | init.normal(self.fc8_1.weight.data, mean=0, std=0.01) 243 | init.constant(self.fc8_1.bias.data, 0) 244 | init.normal(self.fc8_2.weight.data, mean=0, std=0.01) 245 | init.constant(self.fc8_2.bias.data, 0) 246 | init.normal(self.fc8_3.weight.data, mean=0, std=0.01) 247 | init.constant(self.fc8_3.bias.data, 0) 248 | init.normal(self.fc8_4.weight.data, mean=0, std=0.01) 249 | init.constant(self.fc8_4.bias.data, 0) 250 | 251 | model_dict = self.state_dict() 252 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 253 | model_dict.update(pretrained_dict) 254 | self.load_state_dict(model_dict) 255 | 256 | def forward(self, x): 257 | x = self.relu(self.conv5_2(x)) 258 | x = self.pool5(self.relu(self.conv5_3(x))) 259 | 260 | x1 = self.dropout(self.relu(self.fc6_1(x))) 261 | x1 = self.dropout(self.relu(self.fc7_1(x1))) 262 | x1 = self.fc8_1(x1) 263 | 264 | x2 = self.dropout(self.relu(self.fc6_2(x))) 265 | x2 = self.dropout(self.relu(self.fc7_2(x2))) 266 | x2 = self.fc8_2(x2) 267 | 268 | x3 = self.dropout(self.relu(self.fc6_3(x))) 269 | x3 = self.dropout(self.relu(self.fc7_3(x3))) 270 | x3 = self.fc8_3(x3) 271 | 272 | x4 = self.dropout(self.relu(self.fc6_4(x))) 273 | x4 = self.dropout(self.relu(self.fc7_4(x4))) 274 | x4 = self.fc8_4(x4) 275 | x = self.fc8_interp(x1 + x2 + x3 + x4) 276 | return x 277 | 278 | class DeeplabPool12Pool5(nn.Module): 279 | def __init__(self, size=(241,121)): 280 | super(DeeplabPool12Pool5, self).__init__() 281 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 282 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 283 | self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 284 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 285 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 286 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 287 | self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 288 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 289 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 290 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 291 | self.pool4 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 292 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=2, dilation=2) 293 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=2, dilation=2) 294 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=2, dilation=2) 295 | self.pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) 296 | 297 | self.relu = nn.ReLU() 298 | 299 | #self.fc8_interp = nn.Upsample(scale_factor=8,mode='bilinear') 300 | self.fc8_interp = nn.Upsample(size=size, mode='bilinear') 301 | 302 | def weights_init(self, pretrained_dict={}): 303 | model_dict = self.state_dict() 304 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 305 | model_dict.update(pretrained_dict) 306 | self.load_state_dict(model_dict) 307 | 308 | def forward(self, x): 309 | x = self.relu(self.conv2_1(x)) 310 | x = self.pool2(self.relu(self.conv2_2(x))) 311 | x = self.relu(self.conv3_1(x)) 312 | x = self.relu(self.conv3_2(x)) 313 | x = self.pool3(self.relu(self.conv3_3(x))) 314 | x = self.relu(self.conv4_1(x)) 315 | x = self.relu(self.conv4_2(x)) 316 | x = self.pool4(self.relu(self.conv4_3(x))) 317 | x = self.relu(self.conv5_1(x)) 318 | x = self.relu(self.conv5_2(x)) 319 | x = self.pool5(self.relu(self.conv5_3(x))) 320 | return x 321 | 322 | class DeeplabPool52Fc8_interp(nn.Module): 323 | def __init__(self, output_nc, size=(241,121)): 324 | super(DeeplabPool52Fc8_interp, self).__init__() 325 | 326 | self.fc6_1 = nn.Conv2d(512, 1024, 3, padding=6, dilation=6) 327 | self.fc7_1 = nn.Conv2d(1024, 1024, 1) 328 | self.fc8_1 = nn.Conv2d(1024, output_nc, 1) 329 | 330 | self.fc6_2 = nn.Conv2d(512, 1024, 3, padding=12, dilation=12) 331 | self.fc7_2 = nn.Conv2d(1024, 1024, 1) 332 | self.fc8_2 = nn.Conv2d(1024, output_nc, 1) 333 | 334 | self.fc6_3 = nn.Conv2d(512, 1024, 3, padding=18, dilation=18) 335 | self.fc7_3 = nn.Conv2d(1024, 1024, 1) 336 | self.fc8_3 = nn.Conv2d(1024, output_nc, 1) 337 | 338 | self.fc6_4 = nn.Conv2d(512, 1024, 3, padding=24, dilation=24) 339 | self.fc7_4 = nn.Conv2d(1024, 1024, 1) 340 | self.fc8_4 = nn.Conv2d(1024, output_nc, 1) 341 | self.dropout = nn.Dropout2d(0.5) 342 | self.relu = nn.ReLU(inplace=True) 343 | self.fc8_interp = nn.Upsample(size=size, mode='bilinear') 344 | 345 | def weights_init(self, pretrained_dict={}): 346 | init.normal(self.fc8_1.weight.data, mean=0, std=0.01) 347 | init.constant(self.fc8_1.bias.data, 0) 348 | init.normal(self.fc8_2.weight.data, mean=0, std=0.01) 349 | init.constant(self.fc8_2.bias.data, 0) 350 | init.normal(self.fc8_3.weight.data, mean=0, std=0.01) 351 | init.constant(self.fc8_3.bias.data, 0) 352 | init.normal(self.fc8_4.weight.data, mean=0, std=0.01) 353 | init.constant(self.fc8_4.bias.data, 0) 354 | 355 | model_dict = self.state_dict() 356 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 357 | model_dict.update(pretrained_dict) 358 | self.load_state_dict(model_dict) 359 | 360 | def forward(self, x): 361 | x1 = self.dropout(self.relu(self.fc6_1(x))) 362 | x1 = self.dropout(self.relu(self.fc7_1(x1))) 363 | x1 = self.fc8_1(x1) 364 | 365 | x2 = self.dropout(self.relu(self.fc6_2(x))) 366 | x2 = self.dropout(self.relu(self.fc7_2(x2))) 367 | x2 = self.fc8_2(x2) 368 | 369 | x3 = self.dropout(self.relu(self.fc6_3(x))) 370 | x3 = self.dropout(self.relu(self.fc7_3(x3))) 371 | x3 = self.fc8_3(x3) 372 | 373 | x4 = self.dropout(self.relu(self.fc6_4(x))) 374 | x4 = self.dropout(self.relu(self.fc7_4(x4))) 375 | x4 = self.fc8_4(x4) 376 | x = self.fc8_interp(x1 + x2 + x3 + x4) 377 | return x 378 | 379 | class netG(nn.Module): 380 | def __init__(self, n_blocks=6): 381 | super(netG, self).__init__() 382 | input_nc = 64 383 | ngf = 128 384 | norm_layer = nn.BatchNorm2d 385 | padding_type = 'reflect' 386 | use_dropout = 0 387 | 388 | mult = 1 389 | model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3), norm_layer(ngf), nn.ReLU(True)] 390 | 391 | for i in range(n_blocks): 392 | if (i+1) % 3 == 0: 393 | model += [nn.MaxPool2d(kernel_size=3, stride=1, padding=1), nn.Conv2d(ngf*mult, ngf*mult*2, kernel_size=3, stride=2,padding=1)] 394 | mult *= 2 395 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout)] 396 | 397 | self.model = nn.Sequential(*model) 398 | 399 | 400 | def forward(self, x): 401 | return self.model(x) 402 | 403 | class netG_structure(nn.Module): 404 | def __init__(self, input_nc=512, output_nc=12, n_blocks=3, size=(241, 121)): 405 | super(netG_structure, self).__init__() 406 | ngf = 128 407 | norm_layer = nn.BatchNorm2d 408 | padding_type = 'reflect' 409 | use_dropout = 0 410 | 411 | model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3), norm_layer(ngf), nn.ReLU(True)] 412 | 413 | for i in range(n_blocks): 414 | model += [ 415 | ResnetBlock(ngf, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout)] 416 | 417 | model += [nn.Conv2d(ngf, output_nc, kernel_size=3, padding=1), nn.Upsample(size=size, mode='bilinear')] 418 | self.model = nn.Sequential(*model) 419 | 420 | def forward(self, x): 421 | return self.model(x) 422 | 423 | # Define a resnet block 424 | class ResnetBlock(nn.Module): 425 | def __init__(self, dim, padding_type, norm_layer, use_dropout): 426 | super(ResnetBlock, self).__init__() 427 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout) 428 | 429 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout): 430 | conv_block = [] 431 | p = 0 432 | if padding_type == 'reflect': 433 | conv_block += [nn.ReflectionPad2d(1)] 434 | elif padding_type == 'replicate': 435 | conv_block += [nn.ReplicationPad2d(1)] 436 | elif padding_type == 'zero': 437 | p = 1 438 | else: 439 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 440 | 441 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), 442 | norm_layer(dim), 443 | nn.ReLU(True)] 444 | if use_dropout: 445 | conv_block += [nn.Dropout(0.5)] 446 | 447 | p = 0 448 | if padding_type == 'reflect': 449 | conv_block += [nn.ReflectionPad2d(1)] 450 | elif padding_type == 'replicate': 451 | conv_block += [nn.ReplicationPad2d(1)] 452 | elif padding_type == 'zero': 453 | p = 1 454 | else: 455 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 456 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), 457 | norm_layer(dim)] 458 | 459 | return nn.Sequential(*conv_block) 460 | 461 | def forward(self, x): 462 | out = x + self.conv_block(x) 463 | return out 464 | 465 | class MultPathdilationNet(nn.Module): 466 | def __init__(self): 467 | super(MultPathdilationNet, self).__init__() 468 | input_nc = 512 469 | ngf = 128 470 | norm_layer = nn.InstanceNorm2d 471 | padding_type = 'reflect' 472 | use_dropout = 0 473 | self.relu = nn.ReLU(inplace=True) 474 | 475 | model_1 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), norm_layer(1024), self.relu, nn.Dropout2d(0.5), 476 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), self.relu, nn.Dropout2d(0.5), 477 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)] 478 | model_2 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), norm_layer(1024), self.relu, nn.Dropout2d(0.5), 479 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), self.relu, nn.Dropout2d(0.5), 480 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)] 481 | model_3 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), norm_layer(1024), self.relu, nn.Dropout2d(0.5), 482 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), self.relu, nn.Dropout2d(0.5), 483 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)] 484 | model_4 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), norm_layer(1024), self.relu, nn.Dropout2d(0.5), 485 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), self.relu, nn.Dropout2d(0.5), 486 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)] 487 | 488 | self.model_1 = nn.Sequential(*model_1) 489 | self.model_2 = nn.Sequential(*model_2) 490 | self.model_3 = nn.Sequential(*model_3) 491 | self.model_4 = nn.Sequential(*model_4) 492 | 493 | 494 | def forward(self, x): 495 | return ( self.model_1(x) + self.model_2(x) + self.model_3(x) + self.model_4(x) ) / 4 496 | 497 | class RandomMultPathdilationNet(nn.Module): 498 | def __init__(self): 499 | super(RandomMultPathdilationNet, self).__init__() 500 | input_nc = 512 501 | ngf = 128 502 | norm_layer = nn.InstanceNorm2d 503 | padding_type = 'reflect' 504 | use_dropout = 0 505 | self.relu = nn.ReLU(inplace=True) 506 | 507 | model_1 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), norm_layer(1024), self.relu, nn.Dropout2d(0.5), 508 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), self.relu, nn.Dropout2d(0.5), 509 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)] 510 | model_2 = [nn.Conv2d(512, 1024, 3, padding=12, dilation=12), norm_layer(1024), self.relu, nn.Dropout2d(0.5), 511 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), self.relu, nn.Dropout2d(0.5), 512 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)] 513 | model_3 = [nn.Conv2d(512, 1024, 3, padding=18, dilation=18), norm_layer(1024), self.relu, nn.Dropout2d(0.5), 514 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), self.relu, nn.Dropout2d(0.5), 515 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)] 516 | model_4 = [nn.Conv2d(512, 1024, 3, padding=24, dilation=24), norm_layer(1024), self.relu, nn.Dropout2d(0.5), 517 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), self.relu, nn.Dropout2d(0.5), 518 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)] 519 | 520 | self.model_1 = nn.Sequential(*model_1) 521 | self.model_2 = nn.Sequential(*model_2) 522 | self.model_3 = nn.Sequential(*model_3) 523 | self.model_4 = nn.Sequential(*model_4) 524 | 525 | 526 | def forward(self, x): 527 | which_D = random.uniform(0,1) 528 | if which_D < 0.25: 529 | return self.model_1(x) 530 | elif which_D < 0.5: 531 | return self.model_2(x) 532 | elif which_D < 0.75: 533 | return self.model_3(x) 534 | else: 535 | return self.model_4(x) 536 | 537 | class NoBNMultPathdilationNet(nn.Module): 538 | def __init__(self): 539 | super(NoBNMultPathdilationNet, self).__init__() 540 | input_nc = 512 541 | ngf = 128 542 | norm_layer = nn.InstanceNorm2d 543 | padding_type = 'reflect' 544 | use_dropout = 0 545 | self.relu = nn.ReLU(inplace=True) 546 | 547 | model_1 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), self.relu, nn.Dropout2d(0.5), 548 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), self.relu, nn.Dropout2d(0.5), 549 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)] 550 | model_2 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), self.relu, nn.Dropout2d(0.5), 551 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), self.relu, nn.Dropout2d(0.5), 552 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)] 553 | model_3 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), self.relu, nn.Dropout2d(0.5), 554 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), self.relu, nn.Dropout2d(0.5), 555 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)] 556 | model_4 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), self.relu, nn.Dropout2d(0.5), 557 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), self.relu, nn.Dropout2d(0.5), 558 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)] 559 | 560 | self.model_1 = nn.Sequential(*model_1) 561 | self.model_2 = nn.Sequential(*model_2) 562 | self.model_3 = nn.Sequential(*model_3) 563 | self.model_4 = nn.Sequential(*model_4) 564 | 565 | 566 | def forward(self, x): 567 | return ( self.model_1(x) + self.model_2(x) + self.model_3(x) + self.model_4(x) ) / 4 568 | 569 | class FFCFeature(nn.Module): 570 | def __init__(self): 571 | super(FFCFeature, self).__init__() 572 | self.classifier = nn.Sequential( 573 | nn.Dropout(0.5), 574 | nn.Linear(512 * 31 * 16, 512), 575 | nn.BatchNorm1d(512), 576 | nn.ReLU(True), 577 | nn.Dropout(0.5), 578 | nn.Linear(512, 1024), 579 | nn.BatchNorm1d(1024), 580 | nn.ReLU(True), 581 | nn.Dropout(0.5), 582 | nn.Linear(1024, 1), 583 | nn.Sigmoid() 584 | ) 585 | 586 | def forward(self, x): 587 | x = x.view(x.size(0), -1) 588 | x = self.classifier(x) 589 | return x 590 | 591 | class SinglePathdilationSingleOutputNet(nn.Module): 592 | def __init__(self): 593 | super(SinglePathdilationSingleOutputNet, self).__init__() 594 | model = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), nn.Dropout2d(0.5), 595 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), 596 | nn.Dropout2d(0.5), nn.Conv2d(1024, 256, 3, stride=2, padding=1)] 597 | self.model = nn.Sequential(*model) 598 | self.linear = nn.Linear(256 * 8 * 4, 1) 599 | 600 | def forward(self, x): 601 | x = self.model(x) 602 | x = x.view(x.size(0), -1) 603 | x = nn.Sigmoid()(self.linear(x)) 604 | return x 605 | 606 | class SinglePathdilationMultOutputNet(nn.Module): 607 | def __init__(self): 608 | super(SinglePathdilationMultOutputNet, self).__init__() 609 | input_nc = 512 610 | ngf = 128 611 | norm_layer = nn.BatchNorm2d 612 | padding_type = 'reflect' 613 | use_dropout = 0 614 | 615 | model = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), norm_layer(1024), nn.ReLU(inplace=True),nn.Dropout2d(0.5), 616 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), nn.ReLU(inplace=True),nn.Dropout2d(0.5), 617 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)] 618 | 619 | self.model = nn.Sequential(*model) 620 | 621 | 622 | def forward(self, x): 623 | return self.model(x) 624 | 625 | class NoBNSinglePathdilationMultOutputNet(nn.Module): 626 | def __init__(self, input_nc = 512): 627 | super(NoBNSinglePathdilationMultOutputNet, self).__init__() 628 | padding_type = 'reflect' 629 | use_dropout = 0 630 | 631 | model = [nn.Conv2d(input_nc, 1024, 3, padding=6, dilation=6), nn.ReLU(inplace=True),nn.Dropout2d(0.5), 632 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), nn.ReLU(inplace=True),nn.Dropout2d(0.5), 633 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)] 634 | 635 | self.model = nn.Sequential(*model) 636 | 637 | 638 | def forward(self, x): 639 | return self.model(x) 640 | 641 | class dcgan_D(nn.Module): 642 | def __init__(self, input_nc, ngf=64, norm_layer=nn.BatchNorm2d, n_layers=4): 643 | super(input_nc, self).__init__() 644 | self.input_nc = input_nc 645 | self.ngf = ngf 646 | self.norm_layer = norm_layer 647 | self.n_layers = n_layers 648 | self.padding_type = 'reflect' 649 | 650 | mult = 1 651 | model = [nn.Conv2d(input_nc, ngf, 4, stride=2, padding=1), norm_layer(ngf), nn.ReLU(inplace=True)] 652 | for i in range(self.layers-1): 653 | model = model + [nn.Conv2d(ngf*mult, ngf*mult*2, 4, stride=2, padding=1), norm_layer(ngf*mult*2), nn.ReLU(inplace=True)] 654 | mult *= 2 655 | 656 | model = model + [nn.Conv2d(ngf * mult, ngf,4)] 657 | self.model = nn.Sequential(*model) 658 | 659 | 660 | def forward(self, x): 661 | x = self.model(x) 662 | x = x.view(x.size(0), -1) 663 | x = nn.Sigmoid()(nn.Linear(x.size(1), 1)(x)) 664 | return x 665 | 666 | class dcgan_D_multOut(nn.Module): 667 | def __init__(self, input_nc=12, ngf=64, norm_layer=nn.BatchNorm2d, n_layers=4): 668 | super(dcgan_D_multOut, self).__init__() 669 | self.input_nc = input_nc 670 | self.ngf = ngf 671 | self.norm_layer = norm_layer 672 | self.n_layers = n_layers 673 | self.padding_type = 'reflect' 674 | 675 | mult = 1 676 | model = [nn.Conv2d(input_nc, ngf, 4, stride=2, padding=1), norm_layer(ngf), nn.ReLU(inplace=True)] 677 | for i in range(self.n_layers-1): 678 | model = model + [nn.Conv2d(ngf*mult, ngf*mult*2, 4, stride=2, padding=1), norm_layer(ngf*mult*2), nn.ReLU(inplace=True)] 679 | mult *= 2 680 | 681 | model = model + [nn.Conv2d(ngf * mult, 1, 4)] 682 | self.model = nn.Sequential(*model) 683 | 684 | 685 | def forward(self, x): 686 | 687 | return self.model(x) 688 | 689 | class lsgan_D(nn.Module): 690 | def __init__(self, input_nc=12, ngf=64, norm_layer=nn.BatchNorm2d, n_layers=4): 691 | super(lsgan_D, self).__init__() 692 | self.input_nc = input_nc 693 | self.ngf = ngf 694 | self.norm_layer = norm_layer 695 | self.n_layers = n_layers 696 | 697 | mult = 1 698 | features = [nn.Conv2d(input_nc, ngf, 5, stride=2, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True)] 699 | for i in range(self.n_layers-1): 700 | features = features + [nn.Conv2d(ngf*mult, ngf*mult*2, 5, stride=2, padding=1), norm_layer(ngf*mult*2), nn.LeakyReLU(negative_slope=0.2, inplace=True)] 701 | mult *= 2 702 | 703 | self.features = nn.Sequential(*features) 704 | 705 | self.fc = nn.Sequential(nn.Linear(512 * 14 * 6, 1)) 706 | 707 | 708 | def forward(self, x): 709 | x = self.features.forward(x) 710 | x = x.view(x.size(0), -1) 711 | x = self.fc(x) 712 | return x 713 | 714 | class lsganMultOutput_D(nn.Module): 715 | def __init__(self, input_nc=12, ngf=64, norm_layer=nn.BatchNorm2d, n_layers=4): 716 | super(lsganMultOutput_D, self).__init__() 717 | self.input_nc = input_nc 718 | self.ngf = ngf 719 | self.norm_layer = norm_layer 720 | self.n_layers = n_layers 721 | 722 | mult = 1 723 | features = [nn.Conv2d(input_nc, ngf, 5, stride=2, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True)] 724 | for i in range(self.n_layers-1): 725 | features = features + [nn.Conv2d(ngf*mult, ngf*mult*2, 5, stride=2, padding=1), norm_layer(ngf*mult*2), nn.LeakyReLU(negative_slope=0.2, inplace=True)] 726 | mult *= 2 727 | 728 | features += [nn.Conv2d(ngf*mult, 1, 5)] 729 | self.features = nn.Sequential(*features) 730 | 731 | 732 | 733 | def forward(self, x): 734 | return self.features.forward(x) 735 | 736 | --------------------------------------------------------------------------------