├── imgs ├── PraNet-Award.png ├── framework-final-min.png ├── qualitative_results.png ├── quantiative_results_1.png └── quantiative_results_2.png ├── utils ├── __pycache__ │ ├── utils.cpython-36.pyc │ └── dataloader.cpython-36.pyc ├── format_conversion.py ├── utils.py └── dataloader.py ├── eval ├── CalMAE.m ├── README.md ├── Fmeasure_calu.m ├── StructureMeasure.m ├── original_WFb.m ├── S_object.m ├── Enhancedmeasure.m ├── S_region.m └── main.m ├── jittor ├── MyTest.py ├── utils │ └── dataloader.py ├── README.md └── lib │ ├── Res2Net_v1b.py │ └── PraNet_Res2Net.py ├── MyTest.py ├── lib ├── ResNet.py ├── PraNet_Res2Net.py ├── Res2Net_v1b.py └── PraNet_ResNet.py ├── MyTrain.py └── README.md /imgs/PraNet-Award.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/PraNet/HEAD/imgs/PraNet-Award.png -------------------------------------------------------------------------------- /imgs/framework-final-min.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/PraNet/HEAD/imgs/framework-final-min.png -------------------------------------------------------------------------------- /imgs/qualitative_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/PraNet/HEAD/imgs/qualitative_results.png -------------------------------------------------------------------------------- /imgs/quantiative_results_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/PraNet/HEAD/imgs/quantiative_results_1.png -------------------------------------------------------------------------------- /imgs/quantiative_results_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/PraNet/HEAD/imgs/quantiative_results_2.png -------------------------------------------------------------------------------- /utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/PraNet/HEAD/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataloader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengPingFan/PraNet/HEAD/utils/__pycache__/dataloader.cpython-36.pyc -------------------------------------------------------------------------------- /eval/CalMAE.m: -------------------------------------------------------------------------------- 1 | function mae = CalMAE(smap, gtImg) 2 | % Code Author: Wangjiang Zhu 3 | % Email: wangjiang88119@gmail.com 4 | % Date: 3/24/2014 5 | if size(smap, 1) ~= size(gtImg, 1) || size(smap, 2) ~= size(gtImg, 2) 6 | error('Saliency map and gt Image have different sizes!\n'); 7 | end 8 | 9 | if ~islogical(gtImg) 10 | gtImg = gtImg(:,:,1) > 128; 11 | end 12 | 13 | smap = im2double(smap(:,:,1)); 14 | fgPixels = smap(gtImg); 15 | fgErrSum = length(fgPixels) - sum(fgPixels); 16 | bgErrSum = sum(smap(~gtImg)); 17 | mae = (fgErrSum + bgErrSum) / numel(gtImg); -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation code for “PraNet: Parallel Reverse Attention Network for Polyp Segmentation” (MICCAI-2020) 2 | 3 | > Author: Deng-Ping Fan, Ge-Peng Ji, Tao Zhou, Geng Chen, Huazhu Fu, Jianbing Shen, and Ling Shao 4 | 5 | Homepage: http://dpfan.net/ 6 | 7 | Project Page: https://github.com/DengPingFan/PraNet 8 | 9 | Version: 2020-6-29 10 | 11 | Any questions please contact with dengpfan@gmail.com 12 | 13 | If you find this project is useful, please cite our work. Thanks. 14 | 15 | Title: “PraNet: Parallel Reverse Attention Network for Polyp Segmentation” MICCAI 2020 16 | 17 | ## Usage 18 | 19 | - The Dataset folder contains the five different test datasets. 20 | 21 | - The results folder consists of several compared models. 22 | 23 | - You can just run the `main.m` in the EvaluationTool folder to get the final evaluation results in the EvaluateResults folder. 24 | -------------------------------------------------------------------------------- /utils/format_conversion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from libtiff import TIFF # pip install libtiff 4 | from scipy import misc 5 | import random 6 | 7 | 8 | def tif2png(_src_path, _dst_path): 9 | """ 10 | Usage: 11 | formatting `tif/tiff` files to `jpg/png` files 12 | :param _src_path: 13 | :param _dst_path: 14 | :return: 15 | """ 16 | tif = TIFF.open(_src_path, mode='r') 17 | image = tif.read_image() 18 | misc.imsave(_dst_path, image) 19 | 20 | 21 | def data_split(src_list): 22 | """ 23 | Usage: 24 | randomly spliting dataset 25 | :param src_list: 26 | :return: 27 | """ 28 | counter_list = random.sample(range(0, len(src_list)), 550) 29 | 30 | return counter_list 31 | 32 | 33 | if __name__ == '__main__': 34 | src_dir = '../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/test_split/masks_tif' 35 | dst_dir = '../Dataset/train_dataset/CVC-EndoSceneStill/CVC-612/test_split/masks' 36 | 37 | os.makedirs(dst_dir, exist_ok=True) 38 | for img_name in os.listdir(src_dir): 39 | tif2png(os.path.join(src_dir, img_name), 40 | os.path.join(dst_dir, img_name.replace('.tif', '.png'))) 41 | -------------------------------------------------------------------------------- /eval/Fmeasure_calu.m: -------------------------------------------------------------------------------- 1 | %% 2 | function [PreFtem, RecallFtem, SpecifTem, Dice, FmeasureF, IoU] = Fmeasure_calu(sMap, gtMap, gtsize, threshold) 3 | %threshold = 2* mean(sMap(:)) ; 4 | if ( threshold > 1 ) 5 | threshold = 1; 6 | end 7 | 8 | Label3 = zeros( gtsize ); 9 | Label3( sMap>=threshold ) = 1; 10 | 11 | NumRec = length( find( Label3==1 ) ); %FP+TP 12 | NumNoRec = length(find(Label3==0)); % FN+TN 13 | LabelAnd = Label3 & gtMap; 14 | NumAnd = length( find ( LabelAnd==1 ) ); %TP 15 | num_obj = sum(sum(gtMap)); %TP+FN 16 | num_pred = sum(sum(Label3)); % FP+TP 17 | 18 | FN = num_obj-NumAnd; 19 | FP = NumRec-NumAnd; 20 | TN = NumNoRec-FN; 21 | 22 | %SpecifTem = TN/(TN+FP) 23 | %Precision = TP/(TP+FP) 24 | 25 | if NumAnd == 0 26 | PreFtem = 0; 27 | RecallFtem = 0; 28 | FmeasureF = 0; 29 | Dice = 0; 30 | SpecifTem = 0; 31 | IoU = 0; 32 | else 33 | IoU = NumAnd/(FN+NumRec); %TP/(FN+TP+FP) 34 | PreFtem = NumAnd/NumRec; 35 | RecallFtem = NumAnd/num_obj; 36 | SpecifTem = TN/(TN+FP); 37 | Dice = 2 * NumAnd/(num_obj+num_pred); 38 | % FmeasureF = ( ( 1.3* PreFtem * RecallFtem ) / ( .3 * PreFtem + RecallFtem ) ); % beta = 0.3 39 | FmeasureF = (( 2.0 * PreFtem * RecallFtem ) / (PreFtem + RecallFtem)); % beta = 1.0 40 | end 41 | 42 | %Fmeasure = [PreFtem, RecallFtem, FmeasureF]; 43 | 44 | -------------------------------------------------------------------------------- /eval/StructureMeasure.m: -------------------------------------------------------------------------------- 1 | function Q = StructureMeasure(prediction,GT) 2 | % StructureMeasure computes the similarity between the foreground map and 3 | % ground truth(as proposed in "Structure-measure: A new way to evaluate 4 | % foreground maps" [Deng-Ping Fan et. al - ICCV 2017]) 5 | % Usage: 6 | % Q = StructureMeasure(prediction,GT) 7 | % Input: 8 | % prediction - Binary/Non binary foreground map with values in the range 9 | % [0 1]. Type: double. 10 | % GT - Binary ground truth. Type: logical. 11 | % Output: 12 | % Q - The computed similarity score 13 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 14 | 15 | % Check input 16 | if (~isa(prediction,'double')) 17 | error('The prediction should be double type...'); 18 | end 19 | if ((max(prediction(:))>1) || min(prediction(:))<0) 20 | error('The prediction should be in the range of [0 1]...'); 21 | end 22 | if (~islogical(GT)) 23 | error('GT should be logical type...'); 24 | end 25 | 26 | y = mean2(GT); 27 | 28 | if (y==0)% if the GT is completely black 29 | x = mean2(prediction); 30 | Q = 1.0 - x; %only calculate the area of intersection 31 | elseif(y==1)%if the GT is completely white 32 | x = mean2(prediction); 33 | Q = x; %only calcualte the area of intersection 34 | else 35 | alpha = 0.5; 36 | Q = alpha*S_object(prediction,GT)+(1-alpha)*S_region(prediction,GT); 37 | if (Q<0) 38 | Q=0; 39 | end 40 | end 41 | 42 | end 43 | -------------------------------------------------------------------------------- /eval/original_WFb.m: -------------------------------------------------------------------------------- 1 | function [Q]= original_WFb(FG,GT) 2 | % WFb Compute the Weighted F-beta measure (as proposed in "How to Evaluate 3 | % Foreground Maps?" [Margolin et. al - CVPR'14]) 4 | % Usage: 5 | % Q = FbW(FG,GT) 6 | % Input: 7 | % FG - Binary/Non binary foreground map with values in the range [0 1]. Type: double. 8 | % GT - Binary ground truth. Type: logical. 9 | % Output: 10 | % Q - The Weighted F-beta score 11 | 12 | %Check input 13 | if (~isa( FG, 'double' )) 14 | error('FG should be of type: double'); 15 | end 16 | if ((max(FG(:))>1) || min(FG(:))<0) 17 | error('FG should be in the range of [0 1]'); 18 | end 19 | if (~islogical(GT)) 20 | error('GT should be of type: logical'); 21 | end 22 | 23 | dGT = double(GT); %Use double for computations. 24 | 25 | 26 | E = abs(FG-dGT); 27 | % [Ef, Et, Er] = deal(abs(FG-GT)); 28 | 29 | [Dst,IDXT] = bwdist(dGT); 30 | %Pixel dependency 31 | K = fspecial('gaussian',7,5); 32 | Et = E; 33 | Et(~GT)=Et(IDXT(~GT)); %To deal correctly with the edges of the foreground region 34 | EA = imfilter(Et,K); 35 | MIN_E_EA = E; 36 | MIN_E_EA(GT & EA {} - {}'.format(_data_name, name)) 41 | imageio.imwrite((save_path + name[0]), res) -------------------------------------------------------------------------------- /MyTest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import os, argparse 5 | from scipy import misc 6 | from lib.PraNet_Res2Net import PraNet 7 | from utils.dataloader import test_dataset 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--testsize', type=int, default=352, help='testing size') 11 | parser.add_argument('--pth_path', type=str, default='./snapshots/PraNet_Res2Net/PraNet-19.pth') 12 | 13 | for _data_name in ['CVC-300', 'CVC-ClinicDB', 'Kvasir', 'CVC-ColonDB', 'ETIS-LaribPolypDB']: 14 | data_path = './data/TestDataset/{}/'.format(_data_name) 15 | save_path = './results/PraNet/{}/'.format(_data_name) 16 | opt = parser.parse_args() 17 | model = PraNet() 18 | model.load_state_dict(torch.load(opt.pth_path)) 19 | model.cuda() 20 | model.eval() 21 | 22 | os.makedirs(save_path, exist_ok=True) 23 | image_root = '{}/images/'.format(data_path) 24 | gt_root = '{}/masks/'.format(data_path) 25 | test_loader = test_dataset(image_root, gt_root, opt.testsize) 26 | 27 | for i in range(test_loader.size): 28 | image, gt, name = test_loader.load_data() 29 | gt = np.asarray(gt, np.float32) 30 | gt /= (gt.max() + 1e-8) 31 | image = image.cuda() 32 | 33 | res5, res4, res3, res2 = model(image) 34 | res = res2 35 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 36 | res = res.sigmoid().data.cpu().numpy().squeeze() 37 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 38 | misc.imsave(save_path+name, res) -------------------------------------------------------------------------------- /eval/S_object.m: -------------------------------------------------------------------------------- 1 | function Q = S_object(prediction,GT) 2 | % S_object Computes the object similarity between foreground maps and ground 3 | % truth(as proposed in "Structure-measure:A new way to evaluate foreground 4 | % maps" [Deng-Ping Fan et. al - ICCV 2017]) 5 | % Usage: 6 | % Q = S_object(prediction,GT) 7 | % Input: 8 | % prediction - Binary/Non binary foreground map with values in the range 9 | % [0 1]. Type: double. 10 | % GT - Binary ground truth. Type: logical. 11 | % Output: 12 | % Q - The object similarity score 13 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 14 | 15 | % compute the similarity of the foreground in the object level 16 | prediction_fg = prediction; 17 | prediction_fg(~GT)=0; 18 | O_FG = Object(prediction_fg,GT); 19 | 20 | % compute the similarity of the background 21 | prediction_bg = 1.0 - prediction; 22 | prediction_bg(GT) = 0; 23 | O_BG = Object(prediction_bg,~GT); 24 | 25 | % combine the foreground measure and background measure together 26 | u = mean2(GT); 27 | Q = u * O_FG + (1 - u) * O_BG; 28 | 29 | end 30 | 31 | function score = Object(prediction,GT) 32 | 33 | % check the input 34 | if isempty(prediction) 35 | score = 0; 36 | return; 37 | end 38 | if isinteger(prediction) 39 | prediction = double(prediction); 40 | end 41 | if (~isa( prediction, 'double' )) 42 | error('prediction should be of type: double'); 43 | end 44 | if ((max(prediction(:))>1) || min(prediction(:))<0) 45 | error('prediction should be in the range of [0 1]'); 46 | end 47 | if(~islogical(GT)) 48 | error('GT should be of type: logical'); 49 | end 50 | 51 | % compute the mean of the foreground or background in prediction 52 | x = mean2(prediction(GT)); 53 | 54 | % compute the standard deviations of the foreground or background in prediction 55 | sigma_x = std(prediction(GT)); 56 | 57 | score = 2.0 * x./(x^2 + 1.0 + sigma_x + eps); 58 | end -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from thop import profile 4 | from thop import clever_format 5 | 6 | 7 | def clip_gradient(optimizer, grad_clip): 8 | """ 9 | For calibrating misalignment gradient via cliping gradient technique 10 | :param optimizer: 11 | :param grad_clip: 12 | :return: 13 | """ 14 | for group in optimizer.param_groups: 15 | for param in group['params']: 16 | if param.grad is not None: 17 | param.grad.data.clamp_(-grad_clip, grad_clip) 18 | 19 | 20 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): 21 | decay = decay_rate ** (epoch // decay_epoch) 22 | for param_group in optimizer.param_groups: 23 | param_group['lr'] *= decay 24 | 25 | 26 | class AvgMeter(object): 27 | def __init__(self, num=40): 28 | self.num = num 29 | self.reset() 30 | 31 | def reset(self): 32 | self.val = 0 33 | self.avg = 0 34 | self.sum = 0 35 | self.count = 0 36 | self.losses = [] 37 | 38 | def update(self, val, n=1): 39 | self.val = val 40 | self.sum += val * n 41 | self.count += n 42 | self.avg = self.sum / self.count 43 | self.losses.append(val) 44 | 45 | def show(self): 46 | return torch.mean(torch.stack(self.losses[np.maximum(len(self.losses)-self.num, 0):])) 47 | 48 | 49 | def CalParams(model, input_tensor): 50 | """ 51 | Usage: 52 | Calculate Params and FLOPs via [THOP](https://github.com/Lyken17/pytorch-OpCounter) 53 | Necessarity: 54 | from thop import profile 55 | from thop import clever_format 56 | :param model: 57 | :param input_tensor: 58 | :return: 59 | """ 60 | flops, params = profile(model, inputs=(input_tensor,)) 61 | flops, params = clever_format([flops, params], "%.3f") 62 | print('[Statistics Information]\nFLOPs: {}\nParams: {}'.format(flops, params)) -------------------------------------------------------------------------------- /eval/Enhancedmeasure.m: -------------------------------------------------------------------------------- 1 | function [score]= Emeasure(FM,GT) 2 | % Emeasure Compute the Enhanced Alignment measure (as proposed in "Enhanced-alignment 3 | % Measure for Binary Foreground Map Evaluation" [Deng-Ping Fan et. al - IJCAI'18 oral paper]) 4 | % Usage: 5 | % score = Emeasure(FM,GT) 6 | % Input: 7 | % FM - Binary foreground map. Type: double. 8 | % GT - Binary ground truth. Type: double. 9 | % Output: 10 | % score - The Enhanced alignment score 11 | 12 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%Important Note:%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 13 | %The code is for academic purposes only. Please cite this paper if you make use of it: 14 | 15 | %@conference{Fan2018Enhanced, title={Enhanced-alignment Measure for Binary Foreground Map Evaluation}, 16 | % author={Fan, Deng-Ping and Gong, Cheng and Cao, Yang and Ren, Bo and Cheng, Ming-Ming and Borji, Ali}, 17 | % year = {2018}, 18 | % booktitle = {IJCAI} 19 | % } 20 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 21 | 22 | FM = logical(FM); 23 | GT = logical(GT); 24 | 25 | %Use double for computations. 26 | dFM = double(FM); 27 | dGT = double(GT); 28 | 29 | %Special case: 30 | if (sum(dGT(:))==0)% if the GT is completely black 31 | enhanced_matrix = 1.0 - dFM; %only calculate the black area of intersection 32 | elseif(sum(~dGT(:))==0)%if the GT is completely white 33 | enhanced_matrix = dFM; %only calcualte the white area of intersection 34 | else 35 | %Normal case: 36 | 37 | %1.compute alignment matrix 38 | align_matrix = AlignmentTerm(dFM,dGT); 39 | %2.compute enhanced alignment matrix 40 | enhanced_matrix = EnhancedAlignmentTerm(align_matrix); 41 | end 42 | 43 | %3.Emeasure score 44 | [w,h] = size(GT); 45 | score = sum(enhanced_matrix(:))./(w*h - 1 + eps); 46 | end 47 | 48 | % Alignment Term 49 | function [align_Matrix] = AlignmentTerm(dFM,dGT) 50 | 51 | %compute global mean 52 | mu_FM = mean2(dFM); 53 | mu_GT = mean2(dGT); 54 | 55 | %compute the bias matrix 56 | align_FM = dFM - mu_FM; 57 | align_GT = dGT - mu_GT; 58 | 59 | %compute alignment matrix 60 | align_Matrix = 2.*(align_GT.*align_FM)./(align_GT.*align_GT + align_FM.*align_FM + eps); 61 | 62 | end 63 | 64 | % Enhanced Alignment Term function. f(x) = 1/4*(1 + x)^2) 65 | function enhanced = EnhancedAlignmentTerm(align_Matrix) 66 | enhanced = ((align_Matrix + 1).^2)/4; 67 | end 68 | 69 | 70 | -------------------------------------------------------------------------------- /jittor/utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | import jittor as jt 6 | from jittor.dataset import Dataset 7 | 8 | 9 | class PolypDataset(Dataset): 10 | def __init__(self, image_root, gt_root, trainsize): 11 | super().__init__() 12 | 13 | self.trainsize = trainsize 14 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 15 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png')] 16 | self.images = sorted(self.images) 17 | self.gts = sorted(self.gts) 18 | self.filter_files() 19 | self.size = len(self.images) 20 | self.img_transform = jt.transform.Compose([ 21 | jt.transform.Resize((self.trainsize, self.trainsize)), 22 | jt.transform.ToTensor(), 23 | jt.transform.ImageNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 24 | self.gt_transform = jt.transform.Compose([ 25 | jt.transform.Resize((self.trainsize, self.trainsize)), 26 | jt.transform.ToTensor()]) 27 | 28 | def __getitem__(self, index): 29 | image = self.rgb_loader(self.images[index]) 30 | gt = self.binary_loader(self.gts[index]) 31 | image = self.img_transform(image) 32 | gt = self.gt_transform(gt) 33 | return (image, gt) 34 | 35 | def filter_files(self): 36 | assert (len(self.images) == len(self.gts)) 37 | images = [] 38 | gts = [] 39 | for (img_path, gt_path) in zip(self.images, self.gts): 40 | img = Image.open(img_path) 41 | gt = Image.open(gt_path) 42 | if (img.size == gt.size): 43 | images.append(img_path) 44 | gts.append(gt_path) 45 | self.images = images 46 | self.gts = gts 47 | 48 | def rgb_loader(self, path): 49 | with open(path, 'rb') as f: 50 | img = Image.open(f) 51 | return img.convert('RGB') 52 | 53 | def binary_loader(self, path): 54 | with open(path, 'rb') as f: 55 | img = Image.open(f) 56 | return img.convert('L') 57 | 58 | def resize(self, img, gt): 59 | assert (img.size == gt.size) 60 | (w, h) = img.size 61 | if ((h < self.trainsize) or (w < self.trainsize)): 62 | h = max(h, self.trainsize) 63 | w = max(w, self.trainsize) 64 | return (img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST)) 65 | else: 66 | return (img, gt) 67 | 68 | def __len__(self): 69 | return self.size 70 | 71 | 72 | def get_loader(image_root, gt_root, trainsize): 73 | dataset = PolypDataset(image_root, gt_root, trainsize) 74 | return dataset 75 | 76 | 77 | class test_dataset(Dataset): 78 | 79 | def __init__(self, image_root, gt_root, testsize): 80 | super().__init__() 81 | self.testsize = testsize 82 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 83 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')] 84 | self.images = sorted(self.images) 85 | self.gts = sorted(self.gts) 86 | self.transform = jt.transform.Compose([ 87 | jt.transform.Resize((self.testsize, self.testsize)), 88 | jt.transform.ToTensor(), 89 | jt.transform.ImageNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 90 | self.gt_transform = jt.transform.ToTensor() 91 | self.size = len(self.images) 92 | 93 | def __getitem__(self, index): 94 | image = self.rgb_loader(self.images[index]) 95 | image = self.transform(image) 96 | gt = self.binary_loader(self.gts[index]) 97 | name = self.images[index].split('/')[(- 1)] 98 | if name.endswith('.jpg'): 99 | name = (name.split('.jpg')[0] + '.png') 100 | return (image, gt, name) 101 | 102 | def rgb_loader(self, path): 103 | with open(path, 'rb') as f: 104 | img = Image.open(f) 105 | return img.convert('RGB') 106 | 107 | def binary_loader(self, path): 108 | with open(path, 'rb') as f: 109 | img = Image.open(f) 110 | return img.convert('L') 111 | 112 | def __len__(self): 113 | return self.size 114 | -------------------------------------------------------------------------------- /eval/S_region.m: -------------------------------------------------------------------------------- 1 | function Q = S_region(prediction,GT) 2 | % S_region computes the region similarity between the foreground map and 3 | % ground truth(as proposed in "Structure-measure:A new way to evaluate 4 | % foreground maps" [Deng-Ping Fan et. al - ICCV 2017]) 5 | % Usage: 6 | % Q = S_region(prediction,GT) 7 | % Input: 8 | % prediction - Binary/Non binary foreground map with values in the range 9 | % [0 1]. Type: double. 10 | % GT - Binary ground truth. Type: logical. 11 | % Output: 12 | % Q - The region similarity score 13 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 14 | 15 | % find the centroid of the GT 16 | [X,Y] = centroid(GT); 17 | 18 | % divide GT into 4 regions 19 | [GT_1,GT_2,GT_3,GT_4,w1,w2,w3,w4] = divideGT(GT,X,Y); 20 | 21 | %Divede prediction into 4 regions 22 | [prediction_1,prediction_2,prediction_3,prediction_4] = Divideprediction(prediction,X,Y); 23 | 24 | %Compute the ssim score for each regions 25 | Q1 = ssim(prediction_1,GT_1); 26 | Q2 = ssim(prediction_2,GT_2); 27 | Q3 = ssim(prediction_3,GT_3); 28 | Q4 = ssim(prediction_4,GT_4); 29 | 30 | %Sum the 4 scores 31 | Q = w1 * Q1 + w2 * Q2 + w3 * Q3 + w4 * Q4; 32 | 33 | end 34 | 35 | function [X,Y] = centroid(GT) 36 | % Centroid Compute the centroid of the GT 37 | % Usage: 38 | % [X,Y] = Centroid(GT) 39 | % Input: 40 | % GT - Binary ground truth. Type: logical. 41 | % Output: 42 | % [X,Y] - The coordinates of centroid. 43 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 44 | [rows,cols] = size(GT); 45 | 46 | if(sum(GT(:))==0) 47 | X = round(cols/2); 48 | Y = round(rows/2); 49 | else 50 | total=sum(GT(:)); 51 | i=1:cols; 52 | j=(1:rows)'; 53 | X=round(sum(sum(GT,1).*i)/total); 54 | Y=round(sum(sum(GT,2).*j)/total); 55 | 56 | %dGT = double(GT); 57 | %x = ones(rows,1)*(1:cols); 58 | %y = (1:rows)'*ones(1,cols); 59 | %area = sum(dGT(:)); 60 | %X = round(sum(sum(dGT.*x))/area); 61 | %Y = round(sum(sum(dGT.*y))/area); 62 | end 63 | 64 | end 65 | 66 | % divide the GT into 4 regions according to the centroid of the GT and return the weights 67 | function [LT,RT,LB,RB,w1,w2,w3,w4] = divideGT(GT,X,Y) 68 | % LT - left top; 69 | % RT - right top; 70 | % LB - left bottom; 71 | % RB - right bottom; 72 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 73 | 74 | %width and height of the GT 75 | [hei,wid] = size(GT); 76 | area = wid * hei; 77 | 78 | %copy the 4 regions 79 | LT = GT(1:Y,1:X); 80 | RT = GT(1:Y,X+1:wid); 81 | LB = GT(Y+1:hei,1:X); 82 | RB = GT(Y+1:hei,X+1:wid); 83 | 84 | %The different weight (each block proportional to the GT foreground region). 85 | w1 = (X*Y)./area; 86 | w2 = ((wid-X)*Y)./area; 87 | w3 = (X*(hei-Y))./area; 88 | w4 = 1.0 - w1 - w2 - w3; 89 | end 90 | 91 | %Divide the prediction into 4 regions according to the centroid of the GT 92 | function [LT,RT,LB,RB] = Divideprediction(prediction,X,Y) 93 | 94 | %width and height of the prediction 95 | [hei,wid] = size(prediction); 96 | 97 | %copy the 4 regions 98 | LT = prediction(1:Y,1:X); 99 | RT = prediction(1:Y,X+1:wid); 100 | LB = prediction(Y+1:hei,1:X); 101 | RB = prediction(Y+1:hei,X+1:wid); 102 | 103 | end 104 | 105 | function Q = ssim(prediction,GT) 106 | % ssim computes the region similarity between foreground maps and ground 107 | % truth(as proposed in "Structure-measure: A new way to evaluate foreground 108 | % maps" [Deng-Ping Fan et. al - ICCV 2017]) 109 | % Usage: 110 | % Q = ssim(prediction,GT) 111 | % Input: 112 | % prediction - Binary/Non binary foreground map with values in the range 113 | % [0 1]. Type: double. 114 | % GT - Binary ground truth. Type: logical. 115 | % Output: 116 | % Q - The region similarity score 117 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 118 | 119 | dGT = double(GT); 120 | 121 | [hei,wid] = size(prediction); 122 | N = wid*hei; 123 | 124 | %Compute the mean of SM,GT 125 | x = mean2(prediction); 126 | y = mean2(dGT); 127 | 128 | %Compute the variance of SM,GT 129 | sigma_x2 = sum(sum((prediction - x).^2))./(N - 1 + eps);%sigma_x2 = var(prediction(:)) 130 | sigma_y2 = sum(sum((dGT - y).^2))./(N - 1 + eps); %sigma_y2 = var(dGT(:)); 131 | 132 | %Compute the covariance between SM and GT 133 | sigma_xy = sum(sum((prediction - x).*(dGT - y)))./(N - 1 + eps); 134 | 135 | alpha = 4 * x * y * sigma_xy; 136 | beta = (x.^2 + y.^2).*(sigma_x2 + sigma_y2); 137 | 138 | if(alpha ~= 0) 139 | Q = alpha./(beta + eps); 140 | elseif(alpha == 0 && beta == 0) 141 | Q = 1.0; 142 | else 143 | Q = 0; 144 | end 145 | 146 | end 147 | -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | 6 | 7 | class PolypDataset(data.Dataset): 8 | """ 9 | dataloader for polyp segmentation tasks 10 | """ 11 | def __init__(self, image_root, gt_root, trainsize): 12 | self.trainsize = trainsize 13 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 14 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.png')] 15 | self.images = sorted(self.images) 16 | self.gts = sorted(self.gts) 17 | self.filter_files() 18 | self.size = len(self.images) 19 | self.img_transform = transforms.Compose([ 20 | transforms.Resize((self.trainsize, self.trainsize)), 21 | transforms.ToTensor(), 22 | transforms.Normalize([0.485, 0.456, 0.406], 23 | [0.229, 0.224, 0.225])]) 24 | self.gt_transform = transforms.Compose([ 25 | transforms.Resize((self.trainsize, self.trainsize)), 26 | transforms.ToTensor()]) 27 | 28 | def __getitem__(self, index): 29 | image = self.rgb_loader(self.images[index]) 30 | gt = self.binary_loader(self.gts[index]) 31 | image = self.img_transform(image) 32 | gt = self.gt_transform(gt) 33 | return image, gt 34 | 35 | def filter_files(self): 36 | assert len(self.images) == len(self.gts) 37 | images = [] 38 | gts = [] 39 | for img_path, gt_path in zip(self.images, self.gts): 40 | img = Image.open(img_path) 41 | gt = Image.open(gt_path) 42 | if img.size == gt.size: 43 | images.append(img_path) 44 | gts.append(gt_path) 45 | self.images = images 46 | self.gts = gts 47 | 48 | def rgb_loader(self, path): 49 | with open(path, 'rb') as f: 50 | img = Image.open(f) 51 | return img.convert('RGB') 52 | 53 | def binary_loader(self, path): 54 | with open(path, 'rb') as f: 55 | img = Image.open(f) 56 | # return img.convert('1') 57 | return img.convert('L') 58 | 59 | def resize(self, img, gt): 60 | assert img.size == gt.size 61 | w, h = img.size 62 | if h < self.trainsize or w < self.trainsize: 63 | h = max(h, self.trainsize) 64 | w = max(w, self.trainsize) 65 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST) 66 | else: 67 | return img, gt 68 | 69 | def __len__(self): 70 | return self.size 71 | 72 | 73 | def get_loader(image_root, gt_root, batchsize, trainsize, shuffle=True, num_workers=4, pin_memory=True): 74 | 75 | dataset = PolypDataset(image_root, gt_root, trainsize) 76 | data_loader = data.DataLoader(dataset=dataset, 77 | batch_size=batchsize, 78 | shuffle=shuffle, 79 | num_workers=num_workers, 80 | pin_memory=pin_memory) 81 | return data_loader 82 | 83 | 84 | class test_dataset: 85 | def __init__(self, image_root, gt_root, testsize): 86 | self.testsize = testsize 87 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg') or f.endswith('.png')] 88 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.tif') or f.endswith('.png')] 89 | self.images = sorted(self.images) 90 | self.gts = sorted(self.gts) 91 | self.transform = transforms.Compose([ 92 | transforms.Resize((self.testsize, self.testsize)), 93 | transforms.ToTensor(), 94 | transforms.Normalize([0.485, 0.456, 0.406], 95 | [0.229, 0.224, 0.225])]) 96 | self.gt_transform = transforms.ToTensor() 97 | self.size = len(self.images) 98 | self.index = 0 99 | 100 | def load_data(self): 101 | image = self.rgb_loader(self.images[self.index]) 102 | image = self.transform(image).unsqueeze(0) 103 | gt = self.binary_loader(self.gts[self.index]) 104 | name = self.images[self.index].split('/')[-1] 105 | if name.endswith('.jpg'): 106 | name = name.split('.jpg')[0] + '.png' 107 | self.index += 1 108 | return image, gt, name 109 | 110 | def rgb_loader(self, path): 111 | with open(path, 'rb') as f: 112 | img = Image.open(f) 113 | return img.convert('RGB') 114 | 115 | def binary_loader(self, path): 116 | with open(path, 'rb') as f: 117 | img = Image.open(f) 118 | return img.convert('L') 119 | -------------------------------------------------------------------------------- /lib/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=1, bias=False) 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, inplanes, planes, stride=1, downsample=None): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = conv3x3(inplanes, planes, stride) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.downsample = downsample 22 | self.stride = stride 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | 34 | if self.downsample is not None: 35 | residual = self.downsample(x) 36 | 37 | out += residual 38 | out = self.relu(out) 39 | 40 | return out 41 | 42 | 43 | class Bottleneck(nn.Module): 44 | expansion = 4 45 | 46 | def __init__(self, inplanes, planes, stride=1, downsample=None): 47 | super(Bottleneck, self).__init__() 48 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 51 | padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(planes * 4) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv3(out) 71 | out = self.bn3(out) 72 | 73 | if self.downsample is not None: 74 | residual = self.downsample(x) 75 | 76 | out += residual 77 | out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class ResNet(nn.Module): 83 | # ResNet50 with two branches 84 | def __init__(self): 85 | # self.inplanes = 128 86 | self.inplanes = 64 87 | super(ResNet, self).__init__() 88 | 89 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 90 | bias=False) 91 | self.bn1 = nn.BatchNorm2d(64) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 94 | self.layer1 = self._make_layer(Bottleneck, 64, 3) 95 | self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2) 96 | self.layer3 = self._make_layer(Bottleneck, 256, 6, stride=2) 97 | self.layer4 = self._make_layer(Bottleneck, 512, 3, stride=2) 98 | 99 | for m in self.modules(): 100 | if isinstance(m, nn.Conv2d): 101 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 102 | m.weight.data.normal_(0, math.sqrt(2. / n)) 103 | elif isinstance(m, nn.BatchNorm2d): 104 | m.weight.data.fill_(1) 105 | m.bias.data.zero_() 106 | 107 | def _make_layer(self, block, planes, blocks, stride=1): 108 | downsample = None 109 | if stride != 1 or self.inplanes != planes * block.expansion: 110 | downsample = nn.Sequential( 111 | nn.Conv2d(self.inplanes, planes * block.expansion, 112 | kernel_size=1, stride=stride, bias=False), 113 | nn.BatchNorm2d(planes * block.expansion), 114 | ) 115 | 116 | layers = [] 117 | layers.append(block(self.inplanes, planes, stride, downsample)) 118 | self.inplanes = planes * block.expansion 119 | for i in range(1, blocks): 120 | layers.append(block(self.inplanes, planes)) 121 | 122 | return nn.Sequential(*layers) 123 | 124 | def forward(self, x): 125 | x = self.conv1(x) 126 | x = self.bn1(x) 127 | x = self.relu(x) 128 | x = self.maxpool(x) 129 | 130 | x = self.layer1(x) 131 | x = self.layer2(x) 132 | x1 = self.layer3_1(x) 133 | x1 = self.layer4_1(x1) 134 | 135 | x2 = self.layer3_2(x) 136 | x2 = self.layer4_2(x2) 137 | 138 | return x1, x2 139 | -------------------------------------------------------------------------------- /MyTrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import os 4 | import argparse 5 | from datetime import datetime 6 | from lib.PraNet_Res2Net import PraNet 7 | from utils.dataloader import get_loader 8 | from utils.utils import clip_gradient, adjust_lr, AvgMeter 9 | import torch.nn.functional as F 10 | 11 | 12 | def structure_loss(pred, mask): 13 | weit = 1 + 5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) 14 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none') 15 | wbce = (weit*wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3)) 16 | 17 | pred = torch.sigmoid(pred) 18 | inter = ((pred * mask)*weit).sum(dim=(2, 3)) 19 | union = ((pred + mask)*weit).sum(dim=(2, 3)) 20 | wiou = 1 - (inter + 1)/(union - inter+1) 21 | return (wbce + wiou).mean() 22 | 23 | 24 | def train(train_loader, model, optimizer, epoch): 25 | model.train() 26 | # ---- multi-scale training ---- 27 | size_rates = [0.75, 1, 1.25] 28 | loss_record2, loss_record3, loss_record4, loss_record5 = AvgMeter(), AvgMeter(), AvgMeter(), AvgMeter() 29 | for i, pack in enumerate(train_loader, start=1): 30 | for rate in size_rates: 31 | optimizer.zero_grad() 32 | # ---- data prepare ---- 33 | images, gts = pack 34 | images = Variable(images).cuda() 35 | gts = Variable(gts).cuda() 36 | # ---- rescale ---- 37 | trainsize = int(round(opt.trainsize*rate/32)*32) 38 | if rate != 1: 39 | images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 40 | gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 41 | # ---- forward ---- 42 | lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2 = model(images) 43 | # ---- loss function ---- 44 | loss5 = structure_loss(lateral_map_5, gts) 45 | loss4 = structure_loss(lateral_map_4, gts) 46 | loss3 = structure_loss(lateral_map_3, gts) 47 | loss2 = structure_loss(lateral_map_2, gts) 48 | loss = loss2 + loss3 + loss4 + loss5 # TODO: try different weights for loss 49 | # ---- backward ---- 50 | loss.backward() 51 | clip_gradient(optimizer, opt.clip) 52 | optimizer.step() 53 | # ---- recording loss ---- 54 | if rate == 1: 55 | loss_record2.update(loss2.data, opt.batchsize) 56 | loss_record3.update(loss3.data, opt.batchsize) 57 | loss_record4.update(loss4.data, opt.batchsize) 58 | loss_record5.update(loss5.data, opt.batchsize) 59 | # ---- train visualization ---- 60 | if i % 20 == 0 or i == total_step: 61 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], ' 62 | '[lateral-2: {:.4f}, lateral-3: {:0.4f}, lateral-4: {:0.4f}, lateral-5: {:0.4f}]'. 63 | format(datetime.now(), epoch, opt.epoch, i, total_step, 64 | loss_record2.show(), loss_record3.show(), loss_record4.show(), loss_record5.show())) 65 | save_path = 'snapshots/{}/'.format(opt.train_save) 66 | os.makedirs(save_path, exist_ok=True) 67 | if (epoch+1) % 10 == 0: 68 | torch.save(model.state_dict(), save_path + 'PraNet-%d.pth' % epoch) 69 | print('[Saving Snapshot:]', save_path + 'PraNet-%d.pth'% epoch) 70 | 71 | 72 | if __name__ == '__main__': 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument('--epoch', type=int, 75 | default=20, help='epoch number') 76 | parser.add_argument('--lr', type=float, 77 | default=1e-4, help='learning rate') 78 | parser.add_argument('--batchsize', type=int, 79 | default=16, help='training batch size') 80 | parser.add_argument('--trainsize', type=int, 81 | default=352, help='training dataset size') 82 | parser.add_argument('--clip', type=float, 83 | default=0.5, help='gradient clipping margin') 84 | parser.add_argument('--decay_rate', type=float, 85 | default=0.1, help='decay rate of learning rate') 86 | parser.add_argument('--decay_epoch', type=int, 87 | default=50, help='every n epochs decay learning rate') 88 | parser.add_argument('--train_path', type=str, 89 | default='./data/TrainDataset', help='path to train dataset') 90 | parser.add_argument('--train_save', type=str, 91 | default='PraNet_Res2Net') 92 | opt = parser.parse_args() 93 | 94 | # ---- build models ---- 95 | # torch.cuda.set_device(0) # set your gpu device 96 | model = PraNet().cuda() 97 | 98 | # ---- flops and params ---- 99 | # from utils.utils import CalParams 100 | # x = torch.randn(1, 3, 352, 352).cuda() 101 | # CalParams(lib, x) 102 | 103 | params = model.parameters() 104 | optimizer = torch.optim.Adam(params, opt.lr) 105 | 106 | image_root = '{}/images/'.format(opt.train_path) 107 | gt_root = '{}/masks/'.format(opt.train_path) 108 | 109 | train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize) 110 | total_step = len(train_loader) 111 | 112 | print("#"*20, "Start Training", "#"*20) 113 | 114 | for epoch in range(1, opt.epoch): 115 | adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch) 116 | train(train_loader, model, optimizer, epoch) 117 | 118 | -------------------------------------------------------------------------------- /jittor/README.md: -------------------------------------------------------------------------------- 1 | # PraNet: Parallel Reverse Attention Network for Polyp Segmentation (MICCAI2020-Oral) 2 | 3 | ## Introduction 4 | 5 | The repo provides inference code of **PraNet (MICCAI-2020)** with [Jittor deep-learning framework](https://github.com/Jittor/jittor). 6 | 7 | > **Jittor** is a high-performance deep learning framework based on JIT compiling and meta-operators. The whole framework and meta-operators are compiled just-in-time. A powerful op compiler and tuner are integrated into Jittor. It allowed us to generate high-performance code with specialized for your model. Jittor also contains a wealth of high-performance model libraries, including: image recognition, detection, segmentation, generation, differentiable rendering, geometric learning, reinforcement learning, etc. The front-end language is Python. Module Design and Dynamic Graph Execution is used in the front-end, which is the most popular design for deeplearning framework interface. The back-end is implemented by high performance language, such as CUDA, C++. 8 | 9 | ## Usage 10 | 11 | PraNet is also implemented in the Jittor toolbox which can be found in `./jittor`. 12 | + Create environment by `python3.7 -m pip install jittor` on Linux. 13 | As for MacOS or Windows users, using Docker `docker run --name jittor -v $PATH_TO_PROJECT:/home/PraNet -it jittor/jittor /bin/bash` 14 | is easier and necessary. 15 | A simple way to debug and run the script is running a new command in the container through `docker exec -it jittor /bin/bash` and start the experiments. (More details refer to this [installation tutorial](https://github.com/Jittor/jittor#install)) 16 | 17 | + First, run `sudo sysctl vm.overcommit_memory=1` to set the memory allocation policy. 18 | 19 | + Second, switch to the project root by `cd /home/PraNet` 20 | 21 | + For testing, run `python3.7 jittor/MyTest.py`. 22 | 23 | > Note that the Jittor model is just converted from the original PyTorch model via toolbox, and thus, the trained weights of PyTorch model can be used to the inference of Jittor model. 24 | 25 | ## Performance Comparison 26 | 27 | The performance has slight difference due to the different operator implemented between two frameworks. The download link ([Pytorch](https://drive.google.com/file/d/1tW0OOxPSuhfSbMijaMPwRDPElW1qQywz/view?usp=sharing) / [Jittor](https://drive.google.com/file/d/1qpzNTWLAhepCT0OGNdjUIk-SVMCGUEdf/view?usp=sharing)) of prediction results on four testing dataset, including Kvasir, CVC-612, CVC-ColonDB, ETIS, and CVC-T. 28 | 29 | | Kvasir dataset | mean Dice | mean IoU | $F_\beta^w$ | $S_\alpha$ | $E_\phi^max$ | M | 30 | |---------------------|-----------|----------|-------------|------------|--------------|-------| 31 | | PyTorch | 0.898 | 0.840 | 0.885 | 0.915 | 0.948 | 0.030 | 32 | | Jittor | 0.895 | 0.836 | 0.880 | 0.913 | 0.945 | 0.030 | 33 | 34 | | CVC-612 dataset | mean Dice | mean IoU | $F_\beta^w$ | $S_\alpha$ | $E_\phi^max$ | M | 35 | |---------------------|-----------|----------|-------------|------------|--------------|-------| 36 | | PyTorch | 0.899 | 0.849 | 0.896 | 0.936 | 0.979 | 0.009 | 37 | | Jittor | 0.900 | 0.850 | 0.897 | 0.937 | 0.978 | 0.009 | 38 | 39 | | CVC-ColonDB dataset | mean Dice | mean IoU | $F_\beta^w$ | $S_\alpha$ | $E_\phi^max$ | M | 40 | |---------------------|-----------|----------|-------------|------------|--------------|-------| 41 | | PyTorch | 0.709 | 0.640 | 0.696 | 0.819 | 0.869 | 0.045 | 42 | | Jittor | 0.708 | 0.637 | 0.695 | 0.817 | 0.869 | 0.044 | 43 | 44 | | ETIS dataset | mean Dice | mean IoU | $F_\beta^w$ | $S_\alpha$ | $E_\phi^max$ | M | 45 | |---------------------|-----------|----------|-------------|------------|--------------|-------| 46 | | PyTorch | 0.628 | 0.567 | 0.600 | 0.794 | 0.841 | 0.031 | 47 | | Jittor | 0.627 | 0.565 | 0.600 | 0.793 | 0.845 | 0.032 | 48 | 49 | | CVC-T dataset | mean Dice | mean IoU | $F_\beta^w$ | $S_\alpha$ | $E_\phi^max$ | M | 50 | |---------------------|-----------|----------|-------------|------------|--------------|-------| 51 | | PyTorch | 0.871 | 0.797 | 0.843 | 0.925 | 0.972 | 0.010 | 52 | | Jittor | 0.870 | 0.796 | 0.842 | 0.925 | 0.973 | 0.010 | 53 | 54 | ## Speedup 55 | 56 | The jittor-based code can speed up the inference efficiency. 57 | 58 | | Batch Size | PyTorch | Jittor | Speedup | 59 | |----------- |---------------- |---------------- |---------------- | 60 | | 1 | 52 FPS | 67 FPS | 1.29x | 61 | | 4 | 194 FPS | 255 FPS | 1.31x | 62 | | 8 | 391 FPS | 508 FPS | 1.30x | 63 | | 16 | 476 FPS | 593 FPS | 1.25x | 64 | 65 | ## Citation 66 | 67 | If you find our work useful in your research, please consider citing: 68 | 69 | 70 | @article{fan2020pra, 71 | title={PraNet: Parallel Reverse Attention Network for Polyp Segmentation}, 72 | author={Fan, Deng-Ping and Ji, Ge-Peng and Zhou, Tao and Chen, Geng and Fu, Huazhu and Shen, Jianbing and Shao, Ling}, 73 | journal={MICCAI}, 74 | year={2020} 75 | } 76 | 77 | and the jittor framework: 78 | 79 | @article{hu2020jittor, 80 | title={Jittor: a novel deep learning framework with meta-operators and unified graph execution}, 81 | author={Hu, Shi-Min and Liang, Dun and Yang, Guo-Ye and Yang, Guo-Wei and Zhou, Wen-Yang}, 82 | journal={Science China Information Sciences}, 83 | volume={63}, 84 | number={222103}, 85 | pages={1--21}, 86 | year={2020} 87 | } 88 | 89 | 90 | # Acknowledgements 91 | 92 | Thanks to Liang Dun from Tsinghua University ([The Graphics and Geometric Computing Group](https://cg.cs.tsinghua.edu.cn/#people.htm)) for his help in the framework conversion process. 93 | -------------------------------------------------------------------------------- /jittor/lib/Res2Net_v1b.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import jittor as jt 4 | from jittor import init 5 | from jittor import nn 6 | 7 | __all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b', 'res2net50_v1b_26w_4s'] 8 | model_urls = { 9 | 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth', 10 | 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth'} 11 | 12 | 13 | class Bottle2neck(nn.Module): 14 | expansion = 4 15 | 16 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'): 17 | super(Bottle2neck, self).__init__() 18 | width = int(math.floor((planes * (baseWidth / 64.0)))) 19 | self.conv1 = nn.Conv(inplanes, (width * scale), 1, bias=False) 20 | self.bn1 = nn.BatchNorm((width * scale)) 21 | if (scale == 1): 22 | self.nums = 1 23 | else: 24 | self.nums = (scale - 1) 25 | if (stype == 'stage'): 26 | self.pool = nn.Pool(3, stride=stride, padding=1, op='mean') 27 | convs = [] 28 | bns = [] 29 | for i in range(self.nums): 30 | convs.append(nn.Conv(width, width, 3, stride=stride, padding=1, bias=False)) 31 | bns.append(nn.BatchNorm(width)) 32 | self.convs = nn.ModuleList(convs) 33 | self.bns = nn.ModuleList(bns) 34 | self.conv3 = nn.Conv((width * scale), (planes * self.expansion), 1, bias=False) 35 | self.bn3 = nn.BatchNorm((planes * self.expansion)) 36 | self.relu = nn.ReLU() 37 | self.downsample = downsample 38 | self.stype = stype 39 | self.scale = scale 40 | self.width = width 41 | 42 | def execute(self, x): 43 | residual = x 44 | out = self.conv1(x) 45 | out = self.bn1(out) 46 | out = nn.relu(out) 47 | spx = jt.split(out, self.width, 1) 48 | for i in range(self.nums): 49 | if ((i == 0) or (self.stype == 'stage')): 50 | sp = spx[i] 51 | else: 52 | sp = (sp + spx[i]) 53 | sp = self.convs[i](sp) 54 | sp = nn.relu(self.bns[i](sp)) 55 | if (i == 0): 56 | out = sp 57 | else: 58 | out = jt.contrib.concat((out, sp), dim=1) 59 | if ((self.scale != 1) and (self.stype == 'normal')): 60 | out = jt.contrib.concat((out, spx[self.nums]), dim=1) 61 | elif ((self.scale != 1) and (self.stype == 'stage')): 62 | out = jt.contrib.concat((out, self.pool(spx[self.nums])), dim=1) 63 | out = self.conv3(out) 64 | out = self.bn3(out) 65 | if (self.downsample is not None): 66 | residual = self.downsample(x) 67 | out += residual 68 | out = nn.relu(out) 69 | return out 70 | 71 | 72 | class Res2Net(nn.Module): 73 | 74 | def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000): 75 | self.inplanes = 64 76 | super(Res2Net, self).__init__() 77 | self.baseWidth = baseWidth 78 | self.scale = scale 79 | self.conv1 = nn.Sequential(nn.Conv(3, 32, 3, stride=2, padding=1, bias=False), nn.BatchNorm(32), nn.ReLU(), 80 | nn.Conv(32, 32, 3, stride=1, padding=1, bias=False), nn.BatchNorm(32), nn.ReLU(), 81 | nn.Conv(32, 64, 3, stride=1, padding=1, bias=False)) 82 | self.bn1 = nn.BatchNorm(64) 83 | self.relu = nn.ReLU() 84 | self.maxpool = nn.Pool(3, stride=2, padding=1, op='maximum') 85 | self.layer1 = self._make_layer(block, 64, layers[0]) 86 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 87 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 88 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 89 | self.avgpool = nn.AdaptiveAvgPool2d(1) 90 | self.fc = nn.Linear((512 * block.expansion), num_classes) 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv): 93 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 94 | elif isinstance(m, nn.BatchNorm): 95 | init.constant_(m.weight, value=1) 96 | init.constant_(m.bias, value=0) 97 | 98 | def _make_layer(self, block, planes, blocks, stride=1): 99 | downsample = None 100 | 101 | if ((stride != 1) or (self.inplanes != (planes * block.expansion))): 102 | downsample = nn.Sequential( 103 | nn.Pool(stride, stride=stride, ceil_mode=True, op='mean'), 104 | nn.Conv(self.inplanes, (planes * block.expansion), 1, stride=1, bias=False), 105 | nn.BatchNorm((planes * block.expansion)) 106 | ) 107 | 108 | layers = [] 109 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, stype='stage', baseWidth=self.baseWidth, 110 | scale=self.scale)) 111 | 112 | self.inplanes = (planes * block.expansion) 113 | for i in range(1, blocks): 114 | layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale)) 115 | 116 | return nn.Sequential(*layers) 117 | 118 | def execute(self, x): 119 | x = self.conv1(x) 120 | x = self.bn1(x) 121 | x = nn.relu(x) 122 | x = self.maxpool(x) 123 | x = self.layer1(x) 124 | x = self.layer2(x) 125 | x = self.layer3(x) 126 | x = self.layer4(x) 127 | x = self.avgpool(x) 128 | x = x.view((x.shape[0], (- 1))) 129 | x = self.fc(x) 130 | return x 131 | 132 | 133 | def res2net50_v1b(pretrained=False, **kwargs): 134 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 135 | if pretrained: 136 | model.load(jt.load(model_urls['res2net50_v1b_26w_4s'])) 137 | return model 138 | 139 | 140 | def res2net101_v1b(pretrained=False, **kwargs): 141 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs) 142 | if pretrained: 143 | model.load(jt.load(model_urls['res2net101_v1b_26w_4s'])) 144 | return model 145 | 146 | 147 | def res2net50_26w_4s(pretrained=False, **kwargs): 148 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 149 | if pretrained: 150 | model.load(jt.load(model_urls['res2net50_v1b_26w_4s'])) 151 | return model 152 | 153 | 154 | def res2net101_v1b_26w_4s(pretrained=False, **kwargs): 155 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs) 156 | if pretrained: 157 | model.load(jt.load((model_urls['res2net101_v1b_26w_4s']))) 158 | return model 159 | 160 | if __name__ == '__main__': 161 | images = jt.rand(1, 3, 352, 352) 162 | model = res2net50_26w_4s(pretrained=False) 163 | model = model 164 | print(model(images).shape) 165 | -------------------------------------------------------------------------------- /eval/main.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | %Evaluation tool boxs for PraNet: Parallel Reverse Attention Network for Polyp Segmentation (MICCAI20). 3 | %Author: Deng-Ping Fan, Tao Zhou, Ge-Peng Ji, Yi Zhou, Geng Chen, Huazhu Fu, Jianbing Shen, and Ling Shao 4 | %Homepage: http://dpfan.net/ 5 | %Projectpage: https://github.com/DengPingFan/PraNet 6 | %First version: 2020-6-28 7 | %Any questions please contact with dengpfan@gmail.com. 8 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 9 | %Function: Providing several important metrics: Dice, IoU, F1, S-m (ICCV'17), Weighted-F1 (CVPR'14) 10 | % E-m (IJCAI'18), Precision, Recall, Sensitivity, Specificity, MAE. 11 | 12 | 13 | clear all; 14 | close all; 15 | clc; 16 | 17 | % ---- 1. ResultMap Path Setting ---- 18 | ResultMapPath = '../results/'; 19 | Models = {'PraNet'}; %{'UNet','UNet++','PraNet','SFA'}; 20 | modelNum = length(Models); 21 | 22 | % ---- 2. Ground-truth Datasets Setting ---- 23 | DataPath = '../data/TestDataset/'; 24 | Datasets = {'CVC-300','CVC-ClinicDB'}; %{'CVC-ClinicDB', 'CVC-ColonDB','ETIS-LaribPolypDB', 'Kvasir','CVC-300'}; 25 | 26 | % ---- 3. Evaluation Results Save Path Setting ---- 27 | ResDir = './EvaluateResults/'; 28 | ResName='_result.txt'; % You can change the result name. 29 | 30 | Thresholds = 1:-1/255:0; 31 | datasetNum = length(Datasets); 32 | 33 | for d = 1:datasetNum 34 | 35 | tic; 36 | dataset = Datasets{d} % print cur dataset name 37 | fprintf('Processing %d/%d: %s Dataset\n',d,datasetNum,dataset); 38 | 39 | ResPath = [ResDir dataset '-mat/']; % The result will be saved in *.mat file so that you can used it for the next time. 40 | if ~exist(ResPath,'dir') 41 | mkdir(ResPath); 42 | end 43 | resTxt = [ResDir dataset ResName]; % The evaluation result will be saved in `../Resluts/Result-XXXX` folder. 44 | fileID = fopen(resTxt,'w'); 45 | 46 | for m = 1:modelNum 47 | model = Models{m} % print cur model name 48 | 49 | gtPath = [DataPath dataset '/masks/']; 50 | resMapPath = [ResultMapPath '/' model '/' dataset '/']; 51 | 52 | imgFiles = dir([resMapPath '*.png']); 53 | imgNUM = length(imgFiles); 54 | 55 | [threshold_Fmeasure, threshold_Emeasure, threshold_IoU] = deal(zeros(imgNUM,length(Thresholds))); 56 | [threshold_Precion, threshold_Recall] = deal(zeros(imgNUM,length(Thresholds))); 57 | [threshold_Sensitivity, threshold_Specificity, threshold_Dice] = deal(zeros(imgNUM,length(Thresholds))); 58 | 59 | [Smeasure, wFmeasure, MAE] =deal(zeros(1,imgNUM)); 60 | 61 | for i = 1:imgNUM 62 | name = imgFiles(i).name; 63 | fprintf('Evaluating(%s Dataset,%s Model, %s Image): %d/%d\n',dataset, model, name, i,imgNUM); 64 | 65 | %load gt 66 | gt = imread([gtPath name]); 67 | 68 | if (ndims(gt)>2) 69 | gt = rgb2gray(gt); 70 | end 71 | 72 | if ~islogical(gt) 73 | gt = gt(:,:,1) > 128; 74 | end 75 | 76 | %load resMap 77 | resmap = imread([resMapPath name]); 78 | %check size 79 | if size(resmap, 1) ~= size(gt, 1) || size(resmap, 2) ~= size(gt, 2) 80 | resmap = imresize(resmap,size(gt)); 81 | imwrite(resmap,[resMapPath name]); 82 | fprintf('Resizing have been operated!! The resmap size is not math with gt in the path: %s!!!\n', [resMapPath name]); 83 | end 84 | 85 | resmap = im2double(resmap(:,:,1)); 86 | 87 | %normalize resmap to [0, 1] 88 | resmap = reshape(mapminmax(resmap(:)',0,1),size(resmap)); 89 | 90 | % S-meaure metric published in ICCV'17 (Structure measure: A New Way to Evaluate the Foreground Map.) 91 | Smeasure(i) = StructureMeasure(resmap,logical(gt)); 92 | 93 | % Weighted F-measure metric published in CVPR'14 (How to evaluate the foreground maps?) 94 | wFmeasure(i) = original_WFb(resmap,logical(gt)); 95 | 96 | MAE(i) = mean2(abs(double(logical(gt)) - resmap)); 97 | 98 | [threshold_E, threshold_F, threshold_Pr, threshold_Rec, threshold_Iou] = deal(zeros(1,length(Thresholds))); 99 | [threshold_Spe, threshold_Dic] = deal(zeros(1,length(Thresholds))); 100 | for t = 1:length(Thresholds) 101 | threshold = Thresholds(t); 102 | [threshold_Pr(t), threshold_Rec(t), threshold_Spe(t), threshold_Dic(t), threshold_F(t), threshold_Iou(t)] = Fmeasure_calu(resmap,double(gt),size(gt),threshold); 103 | 104 | Bi_resmap = zeros(size(resmap)); 105 | Bi_resmap(resmap>=threshold)=1; 106 | threshold_E(t) = Enhancedmeasure(Bi_resmap, gt); 107 | end 108 | 109 | threshold_Emeasure(i,:) = threshold_E; 110 | threshold_Fmeasure(i,:) = threshold_F; 111 | threshold_Sensitivity(i,:) = threshold_Rec; 112 | threshold_Specificity(i,:) = threshold_Spe; 113 | threshold_Dice(i,:) = threshold_Dic; 114 | threshold_IoU(i,:) = threshold_Iou; 115 | 116 | end 117 | 118 | %MAE 119 | mae = mean2(MAE); 120 | 121 | %Sm 122 | Sm = mean2(Smeasure); 123 | 124 | %wFm 125 | wFm = mean2(wFmeasure); 126 | 127 | %E-m 128 | column_E = mean(threshold_Emeasure,1); 129 | meanEm = mean(column_E); 130 | maxEm = max(column_E); 131 | 132 | %Sensitivity 133 | column_Sen = mean(threshold_Sensitivity,1); 134 | meanSen = mean(column_Sen); 135 | maxSen = max(column_Sen); 136 | 137 | %,Specificity 138 | column_Spe = mean(threshold_Specificity,1); 139 | meanSpe = mean(column_Spe); 140 | maxSpe = max(column_Spe); 141 | 142 | %Dice 143 | column_Dic = mean(threshold_Dice,1); 144 | meanDic = mean(column_Dic); 145 | maxDic = max(column_Dic); 146 | 147 | %IoU 148 | column_IoU = mean(threshold_IoU,1); 149 | meanIoU = mean(column_IoU); 150 | maxIoU = max(column_IoU); 151 | 152 | save([ResPath model],'Sm', 'mae', 'column_Dic', 'column_Sen', 'column_Spe', 'column_E','column_IoU','maxDic','maxEm','maxSen','maxSpe','maxIoU','meanIoU','meanDic','meanEm','meanSen','meanSpe'); 153 | fprintf(fileID, '(Dataset:%s; Model:%s) meanDic:%.3f;meanIoU:%.3f;wFm:%.3f;Sm:%.3f;meanEm:%.3f;MAE:%.3f;maxEm:%.3f;maxDice:%.3f;maxIoU:%.3f;meanSen:%.3f;maxSen:%.3f;meanSpe:%.3f;maxSpe:%.3f.\n',dataset,model,meanDic,meanIoU,wFm,Sm,meanEm,mae,maxEm,maxDic,maxIoU,meanSen,maxSen,meanSpe,maxSpe); 154 | fprintf('(Dataset:%s; Model:%s) meanDic:%.3f;meanIoU:%.3f;wFm:%.3f;Sm:%.3f;meanEm:%.3f;MAE:%.3f;maxEm:%.3f;maxDice:%.3f;maxIoU:%.3f;meanSen:%.3f;maxSen:%.3f;meanSpe:%.3f;maxSpe:%.3f.\n',dataset,model,meanDic,meanIoU,wFm,Sm,meanEm,mae,maxEm,maxDic,maxIoU,meanSen,maxSen,meanSpe,maxSpe); 155 | end 156 | 157 | toc; 158 | end 159 | 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /jittor/lib/PraNet_Res2Net.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import nn 3 | 4 | from lib.Res2Net_v1b import res2net50_26w_4s 5 | 6 | class BasicConv2d(nn.Module): 7 | 8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 9 | super(BasicConv2d, self).__init__() 10 | self.conv = nn.Conv(in_planes, out_planes, kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False) 11 | self.bn = nn.BatchNorm(out_planes) 12 | self.relu = nn.ReLU() 13 | 14 | def execute(self, x): 15 | x = self.conv(x) 16 | x = self.bn(x) 17 | return x 18 | 19 | class RFB_modified(nn.Module): 20 | 21 | def __init__(self, in_channel, out_channel): 22 | super(RFB_modified, self).__init__() 23 | self.relu = nn.ReLU() 24 | self.branch0 = nn.Sequential(BasicConv2d(in_channel, out_channel, 1)) 25 | self.branch1 = nn.Sequential(BasicConv2d(in_channel, out_channel, 1), BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)), BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)), BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)) 26 | self.branch2 = nn.Sequential(BasicConv2d(in_channel, out_channel, 1), BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)), BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)), BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)) 27 | self.branch3 = nn.Sequential(BasicConv2d(in_channel, out_channel, 1), BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)), BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)), BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)) 28 | self.conv_cat = BasicConv2d((4 * out_channel), out_channel, 3, padding=1) 29 | self.conv_res = BasicConv2d(in_channel, out_channel, 1) 30 | 31 | def execute(self, x): 32 | x0 = self.branch0(x) 33 | x1 = self.branch1(x) 34 | x2 = self.branch2(x) 35 | x3 = self.branch3(x) 36 | x_cat = self.conv_cat(jt.contrib.concat((x0, x1, x2, x3), dim=1)) 37 | x = nn.relu((x_cat + self.conv_res(x))) 38 | return x 39 | 40 | class aggregation(nn.Module): 41 | 42 | def __init__(self, channel): 43 | super(aggregation, self).__init__() 44 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') 45 | self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1) 46 | self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1) 47 | self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1) 48 | self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1) 49 | self.conv_upsample5 = BasicConv2d((2 * channel), (2 * channel), 3, padding=1) 50 | self.conv_concat2 = BasicConv2d((2 * channel), (2 * channel), 3, padding=1) 51 | self.conv_concat3 = BasicConv2d((3 * channel), (3 * channel), 3, padding=1) 52 | self.conv4 = BasicConv2d((3 * channel), (3 * channel), 3, padding=1) 53 | self.conv5 = nn.Conv((3 * channel), 1, 1) 54 | 55 | def execute(self, x1, x2, x3): 56 | x1_1 = x1 57 | x2_1 = (self.conv_upsample1(self.upsample(x1)) * x2) 58 | x3_1 = ((self.conv_upsample2(self.upsample(self.upsample(x1))) * self.conv_upsample3(self.upsample(x2))) * x3) 59 | x2_2 = jt.contrib.concat((x2_1, self.conv_upsample4(self.upsample(x1_1))), dim=1) 60 | x2_2 = self.conv_concat2(x2_2) 61 | x3_2 = jt.contrib.concat((x3_1, self.conv_upsample5(self.upsample(x2_2))), dim=1) 62 | x3_2 = self.conv_concat3(x3_2) 63 | x = self.conv4(x3_2) 64 | x = self.conv5(x) 65 | return x 66 | 67 | class PraNet(nn.Module): 68 | 69 | def __init__(self, channel=32, pretrained_backbone=False): 70 | super(PraNet, self).__init__() 71 | self.resnet = res2net50_26w_4s(pretrained=pretrained_backbone) 72 | self.rfb2_1 = RFB_modified(512, channel) 73 | self.rfb3_1 = RFB_modified(1024, channel) 74 | self.rfb4_1 = RFB_modified(2048, channel) 75 | self.agg1 = aggregation(channel) 76 | self.ra4_conv1 = BasicConv2d(2048, 256, kernel_size=1) 77 | self.ra4_conv2 = BasicConv2d(256, 256, kernel_size=5, padding=2) 78 | self.ra4_conv3 = BasicConv2d(256, 256, kernel_size=5, padding=2) 79 | self.ra4_conv4 = BasicConv2d(256, 256, kernel_size=5, padding=2) 80 | self.ra4_conv5 = BasicConv2d(256, 1, kernel_size=1) 81 | self.ra3_conv1 = BasicConv2d(1024, 64, kernel_size=1) 82 | self.ra3_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1) 83 | self.ra3_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1) 84 | self.ra3_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1) 85 | self.ra2_conv1 = BasicConv2d(512, 64, kernel_size=1) 86 | self.ra2_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1) 87 | self.ra2_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1) 88 | self.ra2_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1) 89 | 90 | self.upsample1_1 = nn.Upsample(scale_factor=8, mode='bilinear') 91 | self.upsample1_2 = nn.Upsample(scale_factor=0.25, mode='bilinear') 92 | 93 | self.upsample2_1 = nn.Upsample(scale_factor=32, mode='bilinear') 94 | self.upsample2_2 = nn.Upsample(scale_factor=2, mode='bilinear') 95 | 96 | self.upsample3_1 = nn.Upsample(scale_factor=16, mode='bilinear') 97 | self.upsample3_2 = nn.Upsample(scale_factor=2, mode='bilinear') 98 | 99 | self.upsample4 = nn.Upsample(scale_factor=8, mode='bilinear') 100 | 101 | def execute(self, x): 102 | x = self.resnet.conv1(x) 103 | x = self.resnet.bn1(x) 104 | x = nn.relu(x) 105 | x = self.resnet.maxpool(x) 106 | x1 = self.resnet.layer1(x) 107 | x2 = self.resnet.layer2(x1) 108 | x3 = self.resnet.layer3(x2) 109 | x4 = self.resnet.layer4(x3) 110 | x2_rfb = self.rfb2_1(x2) 111 | x3_rfb = self.rfb3_1(x3) 112 | x4_rfb = self.rfb4_1(x4) 113 | ra5_feat = self.agg1(x4_rfb, x3_rfb, x2_rfb) 114 | lateral_map_5 = self.upsample1_1(ra5_feat) 115 | crop_4 = self.upsample1_2(ra5_feat) 116 | x = (((- 1) * jt.sigmoid(crop_4)) + 1) 117 | x = x.expand((- 1), 2048, (- 1), (- 1)).multiply(x4) 118 | x = self.ra4_conv1(x) 119 | x = nn.relu(self.ra4_conv2(x)) 120 | x = nn.relu(self.ra4_conv3(x)) 121 | x = nn.relu(self.ra4_conv4(x)) 122 | ra4_feat = self.ra4_conv5(x) 123 | x = (ra4_feat + crop_4) 124 | lateral_map_4 = self.upsample2_1(x) 125 | crop_3 = self.upsample2_2(x) 126 | x = (((- 1) * jt.sigmoid(crop_3)) + 1) 127 | x = x.expand((- 1), 1024, (- 1), (- 1)).multiply(x3) 128 | x = self.ra3_conv1(x) 129 | x = nn.relu(self.ra3_conv2(x)) 130 | x = nn.relu(self.ra3_conv3(x)) 131 | ra3_feat = self.ra3_conv4(x) 132 | x = (ra3_feat + crop_3) 133 | lateral_map_3 = self.upsample3_1(x) 134 | crop_2 = self.upsample3_2(x) 135 | x = (((- 1) * jt.sigmoid(crop_2)) + 1) 136 | x = x.expand((- 1), 512, (- 1), (- 1)).multiply(x2) 137 | x = self.ra2_conv1(x) 138 | x = nn.relu(self.ra2_conv2(x)) 139 | x = nn.relu(self.ra2_conv3(x)) 140 | ra2_feat = self.ra2_conv4(x) 141 | x = (ra2_feat + crop_2) 142 | lateral_map_2 = self.upsample4(x) 143 | return (lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2) 144 | 145 | 146 | if __name__ == '__main__': 147 | import numpy as np 148 | from time import time 149 | net = PraNet() 150 | net.eval() 151 | 152 | dump_x = jt.randn(1, 3, 352, 352) 153 | frame_rate = np.zeros((1000, 1)) 154 | for i in range(1000): 155 | start = time() 156 | y = net(dump_x) 157 | end = time() 158 | running_frame_rate = (1 * float((1 / (end - start)))) 159 | print(i, '->', running_frame_rate) 160 | frame_rate[i] = running_frame_rate 161 | print(np.mean(frame_rate)) 162 | print(y.shape) 163 | -------------------------------------------------------------------------------- /lib/PraNet_Res2Net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .Res2Net_v1b import res2net50_v1b_26w_4s 5 | 6 | 7 | class BasicConv2d(nn.Module): 8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 9 | super(BasicConv2d, self).__init__() 10 | self.conv = nn.Conv2d(in_planes, out_planes, 11 | kernel_size=kernel_size, stride=stride, 12 | padding=padding, dilation=dilation, bias=False) 13 | self.bn = nn.BatchNorm2d(out_planes) 14 | self.relu = nn.ReLU(inplace=True) 15 | 16 | def forward(self, x): 17 | x = self.conv(x) 18 | x = self.bn(x) 19 | return x 20 | 21 | 22 | class RFB_modified(nn.Module): 23 | def __init__(self, in_channel, out_channel): 24 | super(RFB_modified, self).__init__() 25 | self.relu = nn.ReLU(True) 26 | self.branch0 = nn.Sequential( 27 | BasicConv2d(in_channel, out_channel, 1), 28 | ) 29 | self.branch1 = nn.Sequential( 30 | BasicConv2d(in_channel, out_channel, 1), 31 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)), 32 | BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)), 33 | BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3) 34 | ) 35 | self.branch2 = nn.Sequential( 36 | BasicConv2d(in_channel, out_channel, 1), 37 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)), 38 | BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)), 39 | BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5) 40 | ) 41 | self.branch3 = nn.Sequential( 42 | BasicConv2d(in_channel, out_channel, 1), 43 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)), 44 | BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)), 45 | BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7) 46 | ) 47 | self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1) 48 | self.conv_res = BasicConv2d(in_channel, out_channel, 1) 49 | 50 | def forward(self, x): 51 | x0 = self.branch0(x) 52 | x1 = self.branch1(x) 53 | x2 = self.branch2(x) 54 | x3 = self.branch3(x) 55 | x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1)) 56 | 57 | x = self.relu(x_cat + self.conv_res(x)) 58 | return x 59 | 60 | 61 | class aggregation(nn.Module): 62 | # dense aggregation, it can be replaced by other aggregation previous, such as DSS, amulet, and so on. 63 | # used after MSF 64 | def __init__(self, channel): 65 | super(aggregation, self).__init__() 66 | self.relu = nn.ReLU(True) 67 | 68 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 69 | self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1) 70 | self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1) 71 | self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1) 72 | self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1) 73 | self.conv_upsample5 = BasicConv2d(2*channel, 2*channel, 3, padding=1) 74 | 75 | self.conv_concat2 = BasicConv2d(2*channel, 2*channel, 3, padding=1) 76 | self.conv_concat3 = BasicConv2d(3*channel, 3*channel, 3, padding=1) 77 | self.conv4 = BasicConv2d(3*channel, 3*channel, 3, padding=1) 78 | self.conv5 = nn.Conv2d(3*channel, 1, 1) 79 | 80 | def forward(self, x1, x2, x3): 81 | x1_1 = x1 82 | x2_1 = self.conv_upsample1(self.upsample(x1)) * x2 83 | x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \ 84 | * self.conv_upsample3(self.upsample(x2)) * x3 85 | 86 | x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1) 87 | x2_2 = self.conv_concat2(x2_2) 88 | 89 | x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1) 90 | x3_2 = self.conv_concat3(x3_2) 91 | 92 | x = self.conv4(x3_2) 93 | x = self.conv5(x) 94 | 95 | return x 96 | 97 | 98 | class PraNet(nn.Module): 99 | # res2net based encoder decoder 100 | def __init__(self, channel=32): 101 | super(PraNet, self).__init__() 102 | # ---- ResNet Backbone ---- 103 | self.resnet = res2net50_v1b_26w_4s(pretrained=True) 104 | # ---- Receptive Field Block like module ---- 105 | self.rfb2_1 = RFB_modified(512, channel) 106 | self.rfb3_1 = RFB_modified(1024, channel) 107 | self.rfb4_1 = RFB_modified(2048, channel) 108 | # ---- Partial Decoder ---- 109 | self.agg1 = aggregation(channel) 110 | # ---- reverse attention branch 4 ---- 111 | self.ra4_conv1 = BasicConv2d(2048, 256, kernel_size=1) 112 | self.ra4_conv2 = BasicConv2d(256, 256, kernel_size=5, padding=2) 113 | self.ra4_conv3 = BasicConv2d(256, 256, kernel_size=5, padding=2) 114 | self.ra4_conv4 = BasicConv2d(256, 256, kernel_size=5, padding=2) 115 | self.ra4_conv5 = BasicConv2d(256, 1, kernel_size=1) 116 | # ---- reverse attention branch 3 ---- 117 | self.ra3_conv1 = BasicConv2d(1024, 64, kernel_size=1) 118 | self.ra3_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1) 119 | self.ra3_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1) 120 | self.ra3_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1) 121 | # ---- reverse attention branch 2 ---- 122 | self.ra2_conv1 = BasicConv2d(512, 64, kernel_size=1) 123 | self.ra2_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1) 124 | self.ra2_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1) 125 | self.ra2_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1) 126 | 127 | def forward(self, x): 128 | x = self.resnet.conv1(x) 129 | x = self.resnet.bn1(x) 130 | x = self.resnet.relu(x) 131 | x = self.resnet.maxpool(x) # bs, 64, 88, 88 132 | # ---- low-level features ---- 133 | x1 = self.resnet.layer1(x) # bs, 256, 88, 88 134 | x2 = self.resnet.layer2(x1) # bs, 512, 44, 44 135 | 136 | x3 = self.resnet.layer3(x2) # bs, 1024, 22, 22 137 | x4 = self.resnet.layer4(x3) # bs, 2048, 11, 11 138 | x2_rfb = self.rfb2_1(x2) # channel -> 32 139 | x3_rfb = self.rfb3_1(x3) # channel -> 32 140 | x4_rfb = self.rfb4_1(x4) # channel -> 32 141 | 142 | ra5_feat = self.agg1(x4_rfb, x3_rfb, x2_rfb) 143 | lateral_map_5 = F.interpolate(ra5_feat, scale_factor=8, mode='bilinear') # NOTES: Sup-1 (bs, 1, 44, 44) -> (bs, 1, 352, 352) 144 | 145 | # ---- reverse attention branch_4 ---- 146 | crop_4 = F.interpolate(ra5_feat, scale_factor=0.25, mode='bilinear') 147 | x = -1*(torch.sigmoid(crop_4)) + 1 148 | x = x.expand(-1, 2048, -1, -1).mul(x4) 149 | x = self.ra4_conv1(x) 150 | x = F.relu(self.ra4_conv2(x)) 151 | x = F.relu(self.ra4_conv3(x)) 152 | x = F.relu(self.ra4_conv4(x)) 153 | ra4_feat = self.ra4_conv5(x) 154 | x = ra4_feat + crop_4 155 | lateral_map_4 = F.interpolate(x, scale_factor=32, mode='bilinear') # NOTES: Sup-2 (bs, 1, 11, 11) -> (bs, 1, 352, 352) 156 | 157 | # ---- reverse attention branch_3 ---- 158 | crop_3 = F.interpolate(x, scale_factor=2, mode='bilinear') 159 | x = -1*(torch.sigmoid(crop_3)) + 1 160 | x = x.expand(-1, 1024, -1, -1).mul(x3) 161 | x = self.ra3_conv1(x) 162 | x = F.relu(self.ra3_conv2(x)) 163 | x = F.relu(self.ra3_conv3(x)) 164 | ra3_feat = self.ra3_conv4(x) 165 | x = ra3_feat + crop_3 166 | lateral_map_3 = F.interpolate(x, scale_factor=16, mode='bilinear') # NOTES: Sup-3 (bs, 1, 22, 22) -> (bs, 1, 352, 352) 167 | 168 | # ---- reverse attention branch_2 ---- 169 | crop_2 = F.interpolate(x, scale_factor=2, mode='bilinear') 170 | x = -1*(torch.sigmoid(crop_2)) + 1 171 | x = x.expand(-1, 512, -1, -1).mul(x2) 172 | x = self.ra2_conv1(x) 173 | x = F.relu(self.ra2_conv2(x)) 174 | x = F.relu(self.ra2_conv3(x)) 175 | ra2_feat = self.ra2_conv4(x) 176 | x = ra2_feat + crop_2 177 | lateral_map_2 = F.interpolate(x, scale_factor=8, mode='bilinear') # NOTES: Sup-4 (bs, 1, 44, 44) -> (bs, 1, 352, 352) 178 | 179 | return lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2 180 | 181 | 182 | if __name__ == '__main__': 183 | ras = PraNet().cuda() 184 | input_tensor = torch.randn(1, 3, 352, 352).cuda() 185 | 186 | out = ras(input_tensor) -------------------------------------------------------------------------------- /lib/Res2Net_v1b.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | __all__ = ['Res2Net', 'res2net50_v1b', 'res2net101_v1b', 'res2net50_v1b_26w_4s'] 8 | 9 | model_urls = { 10 | 'res2net50_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net50_v1b_26w_4s-3cf99910.pth', 11 | 'res2net101_v1b_26w_4s': 'https://shanghuagao.oss-cn-beijing.aliyuncs.com/res2net/res2net101_v1b_26w_4s-0812c246.pth', 12 | } 13 | 14 | 15 | class Bottle2neck(nn.Module): 16 | expansion = 4 17 | 18 | def __init__(self, inplanes, planes, stride=1, downsample=None, baseWidth=26, scale=4, stype='normal'): 19 | """ Constructor 20 | Args: 21 | inplanes: input channel dimensionality 22 | planes: output channel dimensionality 23 | stride: conv stride. Replaces pooling layer. 24 | downsample: None when stride = 1 25 | baseWidth: basic width of conv3x3 26 | scale: number of scale. 27 | type: 'normal': normal set. 'stage': first block of a new stage. 28 | """ 29 | super(Bottle2neck, self).__init__() 30 | 31 | width = int(math.floor(planes * (baseWidth / 64.0))) 32 | self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False) 33 | self.bn1 = nn.BatchNorm2d(width * scale) 34 | 35 | if scale == 1: 36 | self.nums = 1 37 | else: 38 | self.nums = scale - 1 39 | if stype == 'stage': 40 | self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) 41 | convs = [] 42 | bns = [] 43 | for i in range(self.nums): 44 | convs.append(nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, bias=False)) 45 | bns.append(nn.BatchNorm2d(width)) 46 | self.convs = nn.ModuleList(convs) 47 | self.bns = nn.ModuleList(bns) 48 | 49 | self.conv3 = nn.Conv2d(width * scale, planes * self.expansion, kernel_size=1, bias=False) 50 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 51 | 52 | self.relu = nn.ReLU(inplace=True) 53 | self.downsample = downsample 54 | self.stype = stype 55 | self.scale = scale 56 | self.width = width 57 | 58 | def forward(self, x): 59 | residual = x 60 | 61 | out = self.conv1(x) 62 | out = self.bn1(out) 63 | out = self.relu(out) 64 | 65 | spx = torch.split(out, self.width, 1) 66 | for i in range(self.nums): 67 | if i == 0 or self.stype == 'stage': 68 | sp = spx[i] 69 | else: 70 | sp = sp + spx[i] 71 | sp = self.convs[i](sp) 72 | sp = self.relu(self.bns[i](sp)) 73 | if i == 0: 74 | out = sp 75 | else: 76 | out = torch.cat((out, sp), 1) 77 | if self.scale != 1 and self.stype == 'normal': 78 | out = torch.cat((out, spx[self.nums]), 1) 79 | elif self.scale != 1 and self.stype == 'stage': 80 | out = torch.cat((out, self.pool(spx[self.nums])), 1) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class Res2Net(nn.Module): 95 | 96 | def __init__(self, block, layers, baseWidth=26, scale=4, num_classes=1000): 97 | self.inplanes = 64 98 | super(Res2Net, self).__init__() 99 | self.baseWidth = baseWidth 100 | self.scale = scale 101 | self.conv1 = nn.Sequential( 102 | nn.Conv2d(3, 32, 3, 2, 1, bias=False), 103 | nn.BatchNorm2d(32), 104 | nn.ReLU(inplace=True), 105 | nn.Conv2d(32, 32, 3, 1, 1, bias=False), 106 | nn.BatchNorm2d(32), 107 | nn.ReLU(inplace=True), 108 | nn.Conv2d(32, 64, 3, 1, 1, bias=False) 109 | ) 110 | self.bn1 = nn.BatchNorm2d(64) 111 | self.relu = nn.ReLU() 112 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 113 | self.layer1 = self._make_layer(block, 64, layers[0]) 114 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 115 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 116 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 117 | self.avgpool = nn.AdaptiveAvgPool2d(1) 118 | self.fc = nn.Linear(512 * block.expansion, num_classes) 119 | 120 | for m in self.modules(): 121 | if isinstance(m, nn.Conv2d): 122 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 123 | elif isinstance(m, nn.BatchNorm2d): 124 | nn.init.constant_(m.weight, 1) 125 | nn.init.constant_(m.bias, 0) 126 | 127 | def _make_layer(self, block, planes, blocks, stride=1): 128 | downsample = None 129 | if stride != 1 or self.inplanes != planes * block.expansion: 130 | downsample = nn.Sequential( 131 | nn.AvgPool2d(kernel_size=stride, stride=stride, 132 | ceil_mode=True, count_include_pad=False), 133 | nn.Conv2d(self.inplanes, planes * block.expansion, 134 | kernel_size=1, stride=1, bias=False), 135 | nn.BatchNorm2d(planes * block.expansion), 136 | ) 137 | 138 | layers = [] 139 | layers.append(block(self.inplanes, planes, stride, downsample=downsample, 140 | stype='stage', baseWidth=self.baseWidth, scale=self.scale)) 141 | self.inplanes = planes * block.expansion 142 | for i in range(1, blocks): 143 | layers.append(block(self.inplanes, planes, baseWidth=self.baseWidth, scale=self.scale)) 144 | 145 | return nn.Sequential(*layers) 146 | 147 | def forward(self, x): 148 | x = self.conv1(x) 149 | x = self.bn1(x) 150 | x = self.relu(x) 151 | x = self.maxpool(x) 152 | 153 | x = self.layer1(x) 154 | x = self.layer2(x) 155 | x = self.layer3(x) 156 | x = self.layer4(x) 157 | 158 | x = self.avgpool(x) 159 | x = x.view(x.size(0), -1) 160 | x = self.fc(x) 161 | 162 | return x 163 | 164 | 165 | def res2net50_v1b(pretrained=False, **kwargs): 166 | """Constructs a Res2Net-50_v1b lib. 167 | Res2Net-50 refers to the Res2Net-50_v1b_26w_4s. 168 | Args: 169 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 170 | """ 171 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 172 | if pretrained: 173 | model.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 174 | return model 175 | 176 | 177 | def res2net101_v1b(pretrained=False, **kwargs): 178 | """Constructs a Res2Net-50_v1b_26w_4s lib. 179 | Args: 180 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 181 | """ 182 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs) 183 | if pretrained: 184 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 185 | return model 186 | 187 | 188 | def res2net50_v1b_26w_4s(pretrained=False, **kwargs): 189 | """Constructs a Res2Net-50_v1b_26w_4s lib. 190 | Args: 191 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 192 | """ 193 | model = Res2Net(Bottle2neck, [3, 4, 6, 3], baseWidth=26, scale=4, **kwargs) 194 | if pretrained: 195 | model_state = torch.load('/media/nercms/NERCMS/GepengJi/Medical_Seqmentation/CRANet/models/res2net50_v1b_26w_4s-3cf99910.pth') 196 | model.load_state_dict(model_state) 197 | # lib.load_state_dict(model_zoo.load_url(model_urls['res2net50_v1b_26w_4s'])) 198 | return model 199 | 200 | 201 | def res2net101_v1b_26w_4s(pretrained=False, **kwargs): 202 | """Constructs a Res2Net-50_v1b_26w_4s lib. 203 | Args: 204 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 205 | """ 206 | model = Res2Net(Bottle2neck, [3, 4, 23, 3], baseWidth=26, scale=4, **kwargs) 207 | if pretrained: 208 | model.load_state_dict(model_zoo.load_url(model_urls['res2net101_v1b_26w_4s'])) 209 | return model 210 | 211 | 212 | def res2net152_v1b_26w_4s(pretrained=False, **kwargs): 213 | """Constructs a Res2Net-50_v1b_26w_4s lib. 214 | Args: 215 | pretrained (bool): If True, returns a lib pre-trained on ImageNet 216 | """ 217 | model = Res2Net(Bottle2neck, [3, 8, 36, 3], baseWidth=26, scale=4, **kwargs) 218 | if pretrained: 219 | model.load_state_dict(model_zoo.load_url(model_urls['res2net152_v1b_26w_4s'])) 220 | return model 221 | 222 | 223 | if __name__ == '__main__': 224 | images = torch.rand(1, 3, 224, 224).cuda(0) 225 | model = res2net50_v1b_26w_4s(pretrained=True) 226 | model = model.cuda(0) 227 | print(model(images).size()) 228 | -------------------------------------------------------------------------------- /lib/PraNet_ResNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | import torch.nn.functional as F 5 | from .ResNet import ResNet 6 | import math 7 | 8 | 9 | class BasicConv2d(nn.Module): 10 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 11 | super(BasicConv2d, self).__init__() 12 | self.conv = nn.Conv2d(in_planes, out_planes, 13 | kernel_size=kernel_size, stride=stride, 14 | padding=padding, dilation=dilation, bias=False) 15 | self.bn = nn.BatchNorm2d(out_planes) 16 | self.relu = nn.ReLU(inplace=True) 17 | 18 | def forward(self, x): 19 | x = self.conv(x) 20 | x = self.bn(x) 21 | return x 22 | 23 | 24 | class RFB(nn.Module): 25 | # RFB-like multi-scale module 26 | def __init__(self, in_channel, out_channel): 27 | super(RFB, self).__init__() 28 | self.relu = nn.ReLU(True) 29 | self.branch0 = nn.Sequential( 30 | BasicConv2d(in_channel, out_channel, 1), 31 | ) 32 | self.branch1 = nn.Sequential( 33 | BasicConv2d(in_channel, out_channel, 1), 34 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)), 35 | BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)), 36 | BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3) 37 | ) 38 | self.branch2 = nn.Sequential( 39 | BasicConv2d(in_channel, out_channel, 1), 40 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)), 41 | BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)), 42 | BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5) 43 | ) 44 | self.branch3 = nn.Sequential( 45 | BasicConv2d(in_channel, out_channel, 1), 46 | BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)), 47 | BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)), 48 | BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7) 49 | ) 50 | self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1) 51 | self.conv_res = BasicConv2d(in_channel, out_channel, 1) 52 | 53 | def forward(self, x): 54 | x0 = self.branch0(x) 55 | x1 = self.branch1(x) 56 | x2 = self.branch2(x) 57 | x3 = self.branch3(x) 58 | 59 | x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1)) 60 | 61 | x = self.relu(x_cat + self.conv_res(x)) 62 | return x 63 | 64 | 65 | class aggregation(nn.Module): 66 | # dense aggregation, it can be replaced by other aggregation previous, such as DSS, amulet, and so on. 67 | # used after MSF 68 | def __init__(self, channel): 69 | super(aggregation, self).__init__() 70 | self.relu = nn.ReLU(True) 71 | 72 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 73 | self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1) 74 | self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1) 75 | self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1) 76 | self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1) 77 | self.conv_upsample5 = BasicConv2d(2*channel, 2*channel, 3, padding=1) 78 | 79 | self.conv_concat2 = BasicConv2d(2*channel, 2*channel, 3, padding=1) 80 | self.conv_concat3 = BasicConv2d(3*channel, 3*channel, 3, padding=1) 81 | self.conv4 = BasicConv2d(3*channel, 3*channel, 3, padding=1) 82 | self.conv5 = nn.Conv2d(3*channel, 1, 1) 83 | 84 | def forward(self, x1, x2, x3): 85 | x1_1 = x1 86 | x2_1 = self.conv_upsample1(self.upsample(x1)) * x2 87 | x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \ 88 | * self.conv_upsample3(self.upsample(x2)) * x3 89 | 90 | x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1) 91 | x2_2 = self.conv_concat2(x2_2) 92 | 93 | x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1) 94 | x3_2 = self.conv_concat3(x3_2) 95 | 96 | x = self.conv4(x3_2) 97 | x = self.conv5(x) 98 | 99 | return x 100 | 101 | 102 | class CRANet(nn.Module): 103 | # resnet based encoder decoder 104 | def __init__(self, channel=32): 105 | super(CRANet, self).__init__() 106 | 107 | # ---- ResNet Backbone ---- 108 | self.resnet = ResNet() 109 | 110 | # Receptive Field Block 111 | self.rfb2_1 = RFB(512, channel) 112 | self.rfb3_1 = RFB(1024, channel) 113 | self.rfb4_1 = RFB(2048, channel) 114 | 115 | # Partial Decoder 116 | self.agg1 = aggregation(channel) 117 | # self.upsample = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 118 | 119 | # ---- reverse attention branch 4 ---- 120 | self.ra4_conv1 = BasicConv2d(2048, 256, kernel_size=1) 121 | self.ra4_conv2 = BasicConv2d(256, 256, kernel_size=5, padding=2) 122 | self.ra4_conv3 = BasicConv2d(256, 256, kernel_size=5, padding=2) 123 | self.ra4_conv4 = BasicConv2d(256, 256, kernel_size=5, padding=2) 124 | self.ra4_conv5 = BasicConv2d(256, 1, kernel_size=1) 125 | # self.ra4_conv5_up = nn.ConvTranspose2d(1, 1, kernel_size=64, stride=32) 126 | 127 | # ---- reverse attention branch 3 ---- 128 | # self.ra4_3 = nn.ConvTranspose2d(1, 1, kernel_size=4, stride=2) 129 | self.ra3_conv1 = BasicConv2d(1024, 64, kernel_size=1) 130 | self.ra3_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1) 131 | self.ra3_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1) 132 | self.ra3_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1) 133 | # self.ra3_conv4_up = nn.ConvTranspose2d(1, 1, kernel_size=32, stride=16) 134 | 135 | # ---- reverse attention branch 2 ---- 136 | # self.ra3_2 = nn.ConvTranspose2d(1, 1, kernel_size=4, stride=2) 137 | self.ra2_conv1 = BasicConv2d(512, 64, kernel_size=1) 138 | self.ra2_conv2 = BasicConv2d(64, 64, kernel_size=3, padding=1) 139 | self.ra2_conv3 = BasicConv2d(64, 64, kernel_size=3, padding=1) 140 | self.ra2_conv4 = BasicConv2d(64, 1, kernel_size=3, padding=1) 141 | # self.ra2_conv4_up = nn.ConvTranspose2d(1, 1, kernel_size=16, stride=8) 142 | 143 | # self.HA = HA() 144 | if self.training: 145 | self.initialize_weights() 146 | # self.apply(CRANet.weights_init) 147 | 148 | def forward(self, x): 149 | x = self.resnet.conv1(x) 150 | x = self.resnet.bn1(x) 151 | x = self.resnet.relu(x) 152 | x = self.resnet.maxpool(x) # bs, 64, 88, 88 153 | x1 = self.resnet.layer1(x) # bs, 256, 88, 88 154 | x2 = self.resnet.layer2(x1) # bs, 512, 44, 44 155 | 156 | x3 = self.resnet.layer3(x2) # bs, 1024, 22, 22 157 | x4 = self.resnet.layer4(x3) # bs, 2048, 11, 11 158 | x2_rfb = self.rfb2_1(x2) # channel -> 32 159 | x3_rfb = self.rfb3_1(x3) # channel -> 32 160 | x4_rfb = self.rfb4_1(x4) # channel -> 32 161 | 162 | ra5_feat = self.agg1(x4_rfb, x3_rfb, x2_rfb) 163 | lateral_map_5 = F.interpolate(ra5_feat, scale_factor=8, mode='bilinear') # Sup-1 (bs, 1, 44, 44) -> (bs, 1, 352, 352) 164 | 165 | # ---- reverse attention branch_4 ---- 166 | crop_4 = F.interpolate(ra5_feat, scale_factor=0.25, mode='bilinear') 167 | x = -1*(torch.sigmoid(crop_4)) + 1 168 | x = x.expand(-1, 2048, -1, -1).mul(x4) 169 | x = self.ra4_conv1(x) 170 | x = F.relu(self.ra4_conv2(x)) 171 | x = F.relu(self.ra4_conv3(x)) 172 | x = F.relu(self.ra4_conv4(x)) 173 | ra4_feat = self.ra4_conv5(x) 174 | x = ra4_feat + crop_4 175 | lateral_map_4 = F.interpolate(x, scale_factor=32, mode='bilinear') # Sup-2 (bs, 1, 11, 11) -> (bs, 1, 352, 352) 176 | 177 | # ---- reverse attention branch_3 ---- 178 | # x = F.interpolate(x, scale_factor=2, mode='bilinear') 179 | crop_3 = F.interpolate(x, scale_factor=2, mode='bilinear') 180 | x = -1*(torch.sigmoid(crop_3)) + 1 181 | x = x.expand(-1, 1024, -1, -1).mul(x3) 182 | x = self.ra3_conv1(x) 183 | x = F.relu(self.ra3_conv2(x)) 184 | x = F.relu(self.ra3_conv3(x)) 185 | ra3_feat = self.ra3_conv4(x) 186 | x = ra3_feat + crop_3 187 | lateral_map_3 = F.interpolate(x, scale_factor=16, mode='bilinear') 188 | # lateral_map_3 = self.crop(self.ra3_conv4_up(x), x_size) # NOTES: Sup-3 (bs, 1, 22, 22) -> (bs, 1, 352, 352) 189 | 190 | # ---- reverse attention branch_2 ---- 191 | # x = self.ra3_2(x) 192 | # crop_2 = self.crop(x, x2.size()) 193 | crop_2 = F.interpolate(x, scale_factor=2, mode='bilinear') 194 | x = -1*(torch.sigmoid(crop_2)) + 1 195 | x = x.expand(-1, 512, -1, -1).mul(x2) 196 | x = self.ra2_conv1(x) 197 | x = F.relu(self.ra2_conv2(x)) 198 | x = F.relu(self.ra2_conv3(x)) 199 | ra2_feat = self.ra2_conv4(x) 200 | x = ra2_feat + crop_2 201 | lateral_map_2 = F.interpolate(x, scale_factor=8, mode='bilinear') 202 | # lateral_map_2 = self.crop(self.ra2_conv4_up(x), x_size) # NOTES: Sup-4 (bs, 1, 44, 44) -> (bs, 1, 352, 352) 203 | 204 | return lateral_map_5, lateral_map_4, lateral_map_3, lateral_map_2 205 | 206 | # def crop(self, upsampled, x_size): 207 | # c = (upsampled.size()[2] - x_size[2]) // 2 208 | # _c = x_size[2] - upsampled.size()[2] + c 209 | # assert(c >= 0) 210 | # if c == _c == 0: 211 | # return upsampled 212 | # return upsampled[:, :, c:_c, c:_c] 213 | 214 | def initialize_weights(self): 215 | res50 = models.resnet50(pretrained=True) 216 | pretrained_dict = res50.state_dict() 217 | all_params = {} 218 | for k, v in self.resnet.state_dict().items(): 219 | if k in pretrained_dict.keys(): 220 | v = pretrained_dict[k] 221 | all_params[k] = v 222 | assert len(all_params.keys()) == len(self.resnet.state_dict().keys()) 223 | self.resnet.load_state_dict(all_params) 224 | 225 | # @staticmethod 226 | # def weights_init(m): 227 | # if isinstance(m, nn.Conv2d): 228 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 229 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 230 | # elif isinstance(m, nn.ConvTranspose2d): 231 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 232 | # m.weight.data.normal_(0, math.sqrt(2. / n)) 233 | 234 | if __name__ == '__main__': 235 | ras = CRANet().cuda() 236 | input_tensor = torch.randn(1, 3, 352, 352).cuda() 237 | 238 | out = ras(input_tensor) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PraNet: Parallel Reverse Attention Network for Polyp Segmentation (MICCAI2020-Oral & MICCAI2025 Young Scientist Publication Impact Award) 2 | 3 | > **Authors:** 4 | > [Deng-Ping Fan](https://dengpingfan.github.io/), 5 | > [Ge-Peng Ji](https://scholar.google.com/citations?user=oaxKYKUAAAAJ&hl=en), 6 | > [Tao Zhou](https://taozh2017.github.io/), 7 | > [Geng Chen](https://www.researchgate.net/profile/Geng_Chen13), 8 | > [Huazhu Fu](http://hzfu.github.io/), 9 | > [Jianbing Shen](http://iitlab.bit.edu.cn/mcislab/~shenjianbing), and 10 | > [Ling Shao](http://www.inceptioniai.org/). 11 | 12 | 13 | - Honored to be selected as a [MICCAI 2025 YSPIA Awardee](https://miccai.org/index.php/about-miccai/awards/young-scientist-impact-award/) ![image](https://github.com/user-attachments/assets/3966a8b1-85e9-4398-88a0-108816f78a91) ![image](https://github.com/user-attachments/assets/9dfa8007-af53-4213-9bec-e548fe159261) 14 | 15 | 16 | - 💥 We’re excited to introduce **PraNet-V2**, bringing semantic segmentation capabilities! PraNet-V2 introduces the **Dual-Supervised Reverse Attention (DSRA)** module, enabling explicit background supervision, independent background modeling, and semantically enriched attention fusion. We have found that PraNet-V2 outperforms PraNet-V1 in **polyp segmentation** and **improves existing SOTA segmentation multi-class models through DSRA integration**. 📖 [Read the paper](https://arxiv.org/abs/2504.10986) | 🔗 [Check out the code](https://github.com/ai4colonoscopy/PraNet-V2) 17 | 18 | - :boom: We’re excited to introduce “[IntelliScope Project](https://github.com/ai4colonoscopy/IntelliScope),” which offers a deep dive into the latest advancements in intelligent colonoscopy (📖 **ColonSurvey**). We’re also pushing for three key initiatives to embrace the multimodal era in colonoscopy: a pioneering large-scale instruction tuning dataset (🏥 **ColonINST**), a colonoscopy-specific multimodal language model (🤖 **ColonGPT**), and a **multimdoal benchmark** 💯 for comparing different approaches. 19 | 20 | - We are in the MICCAI2024 Young Scientist Publication Impact Award Shortlist 21 | ![image](https://github.com/user-attachments/assets/f2871d49-235c-4554-ae89-11740e41edb8) 22 | 23 | - We receive the award of Jittor Developer Conference Distinguish Paper & Most Influential (Application) Paper 24 |

25 |
26 |

27 | 28 | 29 | ## 1. Preface 30 | 31 | - This repository provides code for "_**PraNet: Parallel Reverse Attention Network for Polyp Segmentation**_" MICCAI-2020. 32 | ([paper](https://link.springer.com/chapter/10.1007%2F978-3-030-59725-2_26) | [中文版](https://dengpingfan.github.io/papers/[2020][MICCAI]PraNet_Chinese.pdf)) 33 | 34 | - If you have any questions about our paper, feel free to contact me. And if you are using PraNet 35 | or evaluation toolbox for your research, please cite this paper ([BibTeX](#4-citation)). 36 | 37 | 38 | ### 1.1. :fire: NEWS :fire: 39 | 40 | - [2025/03/20] 🚀 **PraNet-V2 is here!** 🚀 41 | We have introduced **PraNet-V2**, an enhanced version of PRaNet with **Dual-Supervised Reverse Attention (DSRA)** for more effective **multi-class** segmentation. Check out the paper and code here: [PraNet-V2](https://github.com/ai4colonoscopy/PraNet-V2) 42 | 43 | - [2022/11/26] Our PraNet has been developed on [Huawei Ascend platform](https://e.huawei.com/hk/products/servers/ascend), where the project could be found at [Gitee](https://gitee.com/ascend/ModelZoo-PyTorch/tree/master/PyTorch/contrib/cv/semantic_segmentation/PraNet) and [CSDN introduction](https://blog.csdn.net/m0_62401440/article/details/125563697). 44 | 45 | - [2022/03/27] :boom: We release a new large-scale dataset on **Video Polyp Segmentation (VPS)** task, please enjoy it. [ProjectLink](https://github.com/GewelsJI/VPS)/ [PDF](https://arxiv.org/abs/2203.14291). 46 | 47 | - [2021/12/26] :boom: PraNet模型在[Jittor Developer Conference 2021](https://cg.cs.tsinghua.edu.cn/jittor/news/2021-12-27-15-27-00-00-jdc1/)中荣获「最具影响力计图论文(应用)奖」 48 | 49 | - [2021/09/07] The Jittor convertion of PraNet ([inference code](https://github.com/DengPingFan/PraNet/tree/master/jittor)) is available right now. It has robust inference efficiency compared to PyTorch version, please enjoy it. Many thanks to Yu-Cheng Chou for the excellent conversion from pytorch framework. 50 | 51 | - [2021/09/05] The Tensorflow (Keras) implementation of PraNet (ResNet50/MobileNetV2 version) is released in [github-link](https://github.com/Thehunk1206/PRANet-Polyps-Segmentation). Thanks Tauhid Khan. 52 | 53 | - [2021/08/18] Improved version (PraNet-V2) has been released: https://github.com/DengPingFan/Polyp-PVT. 54 | 55 | - [2021/04/23] We update the results on four [Camouflaged Object Detection (COD)](https://github.com/DengPingFan/SINet) testing dataset (i.e., COD10K, NC4K, CAMO, and CHAMELEON) of our PraNet, which is the retained on COD dataset from scratch. Download links at google drive are avaliable here: [result](https://drive.google.com/file/d/1h1sXnZA3uIeRXe9eUsH8Vp9i40VylauB/view?usp=sharing), [model weight](https://drive.google.com/file/d/1epdeolFS_JC8D8Pm_r0TaUJM-Qo4v49c/view?usp=sharing), [evaluation results](https://drive.google.com/file/d/1hY_S0-o5rezsBZCUegpDtAAmhy8jpW5N/view?usp=sharing). 56 | 57 | - [2021/01/21] :boom: Our PraNet has been used as the base segmentation model of [Prof. Michael I. Jordan](https://scholar.google.com/citations?user=yxUduqMAAAAJ&hl=zh-CN) et al's recent work (Distribution-Free, Risk-Controlling Prediction Sets, [Journal of the ACM 2021](https://arxiv.org/pdf/2101.02703.pdf)). 58 | 59 | - [2021/01/10] :boom: Our PraNet achieved the Top-1 ranking on the camouflaged object detection task ([link](https://paperswithcode.com/paper/pranet-parallel-reverse-attention-network-for)). 60 | 61 | - [2020/09/18] Upload the pre-computed maps. 62 | 63 | - [2020/05/28] Upload pre-trained weights. 64 | 65 | - [2020/06/24] Release training/testing code. 66 | 67 | - [2020/03/24] Create repository. 68 | 69 | 70 | ### 1.2. Table of Contents 71 | 72 | - [PraNet: Parallel Reverse Attention Network for Polyp Segmentation (MICCAI2020-Oral)](#pranet-parallel-reverse-attention-network-for-polyp-segmentation-miccai2020-oral) 73 | - [1. Preface](#1-preface) 74 | - [1.1. :fire: NEWS :fire:](#11-fire-news-fire) 75 | - [1.2. Table of Contents](#12-table-of-contents) 76 | - [1.3. State-of-the-art Approaches](#13-state-of-the-art-approaches) 77 | - [2. Overview](#2-overview) 78 | - [2.1. Introduction](#21-introduction) 79 | - [2.2. Framework Overview](#22-framework-overview) 80 | - [2.3. Qualitative Results](#23-qualitative-results) 81 | - [3. Proposed Baseline](#3-proposed-baseline) 82 | - [3.1. Training/Testing](#31-trainingtesting) 83 | - [3.2 Evaluating your trained model:](#32-evaluating-your-trained-model) 84 | - [3.3 Pre-computed maps:](#33-pre-computed-maps) 85 | - [4. Citation](#4-citation) 86 | - [5. TODO LIST](#5-todo-list) 87 | - [6. FAQ](#6-faq) 88 | - [7. License](#7-license) 89 | 90 | Table of contents generated with markdown-toc 91 | 92 | ### 1.3. State-of-the-art Approaches 93 | 1. "Selective feature aggregation network with area-boundary constraints for polyp segmentation." IEEE Transactions on Medical Imaging, 2019. 94 | paper link: https://link.springer.com/chapter/10.1007/978-3-030-32239-7_34 95 | 2. "PraNet: Parallel Reverse Attention Network for Polyp Segmentation" IEEE Transactions on Medical Imaging, 2020. 96 | paper link: https://link.springer.com/chapter/10.1007%2F978-3-030-59725-2_26 97 | 3. "Hardnet-mseg: A simple encoder-decoder polyp segmentation neural network that achieves over 0.9 mean dice and 86 fps" arXiv, 2021 98 | paper link: https://arxiv.org/pdf/2101.07172.pdf 99 | 4. "TransFuse: Fusing Transformers and CNNs for Medical Image Segmentation" arXiv, 2021. 100 | paper link: https://arxiv.org/pdf/2102.08005.pdf 101 | 5. "Automatic Polyp Segmentation via Multi-scale Subtraction Network" MICCAI, 2021. paper link: https://arxiv.org/pdf/2108.05082.pdf 102 | 6. "CCBANet: Cascading Context and Balancing Attention for Polyp Segmentation" MICCAI, 2021. paper link: https://link.springer.com/book/10.1007/978-3-030-87193-2?noAccess=true 103 | 7. "Double Encoder-Decoder Networks for Gastrointestinal Polyp Segmentation" MICCAI, 2021. paper link: https://arxiv.org/pdf/2110.01939.pdf 104 | 8. "HRENet: A Hard Region Enhancement Network for Polyp Segmentation" MICCAI, 2021. paper link: https://link.springer.com/book/10.1007/978-3-030-87193-2?noAccess=true 105 | 9. "Learnable Oriented-Derivative Network for Polyp Segmentation" MICCAI, 2021. paper link: https://link.springer.com/book/10.1007/978-3-030-87193-2?noAccess=true 106 | 10. "Shallow attention network for polyp segmentation" MICCAI, 2021. paper link: https://arxiv.org/pdf/2108.00882.pdf 107 | 108 | The latest trends in image-/video-based polyp segmentation refer to [AWESOME_VPS.md](https://github.com/GewelsJI/VPS/blob/main/docs/AWESOME_VPS.md). 109 | 110 | 111 | ## 2. Overview 112 | 113 | ### 2.1. Introduction 114 | 115 | Colonoscopy is an effective technique for detecting colorectal polyps, which are highly related to colorectal cancer. 116 | In clinical practice, segmenting polyps from colonoscopy images is of great importance since it provides valuable 117 | information for diagnosis and surgery. However, accurate polyp segmentation is a challenging task, for two major reasons: 118 | (i) the same type of polyps has a diversity of size, color and texture; and 119 | (ii) the boundary between a polyp and its surrounding mucosa is not sharp. 120 | 121 | To address these challenges, we propose a parallel reverse attention network (PraNet) for accurate polyp segmentation in colonoscopy 122 | images. Specifically, we first aggregate the features in high-level layers using a parallel partial decoder (PPD). 123 | Based on the combined feature, we then generate a global map as the initial guidance area for the following components. 124 | In addition, we mine the boundary cues using a reverse attention (RA) module, which is able to establish the relationship between 125 | areas and boundary cues. Thanks to the recurrent cooperation mechanism between areas and boundaries, 126 | our PraNet is capable of calibrating any misaligned predictions, improving the segmentation accuracy. 127 | 128 | Quantitative and qualitative evaluations on five challenging datasets across six 129 | metrics show that our PraNet improves the segmentation accuracy significantly, and presents a number of advantages in terms of generalizability, 130 | and real-time segmentation efficiency (∼50fps). 131 | 132 | ### 2.2. Framework Overview 133 | 134 |

135 |
136 | 137 | Figure 1: Overview of the proposed PraNet, which consists of three reverse attention 138 | modules with a parallel partial decoder connection. See § 2 in the paper for details. 139 | 140 |

141 | 142 | ### 2.3. Qualitative Results 143 | 144 |

145 |
146 | 147 | Figure 2: Qualitative Results. 148 | 149 |

150 | 151 | ## 3. Proposed Baseline 152 | 153 | ### 3.1. Training/Testing 154 | 155 | The training and testing experiments are conducted using [PyTorch](https://github.com/pytorch/pytorch) with 156 | a single GeForce RTX TITAN GPU of 24 GB Memory. 157 | 158 | > Note that our model also supports low memory GPU, which means you can lower the batch size 159 | 160 | 161 | 1. Configuring your environment (Prerequisites): 162 | 163 | Note that PraNet is only tested on Ubuntu OS with the following environments. 164 | It may work on other operating systems as well but we do not guarantee that it will. 165 | 166 | + Creating a virtual environment in terminal: `conda create -n PraNet python=3.6`. 167 | 168 | + Installing necessary packages: PyTorch 1.1 169 | 170 | 1. Downloading necessary data: 171 | 172 | + downloading testing dataset and move it into `./data/TestDataset/`, 173 | which can be found in this [Google Drive Link (327.2MB)](https://drive.google.com/file/d/1Y2z7FD5p5y31vkZwQQomXFRB0HutHyao/view?usp=sharing). It contains five sub-datsets: CVC-300 (60 test samples), CVC-ClinicDB (62 test samples), CVC-ColonDB (380 test samples), ETIS-LaribPolypDB (196 test samples), Kvasir (100 test samples). 174 | 175 | + downloading training dataset and move it into `./data/TrainDataset/`, 176 | which can be found in this [Google Drive Link (399.5MB)](https://drive.google.com/file/d/1YiGHLw4iTvKdvbT6MgwO9zcCv8zJ_Bnb/view?usp=sharing). It contains two sub-datasets: Kvasir-SEG (900 train samples) and CVC-ClinicDB (550 train samples). 177 | 178 | + downloading pretrained weights and move it into `snapshots/PraNet_Res2Net/PraNet-19.pth`, 179 | which can be found in this [Google Drive Link (124.6MB)](https://drive.google.com/file/d/1lJv8XVStsp3oNKZHaSr42tawdMOq6FLP/view?usp=sharing). 180 | 181 | + downloading Res2Net weights [Google Drive (98.4MB)](https://drive.google.com/file/d/1FjXh_YG1hLGPPM6j-c8UxHcIWtzGGau5/view?usp=sharing). 182 | 183 | 1. Training Configuration: 184 | 185 | + Assigning your costumed path, like `--train_save` and `--train_path` in `MyTrain.py`. 186 | 187 | + Just enjoy it! 188 | 189 | 1. Testing Configuration: 190 | 191 | + After you download all the pre-trained model and testing dataset, just run `MyTest.py` to generate the final prediction map: 192 | replace your trained model directory (`--pth_path`). 193 | 194 | + Just enjoy it! 195 | 196 | 197 | ### 3.2 Evaluating your trained model: 198 | 199 | Matlab: One-key evaluation is written in MATLAB code ([Google Drive Link](https://drive.google.com/file/d/1eKUpny19kLaCpZl7jjan408238h5PGIO/view?usp=sharing)), 200 | please follow this the instructions in `./eval/main.m` and just run it to generate the evaluation results in `./res/`. 201 | The complete evaluation toolbox (including data, map, eval code, and res): [Google Drive Link (380.6MB)](https://drive.google.com/file/d/1FJxb9DZMzPWFffkbchU0s9Zcf5oe7qcT/view?usp=sharing). 202 | 203 | Python: Please refer to the work of ACMMM2021 https://github.com/plemeri/UACANet 204 | 205 | 206 | ### 3.3 Pre-computed maps: 207 | They can be found in [Google Drive Link (61.6MB)](https://drive.google.com/file/d/1CJ6CTUdenumgiKXieuKXFohefRJwyFPY/view?usp=sharing). 208 | 209 | 210 | ## 4. Citation 211 | 212 | Please cite our paper if you find the work useful: 213 | 214 | @inproceedings{fan2020pranet, 215 | title={Pranet: Parallel reverse attention network for polyp segmentation}, 216 | author={Fan, Deng-Ping and Ji, Ge-Peng and Zhou, Tao and Chen, Geng and Fu, Huazhu and Shen, Jianbing and Shao, Ling}, 217 | booktitle={International conference on medical image computing and computer-assisted intervention}, 218 | pages={263--273}, 219 | year={2020}, 220 | organization={Springer} 221 | } 222 | 223 | ## 5. TODO LIST 224 | 225 | > If you want to improve the usability or any piece of advice, please feel free to contact me directly ([E-mail](gepengai.ji@gmail.com)). 226 | 227 | - [ ] Support `NVIDIA APEX` training. 228 | 229 | - [ ] Support different backbones ( 230 | VGGNet, 231 | ResNet, 232 | [ResNeXt](https://github.com/facebookresearch/ResNeXt), 233 | [iResNet](https://github.com/iduta/iresnet), 234 | and 235 | [ResNeSt](https://github.com/zhanghang1989/ResNeSt) 236 | etc.) 237 | 238 | - [ ] Support distributed training. 239 | 240 | - [ ] Support lightweight architecture and real-time inference, like MobileNet, SqueezeNet. 241 | 242 | - [ ] Add more comprehensive competitors. 243 | 244 | ## 6. FAQ 245 | 246 | 1. If the image cannot be loaded in the page (mostly in the domestic network situations). 247 | 248 | [Solution Link](https://blog.csdn.net/weixin_42128813/article/details/102915578) 249 | 250 | ## 7. License 251 | 252 | The source code is free for research and education use only. Any comercial use should get formal permission first. 253 | 254 | --- 255 | 256 | **[⬆ back to top](#0-preface)** 257 | --------------------------------------------------------------------------------