├── MECNet
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── config.cpython-36.pyc
│ ├── config.cpython-37.pyc
│ ├── dataset.cpython-37.pyc
│ ├── edge_connect.cpython-37.pyc
│ ├── loss.cpython-36.pyc
│ ├── loss.cpython-37.pyc
│ ├── metrics.cpython-37.pyc
│ ├── models.cpython-36.pyc
│ ├── models.cpython-37.pyc
│ ├── network3.cpython-36.pyc
│ ├── network3.cpython-37.pyc
│ ├── networks.cpython-36.pyc
│ ├── networks.cpython-37.pyc
│ ├── networks2.cpython-36.pyc
│ ├── networks2.cpython-37.pyc
│ └── utils.cpython-37.pyc
├── config.py
├── dataset.py
├── edge_connect.py
├── loss.py
├── metrics.py
├── models.py
├── networks2.py
└── utils.py
├── README.md
├── config.yml
├── data
├── __pycache__
│ ├── basicFunction.cpython-36.pyc
│ ├── basicFunction.cpython-37.pyc
│ ├── dataloader.cpython-36.pyc
│ ├── dataloader.cpython-37.pyc
│ ├── dataloader_canny.cpython-36.pyc
│ └── dataloader_canny.cpython-37.pyc
├── basicFunction.py
├── dataloader.py
└── dataloader_canny.py
├── examples
├── GT28-1.png
├── MEDFE28-1.png
├── ec28-1.png
├── edge_mecnet(s)_1.png
├── edge_mecnet_1.png
├── gc28-1.png
├── gl28-1.png
├── input1.png
├── input28-1.png
├── ours28-1.png
└── pconv28-1.png
├── loss
├── InpaintingLoss.py
└── __pycache__
│ ├── InpaintingLoss.cpython-36.pyc
│ └── InpaintingLoss.cpython-37.pyc
├── models
├── ActivationFunction.py
├── EdgeAttentionLayer.py
├── LBAMModel.py
├── __pycache__
│ ├── ActivationFunction.cpython-36.pyc
│ ├── ActivationFunction.cpython-37.pyc
│ ├── EdgeAttentionLayer.cpython-36.pyc
│ ├── EdgeAttentionLayer.cpython-37.pyc
│ ├── LBAMModel.cpython-36.pyc
│ ├── LBAMModel.cpython-37.pyc
│ ├── discriminator.cpython-36.pyc
│ ├── discriminator.cpython-37.pyc
│ ├── forwardAttentionLayer.cpython-36.pyc
│ ├── forwardAttentionLayer.cpython-37.pyc
│ ├── reverseAttentionLayer.cpython-36.pyc
│ ├── reverseAttentionLayer.cpython-37.pyc
│ ├── weightInitial.cpython-36.pyc
│ └── weightInitial.cpython-37.pyc
├── discriminator.py
├── forwardAttentionLayer.py
├── reverseAttentionLayer.py
└── weightInitial.py
├── pytorch_ssim
├── __init__.py
└── __pycache__
│ ├── __init__.cpython-36.pyc
│ └── __init__.cpython-37.pyc
├── test_random_batch.py
└── train.py
/MECNet/__init__.py:
--------------------------------------------------------------------------------
1 | # empty
--------------------------------------------------------------------------------
/MECNet/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/config.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/config.cpython-36.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/config.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/config.cpython-37.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/edge_connect.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/edge_connect.cpython-37.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/loss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/loss.cpython-36.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/metrics.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/metrics.cpython-37.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/models.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/models.cpython-36.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/models.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/models.cpython-37.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/network3.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/network3.cpython-36.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/network3.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/network3.cpython-37.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/networks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/networks.cpython-36.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/networks.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/networks.cpython-37.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/networks2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/networks2.cpython-36.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/networks2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/networks2.cpython-37.pyc
--------------------------------------------------------------------------------
/MECNet/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/MECNet/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 |
4 | class Config(dict):
5 | def __init__(self, config_path):
6 | with open(config_path, 'r') as f:
7 | self._yaml = f.read()
8 | self._dict = yaml.load(self._yaml)
9 | self._dict['PATH'] = "~/LBAM_GRU_version2/checkpoints/psv"
10 |
11 | def __getattr__(self, name):
12 | if self._dict.get(name) is not None:
13 | return self._dict[name]
14 |
15 | if DEFAULT_CONFIG.get(name) is not None:
16 | return DEFAULT_CONFIG[name]
17 |
18 | return None
19 |
20 | def print(self):
21 | print('Model configurations:')
22 | print('---------------------------------')
23 | print(self._yaml)
24 | print('')
25 | print('---------------------------------')
26 | print('')
27 |
28 |
29 | DEFAULT_CONFIG = {
30 | 'MODE': 2, # 1: train, 2: test, 3: eval
31 | 'MODEL': 1, # 1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model
32 | 'MASK': 3, # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half)
33 | 'EDGE': 1, # 1: canny, 2: external
34 | 'NMS': 1, # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny
35 | 'SEED': 10, # random seed
36 | 'GPU': [0], # list of gpu ids
37 | 'DEBUG': 0, # turns on debugging mode
38 | 'VERBOSE': 0, # turns on verbose mode in the output console
39 |
40 | 'LR': 0.0001, # learning rate
41 | 'D2G_LR': 0.1, # discriminator/generator learning rate ratio
42 | 'BETA1': 0.0, # adam optimizer beta1
43 | 'BETA2': 0.9, # adam optimizer beta2
44 | 'BATCH_SIZE': 8, # input batch size for training
45 | 'INPUT_SIZE': 256, # input image size for training 0 for original size
46 | 'SIGMA': 2, # standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge)
47 | 'MAX_ITERS': 2e6, # maximum number of iterations to train the model
48 |
49 | 'EDGE_THRESHOLD': 0.5, # edge detection threshold
50 | 'L1_LOSS_WEIGHT': 1, # l1 loss weight
51 | 'FM_LOSS_WEIGHT': 10, # feature-matching loss weight
52 | 'STYLE_LOSS_WEIGHT': 1, # style loss weight
53 | 'CONTENT_LOSS_WEIGHT': 1, # perceptual loss weight
54 | 'INPAINT_ADV_LOSS_WEIGHT': 0.01,# adversarial loss weight
55 |
56 | 'GAN_LOSS': 'nsgan', # nsgan | lsgan | hinge
57 | 'GAN_POOL_SIZE': 0, # fake images pool size
58 |
59 | 'SAVE_INTERVAL': 1000, # how many iterations to wait before saving model (0: never)
60 | 'SAMPLE_INTERVAL': 1000, # how many iterations to wait before sampling (0: never)
61 | 'SAMPLE_SIZE': 12, # number of images to sample
62 | 'EVAL_INTERVAL': 0, # how many iterations to wait before model evaluation (0: never)
63 | 'LOG_INTERVAL': 10, # how many iterations to wait before logging training status (0: never)
64 | }
65 |
--------------------------------------------------------------------------------
/MECNet/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import scipy
4 | import torch
5 | import random
6 | import numpy as np
7 | import torchvision.transforms.functional as F
8 | from torch.utils.data import DataLoader
9 | from PIL import Image
10 | from scipy.misc import imread
11 | from skimage.feature import canny
12 | from skimage.color import rgb2gray, gray2rgb
13 | from .utils import create_mask
14 | import matplotlib.pyplot as plt
15 |
16 | class Dataset(torch.utils.data.Dataset):
17 | def __init__(self, config, flist, edge_flist, mask_flist, augment=True, training=True):
18 | super(Dataset, self).__init__()
19 | self.augment = augment
20 | self.training = training
21 | self.data = self.load_flist(flist)
22 | self.edge_data = self.load_flist(edge_flist)
23 | self.mask_data = self.load_flist(mask_flist)
24 |
25 | self.input_size = config.INPUT_SIZE
26 | self.sigma = 2
27 | self.edge = config.EDGE
28 | self.mask = config.MASK
29 | self.nms = config.NMS
30 |
31 | # in test mode, there's a one-to-one relationship between mask and image
32 | # masks are loaded non random
33 | if config.MODE == 2:
34 | self.mask = 6
35 |
36 | def __len__(self):
37 | return len(self.data)
38 |
39 | def __getitem__(self, index):
40 | try:
41 | item = self.load_item(index)
42 | except:
43 | print('loading error: ' + self.data[index])
44 | item = self.load_item(0)
45 |
46 | return item
47 |
48 | def load_name(self, index):
49 | name = self.data[index]
50 | return os.path.basename(name)
51 |
52 | def load_item(self, index):
53 |
54 | size = self.input_size
55 |
56 | # load image
57 | img = imread(self.data[index])
58 |
59 | # gray to rgb
60 | if len(img.shape) < 3:
61 | img = gray2rgb(img)
62 |
63 | # resize/crop if needed
64 | if size != 0:
65 | img = self.resize(img, size, size)
66 |
67 | # create grayscale image
68 | img_gray = rgb2gray(img)
69 |
70 | # load mask
71 | mask = self.load_mask(img, index)
72 | plt.imshow(mask, cmap=plt.cm.gray)
73 | plt.show()
74 | # load edge
75 | edge = self.load_edge(img_gray, index, mask)
76 | # print(img.shape,img_gray.shape,mask.shape,edge.shape)
77 | # augment data
78 | if self.augment and np.random.binomial(1, 0.5) > 0:
79 | img = img[:, ::-1, ...]
80 | img_gray = img_gray[:, ::-1, ...]
81 | edge = edge[:, ::-1, ...]
82 | mask = mask[:, ::-1, ...]
83 |
84 | return self.to_tensor(img), self.to_tensor(img_gray), self.to_tensor(edge), self.to_tensor(mask)
85 |
86 | def load_edge(self, img, index, mask):
87 | sigma = self.sigma
88 |
89 | # in test mode images are masked (with masked regions),
90 | # using 'mask' parameter prevents canny to detect edges for the masked regions
91 | mask = None if self.training else (1 - mask / 255).astype(np.bool)
92 |
93 | # canny
94 | if self.edge == 1:
95 | # no edge
96 | if sigma == -1:
97 | return np.zeros(img.shape).astype(np.float)
98 |
99 | # random sigma
100 | if sigma == 0:
101 | sigma = random.randint(1, 4)
102 |
103 | return canny(img, sigma=sigma, mask=mask).astype(np.float)
104 |
105 | # external
106 | else:
107 | imgh, imgw = img.shape[0:2]
108 | edge = imread(self.edge_data[index])
109 | edge = self.resize(edge, imgh, imgw)
110 |
111 | # non-max suppression
112 | if self.nms == 1:
113 | edge = edge * canny(img, sigma=sigma, mask=mask)
114 |
115 | return edge
116 |
117 | def load_mask(self, img, index):
118 | imgh, imgw = img.shape[0:2]
119 | mask_type = self.mask
120 |
121 | # external + random block
122 | if mask_type == 4:
123 | mask_type = 1 if np.random.binomial(1, 0.5) == 1 else 3
124 |
125 | # external + random block + half
126 | elif mask_type == 5:
127 | mask_type = np.random.randint(1, 4)
128 |
129 | # random block
130 | if mask_type == 1:
131 | return create_mask(imgw, imgh, imgw // 2, imgh // 2)
132 |
133 | # half
134 | if mask_type == 2:
135 | # randomly choose right or left
136 | return create_mask(imgw, imgh, imgw // 2, imgh, 0 if random.random() < 0.5 else imgw // 2, 0)
137 |
138 | # external
139 | if mask_type == 3:
140 | mask_index = random.randint(0, len(self.mask_data) - 1)
141 | mask = imread(self.mask_data[mask_index])
142 | mask = self.resize(mask, imgh, imgw)
143 | mask = (mask > 0).astype(np.uint8) * 255 # threshold due to interpolation
144 | return mask
145 |
146 | # test mode: load mask non random
147 | if mask_type == 6:
148 | mask = imread(self.mask_data[index])
149 | mask = self.resize(mask, imgh, imgw, centerCrop=False)
150 | mask = rgb2gray(mask)
151 | mask = (mask > 0).astype(np.uint8) * 255
152 | return mask
153 |
154 | def to_tensor(self, img):
155 | img = Image.fromarray(img)
156 | img_t = F.to_tensor(img).float()
157 | return img_t
158 |
159 | def resize(self, img, height, width, centerCrop=True):
160 | imgh, imgw = img.shape[0:2]
161 |
162 | if centerCrop and imgh != imgw:
163 | # center crop
164 | side = np.minimum(imgh, imgw)
165 | j = (imgh - side) // 2
166 | i = (imgw - side) // 2
167 | img = img[j:j + side, i:i + side, ...]
168 |
169 | img = scipy.misc.imresize(img, [height, width])
170 |
171 | return img
172 |
173 | def load_flist(self, flist):
174 | if isinstance(flist, list):
175 | return flist
176 |
177 | # flist: image file path, image directory path, text file flist path
178 | if isinstance(flist, str):
179 | if os.path.isdir(flist):
180 | flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png'))
181 | flist.sort()
182 | return flist
183 |
184 | if os.path.isfile(flist):
185 | try:
186 | return np.genfromtxt(flist, dtype=np.str, encoding='utf-8')
187 | except:
188 | return [flist]
189 |
190 | return []
191 |
192 | def create_iterator(self, batch_size):
193 | while True:
194 | sample_loader = DataLoader(
195 | dataset=self,
196 | batch_size=batch_size,
197 | drop_last=True
198 | )
199 |
200 | for item in sample_loader:
201 | yield item
202 |
--------------------------------------------------------------------------------
/MECNet/edge_connect.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from torch.utils.data import DataLoader
5 | from .dataset import Dataset
6 | from .models import EdgeModel, InpaintingModel
7 | from .utils import Progbar, create_dir, stitch_images, imsave
8 | from .metrics import PSNR, EdgeAccuracy
9 | import matplotlib.pyplot as plt
10 | import cv2
11 |
12 | class EdgeConnect():
13 | def __init__(self, config):
14 | self.config = config
15 |
16 | if config.MODEL == 1:
17 | model_name = 'edge'
18 | elif config.MODEL == 2:
19 | model_name = 'inpaint'
20 | elif config.MODEL == 3:
21 | model_name = 'edge_inpaint'
22 | elif config.MODEL == 4:
23 | model_name = 'joint'
24 |
25 | self.debug = False
26 | self.model_name = model_name
27 | self.edge_model = EdgeModel(config).to(config.DEVICE)
28 | self.inpaint_model = InpaintingModel(config).to(config.DEVICE)
29 |
30 | self.psnr = PSNR(255.0).to(config.DEVICE)
31 | self.edgeacc = EdgeAccuracy(config.EDGE_THRESHOLD).to(config.DEVICE)
32 |
33 | # test mode
34 | if self.config.MODE == 2:
35 | print(config.TEST_FLIST)
36 | print(config.TEST_EDGE_FLIST)
37 | print(config.TEST_MASK_FLIST)
38 | self.test_dataset = Dataset(config, config.TEST_FLIST, config.TEST_EDGE_FLIST, config.TEST_MASK_FLIST, augment=False, training=False)
39 | else:
40 | self.train_dataset = Dataset(config, config.TRAIN_FLIST, config.TRAIN_EDGE_FLIST, config.TRAIN_MASK_FLIST, augment=True, training=True)
41 | self.val_dataset = Dataset(config, config.VAL_FLIST, config.VAL_EDGE_FLIST, config.VAL_MASK_FLIST, augment=False, training=True)
42 | self.sample_iterator = self.val_dataset.create_iterator(config.SAMPLE_SIZE)
43 |
44 | self.samples_path = os.path.join(config.PATH, 'samples')
45 | self.results_path = os.path.join(config.PATH, 'results')
46 |
47 | if config.RESULTS is not None:
48 | self.results_path = os.path.join(config.RESULTS)
49 |
50 | if config.DEBUG is not None and config.DEBUG != 0:
51 | self.debug = True
52 |
53 | self.log_file = os.path.join(config.PATH, 'log_' + model_name + '.dat')
54 |
55 | def load(self):
56 | if self.config.MODEL == 1:
57 | self.edge_model.load()
58 |
59 | elif self.config.MODEL == 2:
60 | self.inpaint_model.load()
61 |
62 | else:
63 | self.edge_model.load()
64 | self.inpaint_model.load()
65 |
66 | def save(self):
67 | if self.config.MODEL == 1:
68 | self.edge_model.save()
69 |
70 | elif self.config.MODEL == 2 or self.config.MODEL == 3:
71 | self.inpaint_model.save()
72 |
73 | else:
74 | self.edge_model.save()
75 | self.inpaint_model.save()
76 |
77 | def train(self):
78 | train_loader = DataLoader(
79 | dataset=self.train_dataset,
80 | batch_size=self.config.BATCH_SIZE,
81 | num_workers=4,
82 | drop_last=True,
83 | shuffle=True
84 | )
85 |
86 | epoch = 0
87 | keep_training = True
88 | model = self.config.MODEL
89 | max_iteration = int(float((self.config.MAX_ITERS)))
90 | total = len(self.train_dataset)
91 |
92 | if total == 0:
93 | print('No training data was provided! Check \'TRAIN_FLIST\' value in the configuration file.')
94 | return
95 |
96 | while(keep_training):
97 | epoch += 1
98 | print('\n\nTraining epoch: %d' % epoch)
99 |
100 | progbar = Progbar(total, width=20, stateful_metrics=['epoch', 'iter'])
101 |
102 | for items in train_loader:
103 | self.edge_model.train()
104 | self.inpaint_model.train()
105 |
106 | images, images_gray, edges, masks = self.cuda(*items)
107 |
108 | # edge model
109 | if model == 1:
110 | # train
111 | outputs, gen_loss, dis_loss, logs = self.edge_model.process(images_gray, edges, masks)
112 |
113 | # metrics
114 | precision, recall = self.edgeacc(edges * masks, outputs * masks)
115 | logs.append(('precision', precision.item()))
116 | logs.append(('recall', recall.item()))
117 |
118 | # backward
119 | self.edge_model.backward(gen_loss, dis_loss)
120 | iteration = self.edge_model.iteration
121 |
122 |
123 | # inpaint model
124 | elif model == 2:
125 | # train
126 | outputs, gen_loss, dis_loss, logs = self.inpaint_model.process(images, edges, masks)
127 | outputs_merged = (outputs * masks) + (images * (1 - masks))
128 |
129 | # metrics
130 | psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged))
131 | mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float()
132 | logs.append(('psnr', psnr.item()))
133 | logs.append(('mae', mae.item()))
134 |
135 | # backward
136 | self.inpaint_model.backward(gen_loss, dis_loss)
137 | iteration = self.inpaint_model.iteration
138 |
139 |
140 | # inpaint with edge model
141 | elif model == 3:
142 | # train
143 | if True or np.random.binomial(1, 0.5) > 0:
144 | outputs = self.edge_model(images_gray, edges, masks)
145 | outputs = outputs * masks + edges * (1 - masks)
146 | else:
147 | outputs = edges
148 |
149 | outputs, gen_loss, dis_loss, logs = self.inpaint_model.process(images, outputs.detach(), masks)
150 | outputs_merged = (outputs * masks) + (images * (1 - masks))
151 |
152 | # metrics
153 | psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged))
154 | mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float()
155 | logs.append(('psnr', psnr.item()))
156 | logs.append(('mae', mae.item()))
157 |
158 | # backward
159 | self.inpaint_model.backward(gen_loss, dis_loss)
160 | iteration = self.inpaint_model.iteration
161 |
162 |
163 | # joint model
164 | else:
165 | # train
166 | e_outputs, e_gen_loss, e_dis_loss, e_logs = self.edge_model.process(images_gray, edges, masks)
167 | e_outputs = e_outputs * masks + edges * (1 - masks)
168 | i_outputs, i_gen_loss, i_dis_loss, i_logs = self.inpaint_model.process(images, e_outputs, masks)
169 | outputs_merged = (i_outputs * masks) + (images * (1 - masks))
170 |
171 | # metrics
172 | psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged))
173 | mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float()
174 | precision, recall = self.edgeacc(edges * masks, e_outputs * masks)
175 | e_logs.append(('pre', precision.item()))
176 | e_logs.append(('rec', recall.item()))
177 | i_logs.append(('psnr', psnr.item()))
178 | i_logs.append(('mae', mae.item()))
179 | logs = e_logs + i_logs
180 |
181 | # backward
182 | self.inpaint_model.backward(i_gen_loss, i_dis_loss)
183 | self.edge_model.backward(e_gen_loss, e_dis_loss)
184 | iteration = self.inpaint_model.iteration
185 |
186 |
187 | if iteration >= max_iteration:
188 | keep_training = False
189 | break
190 |
191 | logs = [
192 | ("epoch", epoch),
193 | ("iter", iteration),
194 | ] + logs
195 |
196 | progbar.add(len(images), values=logs if self.config.VERBOSE else [x for x in logs if not x[0].startswith('l_')])
197 |
198 | # log model at checkpoints
199 | if self.config.LOG_INTERVAL and iteration % self.config.LOG_INTERVAL == 0:
200 | self.log(logs)
201 |
202 | # sample model at checkpoints
203 | if self.config.SAMPLE_INTERVAL and iteration % self.config.SAMPLE_INTERVAL == 0:
204 | self.sample()
205 |
206 | # evaluate model at checkpoints
207 | if self.config.EVAL_INTERVAL and iteration % self.config.EVAL_INTERVAL == 0:
208 | print('\nstart eval...\n')
209 | self.eval()
210 |
211 | # save model at checkpoints
212 | if self.config.SAVE_INTERVAL and iteration % self.config.SAVE_INTERVAL == 0:
213 | self.save()
214 |
215 | print('\nEnd training....')
216 |
217 | def eval(self):
218 | val_loader = DataLoader(
219 | dataset=self.val_dataset,
220 | batch_size=self.config.BATCH_SIZE,
221 | drop_last=True,
222 | shuffle=True
223 | )
224 |
225 | model = self.config.MODEL
226 | total = len(self.val_dataset)
227 |
228 | self.edge_model.eval()
229 | self.inpaint_model.eval()
230 |
231 | progbar = Progbar(total, width=20, stateful_metrics=['it'])
232 | iteration = 0
233 |
234 | for items in val_loader:
235 | iteration += 1
236 | images, images_gray, edges, masks = self.cuda(*items)
237 |
238 | # edge model
239 | if model == 1:
240 | # eval
241 | outputs, gen_loss, dis_loss, logs = self.edge_model.process(images_gray, edges, masks)
242 |
243 | # metrics
244 | precision, recall = self.edgeacc(edges * masks, outputs * masks)
245 | logs.append(('precision', precision.item()))
246 | logs.append(('recall', recall.item()))
247 |
248 |
249 | # inpaint model
250 | elif model == 2:
251 | # eval
252 | outputs, gen_loss, dis_loss, logs = self.inpaint_model.process(images, edges, masks)
253 | outputs_merged = (outputs * masks) + (images * (1 - masks))
254 |
255 | # metrics
256 | psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged))
257 | mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float()
258 | logs.append(('psnr', psnr.item()))
259 | logs.append(('mae', mae.item()))
260 |
261 |
262 | # inpaint with edge model
263 | elif model == 3:
264 | # eval
265 | outputs = self.edge_model(images_gray, edges, masks)
266 | outputs = outputs * masks + edges * (1 - masks)
267 |
268 | outputs, gen_loss, dis_loss, logs = self.inpaint_model.process(images, outputs.detach(), masks)
269 | outputs_merged = (outputs * masks) + (images * (1 - masks))
270 |
271 | # metrics
272 | psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged))
273 | mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float()
274 | logs.append(('psnr', psnr.item()))
275 | logs.append(('mae', mae.item()))
276 |
277 |
278 | # joint model
279 | else:
280 | # eval
281 | e_outputs, e_gen_loss, e_dis_loss, e_logs = self.edge_model.process(images_gray, edges, masks)
282 | e_outputs = e_outputs * masks + edges * (1 - masks)
283 | i_outputs, i_gen_loss, i_dis_loss, i_logs = self.inpaint_model.process(images, e_outputs, masks)
284 | outputs_merged = (i_outputs * masks) + (images * (1 - masks))
285 |
286 | # metrics
287 | psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged))
288 | mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float()
289 | precision, recall = self.edgeacc(edges * masks, e_outputs * masks)
290 | e_logs.append(('pre', precision.item()))
291 | e_logs.append(('rec', recall.item()))
292 | i_logs.append(('psnr', psnr.item()))
293 | i_logs.append(('mae', mae.item()))
294 | logs = e_logs + i_logs
295 |
296 |
297 | logs = [("it", iteration), ] + logs
298 | progbar.add(len(images), values=logs)
299 |
300 | def test(self):
301 | self.edge_model.eval()
302 | self.inpaint_model.eval()
303 |
304 | model = self.config.MODEL
305 | create_dir(self.results_path)
306 |
307 | test_loader = DataLoader(
308 | dataset=self.test_dataset,
309 | batch_size=1,
310 | )
311 |
312 | index = 0
313 | for items in test_loader:
314 | name = self.test_dataset.load_name(index)
315 | images, images_gray, edges, masks = items
316 | tmp = images[0, :, :, :]
317 | tmp = np.transpose(tmp, (1, 2, 0))
318 | # plt.imshow(tmp)
319 | # plt.show()
320 | # tmp = images_gray[0, :, :, :]
321 | # tmp = np.transpose(tmp, (1, 2, 0))
322 | # plt.imshow(tmp[:,:,0], cmap=plt.cm.gray)
323 | # plt.show()
324 | # tmp = edges[0, :, :, :]
325 | # tmp = np.transpose(tmp, (1, 2, 0))
326 | # plt.imshow(tmp[:,:,0], cmap=plt.cm.gray)
327 | # plt.show()
328 |
329 | # plt.imshow(edges)
330 | # plt.show()
331 |
332 | images, images_gray, edges, masks = self.cuda(*items)
333 |
334 | index += 1
335 |
336 | # edge model
337 | if model == 1:
338 | outputs = self.edge_model(images_gray, edges, masks)
339 | outputs_merged = (outputs * masks) + (edges * (1 - masks))
340 |
341 | # inpaint model
342 | elif model == 2:
343 | outputs = self.inpaint_model(images, edges, masks)
344 | outputs_merged = (outputs * masks) + (images * (1 - masks))
345 |
346 | # inpaint with edge model / joint model
347 | else:
348 | edges = self.edge_model(images_gray, edges, masks).detach()
349 | outputs = self.inpaint_model(images, edges, masks)
350 | outputs_merged = (outputs * masks) + (images * (1 - masks))
351 |
352 | output = self.postprocess(outputs_merged)[0]
353 | path = os.path.join(self.results_path, name)
354 | print(index, name)
355 |
356 | imsave(output, path)
357 |
358 | if self.debug:
359 | edges = self.postprocess(1 - edges)[0]
360 | masked = self.postprocess(images * (1 - masks) + masks)[0]
361 | fname, fext = name.split('.')
362 |
363 | imsave(edges, os.path.join(self.results_path, fname + '_edge.' + fext))
364 | imsave(masked, os.path.join(self.results_path, fname + '_masked.' + fext))
365 |
366 | print('\nEnd test....')
367 |
368 | def sample(self, it=None):
369 | # do not sample when validation set is empty
370 | if len(self.val_dataset) == 0:
371 | return
372 |
373 | self.edge_model.eval()
374 | self.inpaint_model.eval()
375 |
376 | model = self.config.MODEL
377 | items = next(self.sample_iterator)
378 | images, images_gray, edges, masks = self.cuda(*items)
379 |
380 |
381 | # edge model
382 | if model == 1:
383 | iteration = self.edge_model.iteration
384 | inputs = (images_gray * (1 - masks)) + masks
385 | outputs = self.edge_model(images_gray, edges, masks)
386 | outputs_merged = (outputs * masks) + (edges * (1 - masks))
387 |
388 | # inpaint model
389 | elif model == 2:
390 | iteration = self.inpaint_model.iteration
391 | inputs = (images * (1 - masks)) + masks
392 | outputs = self.inpaint_model(images, edges, masks)
393 | outputs_merged = (outputs * masks) + (images * (1 - masks))
394 |
395 | # inpaint with edge model / joint model
396 | else:
397 | iteration = self.inpaint_model.iteration
398 | inputs = (images * (1 - masks)) + masks
399 | outputs = self.edge_model(images_gray, edges, masks).detach()
400 | edges = (outputs * masks + edges * (1 - masks)).detach()
401 | outputs = self.inpaint_model(images, edges, masks)
402 | outputs_merged = (outputs * masks) + (images * (1 - masks))
403 |
404 | if it is not None:
405 | iteration = it
406 |
407 | image_per_row = 2
408 | if self.config.SAMPLE_SIZE <= 6:
409 | image_per_row = 1
410 |
411 | images = stitch_images(
412 | self.postprocess(images),
413 | self.postprocess(inputs),
414 | self.postprocess(edges),
415 | self.postprocess(outputs),
416 | self.postprocess(outputs_merged),
417 | img_per_row = image_per_row
418 | )
419 |
420 |
421 | path = os.path.join(self.samples_path, self.model_name)
422 | name = os.path.join(path, str(iteration).zfill(5) + ".png")
423 | create_dir(path)
424 | print('\nsaving sample ' + name)
425 | images.save(name)
426 |
427 | def log(self, logs):
428 | with open(self.log_file, 'a') as f:
429 | f.write('%s\n' % ' '.join([str(item[1]) for item in logs]))
430 |
431 | def cuda(self, *args):
432 | return (item.to(self.config.DEVICE) for item in args)
433 |
434 | def postprocess(self, img):
435 | # [0, 1] => [0, 255]
436 | img = img * 255.0
437 | img = img.permute(0, 2, 3, 1)
438 | return img.int()
439 |
--------------------------------------------------------------------------------
/MECNet/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 |
5 |
6 | class AdversarialLoss(nn.Module):
7 | r"""
8 | Adversarial loss
9 | https://arxiv.org/abs/1711.10337
10 | """
11 |
12 | def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
13 | r"""
14 | type = nsgan | lsgan | hinge
15 | """
16 | super(AdversarialLoss, self).__init__()
17 |
18 | self.type = type
19 | self.register_buffer('real_label', torch.tensor(target_real_label))
20 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
21 |
22 | if type == 'nsgan':
23 | self.criterion = nn.BCELoss()
24 |
25 | elif type == 'lsgan':
26 | self.criterion = nn.MSELoss()
27 |
28 | elif type == 'hinge':
29 | self.criterion = nn.ReLU()
30 |
31 | def __call__(self, outputs, is_real, is_disc=None):
32 | if self.type == 'hinge':
33 | if is_disc:
34 | if is_real:
35 | outputs = -outputs
36 | return self.criterion(1 + outputs).mean()
37 | else:
38 | return (-outputs).mean()
39 |
40 | else:
41 | labels = (self.real_label if is_real else self.fake_label).expand_as(outputs)
42 | loss = self.criterion(outputs, labels)
43 | return loss
44 |
45 |
46 | class StyleLoss(nn.Module):
47 | r"""
48 | Perceptual loss, VGG-based
49 | https://arxiv.org/abs/1603.08155
50 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
51 | """
52 |
53 | def __init__(self):
54 | super(StyleLoss, self).__init__()
55 | self.add_module('vgg', VGG19())
56 | self.criterion = torch.nn.L1Loss()
57 |
58 | def compute_gram(self, x):
59 | b, ch, h, w = x.size()
60 | f = x.view(b, ch, w * h)
61 | f_T = f.transpose(1, 2)
62 | G = f.bmm(f_T) / (h * w * ch)
63 |
64 | return G
65 |
66 | def __call__(self, x, y):
67 | # Compute features
68 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
69 |
70 | # Compute loss
71 | style_loss = 0.0
72 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2']))
73 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4']))
74 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4']))
75 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2']))
76 |
77 | return style_loss
78 |
79 |
80 |
81 | class PerceptualLoss(nn.Module):
82 | r"""
83 | Perceptual loss, VGG-based
84 | https://arxiv.org/abs/1603.08155
85 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
86 | """
87 |
88 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
89 | super(PerceptualLoss, self).__init__()
90 | self.add_module('vgg', VGG19())
91 | self.criterion = torch.nn.L1Loss()
92 | self.weights = weights
93 |
94 | def __call__(self, x, y):
95 | # Compute features
96 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
97 |
98 | content_loss = 0.0
99 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
100 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
101 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
102 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
103 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
104 |
105 |
106 | return content_loss
107 |
108 |
109 |
110 | class VGG19(torch.nn.Module):
111 | def __init__(self):
112 | super(VGG19, self).__init__()
113 | features = models.vgg19(pretrained=True).features
114 | self.relu1_1 = torch.nn.Sequential()
115 | self.relu1_2 = torch.nn.Sequential()
116 |
117 | self.relu2_1 = torch.nn.Sequential()
118 | self.relu2_2 = torch.nn.Sequential()
119 |
120 | self.relu3_1 = torch.nn.Sequential()
121 | self.relu3_2 = torch.nn.Sequential()
122 | self.relu3_3 = torch.nn.Sequential()
123 | self.relu3_4 = torch.nn.Sequential()
124 |
125 | self.relu4_1 = torch.nn.Sequential()
126 | self.relu4_2 = torch.nn.Sequential()
127 | self.relu4_3 = torch.nn.Sequential()
128 | self.relu4_4 = torch.nn.Sequential()
129 |
130 | self.relu5_1 = torch.nn.Sequential()
131 | self.relu5_2 = torch.nn.Sequential()
132 | self.relu5_3 = torch.nn.Sequential()
133 | self.relu5_4 = torch.nn.Sequential()
134 |
135 | for x in range(2):
136 | self.relu1_1.add_module(str(x), features[x])
137 |
138 | for x in range(2, 4):
139 | self.relu1_2.add_module(str(x), features[x])
140 |
141 | for x in range(4, 7):
142 | self.relu2_1.add_module(str(x), features[x])
143 |
144 | for x in range(7, 9):
145 | self.relu2_2.add_module(str(x), features[x])
146 |
147 | for x in range(9, 12):
148 | self.relu3_1.add_module(str(x), features[x])
149 |
150 | for x in range(12, 14):
151 | self.relu3_2.add_module(str(x), features[x])
152 |
153 | for x in range(14, 16):
154 | self.relu3_3.add_module(str(x), features[x])
155 |
156 | for x in range(16, 18):
157 | self.relu3_4.add_module(str(x), features[x])
158 |
159 | for x in range(18, 21):
160 | self.relu4_1.add_module(str(x), features[x])
161 |
162 | for x in range(21, 23):
163 | self.relu4_2.add_module(str(x), features[x])
164 |
165 | for x in range(23, 25):
166 | self.relu4_3.add_module(str(x), features[x])
167 |
168 | for x in range(25, 27):
169 | self.relu4_4.add_module(str(x), features[x])
170 |
171 | for x in range(27, 30):
172 | self.relu5_1.add_module(str(x), features[x])
173 |
174 | for x in range(30, 32):
175 | self.relu5_2.add_module(str(x), features[x])
176 |
177 | for x in range(32, 34):
178 | self.relu5_3.add_module(str(x), features[x])
179 |
180 | for x in range(34, 36):
181 | self.relu5_4.add_module(str(x), features[x])
182 |
183 | # don't need the gradients, just want the features
184 | for param in self.parameters():
185 | param.requires_grad = False
186 |
187 | def forward(self, x):
188 | relu1_1 = self.relu1_1(x)
189 | relu1_2 = self.relu1_2(relu1_1)
190 |
191 | relu2_1 = self.relu2_1(relu1_2)
192 | relu2_2 = self.relu2_2(relu2_1)
193 |
194 | relu3_1 = self.relu3_1(relu2_2)
195 | relu3_2 = self.relu3_2(relu3_1)
196 | relu3_3 = self.relu3_3(relu3_2)
197 | relu3_4 = self.relu3_4(relu3_3)
198 |
199 | relu4_1 = self.relu4_1(relu3_4)
200 | relu4_2 = self.relu4_2(relu4_1)
201 | relu4_3 = self.relu4_3(relu4_2)
202 | relu4_4 = self.relu4_4(relu4_3)
203 |
204 | relu5_1 = self.relu5_1(relu4_4)
205 | relu5_2 = self.relu5_2(relu5_1)
206 | relu5_3 = self.relu5_3(relu5_2)
207 | relu5_4 = self.relu5_4(relu5_3)
208 |
209 | out = {
210 | 'relu1_1': relu1_1,
211 | 'relu1_2': relu1_2,
212 |
213 | 'relu2_1': relu2_1,
214 | 'relu2_2': relu2_2,
215 |
216 | 'relu3_1': relu3_1,
217 | 'relu3_2': relu3_2,
218 | 'relu3_3': relu3_3,
219 | 'relu3_4': relu3_4,
220 |
221 | 'relu4_1': relu4_1,
222 | 'relu4_2': relu4_2,
223 | 'relu4_3': relu4_3,
224 | 'relu4_4': relu4_4,
225 |
226 | 'relu5_1': relu5_1,
227 | 'relu5_2': relu5_2,
228 | 'relu5_3': relu5_3,
229 | 'relu5_4': relu5_4,
230 | }
231 | return out
232 |
--------------------------------------------------------------------------------
/MECNet/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class EdgeAccuracy(nn.Module):
6 | """
7 | Measures the accuracy of the edge map
8 | """
9 | def __init__(self, threshold=0.5):
10 | super(EdgeAccuracy, self).__init__()
11 | self.threshold = threshold
12 |
13 | def __call__(self, inputs, outputs):
14 | labels = (inputs > self.threshold)
15 | outputs = (outputs > self.threshold)
16 |
17 | relevant = torch.sum(labels.float())
18 | selected = torch.sum(outputs.float())
19 |
20 | if relevant == 0 and selected == 0:
21 | return torch.tensor(1), torch.tensor(1)
22 |
23 | true_positive = ((outputs == labels) * labels).float()
24 | recall = torch.sum(true_positive) / (relevant + 1e-8)
25 | precision = torch.sum(true_positive) / (selected + 1e-8)
26 |
27 | return precision, recall
28 |
29 |
30 | class PSNR(nn.Module):
31 | def __init__(self, max_val):
32 | super(PSNR, self).__init__()
33 |
34 | base10 = torch.log(torch.tensor(10.0))
35 | max_val = torch.tensor(max_val).float()
36 |
37 | self.register_buffer('base10', base10)
38 | self.register_buffer('max_val', 20 * torch.log(max_val) / base10)
39 |
40 | def __call__(self, a, b):
41 | mse = torch.mean((a.float() - b.float()) ** 2)
42 |
43 | if mse == 0:
44 | return torch.tensor(0)
45 |
46 | return self.max_val - 10 * torch.log(mse) / self.base10
47 |
--------------------------------------------------------------------------------
/MECNet/models.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from .networks2 import InpaintGenerator, EdgeGenerator, Discriminator
6 | from .loss import AdversarialLoss, PerceptualLoss, StyleLoss
7 |
8 |
9 | class BaseModel(nn.Module):
10 | def __init__(self , name, config):
11 | super(BaseModel, self).__init__()
12 |
13 | self.name = name
14 | self.config = config
15 | self.iteration = 0
16 | self.device = 0
17 | self.GPU = [0]
18 | # self.gen_weights_path ="/home/wds/First_Project/LBAM_GRU_version2/checkpoints/psv/EdgeModel_gen.pth"
19 | self.gen_weights_path = "/home/wds/First_Project/edge-connect_psp/checkpoints/psv/EdgeModel_gen_400.pth"
20 | # self.gen_weights_path = "/home/wds/First_Project/edge-connect_psp/checkpoints/psv/EdgeModel_gen_400.pth"
21 | # self.gen_weights_path = '/home/wds/LBAM_version7/EdgeModel_gen.pth'
22 | self.dis_weights_path = os.path.join(config.PATH, name + '_dis.pth')
23 |
24 | def load(self):
25 | print(os.path.exists(self.gen_weights_path))
26 | if os.path.exists(self.gen_weights_path):
27 | print('Loading %s generator...' % self.name)
28 |
29 | if torch.cuda.is_available():
30 | data = torch.load(self.gen_weights_path)
31 | else:
32 | data = torch.load(self.gen_weights_path, map_location=lambda storage, loc: storage)
33 | self.generator.load_state_dict(data['generator'])
34 | self.iteration = data['iteration']
35 | # load discriminator only when training
36 | if self.config.MODE == 1 and os.path.exists(self.dis_weights_path):
37 | print('Loading %s discriminator...' % self.name)
38 |
39 | if torch.cuda.is_available():
40 | data = torch.load(self.dis_weights_path)
41 | else:
42 | data = torch.load(self.dis_weights_path, map_location=lambda storage, loc: storage)
43 |
44 | self.discriminator.load_state_dict(data['discriminator'])
45 |
46 | def save(self):
47 | print('\nsaving %s...\n' % self.name)
48 | torch.save({
49 | 'iteration': self.iteration,
50 | 'generator': self.generator.state_dict()
51 | }, self.gen_weights_path)
52 |
53 | torch.save({
54 | 'discriminator': self.discriminator.state_dict()
55 | }, self.dis_weights_path)
56 |
57 |
58 | class EdgeModel(BaseModel):
59 | def __init__(self, config):
60 | super(EdgeModel, self).__init__('EdgeModel', config)
61 |
62 | # generator input: [grayscale(1) + edge(1) + mask(1)]
63 | # discriminator input: (grayscale(1) + edge(1))
64 | generator = EdgeGenerator(use_spectral_norm=True)
65 | discriminator = Discriminator(in_channels=2, use_sigmoid=config.GAN_LOSS != 'hinge')
66 | self.device = config.DEVICE
67 | self.GPU = config.GPU
68 | # if len(config.GPU) > 1:
69 | # generator = nn.DataParallel(generator, config.GPU)
70 | # discriminator = nn.DataParallel(discriminator, config.GPU)
71 | l1_loss = nn.L1Loss()
72 | adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)
73 |
74 | self.add_module('generator', generator)
75 | self.add_module('discriminator', discriminator)
76 |
77 | self.add_module('l1_loss', l1_loss)
78 | self.add_module('adversarial_loss', adversarial_loss)
79 |
80 | self.gen_optimizer = optim.Adam(
81 | params=generator.parameters(),
82 | lr=float(config.LR),
83 | betas=(config.BETA1, config.BETA2)
84 | )
85 |
86 | self.dis_optimizer = optim.Adam(
87 | params=discriminator.parameters(),
88 | lr=float(config.LR) * float(config.D2G_LR),
89 | betas=(config.BETA1, config.BETA2)
90 | )
91 |
92 | def process(self, images, edges, masks):
93 | self.iteration += 1
94 |
95 |
96 | # zero optimizers
97 | self.gen_optimizer.zero_grad()
98 | self.dis_optimizer.zero_grad()
99 |
100 |
101 | # process outputs
102 | outputs = self(images, edges, masks)
103 | gen_loss = 0
104 | dis_loss = 0
105 |
106 |
107 | # discriminator loss
108 | dis_input_real = torch.cat((images, edges), dim=1)
109 | dis_input_fake = torch.cat((images, outputs.detach()), dim=1)
110 | dis_real, dis_real_feat = self.discriminator(dis_input_real) # in: (grayscale(1) + edge(1))
111 | dis_fake, dis_fake_feat = self.discriminator(dis_input_fake) # in: (grayscale(1) + edge(1))
112 | dis_real_loss = self.adversarial_loss(dis_real, True, True)
113 | dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
114 | dis_loss += (dis_real_loss + dis_fake_loss) / 2
115 |
116 |
117 | # generator adversarial loss
118 | gen_input_fake = torch.cat((images, outputs), dim=1)
119 | gen_fake, gen_fake_feat = self.discriminator(gen_input_fake) # in: (grayscale(1) + edge(1))
120 | gen_gan_loss = self.adversarial_loss(gen_fake, True, False)
121 | gen_loss += gen_gan_loss
122 |
123 |
124 | # generator feature matching loss
125 | gen_fm_loss = 0
126 | for i in range(len(dis_real_feat)):
127 | gen_fm_loss += self.l1_loss(gen_fake_feat[i], dis_real_feat[i].detach())
128 | gen_fm_loss = gen_fm_loss * self.config.FM_LOSS_WEIGHT
129 | gen_loss += gen_fm_loss
130 |
131 |
132 | # create logs
133 | logs = [
134 | ("l_d1", dis_loss.item()),
135 | ("l_g1", gen_gan_loss.item()),
136 | ("l_fm", gen_fm_loss.item()),
137 | ]
138 |
139 | return outputs, gen_loss, dis_loss, logs
140 |
141 | def forward(self, images, edges, masks):
142 | edges_masked = (edges * (1 - masks))
143 | images_masked = (images * (1 - masks)) + masks
144 | inputs = torch.cat((images_masked, edges_masked, masks), dim=1)
145 | outputs = self.generator(inputs) # in: [grayscale(1) + edge(1) + mask(1)]
146 | return outputs
147 |
148 | def backward(self, gen_loss=None, dis_loss=None):
149 | if dis_loss is not None:
150 | dis_loss.backward()
151 | self.dis_optimizer.step()
152 |
153 | if gen_loss is not None:
154 | gen_loss.backward()
155 | self.gen_optimizer.step()
156 |
157 |
158 | class InpaintingModel(BaseModel):
159 | def __init__(self, config):
160 | super(InpaintingModel, self).__init__('InpaintingModel', config)
161 |
162 | # generator input: [rgb(3) + edge(1)]
163 | # discriminator input: [rgb(3)]
164 | generator = InpaintGenerator()
165 | discriminator = Discriminator(in_channels=3, use_sigmoid=config.GAN_LOSS != 'hinge')
166 | if len(config.GPU) > 1:
167 | generator = nn.DataParallel(generator, config.GPU)
168 | discriminator = nn.DataParallel(discriminator , config.GPU)
169 |
170 | l1_loss = nn.L1Loss()
171 | perceptual_loss = PerceptualLoss()
172 | style_loss = StyleLoss()
173 | adversarial_loss = AdversarialLoss(type=config.GAN_LOSS)
174 |
175 | self.add_module('generator', generator)
176 | self.add_module('discriminator', discriminator)
177 |
178 | self.add_module('l1_loss', l1_loss)
179 | self.add_module('perceptual_loss', perceptual_loss)
180 | self.add_module('style_loss', style_loss)
181 | self.add_module('adversarial_loss', adversarial_loss)
182 |
183 | self.gen_optimizer = optim.Adam(
184 | params=generator.parameters(),
185 | lr=float(config.LR),
186 | betas=(config.BETA1, config.BETA2)
187 | )
188 |
189 | self.dis_optimizer = optim.Adam(
190 | params=discriminator.parameters(),
191 | lr=float(config.LR) * float(config.D2G_LR),
192 | betas=(config.BETA1, config.BETA2)
193 | )
194 |
195 | def process(self, images, edges, masks):
196 | self.iteration += 1
197 |
198 | # zero optimizers
199 | self.gen_optimizer.zero_grad()
200 | self.dis_optimizer.zero_grad()
201 |
202 |
203 | # process outputs
204 | outputs = self(images, edges, masks)
205 | gen_loss = 0
206 | dis_loss = 0
207 |
208 |
209 | # discriminator loss
210 | dis_input_real = images
211 | dis_input_fake = outputs.detach()
212 | dis_real, _ = self.discriminator(dis_input_real) # in: [rgb(3)]
213 | dis_fake, _ = self.discriminator(dis_input_fake) # in: [rgb(3)]
214 | dis_real_loss = self.adversarial_loss(dis_real, True, True)
215 | dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
216 | dis_loss += (dis_real_loss + dis_fake_loss) / 2
217 |
218 |
219 | # generator adversarial loss
220 | gen_input_fake = outputs
221 | gen_fake, _ = self.discriminator(gen_input_fake) # in: [rgb(3)]
222 | gen_gan_loss = self.adversarial_loss(gen_fake, True, False) * self.config.INPAINT_ADV_LOSS_WEIGHT
223 | gen_loss += gen_gan_loss
224 |
225 |
226 | # generator l1 loss
227 | gen_l1_loss = self.l1_loss(outputs, images) * self.config.L1_LOSS_WEIGHT / torch.mean(masks)
228 | gen_loss += gen_l1_loss
229 |
230 |
231 | # generator perceptual loss
232 | gen_content_loss = self.perceptual_loss(outputs, images)
233 | gen_content_loss = gen_content_loss * self.config.CONTENT_LOSS_WEIGHT
234 | gen_loss += gen_content_loss
235 |
236 |
237 | # generator style loss
238 | gen_style_loss = self.style_loss(outputs * masks, images * masks)
239 | gen_style_loss = gen_style_loss * self.config.STYLE_LOSS_WEIGHT
240 | gen_loss += gen_style_loss
241 |
242 |
243 | # create logs
244 | logs = [
245 | ("l_d2", dis_loss.item()),
246 | ("l_g2", gen_gan_loss.item()),
247 | ("l_l1", gen_l1_loss.item()),
248 | ("l_per", gen_content_loss.item()),
249 | ("l_sty", gen_style_loss.item()),
250 | ]
251 |
252 | return outputs, gen_loss, dis_loss, logs
253 |
254 | def forward(self, images, edges, masks):
255 | images_masked = (images * (1 - masks).float()) + masks
256 | inputs = torch.cat((images_masked, edges), dim=1)
257 | outputs = self.generator(inputs) # in: [rgb(3) + edge(1)]
258 | return outputs
259 |
260 | def backward(self, gen_loss=None, dis_loss=None):
261 | dis_loss.backward()
262 | self.dis_optimizer.step()
263 |
264 | gen_loss.backward()
265 | self.gen_optimizer.step()
266 |
--------------------------------------------------------------------------------
/MECNet/networks2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class BaseNetwork(nn.Module):
6 | def __init__(self):
7 | super(BaseNetwork, self).__init__()
8 |
9 | def init_weights(self, init_type='normal', gain=0.02):
10 | '''
11 | initialize network's weights
12 | init_type: normal | xavier | kaiming | orthogonal
13 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
14 | '''
15 |
16 | def init_func(m):
17 | classname = m.__class__.__name__
18 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
19 | if init_type == 'normal':
20 | nn.init.normal_(m.weight.data, 0.0, gain)
21 | elif init_type == 'xavier':
22 | nn.init.xavier_normal_(m.weight.data, gain=gain)
23 | elif init_type == 'kaiming':
24 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
25 | elif init_type == 'orthogonal':
26 | nn.init.orthogonal_(m.weight.data, gain=gain)
27 |
28 | if hasattr(m, 'bias') and m.bias is not None:
29 | nn.init.constant_(m.bias.data, 0.0)
30 |
31 | elif classname.find('BatchNorm2d') != -1:
32 | nn.init.normal_(m.weight.data, 1.0, gain)
33 | nn.init.constant_(m.bias.data, 0.0)
34 |
35 | self.apply(init_func)
36 |
37 |
38 | class InpaintGenerator(BaseNetwork):
39 | def __init__(self, residual_blocks=8, init_weights=True):
40 | super(InpaintGenerator, self).__init__()
41 |
42 | self.encoder = nn.Sequential(
43 | nn.ReflectionPad2d(3),
44 | nn.Conv2d(in_channels=4, out_channels=64, kernel_size=7, padding=0),
45 | nn.InstanceNorm2d(64, track_running_stats=False),
46 | nn.ReLU(True),
47 |
48 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
49 | nn.InstanceNorm2d(128, track_running_stats=False),
50 | nn.ReLU(True),
51 |
52 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
53 | nn.InstanceNorm2d(256, track_running_stats=False),
54 | nn.ReLU(True)
55 | )
56 |
57 | blocks = []
58 | for _ in range(residual_blocks):
59 | block = ResnetBlock(256, 2)
60 | blocks.append(block)
61 |
62 | self.middle = nn.Sequential(*blocks)
63 |
64 | self.decoder = nn.Sequential(
65 | nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
66 | nn.InstanceNorm2d(128, track_running_stats=False),
67 | nn.ReLU(True),
68 |
69 | nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
70 | nn.InstanceNorm2d(64, track_running_stats=False),
71 | nn.ReLU(True),
72 |
73 | nn.ReflectionPad2d(3),
74 | nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, padding=0),
75 | )
76 |
77 | if init_weights:
78 | self.init_weights()
79 |
80 | def forward(self, x):
81 | x = self.encoder(x)
82 | x = self.middle(x)
83 | x = self.decoder(x)
84 | x = (torch.tanh(x) + 1) / 2
85 |
86 | return x
87 |
88 |
89 | class EdgeGenerator(BaseNetwork):
90 | def __init__(self, residual_blocks=8, use_spectral_norm=True, init_weights=True):
91 | super(EdgeGenerator, self).__init__()
92 |
93 | self.encoder = nn.Sequential(
94 | nn.ReflectionPad2d(3),
95 | spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=0), use_spectral_norm),
96 | nn.InstanceNorm2d(64, track_running_stats=False),
97 | nn.ReLU(True),
98 |
99 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
100 | nn.InstanceNorm2d(128, track_running_stats=False),
101 | nn.ReLU(True),
102 |
103 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), use_spectral_norm),
104 | nn.InstanceNorm2d(256, track_running_stats=False),
105 | nn.ReLU(True)
106 | )
107 | self.pool1 = nn.AdaptiveAvgPool2d(8)
108 | self.pool2 = nn.AdaptiveAvgPool2d(16)
109 | self.pool3 = nn.AdaptiveAvgPool2d(32)
110 | self.upsample1 = nn.Sequential(
111 | nn.ConvTranspose2d(in_channels=256, out_channels=256,kernel_size=4,stride=2,padding=1),
112 | nn.InstanceNorm2d(128,track_running_stats=False),
113 | nn.ReLU(True)
114 | )
115 | self.upsample2 = nn.Sequential(
116 | nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=4,stride=2,padding=1),
117 | nn.InstanceNorm2d(256,track_running_stats=False),
118 | nn.ReLU(True)
119 | )
120 | self.upsample3 = nn.Sequential(
121 | nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
122 | nn.InstanceNorm2d(256, track_running_stats=False),
123 | nn.ReLU(True)
124 | )
125 |
126 |
127 | # self.conv2
128 | blocks = []
129 | for _ in range(residual_blocks):
130 | block = ResnetBlock(256, 2, use_spectral_norm=use_spectral_norm)
131 | blocks.append(block)
132 |
133 | self.middle = nn.Sequential(*blocks)
134 |
135 | self.decoder = nn.Sequential(
136 | spectral_norm(nn.ConvTranspose2d(in_channels=512, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm),
137 | nn.InstanceNorm2d(128, track_running_stats=False),
138 | nn.ReLU(True),
139 |
140 | spectral_norm(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), use_spectral_norm),
141 | nn.InstanceNorm2d(64, track_running_stats=False),
142 | nn.ReLU(True),
143 |
144 | nn.ReflectionPad2d(3),
145 | nn.Conv2d(in_channels=64, out_channels=1, kernel_size=7, padding=0),
146 | )
147 |
148 | if init_weights:
149 | self.init_weights()
150 |
151 | def forward(self, x):
152 | print(x.shape)
153 | for i in range(len(self.encoder)):
154 | x = self.encoder[i](x)
155 | print(i,'x.shape:\n\t',x.shape)
156 | # x = self.encoder(x)
157 | # print(x.shape)
158 | x_8 = self.pool1(x)
159 | x_16 = self.pool2(x)
160 | x_32 = self.pool3(x)
161 | print(x_8.shape)
162 | x = self.middle(x)
163 | x_8 = self.middle(x_8)
164 | x_16 = self.middle(x_16)
165 | x_32 = self.middle(x_32)
166 | x_8 = self.upsample1(x_8)
167 | x_16 = torch.cat((x_16,x_8),dim=1)
168 | x_16 = self.upsample2(x_16)
169 | x_32 = torch.cat((x_32,x_16),dim=1)
170 | x_32 = self.upsample3(x_32)
171 | x = torch.cat((x,x_32),dim=1)
172 | x = self.decoder(x)
173 | x = torch.sigmoid(x)
174 | return x
175 |
176 |
177 | class Discriminator(BaseNetwork):
178 | def __init__(self, in_channels, use_sigmoid=True, use_spectral_norm=True, init_weights=True):
179 | super(Discriminator, self).__init__()
180 | self.use_sigmoid = use_sigmoid
181 |
182 | self.conv1 = self.features = nn.Sequential(
183 | spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
184 | nn.LeakyReLU(0.2, inplace=True),
185 | )
186 |
187 | self.conv2 = nn.Sequential(
188 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
189 | nn.LeakyReLU(0.2, inplace=True),
190 | )
191 |
192 | self.conv3 = nn.Sequential(
193 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm),
194 | nn.LeakyReLU(0.2, inplace=True),
195 | )
196 |
197 | self.conv4 = nn.Sequential(
198 | spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
199 | nn.LeakyReLU(0.2, inplace=True),
200 | )
201 |
202 | self.conv5 = nn.Sequential(
203 | spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm),
204 | )
205 |
206 | if init_weights:
207 | self.init_weights()
208 |
209 | def forward(self, x):
210 | conv1 = self.conv1(x)
211 | conv2 = self.conv2(conv1)
212 | conv3 = self.conv3(conv2)
213 | conv4 = self.conv4(conv3)
214 | conv5 = self.conv5(conv4)
215 |
216 | outputs = conv5
217 | if self.use_sigmoid:
218 | outputs = torch.sigmoid(conv5)
219 |
220 | return outputs, [conv1, conv2, conv3, conv4, conv5]
221 |
222 |
223 | class ResnetBlock(nn.Module):
224 | def __init__(self, dim, dilation=1, use_spectral_norm=False):
225 | super(ResnetBlock, self).__init__()
226 | self.conv_block = nn.Sequential(
227 | nn.ReflectionPad2d(dilation),
228 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm),
229 | nn.InstanceNorm2d(dim, track_running_stats=False),
230 | nn.ReLU(True),
231 |
232 | nn.ReflectionPad2d(1),
233 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm),
234 | nn.InstanceNorm2d(dim, track_running_stats=False),
235 | )
236 |
237 | def forward(self, x):
238 | out = x + self.conv_block(x)
239 |
240 | # Remove ReLU at the end of the residual block
241 | # http://torch.ch/blog/2016/02/04/resnets.html
242 |
243 | return out
244 |
245 |
246 | def spectral_norm(module, mode=True):
247 | if mode:
248 | return nn.utils.spectral_norm(module)
249 |
250 | return module
251 |
--------------------------------------------------------------------------------
/MECNet/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import random
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | from PIL import Image
8 |
9 |
10 | def create_dir(dir):
11 | if not os.path.exists(dir):
12 | os.makedirs(dir)
13 |
14 |
15 | def create_mask(width, height, mask_width, mask_height, x=None, y=None):
16 | mask = np.zeros((height, width))
17 | mask_x = x if x is not None else random.randint(0, width - mask_width)
18 | mask_y = y if y is not None else random.randint(0, height - mask_height)
19 | mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1
20 | return mask
21 |
22 |
23 | def stitch_images(inputs, *outputs, img_per_row=2):
24 | gap = 5
25 | columns = len(outputs) + 1
26 |
27 | width, height = inputs[0][:, :, 0].shape
28 | img = Image.new('RGB', (width * img_per_row * columns + gap * (img_per_row - 1), height * int(len(inputs) / img_per_row)))
29 | images = [inputs, *outputs]
30 |
31 | for ix in range(len(inputs)):
32 | xoffset = int(ix % img_per_row) * width * columns + int(ix % img_per_row) * gap
33 | yoffset = int(ix / img_per_row) * height
34 |
35 | for cat in range(len(images)):
36 | im = np.array((images[cat][ix]).cpu()).astype(np.uint8).squeeze()
37 | im = Image.fromarray(im)
38 | img.paste(im, (xoffset + cat * width, yoffset))
39 |
40 | return img
41 |
42 |
43 | def imshow(img, title=''):
44 | fig = plt.gcf()
45 | fig.canvas.set_window_title(title)
46 | plt.axis('off')
47 | plt.imshow(img, interpolation='none')
48 | plt.show()
49 |
50 |
51 | def imsave(img, path):
52 | im = Image.fromarray(img.cpu().numpy().astype(np.uint8).squeeze())
53 | im.save(path)
54 |
55 |
56 | class Progbar(object):
57 | """Displays a progress bar.
58 |
59 | Arguments:
60 | target: Total number of steps expected, None if unknown.
61 | width: Progress bar width on screen.
62 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
63 | stateful_metrics: Iterable of string names of metrics that
64 | should *not* be averaged over time. Metrics in this list
65 | will be displayed as-is. All others will be averaged
66 | by the progbar before display.
67 | interval: Minimum visual progress update interval (in seconds).
68 | """
69 |
70 | def __init__(self, target, width=25, verbose=1, interval=0.05,
71 | stateful_metrics=None):
72 | self.target = target
73 | self.width = width
74 | self.verbose = verbose
75 | self.interval = interval
76 | if stateful_metrics:
77 | self.stateful_metrics = set(stateful_metrics)
78 | else:
79 | self.stateful_metrics = set()
80 |
81 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
82 | sys.stdout.isatty()) or
83 | 'ipykernel' in sys.modules or
84 | 'posix' in sys.modules)
85 | self._total_width = 0
86 | self._seen_so_far = 0
87 | # We use a dict + list to avoid garbage collection
88 | # issues found in OrderedDict
89 | self._values = {}
90 | self._values_order = []
91 | self._start = time.time()
92 | self._last_update = 0
93 |
94 | def update(self, current, values=None):
95 | """Updates the progress bar.
96 |
97 | Arguments:
98 | current: Index of current step.
99 | values: List of tuples:
100 | `(name, value_for_last_step)`.
101 | If `name` is in `stateful_metrics`,
102 | `value_for_last_step` will be displayed as-is.
103 | Else, an average of the metric over time will be displayed.
104 | """
105 | values = values or []
106 | for k, v in values:
107 | if k not in self._values_order:
108 | self._values_order.append(k)
109 | if k not in self.stateful_metrics:
110 | if k not in self._values:
111 | self._values[k] = [v * (current - self._seen_so_far),
112 | current - self._seen_so_far]
113 | else:
114 | self._values[k][0] += v * (current - self._seen_so_far)
115 | self._values[k][1] += (current - self._seen_so_far)
116 | else:
117 | self._values[k] = v
118 | self._seen_so_far = current
119 |
120 | now = time.time()
121 | info = ' - %.0fs' % (now - self._start)
122 | if self.verbose == 1:
123 | if (now - self._last_update < self.interval and
124 | self.target is not None and current < self.target):
125 | return
126 |
127 | prev_total_width = self._total_width
128 | if self._dynamic_display:
129 | sys.stdout.write('\b' * prev_total_width)
130 | sys.stdout.write('\r')
131 | else:
132 | sys.stdout.write('\n')
133 |
134 | if self.target is not None:
135 | numdigits = int(np.floor(np.log10(self.target))) + 1
136 | barstr = '%%%dd/%d [' % (numdigits, self.target)
137 | bar = barstr % current
138 | prog = float(current) / self.target
139 | prog_width = int(self.width * prog)
140 | if prog_width > 0:
141 | bar += ('=' * (prog_width - 1))
142 | if current < self.target:
143 | bar += '>'
144 | else:
145 | bar += '='
146 | bar += ('.' * (self.width - prog_width))
147 | bar += ']'
148 | else:
149 | bar = '%7d/Unknown' % current
150 |
151 | self._total_width = len(bar)
152 | sys.stdout.write(bar)
153 |
154 | if current:
155 | time_per_unit = (now - self._start) / current
156 | else:
157 | time_per_unit = 0
158 | if self.target is not None and current < self.target:
159 | eta = time_per_unit * (self.target - current)
160 | if eta > 3600:
161 | eta_format = '%d:%02d:%02d' % (eta // 3600,
162 | (eta % 3600) // 60,
163 | eta % 60)
164 | elif eta > 60:
165 | eta_format = '%d:%02d' % (eta // 60, eta % 60)
166 | else:
167 | eta_format = '%ds' % eta
168 |
169 | info = ' - ETA: %s' % eta_format
170 | else:
171 | if time_per_unit >= 1:
172 | info += ' %.0fs/step' % time_per_unit
173 | elif time_per_unit >= 1e-3:
174 | info += ' %.0fms/step' % (time_per_unit * 1e3)
175 | else:
176 | info += ' %.0fus/step' % (time_per_unit * 1e6)
177 |
178 | for k in self._values_order:
179 | info += ' - %s:' % k
180 | if isinstance(self._values[k], list):
181 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
182 | if abs(avg) > 1e-3:
183 | info += ' %.4f' % avg
184 | else:
185 | info += ' %.4e' % avg
186 | else:
187 | info += ' %s' % self._values[k]
188 |
189 | self._total_width += len(info)
190 | if prev_total_width > self._total_width:
191 | info += (' ' * (prev_total_width - self._total_width))
192 |
193 | if self.target is not None and current >= self.target:
194 | info += '\n'
195 |
196 | sys.stdout.write(info)
197 | sys.stdout.flush()
198 |
199 | elif self.verbose == 2:
200 | if self.target is None or current >= self.target:
201 | for k in self._values_order:
202 | info += ' - %s:' % k
203 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
204 | if avg > 1e-3:
205 | info += ' %.4f' % avg
206 | else:
207 | info += ' %.4e' % avg
208 | info += '\n'
209 |
210 | sys.stdout.write(info)
211 | sys.stdout.flush()
212 |
213 | self._last_update = now
214 |
215 | def add(self, n, values=None):
216 | self.update(self._seen_so_far + n, values)
217 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Edge-LBAM
2 | Pytorch implementation of paper "Image Inpainting with Edge-guided Learnable Bidirectional Attention Maps"
3 |
4 | ## Description
5 |
6 | This paper is an extension of our previous work. In comparison to [LBAM](https://openaccess.thecvf.com/content_ICCV_2019/papers/Xie_Image_Inpainting_With_Learnable_Bidirectional_Attention_Maps_ICCV_2019_paper.pdf) we utilize both the mask of holes
7 | and predicted edge map for mask-updating, resulting in our Edge-LBAM method. Moreover, we introduce a multi-scale
8 | edge completion network for effective prediction of coherent edges.
9 |
10 | ## Prerequisites
11 |
12 | - Python 3.6
13 | - Pytorch =1.1.0
14 | - CPU or NVIDIA GPU + Cuda + Cudnn
15 |
16 | ## Training
17 |
18 |
19 | To train the Edge-LBAM model:
20 |
21 | ```
22 | python train.py --batchSize numOf_batch_size --dataRoot your_image_path \
23 | --maskRoot your_mask_root --modelsSavePath path_to_save_your_model \
24 | --logPath path_to_save_tensorboard_log --pretrain(optional) pretrained_model_path
25 | ```
26 |
27 | ## Testing
28 |
29 | To test with random batch with random masks:
30 |
31 | ```
32 | python test_random_batch.py --dataRoot your_image_path
33 | --maskRoot your_mask_path --batchSize numOf_batch_size --pretrain pretrained_model_path
34 | ```
35 |
36 | ## Pretrained Models
37 |
38 | The pretrained models can be found at [google drive](https://drive.google.com/drive/folders/1iilIU0U7fOYjYlRB7bZjN5oLNCeLoW-R?usp=sharing), we will release the models removing bn from Edge-LBAM later which may effect better. You can also train the model by yourself.
39 |
40 | ## Results
41 |
42 | #### Inpainting
43 | From left to right are input, the result of Global&Local,PConv,DeepFillv2.
44 |
45 |

46 |
47 | From left to right are the result of Edge-Connect, MEDFE, Ours and GT.
48 |
49 |

50 |
51 | ### MECNet
52 | From left to right are input, edge competion of single-scale network and multi-scale network.
53 |
54 | 
_1.png)
55 |
--------------------------------------------------------------------------------
/config.yml:
--------------------------------------------------------------------------------
1 | MODE: 1 # 1: train, 2: test, 3: eval
2 | MODEL: 1 # 1: edge model
3 | MASK: 3 # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half)
4 | EDGE: 1 # 1: canny, 2: external
5 | NMS: 1 # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny
6 | SEED: 10 # random seed
7 | GPU: [1] # list of gpu ids
8 | DEBUG: 0 # turns on debugging mode
9 | VERBOSE: 0 # turns on verbose mode in the output console
10 |
11 | TRAIN_FLIST: ./datasets/places2_train.flist
12 | VAL_FLIST: ./datasets/places2_val.flist
13 | TEST_FLIST: ./datasets/places2_test.flist
14 |
15 | TRAIN_EDGE_FLIST: ./datasets/places2_edges_train.flist
16 | VAL_EDGE_FLIST: ./datasets/places2_edges_val.flist
17 | TEST_EDGE_FLIST: ./datasets/places2_edges_test.flist
18 |
19 | TRAIN_MASK_FLIST: ./datasets/masks_train.flist
20 | VAL_MASK_FLIST: ./datasets/masks_val.flist
21 | TEST_MASK_FLIST: ./datasets/masks_test.flist
22 |
23 | LR: 0.0001 # learning rate
24 | D2G_LR: 0.1 # discriminator/generator learning rate ratio
25 | BETA1: 0.0 # adam optimizer beta1
26 | BETA2: 0.9 # adam optimizer beta2
27 | BATCH_SIZE: 8 # input batch size for training
28 | INPUT_SIZE: 256 # input image size for training 0 for original size
29 | SIGMA: 2 # standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge)
30 | MAX_ITERS: 2e6 # maximum number of iterations to train the model
31 |
32 | EDGE_THRESHOLD: 0.5 # edge detection threshold
33 | L1_LOSS_WEIGHT: 1 # l1 loss weight
34 | FM_LOSS_WEIGHT: 10 # feature-matching loss weight
35 | STYLE_LOSS_WEIGHT: 250 # style loss weight
36 | CONTENT_LOSS_WEIGHT: 0.1 # perceptual loss weight
37 | INPAINT_ADV_LOSS_WEIGHT: 0.1 # adversarial loss weight
38 |
39 | GAN_LOSS: nsgan # nsgan | lsgan | hinge
40 | GAN_POOL_SIZE: 0 # fake images pool size
41 |
42 | SAVE_INTERVAL: 1000 # how many iterations to wait before saving model (0: never)
43 | SAMPLE_INTERVAL: 1000 # how many iterations to wait before sampling (0: never)
44 | SAMPLE_SIZE: 12 # number of images to sample
45 | EVAL_INTERVAL: 0 # how many iterations to wait before model evaluation (0: never)
46 | LOG_INTERVAL: 10 # how many iterations to wait before logging training status (0: never)
47 |
--------------------------------------------------------------------------------
/data/__pycache__/basicFunction.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/data/__pycache__/basicFunction.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/basicFunction.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/data/__pycache__/basicFunction.cpython-37.pyc
--------------------------------------------------------------------------------
/data/__pycache__/dataloader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/data/__pycache__/dataloader.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/dataloader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/data/__pycache__/dataloader.cpython-37.pyc
--------------------------------------------------------------------------------
/data/__pycache__/dataloader_canny.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/data/__pycache__/dataloader_canny.cpython-36.pyc
--------------------------------------------------------------------------------
/data/__pycache__/dataloader_canny.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/data/__pycache__/dataloader_canny.cpython-37.pyc
--------------------------------------------------------------------------------
/data/basicFunction.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, Resize, RandomHorizontalFlip
3 |
4 | def CheckImageFile(filename):
5 | return any(filename.endswith(extention) for extention in ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.bmp', '.BMP'])
6 |
7 | def ImageTransform(loadSize, cropSize):
8 | return Compose([
9 | Resize(size=loadSize, interpolation=Image.BICUBIC),
10 | RandomHorizontalFlip(p=0.5),
11 | RandomCrop(size=cropSize),
12 | ToTensor(),
13 | ])
14 |
15 | def MaskTransform(cropSize):
16 | return Compose([
17 | Resize(size=cropSize, interpolation=Image.NEAREST),
18 | ToTensor(),
19 | ])
20 |
21 | # this was image transforms function for paired image and mask, which means that damaged image and the
22 | # mask are in pairs, the input image already contains damaged area with (ones or zeros),
23 | # we suggest that you resize the input image with "NEAREST" not BICUBIC(or other) algorithm,
24 | # is is not guaranteed, but in some cases, the damaged portion might go out of the mask region, if you perform other resize methods
25 | def PairedImageTransform(cropSize):
26 | return Compose([
27 | Resize(size=cropSize, interpolation=Image.NEAREST),
28 | ToTensor(),
29 | ])
--------------------------------------------------------------------------------
/data/dataloader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | from os import listdir, walk
4 | from os.path import join
5 | from random import randint
6 | from data.basicFunction import CheckImageFile, ImageTransform, MaskTransform
7 | import numpy as np
8 | import torchvision.transforms.functional as F
9 | import random
10 | from skimage.feature import canny
11 | from skimage.color import rgb2gray
12 | from shutil import copyfile
13 | from scipy.misc import imread
14 | import matplotlib.pyplot as plt
15 | from torch.utils.data import Dataset
16 |
17 |
18 | class GetData(Dataset):
19 | def __init__(self, dataRoot, maskRoot, loadSize, cropSize):
20 | super(GetData, self).__init__()
21 |
22 | self.imageFiles = [join(dataRootK, files) for dataRootK, dn, filenames in walk(dataRoot) \
23 | for files in filenames if CheckImageFile(files)]
24 | self.masks = [join(dataRootK, files) for dataRootK, dn, filenames in walk(maskRoot) \
25 | for files in filenames if CheckImageFile(files)]
26 | self.numOfMasks = len(self.masks)
27 | self.loadSize = loadSize
28 | self.cropSize = cropSize
29 | self.ImgTrans = ImageTransform(loadSize, cropSize)
30 | self.maskTrans = MaskTransform(cropSize)
31 | self.sigma = 1.5
32 |
33 | def __getitem__(self, index):
34 | img = Image.open(self.imageFiles[index])
35 | randnum = randint(0, self.numOfMasks - 1)
36 | # mask = Image.open(self.imageFiles[index].replace("GT","mask"))
37 | mask = Image.open(self.masks[randnum])
38 | groundTruth = self.ImgTrans(img.convert('RGB'))
39 | mask = self.maskTrans(mask.convert('RGB'))
40 | # we add this threshhold to force the input mask to be binary 0,1 values
41 | # the threshhold value can be changeble, i think 0.5 is ok
42 | threshhold = 0.5
43 | ones = mask >= threshhold
44 | zeros = mask < threshhold
45 |
46 | mask.masked_fill_(ones, 1.0)
47 | mask.masked_fill_(zeros, 0.0)
48 |
49 | # here, we suggest that the white values(ones) denotes the area to be inpainted,
50 | # and dark values(zeros) is the values remained.
51 | # Therefore, we do a reverse step let mask = 1 - mask, the input = groundTruth * mask, :).
52 | edge_mask = np.transpose(mask, (1, 2, 0))
53 | mask = 1 - mask
54 | inputImage = groundTruth * mask
55 | edge_mask = edge_mask.numpy()
56 |
57 | edge_mask = rgb2gray(edge_mask)
58 | edge_mask2 = (edge_mask > 0).astype(np.uint8) * 255 # threshold due to interpolation
59 |
60 | tmp = np.transpose(groundTruth, (1, 2, 0))
61 | tmp = tmp.numpy()
62 | img_gray = rgb2gray(tmp)
63 |
64 | edge = self.load_edge(img_gray, np.array(1 - edge_mask2 / 255).astype(np.bool))
65 | img_gray = torch.from_numpy(img_gray.reshape((1, 256, 256)))
66 | edge_mask = torch.from_numpy((edge_mask).reshape((1,256,256)))
67 |
68 | edge = torch.from_numpy(edge.reshape((1, 256, 256))).float()
69 | inputImage = torch.cat((inputImage, mask[0].view(1, 256, 256)), 0)
70 |
71 | return inputImage,groundTruth, mask, img_gray, edge, edge_mask.float(),self.imageFiles[index]
72 |
73 | def __len__(self):
74 | return len(self.imageFiles)
75 |
76 | def to_tensor(self, img):
77 | img = Image.fromarray(img)
78 | img_t = F.to_tensor(img).float()
79 | return img_t
80 |
81 | def load_edge(self, img, mask):
82 | sigma = self.sigma
83 | # in test mode images are masked (with masked regions),
84 | # using 'mask' parameter prevents canny to detect edges for the masked regions
85 |
86 | # canny
87 |
88 | # no edge
89 | if sigma == -1:
90 | return np.zeros(img.shape).astype(np.float)
91 |
92 | # random sigma
93 | if sigma == 0:
94 | sigma = random.randint(1, 4)
95 |
96 | return canny(img, sigma=sigma, mask=mask).astype(np.float)
97 |
98 |
99 |
100 |
101 |
--------------------------------------------------------------------------------
/data/dataloader_canny.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from PIL import Image
4 | from os import listdir, walk
5 | from os.path import join
6 | from random import randint
7 | from skimage.feature import canny
8 | from skimage.color import rgb2gray
9 | import numpy as np
10 | from data.basicFunction import CheckImageFile, ImageTransform, MaskTransform
11 | import matplotlib.pyplot as plt
12 | class GetData(Dataset):
13 | def __init__(self, dataRoot, maskRoot, loadSize, cropSize):
14 | super(GetData, self).__init__()
15 |
16 | self.imageFiles = [join (dataRootK, files) for dataRootK, dn, filenames in walk(dataRoot) \
17 | for files in filenames if CheckImageFile(files)]
18 | self.masks = [join (dataRootK, files) for dataRootK, dn, filenames in walk(maskRoot) \
19 | for files in filenames if CheckImageFile(files)]
20 | self.numOfMasks = len(self.masks)
21 | self.loadSize = loadSize
22 | self.cropSize = cropSize
23 | self.ImgTrans = ImageTransform(loadSize, cropSize)
24 | self.maskTrans = MaskTransform(cropSize)
25 |
26 | def __getitem__(self, index):
27 | img = Image.open(self.imageFiles[index])
28 | mask = Image.open(self.masks[randint(0, self.numOfMasks - 1)])
29 |
30 | groundTruth = self.ImgTrans(img.convert('RGB'))
31 | mask = self.maskTrans(mask.convert('RGB'))
32 | # we add this threshhold to force the input mask to be binary 0,1 values
33 | # the threshhold value can be changeble, i think 0.5 is ok
34 | threshhold = 0.5
35 | ones = mask >= threshhold
36 | zeros = mask < threshhold
37 |
38 | mask.masked_fill_(ones, 1.0)
39 | mask.masked_fill_(zeros, 0.0)
40 |
41 | # here, we suggest that the white values(ones) denotes the area to be inpainted,
42 | # and dark values(zeros) is the values remained.
43 | # Therefore, we do a reverse step let mask = 1 - mask, the input = groundTruth * mask, :).
44 | mask = 1 - mask
45 | inputImage = groundTruth * mask
46 | tmp = np.transpose(groundTruth, (1, 2, 0))
47 | tmp = tmp.numpy()
48 | tmp = rgb2gray(tmp)
49 | edge = canny(tmp, sigma=1.5).astype(np.float32)
50 | edge = torch.from_numpy(edge.reshape((1, 256, 256))).float()
51 | inputImage = torch.cat((inputImage, mask[0].view(1, self.cropSize[0], self.cropSize[1])), 0)
52 |
53 | return inputImage, groundTruth, mask, edge
54 |
55 | def __len__(self):
56 | return len(self.imageFiles)
--------------------------------------------------------------------------------
/examples/GT28-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/GT28-1.png
--------------------------------------------------------------------------------
/examples/MEDFE28-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/MEDFE28-1.png
--------------------------------------------------------------------------------
/examples/ec28-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/ec28-1.png
--------------------------------------------------------------------------------
/examples/edge_mecnet(s)_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/edge_mecnet(s)_1.png
--------------------------------------------------------------------------------
/examples/edge_mecnet_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/edge_mecnet_1.png
--------------------------------------------------------------------------------
/examples/gc28-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/gc28-1.png
--------------------------------------------------------------------------------
/examples/gl28-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/gl28-1.png
--------------------------------------------------------------------------------
/examples/input1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/input1.png
--------------------------------------------------------------------------------
/examples/input28-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/input28-1.png
--------------------------------------------------------------------------------
/examples/ours28-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/ours28-1.png
--------------------------------------------------------------------------------
/examples/pconv28-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/pconv28-1.png
--------------------------------------------------------------------------------
/loss/InpaintingLoss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch import autograd
4 | from tensorboardX import SummaryWriter
5 | from models.discriminator import DiscriminatorDoubleColumn
6 |
7 | # modified from WGAN-GP
8 | def calc_gradient_penalty(netD, real_data, fake_data, masks, cuda, Lambda):
9 | BATCH_SIZE = real_data.size()[0]
10 | DIM = real_data.size()[2]
11 | alpha = torch.rand(BATCH_SIZE, 1)
12 | alpha = alpha.expand(BATCH_SIZE, int(real_data.nelement()/BATCH_SIZE)).contiguous()
13 | alpha = alpha.view(BATCH_SIZE, 3, DIM, DIM)
14 | if cuda:
15 | alpha = alpha.cuda()
16 |
17 | fake_data = fake_data.view(BATCH_SIZE, 3, DIM, DIM)
18 | interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach())
19 |
20 | if cuda:
21 | interpolates = interpolates.cuda()
22 | interpolates.requires_grad_(True)
23 |
24 | disc_interpolates = netD(interpolates, masks)
25 |
26 | gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
27 | grad_outputs=torch.ones(disc_interpolates.size()).cuda() if cuda else torch.ones(disc_interpolates.size()),
28 | create_graph=True, retain_graph=True, only_inputs=True)[0]
29 |
30 | gradients = gradients.view(gradients.size(0), -1)
31 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * Lambda
32 | return gradient_penalty.sum().mean()
33 |
34 |
35 | def gram_matrix(feat):
36 | # https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py
37 | (b, ch, h, w) = feat.size()
38 | feat = feat.view(b, ch, h * w)
39 | feat_t = feat.transpose(1, 2)
40 | gram = torch.bmm(feat, feat_t) / (ch * h * w)
41 | return gram
42 |
43 |
44 | #tv loss
45 | def total_variation_loss(image):
46 | # shift one pixel and get difference (for both x and y direction)
47 | loss = torch.mean(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + \
48 | torch.mean(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))
49 | return loss
50 |
51 |
52 |
53 | class InpaintingLossWithGAN(nn.Module):
54 | def __init__(self, logPath, extractor, Lamda, lr, betasInit=(0.5, 0.9)):
55 | super(InpaintingLossWithGAN, self).__init__()
56 | self.l1 = nn.L1Loss()
57 | self.extractor = extractor
58 | self.discriminator = DiscriminatorDoubleColumn(3)
59 | self.D_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=betasInit)
60 | self.cudaAvailable = torch.cuda.is_available()
61 | self.numOfGPUs = torch.cuda.device_count()
62 | """ if (self.numOfGPUs > 1):
63 | self.discriminator = self.discriminator.cuda()
64 | self.discriminator = nn.DataParallel(self.discriminator, device_ids=range(self.numOfGPUs)) """
65 | self.lamda = Lamda
66 | self.writer = SummaryWriter(logPath)
67 |
68 | def forward(self, input, mask, output, gt, count, epoch):
69 | self.discriminator.zero_grad()
70 | D_real = self.discriminator(gt, mask)
71 | D_real = D_real.mean().sum() * -1
72 | D_fake = self.discriminator(output, mask)
73 | D_fake = D_fake.mean().sum() * 1
74 | gp = calc_gradient_penalty(self.discriminator, gt, output, mask, self.cudaAvailable, self.lamda)
75 | D_loss = D_fake - D_real + gp
76 | self.D_optimizer.zero_grad()
77 | D_loss.backward(retain_graph=True)
78 | self.D_optimizer.step()
79 |
80 | self.writer.add_scalar('LossD/Discrinimator loss', D_loss.item(), count)
81 |
82 | output_comp = mask * input + (1 - mask) * output
83 |
84 | holeLoss = 6 * self.l1((1 - mask) * output, (1 - mask) * gt)
85 | validAreaLoss = self.l1(mask * output, mask * gt)
86 |
87 | if output.shape[1] == 3:
88 | feat_output_comp = self.extractor(output_comp)
89 | feat_output = self.extractor(output)
90 | feat_gt = self.extractor(gt)
91 | elif output.shape[1] == 1:
92 | feat_output_comp = self.extractor(torch.cat([output_comp]*3, 1))
93 | feat_output = self.extractor(torch.cat([output]*3, 1))
94 | feat_gt = self.extractor(torch.cat([gt]*3, 1))
95 | else:
96 | raise ValueError('only gray an')
97 |
98 | prcLoss = 0.0
99 | for i in range(3):
100 | prcLoss += 0.005 * self.l1(feat_output[i], feat_gt[i])
101 | prcLoss += 0.005 * self.l1(feat_output_comp[i], feat_gt[i])
102 |
103 | styleLoss = 0.0
104 | for i in range(3):
105 | styleLoss += 120 * self.l1(gram_matrix(feat_output[i]),
106 | gram_matrix(feat_gt[i]))
107 | styleLoss += 120 * self.l1(gram_matrix(feat_output_comp[i]),
108 | gram_matrix(feat_gt[i]))
109 |
110 | """ if self.numOfGPUs > 1:
111 | holeLoss = holeLoss.sum() / self.numOfGPUs
112 | validAreaLoss = validAreaLoss.sum() / self.numOfGPUs
113 | prcLoss = prcLoss.sum() / self.numOfGPUs
114 | styleLoss = styleLoss.sum() / self.numOfGPUs """
115 | self.writer.add_scalar('LossG/Hole loss', holeLoss.item(), count)
116 | self.writer.add_scalar('LossG/Valid loss', validAreaLoss.item(), count)
117 | self.writer.add_scalar('LossPrc/Perceptual loss', prcLoss.item(), count)
118 | self.writer.add_scalar('LossStyle/style loss', styleLoss.item(), count)
119 |
120 | GLoss = holeLoss + validAreaLoss + prcLoss + styleLoss + 0.1 * D_fake
121 | self.writer.add_scalar('Generator/Joint loss', GLoss.item(), count)
122 | return GLoss.sum()
--------------------------------------------------------------------------------
/loss/__pycache__/InpaintingLoss.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/loss/__pycache__/InpaintingLoss.cpython-36.pyc
--------------------------------------------------------------------------------
/loss/__pycache__/InpaintingLoss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/loss/__pycache__/InpaintingLoss.cpython-37.pyc
--------------------------------------------------------------------------------
/models/ActivationFunction.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.nn.parameter import Parameter
4 | from torch import nn
5 | from torchvision import models
6 |
7 | # asymmetric gaussian shaped activation function g_A
8 | class GaussActivation(nn.Module):
9 | def __init__(self, a, mu, sigma1, sigma2):
10 | super(GaussActivation, self).__init__()
11 |
12 | self.a = Parameter(torch.tensor(a, dtype=torch.float32))
13 | self.mu = Parameter(torch.tensor(mu, dtype=torch.float32))
14 | self.sigma1 = Parameter(torch.tensor(sigma1, dtype=torch.float32))
15 | self.sigma2 = Parameter(torch.tensor(sigma2, dtype=torch.float32))
16 |
17 |
18 | def forward(self, inputFeatures):
19 |
20 | self.a.data = torch.clamp(self.a.data, 1.01, 6.0)
21 | self.mu.data = torch.clamp(self.mu.data, 0.1, 3.0)
22 | self.sigma1.data = torch.clamp(self.sigma1.data, 0.5, 2.0)
23 | self.sigma2.data = torch.clamp(self.sigma2.data, 0.5, 2.0)
24 |
25 | lowerThanMu = inputFeatures < self.mu
26 | largerThanMu = inputFeatures >= self.mu
27 |
28 | leftValuesActiv = self.a * torch.exp(- self.sigma1 * ( (inputFeatures - self.mu) ** 2 ) )
29 | leftValuesActiv.masked_fill_(largerThanMu, 0.0)
30 |
31 | rightValueActiv = 1 + (self.a - 1) * torch.exp(- self.sigma2 * ( (inputFeatures - self.mu) ** 2 ) )
32 | rightValueActiv.masked_fill_(lowerThanMu, 0.0)
33 |
34 | output = leftValuesActiv + rightValueActiv
35 |
36 | return output
37 |
38 | # mask updating functions, we recommand using alpha that is larger than 0 and lower than 1.0
39 | class MaskUpdate(nn.Module):
40 | def __init__(self, alpha):
41 | super(MaskUpdate, self).__init__()
42 |
43 | self.updateFunc = nn.ReLU(True)
44 | #self.alpha = Parameter(torch.tensor(alpha, dtype=torch.float32))
45 | self.alpha = alpha
46 | def forward(self, inputMaskMap):
47 | """ self.alpha.data = torch.clamp(self.alpha.data, 0.6, 0.8)
48 | print(self.alpha) """
49 |
50 | return torch.pow(self.updateFunc(inputMaskMap), self.alpha)
--------------------------------------------------------------------------------
/models/EdgeAttentionLayer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models.weightInitial import weights_init
3 | from torch import nn
4 | import torch.nn.functional as F
5 | class ForwardEdgeAttention(nn.Module):
6 | def __init__(self, channels,outchannels):
7 | super(ForwardEdgeAttention, self).__init__()
8 | self.maskconv = nn.Sequential(
9 | nn.Conv2d(channels,channels,kernel_size=3,stride=1,padding=1,bias=False),
10 | nn.BatchNorm2d(channels),
11 | nn.LeakyReLU(0.2,False)
12 | )
13 | self.edgegradient = nn.Sequential(
14 | nn.Conv2d(channels,channels,kernel_size=3,stride=1,padding=1,bias=False),
15 | nn.BatchNorm2d(channels),
16 | nn.LeakyReLU(0.2,False)
17 | )
18 | self.edgemaskcoincide = nn.Sequential(
19 | nn.Conv2d(channels*2,1,kernel_size=1,padding=0,stride=1,bias=False),
20 | nn.BatchNorm2d(1),
21 | nn.LeakyReLU(0.2,False)
22 | )
23 | self.edgeconv = nn.Sequential(
24 | nn.Conv2d(channels,outchannels,kernel_size=3,padding=1,stride = 1,bias=False),
25 | nn.BatchNorm2d(outchannels),
26 | nn.LeakyReLU(0.2,False)
27 | )
28 | self.maskconv.apply(weights_init())
29 | self.edgegradient.apply(weights_init())
30 | self.edgemaskcoincide.apply(weights_init())
31 | self.edgeconv.apply(weights_init())
32 | def forward(self, mask, edge):
33 | # print(edge.shape)
34 | output2 = F.interpolate(edge,size=[mask.shape[2],mask.shape[2]])
35 | maskout = self.maskconv(mask)
36 | edge_gradient=self.edgegradient(output2)
37 | edge_mask_concat = torch.cat((edge_gradient,mask),1)
38 | edgeout = self.edgemaskcoincide(edge_mask_concat)
39 | maskmulti = maskout*edgeout
40 | output1 = maskmulti+mask
41 | output2 = self.edgeconv(output2)
42 | return output1,output2,maskmulti
43 |
--------------------------------------------------------------------------------
/models/LBAMModel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchvision import models
4 | from models.forwardAttentionLayer import ForwardAttention
5 | from models.reverseAttentionLayer import ReverseAttention, ReverseMaskConv
6 | from models.weightInitial import weights_init
7 | from models.EdgeAttentionLayer import ForwardEdgeAttention
8 | from models.weightInitial import weights_init
9 | #VGG16 feature extract
10 | class VGG16FeatureExtractor(nn.Module):
11 | def __init__(self):
12 | super(VGG16FeatureExtractor, self).__init__()
13 | vgg16 = models.vgg16(pretrained=True)
14 | # vgg16.load_state_dict(torch.load('../vgg16-397923af.pth'))
15 | self.enc_1 = nn.Sequential(*vgg16.features[:5])
16 | self.enc_2 = nn.Sequential(*vgg16.features[5:10])
17 | self.enc_3 = nn.Sequential(*vgg16.features[10:17])
18 | # fix the encoder
19 | for i in range(3):
20 | for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():
21 | param.requires_grad = False
22 |
23 | def forward(self, image):
24 | results = [image]
25 | for i in range(3):
26 | func = getattr(self, 'enc_{:d}'.format(i + 1))
27 | results.append(func(results[-1]))
28 | return results[1:]
29 |
30 | class LBAMModel(nn.Module):
31 | def __init__(self, inputChannels, outputChannels):
32 | super(LBAMModel, self).__init__()
33 | self.maskconv1 = nn.Sequential(
34 | nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False),
35 | nn.BatchNorm2d(3),
36 | nn.LeakyReLU(0.2,False)
37 | )
38 | self.maskconv2 = nn.Sequential(
39 | nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False),
40 | nn.BatchNorm2d(3),
41 | nn.LeakyReLU(0.2,False)
42 | )
43 | self.edgeconv1 = nn.Sequential(
44 | nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1, bias=False),
45 | nn.BatchNorm2d(3),
46 | nn.LeakyReLU(0.2,False)
47 | )
48 | self.edgeconv2 = nn.Sequential(
49 | nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False),
50 | nn.BatchNorm2d(3),
51 | nn.LeakyReLU(0.2,False)
52 | )
53 | self.maskconv3 = nn.Sequential(
54 | nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False),
55 | nn.BatchNorm2d(3),
56 | nn.LeakyReLU(0.2,False)
57 | )
58 | self.maskconv4 = nn.Sequential(
59 | nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False),
60 | nn.BatchNorm2d(3),
61 | nn.LeakyReLU(0.2,False)
62 | )
63 | self.edgeconv3 = nn.Sequential(
64 | nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1, bias=False),
65 | nn.BatchNorm2d(3),
66 | nn.LeakyReLU(0.2,False)
67 | )
68 | self.edgeconv4 = nn.Sequential(
69 | nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False),
70 | nn.BatchNorm2d(3),
71 | nn.LeakyReLU(0.2,False)
72 | )
73 | self.maskconv1.apply(weights_init())
74 | self.maskconv2.apply(weights_init())
75 | self.maskconv3.apply(weights_init())
76 | self.maskconv4.apply(weights_init())
77 | self.edgeconv1.apply(weights_init())
78 | self.edgeconv2.apply(weights_init())
79 | self.edgeconv3.apply(weights_init())
80 |
81 | # default kernel is of size 4X4, stride 2, padding 1,
82 | # and the use of biases are set false in default ReverseAttention class.
83 | self.ec1 = ForwardAttention(5, 64, bn=False)
84 | self.ec2 = ForwardAttention(64, 128)
85 | self.ec3 = ForwardAttention(128, 256)
86 | self.ec4 = ForwardAttention(256, 512)
87 | self.edge1 = ForwardEdgeAttention(3,64)
88 | self.edge2 = ForwardEdgeAttention(64,128)
89 | self.edge3 = ForwardEdgeAttention(128,256)
90 | self.edge4 = ForwardEdgeAttention(256,512)
91 | for i in range(5, 8):
92 | name = 'ec{:d}'.format(i)
93 | setattr(self, name, ForwardAttention(512, 512))
94 | name2 = 'edge{:d}'.format(i)
95 | setattr(self,name2,ForwardEdgeAttention(512,512))
96 |
97 | # reverse mask conv
98 | self.reverseConv1 = ReverseMaskConv(3, 64)
99 | self.reverseConv2 = ReverseMaskConv(64, 128)
100 | self.reverseConv3 = ReverseMaskConv(128, 256)
101 | self.reverseConv4 = ReverseMaskConv(256, 512)
102 | self.reverseConv5 = ReverseMaskConv(512, 512)
103 | self.reverseConv6 = ReverseMaskConv(512, 512)
104 | self.reverseedge1 = ForwardEdgeAttention(3,64)
105 | self.reverseedge2 = ForwardEdgeAttention(64,128)
106 | self.reverseedge3 = ForwardEdgeAttention(128,256)
107 | self.reverseedge4 = ForwardEdgeAttention(256,512)
108 | self.reverseedge5 = ForwardEdgeAttention(512, 512)
109 | self.reverseedge6 = ForwardEdgeAttention(512, 512)
110 | self.dc1 = ReverseAttention(512, 512, bnChannels=1024)
111 | self.dc2 = ReverseAttention(512 * 2, 512, bnChannels=1024)
112 | self.dc3 = ReverseAttention(512 * 2, 512, bnChannels=1024)
113 | self.dc4 = ReverseAttention(512 * 2, 256, bnChannels=512)
114 | self.dc5 = ReverseAttention(256 * 2, 128, bnChannels=256)
115 | self.dc6 = ReverseAttention(128 * 2, 64, bnChannels=128)
116 | self.dc7 = nn.ConvTranspose2d(64 * 2, outputChannels, kernel_size=4, stride=2, padding=1, bias=False)
117 |
118 | self.tanh = nn.Tanh()
119 |
120 | def forward(self, inputImgs, masks,edge):
121 | mask1 = self.maskconv1(masks)
122 | mask1 =self.maskconv2(mask1)
123 | edge1 = self.edgeconv1(edge)
124 | edge1 = self.edgeconv2(edge1)
125 | maskoutput1,edgeoutput,feature1 = self.edge1(mask1,edge1)
126 | ef, mu1, skipConnect1, forwardMap1 = self.ec1(inputImgs, maskoutput1)
127 | maskoutput,edgeoutput,feature2 = self.edge2(mu1,edgeoutput)
128 | ef, mu2, skipConnect2, forwardMap2 = self.ec2(ef, maskoutput)
129 | maskoutput3,edgeoutput,feature3 = self.edge3(mu2,edgeoutput)
130 | ef, mu3, skipConnect3, forwardMap3 = self.ec3(ef, maskoutput3)
131 | maskoutput,edgeoutput,_ = self.edge4(mu3,edgeoutput)
132 | ef, mu, skipConnect4, forwardMap4 = self.ec4(ef, maskoutput)
133 | maskoutput,edgeoutput,_ = self.edge5(mu,edgeoutput)
134 | ef, mu, skipConnect5, forwardMap5 = self.ec5(ef, maskoutput)
135 | maskoutput,edgeoutput,_ = self.edge6(mu,edgeoutput)
136 | ef, mu, skipConnect6, forwardMap6 = self.ec6(ef, maskoutput)
137 | maskoutput, edgeoutput,_ = self.edge7(mu, edgeoutput)
138 | ef, _, _, _ = self.ec7(ef, maskoutput)
139 |
140 | mask2 = self.maskconv3(1-masks)
141 | mask2 = self.maskconv4(mask2)
142 | edge2 = self.edgeconv3(edge)
143 | edge2 = self.edgeconv4(edge2)
144 | maskoutput1,edgeoutput,feature1 = self.reverseedge1(mask2,edge2)
145 | reverseMap1, revMu = self.reverseConv1(maskoutput1)
146 | maskoutput2,edgeoutput,feature2 = self.reverseedge2(revMu,edgeoutput)
147 | reverseMap2, revMu = self.reverseConv2(maskoutput2)
148 | maskoutput3, edgeoutput,feature3 = self.reverseedge3(revMu, edgeoutput)
149 | reverseMap3, revMu = self.reverseConv3(maskoutput3)
150 | maskoutput, edgeoutput,_ = self.reverseedge4(revMu, edgeoutput)
151 | reverseMap4, revMu = self.reverseConv4(maskoutput)
152 | maskoutput, edgeoutput,_ = self.reverseedge5(revMu, edgeoutput)
153 | reverseMap5, revMu = self.reverseConv5(maskoutput)
154 | maskoutput, edgeoutput,_ = self.reverseedge6(revMu, edgeoutput)
155 | reverseMap6, _ = self.reverseConv6(maskoutput)
156 |
157 | concatMap6 = torch.cat((forwardMap6, reverseMap6), 1)
158 | dcFeatures1 = self.dc1(skipConnect6, ef, concatMap6)
159 |
160 | concatMap5 = torch.cat((forwardMap5, reverseMap5), 1)
161 | dcFeatures2 = self.dc2(skipConnect5, dcFeatures1, concatMap5)
162 |
163 | concatMap4 = torch.cat((forwardMap4, reverseMap4), 1)
164 | dcFeatures3 = self.dc3(skipConnect4, dcFeatures2, concatMap4)
165 |
166 | concatMap3 = torch.cat((forwardMap3, reverseMap3), 1)
167 | dcFeatures4 = self.dc4(skipConnect3, dcFeatures3, concatMap3)
168 |
169 | concatMap2 = torch.cat((forwardMap2, reverseMap2), 1)
170 | dcFeatures5 = self.dc5(skipConnect2, dcFeatures4, concatMap2)
171 |
172 | concatMap1 = torch.cat((forwardMap1, reverseMap1), 1)
173 | dcFeatures6 = self.dc6(skipConnect1, dcFeatures5, concatMap1)
174 |
175 | dcFeatures7 = self.dc7(dcFeatures6)
176 |
177 | output = torch.abs(self.tanh(dcFeatures7))
178 |
179 | return output,forwardMap1,forwardMap2,forwardMap3, reverseMap1,reverseMap2,reverseMap3
--------------------------------------------------------------------------------
/models/__pycache__/ActivationFunction.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/ActivationFunction.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/ActivationFunction.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/ActivationFunction.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/EdgeAttentionLayer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/EdgeAttentionLayer.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/EdgeAttentionLayer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/EdgeAttentionLayer.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/LBAMModel.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/LBAMModel.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/LBAMModel.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/LBAMModel.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/discriminator.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/discriminator.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/discriminator.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/discriminator.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/forwardAttentionLayer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/forwardAttentionLayer.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/forwardAttentionLayer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/forwardAttentionLayer.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/reverseAttentionLayer.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/reverseAttentionLayer.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/reverseAttentionLayer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/reverseAttentionLayer.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/weightInitial.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/weightInitial.cpython-36.pyc
--------------------------------------------------------------------------------
/models/__pycache__/weightInitial.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/weightInitial.cpython-37.pyc
--------------------------------------------------------------------------------
/models/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | ##discriminator
5 | # two column discriminator
6 | class DiscriminatorDoubleColumn(nn.Module):
7 | def __init__(self, inputChannels):
8 | super(DiscriminatorDoubleColumn, self).__init__()
9 |
10 | self.globalConv = nn.Sequential(
11 | nn.Conv2d(inputChannels, 64, kernel_size=4, stride=2, padding=1),
12 | nn.LeakyReLU(0.2, inplace=True),
13 |
14 | nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
15 | nn.BatchNorm2d(128),
16 | nn.LeakyReLU(0.2, inplace=True),
17 |
18 | nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
19 | nn.BatchNorm2d(256),
20 | nn.LeakyReLU(0.2, inplace=True),
21 |
22 | nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
23 | nn.BatchNorm2d(512),
24 | nn.LeakyReLU(0.2 , inplace=True),
25 |
26 | nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
27 | nn.BatchNorm2d(512),
28 | nn.LeakyReLU(0.2, inplace=True),
29 |
30 | nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
31 | nn.BatchNorm2d(512),
32 | nn.LeakyReLU(0.2, inplace=True),
33 |
34 | )
35 |
36 | self.localConv = nn.Sequential(
37 | nn.Conv2d(inputChannels, 64, kernel_size=4, stride=2, padding=1),
38 | nn.LeakyReLU(0.2, inplace=True),
39 |
40 | nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
41 | nn.BatchNorm2d(128),
42 | nn.LeakyReLU(0.2, inplace=True),
43 |
44 | nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
45 | nn.BatchNorm2d(256),
46 | nn.LeakyReLU(0.2, inplace=True),
47 |
48 | nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
49 | nn.BatchNorm2d(512),
50 | nn.LeakyReLU(0.2 , inplace=True),
51 |
52 | nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
53 | nn.BatchNorm2d(512),
54 | nn.LeakyReLU(0.2, inplace=True),
55 |
56 | nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
57 | nn.BatchNorm2d(512),
58 | nn.LeakyReLU(0.2, inplace=True),
59 | )
60 |
61 | self.fusionLayer = nn.Sequential(
62 | nn.Conv2d(1024, 1, kernel_size=4),
63 | nn.Sigmoid()
64 | )
65 |
66 | def forward(self, batches, masks):
67 | globalFt = self.globalConv(batches * masks)
68 | localFt = self.localConv(batches * (1 - masks))
69 |
70 | concatFt = torch.cat((globalFt, localFt), 1)
71 |
72 | return self.fusionLayer(concatFt).view(batches.size()[0], -1)
--------------------------------------------------------------------------------
/models/forwardAttentionLayer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from models.ActivationFunction import GaussActivation, MaskUpdate
5 | from models.weightInitial import weights_init
6 |
7 | # learnable forward attention conv layer
8 | class ForwardAttentionLayer(nn.Module):
9 | def __init__(self, inputChannels, outputChannels, kernelSize, stride,
10 | padding, dilation=1, groups=1, bias=False):
11 | super(ForwardAttentionLayer, self).__init__()
12 |
13 | self.conv = nn.Conv2d(inputChannels, outputChannels, kernelSize, stride, padding, dilation, \
14 | groups, bias)
15 |
16 | if inputChannels == 5:
17 | self.maskConv = nn.Conv2d(3, outputChannels, kernelSize, stride, padding, dilation, \
18 | groups, bias)
19 | else:
20 | self.maskConv = nn.Conv2d(inputChannels, outputChannels, kernelSize, stride, padding, \
21 | dilation, groups, bias)
22 |
23 | self.conv.apply(weights_init())
24 | self.maskConv.apply(weights_init())
25 |
26 | self.activationFuncG_A = GaussActivation(1.1, 2.0, 1.0, 1.0)
27 | self.updateMask = MaskUpdate(0.8)
28 |
29 | def forward(self, inputFeatures, inputMasks):
30 | convFeatures = self.conv(inputFeatures)
31 | maskFeatures = self.maskConv(inputMasks)
32 | #convFeatures_skip = convFeatures.clone()
33 |
34 | maskActiv = self.activationFuncG_A(maskFeatures)
35 | convOut = convFeatures * maskActiv
36 |
37 | maskUpdate = self.updateMask(maskFeatures)
38 |
39 | return convOut, maskUpdate, convFeatures, maskActiv
40 |
41 | # forward attention gather feature activation and batchnorm
42 | class ForwardAttention(nn.Module):
43 | def __init__(self, inputChannels, outputChannels, bn=True, sample='down-4', \
44 | activ='leaky', convBias=False):
45 | super(ForwardAttention, self).__init__()
46 |
47 | if sample == 'down-4':
48 | self.conv = ForwardAttentionLayer(inputChannels, outputChannels, 4, 2, 1, bias=convBias)
49 | elif sample == 'down-5':
50 | self.conv = ForwardAttentionLayer(inputChannels, outputChannels, 5, 2, 2, bias=convBias)
51 | elif sample == 'down-7':
52 | self.conv = ForwardAttentionLayer(inputChannels, outputChannels, 7, 2, 3, bias=convBias)
53 | elif sample == 'down-3':
54 | self.conv = ForwardAttentionLayer(inputChannels, outputChannels, 3, 2, 1, bias=convBias)
55 | else:
56 | self.conv = ForwardAttentionLayer(inputChannels, outputChannels, 3, 1, 1, bias=convBias)
57 |
58 | if bn:
59 | self.bn = nn.BatchNorm2d(outputChannels)
60 |
61 | if activ == 'leaky':
62 | self.activ = nn.LeakyReLU(0.2, False)
63 | elif activ == 'relu':
64 | self.activ = nn.ReLU()
65 | elif activ == 'sigmoid':
66 | self.activ = nn.Sigmoid()
67 | elif activ == 'tanh':
68 | self.activ = nn.Tanh()
69 | elif activ == 'prelu':
70 | self.activ = nn.PReLU()
71 | else:
72 | pass
73 |
74 | def forward(self, inputFeatures, inputMasks):
75 | features, maskUpdated, convPreF, maskActiv = self.conv(inputFeatures, inputMasks)
76 |
77 | if hasattr(self, 'bn'):
78 | features = self.bn(features)
79 | if hasattr(self, 'activ'):
80 | features = self.activ(features)
81 |
82 | return features, maskUpdated, convPreF, maskActiv
--------------------------------------------------------------------------------
/models/reverseAttentionLayer.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from models.ActivationFunction import GaussActivation, MaskUpdate
5 | from models.weightInitial import weights_init
6 |
7 |
8 | # learnable reverse attention conv
9 | class ReverseMaskConv(nn.Module):
10 | def __init__(self, inputChannels, outputChannels, kernelSize=4, stride=2,
11 | padding=1, dilation=1, groups=1, convBias=False):
12 | super(ReverseMaskConv, self).__init__()
13 |
14 | self.reverseMaskConv = nn.Conv2d(inputChannels, outputChannels, kernelSize, stride, padding, \
15 | dilation, groups, bias=convBias)
16 |
17 | self.reverseMaskConv.apply(weights_init())
18 |
19 | self.activationFuncG_A = GaussActivation(1.1, 1.0, 0.5, 0.5)
20 | self.updateMask = MaskUpdate(0.8)
21 |
22 | def forward(self, inputMasks):
23 | maskFeatures = self.reverseMaskConv(inputMasks)
24 |
25 | maskActiv = self.activationFuncG_A(maskFeatures)
26 |
27 | maskUpdate = self.updateMask(maskFeatures)
28 |
29 | return maskActiv, maskUpdate
30 |
31 | # learnable reverse attention layer, including features activation and batchnorm
32 | class ReverseAttention(nn.Module):
33 | def __init__(self, inputChannels, outputChannels, bn=True, activ='leaky', \
34 | kernelSize=4, stride=2, padding=1, outPadding=0,dilation=1, groups=1,convBias=False, bnChannels=512):
35 | super(ReverseAttention, self).__init__()
36 |
37 | self.conv = nn.ConvTranspose2d(inputChannels, outputChannels, kernel_size=kernelSize, \
38 | stride=stride, padding=padding, output_padding=outPadding, dilation=dilation, groups=groups,bias=convBias)
39 |
40 | self.conv.apply(weights_init())
41 |
42 | if bn:
43 | self.bn = nn.BatchNorm2d(bnChannels)
44 |
45 | if activ == 'leaky':
46 | self.activ = nn.LeakyReLU(0.2, False)
47 | elif activ == 'relu':
48 | self.activ = nn.ReLU()
49 | elif activ == 'sigmoid':
50 | self.activ = nn.Sigmoid()
51 | elif activ == 'tanh':
52 | self.activ = nn.Tanh()
53 | elif activ == 'prelu':
54 | self.activ = nn.PReLU()
55 | else:
56 | pass
57 |
58 | def forward(self, ecFeaturesSkip, dcFeatures, maskFeaturesForAttention):
59 | nextDcFeatures = self.conv(dcFeatures)
60 |
61 | # note that encoder features are ahead, it's important tor make forward attention map ahead
62 | # of reverse attention map when concatenate, we do it in the LBAM model forward function
63 | concatFeatures = torch.cat((ecFeaturesSkip, nextDcFeatures), 1)
64 |
65 | outputFeatures = concatFeatures * maskFeaturesForAttention
66 |
67 | if hasattr(self, 'bn'):
68 | outputFeatures = self.bn(outputFeatures)
69 | if hasattr(self, 'activ'):
70 | outputFeatures = self.activ(outputFeatures)
71 |
72 | return outputFeatures
--------------------------------------------------------------------------------
/models/weightInitial.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | # weight initial strategies
5 | def weights_init(init_type='gaussian'):
6 | def init_fun(m):
7 | classname = m.__class__.__name__
8 |
9 | if (classname.find('Conv') == 0 or classname.find('Linear') == 0 ) and hasattr(m, 'weight'):
10 | if (init_type == 'gaussian'):
11 | nn.init.normal_(m.weight, 0.0, 0.02)
12 | elif (init_type == 'xavier'):
13 | nn.init.xavier_normal_(m.weight, gain=math.sqrt(2))
14 | elif (init_type == 'kaiming'):
15 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
16 | elif (init_type == 'orthogonal'):
17 | nn.init.orthogonal_(m.weight, gain=math.sqrt(2))
18 | elif (init_type == 'default'):
19 | pass
20 | else:
21 | assert 0, 'Unsupported initialization: {}'.format(init_type)
22 | if hasattr(m, 'bias') and m.bias is not None:
23 | nn.init.constant_(m.bias, 0.0)
24 |
25 | return init_fun
--------------------------------------------------------------------------------
/pytorch_ssim/__init__.py:
--------------------------------------------------------------------------------
1 | from math import exp
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | import lpips
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
9 | return gauss / gauss.sum()
10 |
11 |
12 | def create_window(window_size, channel):
13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
16 | return window
17 |
18 |
19 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
20 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
21 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
22 |
23 | mu1_sq = mu1.pow(2)
24 | mu2_sq = mu2.pow(2)
25 | mu1_mu2 = mu1 * mu2
26 |
27 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
28 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
29 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
30 |
31 | C1 = 0.01 ** 2
32 | C2 = 0.03 ** 2
33 |
34 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
35 |
36 | if size_average:
37 | return ssim_map.mean()
38 | else:
39 | return ssim_map.mean(1).mean(1).mean(1)
40 |
41 |
42 | class SSIM(torch.nn.Module):
43 | def __init__(self, window_size=11, size_average=True):
44 | super(SSIM, self).__init__()
45 | self.window_size = window_size
46 | self.size_average = size_average
47 | self.channel = 1
48 | self.window = create_window(window_size, self.channel)
49 |
50 | def forward(self, img1, img2):
51 | (_, channel, _, _) = img1.size()
52 |
53 | if channel == self.channel and self.window.data.type() == img1.data.type():
54 | window = self.window
55 | else:
56 | window = create_window(self.window_size, channel)
57 |
58 | if img1.is_cuda:
59 | window = window.cuda(img1.get_device())
60 | window = window.type_as(img1)
61 |
62 | self.window = window
63 | self.channel = channel
64 |
65 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
66 |
67 |
68 | def ssim(img1, img2, window_size=11, size_average=True):
69 | (_, channel, _, _) = img1.size()
70 | window = create_window(window_size, channel)
71 |
72 | if img1.is_cuda:
73 | window = window.cuda(img1.get_device())
74 | window = window.type_as(img1)
75 |
76 | return _ssim(img1, img2, window, window_size, channel, size_average)
77 |
78 | def caculatelpips(img1,img2):
79 | loss_fn_alex = lpips.LPIPS(net='alex')
80 | d = loss_fn_alex(img1,img2)
81 | return d
--------------------------------------------------------------------------------
/pytorch_ssim/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/pytorch_ssim/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/pytorch_ssim/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/pytorch_ssim/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/test_random_batch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import argparse
4 | import torch
5 | import torch.nn as nn
6 | import torch.backends.cudnn as cudnn
7 | from torchvision.utils import save_image
8 | from torch.utils.data import DataLoader
9 | from data.dataloader import GetData
10 | from models.LBAMModel import LBAMModel
11 | import pytorch_ssim
12 | import random
13 | import numpy as np
14 | from MECNet.models import EdgeModel
15 | from MECNet.config import Config
16 | import numpy
17 | from PIL.Image import fromarray
18 |
19 | torch.manual_seed(0)
20 | torch.cuda.manual_seed_all(0)
21 | random.seed(0)
22 | numpy.random.seed(0)
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('--numOfWorkers', type=int, default=4,
25 | help='workers for dataloader')
26 | parser.add_argument('--local_rank',type=int,default=0)
27 | parser.add_argument('--pretrained', type=str, default='', help='pretrained models')
28 | parser.add_argument('--batchSize', type=int, default=16)
29 | parser.add_argument('--loadSize', type=int, default=350,
30 | help='image loading size')
31 | parser.add_argument('--cropSize', type=int, default=256,
32 | help='image training size')
33 | parser.add_argument('--dataRoot', type=str,
34 | default='')
35 | parser.add_argument('--maskRoot', type=str,
36 | default='')
37 | parser.add_argument('--savePath', type=str, default='./results')
38 | args = parser.parse_args()
39 |
40 | cuda = torch.cuda.is_available()
41 | if cuda:
42 | print('Cuda is available!')
43 | cudnn.benchmark = True
44 | os.makedirs(os.path.join(args.savePath,"GT"), exist_ok=True)
45 | os.makedirs(os.path.join(args.savePath,"damaged"), exist_ok=True)
46 | os.makedirs(os.path.join(args.savePath,"ours"), exist_ok=True)
47 | os.makedirs(os.path.join(args.savePath,"input"), exist_ok=True)
48 | os.makedirs(os.path.join(args.savePath,"masks"), exist_ok=True)
49 | os.makedirs(os.path.join(args.savePath,"edge"), exist_ok=True)
50 |
51 |
52 | batchSize = args.batchSize
53 | loadSize = (args.loadSize, args.loadSize)
54 | cropSize = (args.cropSize, args.cropSize)
55 | dataRoot = args.dataRoot
56 | maskRoot = args.maskRoot
57 | savePath = args.savePath
58 |
59 | if not os.path.exists(savePath):
60 | os.makedirs(savePath)
61 |
62 | config = Config("config.yml")
63 | edge_model = EdgeModel(config).to(config.DEVICE)
64 | edge_model.load()
65 | edge_model.cuda()
66 | edge_model = nn.DataParallel(edge_model, device_ids=[0,1])
67 | imgData = GetData(dataRoot, maskRoot, loadSize, cropSize)
68 | data_loader = DataLoader(imgData, batch_size=batchSize, shuffle=False, num_workers=1, drop_last=False)
69 |
70 | num_epochs = 100
71 |
72 | netG = LBAMModel(5, 3)
73 |
74 | if args.pretrained != '':
75 | netG.load_state_dict(torch.load(args.pretrained))
76 | else:
77 | print('No pretrained model provided!')
78 |
79 | #
80 | if cuda:
81 | netG = netG.cuda()
82 |
83 | for param in netG.parameters():
84 | param.requires_grad = False
85 |
86 | print('OK!')
87 |
88 |
89 | sum_psnr = 0
90 | sum_ssim = 0
91 | count = 0
92 | sum_time = 0.0
93 | l1_loss = 0
94 |
95 | import time
96 | start = time.time()
97 | for i in range(1, num_epochs + 1):
98 | netG.eval()
99 | for inputImgs, GT, masks, img_gray,edge,masks_over in (data_loader):
100 | if count >= 60:
101 | break
102 | if cuda:
103 | inputImgs = inputImgs.cuda()
104 | img_gray=img_gray.cuda()
105 | GT = GT.cuda()
106 | masks = masks.cuda()
107 | edge = edge.cuda()
108 | masks_over=masks_over.cuda()
109 | outputs_2 = edge_model(img_gray, edge, masks_over)
110 | outputs_merged = (outputs_2 * masks_over) + (edge * (1 - masks_over))
111 | inputImgs2 = torch.cat((inputImgs, outputs_merged), 1)
112 | #do something other
113 | fake_images = netG(inputImgs2, masks,outputs_merged)
114 |
115 | g_image = fake_images.data.cpu()
116 | GT = GT.data.cpu()
117 | mask = masks.data.cpu()
118 | damaged = GT * mask
119 | generaredImage = GT * mask + g_image * (1 - mask)
120 | groundTruth = GT
121 | masksT = mask
122 | generaredImage = generaredImage
123 | groundTruth = groundTruth
124 | count += 1
125 | batch_mse = ((groundTruth - generaredImage) ** 2).mean()
126 | psnr = 10 * math.log10(1 / batch_mse)
127 | sum_psnr += psnr
128 | print(count, ' psnr:', psnr)
129 | ssim = pytorch_ssim.ssim(groundTruth * 255, generaredImage * 255)
130 | sum_ssim += ssim
131 | print(count, ' ssim:', ssim)
132 | l1_loss += nn.L1Loss()(generaredImage, groundTruth)
133 |
134 | outputs =torch.Tensor(5* GT.size()[0], GT.size()[1], cropSize[0], cropSize[1])
135 | for i in range(GT.size()[0]):
136 | outputs[5 * i] = masksT[i]
137 | outputs[5 * i + 1] = damaged[i]
138 | outputs[5 * i + 2] = GT[i] * masksT[i]
139 | outputs[5 * i + 2] = generaredImage[i]
140 | outputs[5 * i + 3] = GT[i]
141 | outputs[5 * i + 4]=outputs_merged[i]
142 | #outputs[5 * i + 4] = 1 - masksT[i]
143 | # save_image(outputs, os.path.join(savePath, 'results-{}'.format(count) + '.png'))
144 |
145 | # make subdirs to save mask GT results and input and damaged images
146 | damaged = GT * mask + (1 - mask)
147 |
148 | for j in range(GT.size()[0]):
149 | save_image(outputs[5 * j + 1], savePath + '/damaged/damaged{}-{}.png'.format(count, j))
150 | outputs[5 * j + 1] = damaged[j]
151 |
152 | for j in range(GT.size()[0]):
153 | outputs[5 * j] = 1- masksT[j]
154 | save_image(outputs[5 * j], savePath + '/masks/mask{}-{}.png'.format(count, j))
155 | save_image(outputs[5 * j + 1], savePath + '/input/input{}-{}.png'.format(count, j))
156 | save_image(outputs[5 * j + 2], savePath + '/ours/ours{}-{}.png'.format(count, j))
157 | save_image(outputs[5 * j + 3], savePath + '/GT/GT{}-{}.png'.format(count, j))
158 | save_image(outputs[5 * j + 4], savePath + '/edge/edge{}-{}.png'.format(count, j))
159 |
160 |
161 |
162 | end = time.time()
163 | sum_time += (end - start) / batchSize
164 |
165 |
166 | print('avg l1 loss:', l1_loss / count)
167 | print('average psnr:', sum_psnr / count)
168 | print('average ssim:', sum_ssim / count)
169 | print('average time cost:', sum_time / count)
170 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import argparse
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import torch.backends.cudnn as cudnn
8 | from PIL import Image
9 | from torch.autograd import Variable
10 | from torchvision.utils import save_image
11 | from torchvision import datasets
12 | from torch.utils.data import DataLoader
13 | from torchvision import utils
14 | from data.dataloader_canny import GetData
15 | from loss.InpaintingLoss import InpaintingLossWithGAN
16 | from models.LBAMModel import LBAMModel, VGG16FeatureExtractor
17 | from MECNet.models import EdgeModel
18 | from MECNet.config import Config
19 |
20 | torch.set_num_threads(6)
21 |
22 |
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('--numOfWorkers', type=int, default=4,
25 | help='workers for dataloader')
26 | parser.add_argument('--modelsSavePath', type=str, default='',
27 | help='path for saving models')
28 | parser.add_argument('--logPath', type=str,
29 | default='')
30 | parser.add_argument('--batchSize', type=int, default=16)
31 | parser.add_argument('--loadSize', type=int, default=256,
32 | help='image loading size')
33 | parser.add_argument('--cropSize', type=int, default=256,
34 | help='image training size')
35 | parser.add_argument('--dataRoot', type=str,
36 | default='')
37 | parser.add_argument('--maskRoot', type=str,
38 | default='')
39 | parser.add_argument('--pretrained',type=str, default='', help='pretrained models for finetuning')
40 | parser.add_argument('--train_epochs', type=int, default=500, help='training epochs')
41 | args = parser.parse_args()
42 |
43 |
44 |
45 | cuda = torch.cuda.is_available()
46 | if cuda:
47 | print('Cuda is available!')
48 | cudnn.enable = True
49 | cudnn.benchmark = True
50 |
51 |
52 | batchSize = args.batchSize
53 | loadSize = (args.loadSize, args.loadSize)
54 | cropSize = (args.cropSize, args.cropSize)
55 |
56 | if not os.path.exists(args.modelsSavePath):
57 | os.makedirs(args.modelsSavePath)
58 |
59 | config = Config("config.yml")
60 | edge_model = EdgeModel(config).to(config.DEVICE)
61 | edge_model.load()
62 | edge_model.cuda()
63 | edge_model = nn.DataParallel(edge_model, device_ids=[0,1,2,3])
64 | dataRoot = args.dataRoot
65 | maskRoot = args.maskRoot
66 |
67 |
68 | imgData = GetData(dataRoot, maskRoot, loadSize, cropSize)
69 | data_loader = DataLoader(imgData, batch_size=batchSize,
70 | shuffle=True, num_workers=args.numOfWorkers, drop_last=False, pin_memory=True)
71 |
72 | num_epochs = args.train_epochs
73 |
74 | netG = LBAMModel(5, 3)
75 | if args.pretrained != '':
76 | netG.load_state_dict(torch.load(args.pretrained))
77 |
78 |
79 |
80 | numOfGPUs = torch.cuda.device_count()
81 |
82 | if cuda:
83 | netG = netG.cuda()
84 | if numOfGPUs > 1:
85 | netG = nn.DataParallel(netG, device_ids=range(numOfGPUs))
86 |
87 | count = 1
88 |
89 |
90 | G_optimizer = optim.Adam(netG.parameters(), lr=0.000025, betas=(0.5, 0.9))
91 |
92 |
93 | criterion = InpaintingLossWithGAN(args.logPath, VGG16FeatureExtractor(), lr=0.00001, betasInit=(0.0, 0.9), Lamda=10.0)
94 |
95 | if cuda:
96 | criterion = criterion.cuda()
97 |
98 | if numOfGPUs > 1:
99 | criterion = nn.DataParallel(criterion, device_ids=range(numOfGPUs))
100 |
101 | print('OK!')
102 |
103 | for i in range(1, num_epochs + 1):
104 | netG.train()
105 |
106 | for inputImgs, GT, masks,img_gray, edge, masks_over in (data_loader):
107 |
108 | if cuda:
109 | inputImgs = inputImgs.cuda()
110 | GT = GT.cuda()
111 | masks = masks.cuda()
112 | edge = edge.cuda()
113 | masks_over = masks_over.cuda()
114 | netG.zero_grad()
115 | outputs = edge_model(img_gray, edge, masks_over)
116 |
117 | outputs_merged = (outputs * masks_over) + (edge * (1 - masks_over))
118 | inputImgs = torch.cat((inputImgs, outputs_merged), 1)
119 | # print(inputImgs2.shape)
120 | fake_images = netG(inputImgs, masks,outputs_merged)
121 | G_loss = criterion(inputImgs[:, 0:3, :, :], masks, fake_images, GT, count, i)
122 | G_loss = G_loss.sum()
123 | G_optimizer.zero_grad()
124 | G_loss.backward()
125 | G_optimizer.step()
126 |
127 | with open('/home/wangdongsheng/LBAM_version6/loss2.txt', 'a') as file:
128 | file.write('Generator Loss of epoch{} is {}\n'.format(i, G_loss.item()))
129 |
130 |
131 | count += 1
132 |
133 | """ if (count % 4000 == 0):
134 | torch.save(netG.module.state_dict(), args.modelsSavePath +
135 | '/Places_{}.pth'.format(i)) """
136 |
137 | if ( i % 10 == 0):
138 | if numOfGPUs > 1 :
139 | torch.save(netG.module.state_dict(), args.modelsSavePath +
140 | '/LBAM_{}.pth'.format(i%50))
141 | else:
142 | torch.save(netG.state_dict(), args.modelsSavePath +
143 | '/LBAM_{}.pth'.format(i%50))
144 |
--------------------------------------------------------------------------------