├── common ├── GetMeanStdVar.m ├── SetSCDLParams.m ├── col2im_forSpatiotemporalFusion.m ├── im2col_forSpatiotemporalFusion.m └── im2col_forSpatiotemporalFusion_onlyCenterPoints.m ├── data ├── L7SR.05-24-01.r-g-nir.tif ├── L7SR.07-11-01.r-g-nir.tif ├── L7SR.08-12-01.r-g-nir.tif ├── MOD09GHK.05-24-01.r-g-nir.tif ├── MOD09GHK.07-11-01.r-g-nir.tif └── MOD09GHK.08-12-01.r-g-nir.tif ├── model └── Params_red_5x128x32.mat ├── predict ├── CSDL_predict.m ├── Compute_NLM_Matrix.m ├── Im2Patch.m ├── Low_rank_appro.m ├── ebscdl_interp.m ├── setPatchIdx.m └── soft.m ├── readme.txt ├── result └── readme.txt └── train ├── My_kmeans.m ├── Proc_cls_idx.m ├── coupled_DL.m ├── dict_train.m └── kmeans_1.m /common/GetMeanStdVar.m: -------------------------------------------------------------------------------- 1 | function [m,stdVar] = GetMeanStdVar(A) 2 | depths=size(A,3); 3 | A=double(A); 4 | for color=1:depths 5 | S=A(:,:,color); 6 | m(color)=mean(S(:)); 7 | stdVar(color) = sqrt(sum(sum((S-m(color)).^2))/numel(S)); 8 | % fprintf(' band %d: mean:%f std_var:%f\n', color, m(color), stdVar(color)); 9 | end -------------------------------------------------------------------------------- /common/SetSCDLParams.m: -------------------------------------------------------------------------------- 1 | function param = SetSCDLParams() 2 | param.win = 5; 3 | param.step = floor(param.win/2); 4 | param.rho = 5e-2; 5 | param.lambda1 = 0.01; 6 | param.lambda2 = 0.1; 7 | param.mu = 0.01; 8 | param.sqrtmu = sqrt(param.mu); 9 | param.nu = 0.1; 10 | param.nIter = 10; 11 | param.epsilon = 5e-3; 12 | param.t0 = 5; 13 | param.K = 256;%512; 14 | param.L = param.win * param.win; 15 | param.psf = fspecial('gaussian', param.win+2, 2.2); 16 | param.nClass = 40; 17 | param.nKmeansIters = 1000; 18 | 19 | lassoParam.K = param.K; 20 | lassoParam.lambda = param.lambda1; 21 | lassoParam.iter=100; 22 | lassoParam.L = param.win * param.win; 23 | lassoParam.verbose = false; 24 | lassoParam.numThreads = -1; % mt 25 | 26 | param.lassoParam = lassoParam; 27 | -------------------------------------------------------------------------------- /common/col2im_forSpatiotemporalFusion.m: -------------------------------------------------------------------------------- 1 | function im = col2im_forSpatiotemporalFusion(patches, winsize) 2 | 3 | blockHeight = winsize(1); halfBlockHeight = floor(blockHeight/2); blockHeight = halfBlockHeight*2+1; 4 | blockWidth = winsize(2); halfBlockWidth = floor(blockWidth/2); blockWidth = halfBlockWidth*2+1; 5 | 6 | height = patches(2, end) + halfBlockHeight; 7 | width = patches(3, end) + halfBlockWidth; 8 | 9 | im = zeros(height, width); 10 | count = zeros(height, width); 11 | 12 | for i=1:size(patches, 2) 13 | center_y = patches(2,i); 14 | center_x = patches(3,i); 15 | block = reshape(patches(4:end,i), winsize); 16 | region_y = [center_y-halfBlockHeight:center_y+halfBlockHeight]; 17 | region_x = [center_x-halfBlockWidth:center_x+halfBlockWidth]; 18 | im(region_y, region_x) = im(region_y, region_x) + block; 19 | count(region_y, region_x) = count(region_y, region_x) + 1; 20 | end 21 | 22 | if (any(count<1)) 23 | fprintf('Warning: not all pixels are divided into patches. Check im2col function for reason.'); 24 | count(count<1) = 1; 25 | end 26 | im = im ./ count; -------------------------------------------------------------------------------- /common/im2col_forSpatiotemporalFusion.m: -------------------------------------------------------------------------------- 1 | function [patches]=im2col_forSpatiotemporalFusion(X, winsize, step, resRate) 2 | 3 | % step must be no larger than window size to avoid missing blocks 4 | if step(1)>winsize(1) 5 | step(1) = winsize(1); 6 | end 7 | if step(2)>winsize(2) 8 | step(2) = winsize(2); 9 | end 10 | 11 | if nargin<4 12 | resRate = 500/30; % default for modis/landsat 13 | end 14 | 15 | [height,width] = size(X); 16 | % winsize=[17 17]; % an odd number 17 | blockHeight = winsize(1); halfBlockHeight = floor(blockHeight/2); blockHeight = halfBlockHeight*2+1; 18 | blockWidth = winsize(2); halfBlockWidth = floor(blockWidth/2); blockWidth = halfBlockWidth*2+1; 19 | if any([height width] < [blockHeight blockWidth]) % if neighborhood is larger than image 20 | patches = zeros(blockHeight*blockWidth,0); 21 | return 22 | end 23 | 24 | % centers of exactly mapped blocks 25 | centers_x = round(resRate/2+1:resRate:width); 26 | centers_y = round(resRate/2+1:resRate:height); 27 | centers_x(centers_x-halfBlockWidth<1 | centers_x+halfBlockWidth>width)=[]; 28 | centers_y(centers_y-halfBlockHeight<1 | centers_y+halfBlockHeight>height)=[]; 29 | 30 | % centers_y = [halfBlockHeight+1 centers_y height-halfBlockHeight]; 31 | points_y = []; % centers of all blocks 32 | for i=1:length(centers_y)-1 33 | % c = (centers_y(i)+centers_y(i+1))/2; 34 | % fr = c-step(1)/2:-step(1):centers_y(i)+2; 35 | % r = floor([fr(end:-1:1) c+step(1)/2:step(1):centers_y(i+1)-2]); 36 | r = centers_y(i)+step(1):step(1):centers_y(i+1)-1; 37 | if (~isempty(r)) 38 | offset = step(1)-1+centers_y(i+1)-r(end)-1; 39 | offset = ceil(offset/2); 40 | r = r - step(1) + 1 + offset; 41 | end 42 | points_y = [points_y centers_y(i) r]; 43 | end 44 | fr = centers_y(1):-step(1):halfBlockHeight+1; 45 | points_y = [fr(end:-1:2) points_y centers_y(end):step(1):height-halfBlockHeight]; % add first and last 46 | if points_y(1)>halfBlockHeight+1 47 | points_y = [halfBlockHeight+1 points_y]; % make sure that all pixels are divided into blocks 48 | end 49 | if points_y(end)halfBlockWidth+1 71 | points_x = [halfBlockWidth+1 points_x]; % make sure that all pixels are divided into blocks 72 | end 73 | if points_x(end)width)=[]; 20 | centers_y(centers_y-halfBlockHeight<1 | centers_y+halfBlockHeight>height)=[]; 21 | 22 | patches=zeros(blockHeight*blockWidth, length(centers_y)*length(centers_x)); 23 | for iy=1:length(centers_y) 24 | for ix=1:length(centers_x) 25 | block = X(centers_y(iy)-halfBlockHeight:centers_y(iy)+halfBlockHeight, centers_x(ix)-halfBlockWidth:centers_x(ix)+halfBlockWidth); 26 | patches(:,ix+(iy-1)*length(centers_x)) = block(:); 27 | end 28 | end 29 | -------------------------------------------------------------------------------- /data/L7SR.05-24-01.r-g-nir.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjia5220967/CSSFc/ff320f05e52671a9512e6f9d50c95944f4d74937/data/L7SR.05-24-01.r-g-nir.tif -------------------------------------------------------------------------------- /data/L7SR.07-11-01.r-g-nir.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjia5220967/CSSFc/ff320f05e52671a9512e6f9d50c95944f4d74937/data/L7SR.07-11-01.r-g-nir.tif -------------------------------------------------------------------------------- /data/L7SR.08-12-01.r-g-nir.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjia5220967/CSSFc/ff320f05e52671a9512e6f9d50c95944f4d74937/data/L7SR.08-12-01.r-g-nir.tif -------------------------------------------------------------------------------- /data/MOD09GHK.05-24-01.r-g-nir.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjia5220967/CSSFc/ff320f05e52671a9512e6f9d50c95944f4d74937/data/MOD09GHK.05-24-01.r-g-nir.tif -------------------------------------------------------------------------------- /data/MOD09GHK.07-11-01.r-g-nir.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjia5220967/CSSFc/ff320f05e52671a9512e6f9d50c95944f4d74937/data/MOD09GHK.07-11-01.r-g-nir.tif -------------------------------------------------------------------------------- /data/MOD09GHK.08-12-01.r-g-nir.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjia5220967/CSSFc/ff320f05e52671a9512e6f9d50c95944f4d74937/data/MOD09GHK.08-12-01.r-g-nir.tif -------------------------------------------------------------------------------- /model/Params_red_5x128x32.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjia5220967/CSSFc/ff320f05e52671a9512e6f9d50c95944f4d74937/model/Params_red_5x128x32.mat -------------------------------------------------------------------------------- /predict/CSDL_predict.m: -------------------------------------------------------------------------------- 1 | addpath('../common'); 2 | 3 | band=1; % adjustable 4 | bandName={'red' 'green' 'nir'}; 5 | 6 | resRate=10/3; % adjustable 7 | 8 | % read data 9 | Y01 = double(imread('../data/MOD09GHK.05-24-01.r-g-nir.tif')); 10 | X01 = double(imread('../data/L7SR.05-24-01.r-g-nir.tif')); 11 | Y02 = double(imread('../data/MOD09GHK.07-11-01.r-g-nir.tif')); 12 | X02 = double(imread('../data/L7SR.07-11-01.r-g-nir.tif')); 13 | Y03 = double(imread('../data/MOD09GHK.08-12-01.r-g-nir.tif')); 14 | X03 = double(imread('../data/L7SR.08-12-01.r-g-nir.tif')); 15 | 16 | Y01 = Y01(:,:,band); 17 | Y02 = Y02(:,:,band); 18 | Y03 = Y03(:,:,band); 19 | X01 = X01(:,:,band); 20 | X02 = X02(:,:,band); 21 | X03 = X03(:,:,band); 22 | 23 | [height, width] = size(X01); 24 | regionHeight=400; 25 | regionWidth=400; 26 | nVertRegions=floor(height/regionHeight); 27 | nHoriRegions=floor(width/regionWidth); 28 | region(nVertRegions*nHoriRegions, 1)=struct('stRow', [], 'endRow', [], 'stCol', [], 'endCol', []); 29 | for y=1:floor(height/regionHeight) 30 | for x=1:nHoriRegions 31 | iRegion = (y-1)*nHoriRegions+x; 32 | region(iRegion).stRow = (y-1)*regionHeight+1; 33 | region(iRegion).endRow = y*regionHeight; 34 | region(iRegion).stCol = (x-1)*regionWidth+1; 35 | region(iRegion).endCol = x*regionWidth; 36 | end 37 | end 38 | 39 | iRegion=1; % adjustable 40 | XB1=X01(region(iRegion).stRow:region(iRegion).endRow, region(iRegion).stCol:region(iRegion).endCol,:); 41 | XB2=X02(region(iRegion).stRow:region(iRegion).endRow, region(iRegion).stCol:region(iRegion).endCol,:); 42 | XB3=X03(region(iRegion).stRow:region(iRegion).endRow, region(iRegion).stCol:region(iRegion).endCol,:); 43 | YB1=Y01(region(iRegion).stRow:region(iRegion).endRow, region(iRegion).stCol:region(iRegion).endCol,:); 44 | YB2=Y02(region(iRegion).stRow:region(iRegion).endRow, region(iRegion).stCol:region(iRegion).endCol,:); 45 | YB3=Y03(region(iRegion).stRow:region(iRegion).endRow, region(iRegion).stCol:region(iRegion).endCol,:); 46 | 47 | %%%%%%%%%% begin of generating medium image %%%%%%%%%%%%%%%%%% 48 | % filepath=['../model/' 'KMeans_5x5_32_Factor3_2']; 49 | param.win=5; 50 | param.K=128; 51 | nClass=32; 52 | filepath=['../model/Params_' bandName{band} '_' num2str(param.win) 'x' num2str(param.K) 'x' num2str(nClass)]; 53 | fprintf('param file: %s\n', filepath); 54 | load(filepath, 'dictVec', 'param'); 55 | 56 | param.step=2; 57 | param.dictVec = dictVec; 58 | 59 | % filepath=['../model/' 'Dict_SR_Factor3.mat']; 60 | filepath=['../model/DictSet_' bandName{band} '_' num2str(param.win) 'x' num2str(param.K) 'x' num2str(nClass)]; 61 | fprintf('dictionary file: %s\n', filepath); 62 | load(filepath, 'Dict'); 63 | param.Dict=Dict; 64 | 65 | param.lambda_3 = 1; 66 | param.lambda_4 = 10000; 67 | 68 | param.resRate = 500/30/resRate; 69 | 70 | % downsample to middle resolution 71 | Y1=imresize(YB1, 1.0/resRate); 72 | Y2=imresize(YB2, 1.0/resRate); 73 | Y3=imresize(YB3, 1.0/resRate); 74 | X1=imresize(XB1, 1.0/resRate); 75 | X2=imresize(XB2, 1.0/resRate); % for psnr evaluation only 76 | X3=imresize(XB3, 1.0/resRate); 77 | 78 | % Measurement matrix 79 | B=im2col_forSpatiotemporalFusion(imresize(X01-X03, 1.0/resRate), [param.win param.win], [param.step param.step], resRate); 80 | B=B(4:end,:); 81 | B2=im2col_forSpatiotemporalFusion(imresize(Y01-Y03, 1.0/resRate), [param.win param.win], [param.step param.step], resRate); 82 | B2=B2(end-floor(param.win*param.win/2),:); 83 | [U]=pca(B'); 84 | [~, pos] = min(sum((B'*U-repmat(B2',1,size(U,2))).^2, 1)); 85 | param.M0 = reshape(U(:, pos), 1, []); 86 | 87 | param.originalImage=X2-X1; % only for psnr comparison 88 | XB21m = ebscdl_interp(Y2-Y1, param);%, refRec-X1); 89 | 90 | param.originalImage=X2-X3; % only for psnr comparison 91 | XB23m = ebscdl_interp(Y2-Y3, param);%, refRec-X3); 92 | 93 | X2m = ((X1+XB21m)+(X3+XB23m))/2; 94 | 95 | [psnr, mse, maxerr]=psnr_mse_maxerr(X2m(:), X2(:)); 96 | fprintf('\nResult of middle image after regularization: PSNR:%f rmse:%f\n', psnr, sqrt(mse)); 97 | 98 | %%%%%%%%%% end of generating medium image %%%%%%%%%%%%%%%%%% 99 | %% 100 | X2m = imresize(X2m, size(XB1)); 101 | 102 | param.win=5; 103 | param.K=128; 104 | nClass=32; 105 | % nBands=length(bands); 106 | filepath=['../model/Params_' bandName{band} '_' num2str(param.win) 'x' num2str(param.K) 'x' num2str(nClass)]; 107 | fprintf('param file: %s\n', filepath); 108 | load(filepath, 'dictVec', 'param'); 109 | param.dictVec = dictVec; 110 | 111 | param.step = param.win-1; 112 | param.lambda1=0.02; 113 | param.lambda2=0; 114 | param.lambda_3 = 1; 115 | param.lambda_4 = 20; 116 | 117 | param.resRate = 500/30;%resRate; 118 | 119 | filepath=['../model/DictSet_' bandName{band} '_' num2str(param.win) 'x' num2str(param.K) 'x' num2str(nClass)]; 120 | fprintf('dictionary file: %s\n', filepath); 121 | load(filepath, 'Dict'); 122 | param.Dict=Dict; 123 | 124 | % Measurement matrix 125 | B=im2col_forSpatiotemporalFusion(X01-X03, [param.win param.win], [param.step param.step], param.resRate); 126 | B=B(4:end,:); 127 | B2=im2col_forSpatiotemporalFusion(Y01-Y03, [param.win param.win], [param.step param.step], param.resRate); 128 | B2=B2(end-floor(param.win*param.win/2),:); 129 | [U]=pca(B'); 130 | [mindif, pos] = min(sum((B'*U-repmat(B2',1,size(U,2))).^2, 1)); 131 | param.M0 = reshape(U(:, pos), 1, []); 132 | 133 | param.originalImage=XB2-XB1; % only for psnr comparison 134 | XB21 = ebscdl_interp(YB2-YB1, param, X2m-XB1); 135 | 136 | param.originalImage=XB2-XB3; % only for psnr comparison 137 | X2B3 = ebscdl_interp(YB2-YB3, param, X2m-XB3); 138 | %% 139 | XB2p = ((XB1+XB21)+(XB3+X2B3))/2; 140 | 141 | [psnr, mse, maxerr]=psnr_mse_maxerr(XB2p(:), XB2(:)); 142 | fprintf('\nResult of CSSCDL: PSNR:%f rmse:%f\n', psnr, sqrt(mse)); 143 | 144 | imwrite(uint16(XB2p),['../result/x2p-CSSF-region' num2str(iRegion) '-rate' num2str(resRate) '-' bandName{band} '-' num2str(psnr) '.tif']); 145 | 146 | rmpath('../common'); -------------------------------------------------------------------------------- /predict/Compute_NLM_Matrix.m: -------------------------------------------------------------------------------- 1 | function [mW, W1] = Compute_NLM_Matrix( im, ws, par ) 2 | %---------------------------- 3 | % Only for grayscale image 4 | % Apr. 13, 2010 5 | %---------------------------- 6 | 7 | S = 12; 8 | f = ws; 9 | t = floor(f/2); 10 | nv = 10; %par.nblk; 11 | hp = 65; 12 | 13 | e_im = padarray( im, [t t], 'symmetric' ); 14 | [h w] = size( im ); 15 | nt = (nv)*h*w; 16 | R = zeros(nt,1); 17 | C = zeros(nt,1); 18 | V = zeros(nt,1); 19 | 20 | L = h*w; 21 | X = zeros(f*f, L, 'single'); 22 | 23 | % For the Y component 24 | k = 0; 25 | for i = 1:f 26 | for j = 1:f 27 | k = k+1; 28 | blk = e_im(i:end-f+i,j:end-f+j); 29 | X(k,:) = blk(:)'; 30 | end 31 | end 32 | 33 | % Index image 34 | I = reshape((1:L), h, w); 35 | X = X'; 36 | f2 = f^2; 37 | 38 | cnt = 1; 39 | for row = 1 : h 40 | for col = 1 : w 41 | 42 | off_cen = (col-1)*h + row; 43 | 44 | rmin = max( row-S, 1 ); 45 | rmax = min( row+S, h ); 46 | cmin = max( col-S, 1 ); 47 | cmax = min( col+S, w ); 48 | 49 | idx = I(rmin:rmax, cmin:cmax); 50 | idx = idx(:); 51 | B = X(idx, :); 52 | v = X(off_cen, :); 53 | 54 | 55 | dis = (B(:,1) - v(1)).^2; 56 | for k = 2:f2 57 | dis = dis + (B(:,k) - v(k)).^2; 58 | end 59 | dis = dis./f2; 60 | [val,ind] = sort(dis); 61 | dis(ind(1)) = dis(ind(2)); 62 | wei = exp( -dis(ind(1:nv))./hp ); 63 | 64 | R(cnt:cnt+nv-1) = off_cen; 65 | C(cnt:cnt+nv-1) = idx( ind(1:nv) ); 66 | V(cnt:cnt+nv-1) = wei./(sum(wei)+eps); 67 | cnt = cnt + nv; 68 | 69 | end 70 | end 71 | R = R(1:cnt-1); 72 | C = C(1:cnt-1); 73 | V = V(1:cnt-1); 74 | W1 = sparse(R, C, V, h*w, h*w); 75 | 76 | R = zeros(h*w,1); 77 | C = zeros(h*w,1); 78 | V = zeros(h*w,1); 79 | 80 | R(1:end) = 1:h*w; 81 | C(1:end) = 1:h*w; 82 | V(1:end) = 1; 83 | mI = sparse(R,C,V,h*w,h*w); 84 | mW = mI - W1; 85 | 86 | 87 | 88 | 89 | -------------------------------------------------------------------------------- /predict/Im2Patch.m: -------------------------------------------------------------------------------- 1 | function X = Im2Patch( im, par ) 2 | f = par.win; 3 | N = size(im,1)-f+1; 4 | M = size(im,2)-f+1; 5 | L = N*M; 6 | X = zeros(f*f, L, 'single'); 7 | k = 0; 8 | for i = 1:f 9 | for j = 1:f 10 | k = k+1; 11 | blk = im(i:end-f+i,j:end-f+j); 12 | X(k,:) = blk(:)'; 13 | end 14 | end -------------------------------------------------------------------------------- /predict/Low_rank_appro.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjia5220967/CSSFc/ff320f05e52671a9512e6f9d50c95944f4d74937/predict/Low_rank_appro.m -------------------------------------------------------------------------------- /predict/ebscdl_interp.m: -------------------------------------------------------------------------------- 1 | function hdImage = ebscdl_interp(ldImage, param, initOutImage) 2 | % ldImage: up-sampled low-resolution image 3 | 4 | if (nargin<3) 5 | initOutImage = ldImage; 6 | end 7 | hdImage = initOutImage; 8 | 9 | lassoParam=param.lassoParam; 10 | Dict=param.Dict; 11 | 12 | nOuterLoop = 5; 13 | nInnerLoop = 1; 14 | 15 | cls_num = size(param.dictVec,1); 16 | [height, width, nBands] = size(ldImage); 17 | nBands = 1; 18 | 19 | M = param.M0; 20 | param.M0=0; %||M-M0||->||M|| 21 | 22 | % Used for calculate CSNR 23 | param.height = height; 24 | param.width = width; 25 | 26 | if exist('param.psf', 'var') 27 | psf = param.psf; 28 | else 29 | psf = fspecial('gaussian', param.win+2, 2.2); 30 | end 31 | 32 | ldPatches = im2col_forSpatiotemporalFusion(ldImage, [param.win param.win], [param.step param.step], param.resRate); 33 | nBlocks = size(ldPatches, 2); 34 | 35 | nAtoms = size(Dict.DH{1},2); 36 | param.K=nAtoms; 37 | AL = zeros(nAtoms*nBands, nBlocks); 38 | AH = zeros(nAtoms*nBands, nBlocks); 39 | 40 | if (size(param.originalImage,1)>1) 41 | [psnr, mse, maxerr]=psnr_mse_maxerr(hdImage(:), param.originalImage(:)); 42 | fprintf('initial psnr: %f; rmse: %f\n', psnr, sqrt(mse)); 43 | end 44 | 45 | for outIter = 1 : nOuterLoop 46 | for iter = 1 : nInnerLoop 47 | fprintf('Iter: %d x %d\n ', outIter,iter); 48 | 49 | % clustering 50 | if iter == 1 51 | XH0 = hdImage; 52 | % XH = data2patch(conv2(XH0, psf, 'same') - XH0, param); 53 | XH = im2col_forSpatiotemporalFusion(conv2(XH0, psf, 'same') - XH0, [param.win param.win], [param.step param.step], param.resRate); 54 | cls_idx = setPatchIdx(XH(4:end,:), param.dictVec'); 55 | clear XH XH0; 56 | end 57 | 58 | % %normalize 59 | [m, stdev] = GetMeanStdVar(hdImage(:)); 60 | hdImage = (hdImage-m)/stdev; 61 | ldImage = (ldImage-m)/stdev; 62 | 63 | % image to blocks 64 | XH = im2col_forSpatiotemporalFusion(hdImage, [param.win param.win], [param.step param.step], param.resRate); % X: the high-resolution image 65 | XL = im2col_forSpatiotemporalFusion(ldImage, [param.win param.win], [param.step param.step], param.resRate); % Y: the low-resolution image 66 | 67 | meanX = repmat(mean(XH(4:end,:),1), [param.win^2 1]); 68 | XL(4:end,:) = XL(4:end,:) - meanX; 69 | XH(4:end,:) = XH(4:end,:) - meanX; 70 | 71 | for iClass = 1 : cls_num 72 | idx_cluster = find(cls_idx == iClass); 73 | length_idx = length(idx_cluster); 74 | fprintf(' class %d(%d):', iClass, length_idx); 75 | if (length_idx==0) 76 | continue; 77 | end 78 | start_idx = [1:10000:length_idx, length_idx]; 79 | for j = 1 : length(start_idx) - 1 80 | idx_temp = idx_cluster(start_idx(j):start_idx(j+1)); 81 | Xl = double(XL(4:end, idx_temp)); 82 | Xh = double(XH(4:end, idx_temp)); 83 | Dl = Dict.DL{iClass}; 84 | Dh = Dict.DH{iClass}; 85 | W = Dict.W{iClass}; 86 | 87 | if (iter == 1) 88 | alphaL = mexLasso(Xl, Dl, lassoParam); 89 | alphaH = W * alphaL; 90 | Xh = Dh * alphaH; 91 | else 92 | alphaH = AH(:, idx_temp); 93 | end 94 | 95 | alphaL = mexLasso([Xl;param.sqrtmu * full(alphaH)], [Dl; param.sqrtmu * W], lassoParam); 96 | 97 | alphaH = mexLasso([Xh;param.sqrtmu * W * full(alphaL)], [Dh; param.sqrtmu * eye(size(alphaH, 1))], lassoParam); 98 | 99 | AL(:, idx_temp) = alphaL; 100 | AH(:, idx_temp) = alphaH; 101 | flag_center = logical(XH(1, idx_temp)==1); 102 | 103 | if (any(flag_center)) 104 | %centers: 105 | y = Xl(ceil(end/2),flag_center); 106 | B = Dh*alphaH(:,flag_center)+param.lambda_3*M'*y; 107 | A = eye(size(M,2))+param.lambda_3*(M'*M); 108 | Xh = A\B; 109 | XH(4:end, idx_temp(flag_center)) = Xh; 110 | end 111 | if (any(~flag_center)) 112 | %non-centers: 113 | Xh = Dh * alphaH(:,~flag_center); 114 | XH(4:end, idx_temp(~flag_center)) = Xh; 115 | end 116 | end 117 | end 118 | fprintf('\n'); 119 | 120 | % % update Measurement matrix 121 | % flag_center = logical(XH(1, :)==1); 122 | % y = XL(end-floor(param.win*param.win/2), flag_center); 123 | % X = XH(4:end, flag_center); 124 | % M = (param.lambda_4*param.M0+y*X')/(X*X'+param.lambda_4*eye(size(X,1))); 125 | 126 | XH(4:end,:) = XH(4:end,:) + meanX; 127 | for iBand=1:nBands 128 | hdImage(:,:,iBand) = col2im_forSpatiotemporalFusion(XH, [param.win param.win]); 129 | end 130 | 131 | hdImage = hdImage*stdev+m; 132 | 133 | %%%%%%%%%%%%%%%%%% regularization %%%%%%%%%%%%%%%%%%%%% 134 | % fprintf(' start of regularization\n'); 135 | 136 | %nonlocal regulaziation 137 | % [N, ~] = Compute_NLM_Matrix( hdImage, 5, param); 138 | % NTN = N'*N*0.05; 139 | % im_f = sparse(double(reshape(hdImage, [], 1))); 140 | % for i = 1 : fix(60 / iter.^2) 141 | % im_f = im_f - NTN*im_f; 142 | % end 143 | % hdImage = reshape(full(im_f), height, width); 144 | 145 | % fprintf(' end of regularization\n'); 146 | %%%%%%%%%%%%%%%%%% regularization %%%%%%%%%%%%%%%%%%%%% 147 | 148 | if (size(param.originalImage,1)>1) 149 | [psnr, mse, maxerr]=psnr_mse_maxerr(hdImage(:), param.originalImage(:)); 150 | fprintf(' mapping psnr: %f; rmse: %f\n', psnr, sqrt(mse)); 151 | end 152 | end 153 | end 154 | 155 | -------------------------------------------------------------------------------- /predict/setPatchIdx.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chenjia5220967/CSSFc/ff320f05e52671a9512e6f9d50c95944f4d74937/predict/setPatchIdx.m -------------------------------------------------------------------------------- /predict/soft.m: -------------------------------------------------------------------------------- 1 | % define the soft threshold function, which is used above. 2 | function y = soft(x,tau) 3 | 4 | y = sign(x).*max(abs(x)-tau,0); 5 | -------------------------------------------------------------------------------- /readme.txt: -------------------------------------------------------------------------------- 1 | notice: 2 | This version is for reference ONLY as it is reorganized from the raw version without test. 3 | For better performance of this code, the sequence images should be as similar as possible. 4 | 5 | file structure: 6 | train: code for training 7 | predict: code for prediction 8 | common: auxiliary code for training and prediction 9 | data: MODIS and LandSat-7 images 10 | model: result of training 11 | result: fused images 12 | 13 | Before running the program, you should download and compile the SPAMS software, and add the "build" folder into PATH. 14 | 15 | train script: train/dict_train 16 | prediction script: predict/CSDL_predict -------------------------------------------------------------------------------- /result/readme.txt: -------------------------------------------------------------------------------- 1 | notice: 2 | This version is for reference ONLY as it is reorganized from the raw version without test. 3 | For better performance of this code, the sequence images should be as similar as possible. 4 | 5 | file structure: 6 | train: code for training 7 | predict: code for prediction 8 | common: auxiliary code for training and prediction 9 | data: MODIS and LandSat-7 images 10 | model: result of training 11 | result: fused images 12 | 13 | Before running the program, you should download and compile the SPAMS software, and add the "build" folder into PATH. 14 | 15 | train script: train/dict_train 16 | prediction script: predict/CSDL_predict -------------------------------------------------------------------------------- /train/My_kmeans.m: -------------------------------------------------------------------------------- 1 | function [cls_idx,vec,cls_num] = My_kmeans(Y, cls_num, itn) 2 | Y = Y'; 3 | [L b2] = size(Y); 4 | P = randperm(L); 5 | P2 = P(1:cls_num); 6 | vec = Y(P2(1:end), :); 7 | m_num = 2000; 8 | 9 | for i = 1 : itn 10 | cnt = zeros(1, cls_num); 11 | 12 | v_dis = zeros(L, cls_num); 13 | for k = 1 : cls_num 14 | v_dis(:, k) = (Y(:,1) - vec(k,1)).^2; 15 | for c = 2:b2 16 | v_dis(:,k) = v_dis(:,k) + (Y(:,c) - vec(k,c)).^2; 17 | end 18 | end 19 | 20 | [val cls_idx] = min(v_dis, [], 2); 21 | 22 | [s_idx, seg] = Proc_cls_idx( cls_idx ); 23 | for k = 1 : length(seg)-1 24 | idx = s_idx(seg(k)+1:seg(k+1)); 25 | cls = cls_idx(idx(1)); 26 | vec(cls,:) = mean(Y(idx, :)); 27 | cnt(cls) = length(idx); 28 | end 29 | 30 | if (i==itn-2) 31 | [val ind] = min( cnt ); % Remove these classes with little samples 32 | while (val=40) 33 | vec(ind, :) = []; 34 | cls_num = cls_num - 1; 35 | cnt(ind) = []; 36 | [val ind] = min(cnt); 37 | end 38 | end 39 | end -------------------------------------------------------------------------------- /train/Proc_cls_idx.m: -------------------------------------------------------------------------------- 1 | function [s_idx, seg] = Proc_cls_idx( cls_idx ) 2 | 3 | [idx s_idx] = sort(cls_idx); 4 | 5 | idx2 = idx(1:end-1) - idx(2:end); 6 | seq = find(idx2); 7 | 8 | seg = [0; seq; length(cls_idx)]; -------------------------------------------------------------------------------- /train/coupled_DL.m: -------------------------------------------------------------------------------- 1 | function [DH, DL, W] = coupled_DL(alphaH, XH, XL, DH, DL, W, par) 2 | % Semi-Coupled Dictionary Learning 3 | % Shenlong Wang 4 | % Reference: S Wang, L Zhang, Y Liang and Q. Pan, "Semi-coupled Dictionary Learning with Applications in Super-resolution and Photo-sketch Synthesis", CVPR 2012 5 | 6 | [dimX, numX] = size(XH); 7 | dimY = size(alphaH, 1); 8 | numD = size(DH, 2); 9 | rho = par.rho; 10 | lambda1 = par.lambda1; 11 | lambda2 = par.lambda2; 12 | mu = par.mu; 13 | sqrtmu = sqrt(mu); 14 | nu = par.nu; 15 | nIter = par.nIter; 16 | t0 = par.t0; 17 | epsilon = par.epsilon; 18 | dictSize = par.K; 19 | 20 | %lasso param 21 | param.lambda = lambda1; % not more than 20 non-zeros coefficients 22 | param.lambda2 = lambda2; 23 | param.mode = 2; % penalized formulation 24 | param.approx=0; 25 | param.K = par.K; 26 | param.L = par.L; 27 | 28 | f = 0; % cost 29 | for t = 1 : nIter 30 | 31 | % Alphat = mexLasso(Xt,D,param); 32 | f_prev = f; 33 | 34 | alphaL = mexLasso([XL;sqrtmu * full(alphaH)], [DL; sqrtmu * W], param); 35 | alphaH = mexLasso([XH;sqrtmu * W * full(alphaL)], [DH; sqrtmu * eye(size(alphaH, 1))], param); 36 | 37 | % Update D with K-SVD 38 | for i=1:dictSize 39 | ai = alphaL(i,:); 40 | Y = XL-DL*alphaL+DL(:,i)*ai; 41 | di = Y*ai'; 42 | di = di./(norm(di,2) + eps); 43 | DL(:,i) = di; 44 | end 45 | for i=1:dictSize 46 | ai = alphaH(i,:); 47 | Y = XH-DH*alphaH+DH(:,i)*ai; 48 | di = Y*ai'; 49 | di = di./(norm(di,2) + eps); 50 | DH(:,i) = di; 51 | end 52 | 53 | % Update W 54 | % Ws = Alphap * Alphas' * inv(Alphas * Alphas' + par.nu * eye(size(Alphas, 1))) ; 55 | % Wp = Alphas * Alphap' * inv(Alphap * Alphap' + par.nu * eye(size(Alphap, 1))) ; 56 | W = (1 - rho) * W + rho * alphaH * alphaL' * inv(alphaL * alphaL' + par.nu * eye(size(alphaL, 1))) ; 57 | % Wp = (1 - rho) * Wp + rho * alphaL * alphaH' * inv(alphaH * alphaH' + par.nu * eye(size(alphaH, 1))) ; 58 | 59 | % Alpha = pinv(D' * D + lambda2 * eye(numD)) * D' * X; 60 | P1 = XH - DH * alphaH; 61 | P1 = P1(:)'*P1(:) / 2; 62 | P2 = lambda1 * norm(alphaH, 1); 63 | % P3 = alphaL - Wp * alphaH; 64 | % P3 = P3(:)'*P3(:) / 2; 65 | % P4 = nu * norm(Wp, 'fro'); 66 | fp = 1 / 2 * P1 + P2;% + mu * (P3 + P4); 67 | 68 | P1 = XL - DL * alphaL; 69 | P1 = P1(:)'*P1(:) / 2; 70 | P2 = lambda1 * norm(alphaL, 1); 71 | P3 = alphaH - W * alphaL; 72 | P3 = P3(:)'*P3(:) / 2; 73 | P4 = nu * norm(W, 'fro'); 74 | fs = 1 / 2 * P1 + P2 + mu * (P3 + P4); 75 | f = fp + fs; 76 | if (abs(f_prev - f) / f < epsilon) 77 | break; 78 | end 79 | % fprintf('Energy: %d\n',f); 80 | % save tempDict_SR_NL Ds Dp Ws Wp par param i; 81 | % fprintf('Iter: %d, E1 : %d, E2 : %d, E : %d\n', t, mu * (P1 + P2), (1 - mu) * (P3 + P4), E); 82 | end 83 | -------------------------------------------------------------------------------- /train/dict_train.m: -------------------------------------------------------------------------------- 1 | 2 | % addpath('../../spams-matlab/build'); 3 | addpath('../common'); 4 | 5 | % Parameters Setting 6 | param = SetSCDLParams(); 7 | nClass = 32; 8 | param.win = 5; 9 | param.K = 128; 10 | 11 | Y1 = double(imread('../data/MOD09GHK.05-24-01.r-g-nir.tif')); 12 | X1 = double(imread('../data/L7SR.05-24-01.r-g-nir.tif')); 13 | Y3 = double(imread('../data/MOD09GHK.08-12-01.r-g-nir.tif')); 14 | X3 = double(imread('../data/L7SR.08-12-01.r-g-nir.tif')); 15 | 16 | band=1; 17 | bandName={'red' 'green' 'nir'}; 18 | Y1 = Y1(:,:,band); 19 | X1 = X1(:,:,band); 20 | Y3 = Y3(:,:,band); 21 | X3 = X3(:,:,band); 22 | X31 = X3-X1; 23 | Y31 = Y3-Y1; 24 | 25 | XH = im2col(conv2(X31, param.psf, 'same') - X31, [param.win param.win], 'sliding'); 26 | 27 | [cls_idx0,dictVec] = kmeans_1(XH, nClass); 28 | dictVec = dictVec'; 29 | 30 | filepath=['../model/Params_' bandName{band} '_' num2str(param.win) 'x' num2str(param.K) 'x' num2str(nClass)]; 31 | save(filepath, 'dictVec','param'); 32 | 33 | XH = im2col(X31, [param.win param.win], 'sliding'); 34 | XL = im2col(Y31, [param.win param.win], 'sliding'); 35 | 36 | % SCDL 37 | Dict=[]; 38 | for iClass = 1 : nClass 39 | XH_t = XH(:,cls_idx0==iClass); 40 | XL_t = XL(:,cls_idx0==iClass); 41 | XH_t = XH_t - repmat(mean(XH_t), [param.win^2 1]); 42 | XL_t = XL_t - repmat(mean(XL_t), [param.win^2 1]); 43 | fprintf('dictionary learning: Cluster: %d (%d)\n', iClass, size(XH_t,2)); 44 | D = mexTrainDL([XH_t;XL_t], param.lassoParam); 45 | Dh = D(1:param.win^2,:); 46 | Dl = D(param.win^2+1:end,:); 47 | W = eye(size(Dl, 2)); 48 | alphaH = mexLasso([XH_t;XL_t], D, param.lassoParam); 49 | alphaL = alphaH; 50 | fprintf('Semi-Coupled dictionary learning: Cluster: %d (%d)\n', iClass, size(XH_t,2)); 51 | [Dh, Dl, W] = coupled_DL(alphaH, XH_t, XL_t, Dh, Dl, W, param); 52 | Dict.DH{iClass} = Dh; 53 | Dict.DL{iClass} = Dl; 54 | Dict.W{iClass} = W; 55 | end 56 | 57 | filepath=['../model/DictSet_' bandName{band} '_' num2str(param.win) 'x' num2str(param.K) 'x' num2str(nClass)]; 58 | save(filepath, 'Dict'); 59 | 60 | rmpath('../common'); -------------------------------------------------------------------------------- /train/kmeans_1.m: -------------------------------------------------------------------------------- 1 | function [L,C] = kmeans(X,k) 2 | %KMEANS Cluster multivariate data using the k-means++ algorithm. 3 | % [L,C] = kmeans(X,k) produces a 1-by-size(X,2) vector L with one class 4 | % label per column in X and a size(X,1)-by-k matrix C containing the 5 | % centers corresponding to each class. 6 | 7 | % Version: 2013-02-08 8 | % Authors: Laurent Sorber (Laurent.Sorber@cs.kuleuven.be) 9 | % 10 | % References: 11 | % [1] J. B. MacQueen, "Some Methods for Classification and Analysis of 12 | % MultiVariate Observations", in Proc. of the fifth Berkeley 13 | % Symposium on Mathematical Statistics and Probability, L. M. L. Cam 14 | % and J. Neyman, eds., vol. 1, UC Press, 1967, pp. 281-297. 15 | % [2] D. Arthur and S. Vassilvitskii, "k-means++: The Advantages of 16 | % Careful Seeding", Technical Report 2006-13, Stanford InfoLab, 2006. 17 | 18 | L = []; 19 | L1 = 0; 20 | 21 | while length(unique(L)) ~= k 22 | 23 | % The k-means++ initialization. 24 | C = X(:,1+round(rand*(size(X,2)-1))); 25 | L = ones(1,size(X,2)); 26 | for i = 2:k 27 | D = X-C(:,L); 28 | D = cumsum(sqrt(dot(D,D,1))); 29 | if D(end) == 0, C(:,i:k) = X(:,ones(1,k-i+1)); return; end 30 | C(:,i) = X(:,find(rand < D/D(end),1)); 31 | [~,L] = max(bsxfun(@minus,2*real(C'*X),dot(C,C,1).')); 32 | end 33 | 34 | % The k-means algorithm. 35 | while any(L ~= L1) 36 | L1 = L; 37 | for i = 1:k, l = L==i; C(:,i) = sum(X(:,l),2)/sum(l); end 38 | [~,L] = max(bsxfun(@minus,2*real(C'*X),dot(C,C,1).'),[],1); 39 | end 40 | 41 | end 42 | --------------------------------------------------------------------------------