├── loader
├── __init__.py
├── label_loader.py
├── image_label_loader.py
├── onehot_label_loader.py
└── image_loader.py
├── models
├── __init__.py
├── base_model.py
├── flgan.py
└── networks.py
├── util
├── __init__.py
├── makedirs.py
├── log.py
└── confusion_matrix.py
├── .gitignore
├── README.md
├── .idea
├── vcs.xml
├── misc.xml
├── modules.xml
├── FL-GAN.iml
└── workspace.xml
├── test
├── test_g1.py
├── test_g2.py
└── test.py
└── train.py
/loader/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | checkpoints
2 | datasets
3 | *.pth
4 | *.log
5 | *.pyc
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cross-domain-Human-Parsing-via-Adversarial-Feature-and-Label-Adaptation
2 | The project is not completed yet, this is not the final version.
3 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/util/makedirs.py:
--------------------------------------------------------------------------------
1 | import os
2 | def mkdirs(paths):
3 | if isinstance(paths, list) and not isinstance(paths, str):
4 | for path in paths:
5 | mkdir(path)
6 | else:
7 | mkdir(paths)
8 |
9 |
10 | def mkdir(path):
11 | if not os.path.exists(path):
12 | os.makedirs(path)
13 |
--------------------------------------------------------------------------------
/.idea/FL-GAN.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/util/log.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | class Logger:
5 |
6 | def __init__(self, log_file='log/log.txt', formatter='%(asctime)s\t%(message)s', user='rgh'):
7 | self.user = user
8 | self.log_file = log_file
9 | self.formatter = formatter
10 | self.logger = self.init_logger()
11 |
12 | def init_logger(self):
13 | # create logger with name
14 | # if not specified, it will be root
15 | logger = logging.getLogger(self.user)
16 | logger.setLevel(logging.DEBUG)
17 |
18 | # create a handler, write to log.txt
19 | # logging.FileHandler(self, filename, mode='a', encoding=None, delay=0)
20 | # A handler class which writes formatted logging records to disk files.
21 | fh = logging.FileHandler(self.log_file)
22 | fh.setLevel(logging.DEBUG)
23 |
24 | # create another handler, for stdout in terminal
25 | # A handler class which writes logging records to a stream
26 | sh = logging.StreamHandler()
27 | sh.setLevel(logging.DEBUG)
28 |
29 | # set formatter
30 | # formatter = logging.Formatter('%(asctime)s-%(name)s-%(levelname)s- %(message)s')
31 | formatter = logging.Formatter(self.formatter)
32 | fh.setFormatter(formatter)
33 | sh.setFormatter(formatter)
34 |
35 | # add handler to logger
36 | logger.addHandler(fh)
37 | logger.addHandler(sh)
38 | return logger
39 |
40 | def info(self,message=''):
41 | self.logger.info(message)
42 |
43 | def debug(self,message=''):
44 | self.logger.debug(message)
45 |
--------------------------------------------------------------------------------
/loader/label_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import collections
3 | import json
4 | import torch
5 | import torchvision
6 | import numpy as np
7 | import scipy.misc as m
8 | import scipy.io as io
9 | import matplotlib.pyplot as plt
10 | import Image
11 | from tqdm import tqdm
12 | from torch.utils import data
13 |
14 | class labelLoader(data.Dataset):
15 | def __init__(self, root, dataName, phase='train', lbl_size=(241,121)):
16 | self.root = root
17 | self.dataName = dataName
18 | self.phase = phase
19 | self.files = collections.defaultdict(list)
20 | self.now_idx = 0
21 | file_list = tuple(open(root + '/' + dataName + '/' + self.phase + '.txt', 'r'))
22 | file_list = [id_.rstrip() for id_ in file_list]
23 | self.files[self.phase] = file_list
24 | # print self.files['train']
25 | def __len__(self):
26 | return len(self.files[self.phase])
27 |
28 | def __getitem__(self, index):
29 | lbl_name = self.files[self.phase][index]
30 | lbl_path = self.root + '/' + self.dataName + '/' + 'label/' + self.phase + '/' + lbl_name + '.npy'
31 |
32 | lbl = np.load(lbl_path)
33 | lbl = lbl.copy()
34 | lbl = torch.from_numpy(lbl).float()
35 |
36 | return lbl
37 |
38 | def getBatch(self, index):
39 | self.now_idx = (self.now_idx + 1)%len()
40 | lbl_name = self.files[self.phase][index]
41 | lbl_path = self.root + '/' + self.dataName + '/' + 'label/' + self.phase + '/' + lbl_name + '.npy'
42 |
43 | lbl = np.load(lbl_path)
44 | lbl = lbl.copy()
45 | lbl = torch.from_numpy(lbl).float()
46 |
47 | return lbl
48 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 |
5 | class BaseModel():
6 | def name(self):
7 | return 'BaseModel'
8 |
9 | def initialize(self, opt):
10 | self.opt = opt
11 | self.gpu_ids = opt['device_ids']
12 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
13 | self.save_dir = os.path.join(opt['checkpoints_dir'], opt['name'])
14 |
15 | def set_input(self, input):
16 | self.input = input
17 |
18 | def forward(self):
19 | pass
20 |
21 | # used in test time, no backprop
22 | def test(self):
23 | pass
24 |
25 | def get_image_paths(self):
26 | pass
27 |
28 | def optimize_parameters(self):
29 | pass
30 |
31 | def get_current_visuals(self):
32 | return self.input
33 |
34 | def get_current_errors(self):
35 | return {}
36 |
37 | def save(self, label):
38 | pass
39 |
40 | # helper saving function that can be used by subclasses
41 | def save_network(self, network, network_label, epoch_label, gpu_ids):
42 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
43 | save_path = os.path.join(self.save_dir, save_filename)
44 | torch.save(network.cpu().state_dict(), save_path)
45 | if len(gpu_ids) and torch.cuda.is_available():
46 | network.cuda(device_id=gpu_ids[0])
47 |
48 | # helper loading function that can be used by subclasses
49 | def load_network(self, network, network_label, epoch_label):
50 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
51 | save_path = os.path.join(self.save_dir, save_filename)
52 | network.load_state_dict(torch.load(save_path))
53 |
54 | def update_learning_rate(self):
55 | pass
56 |
--------------------------------------------------------------------------------
/loader/image_label_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import collections
3 | import json
4 | import torch
5 | import torchvision
6 | import numpy as np
7 | import scipy.misc as m
8 | import scipy.io as io
9 | import matplotlib.pyplot as plt
10 | import Image
11 | from tqdm import tqdm
12 | from torch.utils import data
13 |
14 | class imageLabelLoader(data.Dataset):
15 | def __init__(self, root, dataName, phase='train', img_size=(241,121)):
16 | self.root = root
17 | self.dataName = dataName
18 | self.phase = phase
19 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
20 | self.mean = np.array([128, 128, 128])
21 | self.files = collections.defaultdict(list)
22 | """
23 | for phase in ['train', 'val', 'train+5light']:
24 | file_list = tuple(open(root +'/' + dataName +'/'+ phase + '.txt', 'r'))
25 | file_list = [id_.rstrip() for id_ in file_list]
26 | self.files[phase] = file_list
27 | """
28 | file_list = tuple(open(root + '/' + dataName + '/' + self.phase + '.txt', 'r'))
29 | file_list = [id_.rstrip() for id_ in file_list]
30 | self.files[self.phase] = file_list
31 | # print self.files['train']
32 | def __len__(self):
33 | return len(self.files[self.phase])
34 |
35 | def __getitem__(self, index):
36 | img_name = self.files[self.phase][index]
37 | img_path = self.root + '/' + self.dataName + '/' + 'image/' + self.phase+'/'+img_name + '.jpg'
38 | lbl_path = self.root+ '/' + self.dataName + '/' + 'label/' + self.phase+'/'+img_name + '.png'
39 |
40 | img = Image.open(img_path)
41 | #if img.shape
42 | img_size = img.size
43 | if self.img_size[1] != img_size[0] or self.img_size[0] != img_size[1]:
44 | img = img.resize((self.img_size[1], self.img_size[0]), resample=Image.BILINEAR)
45 | img = np.array(img, dtype=np.float32)
46 | if not (len(img.shape) == 3 and img.shape[2] == 3):
47 | img = img.reshape(img.shape[0], img.shape[1], 1)
48 | img = img.repeat(3, 2)
49 | img -= self.mean
50 | img = img[:, :, ::-1]# RGB -> BGR
51 | img = img.transpose(2, 0, 1)
52 | img = img.copy()
53 |
54 | lbl = Image.open(lbl_path)
55 | lbl_size = lbl.size
56 |
57 | if self.img_size[1] != lbl_size[0] or self.img_size[0] != lbl_size[1]:
58 | lbl = lbl.resize((self.img_size[1], self.img_size[0]), resample=Image.BILINEAR)
59 |
60 | lbl = np.array(lbl)
61 | lbl = lbl.copy()
62 |
63 | img = torch.from_numpy(img).float()
64 | lbl = torch.from_numpy(lbl).long()
65 | return img, lbl
--------------------------------------------------------------------------------
/loader/onehot_label_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import collections
3 | import json
4 | import torch
5 | import torchvision
6 | import numpy as np
7 | import scipy.misc as m
8 | import scipy.io as io
9 | import matplotlib.pyplot as plt
10 | import Image
11 | from tqdm import tqdm
12 | from torch.utils import data
13 |
14 | class onehotLabelLoader(data.Dataset):
15 | def __init__(self, root, dataName, phase='train', label_nums=12, lbl_size=(241,121)):
16 | self.root = root
17 | self.dataName = dataName
18 | self.phase = phase
19 | self.label_nums = label_nums
20 | self.lbl_size = lbl_size if isinstance(lbl_size, tuple) else (lbl_size, lbl_size)
21 | self.files = collections.defaultdict(list)
22 | self.now_idx = 0
23 | file_list = tuple(open(root + '/' + dataName + '/' + self.phase + '.txt', 'r'))
24 | file_list = [id_.rstrip() for id_ in file_list]
25 | self.files[self.phase] = file_list
26 | # print self.files['train']
27 | def __len__(self):
28 | return len(self.files[self.phase])
29 |
30 | def __getitem__(self, index):
31 | lbl_name = self.files[self.phase][index]
32 | lbl_path = self.root + '/' + self.dataName + '/' + 'label/' + self.phase + '/' + lbl_name + '.png'
33 |
34 | lbl = Image.open(lbl_path)
35 | lbl_size = lbl.size
36 | if self.lbl_size[1] != lbl_size[0] or self.lbl_size[0] != lbl_size[1]:
37 | lbl = lbl.resize((self.lbl_size[1], self.lbl_size[0]), resample=Image.BILINEAR)
38 |
39 | lbl = np.array(lbl, dtype=np.float32)
40 | lbl_onehot = np.zeros((self.label_nums, self.lbl_size[0], self.lbl_size[1]))
41 | for i in range(self.label_nums):
42 | lbl_onehot[i][lbl==i] = 1
43 |
44 | lbl = lbl.copy()
45 |
46 | lbl = torch.from_numpy(lbl).float()
47 | lbl_onehot = torch.from_numpy(lbl_onehot).float()
48 |
49 | return lbl, lbl_onehot
50 |
51 | def getBatch(self, index):
52 | self.now_idx = (self.now_idx + 1)%len()
53 | lbl_name = self.files[self.phase][index]
54 | lbl_path = self.root + '/' + self.dataName + '/' + 'label/' + self.phase + '/' + lbl_name + '.png'
55 |
56 | lbl = Image.open(lbl_path)
57 | lbl_size = lbl.size
58 | if self.lbl_size[1] != lbl_size[0] or self.lbl_size[0] != lbl_size[1]:
59 | lbl = lbl.resize((self.lbl_size[1], self.lbl_size[0]), resample=Image.BILINEAR)
60 |
61 | lbl = np.array(lbl, dtype=np.float32)
62 | lbl_onehot = np.zeros((self.label_nums, self.lbl_size[0], self.lbl_size[1]))
63 | for i in range(self.label_nums):
64 | lbl_onehot[i][lbl == i] = 1
65 |
66 | lbl = lbl.copy()
67 |
68 | lbl = torch.from_numpy(lbl).float()
69 | lbl_onehot = torch.from_numpy(lbl_onehot).float()
70 |
71 | return lbl, lbl_onehot
72 |
--------------------------------------------------------------------------------
/loader/image_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import collections
3 | import json
4 | import torch
5 | import torchvision
6 | import numpy as np
7 | import scipy.misc as m
8 | import scipy.io as io
9 | import matplotlib.pyplot as plt
10 | import Image
11 | from tqdm import tqdm
12 | from torch.utils import data
13 |
14 | class imageLoader(data.Dataset):
15 | def __init__(self, root, dataName, phase='train', img_size=(241,121)):
16 | self.root = root
17 | self.dataName = dataName
18 | self.phase = phase
19 | self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
20 | self.mean = np.array([128, 128, 128])
21 | self.files = collections.defaultdict(list)
22 | self.now_idx = 0
23 | """
24 | for phase in ['train', 'val', 'train+unlabel']:
25 | file_list = tuple(open(root + '/' + dataName + '/' + phase + '.txt', 'r'))
26 | file_list = [id_.rstrip() for id_ in file_list]
27 | self.files[phase] = file_list
28 | """
29 | file_list = tuple(open(root + '/' + dataName + '/' + self.phase + '.txt', 'r'))
30 | file_list = [id_.rstrip() for id_ in file_list]
31 | self.files[self.phase] = file_list
32 | # print self.files['train']
33 | def __len__(self):
34 | return len(self.files[self.phase])
35 |
36 | def __getitem__(self, index):
37 | img_name = self.files[self.phase][index]
38 | img_path = self.root + '/' + self.dataName + '/' + 'image/' + self.phase + '/' + img_name + '.jpg'
39 |
40 | img = Image.open(img_path)
41 | img_size = img.size
42 | if self.img_size[1] != img_size[0] or self.img_size[0] != img_size[1]:
43 | img = img.resize((self.img_size[1], self.img_size[0]), resample=Image.BILINEAR)
44 |
45 | img = np.array(img, dtype=np.float32)
46 | if not (len(img.shape) == 3 and img.shape[2] == 3):
47 | img = img.reshape(img.shape[0], img.shape[1], 1)
48 | img = img.repeat(3, 2)
49 | img -= self.mean
50 | img = img[:, :, ::-1]# RGB -> BGR
51 | img = img.transpose(2, 0, 1)
52 | img = img.copy()
53 |
54 | img = torch.from_numpy(img).float()
55 |
56 | return img
57 |
58 | def getBatch(self, index):
59 | self.now_idx = (self.now_idx + 1)%len()
60 | img_name = self.files[self.phase][index]
61 | img_path = self.root + '/' + self.dataName + '/' + 'image/' + self.phase + '/' + img_name + '.jpg'
62 |
63 | img = Image.open(img_path)
64 | img = img.resize((self.img_size[1], self.img_size[0]), resample=Image.BILINEAR)
65 | img = np.array(img, dtype=np.float32)
66 | if not (len(img.shape) == 3 and img.shape[2] == 3):
67 | img = img.reshape(img.shape[0], img.shape[1], 1)
68 | img = img.repeat(3, 2)
69 | img -= self.mean
70 | img = img[:, :, ::-1]# RGB -> BGR
71 | img = img.transpose(2, 0, 1)
72 | img = img.copy()
73 |
74 | img = torch.from_numpy(img).float()
75 |
76 | return img
77 |
--------------------------------------------------------------------------------
/util/confusion_matrix.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | class ConfusionMatrix:
3 | def __init__(self, size=12):
4 | self.size = size
5 | self.diag = np.zeros(self.size)
6 | self.act_sum = np.zeros(self.size)
7 | self.pre_sum = np.zeros(self.size)
8 |
9 | def reset(self):
10 | self.diag = np.zeros(self.size)
11 | self.act_sum = np.zeros(self.size)
12 | self.pre_sum = np.zeros(self.size)
13 |
14 | def update(self, actual, predicted):
15 | for i in range(self.size):
16 | act = actual == i
17 | pre = predicted == i
18 | I = act & pre
19 | self.diag[i] += np.sum(I)
20 | self.act_sum[i] += np.sum(act)
21 | self.pre_sum[i] += np.sum(pre)
22 |
23 | def accuracy(self):
24 | ''' accuracy '''
25 | diag_sum = np.sum(self.diag)
26 | total_sum = np.sum(self.act_sum)
27 | if total_sum == 0:
28 | return 0
29 | else:
30 | return diag_sum / total_sum
31 |
32 | def fg_accuracy(self):
33 | '''fg_accuracy'''
34 | diag_sum = np.sum(self.diag) - self.diag[0]
35 | total_sum = np.sum(self.act_sum) - self.act_sum[0]
36 | if total_sum == 0:
37 | return 0
38 | else:
39 | return diag_sum / total_sum
40 |
41 | def avg_precision(self):
42 | '''avg_precision: ignore the label that isn't in imgs of gt'''
43 | total_precision = 0
44 | count = 0
45 | for i in range(self.size):
46 | if self.pre_sum[i] > 0:
47 | total_precision += self.diag[i] / self.pre_sum[i]
48 | count += 1
49 | if count == 0:
50 | return 0
51 | else:
52 | return total_precision / count
53 |
54 | def avg_recall(self):
55 | '''avg_recall: ignore the label that isn't in imgs of gt'''
56 | total_recall = 0
57 | count = 0
58 | for i in range(self.size):
59 | if self.act_sum[i] > 0:
60 | total_recall += self.diag[i] / self.act_sum[i]
61 | count += 1
62 | if count == 0:
63 | return 0
64 | else:
65 | return total_recall / count
66 |
67 | def avg_f1score(self):
68 | '''avgF1score: ignore the label that isn't in imgs of gt'''
69 | total_f1score = 0
70 | count = 0
71 | for i in range(self.size):
72 | t = self.pre_sum[i] + self.act_sum[i]
73 | if t > 0:
74 | total_f1score += 2 * self.diag[i] / t
75 | count += 1
76 | if count == 0:
77 | return 0
78 | else:
79 | return total_f1score / count
80 |
81 | def f1score(self):
82 | '''F1score: ignore the label that isn't in imgs of gt'''
83 | f1score = []
84 | for i in range(self.size):
85 | t = self.pre_sum[i] + self.act_sum[i]
86 | if t > 0:
87 | f1score.append(2 * self.diag[i] / t)
88 | else:
89 | f1score.append(-1)
90 | return f1score
91 |
92 |
93 |
94 | def mean_iou(self):
95 | '''meanIoU: ignore the label that isn't in imgs of gt'''
96 | total_iou = 0
97 | count = 0
98 | for i in range(self.size):
99 | I = self.diag[i]
100 | U = self.act_sum[i] + self.pre_sum[i] - I
101 | if U > 0:
102 | total_iou += I / U
103 | count += 1
104 | return total_iou / count
105 |
106 | def all_acc(self):
107 | return {
108 | 'accuracy':self.accuracy(),
109 | 'fg_accuracy':self.fg_accuracy(),
110 | 'avg_precision':self.avg_precision(),
111 | 'avg_recall':self.avg_recall(),
112 | 'avg_f1score':self.avg_f1score(),
113 | 'mean_iou':self.mean_iou(),
114 | }
--------------------------------------------------------------------------------
/test/test_g1.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils import data
3 | from loader.image_label_loader import imageLabelLoader
4 | from models.deeplab_gan_s2t_with_refine_4 import deeplabGanS2TWithRefine4
5 | from util.confusion_matrix import ConfusionMatrix
6 | import torch
7 | import numpy as np
8 | import scipy.misc
9 | def color(label):
10 | bg = label == 0
11 | bg = bg.reshape(bg.shape[0], bg.shape[1])
12 | face = label == 1
13 | face = face.reshape(face.shape[0], face.shape[1])
14 | hair = label == 2
15 | hair = hair.reshape(hair.shape[0], hair.shape[1])
16 | Upcloth = label == 3
17 | Upcloth = Upcloth.reshape(Upcloth.shape[0], Upcloth.shape[1])
18 | Larm = label == 4
19 | Larm = Larm.reshape(Larm.shape[0], Larm.shape[1])
20 | Rarm = label == 5
21 | Rarm = Rarm.reshape(Rarm.shape[0], Rarm.shape[1])
22 | pants = label == 6
23 | pants = pants.reshape(pants.shape[0], pants.shape[1])
24 | Lleg = label == 7
25 | Lleg = Lleg.reshape(Lleg.shape[0], Lleg.shape[1])
26 | Rleg = label == 8
27 | Rleg = Rleg.reshape(Rleg.shape[0], Rleg.shape[1])
28 | dress = label == 9
29 | dress = dress.reshape(dress.shape[0], dress.shape[1])
30 | Lshoe = label == 10
31 | Lshoe = Lshoe.reshape(Lshoe.shape[0], Lshoe.shape[1])
32 | Rshoe = label == 11
33 | Rshoe = Rshoe.reshape(Rshoe.shape[0], Rshoe.shape[1])
34 |
35 | # bag = label == 12
36 | # bag = bag.reshape(bag.shape[0], bag.shape[1])
37 |
38 | # repeat 2nd axis to 3
39 | label = label.reshape(bg.shape[0], bg.shape[1], 1)
40 | label = label.repeat(3, 2)
41 | R = label[:, :, 2]
42 | G = label[:, :, 1]
43 | B = label[:, :, 0]
44 | R[bg] = 230
45 | G[bg] = 230
46 | B[bg] = 230
47 |
48 | R[face] = 255
49 | G[face] = 215
50 | B[face] = 0
51 |
52 | R[hair] = 80
53 | G[hair] = 49
54 | B[hair] = 49
55 |
56 | R[Upcloth] = 51
57 | G[Upcloth] = 0
58 | B[Upcloth] = 255
59 |
60 | R[Larm] = 2
61 | G[Larm] = 251
62 | B[Larm] = 49
63 |
64 | R[Rarm] = 141
65 | G[Rarm] = 255
66 | B[Rarm] = 212
67 |
68 | R[pants] = 160
69 | G[pants] = 0
70 | B[pants] = 255
71 |
72 | R[Lleg] = 0
73 | G[Lleg] = 204
74 | B[Lleg] = 255
75 |
76 | R[Rleg] = 191
77 | G[Rleg] = 255
78 | B[Rleg] = 248
79 |
80 | R[dress] = 255
81 | G[dress] = 182
82 | B[dress] = 185
83 |
84 | R[Lshoe] = 180
85 | G[Lshoe] = 122
86 | B[Lshoe] = 121
87 |
88 | R[Rshoe] = 202
89 | G[Rshoe] = 160
90 | B[Rshoe] = 57
91 |
92 | # R[bag] = 255
93 | # G[bag] = 1
94 | # B[bag] = 1
95 | return label
96 | def update_confusion_matrix(matrix, output, target):
97 | values, indices = output.max(1)
98 | output = indices
99 | target = target.cpu().numpy()
100 | output = output.cpu().numpy()
101 | matrix.update(target, output)
102 | return matrix
103 |
104 | def main():
105 | if len(args['device_ids']) > 0:
106 | torch.cuda.set_device(args['device_ids'][0])
107 |
108 | test_loader = data.DataLoader(imageLabelLoader(args['data_path'], dataName=args['domainB'], phase='val'),
109 | batch_size=args['batch_size'],
110 | num_workers=args['num_workers'], shuffle=False)
111 | gym = deeplabGanS2TWithRefine4()
112 | gym.initialize(args)
113 | gym.load('/home/ben/mathfinder/PROJECT/AAAI2017/our_Method/v3/deeplab_feature_adaptation/checkpoints/Lip_to_July_g1/best_Ori_on_B_model.pth')
114 | gym.eval()
115 | matrix = ConfusionMatrix(args['label_nums'])
116 | for i, (image, label) in enumerate(test_loader):
117 | label = label.cuda(async=True)
118 | target_var = torch.autograd.Variable(label, volatile=True)
119 |
120 | gym.test(False, image)
121 | output = gym.output
122 |
123 | matrix = update_confusion_matrix(matrix, output.data, label)
124 | print(matrix.avg_f1score())
125 | print(matrix.f1score())
126 |
127 |
128 | if __name__ == "__main__":
129 | global args
130 | args = {
131 | 'test_init':False,
132 | 'label_nums':12,
133 | 'l_rate':1e-8,
134 | 'lr_gan': 0.00001,
135 | 'lr_refine': 1e-6,
136 | 'beta1': 0.5,
137 | 'data_path':'datasets',
138 | 'n_epoch':1000,
139 | 'batch_size':10,
140 | 'num_workers':10,
141 | 'print_freq':10,
142 | 'device_ids':[1],
143 | 'domainA': 'Lip',
144 | 'domainB': 'July',
145 | 'weigths_pool': 'pretrain_models',
146 | 'pretrain_model': 'deeplab.pth',
147 | 'fineSizeH':241,
148 | 'fineSizeW':121,
149 | 'input_nc':3,
150 | 'name': 'v3_s->t_Refine_4',
151 | 'checkpoints_dir': 'checkpoints',
152 | 'net_D': 'NoBNSinglePathdilationMultOutputNet',
153 | 'use_lsgan': True,
154 | 'resume':None#'checkpoints/v3_1/',
155 | }
156 | main()
--------------------------------------------------------------------------------
/test/test_g2.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils import data
3 | from loader.image_label_loader import imageLabelLoader
4 | from models.deeplab_g2 import deeplabG2
5 | from util.confusion_matrix import ConfusionMatrix
6 | import torch
7 | import numpy as np
8 | import scipy.misc
9 | def color(label):
10 | bg = label == 0
11 | bg = bg.reshape(bg.shape[0], bg.shape[1])
12 | face = label == 1
13 | face = face.reshape(face.shape[0], face.shape[1])
14 | hair = label == 2
15 | hair = hair.reshape(hair.shape[0], hair.shape[1])
16 | Upcloth = label == 3
17 | Upcloth = Upcloth.reshape(Upcloth.shape[0], Upcloth.shape[1])
18 | Larm = label == 4
19 | Larm = Larm.reshape(Larm.shape[0], Larm.shape[1])
20 | Rarm = label == 5
21 | Rarm = Rarm.reshape(Rarm.shape[0], Rarm.shape[1])
22 | pants = label == 6
23 | pants = pants.reshape(pants.shape[0], pants.shape[1])
24 | Lleg = label == 7
25 | Lleg = Lleg.reshape(Lleg.shape[0], Lleg.shape[1])
26 | Rleg = label == 8
27 | Rleg = Rleg.reshape(Rleg.shape[0], Rleg.shape[1])
28 | dress = label == 9
29 | dress = dress.reshape(dress.shape[0], dress.shape[1])
30 | Lshoe = label == 10
31 | Lshoe = Lshoe.reshape(Lshoe.shape[0], Lshoe.shape[1])
32 | Rshoe = label == 11
33 | Rshoe = Rshoe.reshape(Rshoe.shape[0], Rshoe.shape[1])
34 |
35 | # bag = label == 12
36 | # bag = bag.reshape(bag.shape[0], bag.shape[1])
37 |
38 | # repeat 2nd axis to 3
39 | label = label.reshape(bg.shape[0], bg.shape[1], 1)
40 | label = label.repeat(3, 2)
41 | R = label[:, :, 2]
42 | G = label[:, :, 1]
43 | B = label[:, :, 0]
44 | R[bg] = 230
45 | G[bg] = 230
46 | B[bg] = 230
47 |
48 | R[face] = 255
49 | G[face] = 215
50 | B[face] = 0
51 |
52 | R[hair] = 80
53 | G[hair] = 49
54 | B[hair] = 49
55 |
56 | R[Upcloth] = 51
57 | G[Upcloth] = 0
58 | B[Upcloth] = 255
59 |
60 | R[Larm] = 2
61 | G[Larm] = 251
62 | B[Larm] = 49
63 |
64 | R[Rarm] = 141
65 | G[Rarm] = 255
66 | B[Rarm] = 212
67 |
68 | R[pants] = 160
69 | G[pants] = 0
70 | B[pants] = 255
71 |
72 | R[Lleg] = 0
73 | G[Lleg] = 204
74 | B[Lleg] = 255
75 |
76 | R[Rleg] = 191
77 | G[Rleg] = 255
78 | B[Rleg] = 248
79 |
80 | R[dress] = 255
81 | G[dress] = 182
82 | B[dress] = 185
83 |
84 | R[Lshoe] = 180
85 | G[Lshoe] = 122
86 | B[Lshoe] = 121
87 |
88 | R[Rshoe] = 202
89 | G[Rshoe] = 160
90 | B[Rshoe] = 57
91 |
92 | # R[bag] = 255
93 | # G[bag] = 1
94 | # B[bag] = 1
95 | return label
96 | def update_confusion_matrix(matrix, output, target):
97 | values, indices = output.max(1)
98 | output = indices
99 | target = target.cpu().numpy()
100 | output = output.cpu().numpy()
101 | matrix.update(target, output)
102 | return matrix
103 |
104 | def main():
105 | if len(args['device_ids']) > 0:
106 | torch.cuda.set_device(args['device_ids'][0])
107 |
108 | test_loader = data.DataLoader(imageLabelLoader(args['data_path'], dataName=args['domainB'], phase='val'),
109 | batch_size=args['batch_size'],
110 | num_workers=args['num_workers'], shuffle=False)
111 | gym = deeplabG2()
112 | gym.initialize(args)
113 | gym.load('/home/ben/mathfinder/PROJECT/AAAI2017/our_Method/v3/deeplab_feature_adaptation/checkpoints/g2_lr_gan=0.00000002_interval_G=5_interval_D=5_net_D=lsganMultOutput_D/best_Ori_on_B_model.pth')
114 | gym.eval()
115 | matrix = ConfusionMatrix(args['label_nums'])
116 | for i, (image, label) in enumerate(test_loader):
117 | label = label.cuda(async=True)
118 | target_var = torch.autograd.Variable(label, volatile=True)
119 |
120 | gym.test(image)
121 | output = gym.output
122 |
123 | matrix = update_confusion_matrix(matrix, output.data, label)
124 | print(matrix.avg_f1score())
125 | print(matrix.f1score())
126 |
127 |
128 | if __name__ == "__main__":
129 | global args
130 | args = {
131 | 'test_init':False,
132 | 'label_nums':12,
133 | 'l_rate':1e-8,
134 | 'lr_gan': 0.00000002,
135 | 'beta1': 0.5,
136 | 'interval_G':5,
137 | 'interval_D':5,
138 | 'data_path':'datasets',
139 | 'n_epoch':1000,
140 | 'batch_size':10,
141 | 'num_workers':10,
142 | 'print_freq':100,
143 | 'device_ids':[1],
144 | 'domainA': 'Lip',
145 | 'domainB': 'Indoor',
146 | 'weigths_pool': 'pretrain_models',
147 | 'pretrain_model': 'deeplab.pth',
148 | 'fineSizeH':241,
149 | 'fineSizeW':121,
150 | 'input_nc':3,
151 | 'name': 'train_iou0.4_onehot_g2_lr_gan=0.00000002_interval_G=5_interval_D=5_net_D=lsganMultOutput_D',
152 | 'checkpoints_dir': 'checkpoints',
153 | 'net_D': 'lsganMultOutput_D',
154 | 'use_lsgan': True,
155 | 'resume':None,#'checkpoints/g2_lr_gan=0.0000002_interval_G=5_interval_D=10_net_D=lsganMultOutput_D/best_Ori_on_B_model.pth',#'checkpoints/v3_1/',
156 | 'if_adv_train':True,
157 | 'if_adaptive':True,
158 | }
159 | main()
--------------------------------------------------------------------------------
/test/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils import data
3 | from loader.image_label_loader import imageLabelLoader
4 | from models.deeplab_g1_g2 import deeplabG1G2
5 | from util.confusion_matrix import ConfusionMatrix
6 | import torch
7 | import numpy as np
8 | import scipy.misc
9 | def color(label):
10 | bg = label == 0
11 | bg = bg.reshape(bg.shape[0], bg.shape[1])
12 | face = label == 1
13 | face = face.reshape(face.shape[0], face.shape[1])
14 | hair = label == 2
15 | hair = hair.reshape(hair.shape[0], hair.shape[1])
16 | Upcloth = label == 3
17 | Upcloth = Upcloth.reshape(Upcloth.shape[0], Upcloth.shape[1])
18 | Larm = label == 4
19 | Larm = Larm.reshape(Larm.shape[0], Larm.shape[1])
20 | Rarm = label == 5
21 | Rarm = Rarm.reshape(Rarm.shape[0], Rarm.shape[1])
22 | pants = label == 6
23 | pants = pants.reshape(pants.shape[0], pants.shape[1])
24 | Lleg = label == 7
25 | Lleg = Lleg.reshape(Lleg.shape[0], Lleg.shape[1])
26 | Rleg = label == 8
27 | Rleg = Rleg.reshape(Rleg.shape[0], Rleg.shape[1])
28 | dress = label == 9
29 | dress = dress.reshape(dress.shape[0], dress.shape[1])
30 | Lshoe = label == 10
31 | Lshoe = Lshoe.reshape(Lshoe.shape[0], Lshoe.shape[1])
32 | Rshoe = label == 11
33 | Rshoe = Rshoe.reshape(Rshoe.shape[0], Rshoe.shape[1])
34 |
35 | # bag = label == 12
36 | # bag = bag.reshape(bag.shape[0], bag.shape[1])
37 |
38 | # repeat 2nd axis to 3
39 | label = label.reshape(bg.shape[0], bg.shape[1], 1)
40 | label = label.repeat(3, 2)
41 | R = label[:, :, 0]
42 | G = label[:, :, 1]
43 | B = label[:, :, 2]
44 | R[bg] = 230
45 | G[bg] = 230
46 | B[bg] = 230
47 |
48 | R[face] = 255
49 | G[face] = 215
50 | B[face] = 0
51 |
52 | R[hair] = 80
53 | G[hair] = 49
54 | B[hair] = 49
55 |
56 | R[Upcloth] = 51
57 | G[Upcloth] = 0
58 | B[Upcloth] = 255
59 |
60 | R[Larm] = 2
61 | G[Larm] = 251
62 | B[Larm] = 49
63 |
64 | R[Rarm] = 141
65 | G[Rarm] = 255
66 | B[Rarm] = 212
67 |
68 | R[pants] = 160
69 | G[pants] = 0
70 | B[pants] = 255
71 |
72 | R[Lleg] = 0
73 | G[Lleg] = 204
74 | B[Lleg] = 255
75 |
76 | R[Rleg] = 191
77 | G[Rleg] = 255
78 | B[Rleg] = 248
79 |
80 | R[dress] = 255
81 | G[dress] = 182
82 | B[dress] = 185
83 |
84 | R[Lshoe] = 180
85 | G[Lshoe] = 122
86 | B[Lshoe] = 121
87 |
88 | R[Rshoe] = 202
89 | G[Rshoe] = 160
90 | B[Rshoe] = 57
91 |
92 | # R[bag] = 255
93 | # G[bag] = 1
94 | # B[bag] = 1
95 | return label
96 | def update_confusion_matrix(matrix, output, target):
97 | values, indices = output.max(1)
98 | output = indices
99 | target = target.cpu().numpy()
100 | output = output.cpu().numpy()
101 | matrix.update(target, output)
102 | return matrix
103 |
104 | def main():
105 | if len(args['device_ids']) > 0:
106 | torch.cuda.set_device(args['device_ids'][0])
107 |
108 | test_loader = data.DataLoader(imageLabelLoader(args['data_path'], dataName=args['domainB'], phase='val'),
109 | batch_size=args['batch_size'],
110 | num_workers=args['num_workers'], shuffle=False)
111 | gym = deeplabG1G2()
112 | gym.initialize(args)
113 | gym.load('/home/ben/mathfinder/PROJECT/AAAI2017/our_Method/v3/deeplab_feature_adaptation/checkpoints/Lip_to_July_lr_g1=0.00001_lr_g2=0.00000002_interval_g1=5_interval_d1=5_net_D=lsganMultOutput_D_if_adaptive=True_resume_decay=g2/best_Ori_on_B_model.pth')
114 | gym.eval()
115 | matrix = ConfusionMatrix(args['label_nums'])
116 | for i, (image, label) in enumerate(test_loader):
117 | label = label.cuda(async=True)
118 | target_var = torch.autograd.Variable(label, volatile=True)
119 |
120 | gym.test(image)
121 | output = gym.output
122 |
123 | matrix = update_confusion_matrix(matrix, output.data, label)
124 | print(matrix.all_acc())
125 | print(matrix.f1score())
126 |
127 |
128 | if __name__ == "__main__":
129 | global args
130 | args = {
131 | 'test_init':False,
132 | 'test_init':False,
133 | 'label_nums':12,
134 | 'l_rate':1e-8,
135 | 'lr_g1': 0.00001,
136 | 'lr_g2': 0.00000002,
137 | 'beta1': 0.5,
138 | 'interval_g2':5,
139 | 'interval_d2':5,
140 | 'data_path':'datasets',
141 | 'n_epoch':1000,
142 | 'batch_size':10,
143 | 'num_workers':10,
144 | 'print_freq':20,
145 | 'device_ids':[0],
146 | 'domainA': 'Lip',
147 | 'domainB': 'July',
148 | 'weigths_pool': 'pretrain_models',
149 | 'pretrain_model': 'deeplab.pth',
150 | 'fineSizeH':241,
151 | 'fineSizeW':121,
152 | 'input_nc':3,
153 | 'name': 'Lip_to_July_lr_g1=0.00001_lr_g2=0.00000002_interval_g1=5_interval_d1=5_net_D=lsganMultOutput_D_if_adaptive=True_resume_decay=g2',
154 | 'checkpoints_dir': 'checkpoints',
155 | 'net_d1': 'NoBNSinglePathdilationMultOutputNet',
156 | 'net_d2': 'lsganMultOutput_D',
157 | 'use_lsgan': True,
158 | 'resume':None,#'checkpoints/g2_lr_gan=0.0000002_interval_G=5_interval_D=10_net_D=lsganMultOutput_D/best_Ori_on_B_model.pth',#'checkpoints/v3_1/',
159 | 'if_adv_train':True,
160 | 'if_adaptive':False,
161 | }
162 | main()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import torch
3 | import time
4 | import torch.nn as nn
5 | from models.flgan import FLGAN
6 | from torch.autograd import Variable
7 | from torch.utils import data
8 | from loader.image_label_loader import imageLabelLoader
9 | from loader.image_loader import imageLoader
10 | from loader.label_loader import labelLoader
11 | from util.confusion_matrix import ConfusionMatrix
12 | import util.makedirs as makedirs
13 | import os
14 | import torchvision.models as models
15 | import matplotlib.pyplot as plt
16 | from util.log import Logger
17 | import numpy as np
18 | import Image
19 | def save_checkpoint(state, filename):
20 | torch.save(state, filename)
21 |
22 |
23 | def update_confusion_matrix(matrix, output, target):
24 | values, indices = output.max(1)
25 | output = indices
26 | target = target.cpu().numpy()
27 | output = output.cpu().numpy()
28 | matrix.update(target, output)
29 | return matrix
30 |
31 | def validate(val_loader, model, criterion, adaptation):
32 | # switch to evaluate mode
33 | run_time = time.time()
34 | matrix = ConfusionMatrix(args['label_nums'])
35 | loss = 0
36 | for i, (images, labels) in enumerate(val_loader):
37 | labels = labels.cuda(async=True)
38 | target_var = torch.autograd.Variable(labels, volatile=True)
39 |
40 | model.test(images)
41 | output = model.output
42 | loss += criterion(output, target_var)/args['batch_size']
43 | matrix = update_confusion_matrix(matrix, output.data, labels)
44 | loss /= (i+1)
45 | run_time = time.time() - run_time
46 | logger.info('=================================================')
47 | logger.info('val:'
48 | 'loss: {0:.4f}\t'
49 | 'accuracy: {1:.4f}\t'
50 | 'fg_accuracy: {2:.4f}\t'
51 | 'avg_precision: {3:.4f}\t'
52 | 'avg_recall: {4:.4f}\t'
53 | 'avg_f1score: {5:.4f}\t'
54 | 'run_time:{run_time:.2f}\t'
55 | .format(loss.data[0], matrix.accuracy(),
56 | matrix.fg_accuracy(), matrix.avg_precision(), matrix.avg_recall(), matrix.avg_f1score(),run_time=run_time))
57 | logger.info('=================================================')
58 | return matrix.all_acc()
59 |
60 |
61 | def main():
62 |
63 | makedirs.mkdirs(os.path.join(args['checkpoints_dir'], args['name']))
64 | if len(args['device_ids']) > 0:
65 | torch.cuda.set_device(args['device_ids'][0])
66 |
67 | A_train_loader = data.DataLoader(imageLabelLoader(args['data_path'],dataName=args['domainA'], phase='train'), batch_size=args['batch_size'],
68 | num_workers=args['num_workers'], shuffle=True)
69 | label_train_loader = data.DataLoader(labelLoader(args['data_path'], dataName=args['domainA'], phase='train_onehot'),
70 | batch_size=args['batch_size'],
71 | num_workers=args['num_workers'], shuffle=True)
72 |
73 | A_val_loader = data.DataLoader(imageLabelLoader(args['data_path'], dataName=args['domainA'], phase='val'), batch_size=args['batch_size'],
74 | num_workers=args['num_workers'], shuffle=False)
75 |
76 | B_train_loader = data.DataLoader(imageLoader(args['data_path'], dataName=args['domainB'], phase='train+unlabel'),
77 | batch_size=args['batch_size'],
78 | num_workers=args['num_workers'], shuffle=True)
79 | B_val_loader = data.DataLoader(imageLabelLoader(args['data_path'], dataName=args['domainB'], phase='val'),
80 | batch_size=args['batch_size'],
81 | num_workers=args['num_workers'], shuffle=False)
82 | model = FLGAN()
83 | model.initialize(args)
84 |
85 | # multi GPUS
86 | # model = torch.nn.DataParallel(model,device_ids=args['device_ids']).cuda()
87 | Iter = 0
88 | Epoch = 0
89 | best_Ori_on_B = 0
90 | prec_Ori_on_B = 0
91 | if args['resume']:
92 | if os.path.isfile(args['resume']):
93 | logger.info("=> loading checkpoint '{}'".format(args['resume']))
94 | Iter, Epoch, best_Ori_on_B = model.load(args['resume'])
95 | prec_Ori_on_B = best_Ori_on_B
96 | if (args['if_adaptive'] and (Epoch + 1) % 30 == 0) or prec_Ori_on_B > 0.56:
97 | model.update_learning_rate()
98 | else:
99 | print("=> no checkpoint found at '{}'".format(args['resume']))
100 |
101 | model.train()
102 | for epoch in range(Epoch, args['n_epoch']):
103 | # train(A_train_loader, B_train_loader, model, epoch)
104 | # switch to train mode
105 | for i, (A_image, A_label) in enumerate(A_train_loader):
106 | Iter += 1
107 | B_image = next(iter(B_train_loader))
108 | if Iter % args['interval_d2'] == 0 and args['if_adv_train']:
109 | label_onehot = next(iter(label_train_loader))
110 | model.set_input({'A': A_image, 'A_label': A_label, 'label_onehot':label_onehot, 'B': B_image})
111 | else:
112 | model.set_input({'A': A_image, 'A_label': A_label, 'B': B_image})
113 |
114 | model.step()
115 | output = model.output
116 | if (i+1) % args['print_freq'] == 0:
117 | matrix = ConfusionMatrix()
118 | update_confusion_matrix(matrix, output.data, A_label)
119 | logger.info('Time: {time}\t'
120 | 'Epoch/Iter: [{epoch}/{Iter}]\t'
121 | 'loss: {loss:.4f}\t'
122 | 'acc: {accuracy:.4f}\t'
123 | 'fg_acc: {fg_accuracy:.4f}\t'
124 | 'avg_prec: {avg_precision:.4f}\t'
125 | 'avg_rec: {avg_recall:.4f}\t'
126 | 'avg_f1: {avg_f1core:.4f}\t'
127 | 'loss_G1: {loss_G1:.4f}\t'
128 | 'loss_D1: {loss_D1:.4f}\t'
129 | 'loss_G2: {loss_G2:.4f}\t'
130 | 'loss_D2: {loss_D2:.4f}\t'
131 | .format(
132 | time=time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()),
133 | epoch=epoch, Iter=Iter, loss=model.loss_P.data[0],
134 | accuracy=matrix.accuracy(),
135 | fg_accuracy=matrix.fg_accuracy(), avg_precision=matrix.avg_precision(),
136 | avg_recall=matrix.avg_recall(), avg_f1core=matrix.avg_f1score(),
137 | loss_G1=model.loss_G1.data[0], loss_D1=model.loss_D1.data[0],
138 | loss_G2=model.loss_G2.data[0], loss_D2=model.loss_D2.data[0]))
139 |
140 | if Iter % 1000 == 0:
141 | model.eval()
142 | acc_Ori_on_A = validate(A_val_loader, model, nn.CrossEntropyLoss(size_average=False), False)
143 | acc_Ori_on_B = validate(B_val_loader, model, nn.CrossEntropyLoss(size_average=False), False)
144 | prec_Ori_on_B = acc_Ori_on_B['avg_f1score']
145 |
146 | is_best = prec_Ori_on_B > best_Ori_on_B
147 | best_Ori_on_B = max(prec_Ori_on_B, best_Ori_on_B)
148 | if is_best:
149 | model.save('best_Ori_on_B', Iter=Iter, epoch=epoch, acc={'acc_Ori_on_A':acc_Ori_on_A, 'acc_Ori_on_B':acc_Ori_on_B})
150 | model.train()
151 | if (args['if_adaptive'] and (epoch+1) % 30 == 0):
152 | model.update_learning_rate()
153 |
154 |
155 |
156 |
157 | if __name__ == '__main__':
158 | global args
159 | args = {
160 | 'test_init':False,
161 | 'label_nums':12,
162 | 'l_rate':1e-8,
163 | 'lr_g1': 0.00001,
164 | 'lr_g2': 0.00000002,
165 | 'beta1': 0.5,
166 | 'interval_g2':5,
167 | 'interval_d2':5,
168 | 'data_path':'datasets',
169 | 'n_epoch':1000,
170 | 'batch_size':10,
171 | 'num_workers':10,
172 | 'print_freq':100,
173 | 'device_ids':[1],
174 | 'domainA': 'Lip',
175 | 'domainB': 'July',
176 | 'weigths_pool': 'pretrain_models',
177 | 'pretrain_model': 'deeplab.pth',
178 | 'log':'log',
179 | 'fineSizeH':241,
180 | 'fineSizeW':121,
181 | 'input_nc':3,
182 | 'name': 'Lip_to_July_lr_g1=0.00001_lr_g2=0.00000002_interval_g1=5_interval_d1=5_net_D=lsganMultOutput_D_if_adaptive=True_resume_decay=g2',
183 | 'checkpoints_dir': 'checkpoints',
184 | 'net_d1': 'NoBNSinglePathdilationMultOutputNet',
185 | 'net_d2': 'lsganMultOutput_D',
186 | 'use_lsgan': True,
187 | 'resume':None,#'checkpoints/lr_g1=0.00001_lr_g2=0.00000001_interval_g1=6_interval_d1=6_net_D=lsganMultOutput_D_if_adaptive=True/best_Ori_on_B_model.pth',#'checkpoints/v3_1/',
188 | 'if_adv_train':True,
189 | 'if_adaptive':True,
190 | }
191 | if not os.path.exists(args['checkpoints_dir']):
192 | os.makedirs(args['checkpoints_dir'])
193 | if not os.path.exists(args['log']):
194 | os.makedirs(args['log'])
195 | if not os.path.exists(os.path.join(args['data_path'], args['domainA'], 'label', 'train_onehot')):
196 | print('Creat onehot label from domainA...')
197 | onehot_label_path = os.path.join(args['data_path'], args['domainA'], 'label', 'train_onehot')
198 | os.makedirs(onehot_label_path)
199 | label_path = os.path.join(args['data_path'], args['domainA'], 'label', 'train')
200 | for name in os.listdir(label_path):
201 | lbl = Image.open(os.path.join(label_path, name))
202 | lbl = np.array(lbl)
203 | lbl_onehot = np.zeros((12, 241, 121))
204 | for i in range(args['label_nums']):
205 | lbl_onehot[i][lbl == i] = 1
206 | np.save(os.path.join(onehot_label_path, name.split('.png')[0] + ".npy"), lbl_onehot)
207 | print('Done!')
208 |
209 | logger = Logger(
210 | log_file=args['log'] + '/' + args['name'] + '-' + time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) + '.log')
211 | logger.info('------------ Options -------------\n')
212 | for k, v in args.items():
213 | logger.info('%s: %s' % (str(k), str(v)))
214 | logger.info('-------------- End ----------------\n')
215 | main()
216 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 | 1505531067661
157 |
158 |
159 | 1505531067661
160 |
161 |
162 | 1505531567073
163 |
164 |
165 |
166 | 1505531567073
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
--------------------------------------------------------------------------------
/models/flgan.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from .base_model import BaseModel
5 | import networks
6 | import itertools
7 | from torch.autograd import Variable
8 |
9 | def get_parameters(model, parameter_name):
10 | for name, param in model.named_parameters():
11 | if name in [parameter_name]:
12 | return param
13 |
14 | def weights_init(m):
15 | classname = m.__class__.__name__
16 | if classname.find('Conv') != -1:
17 | m.weight.data.normal_(0.0, 0.02)
18 | elif classname.find('BatchNorm2d') != -1:
19 | m.weight.data.normal_(1.0, 0.02)
20 | m.bias.data.fill_(0)
21 | elif classname.find('Linear') != -1:
22 | m.weight.data.normal_(0.0, 0.02)
23 |
24 | def define_D(which_netD, input_nc):
25 | if which_netD == 'NoBNSinglePathdilationMultOutputNet':
26 | return networks.NoBNSinglePathdilationMultOutputNet(input_nc)
27 | elif which_netD == 'lsganMultOutput_D':
28 | return networks.lsganMultOutput_D(input_nc)
29 |
30 |
31 | class FLGAN(BaseModel):
32 | def name(self):
33 | return 'flgan'
34 |
35 | def initialize(self, args):
36 | BaseModel.initialize(self, args)
37 | self.if_adv_train = args['if_adv_train']
38 | self.Iter = 0
39 | self.interval_g2 = args['interval_g2']
40 | self.interval_d2 = args['interval_d2']
41 | self.nb = args['batch_size']
42 | sizeH, sizeW = args['fineSizeH'], args['fineSizeW']
43 |
44 | self.tImageA = self.Tensor(self.nb, args['input_nc'], sizeH, sizeW)
45 | self.tImageB = self.Tensor(self.nb, args['input_nc'], sizeH, sizeW)
46 | self.tLabelA = torch.cuda.LongTensor(self.nb, 1, sizeH, sizeW)
47 | self.tOnehotLabelA = self.Tensor(self.nb, args['label_nums'], sizeH, sizeW)
48 | self.loss_G = Variable()
49 | self.loss_D = Variable()
50 |
51 | self.netG1 = networks.netG().cuda(device_id=args['device_ids'][0])
52 | self.netD1 = define_D(args['net_d1'],512).cuda(device_id=args['device_ids'][0])
53 | self.netD2 = define_D(args['net_d2'],args['label_nums']).cuda(device_id=args['device_ids'][0])
54 |
55 | self.deeplabPart1 = networks.DeeplabPool1().cuda(device_id=args['device_ids'][0])
56 | self.deeplabPart2 = networks.DeeplabPool12Pool5().cuda(device_id=args['device_ids'][0])
57 | self.deeplabPart3 = networks.DeeplabPool52Fc8_interp(output_nc=args['label_nums']).cuda(device_id=args['device_ids'][0])
58 |
59 | # define loss functions
60 | self.criterionCE = torch.nn.CrossEntropyLoss(size_average=False)
61 | self.criterionAdv = networks.Advloss(use_lsgan=args['use_lsgan'], tensor=self.Tensor)
62 |
63 |
64 | if not args['resume']:
65 | #initialize networks
66 | self.netG1.apply(weights_init)
67 | self.netD1.apply(weights_init)
68 | self.netD2.apply(weights_init)
69 | pretrained_dict = torch.load(args['weigths_pool'] + '/' + args['pretrain_model'])
70 | self.deeplabPart1.weights_init(pretrained_dict=pretrained_dict)
71 | self.deeplabPart2.weights_init(pretrained_dict=pretrained_dict)
72 | self.deeplabPart3.weights_init(pretrained_dict=pretrained_dict)
73 |
74 | # initialize optimizers
75 | self.optimizer_G1 = torch.optim.Adam(self.netG1.parameters(),
76 | lr=args['lr_g1'], betas=(args['beta1'], 0.999))
77 | self.optimizer_D1 = torch.optim.Adam(self.netD1.parameters(),
78 | lr=args['lr_g1'], betas=(args['beta1'], 0.999))
79 |
80 | self.optimizer_G2 = torch.optim.Adam([
81 | {'params': self.deeplabPart1.parameters()},
82 | {'params': self.deeplabPart2.parameters()},
83 | {'params': self.deeplabPart3.parameters()}],
84 | lr=args['lr_g2'], betas=(args['beta1'], 0.999))
85 | self.optimizer_D2 = torch.optim.Adam(self.netD2.parameters(),
86 | lr=args['lr_g2'], betas=(args['beta1'], 0.999))
87 |
88 | ignored_params = list(map(id, self.deeplabPart3.fc8_1.parameters()))
89 | ignored_params.extend(list(map(id, self.deeplabPart3.fc8_2.parameters())))
90 | ignored_params.extend(list(map(id, self.deeplabPart3.fc8_3.parameters())))
91 | ignored_params.extend(list(map(id, self.deeplabPart3.fc8_4.parameters())))
92 | base_params = filter(lambda p: id(p) not in ignored_params,
93 | self.deeplabPart3.parameters())
94 | base_params = base_params + filter(lambda p: True, self.deeplabPart1.parameters())
95 | base_params = base_params + filter(lambda p: True, self.deeplabPart2.parameters())
96 |
97 | deeplab_params = [{'params': base_params},
98 | {'params': get_parameters(self.deeplabPart3.fc8_1, 'weight'), 'lr': args['l_rate'] * 10},
99 | {'params': get_parameters(self.deeplabPart3.fc8_2, 'weight'), 'lr': args['l_rate'] * 10},
100 | {'params': get_parameters(self.deeplabPart3.fc8_3, 'weight'), 'lr': args['l_rate'] * 10},
101 | {'params': get_parameters(self.deeplabPart3.fc8_4, 'weight'), 'lr': args['l_rate'] * 10},
102 | {'params': get_parameters(self.deeplabPart3.fc8_1, 'bias'), 'lr': args['l_rate'] * 20},
103 | {'params': get_parameters(self.deeplabPart3.fc8_2, 'bias'), 'lr': args['l_rate'] * 20},
104 | {'params': get_parameters(self.deeplabPart3.fc8_3, 'bias'), 'lr': args['l_rate'] * 20},
105 | {'params': get_parameters(self.deeplabPart3.fc8_4, 'bias'), 'lr': args['l_rate'] * 20},
106 | ]
107 |
108 |
109 | self.optimizer_P = torch.optim.SGD(deeplab_params, lr=args['l_rate'], momentum=0.9, weight_decay=5e-4)
110 |
111 | self.optimizer_R = torch.optim.SGD(deeplab_params, lr=args['l_rate'], momentum=0.9, weight_decay=5e-4)
112 |
113 |
114 | print('---------- Networks initialized -------------')
115 | networks.print_network(self.netG1)
116 | networks.print_network(self.netD1)
117 | networks.print_network(self.netD2)
118 | networks.print_network(self.deeplabPart1)
119 | networks.print_network(self.deeplabPart2)
120 | networks.print_network(self.deeplabPart3)
121 | print('-----------------------------------------------')
122 |
123 |
124 | def set_input(self, input):
125 | self.input = input
126 | tImageA = input['A']
127 | tLabelA = input['A_label']
128 | tImageB = input['B']
129 | self.tImageA.resize_(tImageA.size()).copy_(tImageA)
130 | self.vImageA = Variable(self.tImageA)
131 |
132 | self.tLabelA.resize_(tLabelA.size()).copy_(tLabelA)
133 | self.vLabelA = Variable(self.tLabelA)
134 |
135 | self.tImageB.resize_(tImageB.size()).copy_(tImageB)
136 | self.vImageB = Variable(self.tImageB)
137 |
138 | if input.has_key('label_onehot'):
139 | tOnehotLabelA = input['label_onehot']
140 | self.tOnehotLabelA.resize_(tOnehotLabelA.size()).copy_(tOnehotLabelA)
141 | self.vOnehotLabelA = Variable(self.tOnehotLabelA)
142 |
143 | # used in test time, no backprop
144 | def test(self, input):
145 | self.tImageA.resize_(input.size()).copy_(input)
146 | self.vImageA = Variable(self.tImageA)
147 | self.output = self.deeplabPart3(self.deeplabPart2(self.deeplabPart1(self.vImageA)))
148 | return self.output
149 |
150 | def step_P(self):
151 | # Maintain pool5_B in this status
152 | self.pool5_B = self.deeplabPart2(self.deeplabPart1(self.vImageB))
153 | self.pool5_B_for_d1 = Variable(self.pool5_B.data)
154 |
155 | self.pool1_A = self.deeplabPart1(self.vImageA)
156 | self.pool5_A = self.deeplabPart2(self.pool1_A)
157 | self.predic_A = self.deeplabPart3(self.pool5_A)
158 | self.output = Variable(self.predic_A.data)
159 |
160 | self.loss_P = self.criterionCE(self.predic_A, self.vLabelA) / self.nb
161 | self.loss_P.backward()
162 |
163 | self.pool1_A = Variable(self.pool1_A.data)
164 | self.pool5_A = Variable(self.pool5_A.data)
165 |
166 |
167 | def step_G1(self):
168 | self.pool5_A = self.pool5_A + self.netG1(self.pool1_A)
169 | pred_fake = self.netD1.forward(self.pool5_A)
170 |
171 | self.loss_G1 = self.criterionAdv(pred_fake, True)
172 | self.loss_G1.backward()
173 |
174 | self.pool5_A = Variable(self.pool5_A.data)
175 |
176 | def step_D1(self):
177 | pred_real = self.netD1.forward(self.pool5_B_for_d1)
178 | loss_D1_real = self.criterionAdv(pred_real, True)
179 |
180 | pred_fake = self.netD1.forward(self.pool5_A)
181 | loss_D1_fake = self.criterionAdv(pred_fake, False)
182 |
183 | self.loss_D1 = (loss_D1_real + loss_D1_fake) * 0.5
184 | self.loss_D1.backward()
185 |
186 | def step_G2(self):
187 | self.predic_B = self.deeplabPart3(self.pool5_B)
188 | pred_fake = self.netD2.forward(self.predic_B)
189 |
190 | self.loss_G2 = self.criterionAdv(pred_fake, True)
191 | self.loss_G2.backward()
192 |
193 | def step_D2(self):
194 | #self.vOnehotLabelA = Variable(self.vOnehotLabelA.data)
195 | pred_real = self.netD2.forward(self.vOnehotLabelA)
196 | loss_D2_real = self.criterionAdv(pred_real, True)
197 |
198 | self.predic_B = Variable(self.predic_B.data)
199 | pred_fake = self.netD2.forward(self.predic_B)
200 | loss_D2_fake = self.criterionAdv(pred_fake, False)
201 |
202 | self.loss_D2 = (loss_D2_real + loss_D2_fake) * 0.5
203 |
204 | self.loss_D2.backward()
205 |
206 | def step_R(self):
207 | pool1 = self.deeplabPart1(self.vImageA)
208 | self.predic_A_R = self.deeplabPart3(self.deeplabPart2(pool1) + self.netG1(pool1))
209 | self.loss_R = self.criterionCE(self.predic_A_R, self.vLabelA) / self.nb
210 |
211 | self.loss_R.backward()
212 |
213 | def step(self):
214 | self.Iter += 1
215 | # deeplab
216 | self.optimizer_P.zero_grad()
217 | self.step_P()
218 | self.optimizer_P.step()
219 |
220 | # G1
221 | self.optimizer_G1.zero_grad()
222 | self.step_G1()
223 | self.optimizer_G1.step()
224 | # D1
225 | self.optimizer_D1.zero_grad()
226 | self.step_D1()
227 | self.optimizer_D1.step()
228 | if self.Iter % self.interval_g2 == 0 and self.if_adv_train:
229 | # G2
230 | self.optimizer_G2.zero_grad()
231 | self.step_G2()
232 | self.optimizer_G2.step()
233 | if self.Iter % self.interval_d2 == 0 and self.if_adv_train:
234 | # D2
235 | self.optimizer_D2.zero_grad()
236 | self.step_D2()
237 | self.optimizer_D2.step()
238 |
239 | # Refine
240 | self.optimizer_R.zero_grad()
241 | self.step_R()
242 | self.optimizer_R.step()
243 |
244 |
245 | def get_current_visuals(self):
246 | return self.input
247 |
248 | def get_current_errors(self):
249 | return {}
250 |
251 |
252 | def save(self, model_name, Iter=None, epoch=None, acc=[]):
253 | save_filename = '%s_model.pth' % (model_name)
254 | save_path = os.path.join(self.save_dir, save_filename)
255 | torch.save({
256 | 'name':self.name(),
257 | 'Iter': Iter,
258 | 'epoch': epoch,
259 | 'acc':acc,
260 | 'state_dict_netG1': self.netG1.state_dict(),
261 | 'state_dict_netD1': self.netD1.state_dict(),
262 | 'state_dict_netD2': self.netD2.state_dict(),
263 | 'state_dict_deeplabPart1': self.deeplabPart1.state_dict(),
264 | 'state_dict_deeplabPart2':self.deeplabPart2.state_dict(),
265 | 'state_dict_deeplabPart3': self.deeplabPart3.state_dict(),
266 | 'optimizer_P':self.optimizer_P.state_dict(),
267 | 'optimizer_R': self.optimizer_R.state_dict(),
268 | 'optimizer_G1': self.optimizer_G1.state_dict(),
269 | 'optimizer_D1': self.optimizer_D1.state_dict(),
270 | 'optimizer_G2': self.optimizer_G2.state_dict(),
271 | 'optimizer_D2': self.optimizer_D2.state_dict(),
272 | }, save_path)
273 |
274 | def load(self, load_path):
275 | checkpoint = torch.load(load_path)
276 | self.netG1.load_state_dict(checkpoint['state_dict_netG1'])
277 | self.netD1.load_state_dict(checkpoint['state_dict_netD1'])
278 | self.netD2.load_state_dict(checkpoint['state_dict_netD2'])
279 | self.deeplabPart1.load_state_dict(checkpoint['state_dict_deeplabPart1'])
280 | self.deeplabPart2.load_state_dict(checkpoint['state_dict_deeplabPart2'])
281 | self.deeplabPart3.load_state_dict(checkpoint['state_dict_deeplabPart3'])
282 |
283 | self.optimizer_P.load_state_dict(checkpoint['optimizer_P'])
284 | self.optimizer_G1.load_state_dict(checkpoint['optimizer_G1'])
285 | self.optimizer_D1.load_state_dict(checkpoint['optimizer_D1'])
286 | self.optimizer_G2.load_state_dict(checkpoint['optimizer_G2'])
287 | self.optimizer_D2.load_state_dict(checkpoint['optimizer_D2'])
288 | self.optimizer_R.load_state_dict(checkpoint['optimizer_R'])
289 | for k,v in checkpoint['acc'].items():
290 | print('=================================================')
291 | if k == 'acc_Ori_on_B':
292 | best_f1 = v['avg_f1score']
293 | print('accuracy: {0:.4f}\t'
294 | 'fg_accuracy: {1:.4f}\t'
295 | 'avg_precision: {2:.4f}\t'
296 | 'avg_recall: {3:.4f}\t'
297 | 'avg_f1score: {4:.4f}\t'
298 | .format(v['accuracy'],v['fg_accuracy'],v['avg_precision'], v['avg_recall'], v['avg_f1score']))
299 | print('=================================================')
300 |
301 | return checkpoint['Iter'], checkpoint['epoch'], best_f1
302 |
303 | # helper loading function that can be used by subclasses
304 | def load_network(self, network, network_label, epoch_label):
305 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
306 | save_path = os.path.join(self.save_dir, save_filename)
307 | network.load_state_dict(torch.load(save_path))
308 |
309 | def update_learning_rate(self):
310 | for param_group in self.optimizer_D1.param_groups:
311 | param_group['lr'] = param_group['lr'] * 0.1
312 | for param_group in self.optimizer_G1.param_groups:
313 | param_group['lr'] = param_group['lr'] * 0.1
314 |
315 | for param_group in self.optimizer_D2.param_groups:
316 | param_group['lr'] = param_group['lr'] * 0.1
317 | for param_group in self.optimizer_G2.param_groups:
318 | param_group['lr'] = param_group['lr'] * 0.1
319 |
320 | def train(self):
321 | self.deeplabPart1.train()
322 | self.deeplabPart2.train()
323 | self.deeplabPart3.train()
324 | self.netG1.train()
325 | self.netD1.train()
326 | self.netD2.train()
327 |
328 | def eval(self):
329 | self.deeplabPart1.eval()
330 | self.deeplabPart2.eval()
331 | self.deeplabPart3.eval()
332 | self.netG1.eval()
333 | self.netD1.eval()
334 | self.netD2.eval()
335 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.nn.init as init
5 | from torch.autograd import Variable
6 | import random
7 | ###############################################################################
8 | # Functions
9 | ###############################################################################
10 | def print_network(net):
11 | num_params = 0
12 | for param in net.parameters():
13 | num_params += param.numel()
14 | print(net)
15 | print('Total number of parameters: %d' % num_params)
16 |
17 | def weights_init(m):
18 | classname = m.__class__.__name__
19 | if classname.find('Conv') != -1:
20 | m.weight.data.normal_(0.0, 0.02)
21 | elif classname.find('BatchNorm2d') != -1:
22 | m.weight.data.normal_(1.0, 0.02)
23 | m.bias.data.fill_(0)
24 |
25 | # Defines the GAN loss which uses either LSGAN or the regular GAN.
26 | # When LSGAN is used, it is basically same as MSELoss,
27 | # but it abstracts away the need to create the target label tensor
28 | # that has the same size as the input
29 | class Advloss(nn.Module):
30 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
31 | tensor=torch.FloatTensor):
32 | super(Advloss, self).__init__()
33 | self.real_label = target_real_label
34 | self.fake_label = target_fake_label
35 | self.real_label_var = None
36 | self.fake_label_var = None
37 | self.Tensor = tensor
38 | if use_lsgan:
39 | self.loss = nn.MSELoss()
40 | else:
41 | self.loss = nn.BCELoss()
42 |
43 | def get_target_tensor(self, input, target_is_real):
44 | target_tensor = None
45 | if target_is_real:
46 | create_label = ((self.real_label_var is None) or
47 | (self.real_label_var.numel() != input.numel()))
48 | if create_label:
49 | real_tensor = self.Tensor(input.size()).fill_(self.real_label)
50 | self.real_label_var = Variable(real_tensor, requires_grad=False)
51 | target_tensor = self.real_label_var
52 | else:
53 | create_label = ((self.fake_label_var is None) or
54 | (self.fake_label_var.numel() != input.numel()))
55 | if create_label:
56 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
57 | self.fake_label_var = Variable(fake_tensor, requires_grad=False)
58 | target_tensor = self.fake_label_var
59 | return target_tensor
60 |
61 | def __call__(self, input, target_is_real):
62 | target_tensor = self.get_target_tensor(input, target_is_real)
63 | return self.loss(input, target_tensor)
64 |
65 | class Deeplab(nn.Module):
66 | def __init__(self, size=(241,121)):
67 | super(Deeplab, self).__init__()
68 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1)
69 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
70 | self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
71 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
72 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
73 | self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
74 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
75 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
76 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
77 | self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
78 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
79 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
80 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
81 | self.pool4 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
82 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
83 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
84 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
85 | self.pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
86 |
87 | self.fc6_1 = nn.Conv2d(512, 1024, 3, padding=6, dilation=6)
88 | self.fc7_1 = nn.Conv2d(1024, 1024, 1)
89 | self.fc8_1 = nn.Conv2d(1024, 12, 1)
90 |
91 | self.fc6_2 = nn.Conv2d(512, 1024, 3, padding=12, dilation=12)
92 | self.fc7_2 = nn.Conv2d(1024, 1024, 1)
93 | self.fc8_2 = nn.Conv2d(1024, 12, 1)
94 |
95 | self.fc6_3 = nn.Conv2d(512, 1024, 3, padding=18, dilation=18)
96 | self.fc7_3 = nn.Conv2d(1024, 1024, 1)
97 | self.fc8_3 = nn.Conv2d(1024, 12, 1)
98 |
99 | self.fc6_4 = nn.Conv2d(512, 1024, 3, padding=24, dilation=24)
100 | self.fc7_4 = nn.Conv2d(1024, 1024, 1)
101 | self.fc8_4 = nn.Conv2d(1024, 12, 1)
102 |
103 | #self.fc8_interp = nn.Upsample(scale_factor=8,mode='bilinear')
104 | self.dropout = nn.Dropout2d(0.5)
105 | self.relu = nn.ReLU(inplace=True)
106 | self.fc8_interp = nn.Upsample(size=size, mode='bilinear')
107 |
108 | def weights_init(self, pretrained_dict={}):
109 | init.normal(self.fc8_1.weight.data, mean=0, std=0.01)
110 | init.constant(self.fc8_1.bias.data, 0)
111 | init.normal(self.fc8_2.weight.data, mean=0, std=0.01)
112 | init.constant(self.fc8_2.bias.data, 0)
113 | init.normal(self.fc8_3.weight.data, mean=0, std=0.01)
114 | init.constant(self.fc8_3.bias.data, 0)
115 | init.normal(self.fc8_4.weight.data, mean=0, std=0.01)
116 | init.constant(self.fc8_4.bias.data, 0)
117 |
118 | model_dict = self.state_dict()
119 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
120 | model_dict.update(pretrained_dict)
121 | self.load_state_dict(model_dict)
122 |
123 | def forward(self, x):
124 | x = self.relu(self.conv1_1(x))
125 | x = self.pool1(self.relu(self.conv1_2(x)))
126 | x = self.relu(self.conv2_1(x))
127 | x = self.pool2(self.relu(self.conv2_2(x)))
128 | x = self.relu(self.conv3_1(x))
129 | x = self.relu(self.conv3_2(x))
130 | x = self.pool3(self.relu(self.conv3_3(x)))
131 | x = self.relu(self.conv4_1(x))
132 | x = self.relu(self.conv4_2(x))
133 | x = self.pool4(self.relu(self.conv4_3(x)))
134 | x = self.relu(self.conv5_1(x))
135 | x = self.relu(self.conv5_2(x))
136 | x = self.pool5(self.relu(self.conv5_3(x)))
137 |
138 | x1 = self.dropout(0.5)(self.relu(self.fc6_1(x)))
139 | x1 = self.dropout(0.5)(self.relu(self.fc7_1(x1)))
140 | x1 = self.fc8_1(x1)
141 |
142 | x2 = self.dropout(0.5)(self.relu(self.fc6_2(x)))
143 | x2 = self.dropout(0.5)(self.relu(self.fc7_2(x2)))
144 | x2 = self.fc8_2(x2)
145 |
146 | x3 = self.dropout(0.5)(self.relu(self.fc6_3(x)))
147 | x3 = self.dropout(0.5)(self.relu(self.fc7_3(x3)))
148 | x3 = self.fc8_3(x3)
149 |
150 | x4 = self.dropout(0.5)(self.relu(self.fc6_4(x)))
151 | x4 = self.dropout(0.5)(self.relu(self.fc7_4(x4)))
152 | x4 = self.fc8_4(x4)
153 | x = self.fc8_interp(x1 + x2 + x3 + x4)
154 | return x
155 |
156 | class DeeplabPool1(nn.Module):
157 | def __init__(self, size=(241,121)):
158 | super(DeeplabPool1, self).__init__()
159 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1)
160 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
161 | self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
162 | self.relu = nn.ReLU(inplace=True)
163 |
164 | def weights_init(self, pretrained_dict={}):
165 | model_dict = self.state_dict()
166 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
167 | model_dict.update(pretrained_dict)
168 | self.load_state_dict(model_dict)
169 |
170 | def forward(self, x):
171 | x = self.relu(self.conv1_1(x))
172 | x = self.pool1(self.relu(self.conv1_2(x)))
173 | return x
174 |
175 | class DeeplabPool12Conv5_1(nn.Module):
176 | def __init__(self, size=(241,121)):
177 | super(DeeplabPool12Conv5_1, self).__init__()
178 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
179 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
180 | self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
181 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
182 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
183 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
184 | self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
185 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
186 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
187 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
188 | self.pool4 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
189 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
190 |
191 | self.relu = nn.ReLU()
192 |
193 | #self.fc8_interp = nn.Upsample(scale_factor=8,mode='bilinear')
194 | self.fc8_interp = nn.Upsample(size=size, mode='bilinear')
195 |
196 | def weights_init(self, pretrained_dict={}):
197 | model_dict = self.state_dict()
198 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
199 | model_dict.update(pretrained_dict)
200 | self.load_state_dict(model_dict)
201 |
202 | def forward(self, x):
203 | x = self.relu(self.conv2_1(x))
204 | x = self.pool2(self.relu(self.conv2_2(x)))
205 | x = self.relu(self.conv3_1(x))
206 | x = self.relu(self.conv3_2(x))
207 | x = self.pool3(self.relu(self.conv3_3(x)))
208 | x = self.relu(self.conv4_1(x))
209 | x = self.relu(self.conv4_2(x))
210 | x = self.pool4(self.relu(self.conv4_3(x)))
211 | x = self.relu(self.conv5_1(x))
212 |
213 | return x
214 |
215 | class DeeplabConv5_22Fc8_interp(nn.Module):
216 | def __init__(self, size=(241,121)):
217 | super(DeeplabConv5_22Fc8_interp, self).__init__()
218 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
219 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
220 | self.pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
221 |
222 | self.fc6_1 = nn.Conv2d(512, 1024, 3, padding=6, dilation=6)
223 | self.fc7_1 = nn.Conv2d(1024, 1024, 1)
224 | self.fc8_1 = nn.Conv2d(1024, 12, 1)
225 |
226 | self.fc6_2 = nn.Conv2d(512, 1024, 3, padding=12, dilation=12)
227 | self.fc7_2 = nn.Conv2d(1024, 1024, 1)
228 | self.fc8_2 = nn.Conv2d(1024, 12, 1)
229 |
230 | self.fc6_3 = nn.Conv2d(512, 1024, 3, padding=18, dilation=18)
231 | self.fc7_3 = nn.Conv2d(1024, 1024, 1)
232 | self.fc8_3 = nn.Conv2d(1024, 12, 1)
233 |
234 | self.fc6_4 = nn.Conv2d(512, 1024, 3, padding=24, dilation=24)
235 | self.fc7_4 = nn.Conv2d(1024, 1024, 1)
236 | self.fc8_4 = nn.Conv2d(1024, 12, 1)
237 | self.dropout = nn.Dropout2d(0.5)
238 | self.relu = nn.ReLU(inplace=True)
239 | self.fc8_interp = nn.Upsample(size=size, mode='bilinear')
240 |
241 | def weights_init(self, pretrained_dict={}):
242 | init.normal(self.fc8_1.weight.data, mean=0, std=0.01)
243 | init.constant(self.fc8_1.bias.data, 0)
244 | init.normal(self.fc8_2.weight.data, mean=0, std=0.01)
245 | init.constant(self.fc8_2.bias.data, 0)
246 | init.normal(self.fc8_3.weight.data, mean=0, std=0.01)
247 | init.constant(self.fc8_3.bias.data, 0)
248 | init.normal(self.fc8_4.weight.data, mean=0, std=0.01)
249 | init.constant(self.fc8_4.bias.data, 0)
250 |
251 | model_dict = self.state_dict()
252 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
253 | model_dict.update(pretrained_dict)
254 | self.load_state_dict(model_dict)
255 |
256 | def forward(self, x):
257 | x = self.relu(self.conv5_2(x))
258 | x = self.pool5(self.relu(self.conv5_3(x)))
259 |
260 | x1 = self.dropout(self.relu(self.fc6_1(x)))
261 | x1 = self.dropout(self.relu(self.fc7_1(x1)))
262 | x1 = self.fc8_1(x1)
263 |
264 | x2 = self.dropout(self.relu(self.fc6_2(x)))
265 | x2 = self.dropout(self.relu(self.fc7_2(x2)))
266 | x2 = self.fc8_2(x2)
267 |
268 | x3 = self.dropout(self.relu(self.fc6_3(x)))
269 | x3 = self.dropout(self.relu(self.fc7_3(x3)))
270 | x3 = self.fc8_3(x3)
271 |
272 | x4 = self.dropout(self.relu(self.fc6_4(x)))
273 | x4 = self.dropout(self.relu(self.fc7_4(x4)))
274 | x4 = self.fc8_4(x4)
275 | x = self.fc8_interp(x1 + x2 + x3 + x4)
276 | return x
277 |
278 | class DeeplabPool12Pool5(nn.Module):
279 | def __init__(self, size=(241,121)):
280 | super(DeeplabPool12Pool5, self).__init__()
281 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
282 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
283 | self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
284 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
285 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
286 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
287 | self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
288 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
289 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
290 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
291 | self.pool4 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
292 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
293 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
294 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=2, dilation=2)
295 | self.pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
296 |
297 | self.relu = nn.ReLU()
298 |
299 | #self.fc8_interp = nn.Upsample(scale_factor=8,mode='bilinear')
300 | self.fc8_interp = nn.Upsample(size=size, mode='bilinear')
301 |
302 | def weights_init(self, pretrained_dict={}):
303 | model_dict = self.state_dict()
304 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
305 | model_dict.update(pretrained_dict)
306 | self.load_state_dict(model_dict)
307 |
308 | def forward(self, x):
309 | x = self.relu(self.conv2_1(x))
310 | x = self.pool2(self.relu(self.conv2_2(x)))
311 | x = self.relu(self.conv3_1(x))
312 | x = self.relu(self.conv3_2(x))
313 | x = self.pool3(self.relu(self.conv3_3(x)))
314 | x = self.relu(self.conv4_1(x))
315 | x = self.relu(self.conv4_2(x))
316 | x = self.pool4(self.relu(self.conv4_3(x)))
317 | x = self.relu(self.conv5_1(x))
318 | x = self.relu(self.conv5_2(x))
319 | x = self.pool5(self.relu(self.conv5_3(x)))
320 | return x
321 |
322 | class DeeplabPool52Fc8_interp(nn.Module):
323 | def __init__(self, output_nc, size=(241,121)):
324 | super(DeeplabPool52Fc8_interp, self).__init__()
325 |
326 | self.fc6_1 = nn.Conv2d(512, 1024, 3, padding=6, dilation=6)
327 | self.fc7_1 = nn.Conv2d(1024, 1024, 1)
328 | self.fc8_1 = nn.Conv2d(1024, output_nc, 1)
329 |
330 | self.fc6_2 = nn.Conv2d(512, 1024, 3, padding=12, dilation=12)
331 | self.fc7_2 = nn.Conv2d(1024, 1024, 1)
332 | self.fc8_2 = nn.Conv2d(1024, output_nc, 1)
333 |
334 | self.fc6_3 = nn.Conv2d(512, 1024, 3, padding=18, dilation=18)
335 | self.fc7_3 = nn.Conv2d(1024, 1024, 1)
336 | self.fc8_3 = nn.Conv2d(1024, output_nc, 1)
337 |
338 | self.fc6_4 = nn.Conv2d(512, 1024, 3, padding=24, dilation=24)
339 | self.fc7_4 = nn.Conv2d(1024, 1024, 1)
340 | self.fc8_4 = nn.Conv2d(1024, output_nc, 1)
341 | self.dropout = nn.Dropout2d(0.5)
342 | self.relu = nn.ReLU(inplace=True)
343 | self.fc8_interp = nn.Upsample(size=size, mode='bilinear')
344 |
345 | def weights_init(self, pretrained_dict={}):
346 | init.normal(self.fc8_1.weight.data, mean=0, std=0.01)
347 | init.constant(self.fc8_1.bias.data, 0)
348 | init.normal(self.fc8_2.weight.data, mean=0, std=0.01)
349 | init.constant(self.fc8_2.bias.data, 0)
350 | init.normal(self.fc8_3.weight.data, mean=0, std=0.01)
351 | init.constant(self.fc8_3.bias.data, 0)
352 | init.normal(self.fc8_4.weight.data, mean=0, std=0.01)
353 | init.constant(self.fc8_4.bias.data, 0)
354 |
355 | model_dict = self.state_dict()
356 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
357 | model_dict.update(pretrained_dict)
358 | self.load_state_dict(model_dict)
359 |
360 | def forward(self, x):
361 | x1 = self.dropout(self.relu(self.fc6_1(x)))
362 | x1 = self.dropout(self.relu(self.fc7_1(x1)))
363 | x1 = self.fc8_1(x1)
364 |
365 | x2 = self.dropout(self.relu(self.fc6_2(x)))
366 | x2 = self.dropout(self.relu(self.fc7_2(x2)))
367 | x2 = self.fc8_2(x2)
368 |
369 | x3 = self.dropout(self.relu(self.fc6_3(x)))
370 | x3 = self.dropout(self.relu(self.fc7_3(x3)))
371 | x3 = self.fc8_3(x3)
372 |
373 | x4 = self.dropout(self.relu(self.fc6_4(x)))
374 | x4 = self.dropout(self.relu(self.fc7_4(x4)))
375 | x4 = self.fc8_4(x4)
376 | x = self.fc8_interp(x1 + x2 + x3 + x4)
377 | return x
378 |
379 | class netG(nn.Module):
380 | def __init__(self, n_blocks=6):
381 | super(netG, self).__init__()
382 | input_nc = 64
383 | ngf = 128
384 | norm_layer = nn.BatchNorm2d
385 | padding_type = 'reflect'
386 | use_dropout = 0
387 |
388 | mult = 1
389 | model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3), norm_layer(ngf), nn.ReLU(True)]
390 |
391 | for i in range(n_blocks):
392 | if (i+1) % 3 == 0:
393 | model += [nn.MaxPool2d(kernel_size=3, stride=1, padding=1), nn.Conv2d(ngf*mult, ngf*mult*2, kernel_size=3, stride=2,padding=1)]
394 | mult *= 2
395 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout)]
396 |
397 | self.model = nn.Sequential(*model)
398 |
399 |
400 | def forward(self, x):
401 | return self.model(x)
402 |
403 | class netG_structure(nn.Module):
404 | def __init__(self, input_nc=512, output_nc=12, n_blocks=3, size=(241, 121)):
405 | super(netG_structure, self).__init__()
406 | ngf = 128
407 | norm_layer = nn.BatchNorm2d
408 | padding_type = 'reflect'
409 | use_dropout = 0
410 |
411 | model = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=3), norm_layer(ngf), nn.ReLU(True)]
412 |
413 | for i in range(n_blocks):
414 | model += [
415 | ResnetBlock(ngf, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout)]
416 |
417 | model += [nn.Conv2d(ngf, output_nc, kernel_size=3, padding=1), nn.Upsample(size=size, mode='bilinear')]
418 | self.model = nn.Sequential(*model)
419 |
420 | def forward(self, x):
421 | return self.model(x)
422 |
423 | # Define a resnet block
424 | class ResnetBlock(nn.Module):
425 | def __init__(self, dim, padding_type, norm_layer, use_dropout):
426 | super(ResnetBlock, self).__init__()
427 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout)
428 |
429 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout):
430 | conv_block = []
431 | p = 0
432 | if padding_type == 'reflect':
433 | conv_block += [nn.ReflectionPad2d(1)]
434 | elif padding_type == 'replicate':
435 | conv_block += [nn.ReplicationPad2d(1)]
436 | elif padding_type == 'zero':
437 | p = 1
438 | else:
439 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
440 |
441 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
442 | norm_layer(dim),
443 | nn.ReLU(True)]
444 | if use_dropout:
445 | conv_block += [nn.Dropout(0.5)]
446 |
447 | p = 0
448 | if padding_type == 'reflect':
449 | conv_block += [nn.ReflectionPad2d(1)]
450 | elif padding_type == 'replicate':
451 | conv_block += [nn.ReplicationPad2d(1)]
452 | elif padding_type == 'zero':
453 | p = 1
454 | else:
455 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
456 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p),
457 | norm_layer(dim)]
458 |
459 | return nn.Sequential(*conv_block)
460 |
461 | def forward(self, x):
462 | out = x + self.conv_block(x)
463 | return out
464 |
465 | class MultPathdilationNet(nn.Module):
466 | def __init__(self):
467 | super(MultPathdilationNet, self).__init__()
468 | input_nc = 512
469 | ngf = 128
470 | norm_layer = nn.InstanceNorm2d
471 | padding_type = 'reflect'
472 | use_dropout = 0
473 | self.relu = nn.ReLU(inplace=True)
474 |
475 | model_1 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), norm_layer(1024), self.relu, nn.Dropout2d(0.5),
476 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), self.relu, nn.Dropout2d(0.5),
477 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)]
478 | model_2 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), norm_layer(1024), self.relu, nn.Dropout2d(0.5),
479 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), self.relu, nn.Dropout2d(0.5),
480 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)]
481 | model_3 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), norm_layer(1024), self.relu, nn.Dropout2d(0.5),
482 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), self.relu, nn.Dropout2d(0.5),
483 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)]
484 | model_4 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), norm_layer(1024), self.relu, nn.Dropout2d(0.5),
485 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), self.relu, nn.Dropout2d(0.5),
486 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)]
487 |
488 | self.model_1 = nn.Sequential(*model_1)
489 | self.model_2 = nn.Sequential(*model_2)
490 | self.model_3 = nn.Sequential(*model_3)
491 | self.model_4 = nn.Sequential(*model_4)
492 |
493 |
494 | def forward(self, x):
495 | return ( self.model_1(x) + self.model_2(x) + self.model_3(x) + self.model_4(x) ) / 4
496 |
497 | class RandomMultPathdilationNet(nn.Module):
498 | def __init__(self):
499 | super(RandomMultPathdilationNet, self).__init__()
500 | input_nc = 512
501 | ngf = 128
502 | norm_layer = nn.InstanceNorm2d
503 | padding_type = 'reflect'
504 | use_dropout = 0
505 | self.relu = nn.ReLU(inplace=True)
506 |
507 | model_1 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), norm_layer(1024), self.relu, nn.Dropout2d(0.5),
508 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), self.relu, nn.Dropout2d(0.5),
509 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)]
510 | model_2 = [nn.Conv2d(512, 1024, 3, padding=12, dilation=12), norm_layer(1024), self.relu, nn.Dropout2d(0.5),
511 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), self.relu, nn.Dropout2d(0.5),
512 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)]
513 | model_3 = [nn.Conv2d(512, 1024, 3, padding=18, dilation=18), norm_layer(1024), self.relu, nn.Dropout2d(0.5),
514 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), self.relu, nn.Dropout2d(0.5),
515 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)]
516 | model_4 = [nn.Conv2d(512, 1024, 3, padding=24, dilation=24), norm_layer(1024), self.relu, nn.Dropout2d(0.5),
517 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), self.relu, nn.Dropout2d(0.5),
518 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)]
519 |
520 | self.model_1 = nn.Sequential(*model_1)
521 | self.model_2 = nn.Sequential(*model_2)
522 | self.model_3 = nn.Sequential(*model_3)
523 | self.model_4 = nn.Sequential(*model_4)
524 |
525 |
526 | def forward(self, x):
527 | which_D = random.uniform(0,1)
528 | if which_D < 0.25:
529 | return self.model_1(x)
530 | elif which_D < 0.5:
531 | return self.model_2(x)
532 | elif which_D < 0.75:
533 | return self.model_3(x)
534 | else:
535 | return self.model_4(x)
536 |
537 | class NoBNMultPathdilationNet(nn.Module):
538 | def __init__(self):
539 | super(NoBNMultPathdilationNet, self).__init__()
540 | input_nc = 512
541 | ngf = 128
542 | norm_layer = nn.InstanceNorm2d
543 | padding_type = 'reflect'
544 | use_dropout = 0
545 | self.relu = nn.ReLU(inplace=True)
546 |
547 | model_1 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), self.relu, nn.Dropout2d(0.5),
548 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), self.relu, nn.Dropout2d(0.5),
549 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)]
550 | model_2 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), self.relu, nn.Dropout2d(0.5),
551 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), self.relu, nn.Dropout2d(0.5),
552 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)]
553 | model_3 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), self.relu, nn.Dropout2d(0.5),
554 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), self.relu, nn.Dropout2d(0.5),
555 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)]
556 | model_4 = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), self.relu, nn.Dropout2d(0.5),
557 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), self.relu, nn.Dropout2d(0.5),
558 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)]
559 |
560 | self.model_1 = nn.Sequential(*model_1)
561 | self.model_2 = nn.Sequential(*model_2)
562 | self.model_3 = nn.Sequential(*model_3)
563 | self.model_4 = nn.Sequential(*model_4)
564 |
565 |
566 | def forward(self, x):
567 | return ( self.model_1(x) + self.model_2(x) + self.model_3(x) + self.model_4(x) ) / 4
568 |
569 | class FFCFeature(nn.Module):
570 | def __init__(self):
571 | super(FFCFeature, self).__init__()
572 | self.classifier = nn.Sequential(
573 | nn.Dropout(0.5),
574 | nn.Linear(512 * 31 * 16, 512),
575 | nn.BatchNorm1d(512),
576 | nn.ReLU(True),
577 | nn.Dropout(0.5),
578 | nn.Linear(512, 1024),
579 | nn.BatchNorm1d(1024),
580 | nn.ReLU(True),
581 | nn.Dropout(0.5),
582 | nn.Linear(1024, 1),
583 | nn.Sigmoid()
584 | )
585 |
586 | def forward(self, x):
587 | x = x.view(x.size(0), -1)
588 | x = self.classifier(x)
589 | return x
590 |
591 | class SinglePathdilationSingleOutputNet(nn.Module):
592 | def __init__(self):
593 | super(SinglePathdilationSingleOutputNet, self).__init__()
594 | model = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), nn.Dropout2d(0.5),
595 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1),
596 | nn.Dropout2d(0.5), nn.Conv2d(1024, 256, 3, stride=2, padding=1)]
597 | self.model = nn.Sequential(*model)
598 | self.linear = nn.Linear(256 * 8 * 4, 1)
599 |
600 | def forward(self, x):
601 | x = self.model(x)
602 | x = x.view(x.size(0), -1)
603 | x = nn.Sigmoid()(self.linear(x))
604 | return x
605 |
606 | class SinglePathdilationMultOutputNet(nn.Module):
607 | def __init__(self):
608 | super(SinglePathdilationMultOutputNet, self).__init__()
609 | input_nc = 512
610 | ngf = 128
611 | norm_layer = nn.BatchNorm2d
612 | padding_type = 'reflect'
613 | use_dropout = 0
614 |
615 | model = [nn.Conv2d(512, 1024, 3, padding=6, dilation=6), norm_layer(1024), nn.ReLU(inplace=True),nn.Dropout2d(0.5),
616 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), norm_layer(1024), nn.ReLU(inplace=True),nn.Dropout2d(0.5),
617 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)]
618 |
619 | self.model = nn.Sequential(*model)
620 |
621 |
622 | def forward(self, x):
623 | return self.model(x)
624 |
625 | class NoBNSinglePathdilationMultOutputNet(nn.Module):
626 | def __init__(self, input_nc = 512):
627 | super(NoBNSinglePathdilationMultOutputNet, self).__init__()
628 | padding_type = 'reflect'
629 | use_dropout = 0
630 |
631 | model = [nn.Conv2d(input_nc, 1024, 3, padding=6, dilation=6), nn.ReLU(inplace=True),nn.Dropout2d(0.5),
632 | nn.Conv2d(1024, 1024, 3, stride=2, padding=1), nn.ReLU(inplace=True),nn.Dropout2d(0.5),
633 | nn.Conv2d(1024, 1, 3, stride=2, padding=1)]
634 |
635 | self.model = nn.Sequential(*model)
636 |
637 |
638 | def forward(self, x):
639 | return self.model(x)
640 |
641 | class dcgan_D(nn.Module):
642 | def __init__(self, input_nc, ngf=64, norm_layer=nn.BatchNorm2d, n_layers=4):
643 | super(input_nc, self).__init__()
644 | self.input_nc = input_nc
645 | self.ngf = ngf
646 | self.norm_layer = norm_layer
647 | self.n_layers = n_layers
648 | self.padding_type = 'reflect'
649 |
650 | mult = 1
651 | model = [nn.Conv2d(input_nc, ngf, 4, stride=2, padding=1), norm_layer(ngf), nn.ReLU(inplace=True)]
652 | for i in range(self.layers-1):
653 | model = model + [nn.Conv2d(ngf*mult, ngf*mult*2, 4, stride=2, padding=1), norm_layer(ngf*mult*2), nn.ReLU(inplace=True)]
654 | mult *= 2
655 |
656 | model = model + [nn.Conv2d(ngf * mult, ngf,4)]
657 | self.model = nn.Sequential(*model)
658 |
659 |
660 | def forward(self, x):
661 | x = self.model(x)
662 | x = x.view(x.size(0), -1)
663 | x = nn.Sigmoid()(nn.Linear(x.size(1), 1)(x))
664 | return x
665 |
666 | class dcgan_D_multOut(nn.Module):
667 | def __init__(self, input_nc=12, ngf=64, norm_layer=nn.BatchNorm2d, n_layers=4):
668 | super(dcgan_D_multOut, self).__init__()
669 | self.input_nc = input_nc
670 | self.ngf = ngf
671 | self.norm_layer = norm_layer
672 | self.n_layers = n_layers
673 | self.padding_type = 'reflect'
674 |
675 | mult = 1
676 | model = [nn.Conv2d(input_nc, ngf, 4, stride=2, padding=1), norm_layer(ngf), nn.ReLU(inplace=True)]
677 | for i in range(self.n_layers-1):
678 | model = model + [nn.Conv2d(ngf*mult, ngf*mult*2, 4, stride=2, padding=1), norm_layer(ngf*mult*2), nn.ReLU(inplace=True)]
679 | mult *= 2
680 |
681 | model = model + [nn.Conv2d(ngf * mult, 1, 4)]
682 | self.model = nn.Sequential(*model)
683 |
684 |
685 | def forward(self, x):
686 |
687 | return self.model(x)
688 |
689 | class lsgan_D(nn.Module):
690 | def __init__(self, input_nc=12, ngf=64, norm_layer=nn.BatchNorm2d, n_layers=4):
691 | super(lsgan_D, self).__init__()
692 | self.input_nc = input_nc
693 | self.ngf = ngf
694 | self.norm_layer = norm_layer
695 | self.n_layers = n_layers
696 |
697 | mult = 1
698 | features = [nn.Conv2d(input_nc, ngf, 5, stride=2, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True)]
699 | for i in range(self.n_layers-1):
700 | features = features + [nn.Conv2d(ngf*mult, ngf*mult*2, 5, stride=2, padding=1), norm_layer(ngf*mult*2), nn.LeakyReLU(negative_slope=0.2, inplace=True)]
701 | mult *= 2
702 |
703 | self.features = nn.Sequential(*features)
704 |
705 | self.fc = nn.Sequential(nn.Linear(512 * 14 * 6, 1))
706 |
707 |
708 | def forward(self, x):
709 | x = self.features.forward(x)
710 | x = x.view(x.size(0), -1)
711 | x = self.fc(x)
712 | return x
713 |
714 | class lsganMultOutput_D(nn.Module):
715 | def __init__(self, input_nc=12, ngf=64, norm_layer=nn.BatchNorm2d, n_layers=4):
716 | super(lsganMultOutput_D, self).__init__()
717 | self.input_nc = input_nc
718 | self.ngf = ngf
719 | self.norm_layer = norm_layer
720 | self.n_layers = n_layers
721 |
722 | mult = 1
723 | features = [nn.Conv2d(input_nc, ngf, 5, stride=2, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True)]
724 | for i in range(self.n_layers-1):
725 | features = features + [nn.Conv2d(ngf*mult, ngf*mult*2, 5, stride=2, padding=1), norm_layer(ngf*mult*2), nn.LeakyReLU(negative_slope=0.2, inplace=True)]
726 | mult *= 2
727 |
728 | features += [nn.Conv2d(ngf*mult, 1, 5)]
729 | self.features = nn.Sequential(*features)
730 |
731 |
732 |
733 | def forward(self, x):
734 | return self.features.forward(x)
735 |
736 |
--------------------------------------------------------------------------------