├── MECNet ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── config.cpython-36.pyc │ ├── config.cpython-37.pyc │ ├── dataset.cpython-37.pyc │ ├── edge_connect.cpython-37.pyc │ ├── loss.cpython-36.pyc │ ├── loss.cpython-37.pyc │ ├── metrics.cpython-37.pyc │ ├── models.cpython-36.pyc │ ├── models.cpython-37.pyc │ ├── network3.cpython-36.pyc │ ├── network3.cpython-37.pyc │ ├── networks.cpython-36.pyc │ ├── networks.cpython-37.pyc │ ├── networks2.cpython-36.pyc │ ├── networks2.cpython-37.pyc │ └── utils.cpython-37.pyc ├── config.py ├── dataset.py ├── edge_connect.py ├── loss.py ├── metrics.py ├── models.py ├── networks2.py └── utils.py ├── README.md ├── config.yml ├── data ├── __pycache__ │ ├── basicFunction.cpython-36.pyc │ ├── basicFunction.cpython-37.pyc │ ├── dataloader.cpython-36.pyc │ ├── dataloader.cpython-37.pyc │ ├── dataloader_canny.cpython-36.pyc │ └── dataloader_canny.cpython-37.pyc ├── basicFunction.py ├── dataloader.py └── dataloader_canny.py ├── examples ├── GT28-1.png ├── MEDFE28-1.png ├── ec28-1.png ├── edge_mecnet(s)_1.png ├── edge_mecnet_1.png ├── gc28-1.png ├── gl28-1.png ├── input1.png ├── input28-1.png ├── ours28-1.png └── pconv28-1.png ├── loss ├── InpaintingLoss.py └── __pycache__ │ ├── InpaintingLoss.cpython-36.pyc │ └── InpaintingLoss.cpython-37.pyc ├── models ├── ActivationFunction.py ├── EdgeAttentionLayer.py ├── LBAMModel.py ├── __pycache__ │ ├── ActivationFunction.cpython-36.pyc │ ├── ActivationFunction.cpython-37.pyc │ ├── EdgeAttentionLayer.cpython-36.pyc │ ├── EdgeAttentionLayer.cpython-37.pyc │ ├── LBAMModel.cpython-36.pyc │ ├── LBAMModel.cpython-37.pyc │ ├── discriminator.cpython-36.pyc │ ├── discriminator.cpython-37.pyc │ ├── forwardAttentionLayer.cpython-36.pyc │ ├── forwardAttentionLayer.cpython-37.pyc │ ├── reverseAttentionLayer.cpython-36.pyc │ ├── reverseAttentionLayer.cpython-37.pyc │ ├── weightInitial.cpython-36.pyc │ └── weightInitial.cpython-37.pyc ├── discriminator.py ├── forwardAttentionLayer.py ├── reverseAttentionLayer.py └── weightInitial.py ├── pytorch_ssim ├── __init__.py └── __pycache__ │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-37.pyc ├── test_random_batch.py └── train.py /MECNet/__init__.py: -------------------------------------------------------------------------------- 1 | # empty -------------------------------------------------------------------------------- /MECNet/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/edge_connect.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/edge_connect.cpython-37.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/metrics.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/metrics.cpython-37.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/models.cpython-36.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/models.cpython-37.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/network3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/network3.cpython-36.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/network3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/network3.cpython-37.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/networks2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/networks2.cpython-36.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/networks2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/networks2.cpython-37.pyc -------------------------------------------------------------------------------- /MECNet/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/MECNet/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /MECNet/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | 4 | class Config(dict): 5 | def __init__(self, config_path): 6 | with open(config_path, 'r') as f: 7 | self._yaml = f.read() 8 | self._dict = yaml.load(self._yaml) 9 | self._dict['PATH'] = "~/LBAM_GRU_version2/checkpoints/psv" 10 | 11 | def __getattr__(self, name): 12 | if self._dict.get(name) is not None: 13 | return self._dict[name] 14 | 15 | if DEFAULT_CONFIG.get(name) is not None: 16 | return DEFAULT_CONFIG[name] 17 | 18 | return None 19 | 20 | def print(self): 21 | print('Model configurations:') 22 | print('---------------------------------') 23 | print(self._yaml) 24 | print('') 25 | print('---------------------------------') 26 | print('') 27 | 28 | 29 | DEFAULT_CONFIG = { 30 | 'MODE': 2, # 1: train, 2: test, 3: eval 31 | 'MODEL': 1, # 1: edge model, 2: inpaint model, 3: edge-inpaint model, 4: joint model 32 | 'MASK': 3, # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half) 33 | 'EDGE': 1, # 1: canny, 2: external 34 | 'NMS': 1, # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny 35 | 'SEED': 10, # random seed 36 | 'GPU': [0], # list of gpu ids 37 | 'DEBUG': 0, # turns on debugging mode 38 | 'VERBOSE': 0, # turns on verbose mode in the output console 39 | 40 | 'LR': 0.0001, # learning rate 41 | 'D2G_LR': 0.1, # discriminator/generator learning rate ratio 42 | 'BETA1': 0.0, # adam optimizer beta1 43 | 'BETA2': 0.9, # adam optimizer beta2 44 | 'BATCH_SIZE': 8, # input batch size for training 45 | 'INPUT_SIZE': 256, # input image size for training 0 for original size 46 | 'SIGMA': 2, # standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge) 47 | 'MAX_ITERS': 2e6, # maximum number of iterations to train the model 48 | 49 | 'EDGE_THRESHOLD': 0.5, # edge detection threshold 50 | 'L1_LOSS_WEIGHT': 1, # l1 loss weight 51 | 'FM_LOSS_WEIGHT': 10, # feature-matching loss weight 52 | 'STYLE_LOSS_WEIGHT': 1, # style loss weight 53 | 'CONTENT_LOSS_WEIGHT': 1, # perceptual loss weight 54 | 'INPAINT_ADV_LOSS_WEIGHT': 0.01,# adversarial loss weight 55 | 56 | 'GAN_LOSS': 'nsgan', # nsgan | lsgan | hinge 57 | 'GAN_POOL_SIZE': 0, # fake images pool size 58 | 59 | 'SAVE_INTERVAL': 1000, # how many iterations to wait before saving model (0: never) 60 | 'SAMPLE_INTERVAL': 1000, # how many iterations to wait before sampling (0: never) 61 | 'SAMPLE_SIZE': 12, # number of images to sample 62 | 'EVAL_INTERVAL': 0, # how many iterations to wait before model evaluation (0: never) 63 | 'LOG_INTERVAL': 10, # how many iterations to wait before logging training status (0: never) 64 | } 65 | -------------------------------------------------------------------------------- /MECNet/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import scipy 4 | import torch 5 | import random 6 | import numpy as np 7 | import torchvision.transforms.functional as F 8 | from torch.utils.data import DataLoader 9 | from PIL import Image 10 | from scipy.misc import imread 11 | from skimage.feature import canny 12 | from skimage.color import rgb2gray, gray2rgb 13 | from .utils import create_mask 14 | import matplotlib.pyplot as plt 15 | 16 | class Dataset(torch.utils.data.Dataset): 17 | def __init__(self, config, flist, edge_flist, mask_flist, augment=True, training=True): 18 | super(Dataset, self).__init__() 19 | self.augment = augment 20 | self.training = training 21 | self.data = self.load_flist(flist) 22 | self.edge_data = self.load_flist(edge_flist) 23 | self.mask_data = self.load_flist(mask_flist) 24 | 25 | self.input_size = config.INPUT_SIZE 26 | self.sigma = 2 27 | self.edge = config.EDGE 28 | self.mask = config.MASK 29 | self.nms = config.NMS 30 | 31 | # in test mode, there's a one-to-one relationship between mask and image 32 | # masks are loaded non random 33 | if config.MODE == 2: 34 | self.mask = 6 35 | 36 | def __len__(self): 37 | return len(self.data) 38 | 39 | def __getitem__(self, index): 40 | try: 41 | item = self.load_item(index) 42 | except: 43 | print('loading error: ' + self.data[index]) 44 | item = self.load_item(0) 45 | 46 | return item 47 | 48 | def load_name(self, index): 49 | name = self.data[index] 50 | return os.path.basename(name) 51 | 52 | def load_item(self, index): 53 | 54 | size = self.input_size 55 | 56 | # load image 57 | img = imread(self.data[index]) 58 | 59 | # gray to rgb 60 | if len(img.shape) < 3: 61 | img = gray2rgb(img) 62 | 63 | # resize/crop if needed 64 | if size != 0: 65 | img = self.resize(img, size, size) 66 | 67 | # create grayscale image 68 | img_gray = rgb2gray(img) 69 | 70 | # load mask 71 | mask = self.load_mask(img, index) 72 | plt.imshow(mask, cmap=plt.cm.gray) 73 | plt.show() 74 | # load edge 75 | edge = self.load_edge(img_gray, index, mask) 76 | # print(img.shape,img_gray.shape,mask.shape,edge.shape) 77 | # augment data 78 | if self.augment and np.random.binomial(1, 0.5) > 0: 79 | img = img[:, ::-1, ...] 80 | img_gray = img_gray[:, ::-1, ...] 81 | edge = edge[:, ::-1, ...] 82 | mask = mask[:, ::-1, ...] 83 | 84 | return self.to_tensor(img), self.to_tensor(img_gray), self.to_tensor(edge), self.to_tensor(mask) 85 | 86 | def load_edge(self, img, index, mask): 87 | sigma = self.sigma 88 | 89 | # in test mode images are masked (with masked regions), 90 | # using 'mask' parameter prevents canny to detect edges for the masked regions 91 | mask = None if self.training else (1 - mask / 255).astype(np.bool) 92 | 93 | # canny 94 | if self.edge == 1: 95 | # no edge 96 | if sigma == -1: 97 | return np.zeros(img.shape).astype(np.float) 98 | 99 | # random sigma 100 | if sigma == 0: 101 | sigma = random.randint(1, 4) 102 | 103 | return canny(img, sigma=sigma, mask=mask).astype(np.float) 104 | 105 | # external 106 | else: 107 | imgh, imgw = img.shape[0:2] 108 | edge = imread(self.edge_data[index]) 109 | edge = self.resize(edge, imgh, imgw) 110 | 111 | # non-max suppression 112 | if self.nms == 1: 113 | edge = edge * canny(img, sigma=sigma, mask=mask) 114 | 115 | return edge 116 | 117 | def load_mask(self, img, index): 118 | imgh, imgw = img.shape[0:2] 119 | mask_type = self.mask 120 | 121 | # external + random block 122 | if mask_type == 4: 123 | mask_type = 1 if np.random.binomial(1, 0.5) == 1 else 3 124 | 125 | # external + random block + half 126 | elif mask_type == 5: 127 | mask_type = np.random.randint(1, 4) 128 | 129 | # random block 130 | if mask_type == 1: 131 | return create_mask(imgw, imgh, imgw // 2, imgh // 2) 132 | 133 | # half 134 | if mask_type == 2: 135 | # randomly choose right or left 136 | return create_mask(imgw, imgh, imgw // 2, imgh, 0 if random.random() < 0.5 else imgw // 2, 0) 137 | 138 | # external 139 | if mask_type == 3: 140 | mask_index = random.randint(0, len(self.mask_data) - 1) 141 | mask = imread(self.mask_data[mask_index]) 142 | mask = self.resize(mask, imgh, imgw) 143 | mask = (mask > 0).astype(np.uint8) * 255 # threshold due to interpolation 144 | return mask 145 | 146 | # test mode: load mask non random 147 | if mask_type == 6: 148 | mask = imread(self.mask_data[index]) 149 | mask = self.resize(mask, imgh, imgw, centerCrop=False) 150 | mask = rgb2gray(mask) 151 | mask = (mask > 0).astype(np.uint8) * 255 152 | return mask 153 | 154 | def to_tensor(self, img): 155 | img = Image.fromarray(img) 156 | img_t = F.to_tensor(img).float() 157 | return img_t 158 | 159 | def resize(self, img, height, width, centerCrop=True): 160 | imgh, imgw = img.shape[0:2] 161 | 162 | if centerCrop and imgh != imgw: 163 | # center crop 164 | side = np.minimum(imgh, imgw) 165 | j = (imgh - side) // 2 166 | i = (imgw - side) // 2 167 | img = img[j:j + side, i:i + side, ...] 168 | 169 | img = scipy.misc.imresize(img, [height, width]) 170 | 171 | return img 172 | 173 | def load_flist(self, flist): 174 | if isinstance(flist, list): 175 | return flist 176 | 177 | # flist: image file path, image directory path, text file flist path 178 | if isinstance(flist, str): 179 | if os.path.isdir(flist): 180 | flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png')) 181 | flist.sort() 182 | return flist 183 | 184 | if os.path.isfile(flist): 185 | try: 186 | return np.genfromtxt(flist, dtype=np.str, encoding='utf-8') 187 | except: 188 | return [flist] 189 | 190 | return [] 191 | 192 | def create_iterator(self, batch_size): 193 | while True: 194 | sample_loader = DataLoader( 195 | dataset=self, 196 | batch_size=batch_size, 197 | drop_last=True 198 | ) 199 | 200 | for item in sample_loader: 201 | yield item 202 | -------------------------------------------------------------------------------- /MECNet/edge_connect.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from .dataset import Dataset 6 | from .models import EdgeModel, InpaintingModel 7 | from .utils import Progbar, create_dir, stitch_images, imsave 8 | from .metrics import PSNR, EdgeAccuracy 9 | import matplotlib.pyplot as plt 10 | import cv2 11 | 12 | class EdgeConnect(): 13 | def __init__(self, config): 14 | self.config = config 15 | 16 | if config.MODEL == 1: 17 | model_name = 'edge' 18 | elif config.MODEL == 2: 19 | model_name = 'inpaint' 20 | elif config.MODEL == 3: 21 | model_name = 'edge_inpaint' 22 | elif config.MODEL == 4: 23 | model_name = 'joint' 24 | 25 | self.debug = False 26 | self.model_name = model_name 27 | self.edge_model = EdgeModel(config).to(config.DEVICE) 28 | self.inpaint_model = InpaintingModel(config).to(config.DEVICE) 29 | 30 | self.psnr = PSNR(255.0).to(config.DEVICE) 31 | self.edgeacc = EdgeAccuracy(config.EDGE_THRESHOLD).to(config.DEVICE) 32 | 33 | # test mode 34 | if self.config.MODE == 2: 35 | print(config.TEST_FLIST) 36 | print(config.TEST_EDGE_FLIST) 37 | print(config.TEST_MASK_FLIST) 38 | self.test_dataset = Dataset(config, config.TEST_FLIST, config.TEST_EDGE_FLIST, config.TEST_MASK_FLIST, augment=False, training=False) 39 | else: 40 | self.train_dataset = Dataset(config, config.TRAIN_FLIST, config.TRAIN_EDGE_FLIST, config.TRAIN_MASK_FLIST, augment=True, training=True) 41 | self.val_dataset = Dataset(config, config.VAL_FLIST, config.VAL_EDGE_FLIST, config.VAL_MASK_FLIST, augment=False, training=True) 42 | self.sample_iterator = self.val_dataset.create_iterator(config.SAMPLE_SIZE) 43 | 44 | self.samples_path = os.path.join(config.PATH, 'samples') 45 | self.results_path = os.path.join(config.PATH, 'results') 46 | 47 | if config.RESULTS is not None: 48 | self.results_path = os.path.join(config.RESULTS) 49 | 50 | if config.DEBUG is not None and config.DEBUG != 0: 51 | self.debug = True 52 | 53 | self.log_file = os.path.join(config.PATH, 'log_' + model_name + '.dat') 54 | 55 | def load(self): 56 | if self.config.MODEL == 1: 57 | self.edge_model.load() 58 | 59 | elif self.config.MODEL == 2: 60 | self.inpaint_model.load() 61 | 62 | else: 63 | self.edge_model.load() 64 | self.inpaint_model.load() 65 | 66 | def save(self): 67 | if self.config.MODEL == 1: 68 | self.edge_model.save() 69 | 70 | elif self.config.MODEL == 2 or self.config.MODEL == 3: 71 | self.inpaint_model.save() 72 | 73 | else: 74 | self.edge_model.save() 75 | self.inpaint_model.save() 76 | 77 | def train(self): 78 | train_loader = DataLoader( 79 | dataset=self.train_dataset, 80 | batch_size=self.config.BATCH_SIZE, 81 | num_workers=4, 82 | drop_last=True, 83 | shuffle=True 84 | ) 85 | 86 | epoch = 0 87 | keep_training = True 88 | model = self.config.MODEL 89 | max_iteration = int(float((self.config.MAX_ITERS))) 90 | total = len(self.train_dataset) 91 | 92 | if total == 0: 93 | print('No training data was provided! Check \'TRAIN_FLIST\' value in the configuration file.') 94 | return 95 | 96 | while(keep_training): 97 | epoch += 1 98 | print('\n\nTraining epoch: %d' % epoch) 99 | 100 | progbar = Progbar(total, width=20, stateful_metrics=['epoch', 'iter']) 101 | 102 | for items in train_loader: 103 | self.edge_model.train() 104 | self.inpaint_model.train() 105 | 106 | images, images_gray, edges, masks = self.cuda(*items) 107 | 108 | # edge model 109 | if model == 1: 110 | # train 111 | outputs, gen_loss, dis_loss, logs = self.edge_model.process(images_gray, edges, masks) 112 | 113 | # metrics 114 | precision, recall = self.edgeacc(edges * masks, outputs * masks) 115 | logs.append(('precision', precision.item())) 116 | logs.append(('recall', recall.item())) 117 | 118 | # backward 119 | self.edge_model.backward(gen_loss, dis_loss) 120 | iteration = self.edge_model.iteration 121 | 122 | 123 | # inpaint model 124 | elif model == 2: 125 | # train 126 | outputs, gen_loss, dis_loss, logs = self.inpaint_model.process(images, edges, masks) 127 | outputs_merged = (outputs * masks) + (images * (1 - masks)) 128 | 129 | # metrics 130 | psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged)) 131 | mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float() 132 | logs.append(('psnr', psnr.item())) 133 | logs.append(('mae', mae.item())) 134 | 135 | # backward 136 | self.inpaint_model.backward(gen_loss, dis_loss) 137 | iteration = self.inpaint_model.iteration 138 | 139 | 140 | # inpaint with edge model 141 | elif model == 3: 142 | # train 143 | if True or np.random.binomial(1, 0.5) > 0: 144 | outputs = self.edge_model(images_gray, edges, masks) 145 | outputs = outputs * masks + edges * (1 - masks) 146 | else: 147 | outputs = edges 148 | 149 | outputs, gen_loss, dis_loss, logs = self.inpaint_model.process(images, outputs.detach(), masks) 150 | outputs_merged = (outputs * masks) + (images * (1 - masks)) 151 | 152 | # metrics 153 | psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged)) 154 | mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float() 155 | logs.append(('psnr', psnr.item())) 156 | logs.append(('mae', mae.item())) 157 | 158 | # backward 159 | self.inpaint_model.backward(gen_loss, dis_loss) 160 | iteration = self.inpaint_model.iteration 161 | 162 | 163 | # joint model 164 | else: 165 | # train 166 | e_outputs, e_gen_loss, e_dis_loss, e_logs = self.edge_model.process(images_gray, edges, masks) 167 | e_outputs = e_outputs * masks + edges * (1 - masks) 168 | i_outputs, i_gen_loss, i_dis_loss, i_logs = self.inpaint_model.process(images, e_outputs, masks) 169 | outputs_merged = (i_outputs * masks) + (images * (1 - masks)) 170 | 171 | # metrics 172 | psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged)) 173 | mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float() 174 | precision, recall = self.edgeacc(edges * masks, e_outputs * masks) 175 | e_logs.append(('pre', precision.item())) 176 | e_logs.append(('rec', recall.item())) 177 | i_logs.append(('psnr', psnr.item())) 178 | i_logs.append(('mae', mae.item())) 179 | logs = e_logs + i_logs 180 | 181 | # backward 182 | self.inpaint_model.backward(i_gen_loss, i_dis_loss) 183 | self.edge_model.backward(e_gen_loss, e_dis_loss) 184 | iteration = self.inpaint_model.iteration 185 | 186 | 187 | if iteration >= max_iteration: 188 | keep_training = False 189 | break 190 | 191 | logs = [ 192 | ("epoch", epoch), 193 | ("iter", iteration), 194 | ] + logs 195 | 196 | progbar.add(len(images), values=logs if self.config.VERBOSE else [x for x in logs if not x[0].startswith('l_')]) 197 | 198 | # log model at checkpoints 199 | if self.config.LOG_INTERVAL and iteration % self.config.LOG_INTERVAL == 0: 200 | self.log(logs) 201 | 202 | # sample model at checkpoints 203 | if self.config.SAMPLE_INTERVAL and iteration % self.config.SAMPLE_INTERVAL == 0: 204 | self.sample() 205 | 206 | # evaluate model at checkpoints 207 | if self.config.EVAL_INTERVAL and iteration % self.config.EVAL_INTERVAL == 0: 208 | print('\nstart eval...\n') 209 | self.eval() 210 | 211 | # save model at checkpoints 212 | if self.config.SAVE_INTERVAL and iteration % self.config.SAVE_INTERVAL == 0: 213 | self.save() 214 | 215 | print('\nEnd training....') 216 | 217 | def eval(self): 218 | val_loader = DataLoader( 219 | dataset=self.val_dataset, 220 | batch_size=self.config.BATCH_SIZE, 221 | drop_last=True, 222 | shuffle=True 223 | ) 224 | 225 | model = self.config.MODEL 226 | total = len(self.val_dataset) 227 | 228 | self.edge_model.eval() 229 | self.inpaint_model.eval() 230 | 231 | progbar = Progbar(total, width=20, stateful_metrics=['it']) 232 | iteration = 0 233 | 234 | for items in val_loader: 235 | iteration += 1 236 | images, images_gray, edges, masks = self.cuda(*items) 237 | 238 | # edge model 239 | if model == 1: 240 | # eval 241 | outputs, gen_loss, dis_loss, logs = self.edge_model.process(images_gray, edges, masks) 242 | 243 | # metrics 244 | precision, recall = self.edgeacc(edges * masks, outputs * masks) 245 | logs.append(('precision', precision.item())) 246 | logs.append(('recall', recall.item())) 247 | 248 | 249 | # inpaint model 250 | elif model == 2: 251 | # eval 252 | outputs, gen_loss, dis_loss, logs = self.inpaint_model.process(images, edges, masks) 253 | outputs_merged = (outputs * masks) + (images * (1 - masks)) 254 | 255 | # metrics 256 | psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged)) 257 | mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float() 258 | logs.append(('psnr', psnr.item())) 259 | logs.append(('mae', mae.item())) 260 | 261 | 262 | # inpaint with edge model 263 | elif model == 3: 264 | # eval 265 | outputs = self.edge_model(images_gray, edges, masks) 266 | outputs = outputs * masks + edges * (1 - masks) 267 | 268 | outputs, gen_loss, dis_loss, logs = self.inpaint_model.process(images, outputs.detach(), masks) 269 | outputs_merged = (outputs * masks) + (images * (1 - masks)) 270 | 271 | # metrics 272 | psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged)) 273 | mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float() 274 | logs.append(('psnr', psnr.item())) 275 | logs.append(('mae', mae.item())) 276 | 277 | 278 | # joint model 279 | else: 280 | # eval 281 | e_outputs, e_gen_loss, e_dis_loss, e_logs = self.edge_model.process(images_gray, edges, masks) 282 | e_outputs = e_outputs * masks + edges * (1 - masks) 283 | i_outputs, i_gen_loss, i_dis_loss, i_logs = self.inpaint_model.process(images, e_outputs, masks) 284 | outputs_merged = (i_outputs * masks) + (images * (1 - masks)) 285 | 286 | # metrics 287 | psnr = self.psnr(self.postprocess(images), self.postprocess(outputs_merged)) 288 | mae = (torch.sum(torch.abs(images - outputs_merged)) / torch.sum(images)).float() 289 | precision, recall = self.edgeacc(edges * masks, e_outputs * masks) 290 | e_logs.append(('pre', precision.item())) 291 | e_logs.append(('rec', recall.item())) 292 | i_logs.append(('psnr', psnr.item())) 293 | i_logs.append(('mae', mae.item())) 294 | logs = e_logs + i_logs 295 | 296 | 297 | logs = [("it", iteration), ] + logs 298 | progbar.add(len(images), values=logs) 299 | 300 | def test(self): 301 | self.edge_model.eval() 302 | self.inpaint_model.eval() 303 | 304 | model = self.config.MODEL 305 | create_dir(self.results_path) 306 | 307 | test_loader = DataLoader( 308 | dataset=self.test_dataset, 309 | batch_size=1, 310 | ) 311 | 312 | index = 0 313 | for items in test_loader: 314 | name = self.test_dataset.load_name(index) 315 | images, images_gray, edges, masks = items 316 | tmp = images[0, :, :, :] 317 | tmp = np.transpose(tmp, (1, 2, 0)) 318 | # plt.imshow(tmp) 319 | # plt.show() 320 | # tmp = images_gray[0, :, :, :] 321 | # tmp = np.transpose(tmp, (1, 2, 0)) 322 | # plt.imshow(tmp[:,:,0], cmap=plt.cm.gray) 323 | # plt.show() 324 | # tmp = edges[0, :, :, :] 325 | # tmp = np.transpose(tmp, (1, 2, 0)) 326 | # plt.imshow(tmp[:,:,0], cmap=plt.cm.gray) 327 | # plt.show() 328 | 329 | # plt.imshow(edges) 330 | # plt.show() 331 | 332 | images, images_gray, edges, masks = self.cuda(*items) 333 | 334 | index += 1 335 | 336 | # edge model 337 | if model == 1: 338 | outputs = self.edge_model(images_gray, edges, masks) 339 | outputs_merged = (outputs * masks) + (edges * (1 - masks)) 340 | 341 | # inpaint model 342 | elif model == 2: 343 | outputs = self.inpaint_model(images, edges, masks) 344 | outputs_merged = (outputs * masks) + (images * (1 - masks)) 345 | 346 | # inpaint with edge model / joint model 347 | else: 348 | edges = self.edge_model(images_gray, edges, masks).detach() 349 | outputs = self.inpaint_model(images, edges, masks) 350 | outputs_merged = (outputs * masks) + (images * (1 - masks)) 351 | 352 | output = self.postprocess(outputs_merged)[0] 353 | path = os.path.join(self.results_path, name) 354 | print(index, name) 355 | 356 | imsave(output, path) 357 | 358 | if self.debug: 359 | edges = self.postprocess(1 - edges)[0] 360 | masked = self.postprocess(images * (1 - masks) + masks)[0] 361 | fname, fext = name.split('.') 362 | 363 | imsave(edges, os.path.join(self.results_path, fname + '_edge.' + fext)) 364 | imsave(masked, os.path.join(self.results_path, fname + '_masked.' + fext)) 365 | 366 | print('\nEnd test....') 367 | 368 | def sample(self, it=None): 369 | # do not sample when validation set is empty 370 | if len(self.val_dataset) == 0: 371 | return 372 | 373 | self.edge_model.eval() 374 | self.inpaint_model.eval() 375 | 376 | model = self.config.MODEL 377 | items = next(self.sample_iterator) 378 | images, images_gray, edges, masks = self.cuda(*items) 379 | 380 | 381 | # edge model 382 | if model == 1: 383 | iteration = self.edge_model.iteration 384 | inputs = (images_gray * (1 - masks)) + masks 385 | outputs = self.edge_model(images_gray, edges, masks) 386 | outputs_merged = (outputs * masks) + (edges * (1 - masks)) 387 | 388 | # inpaint model 389 | elif model == 2: 390 | iteration = self.inpaint_model.iteration 391 | inputs = (images * (1 - masks)) + masks 392 | outputs = self.inpaint_model(images, edges, masks) 393 | outputs_merged = (outputs * masks) + (images * (1 - masks)) 394 | 395 | # inpaint with edge model / joint model 396 | else: 397 | iteration = self.inpaint_model.iteration 398 | inputs = (images * (1 - masks)) + masks 399 | outputs = self.edge_model(images_gray, edges, masks).detach() 400 | edges = (outputs * masks + edges * (1 - masks)).detach() 401 | outputs = self.inpaint_model(images, edges, masks) 402 | outputs_merged = (outputs * masks) + (images * (1 - masks)) 403 | 404 | if it is not None: 405 | iteration = it 406 | 407 | image_per_row = 2 408 | if self.config.SAMPLE_SIZE <= 6: 409 | image_per_row = 1 410 | 411 | images = stitch_images( 412 | self.postprocess(images), 413 | self.postprocess(inputs), 414 | self.postprocess(edges), 415 | self.postprocess(outputs), 416 | self.postprocess(outputs_merged), 417 | img_per_row = image_per_row 418 | ) 419 | 420 | 421 | path = os.path.join(self.samples_path, self.model_name) 422 | name = os.path.join(path, str(iteration).zfill(5) + ".png") 423 | create_dir(path) 424 | print('\nsaving sample ' + name) 425 | images.save(name) 426 | 427 | def log(self, logs): 428 | with open(self.log_file, 'a') as f: 429 | f.write('%s\n' % ' '.join([str(item[1]) for item in logs])) 430 | 431 | def cuda(self, *args): 432 | return (item.to(self.config.DEVICE) for item in args) 433 | 434 | def postprocess(self, img): 435 | # [0, 1] => [0, 255] 436 | img = img * 255.0 437 | img = img.permute(0, 2, 3, 1) 438 | return img.int() 439 | -------------------------------------------------------------------------------- /MECNet/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | 5 | 6 | class AdversarialLoss(nn.Module): 7 | r""" 8 | Adversarial loss 9 | https://arxiv.org/abs/1711.10337 10 | """ 11 | 12 | def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0): 13 | r""" 14 | type = nsgan | lsgan | hinge 15 | """ 16 | super(AdversarialLoss, self).__init__() 17 | 18 | self.type = type 19 | self.register_buffer('real_label', torch.tensor(target_real_label)) 20 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 21 | 22 | if type == 'nsgan': 23 | self.criterion = nn.BCELoss() 24 | 25 | elif type == 'lsgan': 26 | self.criterion = nn.MSELoss() 27 | 28 | elif type == 'hinge': 29 | self.criterion = nn.ReLU() 30 | 31 | def __call__(self, outputs, is_real, is_disc=None): 32 | if self.type == 'hinge': 33 | if is_disc: 34 | if is_real: 35 | outputs = -outputs 36 | return self.criterion(1 + outputs).mean() 37 | else: 38 | return (-outputs).mean() 39 | 40 | else: 41 | labels = (self.real_label if is_real else self.fake_label).expand_as(outputs) 42 | loss = self.criterion(outputs, labels) 43 | return loss 44 | 45 | 46 | class StyleLoss(nn.Module): 47 | r""" 48 | Perceptual loss, VGG-based 49 | https://arxiv.org/abs/1603.08155 50 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 51 | """ 52 | 53 | def __init__(self): 54 | super(StyleLoss, self).__init__() 55 | self.add_module('vgg', VGG19()) 56 | self.criterion = torch.nn.L1Loss() 57 | 58 | def compute_gram(self, x): 59 | b, ch, h, w = x.size() 60 | f = x.view(b, ch, w * h) 61 | f_T = f.transpose(1, 2) 62 | G = f.bmm(f_T) / (h * w * ch) 63 | 64 | return G 65 | 66 | def __call__(self, x, y): 67 | # Compute features 68 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 69 | 70 | # Compute loss 71 | style_loss = 0.0 72 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2'])) 73 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4'])) 74 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4'])) 75 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2'])) 76 | 77 | return style_loss 78 | 79 | 80 | 81 | class PerceptualLoss(nn.Module): 82 | r""" 83 | Perceptual loss, VGG-based 84 | https://arxiv.org/abs/1603.08155 85 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 86 | """ 87 | 88 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): 89 | super(PerceptualLoss, self).__init__() 90 | self.add_module('vgg', VGG19()) 91 | self.criterion = torch.nn.L1Loss() 92 | self.weights = weights 93 | 94 | def __call__(self, x, y): 95 | # Compute features 96 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 97 | 98 | content_loss = 0.0 99 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1']) 100 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1']) 101 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1']) 102 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1']) 103 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1']) 104 | 105 | 106 | return content_loss 107 | 108 | 109 | 110 | class VGG19(torch.nn.Module): 111 | def __init__(self): 112 | super(VGG19, self).__init__() 113 | features = models.vgg19(pretrained=True).features 114 | self.relu1_1 = torch.nn.Sequential() 115 | self.relu1_2 = torch.nn.Sequential() 116 | 117 | self.relu2_1 = torch.nn.Sequential() 118 | self.relu2_2 = torch.nn.Sequential() 119 | 120 | self.relu3_1 = torch.nn.Sequential() 121 | self.relu3_2 = torch.nn.Sequential() 122 | self.relu3_3 = torch.nn.Sequential() 123 | self.relu3_4 = torch.nn.Sequential() 124 | 125 | self.relu4_1 = torch.nn.Sequential() 126 | self.relu4_2 = torch.nn.Sequential() 127 | self.relu4_3 = torch.nn.Sequential() 128 | self.relu4_4 = torch.nn.Sequential() 129 | 130 | self.relu5_1 = torch.nn.Sequential() 131 | self.relu5_2 = torch.nn.Sequential() 132 | self.relu5_3 = torch.nn.Sequential() 133 | self.relu5_4 = torch.nn.Sequential() 134 | 135 | for x in range(2): 136 | self.relu1_1.add_module(str(x), features[x]) 137 | 138 | for x in range(2, 4): 139 | self.relu1_2.add_module(str(x), features[x]) 140 | 141 | for x in range(4, 7): 142 | self.relu2_1.add_module(str(x), features[x]) 143 | 144 | for x in range(7, 9): 145 | self.relu2_2.add_module(str(x), features[x]) 146 | 147 | for x in range(9, 12): 148 | self.relu3_1.add_module(str(x), features[x]) 149 | 150 | for x in range(12, 14): 151 | self.relu3_2.add_module(str(x), features[x]) 152 | 153 | for x in range(14, 16): 154 | self.relu3_3.add_module(str(x), features[x]) 155 | 156 | for x in range(16, 18): 157 | self.relu3_4.add_module(str(x), features[x]) 158 | 159 | for x in range(18, 21): 160 | self.relu4_1.add_module(str(x), features[x]) 161 | 162 | for x in range(21, 23): 163 | self.relu4_2.add_module(str(x), features[x]) 164 | 165 | for x in range(23, 25): 166 | self.relu4_3.add_module(str(x), features[x]) 167 | 168 | for x in range(25, 27): 169 | self.relu4_4.add_module(str(x), features[x]) 170 | 171 | for x in range(27, 30): 172 | self.relu5_1.add_module(str(x), features[x]) 173 | 174 | for x in range(30, 32): 175 | self.relu5_2.add_module(str(x), features[x]) 176 | 177 | for x in range(32, 34): 178 | self.relu5_3.add_module(str(x), features[x]) 179 | 180 | for x in range(34, 36): 181 | self.relu5_4.add_module(str(x), features[x]) 182 | 183 | # don't need the gradients, just want the features 184 | for param in self.parameters(): 185 | param.requires_grad = False 186 | 187 | def forward(self, x): 188 | relu1_1 = self.relu1_1(x) 189 | relu1_2 = self.relu1_2(relu1_1) 190 | 191 | relu2_1 = self.relu2_1(relu1_2) 192 | relu2_2 = self.relu2_2(relu2_1) 193 | 194 | relu3_1 = self.relu3_1(relu2_2) 195 | relu3_2 = self.relu3_2(relu3_1) 196 | relu3_3 = self.relu3_3(relu3_2) 197 | relu3_4 = self.relu3_4(relu3_3) 198 | 199 | relu4_1 = self.relu4_1(relu3_4) 200 | relu4_2 = self.relu4_2(relu4_1) 201 | relu4_3 = self.relu4_3(relu4_2) 202 | relu4_4 = self.relu4_4(relu4_3) 203 | 204 | relu5_1 = self.relu5_1(relu4_4) 205 | relu5_2 = self.relu5_2(relu5_1) 206 | relu5_3 = self.relu5_3(relu5_2) 207 | relu5_4 = self.relu5_4(relu5_3) 208 | 209 | out = { 210 | 'relu1_1': relu1_1, 211 | 'relu1_2': relu1_2, 212 | 213 | 'relu2_1': relu2_1, 214 | 'relu2_2': relu2_2, 215 | 216 | 'relu3_1': relu3_1, 217 | 'relu3_2': relu3_2, 218 | 'relu3_3': relu3_3, 219 | 'relu3_4': relu3_4, 220 | 221 | 'relu4_1': relu4_1, 222 | 'relu4_2': relu4_2, 223 | 'relu4_3': relu4_3, 224 | 'relu4_4': relu4_4, 225 | 226 | 'relu5_1': relu5_1, 227 | 'relu5_2': relu5_2, 228 | 'relu5_3': relu5_3, 229 | 'relu5_4': relu5_4, 230 | } 231 | return out 232 | -------------------------------------------------------------------------------- /MECNet/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class EdgeAccuracy(nn.Module): 6 | """ 7 | Measures the accuracy of the edge map 8 | """ 9 | def __init__(self, threshold=0.5): 10 | super(EdgeAccuracy, self).__init__() 11 | self.threshold = threshold 12 | 13 | def __call__(self, inputs, outputs): 14 | labels = (inputs > self.threshold) 15 | outputs = (outputs > self.threshold) 16 | 17 | relevant = torch.sum(labels.float()) 18 | selected = torch.sum(outputs.float()) 19 | 20 | if relevant == 0 and selected == 0: 21 | return torch.tensor(1), torch.tensor(1) 22 | 23 | true_positive = ((outputs == labels) * labels).float() 24 | recall = torch.sum(true_positive) / (relevant + 1e-8) 25 | precision = torch.sum(true_positive) / (selected + 1e-8) 26 | 27 | return precision, recall 28 | 29 | 30 | class PSNR(nn.Module): 31 | def __init__(self, max_val): 32 | super(PSNR, self).__init__() 33 | 34 | base10 = torch.log(torch.tensor(10.0)) 35 | max_val = torch.tensor(max_val).float() 36 | 37 | self.register_buffer('base10', base10) 38 | self.register_buffer('max_val', 20 * torch.log(max_val) / base10) 39 | 40 | def __call__(self, a, b): 41 | mse = torch.mean((a.float() - b.float()) ** 2) 42 | 43 | if mse == 0: 44 | return torch.tensor(0) 45 | 46 | return self.max_val - 10 * torch.log(mse) / self.base10 47 | -------------------------------------------------------------------------------- /MECNet/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from .networks2 import InpaintGenerator, EdgeGenerator, Discriminator 6 | from .loss import AdversarialLoss, PerceptualLoss, StyleLoss 7 | 8 | 9 | class BaseModel(nn.Module): 10 | def __init__(self , name, config): 11 | super(BaseModel, self).__init__() 12 | 13 | self.name = name 14 | self.config = config 15 | self.iteration = 0 16 | self.device = 0 17 | self.GPU = [0] 18 | # self.gen_weights_path ="/home/wds/First_Project/LBAM_GRU_version2/checkpoints/psv/EdgeModel_gen.pth" 19 | self.gen_weights_path = "/home/wds/First_Project/edge-connect_psp/checkpoints/psv/EdgeModel_gen_400.pth" 20 | # self.gen_weights_path = "/home/wds/First_Project/edge-connect_psp/checkpoints/psv/EdgeModel_gen_400.pth" 21 | # self.gen_weights_path = '/home/wds/LBAM_version7/EdgeModel_gen.pth' 22 | self.dis_weights_path = os.path.join(config.PATH, name + '_dis.pth') 23 | 24 | def load(self): 25 | print(os.path.exists(self.gen_weights_path)) 26 | if os.path.exists(self.gen_weights_path): 27 | print('Loading %s generator...' % self.name) 28 | 29 | if torch.cuda.is_available(): 30 | data = torch.load(self.gen_weights_path) 31 | else: 32 | data = torch.load(self.gen_weights_path, map_location=lambda storage, loc: storage) 33 | self.generator.load_state_dict(data['generator']) 34 | self.iteration = data['iteration'] 35 | # load discriminator only when training 36 | if self.config.MODE == 1 and os.path.exists(self.dis_weights_path): 37 | print('Loading %s discriminator...' % self.name) 38 | 39 | if torch.cuda.is_available(): 40 | data = torch.load(self.dis_weights_path) 41 | else: 42 | data = torch.load(self.dis_weights_path, map_location=lambda storage, loc: storage) 43 | 44 | self.discriminator.load_state_dict(data['discriminator']) 45 | 46 | def save(self): 47 | print('\nsaving %s...\n' % self.name) 48 | torch.save({ 49 | 'iteration': self.iteration, 50 | 'generator': self.generator.state_dict() 51 | }, self.gen_weights_path) 52 | 53 | torch.save({ 54 | 'discriminator': self.discriminator.state_dict() 55 | }, self.dis_weights_path) 56 | 57 | 58 | class EdgeModel(BaseModel): 59 | def __init__(self, config): 60 | super(EdgeModel, self).__init__('EdgeModel', config) 61 | 62 | # generator input: [grayscale(1) + edge(1) + mask(1)] 63 | # discriminator input: (grayscale(1) + edge(1)) 64 | generator = EdgeGenerator(use_spectral_norm=True) 65 | discriminator = Discriminator(in_channels=2, use_sigmoid=config.GAN_LOSS != 'hinge') 66 | self.device = config.DEVICE 67 | self.GPU = config.GPU 68 | # if len(config.GPU) > 1: 69 | # generator = nn.DataParallel(generator, config.GPU) 70 | # discriminator = nn.DataParallel(discriminator, config.GPU) 71 | l1_loss = nn.L1Loss() 72 | adversarial_loss = AdversarialLoss(type=config.GAN_LOSS) 73 | 74 | self.add_module('generator', generator) 75 | self.add_module('discriminator', discriminator) 76 | 77 | self.add_module('l1_loss', l1_loss) 78 | self.add_module('adversarial_loss', adversarial_loss) 79 | 80 | self.gen_optimizer = optim.Adam( 81 | params=generator.parameters(), 82 | lr=float(config.LR), 83 | betas=(config.BETA1, config.BETA2) 84 | ) 85 | 86 | self.dis_optimizer = optim.Adam( 87 | params=discriminator.parameters(), 88 | lr=float(config.LR) * float(config.D2G_LR), 89 | betas=(config.BETA1, config.BETA2) 90 | ) 91 | 92 | def process(self, images, edges, masks): 93 | self.iteration += 1 94 | 95 | 96 | # zero optimizers 97 | self.gen_optimizer.zero_grad() 98 | self.dis_optimizer.zero_grad() 99 | 100 | 101 | # process outputs 102 | outputs = self(images, edges, masks) 103 | gen_loss = 0 104 | dis_loss = 0 105 | 106 | 107 | # discriminator loss 108 | dis_input_real = torch.cat((images, edges), dim=1) 109 | dis_input_fake = torch.cat((images, outputs.detach()), dim=1) 110 | dis_real, dis_real_feat = self.discriminator(dis_input_real) # in: (grayscale(1) + edge(1)) 111 | dis_fake, dis_fake_feat = self.discriminator(dis_input_fake) # in: (grayscale(1) + edge(1)) 112 | dis_real_loss = self.adversarial_loss(dis_real, True, True) 113 | dis_fake_loss = self.adversarial_loss(dis_fake, False, True) 114 | dis_loss += (dis_real_loss + dis_fake_loss) / 2 115 | 116 | 117 | # generator adversarial loss 118 | gen_input_fake = torch.cat((images, outputs), dim=1) 119 | gen_fake, gen_fake_feat = self.discriminator(gen_input_fake) # in: (grayscale(1) + edge(1)) 120 | gen_gan_loss = self.adversarial_loss(gen_fake, True, False) 121 | gen_loss += gen_gan_loss 122 | 123 | 124 | # generator feature matching loss 125 | gen_fm_loss = 0 126 | for i in range(len(dis_real_feat)): 127 | gen_fm_loss += self.l1_loss(gen_fake_feat[i], dis_real_feat[i].detach()) 128 | gen_fm_loss = gen_fm_loss * self.config.FM_LOSS_WEIGHT 129 | gen_loss += gen_fm_loss 130 | 131 | 132 | # create logs 133 | logs = [ 134 | ("l_d1", dis_loss.item()), 135 | ("l_g1", gen_gan_loss.item()), 136 | ("l_fm", gen_fm_loss.item()), 137 | ] 138 | 139 | return outputs, gen_loss, dis_loss, logs 140 | 141 | def forward(self, images, edges, masks): 142 | edges_masked = (edges * (1 - masks)) 143 | images_masked = (images * (1 - masks)) + masks 144 | inputs = torch.cat((images_masked, edges_masked, masks), dim=1) 145 | outputs = self.generator(inputs) # in: [grayscale(1) + edge(1) + mask(1)] 146 | return outputs 147 | 148 | def backward(self, gen_loss=None, dis_loss=None): 149 | if dis_loss is not None: 150 | dis_loss.backward() 151 | self.dis_optimizer.step() 152 | 153 | if gen_loss is not None: 154 | gen_loss.backward() 155 | self.gen_optimizer.step() 156 | 157 | 158 | class InpaintingModel(BaseModel): 159 | def __init__(self, config): 160 | super(InpaintingModel, self).__init__('InpaintingModel', config) 161 | 162 | # generator input: [rgb(3) + edge(1)] 163 | # discriminator input: [rgb(3)] 164 | generator = InpaintGenerator() 165 | discriminator = Discriminator(in_channels=3, use_sigmoid=config.GAN_LOSS != 'hinge') 166 | if len(config.GPU) > 1: 167 | generator = nn.DataParallel(generator, config.GPU) 168 | discriminator = nn.DataParallel(discriminator , config.GPU) 169 | 170 | l1_loss = nn.L1Loss() 171 | perceptual_loss = PerceptualLoss() 172 | style_loss = StyleLoss() 173 | adversarial_loss = AdversarialLoss(type=config.GAN_LOSS) 174 | 175 | self.add_module('generator', generator) 176 | self.add_module('discriminator', discriminator) 177 | 178 | self.add_module('l1_loss', l1_loss) 179 | self.add_module('perceptual_loss', perceptual_loss) 180 | self.add_module('style_loss', style_loss) 181 | self.add_module('adversarial_loss', adversarial_loss) 182 | 183 | self.gen_optimizer = optim.Adam( 184 | params=generator.parameters(), 185 | lr=float(config.LR), 186 | betas=(config.BETA1, config.BETA2) 187 | ) 188 | 189 | self.dis_optimizer = optim.Adam( 190 | params=discriminator.parameters(), 191 | lr=float(config.LR) * float(config.D2G_LR), 192 | betas=(config.BETA1, config.BETA2) 193 | ) 194 | 195 | def process(self, images, edges, masks): 196 | self.iteration += 1 197 | 198 | # zero optimizers 199 | self.gen_optimizer.zero_grad() 200 | self.dis_optimizer.zero_grad() 201 | 202 | 203 | # process outputs 204 | outputs = self(images, edges, masks) 205 | gen_loss = 0 206 | dis_loss = 0 207 | 208 | 209 | # discriminator loss 210 | dis_input_real = images 211 | dis_input_fake = outputs.detach() 212 | dis_real, _ = self.discriminator(dis_input_real) # in: [rgb(3)] 213 | dis_fake, _ = self.discriminator(dis_input_fake) # in: [rgb(3)] 214 | dis_real_loss = self.adversarial_loss(dis_real, True, True) 215 | dis_fake_loss = self.adversarial_loss(dis_fake, False, True) 216 | dis_loss += (dis_real_loss + dis_fake_loss) / 2 217 | 218 | 219 | # generator adversarial loss 220 | gen_input_fake = outputs 221 | gen_fake, _ = self.discriminator(gen_input_fake) # in: [rgb(3)] 222 | gen_gan_loss = self.adversarial_loss(gen_fake, True, False) * self.config.INPAINT_ADV_LOSS_WEIGHT 223 | gen_loss += gen_gan_loss 224 | 225 | 226 | # generator l1 loss 227 | gen_l1_loss = self.l1_loss(outputs, images) * self.config.L1_LOSS_WEIGHT / torch.mean(masks) 228 | gen_loss += gen_l1_loss 229 | 230 | 231 | # generator perceptual loss 232 | gen_content_loss = self.perceptual_loss(outputs, images) 233 | gen_content_loss = gen_content_loss * self.config.CONTENT_LOSS_WEIGHT 234 | gen_loss += gen_content_loss 235 | 236 | 237 | # generator style loss 238 | gen_style_loss = self.style_loss(outputs * masks, images * masks) 239 | gen_style_loss = gen_style_loss * self.config.STYLE_LOSS_WEIGHT 240 | gen_loss += gen_style_loss 241 | 242 | 243 | # create logs 244 | logs = [ 245 | ("l_d2", dis_loss.item()), 246 | ("l_g2", gen_gan_loss.item()), 247 | ("l_l1", gen_l1_loss.item()), 248 | ("l_per", gen_content_loss.item()), 249 | ("l_sty", gen_style_loss.item()), 250 | ] 251 | 252 | return outputs, gen_loss, dis_loss, logs 253 | 254 | def forward(self, images, edges, masks): 255 | images_masked = (images * (1 - masks).float()) + masks 256 | inputs = torch.cat((images_masked, edges), dim=1) 257 | outputs = self.generator(inputs) # in: [rgb(3) + edge(1)] 258 | return outputs 259 | 260 | def backward(self, gen_loss=None, dis_loss=None): 261 | dis_loss.backward() 262 | self.dis_optimizer.step() 263 | 264 | gen_loss.backward() 265 | self.gen_optimizer.step() 266 | -------------------------------------------------------------------------------- /MECNet/networks2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BaseNetwork(nn.Module): 6 | def __init__(self): 7 | super(BaseNetwork, self).__init__() 8 | 9 | def init_weights(self, init_type='normal', gain=0.02): 10 | ''' 11 | initialize network's weights 12 | init_type: normal | xavier | kaiming | orthogonal 13 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 14 | ''' 15 | 16 | def init_func(m): 17 | classname = m.__class__.__name__ 18 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 19 | if init_type == 'normal': 20 | nn.init.normal_(m.weight.data, 0.0, gain) 21 | elif init_type == 'xavier': 22 | nn.init.xavier_normal_(m.weight.data, gain=gain) 23 | elif init_type == 'kaiming': 24 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 25 | elif init_type == 'orthogonal': 26 | nn.init.orthogonal_(m.weight.data, gain=gain) 27 | 28 | if hasattr(m, 'bias') and m.bias is not None: 29 | nn.init.constant_(m.bias.data, 0.0) 30 | 31 | elif classname.find('BatchNorm2d') != -1: 32 | nn.init.normal_(m.weight.data, 1.0, gain) 33 | nn.init.constant_(m.bias.data, 0.0) 34 | 35 | self.apply(init_func) 36 | 37 | 38 | class InpaintGenerator(BaseNetwork): 39 | def __init__(self, residual_blocks=8, init_weights=True): 40 | super(InpaintGenerator, self).__init__() 41 | 42 | self.encoder = nn.Sequential( 43 | nn.ReflectionPad2d(3), 44 | nn.Conv2d(in_channels=4, out_channels=64, kernel_size=7, padding=0), 45 | nn.InstanceNorm2d(64, track_running_stats=False), 46 | nn.ReLU(True), 47 | 48 | nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), 49 | nn.InstanceNorm2d(128, track_running_stats=False), 50 | nn.ReLU(True), 51 | 52 | nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), 53 | nn.InstanceNorm2d(256, track_running_stats=False), 54 | nn.ReLU(True) 55 | ) 56 | 57 | blocks = [] 58 | for _ in range(residual_blocks): 59 | block = ResnetBlock(256, 2) 60 | blocks.append(block) 61 | 62 | self.middle = nn.Sequential(*blocks) 63 | 64 | self.decoder = nn.Sequential( 65 | nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), 66 | nn.InstanceNorm2d(128, track_running_stats=False), 67 | nn.ReLU(True), 68 | 69 | nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), 70 | nn.InstanceNorm2d(64, track_running_stats=False), 71 | nn.ReLU(True), 72 | 73 | nn.ReflectionPad2d(3), 74 | nn.Conv2d(in_channels=64, out_channels=3, kernel_size=7, padding=0), 75 | ) 76 | 77 | if init_weights: 78 | self.init_weights() 79 | 80 | def forward(self, x): 81 | x = self.encoder(x) 82 | x = self.middle(x) 83 | x = self.decoder(x) 84 | x = (torch.tanh(x) + 1) / 2 85 | 86 | return x 87 | 88 | 89 | class EdgeGenerator(BaseNetwork): 90 | def __init__(self, residual_blocks=8, use_spectral_norm=True, init_weights=True): 91 | super(EdgeGenerator, self).__init__() 92 | 93 | self.encoder = nn.Sequential( 94 | nn.ReflectionPad2d(3), 95 | spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=0), use_spectral_norm), 96 | nn.InstanceNorm2d(64, track_running_stats=False), 97 | nn.ReLU(True), 98 | 99 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm), 100 | nn.InstanceNorm2d(128, track_running_stats=False), 101 | nn.ReLU(True), 102 | 103 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), use_spectral_norm), 104 | nn.InstanceNorm2d(256, track_running_stats=False), 105 | nn.ReLU(True) 106 | ) 107 | self.pool1 = nn.AdaptiveAvgPool2d(8) 108 | self.pool2 = nn.AdaptiveAvgPool2d(16) 109 | self.pool3 = nn.AdaptiveAvgPool2d(32) 110 | self.upsample1 = nn.Sequential( 111 | nn.ConvTranspose2d(in_channels=256, out_channels=256,kernel_size=4,stride=2,padding=1), 112 | nn.InstanceNorm2d(128,track_running_stats=False), 113 | nn.ReLU(True) 114 | ) 115 | self.upsample2 = nn.Sequential( 116 | nn.ConvTranspose2d(in_channels=512,out_channels=256,kernel_size=4,stride=2,padding=1), 117 | nn.InstanceNorm2d(256,track_running_stats=False), 118 | nn.ReLU(True) 119 | ) 120 | self.upsample3 = nn.Sequential( 121 | nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1), 122 | nn.InstanceNorm2d(256, track_running_stats=False), 123 | nn.ReLU(True) 124 | ) 125 | 126 | 127 | # self.conv2 128 | blocks = [] 129 | for _ in range(residual_blocks): 130 | block = ResnetBlock(256, 2, use_spectral_norm=use_spectral_norm) 131 | blocks.append(block) 132 | 133 | self.middle = nn.Sequential(*blocks) 134 | 135 | self.decoder = nn.Sequential( 136 | spectral_norm(nn.ConvTranspose2d(in_channels=512, out_channels=128, kernel_size=4, stride=2, padding=1), use_spectral_norm), 137 | nn.InstanceNorm2d(128, track_running_stats=False), 138 | nn.ReLU(True), 139 | 140 | spectral_norm(nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), use_spectral_norm), 141 | nn.InstanceNorm2d(64, track_running_stats=False), 142 | nn.ReLU(True), 143 | 144 | nn.ReflectionPad2d(3), 145 | nn.Conv2d(in_channels=64, out_channels=1, kernel_size=7, padding=0), 146 | ) 147 | 148 | if init_weights: 149 | self.init_weights() 150 | 151 | def forward(self, x): 152 | print(x.shape) 153 | for i in range(len(self.encoder)): 154 | x = self.encoder[i](x) 155 | print(i,'x.shape:\n\t',x.shape) 156 | # x = self.encoder(x) 157 | # print(x.shape) 158 | x_8 = self.pool1(x) 159 | x_16 = self.pool2(x) 160 | x_32 = self.pool3(x) 161 | print(x_8.shape) 162 | x = self.middle(x) 163 | x_8 = self.middle(x_8) 164 | x_16 = self.middle(x_16) 165 | x_32 = self.middle(x_32) 166 | x_8 = self.upsample1(x_8) 167 | x_16 = torch.cat((x_16,x_8),dim=1) 168 | x_16 = self.upsample2(x_16) 169 | x_32 = torch.cat((x_32,x_16),dim=1) 170 | x_32 = self.upsample3(x_32) 171 | x = torch.cat((x,x_32),dim=1) 172 | x = self.decoder(x) 173 | x = torch.sigmoid(x) 174 | return x 175 | 176 | 177 | class Discriminator(BaseNetwork): 178 | def __init__(self, in_channels, use_sigmoid=True, use_spectral_norm=True, init_weights=True): 179 | super(Discriminator, self).__init__() 180 | self.use_sigmoid = use_sigmoid 181 | 182 | self.conv1 = self.features = nn.Sequential( 183 | spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), 184 | nn.LeakyReLU(0.2, inplace=True), 185 | ) 186 | 187 | self.conv2 = nn.Sequential( 188 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), 189 | nn.LeakyReLU(0.2, inplace=True), 190 | ) 191 | 192 | self.conv3 = nn.Sequential( 193 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), 194 | nn.LeakyReLU(0.2, inplace=True), 195 | ) 196 | 197 | self.conv4 = nn.Sequential( 198 | spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm), 199 | nn.LeakyReLU(0.2, inplace=True), 200 | ) 201 | 202 | self.conv5 = nn.Sequential( 203 | spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm), 204 | ) 205 | 206 | if init_weights: 207 | self.init_weights() 208 | 209 | def forward(self, x): 210 | conv1 = self.conv1(x) 211 | conv2 = self.conv2(conv1) 212 | conv3 = self.conv3(conv2) 213 | conv4 = self.conv4(conv3) 214 | conv5 = self.conv5(conv4) 215 | 216 | outputs = conv5 217 | if self.use_sigmoid: 218 | outputs = torch.sigmoid(conv5) 219 | 220 | return outputs, [conv1, conv2, conv3, conv4, conv5] 221 | 222 | 223 | class ResnetBlock(nn.Module): 224 | def __init__(self, dim, dilation=1, use_spectral_norm=False): 225 | super(ResnetBlock, self).__init__() 226 | self.conv_block = nn.Sequential( 227 | nn.ReflectionPad2d(dilation), 228 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm), 229 | nn.InstanceNorm2d(dim, track_running_stats=False), 230 | nn.ReLU(True), 231 | 232 | nn.ReflectionPad2d(1), 233 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm), 234 | nn.InstanceNorm2d(dim, track_running_stats=False), 235 | ) 236 | 237 | def forward(self, x): 238 | out = x + self.conv_block(x) 239 | 240 | # Remove ReLU at the end of the residual block 241 | # http://torch.ch/blog/2016/02/04/resnets.html 242 | 243 | return out 244 | 245 | 246 | def spectral_norm(module, mode=True): 247 | if mode: 248 | return nn.utils.spectral_norm(module) 249 | 250 | return module 251 | -------------------------------------------------------------------------------- /MECNet/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import random 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from PIL import Image 8 | 9 | 10 | def create_dir(dir): 11 | if not os.path.exists(dir): 12 | os.makedirs(dir) 13 | 14 | 15 | def create_mask(width, height, mask_width, mask_height, x=None, y=None): 16 | mask = np.zeros((height, width)) 17 | mask_x = x if x is not None else random.randint(0, width - mask_width) 18 | mask_y = y if y is not None else random.randint(0, height - mask_height) 19 | mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1 20 | return mask 21 | 22 | 23 | def stitch_images(inputs, *outputs, img_per_row=2): 24 | gap = 5 25 | columns = len(outputs) + 1 26 | 27 | width, height = inputs[0][:, :, 0].shape 28 | img = Image.new('RGB', (width * img_per_row * columns + gap * (img_per_row - 1), height * int(len(inputs) / img_per_row))) 29 | images = [inputs, *outputs] 30 | 31 | for ix in range(len(inputs)): 32 | xoffset = int(ix % img_per_row) * width * columns + int(ix % img_per_row) * gap 33 | yoffset = int(ix / img_per_row) * height 34 | 35 | for cat in range(len(images)): 36 | im = np.array((images[cat][ix]).cpu()).astype(np.uint8).squeeze() 37 | im = Image.fromarray(im) 38 | img.paste(im, (xoffset + cat * width, yoffset)) 39 | 40 | return img 41 | 42 | 43 | def imshow(img, title=''): 44 | fig = plt.gcf() 45 | fig.canvas.set_window_title(title) 46 | plt.axis('off') 47 | plt.imshow(img, interpolation='none') 48 | plt.show() 49 | 50 | 51 | def imsave(img, path): 52 | im = Image.fromarray(img.cpu().numpy().astype(np.uint8).squeeze()) 53 | im.save(path) 54 | 55 | 56 | class Progbar(object): 57 | """Displays a progress bar. 58 | 59 | Arguments: 60 | target: Total number of steps expected, None if unknown. 61 | width: Progress bar width on screen. 62 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 63 | stateful_metrics: Iterable of string names of metrics that 64 | should *not* be averaged over time. Metrics in this list 65 | will be displayed as-is. All others will be averaged 66 | by the progbar before display. 67 | interval: Minimum visual progress update interval (in seconds). 68 | """ 69 | 70 | def __init__(self, target, width=25, verbose=1, interval=0.05, 71 | stateful_metrics=None): 72 | self.target = target 73 | self.width = width 74 | self.verbose = verbose 75 | self.interval = interval 76 | if stateful_metrics: 77 | self.stateful_metrics = set(stateful_metrics) 78 | else: 79 | self.stateful_metrics = set() 80 | 81 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 82 | sys.stdout.isatty()) or 83 | 'ipykernel' in sys.modules or 84 | 'posix' in sys.modules) 85 | self._total_width = 0 86 | self._seen_so_far = 0 87 | # We use a dict + list to avoid garbage collection 88 | # issues found in OrderedDict 89 | self._values = {} 90 | self._values_order = [] 91 | self._start = time.time() 92 | self._last_update = 0 93 | 94 | def update(self, current, values=None): 95 | """Updates the progress bar. 96 | 97 | Arguments: 98 | current: Index of current step. 99 | values: List of tuples: 100 | `(name, value_for_last_step)`. 101 | If `name` is in `stateful_metrics`, 102 | `value_for_last_step` will be displayed as-is. 103 | Else, an average of the metric over time will be displayed. 104 | """ 105 | values = values or [] 106 | for k, v in values: 107 | if k not in self._values_order: 108 | self._values_order.append(k) 109 | if k not in self.stateful_metrics: 110 | if k not in self._values: 111 | self._values[k] = [v * (current - self._seen_so_far), 112 | current - self._seen_so_far] 113 | else: 114 | self._values[k][0] += v * (current - self._seen_so_far) 115 | self._values[k][1] += (current - self._seen_so_far) 116 | else: 117 | self._values[k] = v 118 | self._seen_so_far = current 119 | 120 | now = time.time() 121 | info = ' - %.0fs' % (now - self._start) 122 | if self.verbose == 1: 123 | if (now - self._last_update < self.interval and 124 | self.target is not None and current < self.target): 125 | return 126 | 127 | prev_total_width = self._total_width 128 | if self._dynamic_display: 129 | sys.stdout.write('\b' * prev_total_width) 130 | sys.stdout.write('\r') 131 | else: 132 | sys.stdout.write('\n') 133 | 134 | if self.target is not None: 135 | numdigits = int(np.floor(np.log10(self.target))) + 1 136 | barstr = '%%%dd/%d [' % (numdigits, self.target) 137 | bar = barstr % current 138 | prog = float(current) / self.target 139 | prog_width = int(self.width * prog) 140 | if prog_width > 0: 141 | bar += ('=' * (prog_width - 1)) 142 | if current < self.target: 143 | bar += '>' 144 | else: 145 | bar += '=' 146 | bar += ('.' * (self.width - prog_width)) 147 | bar += ']' 148 | else: 149 | bar = '%7d/Unknown' % current 150 | 151 | self._total_width = len(bar) 152 | sys.stdout.write(bar) 153 | 154 | if current: 155 | time_per_unit = (now - self._start) / current 156 | else: 157 | time_per_unit = 0 158 | if self.target is not None and current < self.target: 159 | eta = time_per_unit * (self.target - current) 160 | if eta > 3600: 161 | eta_format = '%d:%02d:%02d' % (eta // 3600, 162 | (eta % 3600) // 60, 163 | eta % 60) 164 | elif eta > 60: 165 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 166 | else: 167 | eta_format = '%ds' % eta 168 | 169 | info = ' - ETA: %s' % eta_format 170 | else: 171 | if time_per_unit >= 1: 172 | info += ' %.0fs/step' % time_per_unit 173 | elif time_per_unit >= 1e-3: 174 | info += ' %.0fms/step' % (time_per_unit * 1e3) 175 | else: 176 | info += ' %.0fus/step' % (time_per_unit * 1e6) 177 | 178 | for k in self._values_order: 179 | info += ' - %s:' % k 180 | if isinstance(self._values[k], list): 181 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 182 | if abs(avg) > 1e-3: 183 | info += ' %.4f' % avg 184 | else: 185 | info += ' %.4e' % avg 186 | else: 187 | info += ' %s' % self._values[k] 188 | 189 | self._total_width += len(info) 190 | if prev_total_width > self._total_width: 191 | info += (' ' * (prev_total_width - self._total_width)) 192 | 193 | if self.target is not None and current >= self.target: 194 | info += '\n' 195 | 196 | sys.stdout.write(info) 197 | sys.stdout.flush() 198 | 199 | elif self.verbose == 2: 200 | if self.target is None or current >= self.target: 201 | for k in self._values_order: 202 | info += ' - %s:' % k 203 | avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) 204 | if avg > 1e-3: 205 | info += ' %.4f' % avg 206 | else: 207 | info += ' %.4e' % avg 208 | info += '\n' 209 | 210 | sys.stdout.write(info) 211 | sys.stdout.flush() 212 | 213 | self._last_update = now 214 | 215 | def add(self, n, values=None): 216 | self.update(self._seen_so_far + n, values) 217 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Edge-LBAM 2 | Pytorch implementation of paper "Image Inpainting with Edge-guided Learnable Bidirectional Attention Maps" 3 | 4 | ## Description 5 | 6 | This paper is an extension of our previous work. In comparison to [LBAM](https://openaccess.thecvf.com/content_ICCV_2019/papers/Xie_Image_Inpainting_With_Learnable_Bidirectional_Attention_Maps_ICCV_2019_paper.pdf) we utilize both the mask of holes 7 | and predicted edge map for mask-updating, resulting in our Edge-LBAM method. Moreover, we introduce a multi-scale 8 | edge completion network for effective prediction of coherent edges. 9 | 10 | ## Prerequisites 11 | 12 | - Python 3.6 13 | - Pytorch =1.1.0 14 | - CPU or NVIDIA GPU + Cuda + Cudnn 15 | 16 | ## Training 17 | 18 | 19 | To train the Edge-LBAM model: 20 | 21 | ``` 22 | python train.py --batchSize numOf_batch_size --dataRoot your_image_path \ 23 | --maskRoot your_mask_root --modelsSavePath path_to_save_your_model \ 24 | --logPath path_to_save_tensorboard_log --pretrain(optional) pretrained_model_path 25 | ``` 26 | 27 | ## Testing 28 | 29 | To test with random batch with random masks: 30 | 31 | ``` 32 | python test_random_batch.py --dataRoot your_image_path 33 | --maskRoot your_mask_path --batchSize numOf_batch_size --pretrain pretrained_model_path 34 | ``` 35 | 36 | ## Pretrained Models 37 | 38 | The pretrained models can be found at [google drive](https://drive.google.com/drive/folders/1iilIU0U7fOYjYlRB7bZjN5oLNCeLoW-R?usp=sharing), we will release the models removing bn from Edge-LBAM later which may effect better. You can also train the model by yourself. 39 | 40 | ## Results 41 | 42 | #### Inpainting 43 | From left to right are input, the result of Global&Local,PConv,DeepFillv2. 44 | 45 | 46 | 47 | From left to right are the result of Edge-Connect, MEDFE, Ours and GT. 48 | 49 | 50 | 51 | ### MECNet 52 | From left to right are input, edge competion of single-scale network and multi-scale network. 53 | 54 | 55 | -------------------------------------------------------------------------------- /config.yml: -------------------------------------------------------------------------------- 1 | MODE: 1 # 1: train, 2: test, 3: eval 2 | MODEL: 1 # 1: edge model 3 | MASK: 3 # 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half) 4 | EDGE: 1 # 1: canny, 2: external 5 | NMS: 1 # 0: no non-max-suppression, 1: applies non-max-suppression on the external edges by multiplying by Canny 6 | SEED: 10 # random seed 7 | GPU: [1] # list of gpu ids 8 | DEBUG: 0 # turns on debugging mode 9 | VERBOSE: 0 # turns on verbose mode in the output console 10 | 11 | TRAIN_FLIST: ./datasets/places2_train.flist 12 | VAL_FLIST: ./datasets/places2_val.flist 13 | TEST_FLIST: ./datasets/places2_test.flist 14 | 15 | TRAIN_EDGE_FLIST: ./datasets/places2_edges_train.flist 16 | VAL_EDGE_FLIST: ./datasets/places2_edges_val.flist 17 | TEST_EDGE_FLIST: ./datasets/places2_edges_test.flist 18 | 19 | TRAIN_MASK_FLIST: ./datasets/masks_train.flist 20 | VAL_MASK_FLIST: ./datasets/masks_val.flist 21 | TEST_MASK_FLIST: ./datasets/masks_test.flist 22 | 23 | LR: 0.0001 # learning rate 24 | D2G_LR: 0.1 # discriminator/generator learning rate ratio 25 | BETA1: 0.0 # adam optimizer beta1 26 | BETA2: 0.9 # adam optimizer beta2 27 | BATCH_SIZE: 8 # input batch size for training 28 | INPUT_SIZE: 256 # input image size for training 0 for original size 29 | SIGMA: 2 # standard deviation of the Gaussian filter used in Canny edge detector (0: random, -1: no edge) 30 | MAX_ITERS: 2e6 # maximum number of iterations to train the model 31 | 32 | EDGE_THRESHOLD: 0.5 # edge detection threshold 33 | L1_LOSS_WEIGHT: 1 # l1 loss weight 34 | FM_LOSS_WEIGHT: 10 # feature-matching loss weight 35 | STYLE_LOSS_WEIGHT: 250 # style loss weight 36 | CONTENT_LOSS_WEIGHT: 0.1 # perceptual loss weight 37 | INPAINT_ADV_LOSS_WEIGHT: 0.1 # adversarial loss weight 38 | 39 | GAN_LOSS: nsgan # nsgan | lsgan | hinge 40 | GAN_POOL_SIZE: 0 # fake images pool size 41 | 42 | SAVE_INTERVAL: 1000 # how many iterations to wait before saving model (0: never) 43 | SAMPLE_INTERVAL: 1000 # how many iterations to wait before sampling (0: never) 44 | SAMPLE_SIZE: 12 # number of images to sample 45 | EVAL_INTERVAL: 0 # how many iterations to wait before model evaluation (0: never) 46 | LOG_INTERVAL: 10 # how many iterations to wait before logging training status (0: never) 47 | -------------------------------------------------------------------------------- /data/__pycache__/basicFunction.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/data/__pycache__/basicFunction.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/basicFunction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/data/__pycache__/basicFunction.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/data/__pycache__/dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataloader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/data/__pycache__/dataloader.cpython-37.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataloader_canny.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/data/__pycache__/dataloader_canny.cpython-36.pyc -------------------------------------------------------------------------------- /data/__pycache__/dataloader_canny.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/data/__pycache__/dataloader_canny.cpython-37.pyc -------------------------------------------------------------------------------- /data/basicFunction.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torchvision.transforms import Compose, RandomCrop, ToTensor, ToPILImage, Resize, RandomHorizontalFlip 3 | 4 | def CheckImageFile(filename): 5 | return any(filename.endswith(extention) for extention in ['.png', '.PNG', '.jpg', '.JPG', '.jpeg', '.JPEG', '.bmp', '.BMP']) 6 | 7 | def ImageTransform(loadSize, cropSize): 8 | return Compose([ 9 | Resize(size=loadSize, interpolation=Image.BICUBIC), 10 | RandomHorizontalFlip(p=0.5), 11 | RandomCrop(size=cropSize), 12 | ToTensor(), 13 | ]) 14 | 15 | def MaskTransform(cropSize): 16 | return Compose([ 17 | Resize(size=cropSize, interpolation=Image.NEAREST), 18 | ToTensor(), 19 | ]) 20 | 21 | # this was image transforms function for paired image and mask, which means that damaged image and the 22 | # mask are in pairs, the input image already contains damaged area with (ones or zeros), 23 | # we suggest that you resize the input image with "NEAREST" not BICUBIC(or other) algorithm, 24 | # is is not guaranteed, but in some cases, the damaged portion might go out of the mask region, if you perform other resize methods 25 | def PairedImageTransform(cropSize): 26 | return Compose([ 27 | Resize(size=cropSize, interpolation=Image.NEAREST), 28 | ToTensor(), 29 | ]) -------------------------------------------------------------------------------- /data/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from os import listdir, walk 4 | from os.path import join 5 | from random import randint 6 | from data.basicFunction import CheckImageFile, ImageTransform, MaskTransform 7 | import numpy as np 8 | import torchvision.transforms.functional as F 9 | import random 10 | from skimage.feature import canny 11 | from skimage.color import rgb2gray 12 | from shutil import copyfile 13 | from scipy.misc import imread 14 | import matplotlib.pyplot as plt 15 | from torch.utils.data import Dataset 16 | 17 | 18 | class GetData(Dataset): 19 | def __init__(self, dataRoot, maskRoot, loadSize, cropSize): 20 | super(GetData, self).__init__() 21 | 22 | self.imageFiles = [join(dataRootK, files) for dataRootK, dn, filenames in walk(dataRoot) \ 23 | for files in filenames if CheckImageFile(files)] 24 | self.masks = [join(dataRootK, files) for dataRootK, dn, filenames in walk(maskRoot) \ 25 | for files in filenames if CheckImageFile(files)] 26 | self.numOfMasks = len(self.masks) 27 | self.loadSize = loadSize 28 | self.cropSize = cropSize 29 | self.ImgTrans = ImageTransform(loadSize, cropSize) 30 | self.maskTrans = MaskTransform(cropSize) 31 | self.sigma = 1.5 32 | 33 | def __getitem__(self, index): 34 | img = Image.open(self.imageFiles[index]) 35 | randnum = randint(0, self.numOfMasks - 1) 36 | # mask = Image.open(self.imageFiles[index].replace("GT","mask")) 37 | mask = Image.open(self.masks[randnum]) 38 | groundTruth = self.ImgTrans(img.convert('RGB')) 39 | mask = self.maskTrans(mask.convert('RGB')) 40 | # we add this threshhold to force the input mask to be binary 0,1 values 41 | # the threshhold value can be changeble, i think 0.5 is ok 42 | threshhold = 0.5 43 | ones = mask >= threshhold 44 | zeros = mask < threshhold 45 | 46 | mask.masked_fill_(ones, 1.0) 47 | mask.masked_fill_(zeros, 0.0) 48 | 49 | # here, we suggest that the white values(ones) denotes the area to be inpainted, 50 | # and dark values(zeros) is the values remained. 51 | # Therefore, we do a reverse step let mask = 1 - mask, the input = groundTruth * mask, :). 52 | edge_mask = np.transpose(mask, (1, 2, 0)) 53 | mask = 1 - mask 54 | inputImage = groundTruth * mask 55 | edge_mask = edge_mask.numpy() 56 | 57 | edge_mask = rgb2gray(edge_mask) 58 | edge_mask2 = (edge_mask > 0).astype(np.uint8) * 255 # threshold due to interpolation 59 | 60 | tmp = np.transpose(groundTruth, (1, 2, 0)) 61 | tmp = tmp.numpy() 62 | img_gray = rgb2gray(tmp) 63 | 64 | edge = self.load_edge(img_gray, np.array(1 - edge_mask2 / 255).astype(np.bool)) 65 | img_gray = torch.from_numpy(img_gray.reshape((1, 256, 256))) 66 | edge_mask = torch.from_numpy((edge_mask).reshape((1,256,256))) 67 | 68 | edge = torch.from_numpy(edge.reshape((1, 256, 256))).float() 69 | inputImage = torch.cat((inputImage, mask[0].view(1, 256, 256)), 0) 70 | 71 | return inputImage,groundTruth, mask, img_gray, edge, edge_mask.float(),self.imageFiles[index] 72 | 73 | def __len__(self): 74 | return len(self.imageFiles) 75 | 76 | def to_tensor(self, img): 77 | img = Image.fromarray(img) 78 | img_t = F.to_tensor(img).float() 79 | return img_t 80 | 81 | def load_edge(self, img, mask): 82 | sigma = self.sigma 83 | # in test mode images are masked (with masked regions), 84 | # using 'mask' parameter prevents canny to detect edges for the masked regions 85 | 86 | # canny 87 | 88 | # no edge 89 | if sigma == -1: 90 | return np.zeros(img.shape).astype(np.float) 91 | 92 | # random sigma 93 | if sigma == 0: 94 | sigma = random.randint(1, 4) 95 | 96 | return canny(img, sigma=sigma, mask=mask).astype(np.float) 97 | 98 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /data/dataloader_canny.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from PIL import Image 4 | from os import listdir, walk 5 | from os.path import join 6 | from random import randint 7 | from skimage.feature import canny 8 | from skimage.color import rgb2gray 9 | import numpy as np 10 | from data.basicFunction import CheckImageFile, ImageTransform, MaskTransform 11 | import matplotlib.pyplot as plt 12 | class GetData(Dataset): 13 | def __init__(self, dataRoot, maskRoot, loadSize, cropSize): 14 | super(GetData, self).__init__() 15 | 16 | self.imageFiles = [join (dataRootK, files) for dataRootK, dn, filenames in walk(dataRoot) \ 17 | for files in filenames if CheckImageFile(files)] 18 | self.masks = [join (dataRootK, files) for dataRootK, dn, filenames in walk(maskRoot) \ 19 | for files in filenames if CheckImageFile(files)] 20 | self.numOfMasks = len(self.masks) 21 | self.loadSize = loadSize 22 | self.cropSize = cropSize 23 | self.ImgTrans = ImageTransform(loadSize, cropSize) 24 | self.maskTrans = MaskTransform(cropSize) 25 | 26 | def __getitem__(self, index): 27 | img = Image.open(self.imageFiles[index]) 28 | mask = Image.open(self.masks[randint(0, self.numOfMasks - 1)]) 29 | 30 | groundTruth = self.ImgTrans(img.convert('RGB')) 31 | mask = self.maskTrans(mask.convert('RGB')) 32 | # we add this threshhold to force the input mask to be binary 0,1 values 33 | # the threshhold value can be changeble, i think 0.5 is ok 34 | threshhold = 0.5 35 | ones = mask >= threshhold 36 | zeros = mask < threshhold 37 | 38 | mask.masked_fill_(ones, 1.0) 39 | mask.masked_fill_(zeros, 0.0) 40 | 41 | # here, we suggest that the white values(ones) denotes the area to be inpainted, 42 | # and dark values(zeros) is the values remained. 43 | # Therefore, we do a reverse step let mask = 1 - mask, the input = groundTruth * mask, :). 44 | mask = 1 - mask 45 | inputImage = groundTruth * mask 46 | tmp = np.transpose(groundTruth, (1, 2, 0)) 47 | tmp = tmp.numpy() 48 | tmp = rgb2gray(tmp) 49 | edge = canny(tmp, sigma=1.5).astype(np.float32) 50 | edge = torch.from_numpy(edge.reshape((1, 256, 256))).float() 51 | inputImage = torch.cat((inputImage, mask[0].view(1, self.cropSize[0], self.cropSize[1])), 0) 52 | 53 | return inputImage, groundTruth, mask, edge 54 | 55 | def __len__(self): 56 | return len(self.imageFiles) -------------------------------------------------------------------------------- /examples/GT28-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/GT28-1.png -------------------------------------------------------------------------------- /examples/MEDFE28-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/MEDFE28-1.png -------------------------------------------------------------------------------- /examples/ec28-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/ec28-1.png -------------------------------------------------------------------------------- /examples/edge_mecnet(s)_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/edge_mecnet(s)_1.png -------------------------------------------------------------------------------- /examples/edge_mecnet_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/edge_mecnet_1.png -------------------------------------------------------------------------------- /examples/gc28-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/gc28-1.png -------------------------------------------------------------------------------- /examples/gl28-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/gl28-1.png -------------------------------------------------------------------------------- /examples/input1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/input1.png -------------------------------------------------------------------------------- /examples/input28-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/input28-1.png -------------------------------------------------------------------------------- /examples/ours28-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/ours28-1.png -------------------------------------------------------------------------------- /examples/pconv28-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/examples/pconv28-1.png -------------------------------------------------------------------------------- /loss/InpaintingLoss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch import autograd 4 | from tensorboardX import SummaryWriter 5 | from models.discriminator import DiscriminatorDoubleColumn 6 | 7 | # modified from WGAN-GP 8 | def calc_gradient_penalty(netD, real_data, fake_data, masks, cuda, Lambda): 9 | BATCH_SIZE = real_data.size()[0] 10 | DIM = real_data.size()[2] 11 | alpha = torch.rand(BATCH_SIZE, 1) 12 | alpha = alpha.expand(BATCH_SIZE, int(real_data.nelement()/BATCH_SIZE)).contiguous() 13 | alpha = alpha.view(BATCH_SIZE, 3, DIM, DIM) 14 | if cuda: 15 | alpha = alpha.cuda() 16 | 17 | fake_data = fake_data.view(BATCH_SIZE, 3, DIM, DIM) 18 | interpolates = alpha * real_data.detach() + ((1 - alpha) * fake_data.detach()) 19 | 20 | if cuda: 21 | interpolates = interpolates.cuda() 22 | interpolates.requires_grad_(True) 23 | 24 | disc_interpolates = netD(interpolates, masks) 25 | 26 | gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, 27 | grad_outputs=torch.ones(disc_interpolates.size()).cuda() if cuda else torch.ones(disc_interpolates.size()), 28 | create_graph=True, retain_graph=True, only_inputs=True)[0] 29 | 30 | gradients = gradients.view(gradients.size(0), -1) 31 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * Lambda 32 | return gradient_penalty.sum().mean() 33 | 34 | 35 | def gram_matrix(feat): 36 | # https://github.com/pytorch/examples/blob/master/fast_neural_style/neural_style/utils.py 37 | (b, ch, h, w) = feat.size() 38 | feat = feat.view(b, ch, h * w) 39 | feat_t = feat.transpose(1, 2) 40 | gram = torch.bmm(feat, feat_t) / (ch * h * w) 41 | return gram 42 | 43 | 44 | #tv loss 45 | def total_variation_loss(image): 46 | # shift one pixel and get difference (for both x and y direction) 47 | loss = torch.mean(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + \ 48 | torch.mean(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :])) 49 | return loss 50 | 51 | 52 | 53 | class InpaintingLossWithGAN(nn.Module): 54 | def __init__(self, logPath, extractor, Lamda, lr, betasInit=(0.5, 0.9)): 55 | super(InpaintingLossWithGAN, self).__init__() 56 | self.l1 = nn.L1Loss() 57 | self.extractor = extractor 58 | self.discriminator = DiscriminatorDoubleColumn(3) 59 | self.D_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=betasInit) 60 | self.cudaAvailable = torch.cuda.is_available() 61 | self.numOfGPUs = torch.cuda.device_count() 62 | """ if (self.numOfGPUs > 1): 63 | self.discriminator = self.discriminator.cuda() 64 | self.discriminator = nn.DataParallel(self.discriminator, device_ids=range(self.numOfGPUs)) """ 65 | self.lamda = Lamda 66 | self.writer = SummaryWriter(logPath) 67 | 68 | def forward(self, input, mask, output, gt, count, epoch): 69 | self.discriminator.zero_grad() 70 | D_real = self.discriminator(gt, mask) 71 | D_real = D_real.mean().sum() * -1 72 | D_fake = self.discriminator(output, mask) 73 | D_fake = D_fake.mean().sum() * 1 74 | gp = calc_gradient_penalty(self.discriminator, gt, output, mask, self.cudaAvailable, self.lamda) 75 | D_loss = D_fake - D_real + gp 76 | self.D_optimizer.zero_grad() 77 | D_loss.backward(retain_graph=True) 78 | self.D_optimizer.step() 79 | 80 | self.writer.add_scalar('LossD/Discrinimator loss', D_loss.item(), count) 81 | 82 | output_comp = mask * input + (1 - mask) * output 83 | 84 | holeLoss = 6 * self.l1((1 - mask) * output, (1 - mask) * gt) 85 | validAreaLoss = self.l1(mask * output, mask * gt) 86 | 87 | if output.shape[1] == 3: 88 | feat_output_comp = self.extractor(output_comp) 89 | feat_output = self.extractor(output) 90 | feat_gt = self.extractor(gt) 91 | elif output.shape[1] == 1: 92 | feat_output_comp = self.extractor(torch.cat([output_comp]*3, 1)) 93 | feat_output = self.extractor(torch.cat([output]*3, 1)) 94 | feat_gt = self.extractor(torch.cat([gt]*3, 1)) 95 | else: 96 | raise ValueError('only gray an') 97 | 98 | prcLoss = 0.0 99 | for i in range(3): 100 | prcLoss += 0.005 * self.l1(feat_output[i], feat_gt[i]) 101 | prcLoss += 0.005 * self.l1(feat_output_comp[i], feat_gt[i]) 102 | 103 | styleLoss = 0.0 104 | for i in range(3): 105 | styleLoss += 120 * self.l1(gram_matrix(feat_output[i]), 106 | gram_matrix(feat_gt[i])) 107 | styleLoss += 120 * self.l1(gram_matrix(feat_output_comp[i]), 108 | gram_matrix(feat_gt[i])) 109 | 110 | """ if self.numOfGPUs > 1: 111 | holeLoss = holeLoss.sum() / self.numOfGPUs 112 | validAreaLoss = validAreaLoss.sum() / self.numOfGPUs 113 | prcLoss = prcLoss.sum() / self.numOfGPUs 114 | styleLoss = styleLoss.sum() / self.numOfGPUs """ 115 | self.writer.add_scalar('LossG/Hole loss', holeLoss.item(), count) 116 | self.writer.add_scalar('LossG/Valid loss', validAreaLoss.item(), count) 117 | self.writer.add_scalar('LossPrc/Perceptual loss', prcLoss.item(), count) 118 | self.writer.add_scalar('LossStyle/style loss', styleLoss.item(), count) 119 | 120 | GLoss = holeLoss + validAreaLoss + prcLoss + styleLoss + 0.1 * D_fake 121 | self.writer.add_scalar('Generator/Joint loss', GLoss.item(), count) 122 | return GLoss.sum() -------------------------------------------------------------------------------- /loss/__pycache__/InpaintingLoss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/loss/__pycache__/InpaintingLoss.cpython-36.pyc -------------------------------------------------------------------------------- /loss/__pycache__/InpaintingLoss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/loss/__pycache__/InpaintingLoss.cpython-37.pyc -------------------------------------------------------------------------------- /models/ActivationFunction.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.nn.parameter import Parameter 4 | from torch import nn 5 | from torchvision import models 6 | 7 | # asymmetric gaussian shaped activation function g_A 8 | class GaussActivation(nn.Module): 9 | def __init__(self, a, mu, sigma1, sigma2): 10 | super(GaussActivation, self).__init__() 11 | 12 | self.a = Parameter(torch.tensor(a, dtype=torch.float32)) 13 | self.mu = Parameter(torch.tensor(mu, dtype=torch.float32)) 14 | self.sigma1 = Parameter(torch.tensor(sigma1, dtype=torch.float32)) 15 | self.sigma2 = Parameter(torch.tensor(sigma2, dtype=torch.float32)) 16 | 17 | 18 | def forward(self, inputFeatures): 19 | 20 | self.a.data = torch.clamp(self.a.data, 1.01, 6.0) 21 | self.mu.data = torch.clamp(self.mu.data, 0.1, 3.0) 22 | self.sigma1.data = torch.clamp(self.sigma1.data, 0.5, 2.0) 23 | self.sigma2.data = torch.clamp(self.sigma2.data, 0.5, 2.0) 24 | 25 | lowerThanMu = inputFeatures < self.mu 26 | largerThanMu = inputFeatures >= self.mu 27 | 28 | leftValuesActiv = self.a * torch.exp(- self.sigma1 * ( (inputFeatures - self.mu) ** 2 ) ) 29 | leftValuesActiv.masked_fill_(largerThanMu, 0.0) 30 | 31 | rightValueActiv = 1 + (self.a - 1) * torch.exp(- self.sigma2 * ( (inputFeatures - self.mu) ** 2 ) ) 32 | rightValueActiv.masked_fill_(lowerThanMu, 0.0) 33 | 34 | output = leftValuesActiv + rightValueActiv 35 | 36 | return output 37 | 38 | # mask updating functions, we recommand using alpha that is larger than 0 and lower than 1.0 39 | class MaskUpdate(nn.Module): 40 | def __init__(self, alpha): 41 | super(MaskUpdate, self).__init__() 42 | 43 | self.updateFunc = nn.ReLU(True) 44 | #self.alpha = Parameter(torch.tensor(alpha, dtype=torch.float32)) 45 | self.alpha = alpha 46 | def forward(self, inputMaskMap): 47 | """ self.alpha.data = torch.clamp(self.alpha.data, 0.6, 0.8) 48 | print(self.alpha) """ 49 | 50 | return torch.pow(self.updateFunc(inputMaskMap), self.alpha) -------------------------------------------------------------------------------- /models/EdgeAttentionLayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.weightInitial import weights_init 3 | from torch import nn 4 | import torch.nn.functional as F 5 | class ForwardEdgeAttention(nn.Module): 6 | def __init__(self, channels,outchannels): 7 | super(ForwardEdgeAttention, self).__init__() 8 | self.maskconv = nn.Sequential( 9 | nn.Conv2d(channels,channels,kernel_size=3,stride=1,padding=1,bias=False), 10 | nn.BatchNorm2d(channels), 11 | nn.LeakyReLU(0.2,False) 12 | ) 13 | self.edgegradient = nn.Sequential( 14 | nn.Conv2d(channels,channels,kernel_size=3,stride=1,padding=1,bias=False), 15 | nn.BatchNorm2d(channels), 16 | nn.LeakyReLU(0.2,False) 17 | ) 18 | self.edgemaskcoincide = nn.Sequential( 19 | nn.Conv2d(channels*2,1,kernel_size=1,padding=0,stride=1,bias=False), 20 | nn.BatchNorm2d(1), 21 | nn.LeakyReLU(0.2,False) 22 | ) 23 | self.edgeconv = nn.Sequential( 24 | nn.Conv2d(channels,outchannels,kernel_size=3,padding=1,stride = 1,bias=False), 25 | nn.BatchNorm2d(outchannels), 26 | nn.LeakyReLU(0.2,False) 27 | ) 28 | self.maskconv.apply(weights_init()) 29 | self.edgegradient.apply(weights_init()) 30 | self.edgemaskcoincide.apply(weights_init()) 31 | self.edgeconv.apply(weights_init()) 32 | def forward(self, mask, edge): 33 | # print(edge.shape) 34 | output2 = F.interpolate(edge,size=[mask.shape[2],mask.shape[2]]) 35 | maskout = self.maskconv(mask) 36 | edge_gradient=self.edgegradient(output2) 37 | edge_mask_concat = torch.cat((edge_gradient,mask),1) 38 | edgeout = self.edgemaskcoincide(edge_mask_concat) 39 | maskmulti = maskout*edgeout 40 | output1 = maskmulti+mask 41 | output2 = self.edgeconv(output2) 42 | return output1,output2,maskmulti 43 | -------------------------------------------------------------------------------- /models/LBAMModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision import models 4 | from models.forwardAttentionLayer import ForwardAttention 5 | from models.reverseAttentionLayer import ReverseAttention, ReverseMaskConv 6 | from models.weightInitial import weights_init 7 | from models.EdgeAttentionLayer import ForwardEdgeAttention 8 | from models.weightInitial import weights_init 9 | #VGG16 feature extract 10 | class VGG16FeatureExtractor(nn.Module): 11 | def __init__(self): 12 | super(VGG16FeatureExtractor, self).__init__() 13 | vgg16 = models.vgg16(pretrained=True) 14 | # vgg16.load_state_dict(torch.load('../vgg16-397923af.pth')) 15 | self.enc_1 = nn.Sequential(*vgg16.features[:5]) 16 | self.enc_2 = nn.Sequential(*vgg16.features[5:10]) 17 | self.enc_3 = nn.Sequential(*vgg16.features[10:17]) 18 | # fix the encoder 19 | for i in range(3): 20 | for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters(): 21 | param.requires_grad = False 22 | 23 | def forward(self, image): 24 | results = [image] 25 | for i in range(3): 26 | func = getattr(self, 'enc_{:d}'.format(i + 1)) 27 | results.append(func(results[-1])) 28 | return results[1:] 29 | 30 | class LBAMModel(nn.Module): 31 | def __init__(self, inputChannels, outputChannels): 32 | super(LBAMModel, self).__init__() 33 | self.maskconv1 = nn.Sequential( 34 | nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False), 35 | nn.BatchNorm2d(3), 36 | nn.LeakyReLU(0.2,False) 37 | ) 38 | self.maskconv2 = nn.Sequential( 39 | nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False), 40 | nn.BatchNorm2d(3), 41 | nn.LeakyReLU(0.2,False) 42 | ) 43 | self.edgeconv1 = nn.Sequential( 44 | nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1, bias=False), 45 | nn.BatchNorm2d(3), 46 | nn.LeakyReLU(0.2,False) 47 | ) 48 | self.edgeconv2 = nn.Sequential( 49 | nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False), 50 | nn.BatchNorm2d(3), 51 | nn.LeakyReLU(0.2,False) 52 | ) 53 | self.maskconv3 = nn.Sequential( 54 | nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False), 55 | nn.BatchNorm2d(3), 56 | nn.LeakyReLU(0.2,False) 57 | ) 58 | self.maskconv4 = nn.Sequential( 59 | nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False), 60 | nn.BatchNorm2d(3), 61 | nn.LeakyReLU(0.2,False) 62 | ) 63 | self.edgeconv3 = nn.Sequential( 64 | nn.Conv2d(1, 3, kernel_size=3, stride=1, padding=1, bias=False), 65 | nn.BatchNorm2d(3), 66 | nn.LeakyReLU(0.2,False) 67 | ) 68 | self.edgeconv4 = nn.Sequential( 69 | nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False), 70 | nn.BatchNorm2d(3), 71 | nn.LeakyReLU(0.2,False) 72 | ) 73 | self.maskconv1.apply(weights_init()) 74 | self.maskconv2.apply(weights_init()) 75 | self.maskconv3.apply(weights_init()) 76 | self.maskconv4.apply(weights_init()) 77 | self.edgeconv1.apply(weights_init()) 78 | self.edgeconv2.apply(weights_init()) 79 | self.edgeconv3.apply(weights_init()) 80 | 81 | # default kernel is of size 4X4, stride 2, padding 1, 82 | # and the use of biases are set false in default ReverseAttention class. 83 | self.ec1 = ForwardAttention(5, 64, bn=False) 84 | self.ec2 = ForwardAttention(64, 128) 85 | self.ec3 = ForwardAttention(128, 256) 86 | self.ec4 = ForwardAttention(256, 512) 87 | self.edge1 = ForwardEdgeAttention(3,64) 88 | self.edge2 = ForwardEdgeAttention(64,128) 89 | self.edge3 = ForwardEdgeAttention(128,256) 90 | self.edge4 = ForwardEdgeAttention(256,512) 91 | for i in range(5, 8): 92 | name = 'ec{:d}'.format(i) 93 | setattr(self, name, ForwardAttention(512, 512)) 94 | name2 = 'edge{:d}'.format(i) 95 | setattr(self,name2,ForwardEdgeAttention(512,512)) 96 | 97 | # reverse mask conv 98 | self.reverseConv1 = ReverseMaskConv(3, 64) 99 | self.reverseConv2 = ReverseMaskConv(64, 128) 100 | self.reverseConv3 = ReverseMaskConv(128, 256) 101 | self.reverseConv4 = ReverseMaskConv(256, 512) 102 | self.reverseConv5 = ReverseMaskConv(512, 512) 103 | self.reverseConv6 = ReverseMaskConv(512, 512) 104 | self.reverseedge1 = ForwardEdgeAttention(3,64) 105 | self.reverseedge2 = ForwardEdgeAttention(64,128) 106 | self.reverseedge3 = ForwardEdgeAttention(128,256) 107 | self.reverseedge4 = ForwardEdgeAttention(256,512) 108 | self.reverseedge5 = ForwardEdgeAttention(512, 512) 109 | self.reverseedge6 = ForwardEdgeAttention(512, 512) 110 | self.dc1 = ReverseAttention(512, 512, bnChannels=1024) 111 | self.dc2 = ReverseAttention(512 * 2, 512, bnChannels=1024) 112 | self.dc3 = ReverseAttention(512 * 2, 512, bnChannels=1024) 113 | self.dc4 = ReverseAttention(512 * 2, 256, bnChannels=512) 114 | self.dc5 = ReverseAttention(256 * 2, 128, bnChannels=256) 115 | self.dc6 = ReverseAttention(128 * 2, 64, bnChannels=128) 116 | self.dc7 = nn.ConvTranspose2d(64 * 2, outputChannels, kernel_size=4, stride=2, padding=1, bias=False) 117 | 118 | self.tanh = nn.Tanh() 119 | 120 | def forward(self, inputImgs, masks,edge): 121 | mask1 = self.maskconv1(masks) 122 | mask1 =self.maskconv2(mask1) 123 | edge1 = self.edgeconv1(edge) 124 | edge1 = self.edgeconv2(edge1) 125 | maskoutput1,edgeoutput,feature1 = self.edge1(mask1,edge1) 126 | ef, mu1, skipConnect1, forwardMap1 = self.ec1(inputImgs, maskoutput1) 127 | maskoutput,edgeoutput,feature2 = self.edge2(mu1,edgeoutput) 128 | ef, mu2, skipConnect2, forwardMap2 = self.ec2(ef, maskoutput) 129 | maskoutput3,edgeoutput,feature3 = self.edge3(mu2,edgeoutput) 130 | ef, mu3, skipConnect3, forwardMap3 = self.ec3(ef, maskoutput3) 131 | maskoutput,edgeoutput,_ = self.edge4(mu3,edgeoutput) 132 | ef, mu, skipConnect4, forwardMap4 = self.ec4(ef, maskoutput) 133 | maskoutput,edgeoutput,_ = self.edge5(mu,edgeoutput) 134 | ef, mu, skipConnect5, forwardMap5 = self.ec5(ef, maskoutput) 135 | maskoutput,edgeoutput,_ = self.edge6(mu,edgeoutput) 136 | ef, mu, skipConnect6, forwardMap6 = self.ec6(ef, maskoutput) 137 | maskoutput, edgeoutput,_ = self.edge7(mu, edgeoutput) 138 | ef, _, _, _ = self.ec7(ef, maskoutput) 139 | 140 | mask2 = self.maskconv3(1-masks) 141 | mask2 = self.maskconv4(mask2) 142 | edge2 = self.edgeconv3(edge) 143 | edge2 = self.edgeconv4(edge2) 144 | maskoutput1,edgeoutput,feature1 = self.reverseedge1(mask2,edge2) 145 | reverseMap1, revMu = self.reverseConv1(maskoutput1) 146 | maskoutput2,edgeoutput,feature2 = self.reverseedge2(revMu,edgeoutput) 147 | reverseMap2, revMu = self.reverseConv2(maskoutput2) 148 | maskoutput3, edgeoutput,feature3 = self.reverseedge3(revMu, edgeoutput) 149 | reverseMap3, revMu = self.reverseConv3(maskoutput3) 150 | maskoutput, edgeoutput,_ = self.reverseedge4(revMu, edgeoutput) 151 | reverseMap4, revMu = self.reverseConv4(maskoutput) 152 | maskoutput, edgeoutput,_ = self.reverseedge5(revMu, edgeoutput) 153 | reverseMap5, revMu = self.reverseConv5(maskoutput) 154 | maskoutput, edgeoutput,_ = self.reverseedge6(revMu, edgeoutput) 155 | reverseMap6, _ = self.reverseConv6(maskoutput) 156 | 157 | concatMap6 = torch.cat((forwardMap6, reverseMap6), 1) 158 | dcFeatures1 = self.dc1(skipConnect6, ef, concatMap6) 159 | 160 | concatMap5 = torch.cat((forwardMap5, reverseMap5), 1) 161 | dcFeatures2 = self.dc2(skipConnect5, dcFeatures1, concatMap5) 162 | 163 | concatMap4 = torch.cat((forwardMap4, reverseMap4), 1) 164 | dcFeatures3 = self.dc3(skipConnect4, dcFeatures2, concatMap4) 165 | 166 | concatMap3 = torch.cat((forwardMap3, reverseMap3), 1) 167 | dcFeatures4 = self.dc4(skipConnect3, dcFeatures3, concatMap3) 168 | 169 | concatMap2 = torch.cat((forwardMap2, reverseMap2), 1) 170 | dcFeatures5 = self.dc5(skipConnect2, dcFeatures4, concatMap2) 171 | 172 | concatMap1 = torch.cat((forwardMap1, reverseMap1), 1) 173 | dcFeatures6 = self.dc6(skipConnect1, dcFeatures5, concatMap1) 174 | 175 | dcFeatures7 = self.dc7(dcFeatures6) 176 | 177 | output = torch.abs(self.tanh(dcFeatures7)) 178 | 179 | return output,forwardMap1,forwardMap2,forwardMap3, reverseMap1,reverseMap2,reverseMap3 -------------------------------------------------------------------------------- /models/__pycache__/ActivationFunction.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/ActivationFunction.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/ActivationFunction.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/ActivationFunction.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/EdgeAttentionLayer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/EdgeAttentionLayer.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/EdgeAttentionLayer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/EdgeAttentionLayer.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/LBAMModel.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/LBAMModel.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/LBAMModel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/LBAMModel.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/discriminator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/discriminator.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/discriminator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/discriminator.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/forwardAttentionLayer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/forwardAttentionLayer.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/forwardAttentionLayer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/forwardAttentionLayer.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/reverseAttentionLayer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/reverseAttentionLayer.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/reverseAttentionLayer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/reverseAttentionLayer.cpython-37.pyc -------------------------------------------------------------------------------- /models/__pycache__/weightInitial.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/weightInitial.cpython-36.pyc -------------------------------------------------------------------------------- /models/__pycache__/weightInitial.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/models/__pycache__/weightInitial.cpython-37.pyc -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | ##discriminator 5 | # two column discriminator 6 | class DiscriminatorDoubleColumn(nn.Module): 7 | def __init__(self, inputChannels): 8 | super(DiscriminatorDoubleColumn, self).__init__() 9 | 10 | self.globalConv = nn.Sequential( 11 | nn.Conv2d(inputChannels, 64, kernel_size=4, stride=2, padding=1), 12 | nn.LeakyReLU(0.2, inplace=True), 13 | 14 | nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), 15 | nn.BatchNorm2d(128), 16 | nn.LeakyReLU(0.2, inplace=True), 17 | 18 | nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), 19 | nn.BatchNorm2d(256), 20 | nn.LeakyReLU(0.2, inplace=True), 21 | 22 | nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), 23 | nn.BatchNorm2d(512), 24 | nn.LeakyReLU(0.2 , inplace=True), 25 | 26 | nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1), 27 | nn.BatchNorm2d(512), 28 | nn.LeakyReLU(0.2, inplace=True), 29 | 30 | nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1), 31 | nn.BatchNorm2d(512), 32 | nn.LeakyReLU(0.2, inplace=True), 33 | 34 | ) 35 | 36 | self.localConv = nn.Sequential( 37 | nn.Conv2d(inputChannels, 64, kernel_size=4, stride=2, padding=1), 38 | nn.LeakyReLU(0.2, inplace=True), 39 | 40 | nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), 41 | nn.BatchNorm2d(128), 42 | nn.LeakyReLU(0.2, inplace=True), 43 | 44 | nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), 45 | nn.BatchNorm2d(256), 46 | nn.LeakyReLU(0.2, inplace=True), 47 | 48 | nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1), 49 | nn.BatchNorm2d(512), 50 | nn.LeakyReLU(0.2 , inplace=True), 51 | 52 | nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1), 53 | nn.BatchNorm2d(512), 54 | nn.LeakyReLU(0.2, inplace=True), 55 | 56 | nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1), 57 | nn.BatchNorm2d(512), 58 | nn.LeakyReLU(0.2, inplace=True), 59 | ) 60 | 61 | self.fusionLayer = nn.Sequential( 62 | nn.Conv2d(1024, 1, kernel_size=4), 63 | nn.Sigmoid() 64 | ) 65 | 66 | def forward(self, batches, masks): 67 | globalFt = self.globalConv(batches * masks) 68 | localFt = self.localConv(batches * (1 - masks)) 69 | 70 | concatFt = torch.cat((globalFt, localFt), 1) 71 | 72 | return self.fusionLayer(concatFt).view(batches.size()[0], -1) -------------------------------------------------------------------------------- /models/forwardAttentionLayer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from models.ActivationFunction import GaussActivation, MaskUpdate 5 | from models.weightInitial import weights_init 6 | 7 | # learnable forward attention conv layer 8 | class ForwardAttentionLayer(nn.Module): 9 | def __init__(self, inputChannels, outputChannels, kernelSize, stride, 10 | padding, dilation=1, groups=1, bias=False): 11 | super(ForwardAttentionLayer, self).__init__() 12 | 13 | self.conv = nn.Conv2d(inputChannels, outputChannels, kernelSize, stride, padding, dilation, \ 14 | groups, bias) 15 | 16 | if inputChannels == 5: 17 | self.maskConv = nn.Conv2d(3, outputChannels, kernelSize, stride, padding, dilation, \ 18 | groups, bias) 19 | else: 20 | self.maskConv = nn.Conv2d(inputChannels, outputChannels, kernelSize, stride, padding, \ 21 | dilation, groups, bias) 22 | 23 | self.conv.apply(weights_init()) 24 | self.maskConv.apply(weights_init()) 25 | 26 | self.activationFuncG_A = GaussActivation(1.1, 2.0, 1.0, 1.0) 27 | self.updateMask = MaskUpdate(0.8) 28 | 29 | def forward(self, inputFeatures, inputMasks): 30 | convFeatures = self.conv(inputFeatures) 31 | maskFeatures = self.maskConv(inputMasks) 32 | #convFeatures_skip = convFeatures.clone() 33 | 34 | maskActiv = self.activationFuncG_A(maskFeatures) 35 | convOut = convFeatures * maskActiv 36 | 37 | maskUpdate = self.updateMask(maskFeatures) 38 | 39 | return convOut, maskUpdate, convFeatures, maskActiv 40 | 41 | # forward attention gather feature activation and batchnorm 42 | class ForwardAttention(nn.Module): 43 | def __init__(self, inputChannels, outputChannels, bn=True, sample='down-4', \ 44 | activ='leaky', convBias=False): 45 | super(ForwardAttention, self).__init__() 46 | 47 | if sample == 'down-4': 48 | self.conv = ForwardAttentionLayer(inputChannels, outputChannels, 4, 2, 1, bias=convBias) 49 | elif sample == 'down-5': 50 | self.conv = ForwardAttentionLayer(inputChannels, outputChannels, 5, 2, 2, bias=convBias) 51 | elif sample == 'down-7': 52 | self.conv = ForwardAttentionLayer(inputChannels, outputChannels, 7, 2, 3, bias=convBias) 53 | elif sample == 'down-3': 54 | self.conv = ForwardAttentionLayer(inputChannels, outputChannels, 3, 2, 1, bias=convBias) 55 | else: 56 | self.conv = ForwardAttentionLayer(inputChannels, outputChannels, 3, 1, 1, bias=convBias) 57 | 58 | if bn: 59 | self.bn = nn.BatchNorm2d(outputChannels) 60 | 61 | if activ == 'leaky': 62 | self.activ = nn.LeakyReLU(0.2, False) 63 | elif activ == 'relu': 64 | self.activ = nn.ReLU() 65 | elif activ == 'sigmoid': 66 | self.activ = nn.Sigmoid() 67 | elif activ == 'tanh': 68 | self.activ = nn.Tanh() 69 | elif activ == 'prelu': 70 | self.activ = nn.PReLU() 71 | else: 72 | pass 73 | 74 | def forward(self, inputFeatures, inputMasks): 75 | features, maskUpdated, convPreF, maskActiv = self.conv(inputFeatures, inputMasks) 76 | 77 | if hasattr(self, 'bn'): 78 | features = self.bn(features) 79 | if hasattr(self, 'activ'): 80 | features = self.activ(features) 81 | 82 | return features, maskUpdated, convPreF, maskActiv -------------------------------------------------------------------------------- /models/reverseAttentionLayer.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from models.ActivationFunction import GaussActivation, MaskUpdate 5 | from models.weightInitial import weights_init 6 | 7 | 8 | # learnable reverse attention conv 9 | class ReverseMaskConv(nn.Module): 10 | def __init__(self, inputChannels, outputChannels, kernelSize=4, stride=2, 11 | padding=1, dilation=1, groups=1, convBias=False): 12 | super(ReverseMaskConv, self).__init__() 13 | 14 | self.reverseMaskConv = nn.Conv2d(inputChannels, outputChannels, kernelSize, stride, padding, \ 15 | dilation, groups, bias=convBias) 16 | 17 | self.reverseMaskConv.apply(weights_init()) 18 | 19 | self.activationFuncG_A = GaussActivation(1.1, 1.0, 0.5, 0.5) 20 | self.updateMask = MaskUpdate(0.8) 21 | 22 | def forward(self, inputMasks): 23 | maskFeatures = self.reverseMaskConv(inputMasks) 24 | 25 | maskActiv = self.activationFuncG_A(maskFeatures) 26 | 27 | maskUpdate = self.updateMask(maskFeatures) 28 | 29 | return maskActiv, maskUpdate 30 | 31 | # learnable reverse attention layer, including features activation and batchnorm 32 | class ReverseAttention(nn.Module): 33 | def __init__(self, inputChannels, outputChannels, bn=True, activ='leaky', \ 34 | kernelSize=4, stride=2, padding=1, outPadding=0,dilation=1, groups=1,convBias=False, bnChannels=512): 35 | super(ReverseAttention, self).__init__() 36 | 37 | self.conv = nn.ConvTranspose2d(inputChannels, outputChannels, kernel_size=kernelSize, \ 38 | stride=stride, padding=padding, output_padding=outPadding, dilation=dilation, groups=groups,bias=convBias) 39 | 40 | self.conv.apply(weights_init()) 41 | 42 | if bn: 43 | self.bn = nn.BatchNorm2d(bnChannels) 44 | 45 | if activ == 'leaky': 46 | self.activ = nn.LeakyReLU(0.2, False) 47 | elif activ == 'relu': 48 | self.activ = nn.ReLU() 49 | elif activ == 'sigmoid': 50 | self.activ = nn.Sigmoid() 51 | elif activ == 'tanh': 52 | self.activ = nn.Tanh() 53 | elif activ == 'prelu': 54 | self.activ = nn.PReLU() 55 | else: 56 | pass 57 | 58 | def forward(self, ecFeaturesSkip, dcFeatures, maskFeaturesForAttention): 59 | nextDcFeatures = self.conv(dcFeatures) 60 | 61 | # note that encoder features are ahead, it's important tor make forward attention map ahead 62 | # of reverse attention map when concatenate, we do it in the LBAM model forward function 63 | concatFeatures = torch.cat((ecFeaturesSkip, nextDcFeatures), 1) 64 | 65 | outputFeatures = concatFeatures * maskFeaturesForAttention 66 | 67 | if hasattr(self, 'bn'): 68 | outputFeatures = self.bn(outputFeatures) 69 | if hasattr(self, 'activ'): 70 | outputFeatures = self.activ(outputFeatures) 71 | 72 | return outputFeatures -------------------------------------------------------------------------------- /models/weightInitial.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # weight initial strategies 5 | def weights_init(init_type='gaussian'): 6 | def init_fun(m): 7 | classname = m.__class__.__name__ 8 | 9 | if (classname.find('Conv') == 0 or classname.find('Linear') == 0 ) and hasattr(m, 'weight'): 10 | if (init_type == 'gaussian'): 11 | nn.init.normal_(m.weight, 0.0, 0.02) 12 | elif (init_type == 'xavier'): 13 | nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) 14 | elif (init_type == 'kaiming'): 15 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in') 16 | elif (init_type == 'orthogonal'): 17 | nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) 18 | elif (init_type == 'default'): 19 | pass 20 | else: 21 | assert 0, 'Unsupported initialization: {}'.format(init_type) 22 | if hasattr(m, 'bias') and m.bias is not None: 23 | nn.init.constant_(m.bias, 0.0) 24 | 25 | return init_fun -------------------------------------------------------------------------------- /pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import lpips 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 9 | return gauss / gauss.sum() 10 | 11 | 12 | def create_window(window_size, channel): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 16 | return window 17 | 18 | 19 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 20 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 21 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 22 | 23 | mu1_sq = mu1.pow(2) 24 | mu2_sq = mu2.pow(2) 25 | mu1_mu2 = mu1 * mu2 26 | 27 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 28 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 29 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 30 | 31 | C1 = 0.01 ** 2 32 | C2 = 0.03 ** 2 33 | 34 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 35 | 36 | if size_average: 37 | return ssim_map.mean() 38 | else: 39 | return ssim_map.mean(1).mean(1).mean(1) 40 | 41 | 42 | class SSIM(torch.nn.Module): 43 | def __init__(self, window_size=11, size_average=True): 44 | super(SSIM, self).__init__() 45 | self.window_size = window_size 46 | self.size_average = size_average 47 | self.channel = 1 48 | self.window = create_window(window_size, self.channel) 49 | 50 | def forward(self, img1, img2): 51 | (_, channel, _, _) = img1.size() 52 | 53 | if channel == self.channel and self.window.data.type() == img1.data.type(): 54 | window = self.window 55 | else: 56 | window = create_window(self.window_size, channel) 57 | 58 | if img1.is_cuda: 59 | window = window.cuda(img1.get_device()) 60 | window = window.type_as(img1) 61 | 62 | self.window = window 63 | self.channel = channel 64 | 65 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 66 | 67 | 68 | def ssim(img1, img2, window_size=11, size_average=True): 69 | (_, channel, _, _) = img1.size() 70 | window = create_window(window_size, channel) 71 | 72 | if img1.is_cuda: 73 | window = window.cuda(img1.get_device()) 74 | window = window.type_as(img1) 75 | 76 | return _ssim(img1, img2, window, window_size, channel, size_average) 77 | 78 | def caculatelpips(img1,img2): 79 | loss_fn_alex = lpips.LPIPS(net='alex') 80 | d = loss_fn_alex(img1,img2) 81 | return d -------------------------------------------------------------------------------- /pytorch_ssim/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/pytorch_ssim/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_ssim/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds1998/Edge-LBAM/ca85ddc3c245cbf6b71c599e56a19ad215689f51/pytorch_ssim/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /test_random_batch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.backends.cudnn as cudnn 7 | from torchvision.utils import save_image 8 | from torch.utils.data import DataLoader 9 | from data.dataloader import GetData 10 | from models.LBAMModel import LBAMModel 11 | import pytorch_ssim 12 | import random 13 | import numpy as np 14 | from MECNet.models import EdgeModel 15 | from MECNet.config import Config 16 | import numpy 17 | from PIL.Image import fromarray 18 | 19 | torch.manual_seed(0) 20 | torch.cuda.manual_seed_all(0) 21 | random.seed(0) 22 | numpy.random.seed(0) 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--numOfWorkers', type=int, default=4, 25 | help='workers for dataloader') 26 | parser.add_argument('--local_rank',type=int,default=0) 27 | parser.add_argument('--pretrained', type=str, default='', help='pretrained models') 28 | parser.add_argument('--batchSize', type=int, default=16) 29 | parser.add_argument('--loadSize', type=int, default=350, 30 | help='image loading size') 31 | parser.add_argument('--cropSize', type=int, default=256, 32 | help='image training size') 33 | parser.add_argument('--dataRoot', type=str, 34 | default='') 35 | parser.add_argument('--maskRoot', type=str, 36 | default='') 37 | parser.add_argument('--savePath', type=str, default='./results') 38 | args = parser.parse_args() 39 | 40 | cuda = torch.cuda.is_available() 41 | if cuda: 42 | print('Cuda is available!') 43 | cudnn.benchmark = True 44 | os.makedirs(os.path.join(args.savePath,"GT"), exist_ok=True) 45 | os.makedirs(os.path.join(args.savePath,"damaged"), exist_ok=True) 46 | os.makedirs(os.path.join(args.savePath,"ours"), exist_ok=True) 47 | os.makedirs(os.path.join(args.savePath,"input"), exist_ok=True) 48 | os.makedirs(os.path.join(args.savePath,"masks"), exist_ok=True) 49 | os.makedirs(os.path.join(args.savePath,"edge"), exist_ok=True) 50 | 51 | 52 | batchSize = args.batchSize 53 | loadSize = (args.loadSize, args.loadSize) 54 | cropSize = (args.cropSize, args.cropSize) 55 | dataRoot = args.dataRoot 56 | maskRoot = args.maskRoot 57 | savePath = args.savePath 58 | 59 | if not os.path.exists(savePath): 60 | os.makedirs(savePath) 61 | 62 | config = Config("config.yml") 63 | edge_model = EdgeModel(config).to(config.DEVICE) 64 | edge_model.load() 65 | edge_model.cuda() 66 | edge_model = nn.DataParallel(edge_model, device_ids=[0,1]) 67 | imgData = GetData(dataRoot, maskRoot, loadSize, cropSize) 68 | data_loader = DataLoader(imgData, batch_size=batchSize, shuffle=False, num_workers=1, drop_last=False) 69 | 70 | num_epochs = 100 71 | 72 | netG = LBAMModel(5, 3) 73 | 74 | if args.pretrained != '': 75 | netG.load_state_dict(torch.load(args.pretrained)) 76 | else: 77 | print('No pretrained model provided!') 78 | 79 | # 80 | if cuda: 81 | netG = netG.cuda() 82 | 83 | for param in netG.parameters(): 84 | param.requires_grad = False 85 | 86 | print('OK!') 87 | 88 | 89 | sum_psnr = 0 90 | sum_ssim = 0 91 | count = 0 92 | sum_time = 0.0 93 | l1_loss = 0 94 | 95 | import time 96 | start = time.time() 97 | for i in range(1, num_epochs + 1): 98 | netG.eval() 99 | for inputImgs, GT, masks, img_gray,edge,masks_over in (data_loader): 100 | if count >= 60: 101 | break 102 | if cuda: 103 | inputImgs = inputImgs.cuda() 104 | img_gray=img_gray.cuda() 105 | GT = GT.cuda() 106 | masks = masks.cuda() 107 | edge = edge.cuda() 108 | masks_over=masks_over.cuda() 109 | outputs_2 = edge_model(img_gray, edge, masks_over) 110 | outputs_merged = (outputs_2 * masks_over) + (edge * (1 - masks_over)) 111 | inputImgs2 = torch.cat((inputImgs, outputs_merged), 1) 112 | #do something other 113 | fake_images = netG(inputImgs2, masks,outputs_merged) 114 | 115 | g_image = fake_images.data.cpu() 116 | GT = GT.data.cpu() 117 | mask = masks.data.cpu() 118 | damaged = GT * mask 119 | generaredImage = GT * mask + g_image * (1 - mask) 120 | groundTruth = GT 121 | masksT = mask 122 | generaredImage = generaredImage 123 | groundTruth = groundTruth 124 | count += 1 125 | batch_mse = ((groundTruth - generaredImage) ** 2).mean() 126 | psnr = 10 * math.log10(1 / batch_mse) 127 | sum_psnr += psnr 128 | print(count, ' psnr:', psnr) 129 | ssim = pytorch_ssim.ssim(groundTruth * 255, generaredImage * 255) 130 | sum_ssim += ssim 131 | print(count, ' ssim:', ssim) 132 | l1_loss += nn.L1Loss()(generaredImage, groundTruth) 133 | 134 | outputs =torch.Tensor(5* GT.size()[0], GT.size()[1], cropSize[0], cropSize[1]) 135 | for i in range(GT.size()[0]): 136 | outputs[5 * i] = masksT[i] 137 | outputs[5 * i + 1] = damaged[i] 138 | outputs[5 * i + 2] = GT[i] * masksT[i] 139 | outputs[5 * i + 2] = generaredImage[i] 140 | outputs[5 * i + 3] = GT[i] 141 | outputs[5 * i + 4]=outputs_merged[i] 142 | #outputs[5 * i + 4] = 1 - masksT[i] 143 | # save_image(outputs, os.path.join(savePath, 'results-{}'.format(count) + '.png')) 144 | 145 | # make subdirs to save mask GT results and input and damaged images 146 | damaged = GT * mask + (1 - mask) 147 | 148 | for j in range(GT.size()[0]): 149 | save_image(outputs[5 * j + 1], savePath + '/damaged/damaged{}-{}.png'.format(count, j)) 150 | outputs[5 * j + 1] = damaged[j] 151 | 152 | for j in range(GT.size()[0]): 153 | outputs[5 * j] = 1- masksT[j] 154 | save_image(outputs[5 * j], savePath + '/masks/mask{}-{}.png'.format(count, j)) 155 | save_image(outputs[5 * j + 1], savePath + '/input/input{}-{}.png'.format(count, j)) 156 | save_image(outputs[5 * j + 2], savePath + '/ours/ours{}-{}.png'.format(count, j)) 157 | save_image(outputs[5 * j + 3], savePath + '/GT/GT{}-{}.png'.format(count, j)) 158 | save_image(outputs[5 * j + 4], savePath + '/edge/edge{}-{}.png'.format(count, j)) 159 | 160 | 161 | 162 | end = time.time() 163 | sum_time += (end - start) / batchSize 164 | 165 | 166 | print('avg l1 loss:', l1_loss / count) 167 | print('average psnr:', sum_psnr / count) 168 | print('average ssim:', sum_ssim / count) 169 | print('average time cost:', sum_time / count) 170 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.backends.cudnn as cudnn 8 | from PIL import Image 9 | from torch.autograd import Variable 10 | from torchvision.utils import save_image 11 | from torchvision import datasets 12 | from torch.utils.data import DataLoader 13 | from torchvision import utils 14 | from data.dataloader_canny import GetData 15 | from loss.InpaintingLoss import InpaintingLossWithGAN 16 | from models.LBAMModel import LBAMModel, VGG16FeatureExtractor 17 | from MECNet.models import EdgeModel 18 | from MECNet.config import Config 19 | 20 | torch.set_num_threads(6) 21 | 22 | 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument('--numOfWorkers', type=int, default=4, 25 | help='workers for dataloader') 26 | parser.add_argument('--modelsSavePath', type=str, default='', 27 | help='path for saving models') 28 | parser.add_argument('--logPath', type=str, 29 | default='') 30 | parser.add_argument('--batchSize', type=int, default=16) 31 | parser.add_argument('--loadSize', type=int, default=256, 32 | help='image loading size') 33 | parser.add_argument('--cropSize', type=int, default=256, 34 | help='image training size') 35 | parser.add_argument('--dataRoot', type=str, 36 | default='') 37 | parser.add_argument('--maskRoot', type=str, 38 | default='') 39 | parser.add_argument('--pretrained',type=str, default='', help='pretrained models for finetuning') 40 | parser.add_argument('--train_epochs', type=int, default=500, help='training epochs') 41 | args = parser.parse_args() 42 | 43 | 44 | 45 | cuda = torch.cuda.is_available() 46 | if cuda: 47 | print('Cuda is available!') 48 | cudnn.enable = True 49 | cudnn.benchmark = True 50 | 51 | 52 | batchSize = args.batchSize 53 | loadSize = (args.loadSize, args.loadSize) 54 | cropSize = (args.cropSize, args.cropSize) 55 | 56 | if not os.path.exists(args.modelsSavePath): 57 | os.makedirs(args.modelsSavePath) 58 | 59 | config = Config("config.yml") 60 | edge_model = EdgeModel(config).to(config.DEVICE) 61 | edge_model.load() 62 | edge_model.cuda() 63 | edge_model = nn.DataParallel(edge_model, device_ids=[0,1,2,3]) 64 | dataRoot = args.dataRoot 65 | maskRoot = args.maskRoot 66 | 67 | 68 | imgData = GetData(dataRoot, maskRoot, loadSize, cropSize) 69 | data_loader = DataLoader(imgData, batch_size=batchSize, 70 | shuffle=True, num_workers=args.numOfWorkers, drop_last=False, pin_memory=True) 71 | 72 | num_epochs = args.train_epochs 73 | 74 | netG = LBAMModel(5, 3) 75 | if args.pretrained != '': 76 | netG.load_state_dict(torch.load(args.pretrained)) 77 | 78 | 79 | 80 | numOfGPUs = torch.cuda.device_count() 81 | 82 | if cuda: 83 | netG = netG.cuda() 84 | if numOfGPUs > 1: 85 | netG = nn.DataParallel(netG, device_ids=range(numOfGPUs)) 86 | 87 | count = 1 88 | 89 | 90 | G_optimizer = optim.Adam(netG.parameters(), lr=0.000025, betas=(0.5, 0.9)) 91 | 92 | 93 | criterion = InpaintingLossWithGAN(args.logPath, VGG16FeatureExtractor(), lr=0.00001, betasInit=(0.0, 0.9), Lamda=10.0) 94 | 95 | if cuda: 96 | criterion = criterion.cuda() 97 | 98 | if numOfGPUs > 1: 99 | criterion = nn.DataParallel(criterion, device_ids=range(numOfGPUs)) 100 | 101 | print('OK!') 102 | 103 | for i in range(1, num_epochs + 1): 104 | netG.train() 105 | 106 | for inputImgs, GT, masks,img_gray, edge, masks_over in (data_loader): 107 | 108 | if cuda: 109 | inputImgs = inputImgs.cuda() 110 | GT = GT.cuda() 111 | masks = masks.cuda() 112 | edge = edge.cuda() 113 | masks_over = masks_over.cuda() 114 | netG.zero_grad() 115 | outputs = edge_model(img_gray, edge, masks_over) 116 | 117 | outputs_merged = (outputs * masks_over) + (edge * (1 - masks_over)) 118 | inputImgs = torch.cat((inputImgs, outputs_merged), 1) 119 | # print(inputImgs2.shape) 120 | fake_images = netG(inputImgs, masks,outputs_merged) 121 | G_loss = criterion(inputImgs[:, 0:3, :, :], masks, fake_images, GT, count, i) 122 | G_loss = G_loss.sum() 123 | G_optimizer.zero_grad() 124 | G_loss.backward() 125 | G_optimizer.step() 126 | 127 | with open('/home/wangdongsheng/LBAM_version6/loss2.txt', 'a') as file: 128 | file.write('Generator Loss of epoch{} is {}\n'.format(i, G_loss.item())) 129 | 130 | 131 | count += 1 132 | 133 | """ if (count % 4000 == 0): 134 | torch.save(netG.module.state_dict(), args.modelsSavePath + 135 | '/Places_{}.pth'.format(i)) """ 136 | 137 | if ( i % 10 == 0): 138 | if numOfGPUs > 1 : 139 | torch.save(netG.module.state_dict(), args.modelsSavePath + 140 | '/LBAM_{}.pth'.format(i%50)) 141 | else: 142 | torch.save(netG.state_dict(), args.modelsSavePath + 143 | '/LBAM_{}.pth'.format(i%50)) 144 | --------------------------------------------------------------------------------