├── IFT.png ├── README.md ├── analysis_MatLab ├── analysis_MI.m ├── analysis_Reference.m ├── analysis_SCD.m ├── analysis_ms_ssim.m ├── images │ ├── fuse │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR1.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR10.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR11.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR12.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR13.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR14.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR15.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR16.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR17.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR18.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR19.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR2.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR20.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR21.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR3.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR4.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR5.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR6.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR7.png │ │ ├── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR8.png │ │ └── fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR9.png │ ├── ir │ │ ├── IR1.png │ │ ├── IR10.png │ │ ├── IR11.png │ │ ├── IR12.png │ │ ├── IR13.png │ │ ├── IR14.png │ │ ├── IR15.png │ │ ├── IR16.png │ │ ├── IR17.png │ │ ├── IR18.png │ │ ├── IR19.png │ │ ├── IR2.png │ │ ├── IR20.png │ │ ├── IR21.png │ │ ├── IR3.png │ │ ├── IR4.png │ │ ├── IR5.png │ │ ├── IR6.png │ │ ├── IR7.png │ │ ├── IR8.png │ │ └── IR9.png │ └── vis │ │ ├── VIS1.png │ │ ├── VIS10.png │ │ ├── VIS11.png │ │ ├── VIS12.png │ │ ├── VIS13.png │ │ ├── VIS14.png │ │ ├── VIS15.png │ │ ├── VIS16.png │ │ ├── VIS17.png │ │ ├── VIS18.png │ │ ├── VIS19.png │ │ ├── VIS2.png │ │ ├── VIS20.png │ │ ├── VIS21.png │ │ ├── VIS3.png │ │ ├── VIS4.png │ │ ├── VIS5.png │ │ ├── VIS6.png │ │ ├── VIS7.png │ │ ├── VIS8.png │ │ └── VIS9.png ├── main_all.m └── mef_ssim.m ├── args_fusion.py ├── checkpoint.py ├── data └── kaist-rgbt ├── models └── model │ └── nestfuse_gray_1e2.model ├── net.py ├── pytorch_msssim ├── __init__.py └── __pycache__ │ ├── __init__.cpython-36.pyc │ └── __init__.cpython-37.pyc ├── reference.bib ├── test_21pairs_axial.py ├── train_fusionnet_axial.py └── utils.py /IFT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/IFT.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image-Fusion-Transformer 2 | 3 | [![Framework: PyTorch](https://img.shields.io/badge/Framework-PyTorch-orange.svg)](https://pytorch.org/) 4 | 5 | [Vibashan VS](https://vibashan.github.io/), [Jeya Maria Jose](http://jeya-maria-jose.github.io/research), [Poojan Oza](https://www.linkedin.com/in/poojan-oza-a7b68350/), [Vishal M Patel](https://scholar.google.com/citations?user=AkEXTbIAAAAJ&hl=en) 6 | 7 | [[`Peronal Page`](https://viudomain.github.io/)] [[`ICIP`](https://ieeexplore.ieee.org/abstract/document/9897280)] [[`pdf`](https://arxiv.org/pdf/2107.09011.pdf)] [[`BibTeX`](https://github.com/Vibashan/Image-Fusion-Transformer/blob/main/reference.bib)] 8 | 9 | 10 | ## Platform 11 | Python 3.7 12 | Pytorch >=1.0 13 | 14 | 15 | ## Training Dataset 16 | 17 | [MS-COCO 2014](http://images.cocodataset.org/zips/train2014.zip) (T.-Y. Lin, M. Maire, S. Belongie, J. Hays, P. Perona, D. Ramanan, P. Dollar, and C. L. Zitnick. Microsoft coco: Common objects in context. In ECCV, 2014. 3-5.) is utilized to train our auto-encoder network. 18 | 19 | [KAIST](https://sites.google.com/view/multispectral/home) (S. Hwang, J. Park, N. Kim, Y. Choi, I. So Kweon, Multispectral pedestrian detection: Benchmark dataset and baseline, in: Proceedings of the IEEE conference on computer vision and pattern recognition, 2015, pp. 1037–1045.) is utilized to train the RFN modules. 20 | 21 | The testing datasets are included in "analysis_MatLab". 22 | 23 | ### Training Command: 24 | 25 | ```bash 26 | python train_fusionnet_axial.py 27 | ``` 28 | 29 | ### Testing Command: 30 | 31 | ```bash 32 | python test_21pairs_axial.py 33 | ``` 34 | 35 | The Fusion results are included in "analysis_MatLab". 36 | 37 | 38 | If you have any questions about the code, feel free to contact me at vvishnu2@jh.edu. 39 | 40 | ## Acknowledgement 41 | This codebase is built on top of [RFN-Nest](https://github.com/hli1221/imagefusion-rfn-nest) by [Li Hui](https://github.com/hli1221). 42 | 43 | ## Citation 44 | 45 | If you found IFT useful in your research, please consider starring ⭐ us on GitHub and citing 📚 us in your research! 46 | 47 | ```bibtex 48 | @inproceedings{vs2022image, 49 | title={Image fusion transformer}, 50 | author={Vs, Vibashan and Valanarasu, Jeya Maria Jose and Oza, Poojan and Patel, Vishal M}, 51 | booktitle={2022 IEEE International Conference on Image Processing (ICIP)}, 52 | pages={3566--3570}, 53 | year={2022}, 54 | organization={IEEE} 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /analysis_MatLab/analysis_MI.m: -------------------------------------------------------------------------------- 1 | function MI = analysis_MI(A,B,F) 2 | 3 | % MI_A = nmi(A,F); 4 | % MI_B = nmi(B,F); 5 | 6 | % MI_A = mutual_information_images(A,F); 7 | % MI_B = mutual_information_images(B,F); 8 | 9 | MI_A = MutualInformation(A,F); 10 | MI_B = MutualInformation(B,F); 11 | 12 | MI = MI_A + MI_B; 13 | 14 | end 15 | 16 | % by soleimani h.soleimani@ec.iut.ac.ir 17 | %input---> im1 and im2... they should be in gray scale,[0 255], and have the same size 18 | function MI=mutual_information_images(im1, im2) 19 | im1=double(im1)+1; 20 | im2=double(im2)+1; 21 | 22 | % find joint histogram 23 | joint_histogram=zeros(256,256); 24 | 25 | for i=1:min(size(im1,1),size(im2,1)) 26 | for j=1:min(size(im1,2),size(im2,2)) 27 | joint_histogram(im1(i,j),im2(i,j))= joint_histogram(im1(i,j),im2(i,j))+1; 28 | end 29 | end 30 | 31 | 32 | JPDF=joint_histogram/sum(joint_histogram(:)); % joint pdf of two images 33 | pdf_im1=sum(JPDF,1); % pdf of im1 34 | pdf_im2=sum(JPDF,2); % pdf of im2 35 | 36 | % find MI 37 | MI=0; 38 | for i=1:256 39 | for j=1:256 40 | if JPDF(i,j)>0 41 | MI=MI+JPDF(i,j)*log2(JPDF(i,j)/(pdf_im1(i)*pdf_im2(j))); 42 | end 43 | end 44 | end 45 | end 46 | 47 | % MutualInformation: returns mutual information (in bits) of the 'X' and 'Y' 48 | % by Will Dwinnell 49 | % 50 | % I = MutualInformation(X,Y); 51 | % 52 | % I = calculated mutual information (in bits) 53 | % X = variable(s) to be analyzed (column vector) 54 | % Y = variable to be analyzed (column vector) 55 | % 56 | % Note: Multiple variables may be handled jointly as columns in matrix 'X'. 57 | % Note: Requires the 'Entropy' and 'JointEntropy' functions. 58 | % 59 | % Last modified: Nov-12-2006 60 | 61 | function I = MutualInformation(X,Y) 62 | 63 | if (size(X,2) > 1) % More than one predictor? 64 | % Axiom of information theory 65 | I = JointEntropy(X) + entropy(Y) - JointEntropy([X Y]); 66 | else 67 | % Axiom of information theory 68 | I = entropy(X) + entropy(Y) - JointEntropy([X Y]); 69 | end 70 | 71 | 72 | % God bless Claude Shannon. 73 | 74 | % EOF 75 | end 76 | 77 | 78 | % JointEntropy: Returns joint entropy (in bits) of each column of 'X' 79 | % by Will Dwinnell 80 | % 81 | % H = JointEntropy(X) 82 | % 83 | % H = calculated joint entropy (in bits) 84 | % X = data to be analyzed 85 | % 86 | % Last modified: Aug-29-2006 87 | 88 | function H = JointEntropy(X) 89 | 90 | % Sort to get identical records together 91 | X = sortrows(X); 92 | 93 | % Find elemental differences from predecessors 94 | DeltaRow = (X(2:end,:) ~= X(1:end-1,:)); 95 | 96 | % Summarize by record 97 | Delta = [1; any(DeltaRow')']; 98 | 99 | % Generate vector symbol indices 100 | VectorX = cumsum(Delta); 101 | 102 | % Calculate entropy the usual way on the vector symbols 103 | H = entropy(VectorX); 104 | 105 | 106 | % God bless Claude Shannon. 107 | 108 | % EOF 109 | end -------------------------------------------------------------------------------- /analysis_MatLab/analysis_Reference.m: -------------------------------------------------------------------------------- 1 | function [EN,MI,SCD,MS_SSIM] = analysis_Reference(image_f,image_ir,image_vis) 2 | 3 | [s1,s2] = size(image_ir); 4 | imgSeq = zeros(s1, s2, 2); 5 | imgSeq(:, :, 1) = image_ir; 6 | imgSeq(:, :, 2) = image_vis; 7 | 8 | image1 = im2double(image_ir); 9 | image2 = im2double(image_vis); 10 | image_fused = im2double(image_f); 11 | 12 | %EN 13 | EN = entropy(image_fused); 14 | %MI 15 | MI = analysis_MI(image_ir,image_vis,image_f); 16 | %SCD 17 | SCD = analysis_SCD(image1,image2,image_fused); 18 | %MS_SSIM 19 | [MS_SSIM,t1,t2]= analysis_ms_ssim(imgSeq, image_f); 20 | 21 | end 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /analysis_MatLab/analysis_SCD.m: -------------------------------------------------------------------------------- 1 | function r=analysis_SCD(img1,img2,fus) 2 | %% Emre Bende$ - ERU Computer Engineering - 2015 3 | % calculates image quality metric value based on THE SUM OF THE CORRELATIONS OF DIFFERENCES (SCD). 4 | %inputs: img1 and img2 are the source images 5 | % fus is the fused image 6 | % 7 | % Please site: 8 | % V. Aslantas and E. Bendes, 9 | % "A new image quality metric for image fusion: The sum of the correlations of differences," 10 | % AEU - International Journal of Electronics and Communications, vol. 69/12, pp. 1890-1896, 2015. 11 | r=corr2(fus-img2,img1)+corr2(fus-img1,img2); 12 | end -------------------------------------------------------------------------------- /analysis_MatLab/analysis_ms_ssim.m: -------------------------------------------------------------------------------- 1 | function [oQ, Q, qMap] = analysis_ms_ssim(imgSeq, fI, K, window, level, weight) 2 | % ======================================================================== 3 | % Multi-exposure fused (MEF) image quality model Version 1.0 4 | % Copyright(c) 2015 Kede Ma, Kai Zeng and Zhou Wang 5 | % All Rights Reserved. 6 | % 7 | % ---------------------------------------------------------------------- 8 | % Permission to use, copy, or modify this software and its documentation 9 | % for educational and research purposes only and without fee is hereby 10 | % granted, provided that this copyright notice and the original authors' 11 | % names appear on all copies and supporting documentation. This program 12 | % shall not be used, rewritten, or adapted as the basis of a commercial 13 | % software or hardware product without first obtaining permission of the 14 | % authors. The authors make no representations about the suitability of 15 | % this software for any purpose. It is provided "as is" without express 16 | % or implied warranty. 17 | %---------------------------------------------------------------------- 18 | % This is an implementation of an objective image quality assessment model 19 | % for MEF images using their corresponding input source sequences 20 | % as reference. 21 | % 22 | % Please refer to the following paper: 23 | % 24 | % K. Ma et al., "Perceptual Quality Assessment for Multi-Exposure 25 | % Image Fusion" submitted to IEEE Transactions on 26 | % Image Processing. 27 | % 28 | % 29 | % Kindly report any suggestions or corrections to k29ma@uwaterloo.ca, 30 | % kzeng@uwaterloo.ca or zhouwang@ieee.org 31 | % 32 | %---------------------------------------------------------------------- 33 | % 34 | %Input : (1) imgSeq: source sequence being used as reference in [0-255] grayscale. 35 | % (2) fI: the MEF image being compared in [0-255] grayscale. 36 | % (3) K: constant in the SSIM index formula (see the above 37 | % reference). defualt value: K = 0.03 38 | % (4) window: local window for statistics. default widnow is 39 | % Gaussian given by window = fspecial('gaussian', 11, 1.5); 40 | % (5) level: multi-scale level used for downsampling. default value: 41 | % level = 3; 42 | % (6) weight: the weights in each scale (see the above reference). 43 | % default value is given by 44 | % weight = [0.0448 0.2856 0.3001]'; 45 | % weight = weight / sum(weight); 46 | % Note that the length of weight and level should be the same. 47 | % 48 | %Output: (1) oQ: The overlll quality score of the MEF image. 49 | % (2) Q: The quality scores in each scale. 50 | % (3) qMap: The quality maps of the MEF image in each scale. 51 | % 52 | %Basic Usage: 53 | % Given the test MEF image and its corresponding source sequence 54 | % 55 | % [oQ, Q, qMap] = mef_ms_ssim(imgSeq, fI); 56 | % 57 | % 58 | %======================================================================== 59 | 60 | 61 | if (nargin < 2 || nargin > 6) 62 | oQ = -Inf; 63 | Q = -Inf; 64 | qMap = -Inf; 65 | return; 66 | end 67 | 68 | if (~exist('K', 'var')) 69 | K = 0.03; 70 | end 71 | 72 | if (~exist('window', 'var')) 73 | window = fspecial('gaussian', 11, 1.5); 74 | end 75 | 76 | [H, W] = size(window); 77 | 78 | if (~exist('level','var')) 79 | level = 3; 80 | end 81 | 82 | if (~exist('weight', 'var')) 83 | weight = [0.0448 0.2856 0.3001]'; 84 | weight = weight / sum(weight); 85 | end 86 | 87 | if level ~= length(weight) 88 | oQ = -Inf; 89 | Q = -Inf; 90 | qMap = -Inf; 91 | return; 92 | end 93 | 94 | [s1, s2, s3] = size(imgSeq); 95 | minImgWidth = min(s1, s2)/(2^(level-1)); 96 | maxWinWidth = max(H, W); 97 | if (minImgWidth < maxWinWidth) 98 | oQ = -Inf; 99 | Q = -Inf; 100 | qMap = Inf; 101 | return; 102 | end 103 | 104 | imgSeq = double(imgSeq); 105 | fI = double(fI); 106 | downsampleFilter = ones(2)./4; 107 | Q = zeros(level,1); 108 | qMap = cell(level,1); 109 | if level == 1 110 | [Q, qMap] = mef_ssim(imgSeq, fI, K, window); 111 | oQ = Q; 112 | return; 113 | else 114 | for l = 1 : level - 1 115 | [Q(l), qMap{l}] = mef_ssim(imgSeq, fI, K, window); 116 | imgSeqC = imgSeq; 117 | clear imgSeq; 118 | for i = 1:s3 119 | rI = squeeze(imgSeqC(:,:,i)); 120 | dI = imfilter(rI, downsampleFilter, 'symmetric', 'same'); 121 | imgSeq(:,:,i) = dI(1:2:end, 1:2:end); 122 | end 123 | dI = imfilter(fI, downsampleFilter, 'symmetric', 'same'); 124 | clear fI; 125 | fI = dI(1:2:end, 1:2:end); 126 | end 127 | % the coarsest scale 128 | [Q(level), qMap{level}] = mef_ssim(imgSeq, fI, K, window); 129 | Q = Q(:); 130 | oQ = prod(Q.^weight); 131 | end 132 | 133 | -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR1.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR10.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR11.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR12.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR13.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR14.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR15.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR16.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR17.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR18.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR19.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR2.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR20.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR21.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR3.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR4.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR5.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR6.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR7.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR8.png -------------------------------------------------------------------------------- /analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/fuse/fused_rfnnest_700_wir_6.0_wvi_3.0axial_IR9.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR1.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR10.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR11.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR12.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR13.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR14.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR15.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR16.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR17.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR18.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR19.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR2.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR20.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR21.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR3.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR4.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR5.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR6.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR7.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR8.png -------------------------------------------------------------------------------- /analysis_MatLab/images/ir/IR9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/ir/IR9.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS1.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS10.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS11.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS12.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS13.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS14.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS15.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS16.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS17.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS18.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS19.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS2.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS20.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS21.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS3.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS4.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS5.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS6.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS7.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS8.png -------------------------------------------------------------------------------- /analysis_MatLab/images/vis/VIS9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/analysis_MatLab/images/vis/VIS9.png -------------------------------------------------------------------------------- /analysis_MatLab/main_all.m: -------------------------------------------------------------------------------- 1 | 2 | dir_source_ir = dir("./images/ir/"); 3 | dir_source_vis = dir("./images/vis/"); 4 | dir_fused = dir("./images/fuse/"); 5 | 6 | disp("Start"); 7 | disp('---------------------------Analysis---------------------------'); 8 | 9 | all = [0,0,0,0]; 10 | std_dev=0; 11 | for i = 3:23; 12 | 13 | fused_image = imread("./images/fuse/"+dir_fused(i).name); 14 | source_image1 = imread("./images/ir/"+dir_source_ir(i).name); 15 | source_image2 = imread("./images/vis/"+dir_source_vis(i).name); 16 | 17 | [EN,MI,SCD,MS_SSIM] = analysis_Reference(fused_image,source_image1,source_image2); 18 | all = all + [EN,MI,SCD,MS_SSIM]; 19 | end 20 | 21 | disp(all/21); 22 | disp('Done'); -------------------------------------------------------------------------------- /analysis_MatLab/mef_ssim.m: -------------------------------------------------------------------------------- 1 | function [Q, qMap] = mef_ssim(imgSeq, fI, K, window) 2 | 3 | if (nargin < 2 || nargin > 4) 4 | Q = -Inf; 5 | qMap = Inf; 6 | return; 7 | end 8 | 9 | if (~exist('K', 'var')) 10 | K = 0.03; 11 | end 12 | 13 | if (~exist('window', 'var')) 14 | window = fspecial('gaussian', 11, 1.5); 15 | end 16 | 17 | 18 | imgSeq = double(imgSeq); 19 | fI = double(fI); 20 | [s1, s2, s3] = size(imgSeq); 21 | wSize = size(window,1); 22 | sWindow = ones(wSize) / wSize^2; % square window used to calculate the distance 23 | bd = floor(wSize/2); 24 | mu = zeros(s1-2*bd, s2-2*bd, s3); 25 | ed = zeros(s1-2*bd, s2-2*bd, s3); 26 | for i = 1:s3 27 | img = squeeze(imgSeq(:,:,i)); 28 | mu(:,:,i) = filter2(sWindow, img, 'valid'); 29 | muSq = mu(:,:,i) .* mu(:,:,i); 30 | sigmaSq = filter2(sWindow, img.*img, 'valid') - muSq; 31 | ed(:,:,i) = sqrt( max( wSize^2 * sigmaSq, 0 ) ) + 0.001; % add a small constant to avoid instability 32 | end 33 | 34 | R = zeros(s1-2*bd,s2-2*bd); % consistency map which could be used as an output if necessary 35 | for i = bd+1:s1-bd 36 | for j = bd+1:s2-bd 37 | vecs = reshape(imgSeq(i-bd:i+bd,j-bd:j+bd,:),[wSize*wSize, s3]); 38 | denominator = 0; 39 | for k = 1:s3 40 | denominator = denominator + norm(vecs(:,k) - mu(i-bd,j-bd,k)); 41 | end 42 | numerator = norm(sum(vecs,2) - mean(sum(vecs,2))); 43 | R(i-bd,j-bd) = (numerator + eps) / (denominator + eps); 44 | end 45 | end 46 | 47 | R(R > 1) = 1 - eps; % get rid of numerical instability 48 | R(R < 0) = 0 + eps; 49 | 50 | 51 | p = tan(pi/2 * R); 52 | p( p > 10 ) = 10; % to avoid blow up (large number such as 10 is equivalent to taking maximum) 53 | p = repmat(p,[1,1,s3]); 54 | 55 | 56 | wMap = (ed / wSize).^p + eps; % to avoid blowing up 57 | normalizer = sum(wMap,3); 58 | wMap = wMap ./ repmat(normalizer,[1,1,s3]); 59 | maxEd = max(ed,[],3); 60 | 61 | C = (K * 255)^2; 62 | qMap = zeros(s1-2*bd, s2-2*bd); 63 | for i = bd+1:s1-bd 64 | for j = bd+1:s2-bd 65 | blocks = imgSeq(i-bd:i+bd,j-bd:j+bd,:); 66 | rBlock = zeros(wSize,wSize); 67 | for k = 1 : s3 68 | rBlock = rBlock + wMap(i-bd,j-bd,k) * ( blocks(:,:,k) - mu(i-bd,j-bd,k) ) / ed(i-bd,j-bd,k); 69 | end 70 | if norm(rBlock(:)) > 0 71 | rBlock = rBlock / norm(rBlock(:)) * maxEd(i-bd,j-bd); 72 | end 73 | fBlock = fI(i-bd:i+bd,j-bd:j+bd); 74 | rVec = rBlock(:); 75 | fVec = fBlock(:); 76 | mu1 = sum( window(:) .* rVec ); 77 | mu2 = sum( window(:) .* fVec ); 78 | sigma1Sq = sum( window(:) .* (rVec - mu1).^2 ); 79 | sigma2Sq = sum( window(:) .* (fVec - mu2).^2 ); 80 | sigma12 = sum( window(:) .* (rVec - mu1) .* (fVec - mu2) ); 81 | qMap(i-bd,j-bd) = ( 2 * sigma12 + C ) ./ ( sigma1Sq + sigma2Sq + C ); 82 | end 83 | end 84 | 85 | Q = mean2(qMap); 86 | -------------------------------------------------------------------------------- /args_fusion.py: -------------------------------------------------------------------------------- 1 | 2 | class args(): 3 | 4 | # training args 5 | epochs = 2 #"number of training epochs, default is 2" 6 | batch_size = 4 #"batch size for training, default is 4" 7 | dataset_ir = "./data/kaist-rgbt/images" 8 | dataset_vi = "./data/kaist-rgbt/images" 9 | 10 | dataset = 'medical' 11 | #dataset = 'focus' 12 | 13 | HEIGHT = 256 14 | WIDTH = 256 15 | 16 | save_fusion_model = "models/train/fusionnet/" 17 | save_loss_dir = './models/train/loss_fusionnet/' 18 | 19 | image_size = 256 #"size of training images, default is 256 X 256" 20 | cuda = 1 #"set it to 1 for running on GPU, 0 for CPU" 21 | seed = 42 #"random seed for training" 22 | 23 | lr = 1e-4 #"learning rate, default is 0.001" 24 | log_interval = 10 #"number of images after which the training loss is logged, default is 500" 25 | resume_fusion_model = None 26 | # nest net model 27 | resume_nestfuse = './models/nestfuse/nestfuse_gray_1e2.model' 28 | resume_vit = './imagenet21k+imagenet2012_ViT-L_16.pth' 29 | fusion_model = './models/rfn_twostage/' 30 | 31 | mode = "fusion_axial" 32 | 33 | 34 | 35 | -------------------------------------------------------------------------------- /checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from tensorflow.io import gfile 4 | import numpy as np 5 | 6 | 7 | def load_checkpoint(path): 8 | """ Load weights from a given checkpoint path in npz/pth """ 9 | if path.endswith('npz'): 10 | keys, values = load_jax(path) 11 | state_dict = convert_jax_pytorch(keys, values) 12 | elif path.endswith('pth'): 13 | state_dict = torch.load(path)['state_dict'] 14 | else: 15 | raise ValueError("checkpoint format {} not supported yet!".format(path.split('.')[-1])) 16 | 17 | return state_dict 18 | 19 | 20 | def load_jax(path): 21 | """ Loads params from a npz checkpoint previously stored with `save()` in jax implemetation """ 22 | with gfile.GFile(path, 'rb') as f: 23 | ckpt_dict = np.load(f, allow_pickle=False) 24 | keys, values = zip(*list(ckpt_dict.items())) 25 | return keys, values 26 | 27 | 28 | def save_jax_to_pytorch(jax_path, save_path): 29 | model_name = jax_path.split('/')[-1].split('.')[0] 30 | keys, values = load_jax(jax_path) 31 | state_dict = convert_jax_pytorch(keys, values) 32 | checkpoint = {'state_dict': state_dict} 33 | torch.save(checkpoint, os.path.join(save_path, model_name + '.pth')) 34 | 35 | 36 | def replace_names(names): 37 | """ Replace jax model names with pytorch model names """ 38 | new_names = [] 39 | for name in names: 40 | if name == 'Transformer': 41 | new_names.append('transformer') 42 | elif name == 'encoder_norm': 43 | new_names.append('norm') 44 | elif 'encoderblock' in name: 45 | num = name.split('_')[-1] 46 | new_names.append('encoder_layers') 47 | new_names.append(num) 48 | elif 'LayerNorm' in name: 49 | num = name.split('_')[-1] 50 | if num == '0': 51 | new_names.append('norm{}'.format(1)) 52 | elif num == '2': 53 | new_names.append('norm{}'.format(2)) 54 | elif 'MlpBlock' in name: 55 | new_names.append('mlp') 56 | elif 'Dense' in name: 57 | num = name.split('_')[-1] 58 | new_names.append('fc{}'.format(int(num) + 1)) 59 | elif 'MultiHeadDotProductAttention' in name: 60 | new_names.append('attn') 61 | elif name == 'kernel' or name == 'scale': 62 | new_names.append('weight') 63 | elif name == 'bias': 64 | new_names.append(name) 65 | elif name == 'posembed_input': 66 | new_names.append('pos_embedding') 67 | elif name == 'pos_embedding': 68 | new_names.append('pos_embedding') 69 | elif name == 'embedding': 70 | new_names.append('embedding') 71 | elif name == 'head': 72 | new_names.append('classifier') 73 | elif name == 'cls': 74 | new_names.append('cls_token') 75 | else: 76 | new_names.append(name) 77 | return new_names 78 | 79 | 80 | def convert_jax_pytorch(keys, values): 81 | """ Convert jax model parameters with pytorch model parameters """ 82 | state_dict = {} 83 | for key, value in zip(keys, values): 84 | 85 | # convert name to torch names 86 | names = key.split('/') 87 | torch_names = replace_names(names) 88 | torch_key = '.'.join(w for w in torch_names) 89 | 90 | # convert values to tensor and check shapes 91 | tensor_value = torch.tensor(value, dtype=torch.float) 92 | # check shape 93 | num_dim = len(tensor_value.shape) 94 | 95 | if num_dim == 1: 96 | tensor_value = tensor_value.squeeze() 97 | elif num_dim == 2 and torch_names[-1] == 'weight': 98 | # for normal weight, transpose it 99 | tensor_value = tensor_value.T 100 | elif num_dim == 3 and torch_names[-1] == 'weight' and torch_names[-2] in ['query', 'key', 'value']: 101 | feat_dim, num_heads, head_dim = tensor_value.shape 102 | # for multi head attention q/k/v weight 103 | tensor_value = tensor_value 104 | elif num_dim == 2 and torch_names[-1] == 'bias' and torch_names[-2] in ['query', 'key', 'value']: 105 | # for multi head attention q/k/v bias 106 | tensor_value = tensor_value 107 | elif num_dim == 3 and torch_names[-1] == 'weight' and torch_names[-2] == 'out': 108 | # for multi head attention out weight 109 | tensor_value = tensor_value 110 | elif num_dim == 4 and torch_names[-1] == 'weight': 111 | tensor_value = tensor_value.permute(3, 2, 0, 1) 112 | 113 | # print("{}: {}".format(torch_key, tensor_value.shape)) 114 | state_dict[torch_key] = tensor_value 115 | return state_dict 116 | 117 | 118 | if __name__ == '__main__': 119 | save_jax_to_pytorch('/Users/leon/Downloads/jax/imagenet21k+imagenet2012_ViT-L_16-224.npz', '/Users/leon/Downloads/pytorch') 120 | 121 | 122 | -------------------------------------------------------------------------------- /data/kaist-rgbt: -------------------------------------------------------------------------------- 1 | /media/dataset/dataset/KAIST/rgbt-ped-detection/data/kaist-rgbt -------------------------------------------------------------------------------- /models/model/nestfuse_gray_1e2.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/models/model/nestfuse_gray_1e2.model -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import math 7 | import random 8 | import pdb 9 | 10 | 11 | EPSILON = 1e-10 12 | 13 | def var(x, dim=0): 14 | x_zero_meaned = x - x.mean(dim).expand_as(x) 15 | return x_zero_meaned.pow(2).mean(dim) 16 | 17 | class MultConst(nn.Module): 18 | def forward(self, input): 19 | return 255*input 20 | 21 | class UpsampleReshape_eval(torch.nn.Module): 22 | def __init__(self): 23 | super(UpsampleReshape_eval, self).__init__() 24 | self.up = nn.Upsample(scale_factor=2) 25 | 26 | def forward(self, x1, x2): 27 | x2 = self.up(x2) 28 | shape_x1 = x1.size() 29 | shape_x2 = x2.size() 30 | left = 0 31 | right = 0 32 | top = 0 33 | bot = 0 34 | if shape_x1[3] != shape_x2[3]: 35 | lef_right = shape_x1[3] - shape_x2[3] 36 | if lef_right%2 is 0.0: 37 | left = int(lef_right/2) 38 | right = int(lef_right/2) 39 | else: 40 | left = int(lef_right / 2) 41 | right = int(lef_right - left) 42 | 43 | if shape_x1[2] != shape_x2[2]: 44 | top_bot = shape_x1[2] - shape_x2[2] 45 | if top_bot%2 is 0.0: 46 | top = int(top_bot/2) 47 | bot = int(top_bot/2) 48 | else: 49 | top = int(top_bot / 2) 50 | bot = int(top_bot - top) 51 | 52 | reflection_padding = [left, right, top, bot] 53 | reflection_pad = nn.ReflectionPad2d(reflection_padding) 54 | x2 = reflection_pad(x2) 55 | return x2 56 | 57 | # Dense convolution unit 58 | class DenseConv2d(torch.nn.Module): 59 | def __init__(self, in_channels, out_channels, kernel_size, stride): 60 | super(DenseConv2d, self).__init__() 61 | self.dense_conv = ConvLayer(in_channels, out_channels, kernel_size, stride) 62 | 63 | def forward(self, x): 64 | out = self.dense_conv(x) 65 | out = torch.cat([x, out], 1) 66 | return out 67 | 68 | def conv1x1(in_planes, out_planes, stride=1): 69 | """1x1 convolution""" 70 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 71 | 72 | 73 | class qkv_transform(nn.Conv1d): 74 | """Conv1d for qkv_transform""" 75 | 76 | def _make_layer(self, block, planes, blocks, kernel_size=56, stride=1, dilate=False): 77 | norm_layer = self._norm_layer 78 | downsample = None 79 | previous_dilation = self.dilation 80 | if dilate: 81 | self.dilation *= stride 82 | stride = 1 83 | if stride != 1 or self.inplanes != planes * block.expansion: 84 | downsample = nn.Sequential( 85 | conv1x1(self.inplanes, planes * block.expansion, stride), 86 | norm_layer(planes * block.expansion), 87 | ) 88 | 89 | layers = [] 90 | layers.append(block(self.inplanes, planes, stride, downsample, groups=self.groups, 91 | base_width=self.base_width, dilation=previous_dilation, 92 | norm_layer=norm_layer, kernel_size=kernel_size)) 93 | self.inplanes = planes * block.expansion 94 | if stride != 1: 95 | kernel_size = kernel_size // 2 96 | 97 | for _ in range(1, blocks): 98 | layers.append(block(self.inplanes, planes, groups=self.groups, 99 | base_width=self.base_width, dilation=self.dilation, 100 | norm_layer=norm_layer, kernel_size=kernel_size)) 101 | 102 | return nn.Sequential(*layers) 103 | 104 | # Dense Block unit 105 | # light version 106 | class DenseBlock_light(torch.nn.Module): 107 | def __init__(self, in_channels, out_channels, kernel_size, stride): 108 | super(DenseBlock_light, self).__init__() 109 | # out_channels_def = 16 110 | out_channels_def = int(in_channels / 2) 111 | # out_channels_def = out_channels 112 | denseblock = [] 113 | denseblock += [ConvLayer(in_channels, out_channels_def, kernel_size, stride), 114 | ConvLayer(out_channels_def, out_channels, 1, stride)] 115 | self.denseblock = nn.Sequential(*denseblock) 116 | 117 | def forward(self, x): 118 | out = self.denseblock(x) 119 | return out 120 | 121 | class ConvLayer(torch.nn.Module): 122 | def __init__(self, in_channels, out_channels, kernel_size, stride, is_last=False): 123 | super(ConvLayer, self).__init__() 124 | reflection_padding = int(np.floor(kernel_size / 2)) 125 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 126 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride) 127 | self.dropout = nn.Dropout2d(p=0.5) 128 | self.is_last = is_last 129 | 130 | def forward(self, x): 131 | out = self.reflection_pad(x) 132 | out = self.conv2d(out) 133 | if self.is_last is False: 134 | # out = F.normalize(out) 135 | out = F.relu(out, inplace=True) 136 | # out = self.dropout(out) 137 | return out 138 | 139 | # Convolution operation 140 | class f_ConvLayer(torch.nn.Module): 141 | def __init__(self, in_channels, out_channels, kernel_size, stride, is_last=False): 142 | super(f_ConvLayer, self).__init__() 143 | reflection_padding = int(np.floor(kernel_size / 2)) 144 | self.reflection_pad = nn.ReflectionPad2d(reflection_padding) 145 | self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride) 146 | #self.batch_norm = nn.BatchNorm2d(out_channels) 147 | self.dropout = nn.Dropout2d(p=0.5) 148 | self.is_last = is_last 149 | 150 | def forward(self, x): 151 | out = self.reflection_pad(x) 152 | out = self.conv2d(out) 153 | #out = self.batch_norm(out) 154 | out = F.relu(out, inplace=True) 155 | return out 156 | 157 | class FusionBlock_res(torch.nn.Module): 158 | def __init__(self, channels, img_size, index): 159 | super(FusionBlock_res, self).__init__() 160 | 161 | self.axial_attn = AxialBlock(channels, channels//2, kernel_size=img_size) 162 | 163 | self.axial_fusion = nn.Sequential(f_ConvLayer(2*channels, channels, 1, 1)) 164 | self.conv_fusion = nn.Sequential(f_ConvLayer(channels, channels, 1, 1)) 165 | #self.conv_fusion_bn = nn.BatchNorm2d(channels) 166 | 167 | 168 | block = [] 169 | block += [f_ConvLayer(2*channels, channels, 1, 1), 170 | f_ConvLayer(channels, channels, 3, 1), 171 | f_ConvLayer(channels, channels, 3, 1)] 172 | self.bottelblock = nn.Sequential(*block) 173 | #self.block_bn = nn.BatchNorm2d(channels) 174 | #self.relu = nn.ReLU(inplace=True) 175 | 176 | def forward(self, x_ir, x_vi): 177 | # initial fusion - conv 178 | a_cat = torch.cat([self.axial_attn(x_ir), self.axial_attn(x_vi)], 1) 179 | a_init = self.axial_fusion(a_cat) 180 | 181 | x_cvi = self.conv_fusion(x_vi) 182 | x_cir = self.conv_fusion(x_ir) 183 | 184 | out = torch.cat([x_cvi, x_cir], 1) 185 | out = self.bottelblock(out) 186 | out = a_init + out 187 | 188 | return out 189 | 190 | 191 | # Fusion network, 4 groups of features 192 | class Fusion_network(nn.Module): 193 | def __init__(self, nC, fs_type): 194 | super(Fusion_network, self).__init__() 195 | self.fs_type = fs_type 196 | img_size = [256,128,64,32] 197 | #img_size = [84,42,21,10] 198 | 199 | self.fusion_block1 = FusionBlock_res(nC[0], img_size[0], 0) 200 | self.fusion_block2 = FusionBlock_res(nC[1], img_size[1], 1) 201 | self.fusion_block3 = FusionBlock_res(nC[2], img_size[2], 2) 202 | self.fusion_block4 = FusionBlock_res(nC[3], img_size[3], 3) 203 | 204 | def forward(self, en_ir, en_vi): 205 | f1_0 = self.fusion_block1(en_ir[0], en_vi[0]) 206 | f2_0 = self.fusion_block2(en_ir[1], en_vi[1]) 207 | f3_0 = self.fusion_block3(en_ir[2], en_vi[2]) 208 | f4_0 = self.fusion_block4(en_ir[3], en_vi[3]) 209 | 210 | return [f1_0, f2_0, f3_0, f4_0] 211 | 212 | class Fusion_ADD(torch.nn.Module): 213 | def forward(self, en_ir, en_vi): 214 | temp = en_ir + en_vi 215 | return temp 216 | 217 | class Fusion_AVG(torch.nn.Module): 218 | def forward(self, en_ir, en_vi): 219 | temp = (en_ir + en_vi) / 2 220 | return temp 221 | 222 | class Fusion_MAX(torch.nn.Module): 223 | def forward(self, en_ir, en_vi): 224 | temp = torch.max(en_ir, en_vi) 225 | return temp 226 | 227 | class Fusion_SPA(torch.nn.Module): 228 | def forward(self, en_ir, en_vi): 229 | shape = en_ir.size() 230 | spatial_type = 'mean' 231 | # calculate spatial attention 232 | spatial1 = spatial_attention(en_ir, spatial_type) 233 | spatial2 = spatial_attention(en_vi, spatial_type) 234 | # get weight map, soft-max 235 | spatial_w1 = torch.exp(spatial1) / (torch.exp(spatial1) + torch.exp(spatial2) + EPSILON) 236 | spatial_w2 = torch.exp(spatial2) / (torch.exp(spatial1) + torch.exp(spatial2) + EPSILON) 237 | 238 | spatial_w1 = spatial_w1.repeat(1, shape[1], 1, 1) 239 | spatial_w2 = spatial_w2.repeat(1, shape[1], 1, 1) 240 | tensor_f = spatial_w1 * en_ir + spatial_w2 * en_vi 241 | return tensor_f 242 | 243 | # spatial attention 244 | def spatial_attention(tensor, spatial_type='sum'): 245 | spatial = [] 246 | if spatial_type is 'mean': 247 | spatial = tensor.mean(dim=1, keepdim=True) 248 | elif spatial_type is 'sum': 249 | spatial = tensor.sum(dim=1, keepdim=True) 250 | return spatial 251 | 252 | # fuison strategy based on nuclear-norm (channel attention form NestFuse) 253 | class Fusion_Nuclear(torch.nn.Module): 254 | def forward(self, en_ir, en_vi): 255 | shape = en_ir.size() 256 | # calculate channel attention 257 | global_p1 = nuclear_pooling(en_ir) 258 | global_p2 = nuclear_pooling(en_vi) 259 | 260 | # get weight map 261 | global_p_w1 = global_p1 / (global_p1 + global_p2 + EPSILON) 262 | global_p_w2 = global_p2 / (global_p1 + global_p2 + EPSILON) 263 | 264 | global_p_w1 = global_p_w1.repeat(1, 1, shape[2], shape[3]) 265 | global_p_w2 = global_p_w2.repeat(1, 1, shape[2], shape[3]) 266 | 267 | tensor_f = global_p_w1 * en_ir + global_p_w2 * en_vi 268 | return tensor_f 269 | 270 | # sum of S V for each chanel 271 | def nuclear_pooling(tensor): 272 | shape = tensor.size() 273 | vectors = torch.zeros(1, shape[1], 1, 1).cuda() 274 | for i in range(shape[1]): 275 | u, s, v = torch.svd(tensor[0, i, :, :] + EPSILON) 276 | s_sum = torch.sum(s) 277 | vectors[0, i, 0, 0] = s_sum 278 | return vectors 279 | 280 | # Fusion strategy, two type 281 | class Fusion_strategy(nn.Module): 282 | def __init__(self, fs_type): 283 | super(Fusion_strategy, self).__init__() 284 | self.fs_type = fs_type 285 | self.fusion_add = Fusion_ADD() 286 | self.fusion_avg = Fusion_AVG() 287 | self.fusion_max = Fusion_MAX() 288 | self.fusion_spa = Fusion_SPA() 289 | self.fusion_nuc = Fusion_Nuclear() 290 | 291 | def forward(self, en_ir, en_vi): 292 | if self.fs_type is 'add': 293 | fusion_operation = self.fusion_add 294 | elif self.fs_type is 'avg': 295 | fusion_operation = self.fusion_avg 296 | elif self.fs_type is 'max': 297 | fusion_operation = self.fusion_max 298 | elif self.fs_type is 'spa': 299 | fusion_operation = self.fusion_spa 300 | elif self.fs_type is 'nuclear': 301 | fusion_operation = self.fusion_nuc 302 | 303 | f1_0 = fusion_operation(en_ir[0], en_vi[0]) 304 | f2_0 = fusion_operation(en_ir[1], en_vi[1]) 305 | f3_0 = fusion_operation(en_ir[2], en_vi[2]) 306 | f4_0 = fusion_operation(en_ir[3], en_vi[3]) 307 | return [f1_0, f2_0, f3_0, f4_0] 308 | 309 | 310 | # NestFuse network - light, no desnse 311 | class NestFuse_light2_nodense(nn.Module): 312 | def __init__(self, nb_filter, input_nc=1, output_nc=1, deepsupervision=True): 313 | super(NestFuse_light2_nodense, self).__init__() 314 | self.deepsupervision = deepsupervision 315 | block = DenseBlock_light 316 | output_filter = 16 317 | kernel_size = 3 318 | stride = 1 319 | 320 | self.pool = nn.MaxPool2d(2, 2) 321 | self.up = nn.Upsample(scale_factor=2) 322 | self.up_eval = UpsampleReshape_eval() 323 | 324 | # encoder 325 | self.conv0 = ConvLayer(input_nc, output_filter, 1, stride) 326 | self.DB1_0 = block(output_filter, nb_filter[0], kernel_size, 1) 327 | self.DB2_0 = block(nb_filter[0], nb_filter[1], kernel_size, 1) 328 | self.DB3_0 = block(nb_filter[1], nb_filter[2], kernel_size, 1) 329 | self.DB4_0 = block(nb_filter[2], nb_filter[3], kernel_size, 1) 330 | 331 | # decoder 332 | self.DB1_1 = block(nb_filter[0] + nb_filter[1], nb_filter[0], kernel_size, 1) 333 | self.DB2_1 = block(nb_filter[1] + nb_filter[2], nb_filter[1], kernel_size, 1) 334 | self.DB3_1 = block(nb_filter[2] + nb_filter[3], nb_filter[2], kernel_size, 1) 335 | 336 | # short connection 337 | self.DB1_2 = block(nb_filter[0] * 2 + nb_filter[1], nb_filter[0], kernel_size, 1) 338 | self.DB2_2 = block(nb_filter[1] * 2+ nb_filter[2], nb_filter[1], kernel_size, 1) 339 | self.DB1_3 = block(nb_filter[0] * 3 + nb_filter[1], nb_filter[0], kernel_size, 1) 340 | 341 | if self.deepsupervision: 342 | self.conv1 = ConvLayer(nb_filter[0], output_nc, 1, stride) 343 | self.conv2 = ConvLayer(nb_filter[0], output_nc, 1, stride) 344 | self.conv3 = ConvLayer(nb_filter[0], output_nc, 1, stride) 345 | # self.conv4 = ConvLayer(nb_filter[0], output_nc, 1, stride) 346 | else: 347 | self.conv_out = ConvLayer(nb_filter[0], output_nc, 1, stride) 348 | 349 | def encoder(self, input): 350 | x = self.conv0(input) 351 | x1_0 = self.DB1_0(x) 352 | x2_0 = self.DB2_0(self.pool(x1_0)) 353 | x3_0 = self.DB3_0(self.pool(x2_0)) 354 | x4_0 = self.DB4_0(self.pool(x3_0)) 355 | # x5_0 = self.DB5_0(self.pool(x4_0)) 356 | return [x1_0, x2_0, x3_0, x4_0] 357 | 358 | def decoder_train(self, f_en): 359 | x1_1 = self.DB1_1(torch.cat([f_en[0], self.up(f_en[1])], 1)) 360 | 361 | x2_1 = self.DB2_1(torch.cat([f_en[1], self.up(f_en[2])], 1)) 362 | x1_2 = self.DB1_2(torch.cat([f_en[0], x1_1, self.up(x2_1)], 1)) 363 | 364 | x3_1 = self.DB3_1(torch.cat([f_en[2], self.up(f_en[3])], 1)) 365 | x2_2 = self.DB2_2(torch.cat([f_en[1], x2_1, self.up(x3_1)], 1)) 366 | x1_3 = self.DB1_3(torch.cat([f_en[0], x1_1, x1_2, self.up(x2_2)], 1)) 367 | 368 | if self.deepsupervision: 369 | output1 = self.conv1(x1_1) 370 | output2 = self.conv2(x1_2) 371 | output3 = self.conv3(x1_3) 372 | # output4 = self.conv4(x1_4) 373 | return [output1, output2, output3] 374 | else: 375 | output = self.conv_out(x1_3) 376 | return [output] 377 | 378 | def decoder_eval(self, f_en): 379 | x1_1 = self.DB1_1(torch.cat([f_en[0], self.up_eval(f_en[0], f_en[1])], 1)) 380 | 381 | x2_1 = self.DB2_1(torch.cat([f_en[1], self.up_eval(f_en[1], f_en[2])], 1)) 382 | x1_2 = self.DB1_2(torch.cat([f_en[0], x1_1, self.up_eval(f_en[0], x2_1)], 1)) 383 | 384 | x3_1 = self.DB3_1(torch.cat([f_en[2], self.up_eval(f_en[2], f_en[3])], 1)) 385 | x2_2 = self.DB2_2(torch.cat([f_en[1], x2_1, self.up_eval(f_en[1], x3_1)], 1)) 386 | 387 | x1_3 = self.DB1_3(torch.cat([f_en[0], x1_1, x1_2, self.up_eval(f_en[0], x2_2)], 1)) 388 | 389 | if self.deepsupervision: 390 | output1 = self.conv1(x1_1) 391 | output2 = self.conv2(x1_2) 392 | output3 = self.conv3(x1_3) 393 | # output4 = self.conv4(x1_4) 394 | return [output1, output2, output3] 395 | else: 396 | output = self.conv_out(x1_3) 397 | return [output] 398 | 399 | class RFN_decoder(nn.Module): 400 | def __init__(self, nb_filter, input_nc=1, output_nc=1, deepsupervision=True): 401 | super(RFN_decoder, self).__init__() 402 | self.deepsupervision = deepsupervision 403 | block = DenseBlock_light 404 | output_filter = 16 405 | kernel_size = 3 406 | stride = 1 407 | 408 | self.pool = nn.MaxPool2d(2, 2) 409 | self.up = nn.Upsample(scale_factor=2) 410 | self.up_eval = UpsampleReshape_eval() 411 | 412 | # decoder 413 | self.DB1_1 = block(nb_filter[0] + nb_filter[1], nb_filter[0], kernel_size, 1) 414 | self.DB2_1 = block(nb_filter[1] + nb_filter[2], nb_filter[1], kernel_size, 1) 415 | self.DB3_1 = block(nb_filter[2] + nb_filter[3], nb_filter[2], kernel_size, 1) 416 | 417 | # short connection 418 | self.DB1_2 = block(nb_filter[0] * 2 + nb_filter[1], nb_filter[0], kernel_size, 1) 419 | self.DB2_2 = block(nb_filter[1] * 2+ nb_filter[2], nb_filter[1], kernel_size, 1) 420 | self.DB1_3 = block(nb_filter[0] * 3 + nb_filter[1], nb_filter[0], kernel_size, 1) 421 | 422 | if self.deepsupervision: 423 | self.conv1 = ConvLayer(nb_filter[0], output_nc, 1, stride) 424 | self.conv2 = ConvLayer(nb_filter[0], output_nc, 1, stride) 425 | self.conv3 = ConvLayer(nb_filter[0], output_nc, 1, stride) 426 | # self.conv4 = ConvLayer(nb_filter[0], output_nc, 1, stride) 427 | else: 428 | self.conv_out = ConvLayer(nb_filter[0], output_nc, 1, stride) 429 | 430 | def decoder_train(self, f_en): 431 | x1_1 = self.DB1_1(torch.cat([f_en[0], self.up(f_en[1])], 1)) 432 | 433 | x2_1 = self.DB2_1(torch.cat([f_en[1], self.up(f_en[2])], 1)) 434 | x1_2 = self.DB1_2(torch.cat([f_en[0], x1_1, self.up(x2_1)], 1)) 435 | 436 | x3_1 = self.DB3_1(torch.cat([f_en[2], self.up(f_en[3])], 1)) 437 | x2_2 = self.DB2_2(torch.cat([f_en[1], x2_1, self.up(x3_1)], 1)) 438 | x1_3 = self.DB1_3(torch.cat([f_en[0], x1_1, x1_2, self.up(x2_2)], 1)) 439 | 440 | if self.deepsupervision: 441 | output1 = self.conv1(x1_1) 442 | output2 = self.conv2(x1_2) 443 | output3 = self.conv3(x1_3) 444 | # output4 = self.conv4(x1_4) 445 | return [output1, output2, output3] 446 | else: 447 | output = self.conv_out(x1_3) 448 | return [output] 449 | 450 | def decoder_eval(self, f_en): 451 | x1_1 = self.DB1_1(torch.cat([f_en[0], self.up_eval(f_en[0], f_en[1])], 1)) 452 | 453 | x2_1 = self.DB2_1(torch.cat([f_en[1], self.up_eval(f_en[1], f_en[2])], 1)) 454 | x1_2 = self.DB1_2(torch.cat([f_en[0], x1_1, self.up_eval(f_en[0], x2_1)], 1)) 455 | 456 | x3_1 = self.DB3_1(torch.cat([f_en[2], self.up_eval(f_en[2], f_en[3])], 1)) 457 | x2_2 = self.DB2_2(torch.cat([f_en[1], x2_1, self.up_eval(f_en[1], x3_1)], 1)) 458 | 459 | x1_3 = self.DB1_3(torch.cat([f_en[0], x1_1, x1_2, self.up_eval(f_en[0], x2_2)], 1)) 460 | 461 | if self.deepsupervision: 462 | output1 = self.conv1(x1_1) 463 | output2 = self.conv2(x1_2) 464 | output3 = self.conv3(x1_3) 465 | # output4 = self.conv4(x1_4) 466 | return [output1, output2, output3] 467 | else: 468 | output = self.conv_out(x1_3) 469 | return [output] 470 | 471 | class AxialAttention(nn.Module): 472 | def __init__(self, in_planes, out_planes, groups=8, kernel_size=56, 473 | stride=1, bias=False, width=False): 474 | assert (in_planes % groups == 0) and (out_planes % groups == 0) 475 | super(AxialAttention, self).__init__() 476 | self.in_planes = in_planes 477 | self.out_planes = out_planes 478 | self.groups = groups 479 | self.group_planes = out_planes // groups 480 | self.kernel_size = kernel_size 481 | self.stride = stride 482 | self.bias = bias 483 | self.width = width 484 | 485 | # Multi-head self attention 486 | self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1, 487 | padding=0, bias=False) 488 | self.bn_qkv = nn.BatchNorm1d(out_planes * 2) 489 | self.bn_similarity = nn.BatchNorm2d(groups * 3) 490 | 491 | self.bn_output = nn.BatchNorm1d(out_planes * 2) 492 | 493 | # Position embedding 494 | self.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True) 495 | query_index = torch.arange(kernel_size).unsqueeze(0) 496 | key_index = torch.arange(kernel_size).unsqueeze(1) 497 | relative_index = key_index - query_index + kernel_size - 1 498 | self.register_buffer('flatten_index', relative_index.view(-1)) 499 | if stride > 1: 500 | self.pooling = nn.AvgPool2d(stride, stride=stride) 501 | 502 | self.reset_parameters() 503 | 504 | def forward(self, x): 505 | 506 | if self.width: 507 | x = x.permute(0, 2, 1, 3) 508 | else: 509 | x = x.permute(0, 3, 1, 2) # N, W, C, H 510 | N, W, C, H = x.shape 511 | x = x.contiguous().view(N * W, C, H) 512 | 513 | # Transformations 514 | qkv = self.bn_qkv(self.qkv_transform(x)) 515 | q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H), [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2) 516 | 517 | # Calculate position embedding 518 | all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2, self.kernel_size, self.kernel_size) 519 | q_embedding, k_embedding, v_embedding = torch.split(all_embeddings, [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=0) 520 | #pdb.set_trace() 521 | 522 | qr = torch.einsum('bgci,cij->bgij', q, q_embedding) 523 | kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3) 524 | 525 | qk = torch.einsum('bgci, bgcj->bgij', q, k) 526 | 527 | stacked_similarity = torch.cat([qk, qr, kr], dim=1) 528 | stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1) 529 | #stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk) 530 | # (N, groups, H, H, W) 531 | similarity = F.softmax(stacked_similarity, dim=3) 532 | sv = torch.einsum('bgij,bgcj->bgci', similarity, v) 533 | sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding) 534 | stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H) 535 | output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2) 536 | 537 | if self.width: 538 | output = output.permute(0, 2, 1, 3) 539 | else: 540 | output = output.permute(0, 2, 3, 1) 541 | 542 | if self.stride > 1: 543 | output = self.pooling(output) 544 | 545 | return output 546 | 547 | def reset_parameters(self): 548 | self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes)) 549 | #nn.init.uniform_(self.relative, -0.1, 0.1) 550 | nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes)) 551 | 552 | class AxialBlock(nn.Module): 553 | expansion = 2 554 | 555 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 556 | base_width=64, dilation=1, norm_layer=None, kernel_size=56): 557 | super(AxialBlock, self).__init__() 558 | if norm_layer is None: 559 | norm_layer = nn.BatchNorm2d 560 | width = int(planes * (base_width / 64.)) 561 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 562 | self.conv_down = conv1x1(inplanes, width) 563 | self.bn1 = norm_layer(width) 564 | self.hight_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size) 565 | self.width_block = AxialAttention(width, width, groups=groups, kernel_size=kernel_size, stride=stride, width=True) 566 | self.conv_up = conv1x1(width, planes * self.expansion) 567 | self.bn2 = norm_layer(planes * self.expansion) 568 | self.relu = nn.ReLU(inplace=True) 569 | self.downsample = downsample 570 | self.stride = stride 571 | 572 | def forward(self, x): 573 | identity = x 574 | 575 | out = self.conv_down(x) 576 | out = self.bn1(out) 577 | out = self.relu(out) 578 | # print(out.shape) 579 | out = self.hight_block(out) 580 | out = self.width_block(out) 581 | out = self.relu(out) 582 | 583 | out = self.conv_up(out) 584 | out = self.bn2(out) 585 | 586 | if self.downsample is not None: 587 | identity = self.downsample(x) 588 | 589 | out += identity 590 | out = self.relu(out) 591 | 592 | return out 593 | -------------------------------------------------------------------------------- /pytorch_msssim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from math import exp 4 | import numpy as np 5 | 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | 12 | def create_window(window_size, channel=1): 13 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 14 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 15 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 16 | return window 17 | 18 | 19 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 20 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 21 | if val_range is None: 22 | if torch.max(img1) > 128: 23 | max_val = 255 24 | else: 25 | max_val = 1 26 | 27 | if torch.min(img1) < -0.5: 28 | min_val = -1 29 | else: 30 | min_val = 0 31 | L = max_val - min_val 32 | else: 33 | L = val_range 34 | 35 | padd = 0 36 | (_, channel, height, width) = img1.size() 37 | if window is None: 38 | real_size = min(window_size, height, width) 39 | window = create_window(real_size, channel=channel).to(img1.device) 40 | 41 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 42 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 43 | 44 | mu1_sq = mu1.pow(2) 45 | mu2_sq = mu2.pow(2) 46 | mu1_mu2 = mu1 * mu2 47 | 48 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 49 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 50 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 51 | 52 | C1 = (0.01 * L) ** 2 53 | C2 = (0.03 * L) ** 2 54 | 55 | v1 = 2.0 * sigma12 + C2 56 | v2 = sigma1_sq + sigma2_sq + C2 57 | cs = torch.mean(v1 / v2) # contrast sensitivity 58 | 59 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 60 | 61 | if size_average: 62 | ret = ssim_map.mean() 63 | else: 64 | ret = ssim_map.mean(1).mean(1).mean(1) 65 | 66 | if full: 67 | return ret, cs 68 | return ret 69 | 70 | 71 | def msssim(img1, img2, window_size=11, size_average=True, val_range=None, normalize=False): 72 | device = img1.device 73 | weights = torch.FloatTensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(device) 74 | levels = weights.size()[0] 75 | mssim = [] 76 | mcs = [] 77 | for _ in range(levels): 78 | sim, cs = ssim(img1, img2, window_size=window_size, size_average=size_average, full=True, val_range=val_range) 79 | mssim.append(sim) 80 | mcs.append(cs) 81 | 82 | img1 = F.avg_pool2d(img1, (2, 2)) 83 | img2 = F.avg_pool2d(img2, (2, 2)) 84 | 85 | mssim = torch.stack(mssim) 86 | mcs = torch.stack(mcs) 87 | 88 | # Normalize (to avoid NaNs during training unstable models, not compliant with original definition) 89 | if normalize: 90 | mssim = (mssim + 1) / 2 91 | mcs = (mcs + 1) / 2 92 | 93 | pow1 = mcs ** weights 94 | pow2 = mssim ** weights 95 | # From Matlab implementation https://ece.uwaterloo.ca/~z70wang/research/iwssim/ 96 | output = torch.prod(pow1[:-1] * pow2[-1]) 97 | return output 98 | 99 | 100 | # Classes to re-use window 101 | class SSIM(torch.nn.Module): 102 | def __init__(self, window_size=11, size_average=True, val_range=None): 103 | super(SSIM, self).__init__() 104 | self.window_size = window_size 105 | self.size_average = size_average 106 | self.val_range = val_range 107 | 108 | # Assume 1 channel for SSIM 109 | self.channel = 1 110 | self.window = create_window(window_size) 111 | 112 | def forward(self, img1, img2): 113 | (_, channel, _, _) = img1.size() 114 | 115 | if channel == self.channel and self.window.dtype == img1.dtype: 116 | window = self.window 117 | else: 118 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 119 | self.window = window 120 | self.channel = channel 121 | 122 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 123 | 124 | class MSSSIM(torch.nn.Module): 125 | def __init__(self, window_size=11, size_average=True, channel=3): 126 | super(MSSSIM, self).__init__() 127 | self.window_size = window_size 128 | self.size_average = size_average 129 | self.channel = channel 130 | 131 | def forward(self, img1, img2): 132 | # TODO: store window between calls if possible 133 | return msssim(img1, img2, window_size=self.window_size, size_average=self.size_average) 134 | -------------------------------------------------------------------------------- /pytorch_msssim/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/pytorch_msssim/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /pytorch_msssim/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vibashan/Image-Fusion-Transformer/53db2ab720d57d8a160426cc1dfc9e605d88cddf/pytorch_msssim/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /reference.bib: -------------------------------------------------------------------------------- 1 | @inproceedings{vs2022image, 2 | title={Image fusion transformer}, 3 | author={Vs, Vibashan and Valanarasu, Jeya Maria Jose and Oza, Poojan and Patel, Vishal M}, 4 | booktitle={2022 IEEE International Conference on Image Processing (ICIP)}, 5 | pages={3566--3570}, 6 | year={2022}, 7 | organization={IEEE} 8 | } 9 | -------------------------------------------------------------------------------- /test_21pairs_axial.py: -------------------------------------------------------------------------------- 1 | # test phase 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | from net import NestFuse_light2_nodense, Fusion_network, Fusion_strategy, RFN_decoder 6 | import utils 7 | from args_fusion import args 8 | import numpy as np 9 | from vit_model import VisionTransformer 10 | 11 | def load_model(path_auto, path_fusion, fs_type, flag_img): 12 | if flag_img is True: 13 | nc = 3 14 | else: 15 | nc =1 16 | input_nc = nc 17 | output_nc = nc 18 | nb_filter = [64, 112, 160, 208] 19 | 20 | nest_model = NestFuse_light2_nodense(nb_filter, input_nc, output_nc, deepsupervision=False) 21 | nest_model.load_state_dict(torch.load(path_auto)) 22 | 23 | fusion_model = Fusion_network(nb_filter, fs_type) 24 | fusion_model.load_state_dict(torch.load("models/train/fusionnet/6.0/fusion_axial/fusion_axial.model")) 25 | 26 | fusion_strategy = Fusion_strategy(fs_type) 27 | 28 | para = sum([np.prod(list(p.size())) for p in nest_model.parameters()]) 29 | type_size = 4 30 | print('Model {} : params: {:4f}M'.format(nest_model._get_name(), para * type_size / 1000 / 1000)) 31 | 32 | para = sum([np.prod(list(p.size())) for p in fusion_model.parameters()]) 33 | type_size = 4 34 | print('Model {} : params: {:4f}M'.format(fusion_model._get_name(), para * type_size / 1000 / 1000)) 35 | 36 | nest_model.eval() 37 | fusion_model.eval() 38 | #trans_model.eval() 39 | nest_model.cuda() 40 | fusion_model.cuda() 41 | #trans_model.cuda() 42 | 43 | return nest_model, fusion_model, fusion_strategy 44 | 45 | 46 | def run_demo(nest_model, fusion_model, fusion_strategy, infrared_path, visible_path, output_path_root, name_ir, fs_type, use_strategy, flag_img, alpha): 47 | img_ir, h, w, c = utils.get_test_image(infrared_path, flag=flag_img) # True for rgb 48 | img_vi, h, w, c = utils.get_test_image(visible_path, flag=flag_img) 49 | 50 | # dim = img_ir.shape 51 | if c is 1: 52 | if args.cuda: 53 | img_ir = img_ir.cuda() 54 | img_vi = img_vi.cuda() 55 | img_ir = Variable(img_ir, requires_grad=False) 56 | img_vi = Variable(img_vi, requires_grad=False) 57 | print(img_ir.shape) 58 | # encoder 59 | en_r = nest_model.encoder(img_ir) 60 | en_v = nest_model.encoder(img_vi) 61 | # fusion 62 | f = fusion_model(en_r, en_v) 63 | # decoder 64 | img_fusion_list = nest_model.decoder_eval(f) 65 | else: 66 | # fusion each block 67 | img_fusion_blocks = [] 68 | for i in range(c): 69 | # encoder 70 | img_vi_temp = img_vi[i] 71 | img_ir_temp = img_ir[i] 72 | if args.cuda: 73 | img_vi_temp = img_vi_temp.cuda() 74 | img_ir_temp = img_ir_temp.cuda() 75 | img_vi_temp = Variable(img_vi_temp, requires_grad=False) 76 | img_ir_temp = Variable(img_ir_temp, requires_grad=False) 77 | 78 | # encoder 79 | en_r = nest_model.encoder(img_ir) 80 | en_v = nest_model.encoder(img_vi) 81 | # fusion 82 | f = fusion_model(en_r, en_v) 83 | # decoder 84 | img_fusion_temp = nest_model.decoder_eval(f) 85 | img_fusion_blocks.append(img_fusion_temp) 86 | img_fusion_list = utils.recons_fusion_images(img_fusion_blocks, h, w) 87 | 88 | # ########################### multi-outputs ############################################## 89 | output_count = 0 90 | for img_fusion in img_fusion_list: 91 | file_name = 'fused_' + alpha + '_' + name_ir 92 | output_path = output_path_root + file_name 93 | output_count += 1 94 | # save images 95 | utils.save_image_test(img_fusion, output_path) 96 | print(output_path) 97 | 98 | def main(): 99 | # False - gray 100 | flag_img = False 101 | # ################# gray scale ######################################## 102 | test_path = "images/21_pairs_tno/ir/" 103 | path_auto = args.resume_nestfuse 104 | output_path_root = "./outputs/alpha_1e4_21/" 105 | if os.path.exists(output_path_root) is False: 106 | os.mkdir(output_path_root) 107 | 108 | fs_type = 'res' # res (RFN), add, avg, max, spa, nuclear 109 | use_strategy = False # True - static strategy; False - RFN 110 | 111 | path_fusion_root = args.fusion_model 112 | 113 | with torch.no_grad(): 114 | # alpha_list = [2500, 5000, 15000, 20000, 25000] 115 | alpha_list = [700] 116 | w_all_list = [[6.0, 3.0]] 117 | 118 | for alpha in alpha_list: 119 | for w_all in w_all_list: 120 | w, w2 = w_all 121 | 122 | temp = 'rfnnest_' + str(alpha) + '_wir_' + str(w) + '_wvi_' + str(w2)+ 'axial' 123 | output_path_list = 'fused_' + temp + '_21' + '_' + fs_type 124 | output_path1 = output_path_root + output_path_list + '/' 125 | if os.path.exists(output_path1) is False: 126 | os.mkdir(output_path1) 127 | output_path = output_path1 128 | # load network 129 | path_fusion = path_fusion_root + str(w) + '/' + 'Final_epoch_2_alpha_' + str(alpha) + '_wir_' + str(w) + '_wvi_' + str(w2) + '_ssim_vi.model' 130 | model, fusion_model, fusion_strategy = load_model(path_auto, path_fusion, fs_type, flag_img) 131 | imgs_paths_ir, names = utils.list_images(test_path) 132 | num = len(imgs_paths_ir) 133 | for i in range(num): 134 | name_ir = names[i] 135 | infrared_path = imgs_paths_ir[i] 136 | visible_path = infrared_path.replace('ir/', 'vis/') 137 | if visible_path.__contains__('IR'): 138 | visible_path = visible_path.replace('IR', 'VIS') 139 | else: 140 | visible_path = visible_path.replace('i.', 'v.') 141 | run_demo(model, fusion_model, fusion_strategy, infrared_path, visible_path, output_path, name_ir, fs_type, use_strategy, flag_img, temp) 142 | print('Done......') 143 | 144 | if __name__ == '__main__': 145 | main() 146 | -------------------------------------------------------------------------------- /train_fusionnet_axial.py: -------------------------------------------------------------------------------- 1 | # Training a IFT network 2 | 3 | import os 4 | 5 | import sys 6 | import time 7 | from tqdm import tqdm, trange 8 | import scipy.io as scio 9 | import random 10 | import torch 11 | import pdb 12 | import torch.nn as nn 13 | from torch.optim import Adam 14 | from torch.autograd import Variable 15 | import utils 16 | from kornia.filters.sobel import sobel 17 | from net import NestFuse_light2_nodense, Fusion_network, RFN_decoder 18 | from checkpoint import load_checkpoint 19 | from args_fusion import args 20 | import pytorch_msssim 21 | #from vit_model import VisionTransformer 22 | 23 | import warnings 24 | warnings.filterwarnings("ignore") 25 | 26 | EPSILON = 1e-5 27 | 28 | def main(): 29 | original_imgs_path, _ = utils.list_images(args.dataset_ir) 30 | train_num = 80000 31 | original_imgs_path = original_imgs_path[:train_num] 32 | random.shuffle(original_imgs_path) 33 | # True - RGB , False - gray 34 | img_flag = False 35 | alpha_list = [700] 36 | w_all_list = [[6.0, 3.0]] 37 | 38 | for w_w in w_all_list: 39 | w1, w2 = w_w 40 | for alpha in alpha_list: 41 | train(original_imgs_path, img_flag, alpha, w1, w2) 42 | 43 | 44 | def train(original_imgs_path, img_flag, alpha, w1, w2): 45 | 46 | batch_size = args.batch_size 47 | # load network model 48 | nc = 1 49 | input_nc = nc 50 | output_nc = nc 51 | #nb_filter = [64, 128, 256, 512] 52 | nb_filter = [64, 112, 160, 208] 53 | f_type = 'res' 54 | 55 | 56 | with torch.no_grad(): 57 | deepsupervision = False 58 | nest_model = NestFuse_light2_nodense(nb_filter, input_nc, output_nc, deepsupervision) 59 | model_path = args.resume_nestfuse 60 | # load auto-encoder network 61 | print('Resuming, initializing auto-encoder using weight from {}.'.format(model_path)) 62 | nest_model.load_state_dict(torch.load(model_path)) 63 | nest_model.cuda() 64 | nest_model.eval() 65 | 66 | # fusion network 67 | fusion_model = Fusion_network(nb_filter, f_type) 68 | fusion_model.cuda() 69 | fusion_model.train() 70 | 71 | if args.resume_fusion_model is not None: 72 | print('Resuming, initializing fusion net using weight from {}.'.format(args.resume_fusion_model)) 73 | fusion_model.load_state_dict(torch.load(args.resume_fusion_model)) 74 | optimizer = Adam(fusion_model.parameters(), args.lr) 75 | mse_loss = torch.nn.MSELoss() 76 | ssim_loss = pytorch_msssim.msssim 77 | 78 | 79 | tbar = trange(args.epochs) 80 | print('Start training.....') 81 | mode = args.mode 82 | print(mode) 83 | # creating save path 84 | temp_path_model = os.path.join(args.save_fusion_model) 85 | temp_path_loss = os.path.join(args.save_loss_dir) 86 | if os.path.exists(temp_path_model) is False: 87 | os.mkdir(temp_path_model) 88 | 89 | if os.path.exists(temp_path_loss) is False: 90 | os.mkdir(temp_path_loss) 91 | 92 | temp_path_model_w = os.path.join(args.save_fusion_model, str(w1), mode) 93 | temp_path_loss_w = os.path.join(args.save_loss_dir, str(w1)) 94 | if os.path.exists(temp_path_model_w) is False: 95 | os.mkdir(temp_path_model_w) 96 | 97 | if os.path.exists(temp_path_loss_w) is False: 98 | os.mkdir(temp_path_loss_w) 99 | 100 | Loss_feature = [] 101 | Loss_ssim = [] 102 | Loss_all = [] 103 | count_loss = 0 104 | all_ssim_loss = 0. 105 | all_fea_loss = 0. 106 | sobel_loss = nn.L1Loss() 107 | for e in tbar: 108 | print('Epoch %d.....' % e) 109 | # load training database 110 | image_set_ir, batches = utils.load_dataset(original_imgs_path, batch_size) 111 | count = 0 112 | nest_model.cuda() 113 | #trans_model.cuda() 114 | fusion_model.cuda() 115 | for batch in range(batches): 116 | image_paths_ir = image_set_ir[batch * batch_size:(batch * batch_size + batch_size)] 117 | img_ir = utils.get_train_images(image_paths_ir, height=args.HEIGHT, width=args.WIDTH, flag=img_flag) 118 | 119 | image_paths_vi = [x.replace('lwir', 'visible') for x in image_paths_ir] 120 | img_vi = utils.get_train_images(image_paths_vi, height=args.HEIGHT, width=args.WIDTH, flag=img_flag) 121 | 122 | count += 1 123 | optimizer.zero_grad() 124 | 125 | img_ir = Variable(img_ir, requires_grad=False) 126 | img_vi = Variable(img_vi, requires_grad=False) 127 | 128 | img_ir = img_ir.cuda() 129 | img_vi = img_vi.cuda() 130 | 131 | # encoder 132 | en_ir = nest_model.encoder(img_ir) 133 | en_vi = nest_model.encoder(img_vi) 134 | # fusion 135 | f = fusion_model(en_ir, en_vi) 136 | # decoder 137 | outputs = nest_model.decoder_eval(f) 138 | 139 | # resolution loss: between fusion image and visible image 140 | x_ir = Variable(img_ir.data.clone(), requires_grad=False) 141 | x_vi = Variable(img_vi.data.clone(), requires_grad=False) 142 | 143 | ######################### LOSS FUNCTION ######################### 144 | loss1_value = 0. 145 | loss2_value = 0. 146 | for output in outputs: 147 | output = (output - torch.min(output)) / (torch.max(output) - torch.min(output) + EPSILON) 148 | output = output * 255 149 | # ---------------------- LOSS IMAGES ------------------------------------ 150 | # detail loss 151 | ssim_loss_temp2 = ssim_loss(output, x_vi, normalize=True) 152 | loss1_value += alpha * (1 - ssim_loss_temp2) 153 | 154 | # feature loss 155 | g2_ir_fea = en_ir 156 | g2_vi_fea = en_vi 157 | g2_fuse_fea = f 158 | 159 | w_ir = [w1, w1, w1, w1] 160 | w_vi = [w2, w2, w2, w2] 161 | w_fea = [1, 10, 100, 1000] 162 | for ii in range(4): 163 | g2_ir_temp = g2_ir_fea[ii] 164 | g2_vi_temp = g2_vi_fea[ii] 165 | g2_fuse_temp = g2_fuse_fea[ii] 166 | (bt, cht, ht, wt) = g2_ir_temp.size() 167 | loss2_value += w_fea[ii]*mse_loss(g2_fuse_temp, w_ir[ii]*g2_ir_temp + w_vi[ii]*g2_vi_temp) 168 | 169 | loss1_value /= len(outputs) 170 | loss2_value /= len(outputs) 171 | 172 | total_loss = loss1_value + loss2_value 173 | total_loss.backward() 174 | optimizer.step() 175 | 176 | all_fea_loss += loss2_value.item() # 177 | all_ssim_loss += loss1_value.item() # 178 | if (batch + 1) % args.log_interval == 0: 179 | mesg = "{}\t Alpha: {} \tW-IR: {}\tEpoch {}:\t[{}/{}]\t ssim loss: {:.6f}\t fea loss: {:.6f}\t total: {:.6f}".format( 180 | time.ctime(), alpha, w1, e + 1, count, batches, 181 | all_ssim_loss / args.log_interval, 182 | all_fea_loss / args.log_interval, 183 | (all_fea_loss + all_ssim_loss) / args.log_interval 184 | ) 185 | tbar.set_description(mesg) 186 | Loss_ssim.append( all_ssim_loss / args.log_interval) 187 | Loss_feature.append(all_fea_loss / args.log_interval) 188 | Loss_all.append((all_fea_loss + all_ssim_loss) / args.log_interval) 189 | count_loss = count_loss + 1 190 | all_ssim_loss = 0. 191 | all_fea_loss = 0. 192 | 193 | # save model 194 | save_model_filename = mode + ".model" 195 | save_model_path = os.path.join(temp_path_model_w, save_model_filename) 196 | torch.save(fusion_model.state_dict(), save_model_path) 197 | 198 | print("\nDone, trained model saved at", save_model_path) 199 | 200 | 201 | def check_paths(args): 202 | try: 203 | if not os.path.exists(args.vgg_model_dir): 204 | os.makedirs(args.vgg_model_dir) 205 | if not os.path.exists(args.save_model_dir): 206 | os.makedirs(args.save_model_dir) 207 | except OSError as e: 208 | print(e) 209 | sys.exit(1) 210 | 211 | 212 | if __name__ == "__main__": 213 | main() 214 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | import torch 5 | from args_fusion import args 6 | from scipy.misc import imread, imsave, imresize 7 | import matplotlib as mpl 8 | import pdb 9 | import cv2 10 | 11 | from os import listdir 12 | from os.path import join 13 | 14 | EPSILON = 1e-5 15 | 16 | ###### Testing ######## 17 | ''' 18 | def list_images(directory): 19 | images = [] 20 | names = [] 21 | dir = listdir(directory) 22 | dir.sort() 23 | for file in dir: 24 | name = file 25 | if name.endswith('.png'): 26 | images.append(join(directory, file)) 27 | elif name.endswith('.jpg'): 28 | images.append(join(directory, file)) 29 | elif name.endswith('.jpeg'): 30 | images.append(join(directory, file)) 31 | elif name.endswith('.bmp'): 32 | images.append(join(directory, file)) 33 | elif name.endswith('.tif'): 34 | images.append(join(directory, file)) 35 | # name1 = name.split('.') 36 | names.append(name) 37 | return images, names 38 | 39 | ###### Training ######## 40 | ''' 41 | def list_images(directory): 42 | print(directory) 43 | images = [] 44 | names = [] 45 | dir = os.listdir(directory) 46 | dir.sort() 47 | for dir0 in dir: 48 | #print(directory, dir0) 49 | #print(os.listdir(os.path.join(directory, dir0)).sort()) 50 | for dir1 in os.listdir(os.path.join(directory, dir0)): 51 | req_path = os.path.join(directory, dir0, dir1, 'lwir') 52 | for file in os.listdir(req_path): 53 | name = file 54 | if name.endswith('.png'): 55 | images.append(join(req_path, file)) 56 | elif name.endswith('.jpg'): 57 | images.append(join(req_path, file)) 58 | elif name.endswith('.jpeg'): 59 | images.append(join(req_path, file)) 60 | elif name.endswith('.bmp'): 61 | images.append(join(req_path, file)) 62 | elif name.endswith('.tif'): 63 | images.append(join(req_path, file)) 64 | names.append(name) 65 | #print(join(req_path, file), name) 66 | return images, names 67 | 68 | # load training images 69 | def load_dataset(image_path, BATCH_SIZE, num_imgs=None): 70 | if num_imgs is None: 71 | num_imgs = len(image_path) 72 | original_imgs_path = image_path[:num_imgs] 73 | # random 74 | random.shuffle(original_imgs_path) 75 | mod = num_imgs % BATCH_SIZE 76 | print('BATCH SIZE %d.' % BATCH_SIZE) 77 | print('Train images number %d.' % num_imgs) 78 | print('Train images samples %s.' % str(num_imgs / BATCH_SIZE)) 79 | 80 | if mod > 0: 81 | print('Train set has been trimmed %d samples...\n' % mod) 82 | original_imgs_path = original_imgs_path[:-mod] 83 | batches = int(len(original_imgs_path) // BATCH_SIZE) 84 | return original_imgs_path, batches 85 | 86 | 87 | def get_image(path, height=256, width=256, flag=False): 88 | if flag is True: 89 | image = imread(path, mode='RGB') 90 | else: 91 | image = imread(path, mode='L') 92 | 93 | if height is not None and width is not None: 94 | image = imresize(image, [height, width], interp='bicubic') 95 | return image 96 | 97 | 98 | # load images - test phase 99 | def get_test_image(paths, height=256, width=256, flag=False): 100 | if isinstance(paths, str): 101 | paths = [paths] 102 | images = [] 103 | for path in paths: 104 | if flag is True: 105 | image = imread(path, mode='RGB') 106 | else: 107 | image = imread(path, mode='L') 108 | # get saliency part 109 | if height is not None and width is not None: 110 | #image = imresize(image, [height, width], interp='nearest') 111 | image = imresize(image, [height, width], interp='bicubic') 112 | 113 | base_size = 512 114 | h = image.shape[0] 115 | w = image.shape[1] 116 | c = 1 117 | if h > base_size or w > base_size: 118 | c = 4 119 | if flag is True: 120 | image = np.transpose(image, (2, 0, 1)) 121 | else: 122 | image = np.reshape(image, [1, h, w]) 123 | images = get_img_parts(image, h, w) 124 | else: 125 | if flag is True: 126 | image = np.transpose(image, (2, 0, 1)) 127 | else: 128 | image = np.reshape(image, [1, image.shape[0], image.shape[1]]) 129 | images.append(image) 130 | images = np.stack(images, axis=0) 131 | images = torch.from_numpy(images).float() 132 | 133 | return images, h, w, c 134 | 135 | 136 | def get_img_parts(image, h, w): 137 | images = [] 138 | h_cen = int(np.floor(h / 2)) 139 | w_cen = int(np.floor(w / 2)) 140 | img1 = image[:, 0:h_cen + 3, 0: w_cen + 3] 141 | img1 = np.reshape(img1, [1, img1.shape[0], img1.shape[1], img1.shape[2]]) 142 | img2 = image[:, 0:h_cen + 3, w_cen - 2: w] 143 | img2 = np.reshape(img2, [1, img2.shape[0], img2.shape[1], img2.shape[2]]) 144 | img3 = image[:, h_cen - 2:h, 0: w_cen + 3] 145 | img3 = np.reshape(img3, [1, img3.shape[0], img3.shape[1], img3.shape[2]]) 146 | img4 = image[:, h_cen - 2:h, w_cen - 2: w] 147 | img4 = np.reshape(img4, [1, img4.shape[0], img4.shape[1], img4.shape[2]]) 148 | images.append(torch.from_numpy(img1).float()) 149 | images.append(torch.from_numpy(img2).float()) 150 | images.append(torch.from_numpy(img3).float()) 151 | images.append(torch.from_numpy(img4).float()) 152 | return images 153 | 154 | def recons_fusion_images(img_lists, h, w): 155 | img_f_list = [] 156 | h_cen = int(np.floor(h / 2)) 157 | w_cen = int(np.floor(w / 2)) 158 | c = img_lists[0][0].shape[1] 159 | ones_temp = torch.ones(1, c, h, w).cuda() 160 | for i in range(len(img_lists[0])): 161 | # img1, img2, img3, img4 162 | img1 = img_lists[0][i] 163 | img2 = img_lists[1][i] 164 | img3 = img_lists[2][i] 165 | img4 = img_lists[3][i] 166 | 167 | img_f = torch.zeros(1, c, h, w).cuda() 168 | count = torch.zeros(1, c, h, w).cuda() 169 | 170 | img_f[:, :, 0:h_cen + 3, 0: w_cen + 3] += img1 171 | count[:, :, 0:h_cen + 3, 0: w_cen + 3] += ones_temp[:, :, 0:h_cen + 3, 0: w_cen + 3] 172 | img_f[:, :, 0:h_cen + 3, w_cen - 2: w] += img2 173 | count[:, :, 0:h_cen + 3, w_cen - 2: w] += ones_temp[:, :, 0:h_cen + 3, w_cen - 2: w] 174 | img_f[:, :, h_cen - 2:h, 0: w_cen + 3] += img3 175 | count[:, :, h_cen - 2:h, 0: w_cen + 3] += ones_temp[:, :, h_cen - 2:h, 0: w_cen + 3] 176 | img_f[:, :, h_cen - 2:h, w_cen - 2: w] += img4 177 | count[:, :, h_cen - 2:h, w_cen - 2: w] += ones_temp[:, :, h_cen - 2:h, w_cen - 2: w] 178 | img_f = img_f / count 179 | img_f_list.append(img_f) 180 | return img_f_list 181 | 182 | def save_image_test(img_fusion, output_path): 183 | img_fusion = img_fusion.float() 184 | if args.cuda: 185 | img_fusion = img_fusion.cpu().data[0].numpy() 186 | else: 187 | img_fusion = img_fusion.clamp(0, 255).data[0].numpy() 188 | 189 | img_fusion = (img_fusion - np.min(img_fusion)) / (np.max(img_fusion) - np.min(img_fusion) + EPSILON) 190 | img_fusion = img_fusion * 255 191 | img_fusion = img_fusion.transpose(1, 2, 0).astype('uint8') 192 | if img_fusion.shape[2] == 1: 193 | img_fusion = img_fusion.reshape([img_fusion.shape[0], img_fusion.shape[1]]) 194 | img_fusion = cv2.resize(img_fusion, (256,256), interpolation = cv2.INTER_CUBIC) 195 | imsave(output_path, img_fusion) 196 | 197 | 198 | def get_train_images(paths, height=256, width=256, flag=False): 199 | if isinstance(paths, str): 200 | paths = [paths] 201 | images = [] 202 | for path in paths: 203 | image = get_image(path, height, width, flag) 204 | if flag is True: 205 | image = np.transpose(image, (2, 0, 1)) 206 | else: 207 | image = np.reshape(image, [1, height, width]) 208 | images.append(image) 209 | 210 | images = np.stack(images, axis=0) 211 | images = torch.from_numpy(images).float() 212 | return images 213 | --------------------------------------------------------------------------------