├── loss ├── __init__.py ├── hybridloss.py ├── multiCE.py ├── iou.py └── ssim.py ├── models ├── __init__.py ├── vgg.py ├── decoder.py ├── deblurnet.py ├── fpn.py └── resnet.py ├── README.md ├── cfg.py.example ├── test.py ├── utils.py ├── DataCreator ├── Extract_Object.py ├── DataGenerator.py └── Blend.py ├── .gitignore ├── blur_merge.py ├── MTransform.py ├── predict.py ├── DataSetLoader └── MDataSet.py └── train.py /loss/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Defocus-and-Motion-Blur-Detection-with-Deep-Contextual-Features 2 | An implementation of the method in *Defocus and Motion Blur Detection with Deep Contextual Features* with pytorch. 3 | -------------------------------------------------------------------------------- /cfg.py.example: -------------------------------------------------------------------------------- 1 | Configs = { 2 | "train_image_dir": "", 3 | "train_mask_dir": "", 4 | "test_image_dir": "", 5 | "test_mask_dir": "", 6 | "encoder_learning_rate": 0.0001, 7 | "decoder_lr_scale": 10, 8 | "train_batch_size": 10, 9 | "test_batch_size": 1, 10 | "epoch": 1000, 11 | "model_save_path": "", 12 | "pre_path": "", 13 | 'checkpoint_dir': "", 14 | "cross_entropy_weights": [], 15 | "device_ids": [0], 16 | "augmentation": True, 17 | "l_bce":1, 18 | "l_ssim":1, 19 | "l_IoU":1, 20 | "encoder_type": "vgg19", 21 | "skip_out_channel":128, 22 | "fpn":True, 23 | 'fpn_out': 2 24 | } 25 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | from DataSetLoader.MDataSet import BlurDataSet 3 | from torch.utils.data import DataLoader 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import os 8 | from models.fpn import FPN 9 | from models.resnet import ResNet34,ResNet50 10 | from cfg import Configs 11 | 12 | if __name__ == '__main__': 13 | dataset = BlurDataSet(Configs["test_image_dir"],Configs["test_mask_dir"],True) 14 | 15 | for i,data in enumerate(dataset): 16 | image,targets = data 17 | img = np.transpose(image.numpy(),(1,2,0)) 18 | cv2.imshow("img",img) 19 | target = np.transpose(targets[0].numpy(),(1,2,0))*120 20 | cv2.imshow("target",target) 21 | cv2.waitKey(0) 22 | exit(0) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from models.vgg import VGG19_bn, VGG19 2 | from models.resnet import ResNet152, ResNet34, ResNet50 3 | 4 | layer_channels = { 5 | "vgg19": [512, 512, 256, 128, 64], 6 | "vgg19_bn": [512, 512, 256, 128, 64], 7 | "resnet34": [512, 512, 256, 128, 64], 8 | "resnet152": [2048, 2048, 1024, 512, 256], 9 | "resnet50": [2048, 2048, 1024, 512, 256], 10 | } 11 | 12 | 13 | def get_encoder(config): 14 | if config['encoder_type'] == "vgg19": 15 | encoder = VGG19() 16 | elif config['encoder_type'] == "vgg19_bn": 17 | encoder = VGG19_bn() 18 | elif config['encoder_type'] == "resnet152": 19 | encoder = ResNet152() 20 | elif config['encoder_type'] == "resnet34": 21 | encoder = ResNet34() 22 | elif config['encoder_type'] == "resnet50": 23 | encoder = ResNet50() 24 | else: 25 | raise RuntimeError("invalid encoder type") 26 | 27 | return encoder 28 | 29 | def get_fpn_skip(config): 30 | return layer_channels[config['encoder_type']][::-1] 31 | -------------------------------------------------------------------------------- /DataCreator/Extract_Object.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | 5 | 6 | def object_extract(mask,img): 7 | 8 | extracted = np.copy(img) 9 | con = (mask == 0) 10 | extracted[con] = 0 11 | # cv2.normalize(extracted, extracted, 0, 255, cv2.NORM_MINMAX) 12 | return extracted.astype(np.uint8) 13 | 14 | 15 | def extract(image_dir, mask_dir, out_dir): 16 | for _,_,file_names in os.walk(image_dir): 17 | for fileName in file_names: 18 | img_path = os.path.join(image_dir, fileName) 19 | mask_path = os.path.join(mask_dir, fileName) 20 | img = cv2.imread(img_path) 21 | mask = cv2.imread(mask_path) 22 | out_path = os.path.join(out_dir, fileName) 23 | out = object_extract(mask,img) 24 | cv2.imwrite(out_path, out) 25 | print(fileName + ' extract complete') 26 | 27 | if __name__ == '__main__': 28 | p1 = 'C:\\Users\\Whale\\Documents\\LAB\\DataSet\\HKU-IS\\imgs' 29 | p2 = 'C:\\Users\\Whale\\Documents\\LAB\\DataSet\\HKU-IS\\gt' 30 | p3 = 'C:\\Users\\Whale\\Documents\\LAB\\DataSet\\HKU-IS\\objects' 31 | extract(p1, p2, p3) -------------------------------------------------------------------------------- /loss/hybridloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from loss.iou import IOU,MultiIOU 3 | from loss.ssim import SSIM 4 | from loss.multiCE import MultiCrossEntropyLoss 5 | import torch.nn.functional as F 6 | 7 | class HybridLoss(torch.nn.Module): 8 | def __init__(self, ce_weight, ssim_weight, iou_weight): 9 | super(HybridLoss, self).__init__() 10 | self.ce = MultiCrossEntropyLoss() 11 | self.ssim_loss = SSIM(window_size=11, size_average=True) 12 | self.iou_loss = MultiIOU(size_average=True) 13 | self.ce_w = ce_weight 14 | self.ssim_w = ssim_weight 15 | self.iou_w = iou_weight 16 | 17 | def forward(self, output, target): 18 | ce_loss = self.ce(output, target) 19 | onehot_target = [torch.zeros(output[i].shape).cuda(output[i].get_device()).scatter_(1, target[i].unsqueeze(1), 1) for i in range(min(len(target),len(output)))] 20 | output_prob = [F.softmax(i, dim=1) for i in output] 21 | iou_loss = self.iou_loss(output_prob, onehot_target) 22 | ssim_loss = self.ssim_loss(output_prob[0], onehot_target[0]) 23 | return self.ce_w*ce_loss + self.ssim_w*ssim_loss + self.iou_w*iou_loss, ce_loss, ssim_loss, iou_loss -------------------------------------------------------------------------------- /loss/multiCE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from cfg import Configs 3 | import torch.nn.functional as F 4 | 5 | 6 | class MultiCrossEntropyLoss(torch.nn.Module): 7 | def __init__(self): 8 | super(MultiCrossEntropyLoss, self).__init__() 9 | 10 | def forward(self, output, target): 11 | level_num = min(len(output),len(target)) 12 | total_loss = 0 13 | for i in range(level_num): 14 | level_loss = F.cross_entropy(output[i], target[i], 15 | weight=torch.FloatTensor(Configs["cross_entropy_weights"]).cuda(output[0].get_device()))/ (4 ** (level_num -1 - i)) 16 | total_loss += level_loss 17 | # loss_1 = F.cross_entropy(output[1], target[1], 18 | # weight=torch.FloatTensor(Configs["cross_entropy_weights"]).cuda(output[1].get_device())) / 16 19 | # loss_2 = F.cross_entropy(output[2], target[2], 20 | # weight=torch.FloatTensor(Configs["cross_entropy_weights"]).cuda(output[2].get_device())) / 4 21 | # loss_3 = F.cross_entropy(output[3], target[3], 22 | # weight=torch.FloatTensor(Configs["cross_entropy_weights"]).cuda(output[3].get_device())) 23 | # total_loss = loss_0 + loss_1 + loss_2 + loss_3 24 | return total_loss -------------------------------------------------------------------------------- /models/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models.vgg import vgg19_bn,vgg19 3 | 4 | 5 | class VGG19(torch.nn.Module): 6 | def __init__(self): 7 | super(VGG19,self).__init__() 8 | self.vgg = vgg19(pretrained=True) 9 | self.encoder_1 = self.vgg.features[0:4] 10 | self.encoder_2 = self.vgg.features[4:9] 11 | self.encoder_3 = self.vgg.features[9:18] 12 | self.encoder_4 = self.vgg.features[18:27] 13 | self.encoder_5 = self.vgg.features[27:36] 14 | 15 | def forward(self,x): 16 | skip_1 = self.encoder_1(x) 17 | skip_2 = self.encoder_2(skip_1) 18 | skip_3 = self.encoder_3(skip_2) 19 | skip_4 = self.encoder_4(skip_3) 20 | x = self.encoder_5(skip_4) 21 | return skip_1, skip_2, skip_3, skip_4, x 22 | 23 | 24 | class VGG19_bn(torch.nn.Module): 25 | def __init__(self): 26 | super(VGG19_bn,self).__init__() 27 | self.vgg = vgg19_bn(pretrained=True) 28 | 29 | self.encoder_1 = self.vgg.features[0:6] 30 | self.encoder_2 = self.vgg.features[6:13] 31 | self.encoder_3 = self.vgg.features[13:26] 32 | self.encoder_4 = self.vgg.features[26:39] 33 | self.encoder_5 = self.vgg.features[39:-1] 34 | 35 | def forward(self,x): 36 | skip_1 = self.encoder_1(x) 37 | skip_2 = self.encoder_2(skip_1) 38 | skip_3 = self.encoder_3(skip_2) 39 | skip_4 = self.encoder_4(skip_3) 40 | x = self.encoder_5(skip_4) 41 | return skip_1, skip_2, skip_3, skip_4, x 42 | -------------------------------------------------------------------------------- /loss/iou.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | 6 | def _iou(pred, target, size_average = True): 7 | 8 | b = pred.shape[0] 9 | IoU = 0.0 10 | for i in range(0,b): 11 | #compute the IoU of the foreground 12 | Iand1 = torch.sum(target[i,:,:,:]*pred[i,:,:,:]) 13 | Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:])-Iand1 14 | IoU1 = Iand1/Ior1 15 | 16 | #IoU loss is (1-IoU1) 17 | IoU = IoU + (1-IoU1) 18 | 19 | return IoU/b 20 | 21 | class IOU(torch.nn.Module): 22 | def __init__(self, size_average = True): 23 | super(IOU, self).__init__() 24 | self.size_average = size_average 25 | 26 | def forward(self, pred, target): 27 | return _iou(pred, target, self.size_average) 28 | 29 | class MultiIOU(torch.nn.Module): 30 | def __init__(self, size_average = True): 31 | super(MultiIOU, self).__init__() 32 | self.size_average = size_average 33 | 34 | 35 | def forward(self, output, target): 36 | level_num = min(len(output),len(target)) 37 | total_loss = 0 38 | for i in range(level_num): 39 | level_loss = _iou(output[i], target[i],self.size_average)/(4 ** (level_num - 1 - i)) 40 | total_loss += level_loss 41 | # loss_0 = _iou(output[0], target[0],self.size_average)/64 42 | # loss_1 = _iou(output[1], target[1],self.size_average)/16 43 | # loss_2 = _iou(output[2], target[2],self.size_average)/4 44 | # loss_3 = _iou(output[3], target[3],self.size_average) 45 | 46 | # total_loss = loss_0 + loss_1 + loss_2 + loss_3 47 | return total_loss 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | submit/ 3 | 4 | cfg.py 5 | 6 | # model files 7 | *.pth 8 | 9 | runs/ 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | -------------------------------------------------------------------------------- /blur_merge.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import argparse 3 | from glob import glob 4 | import os 5 | import numpy as np 6 | 7 | def merge_class(img, map): 8 | class_num = 0 9 | for i in map.keys(): 10 | img[img == i] = map[i] 11 | if class_num < map[i]: 12 | class_num = map[i] 13 | return img, img*(255//(class_num)) 14 | 15 | if __name__ == '__main__': 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--prediction_dir", '-p', type=str, required=True) 18 | parser.add_argument("--target_dir", '-t', type=str, required=True) 19 | parser.add_argument("--merge_map", '-m', type=str, required=True, nargs='+') 20 | parser.add_argument("--save_dir",'-s',type=str, default='./submit') 21 | args = parser.parse_args() 22 | map = {} 23 | for i in args.merge_map: 24 | a_cls,b_cls = int(str.split(i,'-')[0]), int(str.split(i,'-')[1]) 25 | map[a_cls] = b_cls 26 | tar_paths = glob(os.path.join(args.target_dir, "*")) 27 | img_paths = glob(os.path.join(args.prediction_dir, "*")) 28 | save_gt_dir = os.path.join(args.save_dir,'merge_gt') 29 | save_pred_dir = os.path.join(args.save_dir,'merge_predict') 30 | save_view_dir = os.path.join(args.save_dir,'merge_view') 31 | os.makedirs(save_gt_dir, exist_ok=True) 32 | os.makedirs(save_pred_dir, exist_ok=True) 33 | os.makedirs(save_view_dir, exist_ok=True) 34 | 35 | view_preds = [] 36 | view_tars = [] 37 | 38 | for i in tar_paths: 39 | img = cv2.imread(i) 40 | image_name = os.path.basename(i) 41 | img,view = merge_class(img,map) 42 | img = img[:,:,0] 43 | view = view[:,:,0] 44 | view_tars.append(view) 45 | cv2.imwrite(os.path.join(save_gt_dir,image_name),img) 46 | 47 | for i in img_paths: 48 | img = cv2.imread(i) 49 | image_name = os.path.basename(i) 50 | img, view = merge_class(img, map) 51 | img = img[:, :, 0] 52 | view = view[:, :, 0] 53 | view_preds.append(view) 54 | cv2.imwrite(os.path.join(save_pred_dir,image_name),img) 55 | 56 | for i in zip(view_tars,view_preds,img_paths): 57 | view = np.vstack((i[0],i[1])) 58 | image_name = os.path.basename(i[2]) 59 | cv2.imwrite(os.path.join(save_view_dir,image_name),view) -------------------------------------------------------------------------------- /DataCreator/DataGenerator.py: -------------------------------------------------------------------------------- 1 | from DataCreator.Blend import * 2 | import numpy as np 3 | import os 4 | import cv2 5 | 6 | class Sampler: 7 | def __init__(self, defocus_img_dir, defocus_mask_dir, filte=None): 8 | self.image_path_list = [] 9 | self.mask_path_list = [] 10 | for _, _, file_names in os.walk(defocus_img_dir): 11 | for file in file_names: 12 | if filte is not None and not filte(file): 13 | continue 14 | self.image_path_list.append(os.path.join(defocus_img_dir,file)) 15 | for _, _, file_names in os.walk(defocus_mask_dir): 16 | for file in file_names: 17 | if filte is not None and not filte(file): 18 | continue 19 | self.mask_path_list.append(os.path.join(defocus_mask_dir, file)) 20 | 21 | def sample(self, num): 22 | index = np.array(range(len(self.image_path_list))) 23 | samples_index = np.random.choice(index,num,replace=False).tolist() 24 | samples_path = [] 25 | for i in samples_index: 26 | samples_path.append((self.image_path_list[i],self.mask_path_list[i])) 27 | return samples_path 28 | 29 | 30 | def defocus_filte(file_name): 31 | return file_name.find('out_of_focus', 0, len(file_name)) >= 0 32 | 33 | 34 | if __name__ == '__main__': 35 | defocus_sampler = Sampler("C:\\Users\\Whale\\Documents\\LAB\\DataSet\\CUHK\\image", 36 | "C:\\Users\\Whale\\Documents\\LAB\\DataSet\\CUHK\\gt", 37 | defocus_filte) 38 | motion_sampler = Sampler("C:\\Users\\Whale\\Documents\\LAB\\DataSet\\HKU-IS\\objects", 39 | "C:\\Users\\Whale\\Documents\\LAB\\DataSet\\HKU-IS\\gt") 40 | 41 | out_put_image_dir = "C:\\Users\\Whale\\Documents\\LAB\\DataSet\\SyncBlur\\image" 42 | out_put_mask_dir = "C:\\Users\\Whale\\Documents\\LAB\\DataSet\\SyncBlur\\gt2" 43 | 44 | count = 0 45 | d_samples = defocus_sampler.sample(num=564) 46 | for d_sample in d_samples: 47 | d_img = cv2.imread(d_sample[0]) 48 | d_mask = cv2.imread(d_sample[1]) 49 | m_samples = motion_sampler.sample(num=15) 50 | for m_sample in m_samples: 51 | m_img = cv2.imread(m_sample[0]) 52 | m_mask = cv2.imread(m_sample[1]) 53 | blended, motion_blur_mask, _ = both_blur_image_creator(np.copy(d_img), np.copy(m_img), np.copy(d_mask),np.copy(m_mask), (400, 300)) 54 | b_mask = blended_mask(np.copy(d_mask), np.copy(motion_blur_mask), (400, 300)) 55 | 56 | count += 1 57 | image_name = "{:0>5d}".format(count) + '.png' 58 | mask_name = "{:0>5d}".format(count) + '.png' 59 | cv2.imwrite(os.path.join(out_put_image_dir, image_name), blended) 60 | cv2.imwrite(os.path.join(out_put_mask_dir, mask_name), b_mask) 61 | 62 | print(count) 63 | print(d_sample[0]) 64 | print(m_sample[0]) 65 | 66 | 67 | pass -------------------------------------------------------------------------------- /DataCreator/Blend.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from DataCreator.DataGenerator import * 3 | import random 4 | 5 | def motion_blur(image,blur_kernel): 6 | image = np.array(image) 7 | blurred = cv2.filter2D(image, -1, blur_kernel) 8 | # convert to uint8 9 | cv2.normalize(blurred, blurred, 0, 255, cv2.NORM_MINMAX) 10 | blurred = np.array(blurred, dtype=np.uint8) 11 | return blurred 12 | 13 | 14 | def motion_blur_kernel(degree, angle): 15 | M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1) 16 | kernel = np.diag(np.ones(degree)) 17 | kernel = cv2.warpAffine(kernel, M, (degree, degree)) 18 | kernel = kernel / degree 19 | return kernel 20 | 21 | 22 | def normalvariate_random_int(mean, variance, dmin, dmax): 23 | r = dmax + 1 24 | while r < dmin or r > dmax: 25 | r = int(random.normalvariate(mean, variance)) 26 | return r 27 | 28 | 29 | def uniform_random_int(dmin, dmax): 30 | r = random.randint(dmin,dmax) 31 | return r 32 | 33 | 34 | def random_blur_kernel(mean=50, variance=15, dmin=10, dmax=100): 35 | random_degree = normalvariate_random_int(mean, variance, dmin, dmax) 36 | random_angle = uniform_random_int(-180, 180) 37 | return motion_blur_kernel(random_degree,random_angle) 38 | 39 | 40 | def alpha_blending(defocus,motion,alpha): 41 | f_defocus = defocus.astype("float32") 42 | f_motion = motion.astype("float32") 43 | f_blended = f_defocus*(1-alpha) + f_motion * alpha 44 | return f_blended.astype("uint8") 45 | 46 | 47 | def blended_mask(defocus_mask,motion_blur_mask,shape): 48 | defocus_mask = cv2.resize(defocus_mask,shape) 49 | motion_blur_mask = cv2.resize(motion_blur_mask,shape) 50 | mask = np.zeros(defocus_mask.shape)[:, :, 0] 51 | defocus_mask[motion_blur_mask > 0] = 128 52 | mask[defocus_mask[:, :, 0] == 0] = 2 53 | mask[defocus_mask[:, :, 1] == 128] = 1 54 | mask[defocus_mask[:, :, 2] == 255] = 0 55 | return mask 56 | 57 | 58 | def both_blur_image_creator(defocus, extract_object, defocus_mask, object_mask, shape, motion_blur_threshold=0): 59 | kernel = random_blur_kernel(15, 2) 60 | defocus = cv2.resize(defocus,shape) 61 | extract_object = cv2.resize(extract_object, shape) 62 | defocus_mask = cv2.resize(defocus_mask,shape) 63 | object_mask = cv2.resize(object_mask,shape) 64 | 65 | motion_blur_img = motion_blur(extract_object, kernel) 66 | blur_mask = motion_blur(object_mask,kernel) 67 | motion_blur_mask = np.copy(blur_mask) 68 | motion_blur_mask[blur_mask > motion_blur_threshold] = 255 69 | 70 | alpha = blur_mask / 255 71 | blended = alpha_blending(defocus, motion_blur_img, alpha) 72 | return blended, motion_blur_mask, defocus_mask 73 | 74 | 75 | if __name__ == '__main__': 76 | img = cv2.imread('./0004_img.png') 77 | object = cv2.imread('./0004_object.png') 78 | mask = cv2.imread('./0004_mask.png') 79 | defocus = cv2.imread('./defocus.jpg') 80 | defocus_mask = cv2.imread('./defocus_mask.png') 81 | blended,motion_blur,motion_blur_mask,defocus_mask = both_blur_image_creator(defocus,object,defocus_mask,mask,(400,300)) 82 | blended_mask = blended_mask(defocus_mask,motion_blur_mask,(400,300)) 83 | cv2.imshow('mask', blended_mask) 84 | cv2.imshow('blended',blended) 85 | cv2.waitKey() 86 | exit(0) -------------------------------------------------------------------------------- /models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | class Decoder(torch.nn.Module): 5 | def __init__(self, channels): 6 | super(Decoder,self).__init__() 7 | 8 | self.deconv_1_1 = torch.nn.ConvTranspose2d(channels[0], channels[0], 3, 2, 1, 1) 9 | self.deconv_1_2 = torch.nn.Sequential(torch.nn.Conv2d(channels[0], channels[0], 3, 1, 1), torch.nn.BatchNorm2d(channels[0]), torch.nn.ReLU()) 10 | self.deconv_1_3 = torch.nn.Sequential(torch.nn.Conv2d(channels[0], channels[0], 3, 1, 1), torch.nn.BatchNorm2d(channels[0]), torch.nn.ReLU()) 11 | self.deconv_1_4 = torch.nn.Sequential(torch.nn.Conv2d(channels[0], channels[0], 3, 1, 1), torch.nn.BatchNorm2d(channels[0]), torch.nn.ReLU()) 12 | 13 | self.deconv_2_1 = torch.nn.ConvTranspose2d(channels[0], channels[1], 3, 2, 1, 1) 14 | self.deconv_2_2 = torch.nn.Sequential(torch.nn.Conv2d(channels[1], channels[1], 3, 1, 1), torch.nn.BatchNorm2d(channels[1]), torch.nn.ReLU()) 15 | self.deconv_2_3 = torch.nn.Sequential(torch.nn.Conv2d(channels[1], channels[1], 3, 1, 1), torch.nn.BatchNorm2d(channels[1]), torch.nn.ReLU()) 16 | self.deconv_2_4 = torch.nn.Sequential(torch.nn.Conv2d(channels[1], channels[1], 3, 1, 1), torch.nn.BatchNorm2d(channels[1]), torch.nn.ReLU()) 17 | 18 | self.deconv_3_1 = torch.nn.ConvTranspose2d(channels[1], channels[2], 3, 2, 1, 1) 19 | self.deconv_3_2 = torch.nn.Sequential(torch.nn.Conv2d(channels[2], channels[2], 3, 1, 1), torch.nn.BatchNorm2d(channels[2]), torch.nn.ReLU()) 20 | self.deconv_3_3 = torch.nn.Sequential(torch.nn.Conv2d(channels[2], channels[2], 3, 1, 1), torch.nn.BatchNorm2d(channels[2]), torch.nn.ReLU()) 21 | self.deconv_3_4 = torch.nn.Sequential(torch.nn.Conv2d(channels[2], channels[2], 3, 1, 1), torch.nn.BatchNorm2d(channels[2]), torch.nn.ReLU()) 22 | 23 | self.deconv_4_1 = torch.nn.ConvTranspose2d(channels[2], channels[3], 3, 2, 1, 1) 24 | self.deconv_4_2 = torch.nn.Sequential(torch.nn.Conv2d(channels[3], channels[3], 3, 1, 1), torch.nn.BatchNorm2d(channels[3]), torch.nn.ReLU()) 25 | self.deconv_4_3 = torch.nn.Sequential(torch.nn.Conv2d(channels[3], channels[3], 3, 1, 1), torch.nn.BatchNorm2d(channels[3]), torch.nn.ReLU()) 26 | 27 | self.out_layer_1 = torch.nn.Conv2d(channels[3], 3, 1, 1, 0) 28 | self.out_layer_2 = torch.nn.Conv2d(channels[2], 3, 1, 1, 0) 29 | self.out_layer_3 = torch.nn.Conv2d(channels[1], 3, 1, 1, 0) 30 | self.out_layer_4 = torch.nn.Conv2d(channels[0], 3, 1, 1, 0) 31 | 32 | def forward(self, x, skip_1, skip_2, skip_3, skip_4): 33 | x = self.deconv_1_1(x) 34 | torch.add(x, skip_4) 35 | x = self.deconv_1_2(x) 36 | x = self.deconv_1_3(x) 37 | x = self.deconv_1_4(x) 38 | out_4 = self.out_layer_4(x) 39 | x = self.deconv_2_1(x) 40 | torch.add(x, skip_3) 41 | x = self.deconv_2_2(x) 42 | x = self.deconv_2_3(x) 43 | x = self.deconv_2_4(x) 44 | out_3 = self.out_layer_3(x) 45 | x = self.deconv_3_1(x) 46 | torch.add(x, skip_2) 47 | x = self.deconv_3_2(x) 48 | x = self.deconv_3_3(x) 49 | x = self.deconv_3_4(x) 50 | out_2 = self.out_layer_2(x) 51 | x = self.deconv_4_1(x) 52 | torch.add(x, skip_1) 53 | x = self.deconv_4_2(x) 54 | x = self.deconv_4_3(x) 55 | out_1 = self.out_layer_1(x) 56 | # return F.softmax(out_1, dim=1), F.softmax(out_2, dim=1), F.softmax(out_3, dim=1), F.softmax(out_4, dim=1) 57 | return out_1, out_2, out_3, out_4 58 | -------------------------------------------------------------------------------- /models/deblurnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.decoder import Decoder 3 | from models.vgg import VGG19_bn, VGG19 4 | from models.resnet import ResNet152, ResNet34, ResNet50 5 | import torch.optim as optim 6 | import os 7 | from utils import get_encoder 8 | 9 | layer_channels = { 10 | "vgg19": [512, 256, 128, 64], 11 | "vgg19_bn": [512, 256, 128, 64], 12 | "resnet34": [512, 256, 128, 64], 13 | "resnet152": [2048, 1024, 512, 256], 14 | "resnet50": [2048, 1024, 512, 256], 15 | } 16 | 17 | 18 | class Net(torch.nn.Module): 19 | def __init__(self, config): 20 | super(Net,self).__init__() 21 | self.encoder_type = config["encoder_type"] 22 | self.decoder = Decoder(layer_channels[self.encoder_type]) 23 | self.encoder = get_encoder(config) 24 | self.get_skip_layer() 25 | 26 | 27 | def forward(self, x): 28 | sk_1, sk_2, sk_3, sk_4, x = self.encoder(x) 29 | sk_1 = self.skip_1(sk_1) 30 | sk_2 = self.skip_2(sk_2) 31 | sk_3 = self.skip_3(sk_3) 32 | sk_4 = self.skip_4(sk_4) 33 | return self.decoder(x, sk_1, sk_2, sk_3, sk_4) 34 | 35 | def load_model(self, prep_path, save_path): 36 | if os.path.exists(save_path): 37 | print("load from saved model:" + save_path + '...') 38 | checkpoint = torch.load(save_path) 39 | self.load_state_dict(checkpoint['model_state_dict']) 40 | ech = checkpoint['epoch'] 41 | self.eval() 42 | print("load complete") 43 | return ech 44 | else: 45 | if self.load_pre_from_local(): 46 | print("load pre-parameters:" + prep_path + '...') 47 | prep = torch.load(prep_path) 48 | model_dict = self.encoder.state_dict() 49 | prep = self.parameter_rename(prep, model_dict) 50 | pre_trained_dict = {k: v for k, v in prep.items() if k in model_dict} 51 | model_dict.update(pre_trained_dict) 52 | self.encoder.load_state_dict(model_dict) 53 | print("load complete") 54 | else: 55 | print("use pretrained model", self.encoder_type," from torchvision") 56 | return 0 57 | 58 | def save_model(self, ech, save_path): 59 | torch.save({ 60 | 'epoch': ech, 61 | 'model_state_dict': self.state_dict(), 62 | }, save_path) 63 | 64 | def parameter_rename(self, org_dict, target_dict): 65 | if self.encoder_type == "vgg19" or self.encoder_type == "vgg19_bn": 66 | org_list = [] 67 | target_list = [] 68 | for k, _ in target_dict.items(): 69 | if k.find("batches") < 0: 70 | target_list.append(k) 71 | for k, _ in org_dict.items(): 72 | if k.find("batches") < 0: 73 | org_list.append(k) 74 | replace_index = range(len(target_list)) 75 | for i in replace_index: 76 | org_dict[target_list[i]] = org_dict.pop(org_list[i]) 77 | return org_dict 78 | elif self.encoder_type == "resnet152": 79 | return target_dict 80 | 81 | 82 | def get_skip_layer(self): 83 | channel = layer_channels[self.encoder_type] 84 | if channel is not None: 85 | self.skip_1 = torch.nn.Conv2d(channel[3], channel[3], 1, 1, 0) 86 | self.skip_2 = torch.nn.Conv2d(channel[2], channel[2], 1, 1, 0) 87 | self.skip_3 = torch.nn.Conv2d(channel[1], channel[1], 1, 1, 0) 88 | self.skip_4 = torch.nn.Conv2d(channel[0], channel[0], 1, 1, 0) 89 | else: 90 | raise RuntimeError("invalid encoder type") 91 | 92 | def load_pre_from_local(self): 93 | return self.encoder_type == "vgg19" or self.encoder_type == "vgg19_bn" 94 | 95 | def optimizer_by_layer(self, encoder_lr, decoder_lr): 96 | params = [ 97 | {"params": self.encoder.parameters(), "lr": encoder_lr}, 98 | {"params": self.decoder.parameters(), "lr":decoder_lr}, 99 | {"params": self.skip_1.parameters(), "lr": encoder_lr}, 100 | {"params": self.skip_2.parameters(), "lr": encoder_lr}, 101 | {"params": self.skip_3.parameters(), "lr": encoder_lr}, 102 | {"params": self.skip_4.parameters(), "lr": encoder_lr} 103 | ] 104 | return optim.Adam(params=params, lr=encoder_lr) -------------------------------------------------------------------------------- /MTransform.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import sys 5 | import random 6 | from PIL import Image 7 | try: 8 | import accimage 9 | except ImportError: 10 | accimage = None 11 | import numpy as np 12 | import numbers 13 | import types 14 | import collections 15 | import warnings 16 | 17 | from torchvision.transforms import functional as F 18 | 19 | if sys.version_info < (3, 3): 20 | Sequence = collections.Sequence 21 | Iterable = collections.Iterable 22 | else: 23 | Sequence = collections.abc.Sequence 24 | Iterable = collections.abc.Iterable 25 | 26 | 27 | class SyncRandomCrop(object): 28 | """Crop the given PIL Image at a random location. 29 | 30 | Args: 31 | size (sequence or int): Desired output size of the crop. If size is an 32 | int instead of sequence like (h, w), a square crop (size, size) is 33 | made. 34 | padding (int or sequence, optional): Optional padding on each border 35 | of the image. Default is None, i.e no padding. If a sequence of length 36 | 4 is provided, it is used to pad left, top, right, bottom borders 37 | respectively. If a sequence of length 2 is provided, it is used to 38 | pad left/right, top/bottom borders, respectively. 39 | pad_if_needed (boolean): It will pad the image if smaller than the 40 | desired size to avoid raising an exception. Since cropping is done 41 | after padding, the padding seems to be done at a random offset. 42 | fill: Pixel fill value for constant fill. Default is 0. If a tuple of 43 | length 3, it is used to fill R, G, B channels respectively. 44 | This value is only used when the padding_mode is constant 45 | padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant. 46 | 47 | - constant: pads with a constant value, this value is specified with fill 48 | 49 | - edge: pads with the last value on the edge of the image 50 | 51 | - reflect: pads with reflection of image (without repeating the last value on the edge) 52 | 53 | padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode 54 | will result in [3, 2, 1, 2, 3, 4, 3, 2] 55 | 56 | - symmetric: pads with reflection of image (repeating the last value on the edge) 57 | 58 | padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode 59 | will result in [2, 1, 1, 2, 3, 4, 4, 3] 60 | 61 | """ 62 | 63 | def __init__(self, size, padding=None, pad_if_needed=False, fill=0, padding_mode='constant'): 64 | if isinstance(size, numbers.Number): 65 | self.size = (int(size), int(size)) 66 | else: 67 | self.size = size 68 | self.padding = padding 69 | self.pad_if_needed = pad_if_needed 70 | self.fill = fill 71 | self.padding_mode = padding_mode 72 | self.rand = True 73 | 74 | def get_params(self, img, output_size): 75 | """Get parameters for ``crop`` for a random crop. 76 | 77 | Args: 78 | img (PIL Image): Image to be cropped. 79 | output_size (tuple): Expected output size of the crop. 80 | 81 | Returns: 82 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 83 | """ 84 | w, h = img.size 85 | th, tw = output_size 86 | if w == tw and h == th: 87 | return 0, 0, h, w 88 | 89 | if self.rand: 90 | self.i = random.randint(0, h - th) 91 | self.j = random.randint(0, w - tw) 92 | 93 | return self.i, self.j, th, tw 94 | 95 | def rand_fix(self): 96 | self.rand = False 97 | 98 | def rand_active(self): 99 | self.rand = True 100 | 101 | def __call__(self, img): 102 | """ 103 | Args: 104 | img (PIL Image): Image to be cropped. 105 | 106 | Returns: 107 | PIL Image: Cropped image. 108 | """ 109 | if self.padding is not None: 110 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 111 | 112 | # pad the width if needed 113 | if self.pad_if_needed and img.size[0] < self.size[1]: 114 | img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode) 115 | # pad the height if needed 116 | if self.pad_if_needed and img.size[1] < self.size[0]: 117 | img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode) 118 | 119 | i, j, h, w = self.get_params(img, self.size) 120 | 121 | return F.crop(img, i, j, h, w) 122 | 123 | def __repr__(self): 124 | return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) -------------------------------------------------------------------------------- /loss/ssim.py: -------------------------------------------------------------------------------- 1 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from math import exp 7 | 8 | 9 | def gaussian(window_size, sigma): 10 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 11 | return gauss/gauss.sum() 12 | 13 | 14 | def create_window(window_size, channel): 15 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 16 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 17 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 18 | return window 19 | 20 | 21 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 22 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 23 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 24 | 25 | mu1_sq = mu1.pow(2) 26 | mu2_sq = mu2.pow(2) 27 | mu1_mu2 = mu1*mu2 28 | 29 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 30 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 31 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 32 | 33 | C1 = 0.01**2 34 | C2 = 0.03**2 35 | 36 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 37 | 38 | if size_average: 39 | return ssim_map.mean() 40 | else: 41 | return ssim_map.mean(1).mean(1).mean(1) 42 | 43 | 44 | class SSIM(torch.nn.Module): 45 | def __init__(self, window_size = 11, size_average = True): 46 | super(SSIM, self).__init__() 47 | self.window_size = window_size 48 | self.size_average = size_average 49 | self.channel = 1 50 | self.window = create_window(window_size, self.channel) 51 | 52 | def forward(self, img1, img2): 53 | (_, channel, _, _) = img1.size() 54 | 55 | if channel == self.channel and self.window.data.type() == img1.data.type(): 56 | window = self.window 57 | else: 58 | window = create_window(self.window_size, channel) 59 | 60 | if img1.is_cuda: 61 | window = window.cuda(img1.get_device()) 62 | window = window.type_as(img1) 63 | 64 | self.window = window 65 | self.channel = channel 66 | 67 | return 1 - _ssim(img1, img2, window, self.window_size, channel, self.size_average) 68 | 69 | 70 | def _logssim(img1, img2, window, window_size, channel, size_average = True): 71 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 72 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 73 | 74 | mu1_sq = mu1.pow(2) 75 | mu2_sq = mu2.pow(2) 76 | mu1_mu2 = mu1*mu2 77 | 78 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 79 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 80 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 81 | 82 | C1 = 0.01**2 83 | C2 = 0.03**2 84 | 85 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 86 | ssim_map = (ssim_map - torch.min(ssim_map))/(torch.max(ssim_map)-torch.min(ssim_map)) 87 | ssim_map = -torch.log(ssim_map + 1e-8) 88 | 89 | if size_average: 90 | return ssim_map.mean() 91 | else: 92 | return ssim_map.mean(1).mean(1).mean(1) 93 | 94 | 95 | class LOGSSIM(torch.nn.Module): 96 | def __init__(self, window_size = 11, size_average = True): 97 | super(LOGSSIM, self).__init__() 98 | self.window_size = window_size 99 | self.size_average = size_average 100 | self.channel = 1 101 | self.window = create_window(window_size, self.channel) 102 | 103 | def forward(self, img1, img2): 104 | (_, channel, _, _) = img1.size() 105 | 106 | if channel == self.channel and self.window.data.type() == img1.data.type(): 107 | window = self.window 108 | else: 109 | window = create_window(self.window_size, channel) 110 | 111 | if img1.is_cuda: 112 | window = window.cuda(img1.get_device()) 113 | window = window.type_as(img1) 114 | 115 | self.window = window 116 | self.channel = channel 117 | 118 | 119 | return _logssim(img1, img2, window, self.window_size, channel, self.size_average) 120 | 121 | 122 | def ssim(img1, img2, window_size = 11, size_average = True): 123 | (_, channel, _, _) = img1.size() 124 | window = create_window(window_size, channel) 125 | 126 | if img1.is_cuda: 127 | window = window.cuda(img1.get_device()) 128 | window = window.type_as(img1) 129 | 130 | return _ssim(img1, img2, window, window_size, channel, size_average) 131 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from models.deblurnet import Net 2 | from models.fpn import FPN 3 | import torch 4 | import numpy as np 5 | import cv2 6 | from glob import glob 7 | import torchvision.transforms as transforms 8 | from PIL import Image 9 | import os 10 | import tqdm 11 | import argparse 12 | from cfg import Configs 13 | from torch.nn import DataParallel 14 | import albumentations as albu 15 | 16 | 17 | class Predictor: 18 | def __init__(self, weight_path, config): 19 | assert weight_path != "" 20 | if config["fpn"]: 21 | net = FPN.fromConfig(config) 22 | else: 23 | net = Net(config) 24 | net = DataParallel(net) 25 | params = torch.load(weight_path) 26 | state_dict = params['model_state_dict'] 27 | net.load_state_dict(state_dict) 28 | net.eval() 29 | self.model = net.cuda() 30 | self.trans = transforms.Compose([ 31 | transforms.ToTensor()]) 32 | 33 | def paddingIfNeed(self,img): 34 | img = np.array(img) 35 | height, width, _ = img.shape 36 | padded_height, padded_width,_ = img.shape 37 | if padded_height % 32 != 0: 38 | padded_height = (padded_height // 32 + 1) * 32 39 | if padded_width % 32 != 0: 40 | padded_width = (padded_width // 32 + 1) * 32 41 | pad = albu.PadIfNeeded(padded_height,padded_width) 42 | crop = albu.CenterCrop(height,width) 43 | img = pad(image=img)["image"] 44 | return img, crop 45 | 46 | def predict(self, inp: str, target: str = None, merge_img=False): 47 | assert os.path.exists(inp) 48 | 49 | img = Image.open(inp, 'r') 50 | img_resize, crop = self.paddingIfNeed(img) 51 | output, output_view = self.predict_(img_resize) 52 | img_resize = crop(image=img_resize)["image"] 53 | output = crop(image=output)["image"] 54 | output_view = crop(image=output_view)["image"] 55 | if target: 56 | tar = cv2.imread(target) 57 | tar_resize = tar 58 | tar_resize_view = tar_resize* 127 59 | return np.hstack((img_resize, tar_resize_view, output_view)) if merge_img else output_view, output, tar_resize 60 | else: 61 | return np.hstack(img_resize, output_view) if merge_img else output_view, output, None 62 | 63 | def predict_(self, inp: np.array): 64 | org = np.array(inp) 65 | inp = self.trans(inp) 66 | inp = torch.unsqueeze(inp, 0).cuda() 67 | with torch.no_grad(): 68 | output = self.model(inp)[0].cpu().detach().numpy()[0] 69 | output = np.argmax(output, axis=0)[:, :, np.newaxis].repeat(3, 2) 70 | return output.astype(np.uint8)[:, :, 0], output.astype(np.uint8)*127 71 | 72 | def predict_dir(self, img_path, tar_path=None, out_dir="./submit2", merge_img=False): 73 | img_paths = glob(os.path.join(img_path, "*")) 74 | 75 | os.makedirs(os.path.join(out_dir, 'view'),exist_ok=True) 76 | os.makedirs(os.path.join(out_dir, 'gt'),exist_ok=True) 77 | os.makedirs(os.path.join(out_dir, 'predict'),exist_ok=True) 78 | if tar_path: 79 | tar_paths = glob(os.path.join(tar_path, "*")) 80 | assert len(img_paths) == len(tar_paths) 81 | 82 | img_paths.sort() 83 | tar_paths.sort() 84 | bar = tqdm.tqdm(zip(img_paths, tar_paths), total=len(img_paths)) 85 | 86 | try: 87 | for img, tar in bar: 88 | # print(img,tar) 89 | view, output, target = self.predict(img, tar, merge_img=merge_img) 90 | 91 | cv2.imwrite(os.path.join(out_dir, 'view', os.path.basename(tar)), view) 92 | cv2.imwrite(os.path.join(out_dir, 'predict', os.path.basename(tar)), output) 93 | cv2.imwrite(os.path.join(out_dir, 'gt', os.path.basename(tar)), target) 94 | except KeyboardInterrupt: 95 | bar.close() 96 | raise 97 | bar.close() 98 | else: 99 | bar = tqdm.tqdm(img_paths, total=len(img_paths)) 100 | try: 101 | for img in bar: 102 | view, output, _ = self.predict(img, merge_img=merge_img) 103 | cv2.imwrite(os.path.join(out_dir, 'view', os.path.basename(tar_path)), view) 104 | cv2.imwrite(os.path.join(out_dir, 'predict', os.path.basename(tar_path)), output) 105 | except KeyboardInterrupt: 106 | bar.close() 107 | raise 108 | bar.close() 109 | 110 | 111 | if __name__ == '__main__': 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument("-o", "--out_dir", type=str, default="./submit") 114 | parser.add_argument("-m", "--merge_img", type=bool, default=True) 115 | parser.add_argument("-w", "--weight_path", type=str, default=Configs["model_save_path"]) 116 | args = parser.parse_args() 117 | predictor = Predictor(args.weight_path, Configs) 118 | predictor.predict_dir(Configs["test_image_dir"], 119 | Configs["test_mask_dir"], 120 | out_dir=args.out_dir, 121 | merge_img=args.merge_img) 122 | # predictor = Predictor("mnet.pth") 123 | # predictor.predict_dir(img_path='C:\\Users\\Whale\\Documents\\DataSets\\CUHK\\test_image', 124 | # tar_path='C:\\Users\\Whale\\Documents\\DataSets\\CUHK\\test_gt', 125 | # merge_img=True) 126 | -------------------------------------------------------------------------------- /DataSetLoader/MDataSet.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset,DataLoader 2 | from torchvision import transforms 3 | from MTransform import SyncRandomCrop 4 | import os 5 | from PIL import Image 6 | import numpy as np 7 | import cv2 8 | import torch 9 | from cfg import Configs 10 | import albumentations as albu 11 | 12 | 13 | def get_transforms(crop_size: int, size: int): 14 | pipeline = albu.Compose([albu.RandomCrop(crop_size, crop_size, always_apply=True), albu.Resize(size, size), 15 | albu.VerticalFlip(), albu.RandomRotate90(always_apply=True)], 16 | additional_targets={'target': 'image'}) 17 | 18 | def process(a, b): 19 | r = pipeline(image=a, target=b) 20 | return r['image'], r['target'] 21 | 22 | return process 23 | 24 | 25 | class BlurDataSet(Dataset): 26 | def __init__(self, data_dir, target_dir, aug): 27 | self.data_dir = data_dir 28 | self.target_dir = target_dir 29 | self.aug = aug 30 | 31 | if not os.path.exists(data_dir): 32 | raise RuntimeError("dataset error:", self.target_dir + 'not exists') 33 | if not os.path.exists(target_dir): 34 | raise RuntimeError("dataset error:", self.data_dir + 'not exists') 35 | 36 | self.data_name_list = [] 37 | self.target_name_list = [] 38 | for _, _, file_names in os.walk(self.data_dir): 39 | for fileName in file_names: 40 | self.data_name_list.append(os.path.join(self.data_dir, fileName)) 41 | for _, _, file_names in os.walk(self.target_dir): 42 | for fileName in file_names: 43 | self.target_name_list.append(os.path.join(self.target_dir, fileName)) 44 | data_len = len(self.data_name_list) 45 | target_len = len(self.target_name_list) 46 | if data_len != target_len: 47 | raise RuntimeError("different num of data and target in " + self.data_dir + ' and ' + self.target_dir) 48 | self.data_name_list.sort() 49 | self.target_name_list.sort() 50 | self.size = data_len 51 | self.transform = get_transforms(256, 224) 52 | self.multi_scale_transform = [albu.Resize(224, 224), albu.Resize(112, 112), albu.Resize(56, 56), albu.Resize(28, 28)] 53 | 54 | def __getitem__(self, item): 55 | image = np.array(Image.open(self.data_name_list[item],'r')) 56 | target = np.array(Image.open(self.target_name_list[item],'r')) 57 | if len(target.shape) > 2: 58 | target = target[:,:,0] 59 | 60 | if self.aug: 61 | check_time = 5 62 | for i in range(check_time): 63 | image_t,target_t = self.transform(image,target) 64 | if np.max(target_t) != np.min(target_t) or i == check_time - 1: 65 | image = image_t 66 | target = target_t 67 | break 68 | image = torch.from_numpy(np.transpose(image, (2, 0, 1))).float()/255 69 | 70 | targets = [] 71 | for tran in self.multi_scale_transform: 72 | resize = tran(image=target)['image'] 73 | resize = torch.from_numpy(resize).long() 74 | targets.append(resize) 75 | 76 | return image, targets 77 | 78 | def __len__(self): 79 | return self.size 80 | 81 | 82 | def test_dataset(data_set): 83 | size = len(data_set) 84 | for i in range(size): 85 | a = data_set.__getitem__(i) 86 | b = np.transpose(a[0].numpy(), (1, 2, 0)) 87 | c = a[1][0].numpy().astype(np.uint8) * 120 88 | cv2.imshow('img', b) 89 | cv2.imshow('mask', c) 90 | cv2.waitKey(0) 91 | 92 | 93 | if __name__ == '__main__': 94 | train_data_set = BlurDataSet("C:\\Users\\Whale\\Documents\\DataSets\\CUHK\\train_image", 95 | "C:\\Users\\Whale\\Documents\\DataSets\\CUHK\\train_gt") 96 | test_data_set = BlurDataSet("C:\\Users\\Whale\\Documents\\DataSets\\CUHK\\test_image", 97 | "C:\\Users\\Whale\\Documents\\DataSets\\CUHK\\test_gt") 98 | loader = DataLoader(train_data_set,batch_size=1,shuffle=False) 99 | for batch_id, data in enumerate(loader): 100 | input = torch.cat(data[0],0) 101 | print(input.shape) 102 | target = [] 103 | for i in range(4): 104 | cur=[] 105 | for t in data[1]: 106 | cur.append(t[i]) 107 | target.append(torch.cat(cur, 0)) 108 | 109 | for i in range(input.shape[0]): 110 | b = np.transpose(input[i].numpy(), (1, 2, 0)) 111 | c = target[0][i].numpy().astype(np.uint8) 112 | # d = target[1][i].numpy().astype(np.uint8) 113 | # e = target[2][i].numpy().astype(np.uint8) 114 | # f = target[3][i].numpy().astype(np.uint8) 115 | # c[c == 1] = 0 116 | # c[c == 2] = 0 117 | # c[c > 0] = 255 118 | # d[d == 1] = 0 119 | # d[d == 2] = 0 120 | # d[d > 0] = 255 121 | # e[e == 1] = 0 122 | # e[e == 2] = 0 123 | # e[e > 0] = 255 124 | # f[f == 1] = 0 125 | # f[f == 2] = 0 126 | # f[f > 0] = 255 127 | # cv2.imshow('img', b) 128 | # cv2.imshow('mask_0', c) 129 | # cv2.imshow('mask_1', d) 130 | # cv2.imshow('mask_2', e) 131 | # cv2.imshow('mask_3', f) 132 | # cv2.waitKey(0) 133 | cv2.imwrite(os.path.join("C:\\Users\\Whale\\Projects\\result\\gt",str(batch_id)+".png"),c) 134 | 135 | # break 136 | exit(0) 137 | -------------------------------------------------------------------------------- /models/fpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as func 3 | import numpy as np 4 | from utils import get_encoder,get_fpn_skip 5 | 6 | 7 | class FPN(): 8 | @classmethod 9 | def fromConfig(cls, config): 10 | encoder = get_encoder(config) 11 | if config['fpn_out'] == 2: 12 | fpn = FPN_2(encoder, get_fpn_skip(config), config['skip_out_channel']) 13 | for param in fpn.parameters(): 14 | param.requires_grad = True 15 | elif config['fpn_out'] == 3: 16 | fpn = FPN_3(encoder, get_fpn_skip(config), config['skip_out_channel']) 17 | for param in fpn.parameters(): 18 | param.requires_grad = True 19 | return fpn 20 | 21 | 22 | class FPN_2(torch.nn.Module): 23 | def __init__(self, encoder, skip_ins, skip_out): 24 | super(FPN_2, self).__init__() 25 | self.encoder = encoder 26 | 27 | self.skip_0 = torch.nn.Conv2d(skip_ins[0], skip_out, kernel_size=1) 28 | self.skip_1 = torch.nn.Conv2d(skip_ins[1], skip_out, kernel_size=1) 29 | self.skip_2 = torch.nn.Conv2d(skip_ins[2], skip_out, kernel_size=1) 30 | self.skip_3 = torch.nn.Conv2d(skip_ins[3], skip_out, kernel_size=1) 31 | self.skip_4 = torch.nn.Conv2d(skip_ins[4], skip_out, kernel_size=1) 32 | 33 | self.conv_4 = torch.nn.Conv2d(skip_out, skip_out, kernel_size=3, padding=1) 34 | self.conv_3 = torch.nn.Conv2d(skip_out, skip_out, kernel_size=3, padding=1) 35 | self.conv_2 = torch.nn.Conv2d(skip_out, skip_out, kernel_size=3, padding=1) 36 | self.conv_1 = torch.nn.Conv2d(skip_out, skip_out, kernel_size=3, padding=1) 37 | self.conv_0 = torch.nn.Conv2d(skip_out, skip_out, kernel_size=3, padding=1) 38 | 39 | self.smooth_1 = torch.nn.Conv2d(4 * skip_out, skip_out, kernel_size=3, padding=1) 40 | self.smooth_2 = torch.nn.Conv2d(skip_out, skip_out, kernel_size=3, padding=1) 41 | 42 | self.out_1 = torch.nn.Conv2d(skip_out,3,kernel_size=3, padding=1) 43 | self.out_2 = torch.nn.Conv2d(skip_out,3,kernel_size=3, padding=1) 44 | 45 | def forward(self, x): 46 | map_0,map_1,map_2,map_3,map_4 = self.encoder(x) 47 | 48 | skip_4 = self.conv_4(self.skip_4(map_4)) 49 | skip_3 = self.conv_3(self.skip_3(map_3) + func.interpolate(input=skip_4, scale_factor=2, mode='nearest')) 50 | skip_2 = self.conv_2(self.skip_2(map_2) + func.interpolate(input=skip_3, scale_factor=2, mode='nearest')) 51 | skip_1 = self.conv_1(self.skip_1(map_1) + func.interpolate(input=skip_2, scale_factor=2, mode='nearest')) 52 | skip_0 = self.conv_0(self.skip_0(map_0) + func.interpolate(input=skip_1, scale_factor=2, mode='nearest')) 53 | 54 | upsample_4 = func.interpolate(input=skip_4, scale_factor=8, mode='nearest') 55 | upsample_3 = func.interpolate(input=skip_3, scale_factor=4, mode='nearest') 56 | upsample_2 = func.interpolate(input=skip_2, scale_factor=2, mode='nearest') 57 | upsample_1 = skip_1 58 | 59 | concat = torch.cat([upsample_4, upsample_3, upsample_2, upsample_1], dim=1) 60 | concat = self.smooth_1(concat) 61 | 62 | out = skip_0 + func.interpolate(input=concat, scale_factor=2, mode='nearest') 63 | 64 | out = self.smooth_2(out) 65 | 66 | return self.out_1(out), self.out_2(concat) 67 | 68 | 69 | class FPN_3(torch.nn.Module): 70 | def __init__(self, encoder, skip_ins, skip_out): 71 | super(FPN_3, self).__init__() 72 | self.encoder = encoder 73 | 74 | self.skip_0 = torch.nn.Conv2d(skip_ins[0], skip_out, kernel_size=1) 75 | self.skip_1 = torch.nn.Conv2d(skip_ins[1], skip_out, kernel_size=1) 76 | self.skip_2 = torch.nn.Conv2d(skip_ins[2], skip_out, kernel_size=1) 77 | self.skip_3 = torch.nn.Conv2d(skip_ins[3], skip_out, kernel_size=1) 78 | self.skip_4 = torch.nn.Conv2d(skip_ins[4], skip_out, kernel_size=1) 79 | 80 | self.conv_4 = torch.nn.Conv2d(skip_out, skip_out, kernel_size=3, padding=1) 81 | self.conv_3 = torch.nn.Conv2d(skip_out, skip_out, kernel_size=3, padding=1) 82 | self.conv_2 = torch.nn.Conv2d(skip_out, skip_out, kernel_size=3, padding=1) 83 | self.conv_1 = torch.nn.Conv2d(skip_out, skip_out, kernel_size=3, padding=1) 84 | self.conv_0 = torch.nn.Conv2d(skip_out, skip_out, kernel_size=3, padding=1) 85 | 86 | self.smooth_0 = torch.nn.Conv2d(3 * skip_out, skip_out, kernel_size=3, padding=1) 87 | self.smooth_1 = torch.nn.Conv2d(4 * skip_out, skip_out, kernel_size=3, padding=1) 88 | self.smooth_2 = torch.nn.Conv2d(skip_out, skip_out, kernel_size=3, padding=1) 89 | 90 | self.out_0 = torch.nn.Conv2d(skip_out, 3, kernel_size=3, padding=1) 91 | self.out_1 = torch.nn.Conv2d(skip_out,3,kernel_size=3, padding=1) 92 | self.out_2 = torch.nn.Conv2d(skip_out,3,kernel_size=3, padding=1) 93 | 94 | def forward(self, x): 95 | map_0, map_1, map_2, map_3, map_4 = self.encoder(x) 96 | 97 | skip_4 = self.conv_4(self.skip_4(map_4)) 98 | skip_3 = self.conv_3(self.skip_3(map_3) + func.interpolate(input=skip_4, scale_factor=2, mode='nearest')) 99 | skip_2 = self.conv_2(self.skip_2(map_2) + func.interpolate(input=skip_3, scale_factor=2, mode='nearest')) 100 | skip_1 = self.conv_1(self.skip_1(map_1) + func.interpolate(input=skip_2, scale_factor=2, mode='nearest')) 101 | skip_0 = self.conv_0(self.skip_0(map_0) + func.interpolate(input=skip_1, scale_factor=2, mode='nearest')) 102 | 103 | # out_0 104 | upsample_0_4 = func.interpolate(input=skip_4, scale_factor=4, mode='nearest') 105 | upsample_0_3 = func.interpolate(input=skip_3, scale_factor=2, mode='nearest') 106 | upsample_0_2 = skip_2 107 | 108 | concat_0 = torch.cat([upsample_0_4,upsample_0_3,upsample_0_2], dim=1) 109 | concat_0 = self.smooth_0(concat_0) 110 | 111 | # out_1 112 | upsample_1_4 = func.interpolate(input=skip_4, scale_factor=8, mode='nearest') 113 | upsample_1_3 = func.interpolate(input=skip_3, scale_factor=4, mode='nearest') 114 | upsample_1_2 = func.interpolate(input=skip_2, scale_factor=2, mode='nearest') 115 | upsample_1_1 = skip_1 116 | 117 | concat_1 = torch.cat([upsample_1_4, upsample_1_3, upsample_1_2, upsample_1_1], dim=1) 118 | concat_1 = self.smooth_1(concat_1) 119 | 120 | # out_2 121 | out = skip_0 + func.interpolate(input=concat_1, scale_factor=2, mode='nearest') 122 | 123 | out = self.smooth_2(out) 124 | 125 | return self.out_2(out), self.out_1(concat_1), self.out_0(concat_0) 126 | 127 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from DataSetLoader.MDataSet import BlurDataSet 2 | from models.deblurnet import Net 3 | from models.fpn import FPN 4 | from loss.hybridloss import HybridLoss 5 | from torch.utils.data import DataLoader 6 | from torch import optim 7 | import torch 8 | import os 9 | from cfg import Configs 10 | from tensorboardX import SummaryWriter 11 | import numpy as np 12 | 13 | 14 | min_valid_loss = 100 15 | 16 | def valid(net, loader, loss_function, ech, summary, epoch_test=False): 17 | v_loss = [] 18 | c_loss = [] 19 | s_loss = [] 20 | i_loss = [] 21 | for batch_id, data in enumerate(loader): 22 | input_image = data[0].cuda(device=Configs["device_ids"][0]) 23 | target = [i.cuda(device=Configs["device_ids"][0]) for i in data[1]] 24 | 25 | output = net(input_image) 26 | 27 | valid_loss, ce_loss, ssim_loss, iou_loss = loss_function(output, target) 28 | v_loss.append(valid_loss.cpu().detach().numpy()) 29 | c_loss.append(ce_loss.cpu().detach().numpy()) 30 | s_loss.append(ssim_loss.cpu().detach().numpy()) 31 | i_loss.append(iou_loss.cpu().detach().numpy()) 32 | 33 | if not epoch_test: 34 | gt1 = data[1][0][0] 35 | gt2 = data[1][0][0] 36 | # gt3 = data[1][0][2][0] 37 | # gt4 = data[1][0][3][0] 38 | pre_mask1 = output[0].cpu()[0] 39 | pre_mask1 = torch.argmax(pre_mask1, dim=0).unsqueeze(0) 40 | pre_mask2 = output[1].cpu()[0] 41 | pre_mask2 = torch.argmax(pre_mask2, dim=0).unsqueeze(0) 42 | # pre_mask3 = output[2].cpu()[0] 43 | # pre_mask3 = torch.argmax(pre_mask3, dim=0).unsqueeze(0) 44 | # pre_mask4 = output[3].cpu()[0] 45 | # pre_mask4 = torch.argmax(pre_mask4, dim=0).unsqueeze(0) 46 | summary.add_image("scalar/validate_sample_image", input_image[0], global_step=ech) 47 | summary.add_image("gt/validate_sample_gt1", gt1.unsqueeze(0), global_step=ech) 48 | summary.add_image("gt/validate_sample_gt2", gt2.unsqueeze(0), global_step=ech) 49 | # summary.add_image("gt/validate_sample_gt3", gt3.unsqueeze(0), global_step=ech) 50 | # summary.add_image("gt/validate_sample_gt4", gt4.unsqueeze(0), global_step=ech) 51 | summary.add_image("pre/validate_sample_pre1", pre_mask1, global_step=ech) 52 | summary.add_image("pre/validate_sample_pre2", pre_mask2, global_step=ech) 53 | # summary.add_image("pre/validate_sample_pre3", pre_mask3, global_step=ech) 54 | # summary.add_image("pre/validate_sample_pre4", pre_mask4, global_step=ech) 55 | return valid_loss.cpu().detach().numpy() 56 | break 57 | summary.add_scalar("validate/total_loss", np.mean(v_loss), global_step=ech) 58 | summary.add_scalar("validate/ce_loss", np.mean(c_loss), global_step=ech) 59 | summary.add_scalar("validate/ssim_loss", np.mean(s_loss), global_step=ech) 60 | summary.add_scalar("validate/iou_loss", np.mean(i_loss), global_step=ech) 61 | 62 | print("validate......") 63 | global min_valid_loss 64 | if min_valid_loss > np.mean(v_loss): 65 | min_valid_loss = np.mean(v_loss) 66 | torch.save({ 67 | 'model_state_dict': model.state_dict(), 68 | }, "best.pth") 69 | 70 | print("loss:",np.mean(v_loss)) 71 | return np.mean(v_loss) 72 | 73 | 74 | 75 | 76 | def train(net, train_loader, valid_loader, loss_function, opt, ech, summary): 77 | net.train() 78 | for batch_id, data in enumerate(train_loader): 79 | input_image = data[0].cuda(device=Configs["device_ids"][0]) 80 | target = [i.cuda(device=Configs["device_ids"][0]) for i in data[1]] 81 | 82 | opt.zero_grad() 83 | output = net(input_image) 84 | total_loss, ce_loss, ssim_loss, iou_loss = loss_function(output, target) 85 | total_loss.backward() 86 | opt.step() 87 | 88 | if batch_id % 10 == 0: 89 | valid_loss = valid(net, valid_loader, 90 | loss_function, 91 | ech * len(train_loader.dataset) + batch_id * train_loader.batch_size, summary) 92 | summary.add_scalar("train/total_loss", total_loss, global_step=ech) 93 | summary.add_scalar("train/ce_loss", ce_loss, global_step=ech) 94 | summary.add_scalar("train/ssim_loss", ssim_loss, global_step=ech) 95 | summary.add_scalar("train/iou_loss", iou_loss, global_step=ech) 96 | 97 | print('Train Epoch: {} [{}/{} ({:.0f}%)] ' 98 | '\t train_loss: {:.12f} ' 99 | '\t ce_loss: {:.12f} ' 100 | '\t ssim_loss: {:.12f} ' 101 | '\t iou_loss: {:.12f} ' 102 | '\t valid_loss: {:.12f}'.format( 103 | ech, 104 | batch_id * train_loader.batch_size, 105 | len(train_loader.dataset), 106 | 100. * batch_id / len(train_loader), 107 | total_loss.data.cpu().numpy(), 108 | ce_loss.data.cpu().numpy(), 109 | ssim_loss.data.cpu().numpy(), 110 | iou_loss.data.cpu().numpy(), 111 | valid_loss 112 | )) 113 | 114 | def load_module(model,save_path): 115 | if os.path.exists(save_path): 116 | print("load from saved model:" + save_path + '...') 117 | checkpoint = torch.load(save_path) 118 | model.load_state_dict(checkpoint['model_state_dict']) 119 | ech = checkpoint['epoch'] 120 | print("load complete") 121 | return ech 122 | else: 123 | print("start from new") 124 | return 0 125 | 126 | def save_model(model, ech, save_path): 127 | torch.save({ 128 | 'epoch': ech, 129 | 'model_state_dict': model.state_dict(), 130 | }, save_path) 131 | 132 | 133 | 134 | if __name__ == '__main__': 135 | if Configs["fpn"]: 136 | model = FPN.fromConfig(Configs) 137 | optimizer = optim.Adam(model.parameters(), lr=Configs['encoder_learning_rate']) 138 | else: 139 | model = Net(Configs) 140 | optimizer = model.optimizer_by_layer(Configs['encoder_learning_rate'], Configs['decoder_lr_scale']) 141 | 142 | 143 | model = torch.nn.DataParallel(model, device_ids=Configs["device_ids"]) 144 | model = model.cuda(device=Configs["device_ids"][0]) 145 | cur_epoch = load_module(model,Configs['model_save_path']) 146 | # cur_epoch = model.module.load_model(Configs['pre_path'], Configs['model_save_path']) 147 | # optimizer = model.module.optimizer_by_layer(Configs['encoder_learning_rate'], Configs['decoder_lr_scale']) 148 | 149 | 150 | train_data = BlurDataSet(Configs['train_image_dir'], Configs['train_mask_dir'], aug=Configs['augmentation']) 151 | train_loader = DataLoader(train_data, batch_size=Configs['train_batch_size'] * len(Configs['device_ids']), 152 | shuffle=True) 153 | test_data = BlurDataSet(Configs['test_image_dir'], Configs['test_mask_dir'], True) 154 | test_loader = DataLoader(test_data, batch_size=Configs['test_batch_size'] * len(Configs['device_ids']), 155 | shuffle=True) 156 | 157 | write = SummaryWriter() 158 | # write.add_graph(model,torch.rand(1,3,224,224).cuda()) 159 | loss_func = HybridLoss(Configs["l_bce"], Configs["l_ssim"], Configs["l_IoU"]) 160 | model.train() 161 | for epoch in range(cur_epoch, Configs['epoch']): 162 | train(model, train_loader, test_loader, loss_func, optimizer, epoch, write) 163 | valid(model,test_loader,loss_func,epoch,write,True) 164 | save_model(model, epoch, Configs['model_save_path']) 165 | write.close() 166 | exit(0) 167 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models import resnet152, resnet34, resnet50, resnet101 3 | 4 | # __all__ = ['ResNet50', 'ResNet101','ResNet152'] 5 | # 6 | # def Conv1(in_planes, places, stride=1): 7 | # return nn.Sequential( 8 | # nn.Conv2d(in_channels=in_planes,out_channels=places,kernel_size=3,stride=stride,padding=3, bias=False), 9 | # nn.BatchNorm2d(places), 10 | # nn.ReLU(inplace=True), 11 | # nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 12 | # ) 13 | # 14 | # class Bottleneck(nn.Module): 15 | # def __init__(self,in_places,places, stride=1,downsampling=False, expansion = 4): 16 | # super(Bottleneck,self).__init__() 17 | # self.expansion = expansion 18 | # self.downsampling = downsampling 19 | # 20 | # self.bottleneck = nn.Sequential( 21 | # nn.Conv2d(in_channels=in_places,out_channels=places,kernel_size=1,stride=1, bias=False), 22 | # nn.BatchNorm2d(places), 23 | # nn.ReLU(inplace=True), 24 | # nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False), 25 | # nn.BatchNorm2d(places), 26 | # nn.ReLU(inplace=True), 27 | # nn.Conv2d(in_channels=places, out_channels=places*self.expansion, kernel_size=1, stride=1, bias=False), 28 | # nn.BatchNorm2d(places*self.expansion), 29 | # ) 30 | # 31 | # if self.downsampling: 32 | # self.downsample = nn.Sequential( 33 | # nn.Conv2d(in_channels=in_places, out_channels=places*self.expansion, kernel_size=1, stride=stride, bias=False), 34 | # nn.BatchNorm2d(places*self.expansion) 35 | # ) 36 | # self.relu = nn.ReLU(inplace=True) 37 | # def forward(self, x): 38 | # residual = x 39 | # out = self.bottleneck(x) 40 | # 41 | # if self.downsampling: 42 | # residual = self.downsample(x) 43 | # 44 | # out += residual 45 | # out = self.relu(out) 46 | # return out 47 | # 48 | # class ResNet(nn.Module): 49 | # def __init__(self,blocks, num_classes=1000, expansion = 4): 50 | # super(ResNet,self).__init__() 51 | # self.expansion = expansion 52 | # 53 | # self.conv1 = Conv1(in_planes = 3, places= 64) 54 | # 55 | # self.layer1 = self.make_layer(in_places = 64, places= 64, block=blocks[0], stride=1) 56 | # self.layer2 = self.make_layer(in_places = 256,places=128, block=blocks[1], stride=2) 57 | # self.layer3 = self.make_layer(in_places=512,places=256, block=blocks[2], stride=2) 58 | # self.layer4 = self.make_layer(in_places=1024,places=512, block=blocks[3], stride=2) 59 | # 60 | # for m in self.modules(): 61 | # if isinstance(m, nn.Conv2d): 62 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 63 | # elif isinstance(m, nn.BatchNorm2d): 64 | # nn.init.constant_(m.weight, 1) 65 | # nn.init.constant_(m.bias, 0) 66 | # 67 | # def make_layer(self, in_places, places, block, stride): 68 | # layers = [] 69 | # layers.append(Bottleneck(in_places, places,stride, downsampling =True)) 70 | # for i in range(1, block): 71 | # layers.append(Bottleneck(places*self.expansion, places)) 72 | # 73 | # return nn.Sequential(*layers) 74 | # 75 | # 76 | # def forward(self, x): 77 | # x = self.conv1(x) 78 | # skip_1 = x 79 | # x = self.layer1(x) 80 | # skip_2 = x 81 | # x = self.layer2(x) 82 | # skip_3 = x 83 | # x = self.layer3(x) 84 | # skip_4 = x 85 | # x = self.layer4(x) 86 | # return skip_1, skip_2, skip_3, skip_4, x 87 | # 88 | # 89 | # def ResNet50(): 90 | # return ResNet([3, 4, 6, 3]) 91 | # 92 | # 93 | # def ResNet101(): 94 | # return ResNet([3, 4, 23, 3]) 95 | # 96 | # 97 | # def ResNet152(): 98 | # return ResNet([3, 8, 36, 3]) 99 | 100 | 101 | def conv3x3(in_planes, out_planes, stride=1): 102 | "3x3 convolution with padding" 103 | return torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 104 | padding=1, bias=False) 105 | 106 | def conv1x1(in_planes, out_planes, stride=1): 107 | "3x3 convolution with padding" 108 | return torch.nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, 109 | padding=1, bias=False) 110 | 111 | 112 | class BasicBlock(torch.nn.Module): 113 | expansion = 1 114 | 115 | def __init__(self, inplanes, planes, stride=1, downsample=None): 116 | super(BasicBlock, self).__init__() 117 | self.conv1 = conv3x3(inplanes, planes, stride) 118 | self.bn1 = torch.nn.BatchNorm2d(planes) 119 | self.relu = torch.nn.ReLU(inplace=True) 120 | self.conv2 = conv3x3(planes, planes) 121 | self.bn2 = torch.nn.BatchNorm2d(planes) 122 | self.downsample = downsample 123 | self.stride = stride 124 | 125 | def forward(self, x): 126 | residual = x 127 | 128 | out = self.conv1(x) 129 | out = self.bn1(out) 130 | out = self.relu(out) 131 | 132 | out = self.conv2(out) 133 | out = self.bn2(out) 134 | 135 | if self.downsample is not None: 136 | residual = self.downsample(x) 137 | 138 | out += residual 139 | out = self.relu(out) 140 | 141 | return out 142 | 143 | 144 | class ResNet152(torch.nn.Module): 145 | def __init__(self): 146 | super(ResNet152,self).__init__() 147 | resnet = resnet152(pretrained=True) 148 | self.inconv = torch.nn.Conv2d(3, 64, 3, padding=1) 149 | self.inbn = torch.nn.BatchNorm2d(64) 150 | self.inrelu = torch.nn.ReLU(inplace=True) 151 | # stage 1 152 | self.encoder1 = resnet.layer1 # 224 153 | # stage 2 154 | self.encoder2 = resnet.layer2 # 112 155 | # stage 3 156 | self.encoder3 = resnet.layer3 # 56 157 | # stage 4 158 | self.encoder4 = resnet.layer4 # 28 159 | 160 | self.pool4 = torch.nn.MaxPool2d(2, 2, ceil_mode=True) 161 | 162 | # stage 5 163 | self.resb5_1 = BasicBlock(2048, 2048) 164 | self.resb5_2 = BasicBlock(2048, 2048) 165 | self.resb5_3 = BasicBlock(2048, 2048) # 14 166 | 167 | def forward(self, x): 168 | x = self.inconv(x) 169 | x = self.inbn(x) 170 | x = self.inrelu(x) 171 | 172 | skip1 = self.encoder1(x) 173 | skip2 = self.encoder2(skip1) 174 | skip3 = self.encoder3(skip2) 175 | skip4 = self.encoder4(skip3) 176 | x = self.pool4(skip4) 177 | x = self.resb5_1(x) 178 | x = self.resb5_2(x) 179 | x = self.resb5_3(x) 180 | return skip1, skip2, skip3, skip4, x 181 | 182 | class ResNet50(torch.nn.Module): 183 | def __init__(self): 184 | super(ResNet50,self).__init__() 185 | resnet = resnet50(pretrained=True) 186 | self.inconv = torch.nn.Conv2d(3, 64, 3, padding=1) 187 | self.inbn = torch.nn.BatchNorm2d(64) 188 | self.inrelu = torch.nn.ReLU(inplace=True) 189 | # stage 1 190 | self.encoder1 = resnet.layer1 # 224 191 | # stage 2 192 | self.encoder2 = resnet.layer2 # 112 193 | # stage 3 194 | self.encoder3 = resnet.layer3 # 56 195 | # stage 4 196 | self.encoder4 = resnet.layer4 # 28 197 | 198 | self.pool4 = torch.nn.MaxPool2d(2, 2, ceil_mode=True) 199 | 200 | # stage 5 201 | self.resb5_1 = BasicBlock(2048, 2048) 202 | self.resb5_2 = BasicBlock(2048, 2048) 203 | self.resb5_3 = BasicBlock(2048, 2048) # 14 204 | 205 | def forward(self, x): 206 | x = self.inconv(x) 207 | x = self.inbn(x) 208 | x = self.inrelu(x) 209 | 210 | skip1 = self.encoder1(x) 211 | skip2 = self.encoder2(skip1) 212 | skip3 = self.encoder3(skip2) 213 | skip4 = self.encoder4(skip3) 214 | x = self.pool4(skip4) 215 | x = self.resb5_1(x) 216 | x = self.resb5_2(x) 217 | x = self.resb5_3(x) 218 | return skip1, skip2, skip3, skip4, x 219 | 220 | class ResNet34(torch.nn.Module): 221 | def __init__(self): 222 | super(ResNet34,self).__init__() 223 | resnet = resnet34(pretrained=True) 224 | self.inconv = torch.nn.Conv2d(3, 64, 3, padding=1) 225 | self.inbn = torch.nn.BatchNorm2d(64) 226 | self.inrelu = torch.nn.ReLU(inplace=True) 227 | # stage 1 228 | self.encoder1 = resnet.layer1 # 224 229 | # stage 2 230 | self.encoder2 = resnet.layer2 # 112 231 | # stage 3 232 | self.encoder3 = resnet.layer3 # 56 233 | # stage 4 234 | self.encoder4 = resnet.layer4 # 28 235 | 236 | self.pool4 = torch.nn.MaxPool2d(2, 2, ceil_mode=True) 237 | 238 | # stage 5 239 | self.resb5_1 = BasicBlock(512, 512) 240 | self.resb5_2 = BasicBlock(512, 512) 241 | self.resb5_3 = BasicBlock(512, 512) # 14 242 | 243 | def forward(self, x): 244 | x = self.inconv(x) 245 | x = self.inbn(x) 246 | x = self.inrelu(x) 247 | 248 | skip1 = self.encoder1(x) 249 | skip2 = self.encoder2(skip1) 250 | skip3 = self.encoder3(skip2) 251 | skip4 = self.encoder4(skip3) 252 | x = self.pool4(skip4) 253 | x = self.resb5_1(x) 254 | x = self.resb5_2(x) 255 | x = self.resb5_3(x) 256 | return skip1, skip2, skip3, skip4, x 257 | 258 | --------------------------------------------------------------------------------