├── LICENSE ├── README.md ├── __init__.py ├── dataset.py ├── eval.py ├── flist.py ├── main.py ├── models.py ├── module_util.py ├── networks.py ├── pretrained_model └── x_admin.cluster.localRN-0.8RN-Net_bs_14_epoch_3.pth └── rn.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tao Yu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Region Normalization for Image Inpainting 2 | 3 | The paper can be found [here](https://arxiv.org/abs/1911.10375). If you have any question about the paper/codes, you can contact me through Email(yutao666@mail.ustc.edu.cn). 4 | 5 | Please run the codes where the python is Version 3.x and pytorch>=0.4. 6 | 7 | PS: 1) The results of this version codes are better than those in the paper. The original base inpainting model that RN uses is not very stable (the result variance is a bit large) and we only reported conservative results. However, we optimized the base model and improved its robustness after the pulication so that the results now are better. 2) RN wants to bring an insight that spatially region-wise normalization is better for some CV tasks such as inpainting. Theoretically, RN can be both BN-style or IN-style. Both have pros and cons. IN-style RN gives less blurring results and achieves style consistence to background in some extent, while suffers from spatial inconsistence if the model representation ability is limited. BN-style RN gives higher PSNR on an aligned validation data, but makes regions more blurring and causes much data-bias risk when testing data distribution has a certain shift to training data distribution. One chooses the RN style according to the specific scene. (See [issue #12](https://github.com/geekyutao/RN/issues/12)) 8 | 9 | ## Repo Update: 10 | - [04/26/2022] Support torch >= 1.7; fix old-version issues. 11 | 12 | 13 | ## Preparation 14 | Before running the codes, you should prepare training/evaluation image file list (flist) and mask file list (flist). You can refer to the folowing command to generate .flist file: 15 | ``` 16 | python flist.py --path your_dataset_folder --output xxx.flist 17 | ``` 18 | 19 | ## Training 20 | There are some hyperparameters that you can adjust in the main.py. To train the model, you can run: 21 | ``` 22 | python main.py --bs 14 --gpus 2 --prefix rn --img_flist your_training_images.flist --mask_flist your_training_masks.flist 23 | ``` 24 | PS: You can set the "--bs" and "--gpus" to any number as you like. The above is just an example. 25 | 26 | ## Evaluation 27 | To evaluate the model, you can use GPU or CPU to run. 28 | 29 | For GPU: 30 | ``` 31 | python eval.py --bs your_batch_size --model your_checkpoint_path --img_flist your_eval_images.flist --mask_flist your_eval_masks.flist 32 | ``` 33 | 34 | For CPU: 35 | ``` 36 | python eval.py --cpu --bs your_batch_size --model your_checkpoint_path --img_flist your_eval_images.flist --mask_flist your_eval_masks.flist 37 | ``` 38 | 39 | PS: The pretrained model under folder './pretrained_model/' is trained from Places2 dataset with [Irregular Mask](https://nv-adlr.github.io/publication/partialconv-inpainting) dataset. **Please train RN from scratch if you test data not from Places2 or using regular mask.** 40 | 41 | ## Cite Us 42 | Please cite us if you find this work helps. 43 | 44 | ``` 45 | @inproceedings{yu2020region, 46 | title={Region Normalization for Image Inpainting.}, 47 | author={Yu, Tao and Guo, Zongyu and Jin, Xin and Wu, Shilin and Chen, Zhibo and Li, Weiping and Zhang, Zhizheng and Liu, Sen}, 48 | booktitle={AAAI}, 49 | pages={12733--12740}, 50 | year={2020} 51 | } 52 | ``` 53 | 54 | ## Appreciation 55 | The codes refer to [EdgeConnect](https://github.com/knazeri/edge-connect). Thanks for the authors of it! 56 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # The codes for RN are soon released. 2 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import scipy 4 | import torch 5 | import random 6 | import numpy as np 7 | import torchvision.transforms.functional as F 8 | from PIL import Image 9 | # from scipy.misc import imread 10 | import imageio 11 | import cv2 12 | from skimage.color import rgb2gray, gray2rgb 13 | from skimage.transform import resize 14 | from torch.utils.data import DataLoader 15 | 16 | def my_transforms(): 17 | transform = transforms.Compose([ 18 | transforms.Resize((256, 256)), 19 | transforms.ToTensor(), 20 | ]) 21 | return transform 22 | 23 | class Dataset(torch.utils.data.Dataset): 24 | def __init__(self, flist, mask_flist, augment, training, input_size): 25 | super(Dataset, self).__init__() 26 | self.augment = augment 27 | self.training = training 28 | self.data = self.load_flist(flist) 29 | self.mask_data = self.load_flist(mask_flist) 30 | self.input_size = input_size 31 | 32 | def __len__(self): 33 | return len(self.data) 34 | 35 | def __getitem__(self, index): 36 | item = self.load_item(index) 37 | # try: 38 | # item = self.load_item(index) 39 | # except: 40 | # print('loading error: ' + self.data[index]) 41 | # item = self.load_item(0) 42 | 43 | return item 44 | 45 | def load_name(self, index): 46 | name = self.data[index] 47 | return os.path.basename(name) 48 | 49 | def load_item(self, index): 50 | 51 | size = self.input_size 52 | 53 | # load image 54 | img = imageio.imread(self.data[index]) 55 | 56 | # gray to rgb 57 | if len(img.shape) < 3: 58 | img = gray2rgb(img) 59 | 60 | # resize/crop if needed 61 | if self.training: 62 | if size != 0: 63 | img = self.resize(img, size, size) 64 | 65 | # load mask 66 | mask = self.load_mask(img, index) 67 | 68 | # augment data 69 | if self.augment and np.random.binomial(1, 0.5) > 0: 70 | img = img[:, ::-1, ...] 71 | mask = mask[:, ::-1, ...] 72 | 73 | return self.to_tensor(img), self.to_tensor(mask), index 74 | 75 | 76 | def load_mask(self, img, index): 77 | imgh, imgw = img.shape[0:2] 78 | 79 | # external 80 | if self.training: 81 | mask_index = random.randint(0, len(self.mask_data) - 1) 82 | mask = imageio.imread(self.mask_data[mask_index]) 83 | mask = self.resize(mask, imgh, imgw) 84 | else: # in test mode, there's a one-to-one relationship between mask and image; masks are loaded non random 85 | # mask = 255 - imread(self.mask_data[index])[:,:,0] # ICME original (H,W,3) mask: 0 for hole 86 | mask = imageio.imread(self.mask_data[index]) # mask must be 255 for hole in this InpaintingModel 87 | mask = self.resize(mask, imgh, imgw, centerCrop=False) 88 | if len(mask.shape) == 3: 89 | mask = rgb2gray(mask) 90 | mask = (mask > 0).astype(np.uint8) * 255 # threshold due to interpolation 91 | return mask 92 | 93 | def to_tensor(self, img): 94 | img = Image.fromarray(img) 95 | img_t = F.to_tensor(img).float() 96 | return img_t 97 | 98 | def resize(self, img, height, width, centerCrop=True): 99 | imgh, imgw = img.shape[:2] 100 | 101 | if centerCrop and imgh != imgw: 102 | # center crop 103 | side = np.minimum(imgh, imgw) 104 | j = (imgh - side) // 2 105 | i = (imgw - side) // 2 106 | img = img[j:j + side, i:i + side, ...] 107 | 108 | # print(type(img)) # imageio.core.util.Array 109 | # img = scipy.misc.imresize(img, [height, width]) 110 | img = cv2.resize(img, (height, width)) 111 | 112 | return img 113 | 114 | def load_flist(self, flist): 115 | if isinstance(flist, list): 116 | return flist 117 | 118 | # flist: image file path, image directory path, text file flist path 119 | if isinstance(flist, str): 120 | if os.path.isdir(flist): 121 | flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png')) 122 | flist.sort() 123 | return flist 124 | 125 | if os.path.isfile(flist): 126 | # print(np.genfromtxt(flist, dtype=np.str)) 127 | # return np.genfromtxt(flist, dtype=np.str) 128 | try: 129 | return np.genfromtxt(flist, dtype=np.str) 130 | except: 131 | return [flist] 132 | return [] 133 | 134 | 135 | def build_dataloader(flist, mask_flist, augment, training, input_size, batch_size, \ 136 | num_workers, shuffle): 137 | 138 | dataset = Dataset( 139 | flist=flist, 140 | mask_flist=mask_flist, 141 | augment=augment, 142 | training=training, 143 | input_size=input_size 144 | ) 145 | 146 | print('Total instance number:', dataset.__len__()) 147 | 148 | dataloader = DataLoader( 149 | dataset=dataset, 150 | batch_size=batch_size, 151 | num_workers=num_workers, 152 | drop_last=True, 153 | shuffle=shuffle 154 | ) 155 | 156 | return dataloader 157 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | from math import log10 4 | import numpy as np 5 | import math 6 | 7 | import os 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | import torch.backends.cudnn as cudnn 13 | from torch.autograd import Variable 14 | from torch.utils.data import DataLoader 15 | import torchvision.utils as vutils 16 | 17 | from module_util import initialize_weights 18 | from dataset import build_dataloader 19 | import pdb 20 | import socket 21 | import time 22 | import skimage 23 | # from skimage.measure import compare_ssim 24 | # from skimage.measure import compare_psnr 25 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr 26 | from skimage.metrics import structural_similarity as compare_ssim 27 | 28 | from models import InpaintingModel 29 | 30 | # from cal_fid import calculate_fid_given_paths 31 | 32 | import lpips 33 | 34 | loss_fn_alex = lpips.LPIPS(net='alex').cuda() 35 | 36 | # Training settings 37 | parser = argparse.ArgumentParser(description='PyTorch Video Inpainting with Background Auxilary') 38 | parser.add_argument('--bs', type=int, default=64, help='training batch size') 39 | parser.add_argument('--lr', type=float, default=0.0001, help='Learning Rate. Default=0.0001') 40 | parser.add_argument('--cpu', default=False, action='store_true', help='Use CPU to test') 41 | parser.add_argument('--threads', type=int, default=1, help='number of threads for data loader to use') 42 | parser.add_argument('--seed', type=int, default=67454, help='random seed to use. Default=123') 43 | parser.add_argument('--gpus', default=1, type=int, help='number of gpu') 44 | parser.add_argument('--threshold', type=float, default=0.8) 45 | parser.add_argument('--img_flist', type=str, default='/data/dataset/places2/flist/val.flist') 46 | parser.add_argument('--mask_flist', type=str, default='/data/dataset/places2/flist/3w_all.flist') 47 | parser.add_argument('--model', default='/data/yutao/Project/weights/BGNet/x_admin.cluster.localRN-0.8BGNet_bs_14_epoch_9.pth', help='sr pretrained base model') 48 | parser.add_argument('--save', default=False, action='store_true', help='If save test images') 49 | parser.add_argument('--save_path', type=str, default='./test_results') 50 | parser.add_argument('--input_size', type=int, default=256, help='input image size') 51 | parser.add_argument('--l1_weight', type=float, default=1.0) 52 | parser.add_argument('--gan_weight', type=float, default=0.1) 53 | 54 | 55 | opt = parser.parse_args() 56 | 57 | def eval(): 58 | model.eval() 59 | model.generator.eval() 60 | count = 1 61 | avg_du = 0 62 | avg_psnr, avg_ssim, avg_l1, avg_lpips = 0., 0., 0., 0. 63 | for batch in testing_data_loader: 64 | gt, mask, index = batch 65 | t_io2 = time.time() 66 | if cuda: 67 | gt = gt.cuda() 68 | mask = mask.cuda() 69 | 70 | 71 | ## The test or ensemble test 72 | 73 | # t0 = time.clock() 74 | with torch.no_grad(): 75 | prediction = model.generator(gt, mask) 76 | prediction = prediction * mask + gt * (1 - mask) 77 | batch_avg_lpips = loss_fn_alex(prediction, gt).mean() 78 | avg_lpips = avg_lpips + ((batch_avg_lpips- avg_lpips) / count) 79 | # t1 = time.clock() 80 | # du = t1 - t0 81 | # print("===> Processing: %s || Timer: %.4f sec." % (str(count), du)) 82 | 83 | # avg_du += du 84 | # print( 85 | # "Number: %05d" % (count), 86 | # " | Average time: %.4f" % (avg_du/count)) 87 | 88 | # Save the video frames 89 | batch_avg_psnr, batch_avg_ssim, batch_avg_l1 = evaluate_batch( 90 | batch_size=opt.bs, 91 | gt_batch=gt, 92 | pred_batch=prediction, 93 | mask_batch=mask, 94 | save=opt.save, 95 | path=opt.save_path, 96 | count=count, 97 | index=index 98 | ) 99 | 100 | # avg_psnr = (avg_psnr * (count - 1) + batch_avg_psnr) / count 101 | avg_psnr = avg_psnr + ((batch_avg_psnr- avg_psnr) / count) 102 | avg_ssim = avg_ssim + ((batch_avg_ssim- avg_ssim) / count) 103 | avg_l1 = avg_l1 + ((batch_avg_l1- avg_l1) / count) 104 | 105 | print( 106 | "Number: %05d" % (count * opt.bs), 107 | " | Average: PSNR: %.4f" % (avg_psnr), 108 | " SSIM: %.4f" % (avg_ssim), 109 | " LPIPS: %.4f" % (avg_lpips), 110 | " L1: %.4f" % (avg_l1), 111 | "| Current batch:", count, 112 | " PSNR: %.4f" % (batch_avg_psnr), 113 | " SSIM: %.4f" % (batch_avg_ssim), 114 | " LPIPS: %.4f" % (batch_avg_lpips), 115 | " L1: %.4f" % (batch_avg_l1), flush=True 116 | ) 117 | 118 | count+=1 119 | 120 | 121 | 122 | 123 | def save_img(path, name, img): 124 | # img (H,W,C) or (H,W) np.uint8 125 | skimage.io.imsave(path+'/'+name+'.png', img) 126 | 127 | def PSNR(pred, gt, shave_border=0): 128 | return compare_psnr(pred, gt, data_range=255) 129 | # imdff = pred - gt 130 | # rmse = math.sqrt(np.mean(imdff ** 2)) 131 | # if rmse == 0: 132 | # return 100 133 | # return 20 * math.log10(255.0 / rmse) 134 | 135 | def L1(pred, gt): 136 | return np.mean(np.abs((np.mean(pred,2) - np.mean(gt,2))/255)) 137 | 138 | def SSIM(pred, gt, data_range=255, win_size=11, multichannel=True): 139 | return compare_ssim(pred, gt, data_range=data_range, \ 140 | multichannel=multichannel, win_size=win_size) 141 | 142 | def evaluate_batch(batch_size, gt_batch, pred_batch, mask_batch, save=False, path=None, count=None, index=None): 143 | pred_batch = pred_batch * mask_batch + gt_batch * (1 - mask_batch) 144 | 145 | if save: 146 | input_batch = gt_batch * (1 - mask_batch) + mask_batch 147 | input_batch = (input_batch.detach().permute(0,2,3,1).cpu().numpy()*255).astype(np.uint8) 148 | mask_batch = (mask_batch.detach().permute(0,2,3,1).cpu().numpy()[:,:,:,0]*255).astype(np.uint8) 149 | 150 | if not os.path.exists(path): 151 | os.mkdir(path) 152 | 153 | 154 | gt_batch = (gt_batch.detach().permute(0,2,3,1).cpu().numpy()*255).astype(np.uint8) 155 | pred_batch = (pred_batch.detach().permute(0,2,3,1).cpu().numpy()*255).astype(np.uint8) 156 | 157 | psnr, ssim, l1 = 0., 0., 0. 158 | for i in range(batch_size): 159 | gt, pred, name = gt_batch[i], pred_batch[i], index[i].data.item() 160 | 161 | psnr += PSNR(pred, gt) 162 | ssim += SSIM(pred, gt) 163 | # ssim += SSIM(pred, gt, multichannel=False) 164 | l1 += L1(pred, gt) 165 | 166 | if save: 167 | save_img(path, str(count)+'_'+str(name)+'_input', input_batch[i]) 168 | save_img(path, str(count)+'_'+str(name)+'_mask', mask_batch[i]) 169 | save_img(path, str(count)+'_'+str(name)+'_output', pred_batch[i]) 170 | save_img(path, str(count)+'_'+str(name)+'_gt', gt_batch[i]) 171 | 172 | return psnr/batch_size, ssim/batch_size, l1/batch_size 173 | 174 | 175 | 176 | def print_network(net): 177 | num_params = 0 178 | learnable_num_params = 0 179 | for param in net.parameters(): 180 | num_params += param.numel() 181 | if param.requires_grad: 182 | learnable_num_params += param.numel() 183 | print(net) 184 | print('Total number of parameters: %d' % num_params) 185 | print('Learnable number of parameters: %d' % learnable_num_params) 186 | 187 | 188 | if __name__ == '__main__': 189 | if opt.cpu: 190 | print("===== Use CPU to Test! =====") 191 | else: 192 | print("===== Use GPU to Test! =====") 193 | 194 | ## Set the GPU mode 195 | gpus_list=range(opt.gpus) 196 | cuda = not opt.cpu 197 | if cuda and not torch.cuda.is_available(): 198 | raise Exception("No GPU found, please run without --cuda") 199 | 200 | 201 | # Model 202 | model = InpaintingModel(g_lr=opt.lr, d_lr=(0.1 * opt.lr), l1_weight=opt.l1_weight, gan_weight=opt.gan_weight, iter=0, threshold=opt.threshold) 203 | print('---------- Networks architecture -------------') 204 | print("Generator:") 205 | print_network(model.generator) 206 | print("Discriminator:") 207 | print_network(model.discriminator) 208 | print('----------------------------------------------') 209 | 210 | pretrained_model = torch.load(opt.model, map_location=lambda storage, loc: storage) 211 | 212 | model.generator = torch.nn.DataParallel(model.generator, device_ids=gpus_list) 213 | model.discriminator = torch.nn.DataParallel(model.discriminator, device_ids=gpus_list) 214 | model.load_state_dict(pretrained_model, strict=False) # strict=Fasle since I modify discirminator in the previous commit 215 | model.generator = model.generator.cuda() 216 | print('Pre-trained G model is loaded.') 217 | 218 | # Datasets 219 | print('===> Loading datasets') 220 | testing_data_loader = build_dataloader( 221 | flist=opt.img_flist, 222 | mask_flist=opt.mask_flist, 223 | augment=False, 224 | training=False, 225 | input_size=opt.input_size, 226 | batch_size=opt.bs, 227 | num_workers=opt.threads, 228 | shuffle=False 229 | ) 230 | print('===> Dataset loaded!') 231 | 232 | ## Eval Start!!!! 233 | eval() 234 | -------------------------------------------------------------------------------- /flist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument('--path', type=str, help='path to the dataset') 7 | parser.add_argument('--output', type=str, help='path to the file list') 8 | args = parser.parse_args() 9 | 10 | ext = {'.jpg', '.png'} 11 | 12 | images = [] 13 | for root, dirs, files in os.walk(args.path): 14 | print('loading ' + root) 15 | for file in files: 16 | if os.path.splitext(file)[1] in ext: 17 | images.append(os.path.join(root, file)) 18 | 19 | images = sorted(images) 20 | np.savetxt(args.output, images, fmt='%s') -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | from math import log10 4 | import numpy as np 5 | 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.optim as optim 11 | import torch.backends.cudnn as cudnn 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | import torchvision.utils as vutils 15 | 16 | from module_util import * 17 | from dataset import build_dataloader 18 | import pdb 19 | import socket 20 | import time 21 | from skimage import io 22 | # from skimage.measure import compare_psnr 23 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr 24 | 25 | from models import InpaintingModel 26 | 27 | from tensorboardX import SummaryWriter 28 | 29 | 30 | # Training settings 31 | parser = argparse.ArgumentParser(description='Region Normalization for Image Inpainting') 32 | parser.add_argument('--bs', type=int, default=14, help='training batch size') 33 | parser.add_argument('--input_size', type=int, default=256, help='input image size') 34 | parser.add_argument('--start_epoch', type=int, default=1, help='Starting epoch for continuing training') 35 | parser.add_argument('--nEpochs', type=int, default=10, help='number of epochs to train for') 36 | parser.add_argument('--snapshots', type=int, default=1, help='Snapshots') 37 | parser.add_argument('--lr', type=float, default=0.0001, help='Learning Rate. Default=0.0001') 38 | parser.add_argument('--gpu_mode', type=bool, default=True) 39 | parser.add_argument('--threads', type=int, default=2, help='number of threads for data loader to use') 40 | parser.add_argument('--seed', type=int, default=67454, help='random seed to use. Default=123') 41 | parser.add_argument('--gpus', default=1, type=int, help='number of gpu') 42 | parser.add_argument('--img_flist', type=str, default='shuffled_train.flist') 43 | parser.add_argument('--mask_flist', type=str, default='all.flist') 44 | parser.add_argument('--model_type', type=str, default='RN') 45 | parser.add_argument('--threshold', type=float, default=0.8) 46 | parser.add_argument('--pretrained_sr', default='../weights/xx.pth', help='pretrained base model') 47 | parser.add_argument('--pretrained', type=bool, default=False) 48 | parser.add_argument('--save_folder', default='./ckpt/', help='Location to save checkpoint models') 49 | parser.add_argument('--prefix', default='0p1GAN0p8thre', help='Location to save checkpoint models') 50 | parser.add_argument('--print_interval', type=int, default=100, help='how many steps to print the results out') 51 | parser.add_argument('--render_interval', type=int, default=10000, help='how many steps to save a checkpoint') 52 | parser.add_argument('--l1_weight', type=float, default=1.0) 53 | parser.add_argument('--gan_weight', type=float, default=0.1) 54 | parser.add_argument('--update_weight_interval', type=int, default=5000, help='how many steps to update losses weighing') 55 | parser.add_argument('--with_test', default=False, action='store_true', help='Train with testing?') 56 | parser.add_argument('--test', default=False, action='store_true', help='Test model') 57 | parser.add_argument('--test_mask_flist', type=str, default='mask1k.flist') 58 | parser.add_argument('--test_img_flist', type=str, default='val1k.flist') 59 | parser.add_argument('--tb', default=False, action='store_true', help='Use tensorboardX?') 60 | 61 | opt = parser.parse_args() 62 | gpus_list = list(range(opt.gpus)) # the list of gpu 63 | hostname = str(socket.gethostname()) 64 | opt.save_folder += opt.prefix 65 | cudnn.benchmark = True 66 | if not os.path.exists(opt.save_folder): 67 | os.makedirs(opt.save_folder) 68 | print(opt) 69 | 70 | 71 | def train(epoch): 72 | iteration, avg_g_loss, avg_d_loss, avg_l1_loss, avg_gan_loss = 0, 0, 0, 0, 0 73 | last_l1_loss, last_gan_loss, cur_l1_loss, cur_gan_loss = 0, 0, 0, 0 74 | model.train() 75 | t0 = time.time() 76 | t_io1 = time.time() 77 | for batch in training_data_loader: 78 | gt, mask, index = batch 79 | t_io2 = time.time() 80 | if cuda: 81 | gt = gt.cuda() 82 | mask = mask.cuda() 83 | 84 | prediction = model.generator(gt, mask) 85 | merged_result = prediction * mask + gt * (1 - mask) 86 | # render(epoch, iteration, mask, prediction.detach(), gt) 87 | # os._exit() 88 | 89 | # Compute Loss 90 | g_loss, d_loss = 0, 0 91 | 92 | d_real, _ = model.discriminator(gt) 93 | d_fake, _ = model.discriminator(prediction.detach()) 94 | d_real_loss = model.adversarial_loss(d_real, True, True) 95 | d_fake_loss = model.adversarial_loss(d_fake, False, True) 96 | d_loss = d_loss + (d_real_loss + d_fake_loss) / 2 97 | 98 | # Backward D 99 | d_loss.backward() 100 | model.dis_optimizer.step() 101 | model.dis_optimizer.zero_grad() 102 | 103 | g_fake, _ = model.discriminator(prediction) 104 | g_gan_loss = model.adversarial_loss(g_fake, True, False) 105 | g_loss = g_loss + model.gan_weight * g_gan_loss 106 | g_l1_loss = model.l1_loss(gt, merged_result) / torch.mean(mask) 107 | # g_l1_loss = model.l1_loss(gt, prediction) / torch.mean(mask) 108 | g_loss = g_loss + model.l1_weight * g_l1_loss 109 | 110 | # Backward G 111 | g_loss.backward() 112 | model.gen_optimizer.step() 113 | model.gen_optimizer.zero_grad() 114 | 115 | # Record 116 | cur_l1_loss += g_l1_loss.data.item() 117 | cur_gan_loss += g_gan_loss.data.item() 118 | avg_l1_loss += g_l1_loss.data.item() 119 | avg_gan_loss += g_gan_loss.data.item() 120 | avg_g_loss += g_loss.data.item() 121 | avg_d_loss += d_loss.data.item() 122 | 123 | model.global_iter += 1 124 | iteration += 1 125 | t1 = time.time() 126 | td, t0 = t1 - t0, t1 127 | 128 | if iteration % opt.print_interval == 0: 129 | print("=> Epoch[{}]({}/{}): Avg L1 loss: {:.6f} | G loss: {:.6f} | Avg D loss: {:.6f} || Timer: {:.4f} sec. | IO: {:.4f}".format( 130 | epoch, iteration, len(training_data_loader), avg_l1_loss/opt.print_interval, avg_g_loss/opt.print_interval, avg_d_loss/opt.print_interval, td, t_io2-t_io1), flush=True) 131 | #print("=> Epoch[{}]({}/{}): Avg G loss: {:.6f} || Timer: {:.4f} sec. || IO: {:.4f}".format( 132 | # epoch, iteration, len(training_data_loader), avg_g_loss/opt.print_interval, td, t_io2-t_io1), flush=True) 133 | 134 | if opt.tb: 135 | writer.add_scalar('scalar/G_loss', avg_g_loss/opt.print_interval, model.global_iter) 136 | writer.add_scalar('scalar/D_loss', avg_d_loss/opt.print_interval, model.global_iter) 137 | writer.add_scalar('scalar/G_l1_loss', avg_l1_loss/opt.print_interval, model.global_iter) 138 | writer.add_scalar('scalar/G_gan_loss', avg_gan_loss/opt.print_interval, model.global_iter) 139 | 140 | avg_g_loss, avg_d_loss, avg_l1_loss, avg_gan_loss = 0, 0, 0, 0 141 | t_io1 = time.time() 142 | 143 | if iteration % opt.render_interval == 0: 144 | render(epoch, iteration, mask, merged_result.detach(), gt) 145 | if opt.with_test: 146 | test_num = 500 147 | print("Testing {} images...".format(test_num)) 148 | test_psnr = test(model, test_data_loader, test_num=test_num) # or 'all' 149 | print("PSNR: ", test_psnr) 150 | if opt.tb: 151 | writer.add_scalar('scalar/test_PSNR', test_psnr, model.global_iter) 152 | 153 | if iteration % 50000 == 0: 154 | checkpoint(iteration) 155 | 156 | def render(epoch, iter, mask, output, gt): 157 | diry = 'render/'+opt.prefix 158 | if not os.path.exists(diry): 159 | os.makedirs(diry) 160 | 161 | name_pre = diry+'/'+str(epoch)+'_'+str(iter)+'_' 162 | 163 | # input: (bs,3,256,256) 164 | input = gt * (1 - mask) + mask 165 | input = input[0].permute(1,2,0).cpu().numpy() 166 | io.imsave(name_pre+'input.png', (input*255).astype(np.uint8)) 167 | 168 | # mask: (bs,1,256,256) 169 | mask = mask[0,0].cpu().numpy() 170 | io.imsave(name_pre+'mask.png', (mask*255).astype(np.uint8)) 171 | 172 | # output: (bs,3,256,256) 173 | output = output[0].permute(1,2,0).cpu().numpy() 174 | io.imsave(name_pre+'output.png', (output*255).astype(np.uint8)) 175 | 176 | # gt: (bs,3,256,256) 177 | gt = gt[0].permute(1,2,0).cpu().numpy() 178 | io.imsave(name_pre+'gt.png', (gt*255).astype(np.uint8)) 179 | 180 | def test(gen, dataloader, test_num='all'): 181 | model = gen.eval() 182 | psnr = 0 183 | count = 0 184 | for batch in dataloader: 185 | gt_batch, mask_batch, index = batch 186 | if cuda: 187 | gt_batch = gt_batch.cuda() 188 | mask_batch = mask_batch.cuda() 189 | with torch.no_grad(): 190 | pred_batch = model.generator(gt_batch, mask_batch) 191 | for i in range(gt_batch.size(0)): 192 | gt, pred = gt_batch[i], pred_batch[i] 193 | psnr += compare_psnr(pred.permute(1,2,0).cpu().numpy(), gt.permute(1,2,0).cpu().numpy(),\ 194 | data_range=1) 195 | count += 1 196 | if test_num == 'all': 197 | pass 198 | elif count > test_num: 199 | break 200 | 201 | return psnr / count 202 | 203 | def checkpoint(epoch): 204 | model_out_path = opt.save_folder+'/'+'x_'+hostname + \ 205 | opt.model_type+"_"+opt.prefix + "_bs_{}_epoch_{}.pth".format(opt.bs, epoch) 206 | torch.save(model.state_dict(), model_out_path) 207 | print("Checkpoint saved to {}".format(model_out_path)) 208 | 209 | if __name__ == '__main__': 210 | if opt.tb: 211 | writer = SummaryWriter() 212 | 213 | # Set the GPU mode 214 | cuda = opt.gpu_mode 215 | if cuda and not torch.cuda.is_available(): 216 | raise Exception("No GPU found, please run without --cuda") 217 | 218 | # Set the random seed 219 | torch.manual_seed(opt.seed) 220 | if cuda: 221 | torch.cuda.manual_seed_all(opt.seed) 222 | 223 | # Model 224 | model = InpaintingModel(g_lr=opt.lr, d_lr=(0.1 * opt.lr), l1_weight=opt.l1_weight, gan_weight=opt.gan_weight, iter=0, threshold=opt.threshold) 225 | print('---------- Networks architecture -------------') 226 | print("Generator:") 227 | print_network(model.generator) 228 | print("Discriminator:") 229 | print_network(model.discriminator) 230 | print('----------------------------------------------') 231 | initialize_weights(model, scale=0.1) 232 | 233 | if cuda: 234 | model = model.cuda() 235 | if opt.gpus > 1: 236 | model.generator = torch.nn.DataParallel(model.generator, device_ids=gpus_list) 237 | model.discriminator = torch.nn.DataParallel(model.discriminator, device_ids=gpus_list) 238 | 239 | # Load the pretrain model. 240 | if opt.pretrained: 241 | model_name = os.path.join(opt.pretrained_sr) 242 | print('pretrained model: %s' % model_name) 243 | if os.path.exists(model_name): 244 | pretained_model = torch.load(model_name, map_location=lambda storage, loc: storage) 245 | model.load_state_dict(pretained_model, strict=False) # strict=Fasle since I modify discirminator in the previous commit 246 | print('Pre-trained model is loaded.') 247 | print(' Current: G learning rate:', model.g_lr, ' | L1 loss weight:', model.l1_weight, \ 248 | ' | GAN loss weight:', model.gan_weight) 249 | 250 | # Datasets 251 | print('===> Loading datasets...') 252 | training_data_loader = build_dataloader( 253 | flist=opt.img_flist, 254 | mask_flist=opt.mask_flist, 255 | augment=True, 256 | training=True, 257 | input_size=opt.input_size, 258 | batch_size=opt.bs, 259 | num_workers=opt.threads, 260 | shuffle=True 261 | ) 262 | print('===> Datasets loaded!') 263 | 264 | if opt.test or opt.with_test: 265 | test_data_loader = build_dataloader( 266 | flist=opt.test_img_flist, 267 | mask_flist=opt.test_mask_flist, 268 | augment=False, 269 | training=False, 270 | input_size=opt.input_size, 271 | batch_size=64, 272 | num_workers=opt.threads, 273 | shuffle=False 274 | ) 275 | print('===> Test datasets loaded') 276 | 277 | if opt.test: 278 | test_psnr = test(model, test_data_loader) 279 | os._exit(0) 280 | 281 | # Start training 282 | if not os.path.exists('render'): 283 | os.makedirs('render') 284 | 285 | for epoch in range(opt.start_epoch, opt.nEpochs + 1): 286 | 287 | train(epoch) 288 | 289 | count = (epoch-1) 290 | if isinstance(model, torch.nn.DataParallel): 291 | model = model.module 292 | for param_group in model.gen_optimizer.param_groups: 293 | param_group['lr'] = model.g_lr * (0.8 ** count) 294 | print('===> Current G learning rate: ', param_group['lr']) 295 | for param_group in model.dis_optimizer.param_groups: 296 | param_group['lr'] = model.d_lr * (0.8 ** count) 297 | print('===> Current D learning rate: ', param_group['lr']) 298 | 299 | if (epoch+1) % (opt.snapshots) == 0: 300 | checkpoint(epoch) 301 | if opt.tb: 302 | writer.close() 303 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch import autograd 7 | import torchvision.models as models 8 | from networks import G_Net, D_Net 9 | 10 | class AdversarialLoss(nn.Module): 11 | """ 12 | Adversarial loss 13 | https://arxiv.org/abs/1711.10337 14 | """ 15 | 16 | def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0): 17 | """ 18 | type = nsgan | lsgan | hinge 19 | """ 20 | super(AdversarialLoss, self).__init__() 21 | 22 | self.type = type 23 | self.register_buffer('real_label', torch.tensor(target_real_label)) 24 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 25 | 26 | if type == 'nsgan': 27 | self.criterion = nn.BCELoss() 28 | 29 | elif type == 'lsgan': 30 | self.criterion = nn.MSELoss() 31 | 32 | elif type == 'hinge': 33 | self.criterion = nn.ReLU() 34 | 35 | def __call__(self, outputs, is_real, is_disc=None): 36 | if self.type == 'hinge': 37 | if is_disc: 38 | if is_real: 39 | outputs = -outputs 40 | return self.criterion(1 + outputs).mean() 41 | else: 42 | return (-outputs).mean() 43 | 44 | else: 45 | labels = (self.real_label if is_real else self.fake_label).expand_as(outputs) 46 | loss = self.criterion(outputs, labels) 47 | return loss 48 | 49 | class InpaintingModel(nn.Module): 50 | def __init__(self, g_lr, d_lr, l1_weight, gan_weight, iter=0, threshold=None): 51 | super(InpaintingModel, self).__init__() 52 | 53 | self.generator = G_Net(input_channels=3, residual_blocks=8, threshold=threshold) 54 | self.discriminator = D_Net(in_channels=3, use_sigmoid=True) 55 | 56 | self.l1_loss = nn.L1Loss() 57 | self.adversarial_loss = AdversarialLoss('nsgan') 58 | 59 | self.g_lr, self.d_lr = g_lr, d_lr 60 | 61 | self.l1_weight, self.gan_weight = l1_weight, gan_weight 62 | 63 | self.global_iter = iter 64 | 65 | self.gen_optimizer = optim.Adam( 66 | params=self.generator.parameters(), 67 | lr=float(self.g_lr), 68 | betas=(0., 0.9) 69 | ) 70 | 71 | self.dis_optimizer = optim.Adam( 72 | params=self.discriminator.parameters(), 73 | lr=float(self.d_lr), 74 | betas=(0., 0.9) 75 | ) 76 | 77 | 78 | # if __name__ == '__main__': 79 | -------------------------------------------------------------------------------- /module_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | 7 | def print_network(net): 8 | num_params = 0 9 | for param in net.parameters(): 10 | num_params += param.numel() 11 | print(net) 12 | print('Total number of parameters: %d' % num_params) 13 | 14 | def initialize_weights(net_l, scale=1): 15 | if not isinstance(net_l, list): 16 | net_l = [net_l] 17 | for net in net_l: 18 | for m in net.modules(): 19 | if isinstance(m, nn.Conv2d): 20 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 21 | m.weight.data *= scale # for residual block 22 | if m.bias is not None: 23 | init.normal_(m.bias, 0.0001) 24 | elif isinstance(m, nn.Linear): 25 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 26 | m.weight.data *= scale 27 | if m.bias is not None: 28 | init.normal_(m.bias, 0.0001) 29 | elif isinstance(m, nn.BatchNorm2d): 30 | try: 31 | init.constant_(m.weight, 1) 32 | init.normal_(m.bias, 0.0001) 33 | except: 34 | print('This layer has no BN parameters:', m) 35 | 36 | def make_layer(block, n_layers): 37 | layers = [] 38 | for _ in range(n_layers): 39 | layers.append(block()) 40 | return nn.Sequential(*layers) 41 | 42 | class SELayer(nn.Module): 43 | def __init__(self, channel, reduction=16): 44 | super(SELayer, self).__init__() 45 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 46 | # self.spatialpool = SpatialPool(channel) 47 | self.fc = nn.Sequential( 48 | nn.Linear(channel, channel // reduction, bias=True), 49 | nn.ReLU(inplace=True), 50 | nn.Linear(channel // reduction, channel, bias=True), 51 | ) 52 | 53 | def forward(self, x): 54 | b, c, _, _ = x.size() 55 | y = self.avg_pool(x).view(b, c) 56 | y = self.fc(y).view(b,c,1) 57 | return y 58 | 59 | # class ResidualBlock_noBN(nn.Module): 60 | # def __init__(self, nf=64, stride=1, downsample=None, reduction=4): 61 | # super(ResidualBlock_noBN, self).__init__() 62 | # self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 63 | # self.relu = nn.ReLU(inplace=True) 64 | # self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 65 | # self.se = SELayer(nf, reduction) 66 | # self.sigmoid = nn.Sigmoid() 67 | 68 | # def forward(self, x): 69 | # residual = x 70 | # out = self.relu(self.conv1(x)) 71 | # out = self.conv2(out) 72 | # y = self.se(out) 73 | # y = self.sigmoid(y.view(y.size(0),-1)) 74 | # y = y.view(y.size(0), y.size(1),1,1) 75 | # out = torch.mul(out, y) 76 | # out += residual 77 | # return out 78 | 79 | class ResidualBlock_noBN(nn.Module): 80 | '''Residual block w/o BN 81 | ---Conv-ReLU-Conv-+- 82 | |________________| 83 | ''' 84 | def __init__(self, nf=64): 85 | super(ResidualBlock_noBN, self).__init__() 86 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 87 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 88 | 89 | # initialization 90 | initialize_weights([self.conv1, self.conv2], 0.1) 91 | 92 | def forward(self, x): 93 | identity = x 94 | out = F.relu(self.conv1(x), inplace=True) 95 | out = self.conv2(out) 96 | return identity + out 97 | 98 | 99 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): 100 | """Warp an image or feature map with optical flow 101 | Args: 102 | x (Tensor): size (N, C, H, W) 103 | flow (Tensor): size (N, H, W, 2), normal value 104 | interp_mode (str): 'nearest' or 'bilinear' 105 | padding_mode (str): 'zeros' or 'border' or 'reflection' 106 | 107 | Returns: 108 | Tensor: warped image or feature map 109 | """ 110 | assert x.size()[-2:] == flow.size()[1:3] 111 | B, C, H, W = x.size() 112 | # mesh grid 113 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 114 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 115 | grid.requires_grad = False 116 | grid = grid.type_as(x) 117 | vgrid = grid + flow 118 | # scale grid to [-1,1] 119 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 120 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 121 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 122 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) 123 | return output 124 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torchvision.transforms import * 6 | import torch.nn.functional as F 7 | 8 | from rn import RN_B, RN_L 9 | 10 | 11 | class G_Net(nn.Module): 12 | def __init__(self, input_channels, residual_blocks, threshold): 13 | super(G_Net, self).__init__() 14 | 15 | # Encoder 16 | self.encoder_prePad = nn.ReflectionPad2d(3) 17 | self.encoder_conv1 = nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=7, padding=0) 18 | self.encoder_in1 = RN_B(feature_channels=64) 19 | self.encoder_relu1 = nn.ReLU(True) 20 | self.encoder_conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1) 21 | self.encoder_in2 = RN_B(feature_channels=128) 22 | self.encoder_relu2 = nn.ReLU(True) 23 | self.encoder_conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1) 24 | self.encoder_in3 = RN_B(feature_channels=256) 25 | self.encoder_relu3 = nn.ReLU(True) 26 | 27 | 28 | # Middle 29 | blocks = [] 30 | for _ in range(residual_blocks): 31 | # block = ResnetBlock(256, 2, use_spectral_norm=False) 32 | block = saRN_ResnetBlock(256, dilation=2, threshold=threshold, use_spectral_norm=False) 33 | blocks.append(block) 34 | 35 | self.middle = nn.Sequential(*blocks) 36 | 37 | 38 | # Decoder 39 | self.decoder = nn.Sequential( 40 | nn.Conv2d(256, 128*4, kernel_size=3, stride=1, padding=1), 41 | nn.PixelShuffle(2), 42 | RN_L(128), 43 | nn.ReLU(True), 44 | 45 | nn.Conv2d(128, 64*4, kernel_size=3, stride=1, padding=1), 46 | nn.PixelShuffle(2), 47 | RN_L(64), 48 | nn.ReLU(True), 49 | 50 | nn.ReflectionPad2d(3), 51 | nn.Conv2d(in_channels=64, out_channels=input_channels, kernel_size=7, padding=0) 52 | 53 | ) 54 | 55 | 56 | def encoder(self, x, mask): 57 | x = self.encoder_prePad(x) 58 | 59 | x = self.encoder_conv1(x) 60 | x = self.encoder_in1(x, mask) 61 | x = self.encoder_relu2(x) 62 | 63 | x = self.encoder_conv2(x) 64 | x = self.encoder_in2(x, mask) 65 | x = self.encoder_relu2(x) 66 | 67 | x = self.encoder_conv3(x) 68 | x = self.encoder_in3(x, mask) 69 | x = self.encoder_relu3(x) 70 | return x 71 | 72 | def forward(self, x, mask): 73 | gt = x 74 | x = (x * (1 - mask).float()) + mask 75 | # input mask: 1 for hole, 0 for valid 76 | x = self.encoder(x, mask) 77 | 78 | x = self.middle(x) 79 | 80 | x = self.decoder(x) 81 | 82 | x = (torch.tanh(x) + 1) / 2 83 | # x = x*mask+gt*(1-mask) 84 | return x 85 | 86 | 87 | # original D 88 | class D_Net(nn.Module): 89 | def __init__(self, in_channels, use_sigmoid=True, use_spectral_norm=True): 90 | super(D_Net, self).__init__() 91 | self.use_sigmoid = use_sigmoid 92 | 93 | self.conv1 = nn.Sequential( 94 | spectral_norm(nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), 95 | nn.LeakyReLU(0.2, inplace=True), 96 | ) 97 | 98 | self.conv2 = nn.Sequential( 99 | spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), 100 | nn.LeakyReLU(0.2, inplace=True), 101 | ) 102 | 103 | self.conv3 = nn.Sequential( 104 | spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=not use_spectral_norm), use_spectral_norm), 105 | nn.LeakyReLU(0.2, inplace=True), 106 | ) 107 | 108 | self.conv4 = nn.Sequential( 109 | spectral_norm(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm), 110 | nn.LeakyReLU(0.2, inplace=True), 111 | ) 112 | 113 | self.conv5 = nn.Sequential( 114 | spectral_norm(nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=1, bias=not use_spectral_norm), use_spectral_norm), 115 | ) 116 | 117 | 118 | def forward(self, x): 119 | conv1 = self.conv1(x) 120 | conv2 = self.conv2(conv1) 121 | conv3 = self.conv3(conv2) 122 | conv4 = self.conv4(conv3) 123 | conv5 = self.conv5(conv4) 124 | 125 | outputs = conv5 126 | if self.use_sigmoid: 127 | outputs = torch.sigmoid(conv5) 128 | 129 | return outputs, [conv1, conv2, conv3, conv4, conv5] 130 | 131 | 132 | 133 | 134 | 135 | class ResnetBlock(nn.Module): 136 | def __init__(self, dim, dilation=1, use_spectral_norm=True): 137 | super(ResnetBlock, self).__init__() 138 | self.conv_block = nn.Sequential( 139 | nn.ReflectionPad2d(dilation), 140 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=256, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm), 141 | nn.InstanceNorm2d(256, track_running_stats=False), 142 | nn.ReLU(True), 143 | 144 | nn.ReflectionPad2d(1), 145 | spectral_norm(nn.Conv2d(in_channels=256, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm), 146 | nn.InstanceNorm2d(dim, track_running_stats=False), 147 | ) 148 | 149 | def forward(self, x): 150 | out = x + self.conv_block(x) 151 | 152 | # Remove ReLU at the end of the residual block 153 | # http://torch.ch/blog/2016/02/04/resnets.html 154 | 155 | return out 156 | 157 | class saRN_ResnetBlock(nn.Module): 158 | def __init__(self, dim, dilation, threshold, use_spectral_norm=True): 159 | super(saRN_ResnetBlock, self).__init__() 160 | self.conv_block = nn.Sequential( 161 | nn.ReflectionPad2d(dilation), 162 | spectral_norm(nn.Conv2d(in_channels=dim, out_channels=256, kernel_size=3, padding=0, dilation=dilation, bias=not use_spectral_norm), use_spectral_norm), 163 | # nn.InstanceNorm2d(256, track_running_stats=False), 164 | RN_L(feature_channels=256, threshold=threshold), 165 | nn.ReLU(True), 166 | 167 | nn.ReflectionPad2d(1), 168 | spectral_norm(nn.Conv2d(in_channels=256, out_channels=dim, kernel_size=3, padding=0, dilation=1, bias=not use_spectral_norm), use_spectral_norm), 169 | # nn.InstanceNorm2d(dim, track_running_stats=False), 170 | RN_L(feature_channels=dim, threshold=threshold), 171 | ) 172 | 173 | def forward(self, x): 174 | out = x + self.conv_block(x) 175 | # skimage.io.imsave('block.png', out[0].detach().permute(1,2,0).cpu().numpy()[:,:,0]) 176 | 177 | # Remove ReLU at the end of the residual block 178 | # http://torch.ch/blog/2016/02/04/resnets.html 179 | 180 | return out 181 | 182 | def spectral_norm(module, mode=True): 183 | if mode: 184 | return nn.utils.spectral_norm(module) 185 | 186 | return module 187 | 188 | 189 | if __name__ == '__main__': 190 | print("No Abnormal!") 191 | -------------------------------------------------------------------------------- /pretrained_model/x_admin.cluster.localRN-0.8RN-Net_bs_14_epoch_3.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geekyutao/RN/b0a9476bd78d0681b572aebb881766c2072f94dc/pretrained_model/x_admin.cluster.localRN-0.8RN-Net_bs_14_epoch_3.pth -------------------------------------------------------------------------------- /rn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class RN_binarylabel(nn.Module): 6 | def __init__(self, feature_channels): 7 | super(RN_binarylabel, self).__init__() 8 | self.bn_norm = nn.BatchNorm2d(feature_channels, affine=False, track_running_stats=False) 9 | 10 | def forward(self, x, label): 11 | ''' 12 | input: x: (B,C,M,N), features 13 | label: (B,1,M,N), 1 for foreground regions, 0 for background regions 14 | output: _x: (B,C,M,N) 15 | ''' 16 | label = label.detach() 17 | 18 | rn_foreground_region = self.rn(x * label, label) 19 | 20 | rn_background_region = self.rn(x * (1 - label), 1 - label) 21 | 22 | return rn_foreground_region + rn_background_region 23 | 24 | def rn(self, region, mask): 25 | ''' 26 | input: region: (B,C,M,N), 0 for surroundings 27 | mask: (B,1,M,N), 1 for target region, 0 for surroundings 28 | output: rn_region: (B,C,M,N) 29 | ''' 30 | shape = region.size() 31 | 32 | sum = torch.sum(region, dim=[0,2,3]) # (B, C) -> (C) 33 | Sr = torch.sum(mask, dim=[0,2,3]) # (B, 1) -> (1) 34 | Sr[Sr==0] = 1 35 | mu = (sum / Sr) # (B, C) -> (C) 36 | 37 | return self.bn_norm(region + (1 - mask) * mu[None,:,None,None]) * \ 38 | (torch.sqrt(Sr / (shape[0] * shape[2] * shape[3])))[None,:,None,None] 39 | 40 | class RN_B(nn.Module): 41 | def __init__(self, feature_channels): 42 | super(RN_B, self).__init__() 43 | ''' 44 | input: tensor(features) x: (B,C,M,N) 45 | condition Mask: (B,1,H,W): 0 for background, 1 for foreground 46 | return: tensor RN_B(x): (N,C,M,N) 47 | --------------------------------------- 48 | args: 49 | feature_channels: C 50 | ''' 51 | # RN 52 | self.rn = RN_binarylabel(feature_channels) # need no external parameters 53 | 54 | # gamma and beta 55 | self.foreground_gamma = nn.Parameter(torch.zeros(feature_channels), requires_grad=True) 56 | self.foreground_beta = nn.Parameter(torch.zeros(feature_channels), requires_grad=True) 57 | self.background_gamma = nn.Parameter(torch.zeros(feature_channels), requires_grad=True) 58 | self.background_beta = nn.Parameter(torch.zeros(feature_channels), requires_grad=True) 59 | 60 | def forward(self, x, mask): 61 | # mask = F.adaptive_max_pool2d(mask, output_size=x.size()[2:]) 62 | mask = F.interpolate(mask, size=x.size()[2:], mode='nearest') # after down-sampling, there can be all-zero mask 63 | 64 | rn_x = self.rn(x, mask) 65 | 66 | rn_x_foreground = (rn_x * mask) * (1 + self.foreground_gamma[None,:,None,None]) + self.foreground_beta[None,:,None,None] 67 | rn_x_background = (rn_x * (1 - mask)) * (1 + self.background_gamma[None,:,None,None]) + self.background_beta[None,:,None,None] 68 | 69 | return rn_x_foreground + rn_x_background 70 | 71 | class SelfAware_Affine(nn.Module): 72 | def __init__(self, kernel_size=7): 73 | super(SelfAware_Affine, self).__init__() 74 | 75 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 76 | padding = 3 if kernel_size == 7 else 1 77 | 78 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 79 | self.sigmoid = nn.Sigmoid() 80 | 81 | self.gamma_conv = nn.Conv2d(1, 1, kernel_size, padding=padding) 82 | self.beta_conv = nn.Conv2d(1, 1, kernel_size, padding=padding) 83 | 84 | def forward(self, x): 85 | avg_out = torch.mean(x, dim=1, keepdim=True) 86 | max_out, _ = torch.max(x, dim=1, keepdim=True) 87 | x = torch.cat([avg_out, max_out], dim=1) 88 | 89 | x = self.conv1(x) 90 | importance_map = self.sigmoid(x) 91 | 92 | gamma = self.gamma_conv(importance_map) 93 | beta = self.beta_conv(importance_map) 94 | 95 | return importance_map, gamma, beta 96 | 97 | class RN_L(nn.Module): 98 | def __init__(self, feature_channels, threshold=0.8): 99 | super(RN_L, self).__init__() 100 | ''' 101 | input: tensor(features) x: (B,C,M,N) 102 | return: tensor RN_L(x): (B,C,M,N) 103 | --------------------------------------- 104 | args: 105 | feature_channels: C 106 | ''' 107 | # SelfAware_Affine 108 | self.sa = SelfAware_Affine() 109 | self.threshold = threshold 110 | 111 | # RN 112 | self.rn = RN_binarylabel(feature_channels) # need no external parameters 113 | 114 | 115 | def forward(self, x): 116 | 117 | sa_map, gamma, beta = self.sa(x) # (B,1,M,N) 118 | 119 | # m = sa_map.detach() 120 | if x.is_cuda: 121 | mask = torch.zeros_like(sa_map).cuda() 122 | else: 123 | mask = torch.zeros_like(sa_map) 124 | mask[sa_map.detach() >= self.threshold] = 1 125 | 126 | rn_x = self.rn(x, mask.expand(x.size())) 127 | 128 | rn_x = rn_x * (1 + gamma) + beta 129 | 130 | return rn_x 131 | --------------------------------------------------------------------------------