├── eval ├── trainImgSet.mat ├── mat2txt.m ├── MAE.m ├── Smeasure.m ├── Fmeasure.m ├── Fm_th.m ├── PRCurve.m ├── wFmeasure.m ├── S_object.m ├── Emeasure.m ├── S_region.m └── main.m ├── utils.py ├── options.py ├── test.py ├── README.md ├── loss └── ssim.py ├── train.py ├── data.py └── net.py /eval/trainImgSet.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jxr326/SwinMCNet/HEAD/eval/trainImgSet.mat -------------------------------------------------------------------------------- /eval/mat2txt.m: -------------------------------------------------------------------------------- 1 | close all 2 | clear all 3 | clc 4 | 5 | 6 | load('C:\Users\Downloads\HKU-IS\HKU-IS\valImgSet.mat'); 7 | 8 | [nrows,ncols]= size(valImgSet); 9 | fid=fopen('C:\Users\Downloads\HKU-IS\HKU-IS\valImgSet.txt','w'); 10 | for row=1:nrows 11 | fprintf(fid, '%s\n', valImgSet{row,:}); 12 | end 13 | fclose(fid); 14 | -------------------------------------------------------------------------------- /eval/MAE.m: -------------------------------------------------------------------------------- 1 | function mae = MAE(smap, gtImg) 2 | 3 | if size(smap, 1) ~= size(gtImg, 1) || size(smap, 2) ~= size(gtImg, 2) 4 | error('Saliency map and gt Image have different sizes!\n'); 5 | end 6 | 7 | if ~islogical(gtImg) 8 | gtImg = gtImg(:,:,1) > 128; 9 | end 10 | 11 | fgPixels = smap(gtImg); 12 | fgErrSum = length(fgPixels) - sum(fgPixels); 13 | bgErrSum = sum(smap(~gtImg)); 14 | mae = (fgErrSum + bgErrSum) / numel(gtImg); -------------------------------------------------------------------------------- /eval/Smeasure.m: -------------------------------------------------------------------------------- 1 | function Q = Smeasure(prediction,GT) 2 | 3 | if (~isa(prediction,'double')) 4 | error('The prediction should be double type...'); 5 | end 6 | if ((max(prediction(:))>1) || min(prediction(:))<0) 7 | error('The prediction should be in the range of [0 1]...'); 8 | end 9 | if (~islogical(GT)) 10 | error('GT should be logical type...'); 11 | end 12 | 13 | y = mean2(GT); 14 | 15 | if (y==0) 16 | x = mean2(prediction); 17 | Q = 1.0 - x; 18 | elseif(y==1) 19 | x = mean2(prediction); 20 | Q = x; 21 | else 22 | alpha = 0.5; 23 | Q = alpha*S_object(prediction,GT)+(1-alpha)*S_region(prediction,GT); 24 | end 25 | 26 | end -------------------------------------------------------------------------------- /eval/Fmeasure.m: -------------------------------------------------------------------------------- 1 | function [PreFtem, RecallFtem, FmeasureF] = Fmeasure(sMap, gtMap, gtsize) 2 | 3 | sumLabel = 2* mean(sMap(:)); 4 | if (sumLabel > 1) 5 | sumLabel = 1; 6 | end 7 | 8 | Label3 = zeros( gtsize ); 9 | Label3(sMap>=sumLabel ) = 1; 10 | 11 | NumRec = length( find( Label3==1 ) ); 12 | LabelAnd = Label3 & gtMap; 13 | NumAnd = length( find ( LabelAnd==1 ) ); 14 | num_obj = sum(sum(gtMap)); 15 | 16 | if NumAnd == 0 17 | PreFtem = 0; 18 | RecallFtem = 0; 19 | FmeasureF = 0; 20 | else 21 | PreFtem = NumAnd/NumRec; 22 | RecallFtem = NumAnd/num_obj; 23 | FmeasureF = ((1.3*PreFtem*RecallFtem)/(0.3*PreFtem+RecallFtem)); 24 | end -------------------------------------------------------------------------------- /eval/Fm_th.m: -------------------------------------------------------------------------------- 1 | function [all_f_th, all_th] = Fm_th(sMap, gtMap, gtsize) 2 | 3 | sMap = 255 * sMap; 4 | all_f_th = zeros(256,1); 5 | all_th = zeros(256,1); 6 | for threshold = 0:255 7 | Label3 = zeros( gtsize ); 8 | Label3(sMap>=threshold ) = 1; 9 | NumRec = length( find( Label3==1 ) ); 10 | LabelAnd = Label3 & gtMap; 11 | NumAnd = length( find ( LabelAnd==1 ) ); 12 | num_obj = sum(sum(gtMap)); 13 | if NumAnd == 0 14 | PreFtem = 0; 15 | RecallFtem = 0; 16 | f_th = 0; 17 | else 18 | PreFtem = NumAnd/NumRec; 19 | RecallFtem = NumAnd/num_obj; 20 | f_th = ((1.3*PreFtem*RecallFtem)/(0.3*PreFtem+RecallFtem)); 21 | th = threshold; 22 | all_f_th(threshold+1,:) = f_th; 23 | all_th(threshold+1,:) = th; 24 | end 25 | 26 | end -------------------------------------------------------------------------------- /eval/PRCurve.m: -------------------------------------------------------------------------------- 1 | function [precision, recall] = PRCurve(smapImg, gtImg) 2 | 3 | if ~islogical(gtImg) 4 | gtImg = gtImg(:,:,1) > 128; 5 | end 6 | if any(size(smapImg) ~= size(gtImg)) 7 | error('saliency map and ground truth mask have different size'); 8 | end 9 | 10 | gtPxlNum = sum(gtImg(:)); 11 | if 0 == gtPxlNum 12 | error('no foreground region is labeled'); 13 | end 14 | 15 | targetHist = histc(smapImg(gtImg), 0:255); 16 | nontargetHist = histc(smapImg(~gtImg), 0:255); 17 | 18 | targetHist = flipud(targetHist); 19 | nontargetHist = flipud(nontargetHist); 20 | 21 | targetHist = cumsum( targetHist ); 22 | nontargetHist = cumsum( nontargetHist ); 23 | 24 | precision = targetHist ./ (targetHist + nontargetHist + eps); 25 | if any(isnan(precision)) 26 | warning('there exists NAN in precision, this is because your saliency map do not range from 0 to 255\n'); 27 | end 28 | recall = targetHist / gtPxlNum; 29 | -------------------------------------------------------------------------------- /eval/wFmeasure.m: -------------------------------------------------------------------------------- 1 | function [Q]= wFmeasure(FG,GT) 2 | 3 | if (~isa( FG, 'double' )) 4 | error('FG should be of type: double'); 5 | end 6 | if ((max(FG(:))>1) || min(FG(:))<0) 7 | error('FG should be in the range of [0 1]'); 8 | end 9 | if (~islogical(GT)) 10 | error('GT should be of type: logical'); 11 | end 12 | 13 | 14 | 15 | dGT = double(GT); 16 | if max(dGT(:)) == 0 17 | Q = 0; 18 | return 19 | end 20 | 21 | E = abs(FG-dGT); 22 | 23 | 24 | [Dst,IDXT] = bwdist(dGT); 25 | 26 | K = fspecial('gaussian',7,5); 27 | Et = E; 28 | Et(~GT)=Et(IDXT(~GT)); 29 | EA = imfilter(Et,K); 30 | MIN_E_EA = E; 31 | MIN_E_EA(GT & EA0)] 18 | if len(tmp)!=0: 19 | body[np.where(body>0)] = np.floor(tmp/np.max(tmp)*255) 20 | 21 | if not os.path.exists(datapath+'/skeleton/'): 22 | os.makedirs(datapath+'/skeleton/') 23 | cv2.imwrite(datapath+'/skeleton/'+name, body) 24 | 25 | if not os.path.exists(datapath+'/contour/'): 26 | os.makedirs(datapath+'/contour/') 27 | cv2.imwrite(datapath+'/contour/'+name, mask-body) 28 | 29 | 30 | if __name__=='__main__': 31 | split_map('../dataset') 32 | -------------------------------------------------------------------------------- /eval/S_object.m: -------------------------------------------------------------------------------- 1 | function Q = S_object(prediction,GT) 2 | 3 | prediction_fg = prediction; 4 | prediction_fg(~GT)=0; 5 | O_FG = Object(prediction_fg,GT); 6 | 7 | 8 | prediction_bg = 1.0 - prediction; 9 | prediction_bg(GT) = 0; 10 | O_BG = Object(prediction_bg,~GT); 11 | 12 | 13 | u = mean2(GT); 14 | Q = u * O_FG + (1 - u) * O_BG; 15 | 16 | end 17 | 18 | function score = Object(prediction,GT) 19 | 20 | 21 | if isempty(prediction) 22 | score = 0; 23 | return; 24 | end 25 | if isinteger(prediction) 26 | prediction = double(prediction); 27 | end 28 | if (~isa( prediction, 'double' )) 29 | error('prediction should be of type: double'); 30 | end 31 | if ((max(prediction(:))>1) || min(prediction(:))<0) 32 | error('prediction should be in the range of [0 1]'); 33 | end 34 | if(~islogical(GT)) 35 | error('GT should be of type: logical'); 36 | end 37 | 38 | 39 | x = mean2(prediction(GT)); 40 | 41 | 42 | sigma_x = std(prediction(GT)); 43 | 44 | score = 2.0 * x./(x^2 + 1.0 + sigma_x + eps); 45 | end -------------------------------------------------------------------------------- /eval/Emeasure.m: -------------------------------------------------------------------------------- 1 | function [score]= Emeasure(FM,GT) 2 | 3 | FM = mat2gray(FM); 4 | thd = 2 * mean(FM(:)); 5 | FM = FM > thd; 6 | 7 | 8 | FM = logical(FM); 9 | GT = logical(GT); 10 | 11 | 12 | dFM = double(FM); 13 | dGT = double(GT); 14 | 15 | 16 | if (sum(dGT(:))==0) 17 | enhanced_matrix = 1.0 - dFM; 18 | elseif(sum(~dGT(:))==0) 19 | enhanced_matrix = dFM; 20 | else 21 | 22 | 23 | 24 | align_matrix = AlignmentTerm(dFM,dGT); 25 | 26 | enhanced_matrix = EnhancedAlignmentTerm(align_matrix); 27 | end 28 | 29 | 30 | [w,h] = size(GT); 31 | score = sum(enhanced_matrix(:))./(w*h - 1 + eps); 32 | end 33 | 34 | 35 | function [align_Matrix] = AlignmentTerm(dFM,dGT) 36 | 37 | 38 | mu_FM = mean2(dFM); 39 | mu_GT = mean2(dGT); 40 | 41 | 42 | align_FM = dFM - mu_FM; 43 | align_GT = dGT - mu_GT; 44 | 45 | 46 | align_Matrix = 2.*(align_GT.*align_FM)./(align_GT.*align_GT + align_FM.*align_FM + eps); 47 | 48 | end 49 | 50 | 51 | function enhanced = EnhancedAlignmentTerm(align_Matrix) 52 | enhanced = ((align_Matrix + 1).^2)/4; 53 | end 54 | 55 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | parser = argparse.ArgumentParser() 3 | parser.add_argument('--epoch', type=int, default=48, help='epoch number') 4 | parser.add_argument('--lr', type=float, default=0.05, help='learning rate') 5 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 6 | parser.add_argument('--batchsize', type=int, default=16, help='training batch size') 7 | parser.add_argument('--trainsize', type=int, default=384, help='training dataset size') 8 | parser.add_argument('--testsize', type=int, default=384, help='testing dataset size') 9 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 10 | parser.add_argument('--decay_rate', type=float, default=5e-4, help='decay rate of learning rate') 11 | 12 | parser.add_argument('--load', type=str, default='../res/swin_base_patch4_window12_384_22k.pth') 13 | 14 | 15 | parser.add_argument('--train_data_root', type=str, default='../VT5000/Train', help='the training datasets root') 16 | parser.add_argument('--val_data_root', type=str, default='../VT5000/Test', help='the value datasets root') 17 | parser.add_argument('--test_data_root', type=str, default='../dataset/', help='the test datasets root') 18 | 19 | 20 | parser.add_argument('--save_path', type=str, default='../res/', help='the path to save models and logs') 21 | parser.add_argument('--test_model', type=str, default='../res/SwinMCNet_epoch_best.pth', help='saved model path') 22 | parser.add_argument('--maps_path', type=str, default='../maps/', help='saved model path') 23 | 24 | opt = parser.parse_args() 25 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import sys 4 | sys.path.append('./models') 5 | import numpy as np 6 | import os, argparse 7 | import cv2 8 | from net import SwinMCNet 9 | from data import test_dataset 10 | from options import opt 11 | from collections import OrderedDict 12 | import time 13 | from os.path import splitext 14 | from ptflops.flops_counter import get_model_complexity_info 15 | 16 | 17 | #set device for test 18 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 19 | 20 | 21 | #load the model 22 | model = SwinMCNet() 23 | 24 | 25 | base_weights = torch.load(opt.test_model) 26 | new_state_dict = OrderedDict() 27 | for k, v in base_weights.items(): 28 | name = k[7:] # remove 'module.' 29 | new_state_dict[name] = v 30 | model.load_state_dict(new_state_dict) 31 | 32 | print('Loading base network...') 33 | 34 | 35 | model.cuda() 36 | model.eval() 37 | 38 | #test 39 | test_data_root = opt.test_data_root 40 | maps_path = opt.maps_path 41 | 42 | test_sets = ['VT5000/Test','VT1000','VT821'] 43 | 44 | for dataset in test_sets: 45 | 46 | save_path = maps_path + dataset + '/' 47 | 48 | if not os.path.exists(save_path): 49 | os.makedirs(save_path) 50 | dataset_path = test_data_root + dataset 51 | test_loader = test_dataset(dataset_path, opt.testsize) 52 | 53 | for i in range(test_loader.size): 54 | image, t, gt, (H, W), name = test_loader.load_data() 55 | image = image.cuda() 56 | t = t.cuda() 57 | shape = (W,H) 58 | 59 | outi1, outt1, out1, outi2, outt2, out2 = model(image,t,shape) 60 | 61 | res = out2 62 | res = res.sigmoid().data.cpu().numpy().squeeze() 63 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 64 | print('save img to: ',save_path + name) 65 | cv2.imwrite(save_path + name,res*255) 66 | 67 | 68 | print('Test Done!') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## MCNet 2 | SwinMCNet: Mirror Complementary Transformer Network for RGB-thermal Salient Object Detection 3 | 4 | 5 | ## Prerequisites 6 | - [Python 3.68](https://www.python.org/) 7 | - [Pytorch 1.3.1](http://pytorch.org/) 8 | - [Cuda 10.0](https://developer.nvidia.com/cuda-10.0-download-archive) 9 | - [OpenCV 4.1.2](https://opencv.org/) 10 | - [Numpy 1.17.3](https://numpy.org/) 11 | - [TensorboardX 2.1](https://github.com/lanpa/tensorboardX) 12 | 13 | 14 | ## Benchmark Datasets 15 | Download the following datasets and unzip them into `data` folder 16 | 17 | - [VT5000](https://arxiv.org/pdf/2007.03262.pdf) 18 | - [VT1000](https://arxiv.org/pdf/1905.06741.pdf) 19 | - [VT821](https://arxiv.org/pdf/1701.02829.pdf) 20 | 21 | 22 | ## The Proposed Dataset 23 | Our proposed RGBT SOD dataset VT723 that contain common challenging scenes of real world. 24 | - [VT723](https://drive.google.com/file/d/12gEUFG2yWi3uBTjLymQ3hjnDUHGcgADq/view?usp=sharing) 25 | 26 | 27 | ## Training & Testing & Evaluate 28 | - Split the ground truth into skeleton map and contour map, which will be saved into `data/VT5000/skeleton` and `data/VT5000/contour`. 29 | ```shell 30 | python3 utils.py 31 | ``` 32 | 33 | - Train the model and get the pretrained model, which will be saved into `res` folder. 34 | ```shell 35 | python3 train.py 36 | ``` 37 | 38 | - If you just want to evaluate the performance of MCNet without training, please download the [pretrained model](https://drive.google.com/file/d/1qcZeBiwF78Lv24hXmXN4vMFbK4yC-C-y/view?usp=sharing) into `res` folder. 39 | - Test the model and get the predicted saliency maps, which will be saved into `maps` folder. 40 | ```shell 41 | python3 test.py 42 | ``` 43 | 44 | - Evaluate the predicted results. 45 | ```shell 46 | cd eval 47 | matlab 48 | main 49 | ``` 50 | 51 | ## Saliency maps & Trained model 52 | - saliency maps: [Google](https://drive.google.com/file/d/1LjLLIGmnKb_UQeGpFar35gyEs76EHFLL/view?usp=sharing) 53 | - trained model: [Google](https://drive.google.com/file/d/1D40-nIqvTmqh5CpH22c8yQA8lcee23SB/view?usp=sharing) 54 | -------------------------------------------------------------------------------- /eval/S_region.m: -------------------------------------------------------------------------------- 1 | function Q = S_region(prediction,GT) 2 | 3 | [X,Y] = centroid(GT); 4 | 5 | 6 | [GT_1,GT_2,GT_3,GT_4,w1,w2,w3,w4] = divideGT(GT,X,Y); 7 | 8 | 9 | [prediction_1,prediction_2,prediction_3,prediction_4] = Divideprediction(prediction,X,Y); 10 | 11 | 12 | Q1 = ssim(prediction_1,GT_1); 13 | Q2 = ssim(prediction_2,GT_2); 14 | Q3 = ssim(prediction_3,GT_3); 15 | Q4 = ssim(prediction_4,GT_4); 16 | 17 | 18 | Q = w1 * Q1 + w2 * Q2 + w3 * Q3 + w4 * Q4; 19 | 20 | end 21 | 22 | 23 | function [X,Y] = centroid(GT) 24 | 25 | [rows,cols] = size(GT); 26 | 27 | if(sum(GT(:))==0) 28 | X = round(cols/2); 29 | Y = round(rows/2); 30 | else 31 | dGT = double(GT); 32 | x = ones(rows,1)*(1:cols); 33 | y = (1:rows)'*ones(1,cols); 34 | area = sum(dGT(:)); 35 | X = round(sum(sum(dGT.*x))/area); 36 | Y = round(sum(sum(dGT.*y))/area); 37 | end 38 | 39 | end 40 | 41 | 42 | 43 | function [LT,RT,LB,RB,w1,w2,w3,w4] = divideGT(GT,X,Y) 44 | 45 | 46 | [hei,wid] = size(GT); 47 | area = wid * hei; 48 | 49 | 50 | LT = GT(1:Y,1:X); 51 | RT = GT(1:Y,X+1:wid); 52 | LB = GT(Y+1:hei,1:X); 53 | RB = GT(Y+1:hei,X+1:wid); 54 | 55 | 56 | w1 = (X*Y)./area; 57 | w2 = ((wid-X)*Y)./area; 58 | w3 = (X*(hei-Y))./area; 59 | w4 = 1.0 - w1 - w2 - w3; 60 | end 61 | 62 | 63 | function [LT,RT,LB,RB] = Divideprediction(prediction,X,Y) 64 | 65 | 66 | [hei,wid] = size(prediction); 67 | 68 | 69 | LT = prediction(1:Y,1:X); 70 | RT = prediction(1:Y,X+1:wid); 71 | LB = prediction(Y+1:hei,1:X); 72 | RB = prediction(Y+1:hei,X+1:wid); 73 | 74 | end 75 | 76 | function Q = ssim(prediction,GT) 77 | 78 | 79 | dGT = double(GT); 80 | 81 | [hei,wid] = size(prediction); 82 | N = wid*hei; 83 | 84 | 85 | x = mean2(prediction); 86 | y = mean2(dGT); 87 | 88 | 89 | sigma_x2 = sum(sum((prediction - x).^2))./(N - 1 + eps); 90 | sigma_y2 = sum(sum((dGT - y).^2))./(N - 1 + eps); 91 | 92 | 93 | sigma_xy = sum(sum((prediction - x).*(dGT - y)))./(N - 1 + eps); 94 | 95 | alpha = 4 * x * y * sigma_xy; 96 | beta = (x.^2 + y.^2).*(sigma_x2 + sigma_y2); 97 | 98 | if(alpha ~= 0) 99 | Q = alpha./(beta + eps); 100 | elseif(alpha == 0 && beta == 0) 101 | Q = 1.0; 102 | else 103 | Q = 0; 104 | end 105 | 106 | end -------------------------------------------------------------------------------- /eval/main.m: -------------------------------------------------------------------------------- 1 | clc; 2 | clear all; 3 | 4 | predpath = 'C:\Users\Desktop\VT723_maps\Ours\'; 5 | maskpath = 'C:\Users\Desktop\VT723\GT\'; 6 | 7 | 8 | names = dir(fullfile(maskpath,'*.png')); 9 | disp(names); 10 | names = {names.name}; 11 | wfm = 0; mae = 0; sm = 0; fm = 0; prec = 0; rec = 0; em = 0; 12 | score1 = 0; score2 = 0; score3 = 0; score4 = 0; score5 = 0; score6 = 0; score7 = 0; 13 | 14 | results = cell(numel(names), 6); 15 | ALLPRECISION = zeros(numel(names), 256); 16 | ALLRECALL = zeros(numel(names), 256); 17 | a_fth = zeros(numel(names), 256); 18 | a_th = zeros(numel(names), 256); 19 | file_num = false(numel(names), 1); 20 | for k = 1:numel(names) 21 | name = names{1,k}; 22 | results{k, 1} = name; 23 | file_num(k) = true; 24 | 25 | gtpath = [maskpath name]; 26 | gt = imread(gtpath); 27 | 28 | fgpath = [predpath strrep(name, '.jpg', '.png')]; 29 | fg = imread(fgpath); 30 | 31 | if length(size(fg)) == 3, fg = fg(:,:,1); end 32 | if length(size(gt)) == 3, gt = gt(:,:,1); end 33 | fg = imresize(fg, size(gt)); 34 | fg = mat2gray(fg); 35 | gt = mat2gray(gt); 36 | if max(fg(:)) == 0 || max(gt(:)) == 0, continue; end 37 | 38 | gt(gt>=0.5) = 1; gt(gt<0.5) = 0; gt = logical(gt); 39 | score1 = MAE(fg, gt); 40 | [score2, score3, score4] = Fmeasure(fg, gt, size(gt)); 41 | score5 = wFmeasure(fg, gt); 42 | score6 = Smeasure(fg, gt); 43 | score7 = Emeasure(fg, gt); 44 | mae = mae + score1; 45 | prec = prec + score2; 46 | rec = rec + score3; 47 | fm = fm + score4; 48 | wfm = wfm + score5; 49 | sm = sm + score6; 50 | em = em + score7; 51 | results{k, 2} = score1; 52 | results{k, 3} = score4; 53 | results{k, 4} = score5; 54 | results{k, 5} = score6; 55 | results{k, 6} = score7; 56 | [precision, recall] = PRCurve(fg*255, gt); 57 | ALLPRECISION(k, :) = precision; 58 | ALLRECALL(k, :) = recall; 59 | 60 | [all_f_th all_th] = Fm_th(fg, gt, size(gt)); 61 | a_fth(k, :) = all_f_th; 62 | a_th(k, :) = all_th; 63 | 64 | end 65 | m_fth = mean(a_fth, 1); 66 | prec = mean(ALLPRECISION(file_num,:), 1); 67 | rec = mean(ALLRECALL(file_num,:), 1); 68 | maxF = max(1.3*prec.*rec./(0.3*prec+rec+eps)); 69 | file_num = double(file_num); 70 | fm = fm / sum(file_num); 71 | mae = mae / sum(file_num); 72 | wfm = wfm / sum(file_num); 73 | sm = sm / sum(file_num); 74 | em = em / sum(file_num); 75 | 76 | fprintf('%6.3f, %6.3f, %6.3f, %6.3f, %6.3f, %6.3f\n', fm, maxF, wfm, mae, em, sm) 77 | 78 | save_path = 'C:\Users\Desktop\VT723_eval\PRcurve\Ours\'; 79 | fprintf('The save path is %s\n', save_path); 80 | if ~exist(save_path, 'dir'), mkdir(save_path); end 81 | save([save_path 'results.mat'], 'results'); 82 | save([save_path 'prec.mat'], 'prec'); 83 | save([save_path 'rec.mat'], 'rec'); 84 | 85 | 86 | save_path = 'C:\Users\Desktop\VT723_eval\Fmeasure\'; 87 | fprintf('The save path is %s\n', save_path); 88 | if ~exist(save_path, 'dir'), mkdir(save_path); end 89 | save([save_path 'Ours.mat'], 'm_fth'); -------------------------------------------------------------------------------- /loss/ssim.py: -------------------------------------------------------------------------------- 1 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | from math import exp 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | def create_window(window_size, channel): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 16 | return window 17 | 18 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 19 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 20 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 21 | 22 | mu1_sq = mu1.pow(2) 23 | mu2_sq = mu2.pow(2) 24 | mu1_mu2 = mu1*mu2 25 | 26 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 27 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 28 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 29 | 30 | C1 = 0.01**2 31 | C2 = 0.03**2 32 | 33 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 34 | 35 | if size_average: 36 | return ssim_map.mean() 37 | else: 38 | return ssim_map.mean(1).mean(1).mean(1) 39 | 40 | class SSIM(torch.nn.Module): 41 | def __init__(self, window_size = 11, size_average = True): 42 | super(SSIM, self).__init__() 43 | self.window_size = window_size 44 | self.size_average = size_average 45 | self.channel = 1 46 | self.window = create_window(window_size, self.channel) 47 | 48 | def forward(self, img1, img2): 49 | (_, channel, _, _) = img1.size() 50 | 51 | if channel == self.channel and self.window.data.type() == img1.data.type(): 52 | window = self.window 53 | else: 54 | window = create_window(self.window_size, channel) 55 | 56 | if img1.is_cuda: 57 | window = window.cuda(img1.get_device()) 58 | window = window.type_as(img1) 59 | 60 | self.window = window 61 | self.channel = channel 62 | 63 | 64 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 65 | 66 | def _logssim(img1, img2, window, window_size, channel, size_average = True): 67 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 68 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 69 | 70 | mu1_sq = mu1.pow(2) 71 | mu2_sq = mu2.pow(2) 72 | mu1_mu2 = mu1*mu2 73 | 74 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 75 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 76 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 77 | 78 | C1 = 0.01**2 79 | C2 = 0.03**2 80 | 81 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 82 | ssim_map = (ssim_map - torch.min(ssim_map))/(torch.max(ssim_map)-torch.min(ssim_map)) 83 | ssim_map = -torch.log(ssim_map + 1e-8) 84 | 85 | if size_average: 86 | return ssim_map.mean() 87 | else: 88 | return ssim_map.mean(1).mean(1).mean(1) 89 | 90 | class LOGSSIM(torch.nn.Module): 91 | def __init__(self, window_size = 11, size_average = True): 92 | super(LOGSSIM, self).__init__() 93 | self.window_size = window_size 94 | self.size_average = size_average 95 | self.channel = 1 96 | self.window = create_window(window_size, self.channel) 97 | 98 | def forward(self, img1, img2): 99 | (_, channel, _, _) = img1.size() 100 | 101 | if channel == self.channel and self.window.data.type() == img1.data.type(): 102 | window = self.window 103 | else: 104 | window = create_window(self.window_size, channel) 105 | 106 | if img1.is_cuda: 107 | window = window.cuda(img1.get_device()) 108 | window = window.type_as(img1) 109 | 110 | self.window = window 111 | self.channel = channel 112 | 113 | 114 | return _logssim(img1, img2, window, self.window_size, channel, self.size_average) 115 | 116 | 117 | def ssim(img1, img2, window_size = 11, size_average = True): 118 | (_, channel, _, _) = img1.size() 119 | window = create_window(window_size, channel) 120 | 121 | if img1.is_cuda: 122 | window = window.cuda(img1.get_device()) 123 | window = window.type_as(img1) 124 | 125 | return _ssim(img1, img2, window, window_size, channel, size_average) 126 | 127 | 128 | 129 | class CEL(torch.nn.Module): 130 | def __init__(self): 131 | super(CEL, self).__init__() 132 | #print("You are using `CEL`!") 133 | self.eps = 1e-6 134 | 135 | def forward(self, pred, target): 136 | pred = pred.sigmoid() 137 | intersection = pred * target 138 | numerator = (pred - intersection).sum() + (target - intersection).sum() 139 | denominator = pred.sum() + target.sum() 140 | return numerator / (denominator + self.eps) 141 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import sys 6 | sys.path.append('./models') 7 | import numpy as np 8 | from datetime import datetime 9 | from torchvision.utils import make_grid 10 | from net import SwinMCNet 11 | from data import get_loader,test_dataset 12 | from utils import clip_gradient 13 | from tensorboardX import SummaryWriter 14 | import logging 15 | import torch.backends.cudnn as cudnn 16 | from options import opt 17 | from loss.ssim import SSIM 18 | 19 | os.environ["CUDA_VISIBLE_DEVICES"] = "2,3" 20 | cudnn.benchmark = True 21 | 22 | #build the model 23 | model = SwinMCNet() 24 | if(opt.load is not None): 25 | model.load_pre(opt.load) 26 | print('load model from ',opt.load) 27 | model = nn.DataParallel(model).cuda() 28 | # model = model.cuda() 29 | 30 | 31 | base, body = [], [] 32 | for name, param in model.named_parameters(): 33 | if 'swin_image' in name or 'swin_thermal' in name: 34 | print(name) 35 | base.append(param) 36 | else: 37 | print(name) 38 | body.append(param) 39 | optimizer = torch.optim.SGD([{'params': base}, {'params': body}], lr=opt.lr, momentum=opt.momentum, 40 | weight_decay=opt.decay_rate, nesterov=True) 41 | 42 | 43 | #set the path 44 | train_root = opt.train_data_root 45 | test_root = opt.val_data_root 46 | 47 | save_path=opt.save_path 48 | 49 | if not os.path.exists(save_path): 50 | os.makedirs(save_path) 51 | 52 | #load data 53 | print('load data...') 54 | num_gpus = torch.cuda.device_count() 55 | print(f"========>num_gpus:{num_gpus}==========") 56 | train_loader = get_loader(train_root, batchsize=opt.batchsize, trainsize=opt.trainsize) 57 | test_loader = test_dataset(test_root, opt.trainsize) 58 | total_step = len(train_loader) 59 | 60 | logging.basicConfig(filename=save_path+'log.log',format='[%(asctime)s-%(filename)s-%(levelname)s:%(message)s]', level = logging.INFO,filemode='a',datefmt='%Y-%m-%d %I:%M:%S %p') 61 | logging.info("SwinMCNet-Train") 62 | logging.info('epoch:{};lr:{};batchsize:{};trainsize:{};clip:{};decay_rate:{};load:{};save_path:{}'.format(opt.epoch,opt.lr,opt.batchsize,opt.trainsize,opt.clip,opt.decay_rate,opt.load,save_path)) 63 | 64 | # loss 65 | def iou_loss(pred, mask): 66 | pred = torch.sigmoid(pred) 67 | inter = (pred * mask).sum(dim=(2, 3)) 68 | union = (pred + mask).sum(dim=(2, 3)) 69 | iou = 1 - (inter + 1) / (union - inter + 1) 70 | return iou.mean() 71 | ssim_loss = SSIM(window_size=11, size_average=True) 72 | 73 | step=0 74 | writer = SummaryWriter(save_path+'summary') 75 | best_mae=1 76 | best_epoch=1 77 | 78 | #train function 79 | def train(train_loader, model, optimizer, epoch,save_path): 80 | global step 81 | model.train() 82 | loss_all=0 83 | epoch_step=0 84 | try: 85 | for i, (images, ts, gts, bodys, details) in enumerate(train_loader, start=1): 86 | optimizer.zero_grad() 87 | 88 | image, t, gt, body, detail = images.cuda(), ts.cuda(), gts.cuda(), bodys.cuda(), details.cuda() 89 | 90 | outi1, outt1, out1, outi2, outt2, out2 = model(image,t) 91 | 92 | 93 | lossi1 = F.binary_cross_entropy_with_logits(outi1, body) + ssim_loss(outi1, body) 94 | losst1 = F.binary_cross_entropy_with_logits(outt1, detail) + ssim_loss(outt1, detail) 95 | loss1 = F.binary_cross_entropy_with_logits(out1, gt) + iou_loss(out1, gt) + ssim_loss(out1, gt) 96 | 97 | lossi2 = F.binary_cross_entropy_with_logits(outi2, body) + ssim_loss(outi2, body) 98 | losst2 = F.binary_cross_entropy_with_logits(outt2, detail) + ssim_loss(outt2, detail) 99 | loss2 = F.binary_cross_entropy_with_logits(out2, gt) + iou_loss(out2, gt) + ssim_loss(out2, gt) 100 | 101 | loss = (lossi1 + losst1 + loss1 + lossi2 + losst2 + loss2)/2 102 | 103 | loss.backward() 104 | 105 | clip_gradient(optimizer, opt.clip) 106 | optimizer.step() 107 | step+=1 108 | epoch_step+=1 109 | loss_all+=loss.data 110 | if i % 50 == 0 or i == total_step or i==1: 111 | print('%s | epoch:%d/%d | step:%d/%d | lr=%.6f | lossi1=%.6f | losst1=%.6f | loss1=%.6f | lossi2=%.6f | losst2=%.6f | loss2=%.6f' 112 | %(datetime.now(), epoch, opt.epoch, i, total_step, optimizer.param_groups[0]['lr'], lossi1.item(), 113 | losst1.item(), loss1.item(), lossi2.item(), losst2.item(), loss2.item())) 114 | 115 | logging.info('##TRAIN##:Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], lr_bk: {:.6f}, Loss1: {:.4f} Loss2: {:0.4f}'. 116 | format( epoch, opt.epoch, i, total_step, optimizer.param_groups[0]['lr'], loss1.data, loss2.data)) 117 | writer.add_scalar('Loss', loss.data, global_step=step) 118 | grid_image = make_grid(images[0].clone().cpu().data, 1, normalize=True) 119 | writer.add_image('RGB', grid_image, step) 120 | grid_image = make_grid(gts[0].clone().cpu().data, 1, normalize=True) 121 | writer.add_image('Ground_truth', grid_image, step) 122 | res=out1[0].clone() 123 | res = res.sigmoid().data.cpu().numpy().squeeze() 124 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 125 | writer.add_image('out1', torch.tensor(res), step,dataformats='HW') 126 | res=out2[0].clone() 127 | res = res.sigmoid().data.cpu().numpy().squeeze() 128 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 129 | writer.add_image('out2', torch.tensor(res), step,dataformats='HW') 130 | 131 | loss_all/=epoch_step 132 | logging.info('##TRAIN##:Epoch [{:03d}/{:03d}], Loss_AVG: {:.4f}'.format( epoch, opt.epoch, loss_all)) 133 | writer.add_scalar('Loss-epoch', loss_all, global_step=epoch) 134 | if (epoch) % 50 == 0: 135 | torch.save(model.state_dict(), save_path+'SwinMCNet_epoch_{}.pth'.format(epoch)) 136 | except KeyboardInterrupt: 137 | print('Keyboard Interrupt: save model and exit.') 138 | if not os.path.exists(save_path): 139 | os.makedirs(save_path) 140 | torch.save(model.state_dict(), save_path+'SwinMCNet_epoch_{}.pth'.format(epoch)) 141 | print('save checkpoints successfully!') 142 | raise 143 | 144 | #test function 145 | def test(test_loader,model,epoch,save_path): 146 | global best_mae,best_epoch 147 | model.eval() 148 | with torch.no_grad(): 149 | mae_sum=0 150 | #for i in range(1000): 151 | for i in range(test_loader.size): 152 | image, t, gt, (H, W), name = test_loader.load_data() 153 | gt = np.asarray(gt, np.float32) 154 | gt /= (gt.max() + 1e-8) 155 | image = image.cuda() 156 | t = t.cuda() 157 | #shape = (W,H) 158 | outi1, outt1, out1, outi2, outt2, out2 = model(image,t) 159 | res = out2 160 | res = F.interpolate(res, size=gt.shape, mode='bilinear') 161 | res = res.sigmoid().data.cpu().numpy().squeeze() 162 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 163 | mae_sum += np.sum(np.abs(res-gt))*1.0/(gt.shape[0]*gt.shape[1]) 164 | mae=mae_sum/test_loader.size 165 | writer.add_scalar('MAE', torch.tensor(mae), global_step=epoch) 166 | print('\n') 167 | print('##TEST##:Epoch: {} MAE: {}'.format(epoch,mae)) 168 | 169 | if epoch==1: 170 | best_mae=mae 171 | else: 172 | if mae> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 41 | (image_height + crop_win_height) >> 1) 42 | return img.crop(random_region), t.crop(random_region),gt.crop(random_region),body.crop(random_region),detail.crop(random_region) 43 | def randomRotation(img,t,gt,body,detail): 44 | mode=Image.BICUBIC 45 | if random.random()>0.8: 46 | random_angle = np.random.randint(-15, 15) 47 | img = img.rotate(random_angle, mode) 48 | t = t.rotate(random_angle, mode) 49 | gt = gt.rotate(random_angle, mode) 50 | body = body.rotate(random_angle, mode) 51 | detail = detail.rotate(random_angle, mode) 52 | return img,t,gt,body,detail 53 | def colorEnhance(image): 54 | bright_intensity=random.randint(5,15)/10.0 55 | image=ImageEnhance.Brightness(image).enhance(bright_intensity) 56 | contrast_intensity=random.randint(5,15)/10.0 57 | image=ImageEnhance.Contrast(image).enhance(contrast_intensity) 58 | color_intensity=random.randint(0,20)/10.0 59 | image=ImageEnhance.Color(image).enhance(color_intensity) 60 | sharp_intensity=random.randint(0,30)/10.0 61 | image=ImageEnhance.Sharpness(image).enhance(sharp_intensity) 62 | return image 63 | '''def randomGaussian(image, mean=0.1, sigma=0.35): 64 | def gaussianNoisy(im, mean=mean, sigma=sigma): 65 | for _i in range(len(im)): 66 | im[_i] += random.gauss(mean, sigma) 67 | return im 68 | img = np.asarray(image) 69 | width, height = img.shape 70 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 71 | img = img.reshape([width, height]) 72 | return Image.fromarray(np.uint8(img))''' 73 | def randomPeper(img): 74 | img=np.array(img) 75 | noiseNum=int(0.0015*img.shape[0]*img.shape[1]) 76 | for i in range(noiseNum): 77 | randX=random.randint(0,img.shape[0]-1) 78 | randY=random.randint(0,img.shape[1]-1) 79 | if random.randint(0,1)==0: 80 | img[randX,randY]=0 81 | else: 82 | img[randX,randY]=255 83 | return Image.fromarray(img) 84 | 85 | '''def randomGamma(img): 86 | gamma_flag = random.randint(0, 2) 87 | if gamma_flag == 1: 88 | gamma_value = random.uniform(2,5) 89 | img=np.array(img) 90 | img = exposure.adjust_gamma(img, gamma_value) 91 | return Image.fromarray(np.uint8(img))''' 92 | 93 | # dataset for training 94 | class SalObjDataset(data.Dataset): 95 | def __init__(self, train_root, trainsize): 96 | self.trainsize = trainsize 97 | 98 | self.image_root = train_root + '/RGB/' 99 | self.gt_root = train_root + '/GT/' 100 | self.t_root = train_root + '/T/' 101 | self.body_root = train_root + '/body/' 102 | self.detail_root = train_root + '/detail/' 103 | 104 | 105 | self.images = [self.image_root + f for f in os.listdir(self.image_root) if f.endswith('.jpg') or f.endswith('.png')] 106 | self.gts = [self.gt_root + f for f in os.listdir(self.gt_root) if f.endswith('.png')] 107 | self.ts = [self.t_root + f for f in os.listdir(self.t_root) if f.endswith('.jpg') or f.endswith('.png')] 108 | self.bodys = [self.body_root + f for f in os.listdir(self.body_root) if f.endswith('.png')] 109 | self.details = [self.detail_root + f for f in os.listdir(self.detail_root) if f.endswith('.png')] 110 | 111 | self.images = sorted(self.images) 112 | self.gts = sorted(self.gts) 113 | self.ts = sorted(self.ts) 114 | self.bodys = sorted(self.bodys) 115 | self.details = sorted(self.details) 116 | 117 | # self.filter_files() 118 | self.size = len(self.images) 119 | 120 | ## RGB(VT5000 + VT1000 + VT821) 121 | # [0.525, 0.590, 0.537], [0.177, 0.167, 0.176] 122 | 123 | ## MIX(VT5000 + VT1000 + VT821) 124 | # [0.501, 0.612, 0.602], [0.173, 0.152, 0.166] 125 | 126 | 127 | self.img_transform = transforms.Compose([ 128 | transforms.Resize((self.trainsize, self.trainsize)), 129 | transforms.ToTensor(), 130 | transforms.Normalize([0.525, 0.590, 0.537], [0.177, 0.167, 0.176])]) 131 | 132 | 133 | ## T(VT5000 + VT1000 + VT821) 134 | # [0.736, 0.346, 0.339], [0.179, 0.196, 0.169] 135 | 136 | self.t_transform = transforms.Compose([ 137 | transforms.Resize((self.trainsize, self.trainsize)), 138 | transforms.ToTensor(), 139 | transforms.Normalize([0.736, 0.346, 0.339], [0.179, 0.196, 0.169])]) 140 | 141 | self.gt_transform = transforms.Compose([transforms.Resize((self.trainsize, self.trainsize)),transforms.ToTensor()]) 142 | self.body_transform = transforms.Compose([transforms.Resize((self.trainsize, self.trainsize)),transforms.ToTensor()]) 143 | self.detail_transform = transforms.Compose([transforms.Resize((self.trainsize, self.trainsize)),transforms.ToTensor()]) 144 | 145 | def __getitem__(self, index): 146 | image = self.rgb_loader(self.images[index]) 147 | t = self.rgb_loader(self.ts[index]) 148 | gt = self.binary_loader(self.gts[index]) 149 | body = self.binary_loader(self.bodys[index]) 150 | detail = self.binary_loader(self.details[index]) 151 | 152 | image,t,gt,body,detail = cv_random_flip(image,t,gt,body,detail) 153 | image,t,gt,body,detail = randomCrop(image,t,gt,body,detail ) 154 | image,t,gt,body,detail = randomRotation(image,t,gt,body,detail ) 155 | 156 | # image = randomGamma(image) 157 | image = colorEnhance(image) 158 | 159 | t = colorEnhance(t) 160 | # gt=randomGaussian(gt) 161 | gt = randomPeper(gt) 162 | 163 | image = self.img_transform(image) 164 | t = self.t_transform(t) 165 | gt = self.gt_transform(gt) 166 | body = self.body_transform(body) 167 | detail = self.detail_transform(detail) 168 | 169 | return image, t, gt, body, detail 170 | 171 | '''def filter_files(self): 172 | assert len(self.images) == len(self.gts) and len(self.gts)==len(self.images) 173 | images = [] 174 | gts = [] 175 | depths=[] 176 | for img_path, gt_path,depth_path in zip(self.images, self.gts, self.depths): 177 | img = Image.open(img_path) 178 | gt = Image.open(gt_path) 179 | depth= Image.open(depth_path) 180 | if img.size == gt.size and gt.size==depth.size: 181 | images.append(img_path) 182 | gts.append(gt_path) 183 | depths.append(depth_path) 184 | self.images = images 185 | self.gts = gts 186 | self.depths=depths''' 187 | 188 | def rgb_loader(self, path): 189 | with open(path, 'rb') as f: 190 | img = Image.open(f) 191 | return img.convert('RGB') 192 | 193 | def binary_loader(self, path): 194 | with open(path, 'rb') as f: 195 | img = Image.open(f) 196 | return img.convert('L') 197 | 198 | '''def resize(self, img,t,gt,body,detail): 199 | assert img.size == gt.size and gt.size==t.size 200 | w, h = img.size 201 | if h < self.trainsize or w < self.trainsize: 202 | h = max(h, self.trainsize) 203 | w = max(w, self.trainsize) 204 | return img.resize((w, h), Image.BILINEAR),t.resize((w, h), Image.BILINEAR),gt.resize((w, h), Image.NEAREST),body.resize((w, h), Image.NEAREST),detail.resize((w, h), Image.NEAREST) 205 | else: 206 | return img,t,gt,body,detail''' 207 | 208 | def __len__(self): 209 | return self.size 210 | 211 | #dataloader for training 212 | def get_loader(train_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=False): 213 | 214 | dataset = SalObjDataset(train_root,trainsize) 215 | data_loader = data.DataLoader(dataset=dataset, 216 | batch_size=batchsize, 217 | shuffle=shuffle, 218 | num_workers=num_workers, 219 | pin_memory=pin_memory) 220 | return data_loader 221 | 222 | #test dataset and loader 223 | class test_dataset: 224 | def __init__(self, test_root, testsize): 225 | self.testsize = testsize 226 | 227 | self.image_root = test_root + '/RGB/' 228 | self.gt_root = test_root + '/GT/' 229 | self.t_root = test_root + '/T/' 230 | 231 | self.images = [self.image_root + f for f in os.listdir(self.image_root) if f.endswith('.jpg') or f.endswith('.png')] 232 | self.gts = [self.gt_root + f for f in os.listdir(self.gt_root) if f.endswith('.png') or f.endswith('.jpg')] 233 | self.ts = [self.t_root + f for f in os.listdir(self.t_root) if f.endswith('.jpg') or f.endswith('.png')] 234 | 235 | self.images = sorted(self.images) 236 | self.gts = sorted(self.gts) 237 | self.ts = sorted(self.ts) 238 | 239 | ## RGB(VT5000 + VT1000 + VT821) 240 | # [0.525, 0.590, 0.537], [0.177, 0.167, 0.176] 241 | ## MIX(VT5000 + VT1000 + VT821) 242 | # [0.501, 0.612, 0.602], [0.173, 0.152, 0.166] 243 | 244 | ## RGB(VT1606) 245 | # [0.238, 0.271, 0.236], [0.172, 0.174, 0.174] 246 | 247 | self.img_transform = transforms.Compose([ 248 | transforms.Resize((self.testsize, self.testsize)), 249 | transforms.ToTensor(), 250 | transforms.Normalize([0.525, 0.590, 0.537], [0.177, 0.167, 0.176])]) 251 | 252 | 253 | ## T(VT5000 + VT1000 + VT821) 254 | # [0.736, 0.346, 0.339], [0.179, 0.196, 0.169] 255 | ## T(VT251) 256 | # [0.273, 0.687, 0.716], [0.148, 0.212, 0.155] 257 | 258 | ## T(VT1606) 259 | # color 260 | # [0.213, 0.656, 0.779], [0.113, 0.203, 0.122] 261 | # gray 262 | # [0.3451, 0.345, 0.345], [0.0768, 0.0768, 0.0768] 263 | 264 | 265 | self.t_transform = transforms.Compose([ 266 | transforms.Resize((self.testsize, self.testsize)), 267 | transforms.ToTensor(), 268 | transforms.Normalize([0.736, 0.346, 0.339], [0.179, 0.196, 0.169])]) 269 | 270 | self.gt_transform = transforms.ToTensor() 271 | 272 | self.size = len(self.images) 273 | self.index = 0 274 | 275 | def load_data(self): 276 | image = self.rgb_loader(self.images[self.index]) 277 | shape = image.size 278 | image = self.img_transform(image).unsqueeze(0) 279 | 280 | t = self.rgb_loader(self.ts[self.index]) 281 | t = self.t_transform(t).unsqueeze(0) 282 | 283 | gt = self.binary_loader(self.gts[self.index]) 284 | 285 | name = self.images[self.index].split('/')[-1] 286 | 287 | if name.endswith('.jpg'): 288 | name = name.split('.jpg')[0] + '.png' 289 | self.index += 1 290 | self.index = self.index % self.size 291 | return image, t, gt, shape, name 292 | # return image, t, shape, name 293 | 294 | def rgb_loader(self, path): 295 | with open(path, 'rb') as f: 296 | img = Image.open(f) 297 | return img.convert('RGB') 298 | 299 | def binary_loader(self, path): 300 | with open(path, 'rb') as f: 301 | img = Image.open(f) 302 | return img.convert('L') 303 | def __len__(self): 304 | return self.size -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.utils.checkpoint as checkpoint 6 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | 11 | class Decoder(nn.Module): 12 | def __init__(self): 13 | super(Decoder, self).__init__() 14 | self.conv0 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 15 | self.bn0 = nn.BatchNorm2d(64) 16 | self.conv1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 17 | self.bn1 = nn.BatchNorm2d(64) 18 | self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 19 | self.bn2 = nn.BatchNorm2d(64) 20 | self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 21 | self.bn3 = nn.BatchNorm2d(64) 22 | 23 | def forward(self, input1, input2=[0,0,0,0]): 24 | out0 = F.relu(self.bn0(self.conv0(input1[0]+input2[0])), inplace=True) 25 | out0 = F.interpolate(out0, size=input1[1].size()[2:], mode='bilinear') 26 | out1 = F.relu(self.bn1(self.conv1(input1[1]+input2[1]+out0)), inplace=True) 27 | out1 = F.interpolate(out1, size=input1[2].size()[2:], mode='bilinear') 28 | out2 = F.relu(self.bn2(self.conv2(input1[2]+input2[2]+out1)), inplace=True) 29 | out2 = F.interpolate(out2, size=input1[3].size()[2:], mode='bilinear') 30 | out3 = F.relu(self.bn3(self.conv3(input1[3]+input2[3]+out2)), inplace=True) 31 | 32 | return out3 33 | 34 | 35 | 36 | class ChannelAttention(nn.Module): 37 | def __init__(self, in_planes=192, ratio=16): 38 | super(ChannelAttention, self).__init__() 39 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 40 | self.max_pool = nn.AdaptiveMaxPool2d(1) 41 | 42 | self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) 43 | self.relu1 = nn.ReLU() 44 | self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) 45 | 46 | self.sigmoid = nn.Sigmoid() 47 | 48 | def forward(self, x): 49 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 50 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 51 | out = avg_out + max_out 52 | return self.sigmoid(out) 53 | 54 | 55 | class SpatialAttention(nn.Module): 56 | def __init__(self, kernel_size=7): 57 | super(SpatialAttention, self).__init__() 58 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 59 | padding = 3 if kernel_size == 7 else 1 60 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 61 | 62 | self.sigmoid = nn.Sigmoid() 63 | 64 | def forward(self, x): 65 | avg_out = torch.mean(x, dim=1, keepdim=True) 66 | max_out, _ = torch.max(x, dim=1, keepdim=True) 67 | x = torch.cat([avg_out, max_out], dim=1) 68 | x = self.conv1(x) 69 | 70 | return self.sigmoid(x) 71 | 72 | 73 | 74 | class Attention(nn.Module): 75 | def __init__(self): 76 | super(Attention , self).__init__() 77 | 78 | self.att_c = ChannelAttention() 79 | self.att_s = SpatialAttention() 80 | 81 | self.fc_i1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) 82 | self.fc_t1 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) 83 | 84 | self.fc_i2 = nn.Sequential(nn.Conv2d(192, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) 85 | self.fc_t2 = nn.Sequential(nn.Conv2d(192, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) 86 | 87 | 88 | def forward(self, i, t): 89 | 90 | sa_i = i.mul(self.att_s(i)) 91 | sa_t = t.mul(self.att_s(t)) 92 | 93 | i1 = self.fc_i1(i) 94 | t1 = self.fc_t1(t) 95 | 96 | mix2 = torch.cat([torch.cat([i1, torch.mul(i1, t1)], dim=1), t1], dim =1) 97 | ca = mix2.mul(self.att_c(mix2)) 98 | 99 | atti = self.fc_i2(ca) + sa_t 100 | attt = self.fc_t2(ca) + sa_i 101 | 102 | return atti, attt 103 | 104 | 105 | class BasicConv(nn.Module): 106 | 107 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False): 108 | super(BasicConv, self).__init__() 109 | self.out_channels = out_planes 110 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 111 | self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None 112 | self.relu = nn.PReLU() if relu else None 113 | 114 | def forward(self, x): 115 | x = self.conv(x) 116 | if self.bn is not None: 117 | x = self.bn(x) 118 | if self.relu is not None: 119 | x = self.relu(x) 120 | return x 121 | 122 | 123 | class LightRFB(nn.Module): 124 | def __init__(self, c_in=128, c_out=64): 125 | super(LightRFB, self).__init__() 126 | self.br2 = nn.Sequential( 127 | BasicConv(c_in, c_out, kernel_size=1,bias=False, bn=True, relu=True), 128 | 129 | BasicConv(c_out, c_out, kernel_size=3, dilation=1, padding=1, groups=c_out, bias=False, 130 | relu=False), 131 | ) 132 | self.br3 = nn.Sequential( 133 | BasicConv(c_out, c_out, kernel_size=3, dilation=1, padding=1, groups=c_out, bias=False, 134 | bn=True,relu=False), 135 | BasicConv(c_out, c_out, kernel_size=1, dilation=1, bias=False,bn=True,relu=True), 136 | 137 | BasicConv(c_out, c_out, kernel_size=3, dilation=3, padding=3, groups=c_out, bias=False, 138 | relu=False), 139 | ) 140 | self.br4 = nn.Sequential( 141 | BasicConv(c_out, c_out, kernel_size=5, dilation=1, padding=2, groups=c_out, bias=False, 142 | bn=True, relu=False), 143 | BasicConv(c_out, c_out, kernel_size=1, dilation=1, bias=False, bn=True, relu=True), 144 | 145 | BasicConv(c_out, c_out, kernel_size=3, dilation=5, padding=5, groups=c_out, bias=False, 146 | relu=False), 147 | ) 148 | self.br5 = nn.Sequential( 149 | BasicConv(c_out, c_out, kernel_size=7, dilation=1, padding=3, groups=c_out, bias=False, 150 | bn=True, relu=False), 151 | BasicConv(c_out, c_out, kernel_size=1, dilation=1, bias=False, bn=True, relu=True), 152 | 153 | BasicConv(c_out, c_out, kernel_size=3, dilation=7, padding=7, groups=c_out, bias=False, 154 | relu=False), 155 | ) 156 | 157 | 158 | self.conv1b = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 159 | self.bn1b = nn.BatchNorm2d(64) 160 | self.conv2b = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 161 | self.bn2b = nn.BatchNorm2d(64) 162 | self.conv3b = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 163 | self.bn3b = nn.BatchNorm2d(64) 164 | self.conv4b = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 165 | self.bn4b = nn.BatchNorm2d(64) 166 | 167 | self.conv1d = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 168 | self.bn1d = nn.BatchNorm2d(64) 169 | self.conv2d = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 170 | self.bn2d = nn.BatchNorm2d(64) 171 | self.conv3d = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 172 | self.bn3d = nn.BatchNorm2d(64) 173 | self.conv4d = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 174 | self.bn4d = nn.BatchNorm2d(64) 175 | 176 | 177 | def forward(self, x): 178 | 179 | out2=self.br2(x) 180 | 181 | x3=self.br3(out2) 182 | out3 = F.max_pool2d(x3, kernel_size=2, stride=2) 183 | 184 | x4=self.br4(out3) 185 | out4 = F.max_pool2d(x4, kernel_size=2, stride=2) 186 | 187 | x5=self.br5(out4) 188 | out5 = F.max_pool2d(x5, kernel_size=2, stride=2) 189 | 190 | 191 | out1b = F.relu(self.bn1b(self.conv1b(out2)), inplace=True) 192 | out2b = F.relu(self.bn2b(self.conv2b(out3)), inplace=True) 193 | out3b = F.relu(self.bn3b(self.conv3b(out4)), inplace=True) 194 | out4b = F.relu(self.bn4b(self.conv4b(out5)), inplace=True) 195 | 196 | out1d = F.relu(self.bn1d(self.conv1d(out2)), inplace=True) 197 | out2d = F.relu(self.bn2d(self.conv2d(out3)), inplace=True) 198 | out3d = F.relu(self.bn3d(self.conv3d(out4)), inplace=True) 199 | out4d = F.relu(self.bn4d(self.conv4d(out5)), inplace=True) 200 | 201 | return (out4b, out3b, out2b, out1b), (out4d, out3d, out2d, out1d) 202 | 203 | 204 | 205 | class Mlp(nn.Module): 206 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 207 | super().__init__() 208 | out_features = out_features or in_features 209 | hidden_features = hidden_features or in_features 210 | self.fc1 = nn.Linear(in_features, hidden_features) 211 | self.act = act_layer() 212 | self.fc2 = nn.Linear(hidden_features, out_features) 213 | self.drop = nn.Dropout(drop) 214 | 215 | def forward(self, x): 216 | x = self.fc1(x) 217 | x = self.act(x) 218 | x = self.drop(x) 219 | x = self.fc2(x) 220 | x = self.drop(x) 221 | return x 222 | 223 | 224 | def window_partition(x, window_size): 225 | """ 226 | Args: 227 | x: (B, H, W, C) 228 | window_size (int): window size 229 | 230 | Returns: 231 | windows: (num_windows*B, window_size, window_size, C) 堆叠到一起形成一个长条 232 | """ 233 | B, H, W, C = x.shape 234 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 235 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 236 | return windows 237 | 238 | 239 | def window_reverse(windows, window_size, H, W): 240 | """ 241 | Args: 242 | windows: (num_windows*B, window_size, window_size, C) 243 | window_size (int): Window size 244 | H (int): Height of image 245 | W (int): Width of image 246 | 247 | Returns: 248 | x: (B, H, W, C) 249 | """ 250 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 251 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 252 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 253 | return x 254 | 255 | 256 | class WindowAttention(nn.Module): 257 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 258 | It supports both of shifted and non-shifted window. 259 | 260 | Args: 261 | dim (int): Number of input channels. 262 | window_size (tuple[int]): The height and width of the window. 263 | num_heads (int): Number of attention heads. 264 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 265 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 266 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 267 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 268 | """ 269 | 270 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 271 | 272 | super().__init__() 273 | self.dim = dim 274 | self.window_size = window_size # Wh, Ww 275 | self.num_heads = num_heads 276 | head_dim = dim // num_heads # 每一个头的通道维数 277 | self.scale = qk_scale or head_dim ** -0.5 278 | 279 | # define a parameter table of relative position bias 280 | self.relative_position_bias_table = nn.Parameter( 281 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 282 | 283 | # get pair-wise relative position index for each token inside the window 284 | coords_h = torch.arange(self.window_size[0]) 285 | coords_w = torch.arange(self.window_size[1]) 286 | 287 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 288 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 289 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 290 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 291 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 292 | relative_coords[:, :, 1] += self.window_size[1] - 1 293 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 294 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 295 | self.register_buffer("relative_position_index", relative_position_index) 296 | 297 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 298 | self.attn_drop = nn.Dropout(attn_drop) 299 | self.proj = nn.Linear(dim, dim) 300 | self.proj_drop = nn.Dropout(proj_drop) 301 | 302 | trunc_normal_(self.relative_position_bias_table, std=.02) 303 | self.softmax = nn.Softmax(dim=-1) 304 | 305 | def forward(self, x, mask=None): 306 | """ 307 | Args: 308 | x: input features with shape of (num_windows*B, N, C) 309 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 310 | """ 311 | B_, N, C = x.shape 312 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 313 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 314 | 315 | q = q * self.scale 316 | attn = (q @ k.transpose(-2, -1)) 317 | 318 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 319 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 320 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 321 | attn = attn + relative_position_bias.unsqueeze(0) 322 | 323 | if mask is not None: 324 | nW = mask.shape[0] 325 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 326 | attn = attn.view(-1, self.num_heads, N, N) 327 | attn = self.softmax(attn) 328 | else: 329 | attn = self.softmax(attn) 330 | 331 | attn = self.attn_drop(attn) 332 | 333 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 334 | x = self.proj(x) 335 | x = self.proj_drop(x) 336 | return x 337 | 338 | def extra_repr(self) -> str: 339 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 340 | 341 | def flops(self, N): 342 | # calculate flops for 1 window with token length of N 343 | flops = 0 344 | # qkv = self.qkv(x) 345 | flops += N * self.dim * 3 * self.dim 346 | # attn = (q @ k.transpose(-2, -1)) 347 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 348 | # x = (attn @ v) 349 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 350 | # x = self.proj(x) 351 | flops += N * self.dim * self.dim 352 | return flops 353 | 354 | 355 | class SwinTransformerBlock(nn.Module): 356 | r""" Swin Transformer Block. 357 | 358 | Args: 359 | dim (int): Number of input channels. 360 | input_resolution (tuple[int]): Input resulotion. 361 | num_heads (int): Number of attention heads. 362 | window_size (int): Window size. 363 | shift_size (int): Shift size for SW-MSA. 364 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 365 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 366 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 367 | drop (float, optional): Dropout rate. Default: 0.0 368 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 369 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 370 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 371 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 372 | """ 373 | 374 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 375 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 376 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 377 | super().__init__() 378 | self.dim = dim 379 | self.input_resolution = input_resolution 380 | self.num_heads = num_heads 381 | self.window_size = window_size 382 | self.shift_size = shift_size 383 | self.mlp_ratio = mlp_ratio 384 | if min(self.input_resolution) <= self.window_size: 385 | # if window size is larger than input resolution, we don't partition windows 386 | self.shift_size = 0 387 | self.window_size = min(self.input_resolution) 388 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 389 | 390 | self.norm1 = norm_layer(dim) 391 | self.attn = WindowAttention( 392 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 393 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 394 | 395 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 396 | self.norm2 = norm_layer(dim) 397 | mlp_hidden_dim = int(dim * mlp_ratio) 398 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 399 | 400 | if self.shift_size > 0: 401 | # calculate attention mask for SW-MSA 402 | H, W = self.input_resolution 403 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1---Important!!! 404 | h_slices = (slice(0, -self.window_size), 405 | slice(-self.window_size, -self.shift_size), 406 | slice(-self.shift_size, None)) 407 | w_slices = (slice(0, -self.window_size), 408 | slice(-self.window_size, -self.shift_size), 409 | slice(-self.shift_size, None)) 410 | cnt = 0 411 | for h in h_slices: 412 | for w in w_slices: 413 | img_mask[:, h, w, :] = cnt 414 | cnt += 1 415 | 416 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 417 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 418 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 419 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 420 | else: 421 | attn_mask = None 422 | 423 | self.register_buffer("attn_mask", attn_mask) 424 | 425 | def forward(self, x): 426 | # 输入此的x是整图 427 | H, W = self.input_resolution 428 | B, L, C = x.shape 429 | assert L == H * W, "input feature has wrong size" 430 | 431 | shortcut = x 432 | x = self.norm1(x) 433 | x = x.view(B, H, W, C) 434 | 435 | # cyclic shift 436 | if self.shift_size > 0: 437 | 438 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 439 | else: 440 | shifted_x = x 441 | 442 | # partition windows 443 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 444 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 445 | 446 | # W-MSA/SW-MSA 447 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 448 | 449 | # merge windows 450 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 451 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 452 | 453 | # reverse cyclic shift 454 | if self.shift_size > 0: 455 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 456 | else: 457 | x = shifted_x 458 | x = x.view(B, H * W, C) 459 | 460 | # FFN 461 | x = shortcut + self.drop_path(x) 462 | x = x + self.drop_path(self.mlp(self.norm2(x))) 463 | # print('FFN',x.shape) 464 | return x 465 | 466 | def extra_repr(self) -> str: 467 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 468 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 469 | 470 | def flops(self): 471 | flops = 0 472 | H, W = self.input_resolution 473 | # norm1 474 | flops += self.dim * H * W 475 | # W-MSA/SW-MSA 476 | nW = H * W / self.window_size / self.window_size 477 | flops += nW * self.attn.flops(self.window_size * self.window_size) 478 | # mlp 479 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 480 | # norm2 481 | flops += self.dim * H * W 482 | return flops 483 | 484 | 485 | class PatchMerging(nn.Module): 486 | r""" Patch Merging Layer. 487 | 488 | Args: 489 | input_resolution (tuple[int]): Resolution of input feature. 490 | dim (int): Number of input channels. 491 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 492 | """ 493 | 494 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 495 | super().__init__() 496 | self.input_resolution = input_resolution 497 | self.dim = dim 498 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 499 | self.norm = norm_layer(4 * dim) 500 | 501 | def forward(self, x): 502 | """ 503 | x: B, H*W, C 504 | """ 505 | H, W = self.input_resolution 506 | B, L, C = x.shape 507 | assert L == H * W, "input feature has wrong size" 508 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 509 | 510 | x = x.view(B, H, W, C) 511 | 512 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 513 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 514 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 515 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 516 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 517 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 518 | 519 | x = self.norm(x) 520 | x = self.reduction(x) 521 | 522 | return x 523 | 524 | def extra_repr(self) -> str: 525 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 526 | 527 | def flops(self): 528 | H, W = self.input_resolution 529 | flops = H * W * self.dim 530 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 531 | return flops 532 | 533 | 534 | class BasicLayer(nn.Module): 535 | """ A basic Swin Transformer layer for one stage. 536 | 537 | Args: 538 | dim (int): Number of input channels. 539 | input_resolution (tuple[int]): Input resolution. 540 | depth (int): Number of blocks. 541 | num_heads (int): Number of attention heads. 542 | window_size (int): Local window size. 543 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 544 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 545 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 546 | drop (float, optional): Dropout rate. Default: 0.0 547 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 548 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 549 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 550 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 551 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 552 | """ 553 | 554 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 555 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 556 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 557 | 558 | super().__init__() 559 | self.dim = dim 560 | self.input_resolution = input_resolution 561 | self.depth = depth 562 | self.use_checkpoint = use_checkpoint 563 | 564 | # build blocks 565 | self.blocks = nn.ModuleList([ 566 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 567 | num_heads=num_heads, window_size=window_size, 568 | shift_size=0 if (i % 2 == 0) else window_size // 2, 569 | mlp_ratio=mlp_ratio, 570 | qkv_bias=qkv_bias, qk_scale=qk_scale, 571 | drop=drop, attn_drop=attn_drop, 572 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 573 | norm_layer=norm_layer) 574 | for i in range(depth)]) 575 | 576 | # patch merging layer 577 | if downsample is not None: 578 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 579 | else: 580 | self.downsample = None 581 | 582 | def forward(self, x): 583 | 584 | for blk in self.blocks: 585 | if self.use_checkpoint: 586 | x = checkpoint.checkpoint(blk, x) 587 | else: 588 | x = blk(x) 589 | if self.downsample is not None: 590 | x_down = self.downsample(x) 591 | elif self.downsample is None: 592 | x_down = x 593 | return x_down, x 594 | 595 | def extra_repr(self) -> str: 596 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 597 | 598 | def flops(self): 599 | flops = 0 600 | for blk in self.blocks: 601 | flops += blk.flops() 602 | if self.downsample is not None: 603 | flops += self.downsample.flops() 604 | return flops 605 | 606 | 607 | class PatchEmbed(nn.Module): 608 | 609 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 610 | super().__init__() 611 | img_size = to_2tuple(img_size) 612 | patch_size = to_2tuple(patch_size) 613 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 614 | self.img_size = img_size 615 | self.patch_size = patch_size 616 | self.patches_resolution = patches_resolution 617 | self.num_patches = patches_resolution[0] * patches_resolution[1] 618 | 619 | self.in_chans = in_chans # define in_chans == 3 620 | self.embed_dim = embed_dim # Swin-B.embed_dim ==128,(T is 96) 621 | 622 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # dim 3->128 623 | if norm_layer is not None: 624 | self.norm = norm_layer(embed_dim) 625 | else: 626 | self.norm = None 627 | 628 | def forward(self, x): 629 | B, C, H, W = x.shape 630 | # FIXME look at relaxing size constraints,尺寸固定,下有断言 631 | assert H == self.img_size[0] and W == self.img_size[1], \ 632 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 633 | 634 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 635 | if self.norm is not None: 636 | x = self.norm(x) 637 | return x 638 | 639 | def flops(self): 640 | Ho, Wo = self.patches_resolution 641 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 642 | if self.norm is not None: 643 | flops += Ho * Wo * self.embed_dim 644 | return flops 645 | 646 | 647 | class SwinTransformer(nn.Module): 648 | r""" Swin Transformer 649 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 650 | https://arxiv.org/pdf/2103.14030 651 | 652 | Args: 653 | img_size (int | tuple(int)): Input image size. Default 224 654 | patch_size (int | tuple(int)): Patch size. Default: 4 655 | in_chans (int): Number of input image channels. Default: 3 656 | embed_dim (int): Patch embedding dimension. Default: 96 657 | depths (tuple(int)): Depth of each Swin Transformer layer. 658 | num_heads (tuple(int)): Number of attention heads in different layers. 659 | window_size (int): Window size. Default: 7 660 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 661 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 662 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 663 | drop_rate (float): Dropout rate. Default: 0 664 | attn_drop_rate (float): Attention dropout rate. Default: 0 665 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 666 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 667 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 668 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 669 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 670 | """ 671 | 672 | def __init__(self, img_size=384, patch_size=4, in_chans=3, 673 | embed_dim=128, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 674 | window_size=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, 675 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 676 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 677 | use_checkpoint=False, **kwargs): 678 | super().__init__() 679 | 680 | 681 | self.num_layers = len(depths) 682 | self.embed_dim = embed_dim 683 | self.ape = ape 684 | self.patch_norm = patch_norm 685 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 686 | self.mlp_ratio = mlp_ratio 687 | 688 | 689 | self.patch_embed = PatchEmbed( 690 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 691 | norm_layer=norm_layer if self.patch_norm else None) 692 | num_patches = self.patch_embed.num_patches 693 | patches_resolution = self.patch_embed.patches_resolution 694 | self.patches_resolution = patches_resolution 695 | 696 | 697 | if self.ape: 698 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 699 | trunc_normal_(self.absolute_pos_embed, std=.02) 700 | 701 | self.pos_drop = nn.Dropout(p=drop_rate) 702 | 703 | 704 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 705 | 706 | # build layers 707 | self.layers = nn.ModuleList() 708 | for i_layer in range(self.num_layers): 709 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 710 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 711 | patches_resolution[1] // (2 ** i_layer)), 712 | depth=depths[i_layer], 713 | num_heads=num_heads[i_layer], 714 | window_size=window_size, 715 | mlp_ratio=self.mlp_ratio, 716 | qkv_bias=qkv_bias, qk_scale=qk_scale, 717 | drop=drop_rate, attn_drop=attn_drop_rate, 718 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 719 | norm_layer=norm_layer, 720 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 721 | use_checkpoint=use_checkpoint) 722 | # self.layers 中应该是 4 个 723 | self.layers.append(layer) 724 | 725 | self.norm = norm_layer(self.num_features) 726 | self.apply(self._init_weights) 727 | 728 | def _init_weights(self, m): 729 | if isinstance(m, nn.Linear): 730 | trunc_normal_(m.weight, std=.02) 731 | if isinstance(m, nn.Linear) and m.bias is not None: 732 | nn.init.constant_(m.bias, 0) 733 | elif isinstance(m, nn.LayerNorm): 734 | nn.init.constant_(m.bias, 0) 735 | nn.init.constant_(m.weight, 1.0) 736 | 737 | @torch.jit.ignore 738 | def no_weight_decay(self): 739 | return {'absolute_pos_embed'} 740 | 741 | @torch.jit.ignore 742 | def no_weight_decay_keywords(self): 743 | return {'relative_position_bias_table'} 744 | 745 | def forward_features(self, x): 746 | layer_features = [] 747 | x = self.patch_embed(x) 748 | B,L,C = x.shape 749 | layer_features.append(x.view(B, int(np.sqrt(L)), int(np.sqrt(L)),-1).permute(0, 3, 1, 2).contiguous()) 750 | # layer_features.append(x) 751 | if self.ape: 752 | x = x + self.absolute_pos_embed 753 | x = self.pos_drop(x) 754 | 755 | for layer in self.layers: 756 | x, x_undownsample = layer(x) 757 | B, L, C = x_undownsample.shape 758 | # print('x:', x.shape) 759 | xl = x_undownsample.view(B, int(np.sqrt(L)), int(np.sqrt(L)),-1).permute(0, 3, 1, 2).contiguous() 760 | # print('xl',xl.shape) 761 | layer_features.append(xl) 762 | x = self.norm(x) # B L C 763 | B, L, C = x.shape 764 | # x = self.avgpool(x.ranspose(1, 2)) # B C 1 765 | x = x.view(B, int(np.sqrt(L)), int(np.sqrt(L)),-1).permute(0, 3, 1, 2).contiguous() 766 | # x = torch.flatten(x, 1) 767 | layer_features[-1] = x 768 | 769 | 770 | 771 | return layer_features 772 | 773 | def forward(self, x): 774 | outs = self.forward_features(x) 775 | 776 | return outs 777 | 778 | 779 | class SwinMCNet(nn.Module): 780 | def __init__(self, norm_layer = nn.LayerNorm): 781 | super(SwinMCNet, self).__init__() 782 | 783 | self.swin_image = SwinTransformer(embed_dim=128, depths=[2,2,18,2], num_heads=[4,8,16,32]) 784 | self.swin_thermal = SwinTransformer(embed_dim=128, depths=[2,2,18,2], num_heads=[4,8,16,32]) 785 | 786 | self.bi5 = nn.Sequential(nn.Conv2d(1024, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) 787 | self.bi4 = nn.Sequential(nn.Conv2d( 512, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) 788 | self.bi3 = nn.Sequential(nn.Conv2d( 256, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) 789 | self.bi2 = nn.Sequential(nn.Conv2d( 128, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) 790 | 791 | self.bt5 = nn.Sequential(nn.Conv2d(1024, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) 792 | self.bt4 = nn.Sequential(nn.Conv2d( 512, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) 793 | self.bt3 = nn.Sequential(nn.Conv2d( 256, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) 794 | self.bt2 = nn.Sequential(nn.Conv2d( 128, 64, kernel_size=1), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True)) 795 | 796 | self.att_2 = Attention() 797 | self.att_3 = Attention() 798 | self.att_4 = Attention() 799 | self.att_5 = Attention() 800 | 801 | #self.encoder = Encoder() 802 | self.rfb = LightRFB() 803 | self.decoderi = Decoder() 804 | self.decodert = Decoder() 805 | self.lineari = nn.Conv2d(64, 1, kernel_size=3, padding=1) 806 | self.lineart = nn.Conv2d(64, 1, kernel_size=3, padding=1) 807 | self.linear = nn.Sequential(nn.Conv2d(128, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 1, kernel_size=3, padding=1)) 808 | 809 | def forward(self, image, thermal, shape=None): 810 | image_list = self.swin_image(image) 811 | thermal_list = self.swin_thermal(thermal) 812 | 813 | i1,i2,i3,i4,i5 = image_list[0], image_list[1], image_list[2], image_list[3], image_list[4] 814 | t1,t2,t3,t4,t5 = thermal_list[0], thermal_list[1], thermal_list[2], thermal_list[3], thermal_list[4] 815 | 816 | i2,i3,i4,i5 = self.bi2(i2), self.bi3(i3), self.bi4(i4), self.bi5(i5) 817 | t2,t3,t4,t5 = self.bt2(t2), self.bt3(t3), self.bt4(t4), self.bt5(t5) 818 | 819 | 820 | att2i, att2t = self.att_2(i2, t2) 821 | att3i, att3t = self.att_3(i3, t3) 822 | att4i, att4t = self.att_4(i4, t4) 823 | att5i, att5t = self.att_5(i5, t5) 824 | 825 | out2i, out3i, out4i, out5i = i2 + att2i, i3 + att3i, i4 + att4i, i5 + att5i 826 | out2t, out3t, out4t, out5t = t2 + att2t, t3 + att3t, t4 + att4t, t5 + att5t 827 | 828 | outi1 = self.decoderi([out5i, out4i, out3i, out2i]) 829 | outt1 = self.decodert([out5t, out4t, out3t, out2t]) 830 | 831 | out1 = torch.cat([outi1, outt1], dim=1) 832 | 833 | outi2, outt2 = self.rfb(out1) 834 | 835 | outi2 = self.decoderi([out5i, out4i, out3i, out2i], outi2) 836 | outt2 = self.decodert([out5t, out4t, out3t, out2t], outt2) 837 | 838 | out2 = torch.cat([outi2, outt2], dim=1) 839 | 840 | if shape is None: 841 | shape = image.size()[2:] 842 | out1 = F.interpolate(self.linear(out1), size=shape, mode='bilinear') 843 | outi1 = F.interpolate(self.lineari(outi1), size=shape, mode='bilinear') 844 | outt1 = F.interpolate(self.lineart(outt1), size=shape, mode='bilinear') 845 | 846 | out2 = F.interpolate(self.linear(out2), size=shape, mode='bilinear') 847 | outi2 = F.interpolate(self.lineari(outi2), size=shape, mode='bilinear') 848 | outt2 = F.interpolate(self.lineart(outt2), size=shape, mode='bilinear') 849 | 850 | #return outi1, outt1, out1, outi2, outt2, out2 851 | return i2,i3,i4,i5,t2,t3,t4,t5 852 | 853 | def load_pre(self, pre_model): 854 | self.swin_image.load_state_dict(torch.load(pre_model)['model'],strict=False) 855 | print(f"RGB SwinTransformer loading pre_model ${pre_model}") 856 | self.swin_thermal.load_state_dict(torch.load(pre_model)['model'], strict=False) 857 | print(f"Thermal SwinTransformer loading pre_model ${pre_model}") 858 | --------------------------------------------------------------------------------