├── .gitignore ├── README.md ├── annotator.py ├── core.py ├── dataset └── .gitkeep ├── evaluate.py ├── model ├── __init__.py ├── fcanet.py ├── res2net_split.py └── resnet_split.py ├── pretrained_model └── .gitkeep └── test.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | /dataset/* 3 | /pretrained_model/* 4 | !**/.gitkeep 5 | 6 | 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FCANet 2 | The official PyTorch implementation of CVPR 2020 paper ["Interactive Image Segmentation with First Click Attention"](http://mftp.mmcheng.net/Papers/20CvprFirstClick.pdf). 3 | 4 | ## Requirement 5 | - PyTorch>=0.4.1 6 | - Opencv 7 | - Scipy 8 | - Matplotlib 9 | 10 | ## Useage 11 | Put the pretrained model into the folder "pretrained_model" and the unzipped datasets into the folder "dataset". 12 | 13 | ### Evalution: 14 | ``` 15 | python evaluate.py --backbone [resnet,res2net] --dataset [GrabCut,Berkeley,DAVIS,VOC2012] (--sis) 16 | ``` 17 | ### Demo with UI: 18 | ``` 19 | python annotator.py --backbone [resnet,res2net] --input test.jpg --output test_mask.jpg (--sis) 20 | ``` 21 | 22 | *(\[x,y,z] means choices. (x) means optional.)* 23 | 24 | ## Datasets 25 | - GrabCut ( [GoogleDrive](https://drive.google.com/open?id=1PhtmSoijh7THHZOmMt8_mwEF5Ii-3ofY) | [BaiduYun](https://pan.baidu.com/s/1u3LzLx3Xnr1kuU9HAnCN4w) pwd: **6qkj** ) 26 | - Berkeley ( [GoogleDrive](https://drive.google.com/open?id=1vo0k3JrulK8C198lmvfbDhUok8S7gpTr) | [BaiduYun](https://pan.baidu.com/s/1B6T3aKWB2U6sIeWrQrnszQ) pwd: **6bfw** ) 27 | - DAVIS ( [GoogleDrive](https://drive.google.com/open?id=1Cn0QvzYCIlFky5hFdeUYEtXaiTYW91rL) | [BaiduYun](https://pan.baidu.com/s/1qZTrFE7K_41CgsZyH5NJJw) pwd: **u5vd** ) 28 | *(The three datasets we used are downloaded from [f-BRS](https://github.com/saic-vul/fbrs_interactive_segmentation))* 29 | - PASCAL ( [official](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar) ) 30 | 31 | 32 | ## Pretrained models 33 | - fcanet-resnet ( [GoogleDrive](https://drive.google.com/open?id=1Zjn1RqAuNfG_VXjJ4X7RjX0cvuu6vFtM) | [BaiduYun](https://pan.baidu.com/s/1_Jjxb_0SU842FAj9RAEmBw ) pwd: **g989** ) 34 | - fcanet-res2net ( [GoogleDrive](https://drive.google.com/open?id=12wgzK5Gx_ku3jSm9RcUx_OmEHU3HGfG_ 35 | ) | [BaiduYun](https://pan.baidu.com/s/157Lkno7x8YSdmQzaYms-lg) pwd: **xzmx** ) 36 | 37 | 38 | ## Other framework 39 | We offer the [code](https://gitee.com/mindspore/models/tree/master/research/cv/FCANet) in constructed by MindSpore framework. 40 | 41 | 42 | ## Citation 43 | If you find this work or code is helpful in your research, please cite: 44 | ``` 45 | @inproceedings{lin2020fclick, 46 | title={Interactive image segmentation with first click attention}, 47 | author={Lin, Zheng and Zhang, Zhao and Chen, Lin-Zhuo and Cheng, Ming-Ming and Lu, Shao-Ping}, 48 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 49 | pages={13339--13348}, 50 | year={2020} 51 | } 52 | ``` 53 | ## Contact 54 | If you have any questions, feel free to contact me via: `frazer.linzheng(at)gmail.com`. 55 | Welcome to visit [the project page](https://www.lin-zheng.com/fclick/) or [my home page](https://www.lin-zheng.com/). 56 | 57 | ## License 58 | The source code is free for research and education use only. Any comercial use should get formal permission first. 59 | -------------------------------------------------------------------------------- /annotator.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import argparse 4 | import numpy as np 5 | from PIL import Image 6 | from pathlib import Path 7 | import matplotlib as mpl 8 | import matplotlib.pyplot as plt 9 | from core import init_model,predict 10 | 11 | #Forbidden Key: QSFKL 12 | class Annotator(object): 13 | def __init__(self ,img_path, model,if_sis=False,if_cuda=True,save_path=None): 14 | 15 | self.model,self.if_sis,self.if_cuda,self.save_path=model,if_sis,if_cuda,save_path 16 | self.file = Path(img_path).name 17 | self.img = np.array(Image.open(img_path)) 18 | self.clicks = np.empty([0,3],dtype=np.int64) 19 | self.pred = np.zeros(self.img.shape[:2],dtype=np.uint8) 20 | self.merge = self.__gene_merge(self.pred, self.img, self.clicks) 21 | 22 | def __gene_merge(self,pred,img,clicks,r=9,cb=2,b=2,if_first=True): 23 | pred_mask=cv2.merge([pred*255,pred*255,np.zeros_like(pred)]) 24 | result= np.uint8(np.clip(img*0.7+pred_mask*0.3,0,255)) 25 | if b>0: 26 | contours,_=cv2.findContours(pred,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE) 27 | cv2.drawContours(result,contours,-1,(255,255,255),b) 28 | for pt in clicks: 29 | cv2.circle(result,tuple(pt[:2]),r,(255,0,0) if pt[2]==1 else (0,0,255),-1) 30 | cv2.circle(result,tuple(pt[:2]),r,(255,255,255),cb) 31 | if if_first and len(clicks)!=0: 32 | cv2.circle(result,tuple(clicks[0,:2]),r,(0,255,0),cb) 33 | return result 34 | 35 | def __update(self): 36 | self.ax1.imshow(self.merge) 37 | self.fig.canvas.draw() 38 | 39 | def __reset(self): 40 | self.clicks = np.empty([0,3],dtype=np.int64) 41 | self.pred = np.zeros(self.img.shape[:2],dtype=np.uint8) 42 | self.merge = self.__gene_merge(self.pred, self.img, self.clicks) 43 | self.__update() 44 | 45 | def __predict(self): 46 | self.pred = predict(self.model,self.img,self.clicks,if_sis=self.if_sis,if_cuda=self.if_cuda) 47 | self.merge = self.__gene_merge(self.pred, self.img, self.clicks) 48 | self.__update() 49 | 50 | def __on_key_press(self,event): 51 | if event.key=='ctrl+z': 52 | self.clicks=self.clicks[:-1,:] 53 | if len(self.clicks)!=0: 54 | self.__predict() 55 | else: 56 | self.__reset() 57 | elif event.key=='ctrl+r': 58 | self.__reset() 59 | elif event.key=='escape': 60 | plt.close() 61 | elif event.key=='enter': 62 | if self.save_path is not None: 63 | Image.fromarray(self.pred*255).save(self.save_path) 64 | print('save mask in [{}]!'.format(self.save_path)) 65 | plt.close() 66 | 67 | def __on_button_press(self,event): 68 | if (event.xdata is None) or (event.ydata is None):return 69 | if event.button==1 or event.button==3: 70 | x,y= int(event.xdata+0.5), int(event.ydata+0.5) 71 | self.clicks=np.append(self.clicks,np.array([[x,y,(3-event.button)/2]],dtype=np.int64),axis=0) 72 | self.__predict() 73 | 74 | def main(self): 75 | self.fig = plt.figure('Annotator') 76 | self.fig.canvas.mpl_connect('key_press_event', self.__on_key_press) 77 | self.fig.canvas.mpl_connect("button_press_event", self.__on_button_press) 78 | self.fig.suptitle('( file : {} )'.format(self.file),fontsize=16) 79 | self.ax1 = self.fig.add_subplot(1,1,1) 80 | self.ax1.axis('off') 81 | self.ax1.imshow(self.merge) 82 | plt.show() 83 | 84 | 85 | if __name__ == "__main__": 86 | 87 | parser = argparse.ArgumentParser(description="Annotator for FCANet") 88 | parser.add_argument('--input', type=str, default='test.jpg', help='input image') 89 | parser.add_argument('--output', type=str, default='test_mask.png', help='output mask') 90 | parser.add_argument('--backbone', type=str, default='resnet', choices=['resnet', 'res2net'], help='backbone name (default: resnet)') 91 | parser.add_argument('--sis', action='store_true', default=False, help='use sis') 92 | parser.add_argument('--cpu', action='store_true', default=False, help='use cpu (not recommended)') 93 | args = parser.parse_args() 94 | 95 | model = init_model('fcanet',args.backbone,'./pretrained_model/fcanet-{}.pth'.format(args.backbone),if_cuda=not args.cpu) 96 | anno=Annotator(img_path=args.input ,model=model, if_sis=args.sis, if_cuda=not args.cpu,save_path=args.output) 97 | anno.main() 98 | -------------------------------------------------------------------------------- /core.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from scipy.ndimage.morphology import distance_transform_edt 5 | from model.fcanet import FCANet 6 | 7 | ########################################[ Encapsulation ]######################################## 8 | 9 | def get_points_mask(size, points): 10 | mask=np.zeros(size[::-1]).astype(np.uint8) 11 | if len(points)!=0: 12 | points=np.array(points) 13 | mask[points[:,1], points[:,0]]=1 14 | return mask 15 | 16 | def structural_integrity_strategy(pred, pos_mask): 17 | pos_mask=((pos_mask==1)&(pred==1)).astype(np.uint8) 18 | h,w=pred.shape 19 | mask=np.zeros([h+2, w+2], np.uint8) 20 | pred_new=pred.copy() 21 | pts_y, pts_x = np.where(pos_mask==1) 22 | pts_xy=np.concatenate((pts_x[:,np.newaxis], pts_y[:,np.newaxis]), axis=1) 23 | for pt in pts_xy: 24 | cv2.floodFill(pred_new, mask, tuple(pt),2) 25 | pred_new=(pred_new==2).astype(np.uint8) 26 | return pred_new 27 | 28 | def img_resize_point(img, size): 29 | (h, w) = img.shape 30 | if not isinstance(size, tuple): size=( int(w*size), int(h*size) ) 31 | M=np.array([[size[0]/w,0,0],[0,size[1]/h,0]]) 32 | pts_y, pts_x= np.where(img==1) 33 | pts_xy=np.concatenate( (pts_x[:,np.newaxis], pts_y[:,np.newaxis]), axis=1 ) 34 | pts_xy_new= np.dot( np.insert(pts_xy,2,1,axis=1), M.T).astype(np.int64) 35 | img_new=np.zeros(size[::-1],dtype=np.uint8) 36 | for pt in pts_xy_new: 37 | img_new[pt[1], pt[0]]=1 38 | return img_new 39 | 40 | class Resize(object): 41 | def __init__(self, size, mode=None, elems_point=['pos_points_mask','neg_points_mask','first_point_mask'], elems_do=None, elems_undo=[]): 42 | self.size, self.mode = size, mode 43 | self.elems_point = elems_point 44 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 45 | def __call__(self, sample): 46 | for elem in sample.keys(): 47 | if self.elems_do!= None and elem not in self.elems_do :continue 48 | if elem in self.elems_undo:continue 49 | if elem in self.elems_point: 50 | sample[elem]=img_resize_point(sample[elem],self.size) 51 | continue 52 | if self.mode is None: 53 | mode = cv2.INTER_LINEAR if len(sample[elem].shape)==3 else cv2.INTER_NEAREST 54 | sample[elem] = cv2.resize(sample[elem], self.size, interpolation=mode) 55 | return sample 56 | 57 | class CatPointMask(object): 58 | def __init__(self, mode='NO', paras={}, if_repair=True): 59 | self.mode,self.paras,self.if_repair = mode, paras, if_repair 60 | def __call__(self, sample): 61 | gt = sample['gt'] 62 | if_gt_empty= not ((gt>127).any()) 63 | pos_points_mask, neg_points_mask = sample['pos_points_mask'], sample['neg_points_mask'] 64 | if self.mode == 'DISTANCE_POINT_MASK_SRC': 65 | max_dist=255 66 | if if_gt_empty: 67 | pos_points_mask_dist = np.ones(gt.shape).astype(np.float64)*max_dist 68 | else: 69 | pos_points_mask_dist = distance_transform_edt(1-pos_points_mask) 70 | pos_points_mask_dist = np.minimum(pos_points_mask_dist, max_dist) 71 | if neg_points_mask.any()==False: 72 | neg_points_mask_dist = np.ones(gt.shape).astype(np.float64)*max_dist 73 | else: 74 | neg_points_mask_dist = distance_transform_edt(1-neg_points_mask) 75 | neg_points_mask_dist = np.minimum(neg_points_mask_dist, max_dist) 76 | pos_points_mask_dist, neg_points_mask_dist = pos_points_mask_dist*255, neg_points_mask_dist*255 77 | sample['pos_mask_dist_src'] = pos_points_mask_dist 78 | sample['neg_mask_dist_src'] = neg_points_mask_dist 79 | return sample 80 | 81 | class ToTensor(object): 82 | def __init__(self, if_div=True, elems_do=None, elems_undo=[]): 83 | self.if_div = if_div 84 | self.elems_do, self.elems_undo = elems_do, (['meta']+elems_undo) 85 | def __call__(self, sample): 86 | for elem in sample.keys(): 87 | if self.elems_do!= None and elem not in self.elems_do :continue 88 | if elem in self.elems_undo:continue 89 | tmp = sample[elem] 90 | tmp = tmp[np.newaxis,:,:] if tmp.ndim == 2 else tmp.transpose((2, 0, 1)) 91 | tmp = torch.from_numpy(tmp).float() 92 | tmp = tmp.float().div(255) if self.if_div else tmp 93 | sample[elem] = tmp 94 | return sample 95 | 96 | ########################################[ Interface ]######################################## 97 | 98 | def init_model(model_name='fcanet',backbone='resnet',pretrained_file=None, if_cuda=True): 99 | print('Backbone is {}'.format(backbone)) 100 | if model_name=='fcanet': 101 | model=FCANet(backbone=backbone) 102 | if if_cuda: model = model.cuda() 103 | model.eval() 104 | if pretrained_file is not None: 105 | if if_cuda: 106 | state_dict=torch.load(pretrained_file) 107 | else: 108 | state_dict=torch.load(pretrained_file,map_location=lambda storage, loc: storage) 109 | model.load_state_dict(state_dict) 110 | print('load from [{}]!'.format(pretrained_file)) 111 | return model 112 | 113 | def predict(model,img, seq_points, if_sis=False, if_cuda=True): 114 | h,w,_ =img.shape 115 | sample={} 116 | sample['img']=img.copy() 117 | sample['gt']=(np.ones((h,w))*255).astype(np.uint8) 118 | sample['pos_points_mask'] = get_points_mask((w,h),seq_points[seq_points[:,2]==1,:2]) 119 | sample['neg_points_mask'] = get_points_mask((w,h),seq_points[seq_points[:,2]==0,:2]) 120 | sample['first_point_mask'] = get_points_mask((w,h),seq_points[0:1,:2]) 121 | Resize((int(w*512/min(h, w)),int(h*512/min(h, w))))(sample) 122 | CatPointMask(mode='DISTANCE_POINT_MASK_SRC', if_repair=False)(sample) 123 | sample['pos_mask_dist_first'] = np.minimum(distance_transform_edt(1-sample['first_point_mask']), 255.0)*255.0 124 | ToTensor()(sample) 125 | input=[sample['img'].unsqueeze(0), sample['pos_mask_dist_src'].unsqueeze(0), sample['neg_mask_dist_src'].unsqueeze(0), sample['pos_mask_dist_first'].unsqueeze(0)] 126 | if if_cuda: 127 | for i in range(len(input)): 128 | input[i]=input[i].cuda() 129 | with torch.no_grad(): 130 | output = model(input) 131 | result = torch.sigmoid(output.data.cpu()).numpy()[0,0,:,:] 132 | result = cv2.resize(result, (w,h), interpolation=cv2.INTER_LINEAR) 133 | pred = (result>0.5).astype(np.uint8) 134 | if if_sis: pred=structural_integrity_strategy(pred,get_points_mask((w,h),seq_points[seq_points[:,2]==1,:2])) 135 | return pred 136 | 137 | -------------------------------------------------------------------------------- /dataset/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frazerlin/fcanet/e7de0edca91ad73766a6b30b24cd4ad6a3ec53fd/dataset/.gitkeep -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import numpy as np 4 | from PIL import Image 5 | from tqdm import tqdm 6 | from pathlib import Path 7 | from scipy.ndimage.morphology import distance_transform_edt 8 | from core import init_model,predict 9 | 10 | ########################################[ Dataset ]######################################## 11 | #for general dataset format 12 | class Dataset(): 13 | def __init__(self,dataset_path,img_folder='img',gt_folder='gt',threshold=128,ignore_label=None): 14 | self.index,self.threshold,self.ignore_label = 0,threshold,ignore_label 15 | dataset_path=Path(dataset_path) 16 | self.img_files = sorted((dataset_path/img_folder).glob('*.*')) 17 | self.gt_files = [ next((dataset_path/gt_folder).glob(t.stem+'.*')) for t in self.img_files] 18 | def __iter__(self): 19 | return self 20 | def __len__(self): 21 | return len(self.img_files) 22 | def __next__(self): 23 | if self.index > len(self) - 1:raise StopIteration 24 | img_src = np.array(Image.open(self.img_files[self.index])) 25 | gt_src = np.array(Image.open(self.gt_files[self.index])) 26 | gt = gt_src[:,:,0] if gt_src.ndim==3 else gt_src 27 | gt = np.uint8(gt>=self.threshold) 28 | if self.ignore_label is not None: gt[gt_src==self.ignore_label]=255 29 | self.index += 1 30 | return img_src,gt 31 | 32 | #special for PASCAL_VOC2012 33 | class VOC2012(): 34 | def __init__(self,dataset_path): 35 | self.index = 0 36 | dataset_path=Path(dataset_path) 37 | with open(dataset_path/'ImageSets'/'Segmentation'/'val.txt') as f: 38 | val_ids=sorted(f.read().splitlines()) 39 | 40 | self.img_files,self.gt_files,self.instance_indices=[],[],[] 41 | print('Preprocessing!') 42 | for val_id in tqdm(val_ids): 43 | gt_ins_set= sorted(set(np.array(Image.open( dataset_path/'SegmentationObject'/(val_id+'.png'))).flat)) 44 | for instance_index in gt_ins_set: 45 | if instance_index not in [0,255]: 46 | self.img_files.append( dataset_path/'JPEGImages'/(val_id+'.jpg')) 47 | self.gt_files.append( dataset_path/'SegmentationObject'/(val_id+'.png')) 48 | self.instance_indices.append(instance_index) 49 | def __iter__(self): 50 | return self 51 | def __len__(self): 52 | return len(self.img_files) 53 | def __next__(self): 54 | if self.index > len(self) - 1:raise StopIteration 55 | img_src = np.array(Image.open(self.img_files[self.index])) 56 | gt_src = np.array(Image.open(self.gt_files[self.index])) 57 | gt=np.uint8(gt_src==self.instance_indices[self.index]) 58 | gt[gt_src==255]=255 59 | self.index += 1 60 | return img_src,gt 61 | 62 | 63 | ########################################[ Evaluation ]######################################## 64 | #robot user strategy 65 | def get_next_anno_point(pred, gt, seq_points): 66 | fndist_map=distance_transform_edt(np.pad((gt==1)&(pred==0),((1,1),(1,1)),'constant'))[1:-1, 1:-1] 67 | fpdist_map=distance_transform_edt(np.pad((gt==0)&(pred==1),((1,1),(1,1)),'constant'))[1:-1, 1:-1] 68 | fndist_map[seq_points[:,1],seq_points[:,0]],fpdist_map[seq_points[:,1],seq_points[:,0]]=0,0 69 | [usr_map,if_pos] = [fndist_map, 1] if fndist_map.max()>fpdist_map.max() else [fpdist_map, 0] 70 | [y_mlist, x_mlist] = np.where(usr_map == usr_map.max()) 71 | pt_next=(x_mlist[0],y_mlist[0],if_pos) 72 | return pt_next 73 | 74 | datasets_kwargs={ 75 | 'GrabCut' :{'dataset_path':'dataset/GrabCut' ,'img_folder':'data_GT','gt_folder':'boundary_GT','threshold':128,'ignore_label':128 }, 76 | 'Berkeley':{'dataset_path':'dataset/Berkeley','img_folder':'images' ,'gt_folder':'masks' ,'threshold':128,'ignore_label':None}, 77 | 'DAVIS' :{'dataset_path':'dataset/DAVIS' ,'img_folder':'img' ,'gt_folder':'gt' ,'threshold':0.5,'ignore_label':None}, 78 | 'VOC2012' :{'dataset_path':'dataset/VOC2012'}, 79 | } 80 | default_miou_targets={'GrabCut':0.90,'Berkeley':0.90,'DAVIS':0.90,'VOC2012':0.85} 81 | 82 | def eval_dataset(model, dataset, max_point_num=20, record_point_num=20,if_sis=False,miou_target=None,if_cuda=True): 83 | global datasets_kwargs, default_miou_targets 84 | if dataset in datasets_kwargs: 85 | dataset_iter= VOC2012(**datasets_kwargs[dataset]) if dataset=='VOC2012' else Dataset(**datasets_kwargs[dataset]) 86 | miou_target = default_miou_targets[dataset] if miou_target is None else miou_target 87 | else: 88 | dataset_iter=Dataset(dataset_path='dataset/{}'.format(dataset)) 89 | miou_target = 0.85 if miou_target is None else miou_target 90 | 91 | NoC,mIoU_NoC=0,[0]*(record_point_num+1) 92 | for img,gt in tqdm(dataset_iter): 93 | pred = np.zeros_like(gt) 94 | seq_points=np.empty([0,3],dtype=np.int64) 95 | if_get_target=False 96 | for point_num in range(1, max_point_num+1): 97 | pt_next = get_next_anno_point(pred, gt, seq_points) 98 | seq_points=np.append(seq_points,[pt_next],axis=0) 99 | pred = predict(model,img,seq_points,if_sis=if_sis,if_cuda=if_cuda) 100 | miou = ((pred==1)&(gt==1)).sum()/(((pred==1)|(gt==1))&(gt!=255)).sum() 101 | if point_num<=record_point_num: 102 | mIoU_NoC[point_num]+=miou 103 | if (not if_get_target) and (miou>=miou_target or point_num==max_point_num): 104 | NoC+=point_num 105 | if_get_target=True 106 | if if_get_target and point_num>=record_point_num: break 107 | 108 | print('dataset: [{}] {}:'.format(dataset,'(SIS)'if if_sis else ' ')) 109 | print('--> mNoC : {}'.format(NoC/len(dataset_iter))) 110 | print('--> mIoU-NoC : {}\n\n'.format(np.array([round(i/len(dataset_iter),3) for i in mIoU_NoC ]))) 111 | 112 | if __name__ == "__main__": 113 | 114 | parser = argparse.ArgumentParser(description="Evaluation for FCANet") 115 | parser.add_argument('--backbone', type=str, default='resnet', choices=['resnet', 'res2net'], help='backbone name (default: resnet)') 116 | parser.add_argument('--dataset', type=str, default='VOC2012', help='evaluation dataset (default: VOC2012)') 117 | parser.add_argument('--sis', action='store_true', default=False, help='use sis') 118 | parser.add_argument('--miou', type=float, default=-1.0, help='miou_target (default: -1.0[means automatic selection])') 119 | parser.add_argument('--cpu', action='store_true', default=False, help='use cpu (not recommended)') 120 | args = parser.parse_args() 121 | 122 | if Path('dataset/{}'.format(args.dataset)).exists(): 123 | model = init_model('fcanet',args.backbone,'./pretrained_model/fcanet-{}.pth'.format(args.backbone),if_cuda=not args.cpu) 124 | eval_dataset(model,args.dataset,if_sis=args.sis, miou_target=(None if args.miou<0 else args.miou),if_cuda=not args.cpu) 125 | else: 126 | print('not found folder [dataset/{}]'.format(args.dataset)) 127 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frazerlin/fcanet/e7de0edca91ad73766a6b30b24cd4ad6a3ec53fd/model/__init__.py -------------------------------------------------------------------------------- /model/fcanet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from . import resnet_split 5 | from . import res2net_split 6 | 7 | ########################################[ Global ]######################################## 8 | 9 | def init_weight(model): 10 | for m in model.modules(): 11 | if isinstance(m, nn.Conv2d): 12 | torch.nn.init.kaiming_normal_(m.weight) 13 | elif isinstance(m, nn.BatchNorm2d): 14 | m.weight.data.fill_(1) 15 | m.bias.data.zero_() 16 | 17 | def get_mask_gauss(mask_dist_src, sigma): 18 | return torch.exp(-2.772588722*(mask_dist_src**2)/(sigma**2)) 19 | 20 | ########################################[ MultiConv ]######################################## 21 | 22 | class MultiConv(nn.Module): 23 | def __init__(self,in_ch, channels, kernel_sizes=None, strides=None, dilations=None, paddings=None, BatchNorm=nn.BatchNorm2d): 24 | super(MultiConv, self).__init__() 25 | self.num=len(channels) 26 | if kernel_sizes is None: kernel_sizes=[ 3 for c in channels] 27 | if strides is None: strides=[ 1 for c in channels] 28 | if dilations is None: dilations=[ 1 for c in channels] 29 | if paddings is None: paddings = [ ( (kernel_sizes[i]//2) if dilations[i]==1 else (kernel_sizes[i]//2 * dilations[i]) ) for i in range(self.num)] 30 | convs_tmp=[] 31 | for i in range(self.num): 32 | if channels[i]==1: 33 | convs_tmp.append(nn.Conv2d( in_ch if i==0 else channels[i-1] , channels[i], kernel_size=kernel_sizes[i], stride=strides[i], padding=paddings[i], dilation=dilations[i])) 34 | else: 35 | convs_tmp.append(nn.Sequential(nn.Conv2d( in_ch if i==0 else channels[i-1] , channels[i], kernel_size=kernel_sizes[i], stride=strides[i], padding=paddings[i], dilation=dilations[i],bias=False), BatchNorm(channels[i]), nn.ReLU())) 36 | self.convs=nn.Sequential(*convs_tmp) 37 | init_weight(self) 38 | def forward(self, x): 39 | return self.convs(x) 40 | 41 | ########################################[ MyASPP ]######################################## 42 | 43 | class _ASPPModule(nn.Module): 44 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 45 | super(_ASPPModule, self).__init__() 46 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False) 47 | self.bn = BatchNorm(planes) 48 | self.relu = nn.ReLU() 49 | init_weight(self) 50 | 51 | def forward(self, x): 52 | x = self.relu(self.bn(self.atrous_conv(x))) 53 | return x 54 | 55 | class MyASPP(nn.Module): 56 | def __init__(self, in_ch, out_ch, dilations, BatchNorm, if_global=True): 57 | super(MyASPP, self).__init__() 58 | self.if_global = if_global 59 | 60 | self.aspp1 = _ASPPModule(in_ch, out_ch, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 61 | self.aspp2 = _ASPPModule(in_ch, out_ch, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 62 | self.aspp3 = _ASPPModule(in_ch, out_ch, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 63 | self.aspp4 = _ASPPModule(in_ch, out_ch, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 64 | 65 | if if_global: 66 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 67 | nn.Conv2d(in_ch, out_ch, 1, stride=1, bias=False), 68 | BatchNorm(out_ch), 69 | nn.ReLU()) 70 | 71 | merge_channel=out_ch*5 if if_global else out_ch*4 72 | 73 | self.conv1 = nn.Conv2d(merge_channel, out_ch, 1, bias=False) 74 | self.bn1 = BatchNorm(out_ch) 75 | self.relu = nn.ReLU() 76 | init_weight(self) 77 | 78 | def forward(self, x): 79 | x1 = self.aspp1(x) 80 | x2 = self.aspp2(x) 81 | x3 = self.aspp3(x) 82 | x4 = self.aspp4(x) 83 | 84 | if self.if_global: 85 | x5 = self.global_avg_pool(x) 86 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 87 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 88 | else: 89 | x = torch.cat((x1, x2, x3, x4), dim=1) 90 | 91 | x = self.conv1(x) 92 | x = self.bn1(x) 93 | x = self.relu(x) 94 | return x 95 | 96 | ########################################[ MyDecoder ]######################################## 97 | 98 | class MyDecoder(nn.Module): 99 | def __init__(self, in_ch, in_ch_reduce, side_ch, side_ch_reduce, out_ch, BatchNorm, size_ref='side'): 100 | super(MyDecoder, self).__init__() 101 | self.size_ref=size_ref 102 | self.relu = nn.ReLU() 103 | self.in_ch_reduce, self.side_ch_reduce = in_ch_reduce, side_ch_reduce 104 | 105 | if in_ch_reduce is not None: 106 | self.in_conv = nn.Sequential( nn.Conv2d(in_ch, in_ch_reduce, 1, bias=False), BatchNorm(in_ch_reduce), nn.ReLU()) 107 | if side_ch_reduce is not None: 108 | self.side_conv = nn.Sequential( nn.Conv2d(side_ch, side_ch_reduce, 1, bias=False), BatchNorm(side_ch_reduce), nn.ReLU()) 109 | 110 | merge_ch= (in_ch_reduce if in_ch_reduce is not None else in_ch) + (side_ch_reduce if side_ch_reduce is not None else side_ch) 111 | 112 | self.merge_conv = nn.Sequential(nn.Conv2d(merge_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False), 113 | BatchNorm(out_ch), 114 | nn.ReLU(), 115 | nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=False), 116 | BatchNorm(out_ch), 117 | nn.ReLU()) 118 | init_weight(self) 119 | 120 | def forward(self, input, side): 121 | if self.in_ch_reduce is not None: 122 | input=self.in_conv(input) 123 | if self.side_ch_reduce is not None: 124 | side=self.side_conv(side) 125 | 126 | if self.size_ref=='side': 127 | input=F.interpolate(input, size=side.size()[2:], mode='bilinear', align_corners=True) 128 | elif self.size_ref=='input': 129 | side=F.interpolate(side, size=input.size()[2:], mode='bilinear', align_corners=True) 130 | 131 | merge=torch.cat((input, side), dim=1) 132 | output=self.merge_conv(merge) 133 | return output 134 | 135 | ########################################[ PredDecoder ]######################################## 136 | 137 | class PredDecoder(nn.Module): 138 | def __init__(self,in_ch,BatchNorm, if_sigmoid=False): 139 | super(PredDecoder, self).__init__() 140 | self.if_sigmoid=if_sigmoid 141 | self.pred_conv = nn.Sequential(nn.Conv2d(in_ch, in_ch//2, kernel_size=3, stride=1, padding=1, bias=False), 142 | BatchNorm(in_ch//2), 143 | nn.ReLU(), 144 | nn.Conv2d(in_ch//2, in_ch//2, kernel_size=3, stride=1, padding=1, bias=False), 145 | BatchNorm(in_ch//2), 146 | nn.ReLU(), 147 | nn.Conv2d(in_ch//2, 1, kernel_size=1, stride=1)) 148 | init_weight(self) 149 | def forward(self, input): 150 | output = self.pred_conv(input) 151 | if self.if_sigmoid: 152 | output=torch.sigmoid(output) 153 | return output 154 | 155 | ########################################[ Net ]######################################## 156 | 157 | class FCANet(nn.Module): 158 | def __init__(self, backbone='resnet', BatchNorm = nn.BatchNorm2d): 159 | super(FCANet, self).__init__() 160 | if backbone=='resnet': 161 | self.backbone_pre = resnet_split.ResNetSplit101(16,BatchNorm,False,[0,1],2) 162 | self.backbone_last = resnet_split.ResNetSplit101(16,BatchNorm,False,[2,4],0) 163 | elif backbone=='res2net': 164 | self.backbone_pre = res2net_split.Res2NetSplit101(16,BatchNorm,False,[0,1],2) 165 | self.backbone_last = res2net_split.Res2NetSplit101(16,BatchNorm,False,[2,4],0) 166 | 167 | self.my_aspp = MyASPP(in_ch=2048+512,out_ch=256,dilations=[1, 6, 12, 18],BatchNorm=BatchNorm, if_global=True) 168 | self.my_decoder=MyDecoder(in_ch=256, in_ch_reduce=None, side_ch=256, side_ch_reduce=48,out_ch=256,BatchNorm=BatchNorm) 169 | self.pred_decoder=PredDecoder(in_ch=256, BatchNorm=BatchNorm) 170 | self.first_conv=MultiConv(257,[256,256,256,512,512,512],[3,3,3,3,3,3],[2,1,1,2,1,1]) 171 | self.first_pred_decoder=PredDecoder(in_ch=512, BatchNorm=BatchNorm) 172 | 173 | def forward(self, input): 174 | [img, pos_mask_dist_src, neg_mask_dist_src, pos_mask_dist_first]=input 175 | pos_mask_gauss, neg_mask_gauss= get_mask_gauss(pos_mask_dist_src,10), get_mask_gauss(neg_mask_dist_src,10) 176 | pos_mask_gauss_first = get_mask_gauss(pos_mask_dist_first,30) 177 | img_with_anno = torch.cat((img, pos_mask_gauss, neg_mask_gauss), dim=1) 178 | l1=self.backbone_pre(img_with_anno) 179 | l4=self.backbone_last(l1) 180 | l1_first=torch.cat((l1, F.interpolate(pos_mask_gauss_first, size=l1.size()[2:], mode='bilinear', align_corners=True)),dim=1) 181 | l1_first=self.first_conv(l1_first) 182 | l4=torch.cat((l1_first,l4),dim=1) 183 | x=self.my_aspp(l4) 184 | x=self.my_decoder(x,l1) 185 | x=self.pred_decoder(x) 186 | result = F.interpolate(x, size=img.size()[2:], mode='bilinear', align_corners=True) 187 | return result 188 | 189 | 190 | -------------------------------------------------------------------------------- /model/res2net_split.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | class Bottle2neck(nn.Module): 7 | expansion = 4 8 | 9 | def __init__(self, inplanes, planes, stride=1, dilation=1, 10 | downsample=None, baseWidth=26, scale = 4, stype='normal', BatchNorm=None): 11 | """ Constructor 12 | Args: 13 | inplanes: input channel dimensionality 14 | planes: output channel dimensionality 15 | stride: conv stride. Replaces pooling layer. 16 | downsample: None when stride = 1 17 | baseWidth: basic width of conv3x3 18 | scale: number of scale. 19 | type: 'normal': normal set. 'stage': first block of a new stage. 20 | """ 21 | super(Bottle2neck, self).__init__() 22 | 23 | width = int(math.floor(planes * (baseWidth/64.0))) 24 | self.conv1 = nn.Conv2d(inplanes, width*scale, kernel_size=1, bias=False) 25 | self.bn1 = BatchNorm(width*scale) 26 | 27 | if scale == 1: 28 | self.nums = 1 29 | else: 30 | self.nums = scale -1 31 | if stype == 'stage' and stride>1: 32 | self.pool = nn.AvgPool2d(kernel_size=3, stride = stride, padding=1) 33 | convs = [] 34 | bns = [] 35 | for i in range(self.nums): 36 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride = stride, dilation=dilation, padding=dilation, bias=False)) 37 | bns.append(BatchNorm(width)) 38 | self.convs = nn.ModuleList(convs) 39 | self.bns = nn.ModuleList(bns) 40 | 41 | self.conv3 = nn.Conv2d(width*scale, planes * self.expansion, kernel_size=1, bias=False) 42 | self.bn3 = BatchNorm(planes * self.expansion) 43 | 44 | self.relu = nn.ReLU(inplace=True) 45 | self.downsample = downsample 46 | self.stype = stype 47 | self.scale = scale 48 | self.width = width 49 | self.stride = stride 50 | 51 | def forward(self, x): 52 | residual = x 53 | 54 | out = self.conv1(x) 55 | out = self.bn1(out) 56 | out = self.relu(out) 57 | 58 | spx = torch.split(out, self.width, 1) 59 | for i in range(self.nums): 60 | if i==0 or self.stype=='stage': 61 | sp = spx[i] 62 | else: 63 | sp = sp + spx[i] 64 | sp = self.convs[i](sp) 65 | sp = self.relu(self.bns[i](sp)) 66 | if i==0: 67 | out = sp 68 | else: 69 | out = torch.cat((out, sp), 1) 70 | if self.scale != 1 and self.stype=='normal': 71 | out = torch.cat((out, spx[self.nums]),1) 72 | elif self.stride==1 and self.scale != 1 and self.stype=='stage': 73 | out = torch.cat((out, spx[self.nums]),1) 74 | elif self.scale != 1 and self.stype=='stage': 75 | out = torch.cat((out, self.pool(spx[self.nums])),1) 76 | 77 | out = self.conv3(out) 78 | out = self.bn3(out) 79 | 80 | if self.downsample is not None: 81 | residual = self.downsample(x) 82 | 83 | out += residual 84 | out = self.relu(out) 85 | 86 | return out 87 | 88 | class Res2NetSplit(nn.Module): 89 | 90 | def __init__(self, block, layers, output_stride, BatchNorm, layer_range=[0,4], aux_channel=0, baseWidth = 26, scale = 4): 91 | super(Res2NetSplit, self).__init__() 92 | self.baseWidth = baseWidth 93 | self.scale = scale 94 | blocks = [1, 2, 4] 95 | if output_stride == 16: 96 | strides = [1, 2, 2, 1] 97 | dilations = [1, 1, 1, 2] 98 | elif output_stride == 8: 99 | strides = [1, 2, 1, 1] 100 | dilations = [1, 1, 2, 4] 101 | else: 102 | raise NotImplementedError 103 | 104 | self.layer_range=layer_range 105 | self.aux_channel = aux_channel 106 | self.ori_inplanes=[3, 64, 64*block.expansion, 128*block.expansion, 256*block.expansion] 107 | self.inplanes = self.ori_inplanes[layer_range[0]]+aux_channel 108 | 109 | # Modules 110 | if layer_range[0]<=0 and 0<=layer_range[1]: 111 | self.conv1 = nn.Sequential( 112 | nn.Conv2d(self.inplanes, 32, 3, 2, 1, bias=False), 113 | BatchNorm(32), 114 | nn.ReLU(inplace=True), 115 | nn.Conv2d(32, 32, 3, 1, 1, bias=False), 116 | BatchNorm(32), 117 | nn.ReLU(inplace=True), 118 | nn.Conv2d(32, 64, 3, 1, 1, bias=False) 119 | ) 120 | self.bn1 = BatchNorm(64) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 123 | self.inplanes=64 124 | 125 | if layer_range[0]<=1 and 1<=layer_range[1]: 126 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 127 | 128 | if layer_range[0]<=2 and 2<=layer_range[1]: 129 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 130 | 131 | if layer_range[0]<=3 and 3<=layer_range[1]: 132 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 133 | 134 | if layer_range[0]<=4 and 4<=layer_range[1]: 135 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 136 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 137 | 138 | self._init_weight() 139 | 140 | 141 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 142 | downsample = None 143 | if stride != 1 or self.inplanes != planes * block.expansion: 144 | downsample = nn.Sequential( 145 | nn.AvgPool2d(kernel_size=stride, stride=stride, 146 | ceil_mode=True, count_include_pad=False), 147 | nn.Conv2d(self.inplanes, planes * block.expansion, 148 | kernel_size=1, stride=1, bias=False), 149 | BatchNorm(planes * block.expansion), 150 | ) 151 | 152 | layers = [] 153 | layers.append(block(self.inplanes, planes, stride, dilation, downsample=downsample, 154 | BatchNorm=BatchNorm, stype='stage', baseWidth=self.baseWidth, scale=self.scale)) 155 | self.inplanes = planes * block.expansion 156 | for i in range(1, blocks): 157 | layers.append(block(self.inplanes, planes, dilation=dilation, 158 | BatchNorm=BatchNorm, baseWidth=self.baseWidth, scale=self.scale)) 159 | 160 | return nn.Sequential(*layers) 161 | 162 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 163 | downsample = None 164 | if stride != 1 or self.inplanes != planes * block.expansion: 165 | downsample = nn.Sequential( 166 | nn.AvgPool2d(kernel_size=stride, stride=stride, 167 | ceil_mode=True, count_include_pad=False), 168 | nn.Conv2d(self.inplanes, planes * block.expansion, 169 | kernel_size=1, stride=1, bias=False), 170 | BatchNorm(planes * block.expansion), 171 | ) 172 | 173 | layers = [] 174 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 175 | downsample=downsample, BatchNorm=BatchNorm, stype='stage', baseWidth=self.baseWidth, scale=self.scale)) 176 | self.inplanes = planes * block.expansion 177 | for i in range(1, len(blocks)): 178 | layers.append(block(self.inplanes, planes, stride=1, 179 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm, baseWidth=self.baseWidth, scale=self.scale)) 180 | 181 | return nn.Sequential(*layers) 182 | 183 | def forward(self, input, if_final=True): 184 | x=input 185 | results=[] 186 | if self.layer_range[0]<=0 and 0<=self.layer_range[1]: 187 | x = self.conv1(x) 188 | x = self.bn1(x) 189 | x = self.relu(x) 190 | x = self.maxpool(x) 191 | l0=x 192 | results.append(l0) 193 | 194 | if self.layer_range[0]<=1 and 1<=self.layer_range[1]: 195 | x = self.layer1(x) 196 | l1=x 197 | results.append(l1) 198 | 199 | 200 | if self.layer_range[0]<=2 and 2<=self.layer_range[1]: 201 | x = self.layer2(x) 202 | l2=x 203 | results.append(l2) 204 | 205 | 206 | if self.layer_range[0]<=3 and 3<=self.layer_range[1]: 207 | x = self.layer3(x) 208 | l3=x 209 | results.append(l3) 210 | 211 | if self.layer_range[0]<=4 and 4<=self.layer_range[1]: 212 | x = self.layer4(x) 213 | l4=x 214 | results.append(l4) 215 | 216 | if if_final: 217 | return results[-1] 218 | else: 219 | return results 220 | 221 | def _init_weight(self): 222 | for m in self.modules(): 223 | if isinstance(m, nn.Conv2d): 224 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 225 | m.weight.data.normal_(0, math.sqrt(2. / n)) 226 | 227 | elif isinstance(m, nn.BatchNorm2d): 228 | m.weight.data.fill_(1) 229 | m.bias.data.zero_() 230 | 231 | def _load_pretrained_model(self, pretrain_url): 232 | pretrain_dict = model_zoo.load_url(pretrain_url) 233 | model_dict = {} 234 | state_dict = self.state_dict() 235 | first_weight_names=['conv1.0.weight','layer1.0.conv1.weight','layer2.0.conv1.weight','layer3.0.conv1.weight','layer4.0.conv1.weight'] 236 | ds_weight_names=['None','layer1.0.downsample.1.weight','layer2.0.downsample.1.weight','layer3.0.downsample.1.weight','layer4.0.downsample.1.weight'] 237 | for k, v in pretrain_dict.items(): 238 | if self.aux_channel>0 and (k==first_weight_names[self.layer_range[0]] or k==ds_weight_names[self.layer_range[0]]): 239 | v_tmp=state_dict[k].clone() 240 | v_tmp[:,:self.ori_inplanes[self.layer_range[0]],:,:]=v.clone() 241 | model_dict[k] = v_tmp 242 | continue 243 | if k in state_dict: 244 | model_dict[k] = v 245 | state_dict.update(model_dict) 246 | self.load_state_dict(state_dict) 247 | 248 | 249 | def Res2NetSplit50(output_stride, BatchNorm, pretrained=False,layer_range=[0,4], aux_channel=0): 250 | model = Res2NetSplit(Bottle2neck, [3, 4, 6, 3], output_stride, BatchNorm, layer_range=layer_range, aux_channel=aux_channel) 251 | if pretrained: 252 | model._load_pretrained_model('https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth') 253 | return model 254 | 255 | def Res2NetSplit101(output_stride, BatchNorm, pretrained=False,layer_range=[0,4], aux_channel=0): 256 | model = Res2NetSplit(Bottle2neck, [3, 4, 23, 3], output_stride, BatchNorm, layer_range=layer_range, aux_channel=aux_channel) 257 | if pretrained: 258 | model._load_pretrained_model('https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth') 259 | return model 260 | 261 | 262 | -------------------------------------------------------------------------------- /model/resnet_split.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | class BasicBlock(nn.Module): 6 | expansion = 1 7 | 8 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 9 | super(BasicBlock, self).__init__() 10 | 11 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 12 | self.bn1 = BatchNorm(planes) 13 | self.relu = nn.ReLU(inplace=True) 14 | 15 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, dilation=dilation, padding=dilation, bias=False) 16 | self.bn2 = BatchNorm(planes) 17 | self.downsample = downsample 18 | self.stride = stride 19 | 20 | def forward(self, x): 21 | residual = x 22 | 23 | out = self.conv1(x) 24 | out = self.bn1(out) 25 | out = self.relu(out) 26 | 27 | out = self.conv2(out) 28 | out = self.bn2(out) 29 | 30 | if self.downsample is not None: 31 | residual = self.downsample(x) 32 | 33 | out += residual 34 | out = self.relu(out) 35 | 36 | return out 37 | 38 | 39 | class Bottleneck(nn.Module): 40 | expansion = 4 41 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 42 | super(Bottleneck, self).__init__() 43 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 44 | self.bn1 = BatchNorm(planes) 45 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 46 | dilation=dilation, padding=dilation, bias=False) 47 | self.bn2 = BatchNorm(planes) 48 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 49 | self.bn3 = BatchNorm(planes * 4) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.downsample = downsample 52 | self.stride = stride 53 | self.dilation = dilation 54 | def forward(self, x): 55 | residual = x 56 | 57 | out = self.conv1(x) 58 | out = self.bn1(out) 59 | out = self.relu(out) 60 | 61 | out = self.conv2(out) 62 | out = self.bn2(out) 63 | out = self.relu(out) 64 | 65 | out = self.conv3(out) 66 | out = self.bn3(out) 67 | 68 | if self.downsample is not None: 69 | residual = self.downsample(x) 70 | 71 | out += residual 72 | out = self.relu(out) 73 | 74 | return out 75 | 76 | 77 | class ResNetSplit(nn.Module): 78 | 79 | def __init__(self, block, layers, output_stride, BatchNorm,layer_range=[0,4],aux_channel=0): 80 | super(ResNetSplit, self).__init__() 81 | blocks = [1, 2, 4] 82 | if output_stride == 16: 83 | strides = [1, 2, 2, 1] 84 | dilations = [1, 1, 1, 2] 85 | elif output_stride == 8: 86 | strides = [1, 2, 1, 1] 87 | dilations = [1, 1, 2, 4] 88 | else: 89 | raise NotImplementedError 90 | 91 | self.layer_range=layer_range 92 | self.aux_channel = aux_channel 93 | self.ori_inplanes=[3, 64, 64*block.expansion, 128*block.expansion, 256*block.expansion] 94 | self.inplanes = self.ori_inplanes[layer_range[0]]+aux_channel 95 | 96 | # Modules 97 | if layer_range[0]<=0 and 0<=layer_range[1]: 98 | self.conv1 = nn.Conv2d(self.inplanes, 64, kernel_size=7, stride=2, padding=3, bias=False) 99 | self.bn1 = BatchNorm(64) 100 | self.relu = nn.ReLU(inplace=True) 101 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 102 | self.inplanes=64 103 | 104 | if layer_range[0]<=1 and 1<=layer_range[1]: 105 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 106 | 107 | if layer_range[0]<=2 and 2<=layer_range[1]: 108 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 109 | 110 | if layer_range[0]<=3 and 3<=layer_range[1]: 111 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 112 | 113 | if layer_range[0]<=4 and 4<=layer_range[1]: 114 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 115 | # self.layer4 = self._make_layer(block, 512, layers[3], stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 116 | 117 | self._init_weight() 118 | 119 | 120 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 121 | downsample = None 122 | if stride != 1 or self.inplanes != planes * block.expansion: 123 | downsample = nn.Sequential( 124 | nn.Conv2d(self.inplanes, planes * block.expansion, 125 | kernel_size=1, stride=stride, bias=False), 126 | BatchNorm(planes * block.expansion), 127 | ) 128 | 129 | layers = [] 130 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 131 | self.inplanes = planes * block.expansion 132 | for i in range(1, blocks): 133 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 134 | 135 | return nn.Sequential(*layers) 136 | 137 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 138 | downsample = None 139 | if stride != 1 or self.inplanes != planes * block.expansion: 140 | downsample = nn.Sequential( 141 | nn.Conv2d(self.inplanes, planes * block.expansion, 142 | kernel_size=1, stride=stride, bias=False), 143 | BatchNorm(planes * block.expansion), 144 | ) 145 | 146 | layers = [] 147 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 148 | downsample=downsample, BatchNorm=BatchNorm)) 149 | self.inplanes = planes * block.expansion 150 | for i in range(1, len(blocks)): 151 | layers.append(block(self.inplanes, planes, stride=1, 152 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 153 | 154 | return nn.Sequential(*layers) 155 | 156 | def forward(self, input, if_final=True): 157 | x=input 158 | results=[] 159 | if self.layer_range[0]<=0 and 0<=self.layer_range[1]: 160 | x = self.conv1(x) 161 | x = self.bn1(x) 162 | x = self.relu(x) 163 | x = self.maxpool(x) 164 | l0=x 165 | results.append(l0) 166 | 167 | if self.layer_range[0]<=1 and 1<=self.layer_range[1]: 168 | x = self.layer1(x) 169 | l1=x 170 | results.append(l1) 171 | 172 | if self.layer_range[0]<=2 and 2<=self.layer_range[1]: 173 | x = self.layer2(x) 174 | l2=x 175 | results.append(l2) 176 | 177 | if self.layer_range[0]<=3 and 3<=self.layer_range[1]: 178 | x = self.layer3(x) 179 | l3=x 180 | results.append(l3) 181 | 182 | if self.layer_range[0]<=4 and 4<=self.layer_range[1]: 183 | x = self.layer4(x) 184 | l4=x 185 | results.append(l4) 186 | 187 | if if_final: 188 | return results[-1] 189 | else: 190 | return results 191 | 192 | def _init_weight(self): 193 | for m in self.modules(): 194 | if isinstance(m, nn.Conv2d): 195 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 196 | m.weight.data.normal_(0, math.sqrt(2. / n)) 197 | elif isinstance(m, nn.BatchNorm2d): 198 | m.weight.data.fill_(1) 199 | m.bias.data.zero_() 200 | 201 | def _load_pretrained_model(self, pretrain_url): 202 | pretrain_dict = model_zoo.load_url(pretrain_url) 203 | model_dict = {} 204 | state_dict = self.state_dict() 205 | first_weight_names=['conv1.weight','layer1.0.conv1.weight','layer2.0.conv1.weight','layer3.0.conv1.weight','layer4.0.conv1.weight'] 206 | ds_weight_names=['None','layer1.0.downsample.0.weight','layer2.0.downsample.0.weight','layer3.0.downsample.0.weight','layer4.0.downsample.0.weight'] 207 | for k, v in pretrain_dict.items(): 208 | if self.aux_channel>0 and (k==first_weight_names[self.layer_range[0]] or k==ds_weight_names[self.layer_range[0]]): 209 | v_tmp=state_dict[k].clone() 210 | v_tmp[:,:self.ori_inplanes[self.layer_range[0]],:,:]=v.clone() 211 | model_dict[k] = v_tmp 212 | continue 213 | if k in state_dict: 214 | model_dict[k] = v 215 | state_dict.update(model_dict) 216 | self.load_state_dict(state_dict) 217 | 218 | def ResNetSplit18(output_stride, BatchNorm, pretrained=False,layer_range=[0,4], aux_channel=0): 219 | model = ResNetSplit(BasicBlock, [2, 2, 2, 2], output_stride, BatchNorm,layer_range=layer_range,aux_channel=aux_channel) 220 | if pretrained: 221 | model._load_pretrained_model('https://download.pytorch.org/models/resnet18-5c106cde.pth') 222 | return model 223 | 224 | def ResNetSplit34(output_stride, BatchNorm, pretrained=False,layer_range=[0,4], aux_channel=0): 225 | model = ResNetSplit(BasicBlock, [3, 4, 6, 3], output_stride, BatchNorm,layer_range=layer_range, aux_channel=aux_channel) 226 | if pretrained: 227 | model._load_pretrained_model('https://download.pytorch.org/models/resnet34-333f7ec4.pth') 228 | return model 229 | 230 | def ResNetSplit50(output_stride, BatchNorm, pretrained=False,layer_range=[0,4], aux_channel=0): 231 | model = ResNetSplit(Bottleneck, [3, 4, 6, 3], output_stride, BatchNorm, layer_range=layer_range, aux_channel=aux_channel) 232 | if pretrained: 233 | model._load_pretrained_model('https://download.pytorch.org/models/resnet50-19c8e357.pth') 234 | return model 235 | 236 | def ResNetSplit101(output_stride, BatchNorm, pretrained=False,layer_range=[0,4], aux_channel=0): 237 | model = ResNetSplit(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, layer_range=layer_range, aux_channel=aux_channel) 238 | if pretrained: 239 | model._load_pretrained_model('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 240 | return model 241 | 242 | def ResNetSplit152(output_stride, BatchNorm, pretrained=False,layer_range=[0,4], aux_channel=0): 243 | model = ResNetSplit(Bottleneck, [3, 8, 36, 3], output_stride, BatchNorm, layer_range=layer_range, aux_channel=aux_channel) 244 | if pretrained: 245 | model._load_pretrained_model('https://download.pytorch.org/models/resnet152-b121ed2d.pth') 246 | return model 247 | 248 | 249 | 250 | -------------------------------------------------------------------------------- /pretrained_model/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frazerlin/fcanet/e7de0edca91ad73766a6b30b24cd4ad6a3ec53fd/pretrained_model/.gitkeep -------------------------------------------------------------------------------- /test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/frazerlin/fcanet/e7de0edca91ad73766a6b30b24cd4ad6a3ec53fd/test.jpg --------------------------------------------------------------------------------