├── .gitignore ├── L72000306_SZ_B432_30m.tif ├── L72002311_SZ_B432_30m.tif ├── LICENSE ├── MOD09_2000306_SZ_B214_250m.tif ├── MOD09_2002311_SZ_B214_250m.tif ├── README.md ├── STARFM(abandon).py └── STARFM_torch.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /L72000306_SZ_B432_30m.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endu111/remote-sensing-images-fusion/0e8f6ede2cadc3cf9ce26eff557e55e1001294d8/L72000306_SZ_B432_30m.tif -------------------------------------------------------------------------------- /L72002311_SZ_B432_30m.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endu111/remote-sensing-images-fusion/0e8f6ede2cadc3cf9ce26eff557e55e1001294d8/L72002311_SZ_B432_30m.tif -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 shx951104 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MOD09_2000306_SZ_B214_250m.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endu111/remote-sensing-images-fusion/0e8f6ede2cadc3cf9ce26eff557e55e1001294d8/MOD09_2000306_SZ_B214_250m.tif -------------------------------------------------------------------------------- /MOD09_2002311_SZ_B214_250m.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endu111/remote-sensing-images-fusion/0e8f6ede2cadc3cf9ce26eff557e55e1001294d8/MOD09_2002311_SZ_B214_250m.tif -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # remote-sensing-images-fusion 2 | remote sensing images fusion,a task between "missing time series remote sensing images reconstruction" and "super resolution" 3 | 4 | language:python 5 | 6 | algorithms now:1.starfm(vectorized for accelerating,support blocking image) 7 | -------------------------------------------------------------------------------- /STARFM(abandon).py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Jan 2 23:35:07 2020 4 | 5 | @author: shx 6 | """ 7 | import math 8 | import numpy as np 9 | from osgeo import gdal 10 | from skimage.measure import compare_psnr 11 | from skimage.measure import compare_ssim 12 | from skimage.measure import compare_nrmse 13 | from skimage.measure import compare_mse 14 | import matplotlib.pyplot as plt 15 | import numba as nb 16 | import copy 17 | 18 | ##coordination distance 19 | @nb.vectorize() 20 | def xydistance(x,y,xc,yc): 21 | d=1+np.sqrt( math.pow( x-xc ,2) + math.pow( y-yc ,2) )/(window//2) 22 | return d 23 | ###because there are only one windows,we only caculate one distance matrix 24 | def windowdistance(w): 25 | x=np.arange(w)[:,None]+np.zeros(shape=w)[None,:] 26 | y=np.arange(w)[None,:]+np.zeros(shape=w)[:,None] 27 | xc=np.ones(shape=(w,w))*(w//2) 28 | yc=np.ones(shape=(w,w))*(w//2) 29 | d=xydistance(x,y,xc,yc) 30 | return d.reshape(w*w) 31 | ###get similar mask-matrix 32 | @nb.vectorize() 33 | def similar(l0,d): 34 | if abs(l0)<=d: 35 | kk=1 36 | else: 37 | kk=0 38 | return kk 39 | ###make pixels' indexs of the x-th window 40 | def makeindexs(x): 41 | return (np.arange(x,x+window)[:,None]+long*np.arange(window)[None,:]).T.reshape(window*window) 42 | 43 | ####get all center's windowindexs 44 | @nb.jit() 45 | def loopindex(number1,number2,size): 46 | ll=[] 47 | t0=int(number1//size) 48 | ti=int(number2//size) 49 | r0=number1%size 50 | r1=number2%size 51 | numindex=[] 52 | #numindex=list(range(t0*size+t0*(window-1),t0*size+t0*(window-1)+size-r0)) 53 | for i in range(t0,ti): 54 | numindex=numindex+list( range(i*size+i*(window-1),(i+1)*size+i*(window-1)) ) 55 | #numindex=numindex+list(range((ti)*size+(ti-1)*(window-1),(ti)*size+(ti-1)*(window-1)+r1)) 56 | for i in numindex: 57 | ll.append(makeindexs(i)) 58 | return np.array(ll) 59 | 60 | ###### padding 61 | def makeadd(array,number): 62 | up=array.shape[0] 63 | right=array.shape[1] 64 | a1=np.ones(shape=(up+window-1,right+window-1) )*number 65 | a1[window//2:window//2+up,window//2:window//2+right]=copy.deepcopy(array) 66 | return a1 67 | #####make divide parts 68 | def makedivid(numbers,times,size): 69 | d1=int(numbers//times) 70 | d2=(d1//size)*size 71 | left=[] 72 | right=[] 73 | for i in range(times-1): 74 | left.append(i*d2) 75 | right.append((i+1)*d2) 76 | left.append((times-1)*d2) 77 | right.append(numbers) 78 | return left,right 79 | 80 | #### main starfm 81 | def starfm(landsat0, modis0,modis1,w,times): 82 | #window size(window=w,easy for writing) 83 | #image shape 84 | right=landsat0.shape[1] 85 | up=landsat0.shape[0] 86 | #padding:extend image to get the same window pixels for border center-pixels 87 | newup=up+w-1 88 | newright=right+w-1 89 | ####just choosesome constant making the padding pixels dont meet the similar-pixles-condition.if you dont like it,defining a binary mask-matrix is esrier 90 | l0=makeadd(landsat0,-50000).reshape(newup*newright) 91 | m0=makeadd(modis0,20000).reshape(newup*newright) 92 | m1=makeadd(modis1,160000).reshape(newup*newright) 93 | #####divide several parts if you dont have enough memory 94 | left,one=makedivid(up*right,times,up) 95 | #####the choosed similar pixels select paraments 96 | d=[0.1,0.1,0.1] 97 | l1=dividstarfm(one[0],left[0],l0,m0,m1,right,up,newright,newup,w,d) 98 | if times>1: 99 | for i in range(1,times): 100 | l1=np.hstack((l1,dividstarfm(one[i],left[i],l0,m0,m1,right,up,newright,newup,w,d))) 101 | 102 | ##reconstruct image 103 | if l1.shape[0]!=up*right: 104 | print('error') 105 | else: 106 | landsat1=l1.reshape(up,right) 107 | return landsat1 108 | 109 | 110 | ######part starfm 111 | def dividstarfm(one,left,l0,m0,m1,right,up,newright,newup,w,d): 112 | ''' 113 | always delete vars, even it is python 114 | ''' 115 | print('get windowindexs') 116 | ##get windowindexs 117 | windowindex=loopindex(left,one,up) 118 | 119 | 120 | print('define window features') 121 | ##define window features 122 | f1=(abs(m1-m0))[windowindex] 123 | f2=(abs(l0-m0))[windowindex] 124 | f3=(l0)[windowindex] 125 | 126 | 127 | print('define center features') 128 | ##define center features 129 | l0r=landsat0.reshape(up*right)[left:one] 130 | m0r=modis0.reshape(up*right)[left:one] 131 | m1r=modis1.reshape(up*right)[left:one] 132 | c1=(abs(m1r-m0r))[:,None]+np.zeros(shape=w*w)[None,:] 133 | c2=(abs(l0r-m0r))[:,None]+np.zeros(shape=w*w)[None,:] 134 | c3=l0r[:,None]+np.zeros(shape=w*w)[None,:] 135 | 136 | 137 | print('select similar pixels') 138 | ##select similar pixels 139 | d1=d[0] 140 | d2=d[1] 141 | d3=d[2] 142 | mask0=similar(f1-c1,d1*c1) 143 | mask1=similar(f2-c2,d2*c2) 144 | mask2=similar(f3-c3,d3*c3) 145 | mask=mask0*mask1*mask2 146 | del mask1,mask2,mask0 147 | del c1,c2,c3 148 | 149 | 150 | print('caculate weights') 151 | ##caculate weights 152 | f12=np.log(abs( f1)+2) 153 | f22=np.log(abs( f2)+2) 154 | del f1,f2 155 | distance=windowdistance(w) 156 | catt=f12*f22*distance 157 | weight= 1/catt 158 | del f12,f22,distance 159 | weight=weight*mask 160 | data=((l0+m1-m0)[windowindex])*mask 161 | del mask,f3 162 | normweight=weight/np.sum(weight,axis=1).reshape(weight.shape[0],1) 163 | 164 | 165 | print('caculate aim landsat') 166 | ##caculate aim landsat 167 | l1=np.sum(np.multiply(data,normweight),axis=1) 168 | return l1 169 | 170 | 171 | ##read data(after pre-step:same projection,same .....),if dont,please use gdal,py6s..... 172 | landsatfir='L72000306_SZ_B432_30m.tif' 173 | modisfir='MOD09_2000306_SZ_B214_250m.tif' 174 | aimlandsatfir='L72002311_SZ_B432_30m.tif' 175 | aimmodisfir='MOD09_2002311_SZ_B214_250m.tif' 176 | 177 | landsat0=gdal.Open(landsatfir).ReadAsArray()[0] 178 | modis0=gdal.Open(modisfir).ReadAsArray()[0] 179 | landsat1=gdal.Open(aimlandsatfir).ReadAsArray()[0] 180 | modis1=gdal.Open(aimmodisfir).ReadAsArray()[0] 181 | 182 | ##define some global constant for better useing nb.vectorize() acc speed 183 | global long,window 184 | window=49 185 | long=landsat0.shape[1]+window-1 186 | 187 | 188 | ##use starfm fusion landsat and modis 189 | l1=starfm(landsat0, modis0,modis1,window,3) 190 | 191 | 192 | ##plot image 193 | plt.figure('starfm based landsat1') 194 | plt.imshow(l1.reshape(landsat0.shape)/10000,cmap=plt.cm.gray) 195 | 196 | plt.figure('real landsat1') 197 | plt.imshow(landsat1/10000,cmap=plt.cm.gray) 198 | -------------------------------------------------------------------------------- /STARFM_torch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Mar 17 15:15:36 2020 4 | 5 | @author: Administrator 6 | """ 7 | 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn as nn 12 | import time 13 | #import skimage.measure as sm 14 | import skimage.metrics as sm 15 | import cv2 16 | from osgeo import gdal 17 | import matplotlib.pyplot as plt 18 | 19 | ###img read tool############################################################### 20 | def imgread(file,mode='gdal'): 21 | if mode=='cv2': 22 | img=cv2.imread(file,-1)/10000. 23 | if mode=='gdal': 24 | img=gdal.Open(file).ReadAsArray()/10000. 25 | return img 26 | 27 | ###weight caculate tools###################################################### 28 | def weight_caculate(data): 29 | return torch.log((abs(data)*10000+1.00001)) 30 | 31 | def caculate_weight(l1m1,m1m2): 32 | #atmos difference 33 | wl1m1=weight_caculate(l1m1 ) 34 | #time deference 35 | wm1m2=weight_caculate(m1m2 ) 36 | return wl1m1*wm1m2 37 | 38 | ###space distance caculate tool################################################ 39 | def indexdistance(window): 40 | #one window, one distance weight matrix 41 | [distx,disty]=np.meshgrid(np.arange(window[0]),np.arange(window[1])) 42 | centerlocx,centerlocy=(window[0]-1)//2,(window[1]-1)//2 43 | dist=1+(((distx-centerlocx)**2+(disty-centerlocy)**2)**0.5)/((window[0]-1)//2) 44 | return dist 45 | 46 | ###threshold select tool###################################################### 47 | def weight_bythreshold(weight,data,threshold): 48 | #make weight tensor 49 | weight[data<=threshold]=1 50 | return weight 51 | def weight_bythreshold_allbands(weight,l1m1,m1m2,thresholdmax): 52 | #make weight tensor 53 | weight[l1m1<=thresholdmax[0]]=1 54 | weight[m1m2<=thresholdmax[1]]=1 55 | allweight=(weight.sum(0).view(1,weight.shape[1],weight.shape[2]))/weight.shape[0] 56 | allweight[allweight!=1]=0 57 | return allweight 58 | 59 | 60 | ###initial similar pixels tools################################################ 61 | def spectral_similar_threshold(clusters,NIR,red): 62 | thresholdNIR=NIR.std()*2/clusters 63 | thresholdred=red.std()*2/clusters 64 | return (thresholdNIR,thresholdred) 65 | 66 | def caculate_similar(l1,threshold,window): 67 | #read l1 68 | device= torch.device("cuda" if torch.cuda.is_available() else "cpu") 69 | l1=nn.functional.unfold(l1,window) 70 | #caculate similar 71 | weight=torch.zeros(l1.shape,dtype=torch.float32).to(device) 72 | centerloc=( l1.size()[1]-1)//2 73 | weight=weight_bythreshold(weight,abs(l1-l1[:,centerloc:centerloc+1,:]) ,threshold) 74 | return weight 75 | 76 | def classifier(l1): 77 | '''not used''' 78 | return 79 | 80 | ###similar pixels filter tools################################################# 81 | def allband_arrayindex(arraylist,indexarray,rawindexshape): 82 | shape=arraylist[0].shape 83 | datalist=[] 84 | for array in arraylist: 85 | newarray=torch.zeros(rawindexshape,dtype=torch.float32).cuda() 86 | for band in range(shape[1]): 87 | newarray[0,band]=array[0,band][indexarray] 88 | datalist.append(newarray) 89 | return datalist 90 | 91 | def similar_filter(datalist,sital,sitam): 92 | [l1,m1,m2]=datalist 93 | l1m1=abs(l1-m1) 94 | m1m2=abs(m2-m1) 95 | ##### 96 | l1m1=nn.functional.unfold(l1m1,(1,1)).max(1)[0]+(sital**2+sitam**2)**0.5 97 | m1m2=nn.functional.unfold(m1m2,(1,1)).max(1)[0]+(sitam**2+sitam**2)**0.5 98 | return (l1m1,m1m2) 99 | 100 | ###starfm for onepart########################################################## 101 | def starfm_onepart(datalist,similar,thresholdmax,window,outshape,dist): 102 | #####param and data 103 | [l1,m1,m2]=datalist 104 | bandsize=l1.shape[1] 105 | outshape=outshape 106 | blocksize=outshape[0]*outshape[1] 107 | device= torch.device("cuda" if torch.cuda.is_available() else "cpu") 108 | #####img to col 109 | l1=nn.functional.unfold(l1,window) 110 | m1=nn.functional.unfold(m1,window) 111 | m2=nn.functional.unfold(m2,window) 112 | l1=l1.view(bandsize,-1,blocksize) 113 | m1=m1.view(bandsize,-1,blocksize) 114 | m2=m2.view(bandsize,-1,blocksize) 115 | l1m1=abs(l1-m1) 116 | m1m2=abs(m2-m1) 117 | #####caculate weights 118 | #time and space weight 119 | w=caculate_weight(l1m1,m1m2) 120 | w=1/(w*dist) 121 | #similar pixels: 1:by threshold 2:by classifier 122 | wmask=torch.zeros(l1.shape,dtype=torch.float32).to(device) 123 | 124 | #filter similar pixels for each band: (bandsize,windowsize,blocksize) 125 | #wmasknew=weight_bythreshold(wmask,l1m1,thresholdmax[0]) 126 | #wmasknew=weight_bythreshold(wmasknew,m1m2,thresholdmax[1]) 127 | 128 | #filter similar pixels for all bands: (1,windowsize,blocksize) 129 | wmasknew=weight_bythreshold_allbands(wmask,l1m1,m1m2,thresholdmax) 130 | #mask 131 | w=w*wmasknew*similar 132 | #normili 133 | w=w/(w.sum(1).view(w.shape[0],1,w.shape[2])) 134 | #####predicte and trans 135 | #predicte l2 136 | l2=(l1+m2-m1)*w 137 | l2=l2.sum(1).reshape(1,bandsize,l2.shape[2]) 138 | #col to img 139 | l2=nn.functional.fold(l2.view(1,-1,blocksize),outshape,(1,1)) 140 | return l2 141 | ###starfm for allpart######################################################### 142 | def starfm_main(l1r,m1r,m2r, 143 | param={'part_shape':(140,140), 144 | 'window_size':(31,31), 145 | 'clusters':5, 146 | 'NIRindex':3,'redindex':2, 147 | 'sital':0.001,'sitam':0.001}): 148 | #get start time 149 | time_start=time.time() 150 | device= torch.device("cuda" if torch.cuda.is_available() else "cpu") 151 | #read parameters 152 | parts_shape=param['part_shape'] 153 | window=param['window_size'] 154 | clusters=param['clusters'] 155 | NIRindex=param['NIRindex'] 156 | redindex=param['redindex'] 157 | sital=param['sital'] 158 | sitam=param['sitam'] 159 | #caculate initial similar pixels threshold 160 | threshold=spectral_similar_threshold(clusters,l1r[:,NIRindex:NIRindex+1],l1r[:,redindex:redindex+1]) 161 | print('similar threshold (NIR,red)',threshold) 162 | ####shape 163 | imageshape=(l1r.shape[1],l1r.shape[2],l1r.shape[3]) 164 | print('datashape:',imageshape) 165 | row=imageshape[1]//parts_shape[0]+1 166 | col=imageshape[2]//parts_shape[1]+1 167 | padrow=window[0]//2 168 | padcol=window[1]//2 169 | #####padding constant for conv;STARFM use Inverse distance weight(1/w),better to avoid 0 and NAN(1/0),or you can use another distance measure 170 | constant1=10 171 | constant2=20 172 | constant3=30 173 | l1=torch.nn.functional.pad( l1r,(padrow,padcol,padrow,padcol),'constant', constant1) 174 | m1=torch.nn.functional.pad( m1r,(padrow,padcol,padrow,padcol),'constant', constant2) 175 | m2=torch.nn.functional.pad( m2r,(padrow,padcol,padrow,padcol),'constant', constant3) 176 | #split parts , get index and run for every part 177 | row_part=np.array_split( np.arange(imageshape[1]), row , axis = 0) 178 | col_part=np.array_split( np.arange(imageshape[2]), col, axis = 0) 179 | print('Split into {} parts,row number: {},col number: {}'.format(len(row_part)*len(row_part),len(row_part),len(row_part))) 180 | dist=nn.functional.unfold(torch.tensor( indexdistance(window),dtype=torch.float32).reshape(1,1,window[0],window[1]),window).to(device) 181 | 182 | for rnumber,row_index in enumerate(row_part): 183 | for cnumber,col_index in enumerate(col_part): 184 | ####run for part: (rnumber,cnumber) 185 | print('now for part{}'.format((rnumber,cnumber))) 186 | ####output index 187 | rawindex=np.meshgrid(row_index,col_index) 188 | ####output shape 189 | rawindexshape=(col_index.shape[0],row_index.shape[0]) 190 | ####the real parts_index ,for reading the padded data 191 | row_pad=np.arange(row_index[0],row_index[len(row_index)-1]+window[0]) 192 | col_pad=np.arange(col_index[0],col_index[len(col_index)-1]+window[1]) 193 | padindex=np.meshgrid(row_pad,col_pad) 194 | padindexshape=(col_pad.shape[0],row_pad.shape[0]) 195 | ####caculate initial similar pixels 196 | NIR_similar=caculate_similar(l1[0,NIRindex][ padindex ].view(1,1,padindexshape[0],padindexshape[1]),threshold[0],window) 197 | red_similar=caculate_similar(l1[0,redindex][ padindex ].view(1,1,padindexshape[0],padindexshape[1]),threshold[1],window) 198 | similar=NIR_similar*red_similar 199 | ####caculate threshold used for similar_pixels_filter 200 | thresholdmax=similar_filter( allband_arrayindex([l1r,m1r,m2r],rawindex,(1,imageshape[0],rawindexshape[0],rawindexshape[1])), 201 | sital,sitam) 202 | ####Splicing each col at rnumber-th row 203 | if cnumber==0: 204 | rowdata=starfm_onepart( allband_arrayindex([l1,m1,m2],padindex,(1,imageshape[0],padindexshape[0],padindexshape[1])), 205 | similar,thresholdmax,window,rawindexshape,dist 206 | ) 207 | 208 | else: 209 | rowdata=torch.cat( (rowdata, 210 | starfm_onepart( allband_arrayindex([l1,m1,m2],padindex,(1,imageshape[0],padindexshape[0],padindexshape[1])), 211 | similar,thresholdmax,window,rawindexshape,dist) ) ,2) 212 | ####Splicing each row 213 | if rnumber==0: 214 | l2_fake=rowdata 215 | else: 216 | l2_fake=torch.cat((l2_fake,rowdata),3) 217 | 218 | l2_fake=l2_fake.transpose(3,2) 219 | #time cost 220 | time_end=time.time() 221 | print('now over,use time {:.4f}'.format(time_end-time_start)) 222 | return l2_fake 223 | 224 | 225 | def test(): 226 | ##three band datas(sorry,just find them at home,i cant recognise the spectral response range of each band,'NIR' and 'red' are only examples) 227 | l1file='L72000306_SZ_B432_30m.tif' 228 | l2file='L72002311_SZ_B432_30m.tif' 229 | m1file='MOD09_2000306_SZ_B214_250m.tif' 230 | m2file='MOD09_2002311_SZ_B214_250m.tif' 231 | 232 | ##param 233 | param={'part_shape':(75,75), 234 | 'window_size':(31,31), 235 | 'clusters':5, 236 | 'NIRindex':1,'redindex':0, 237 | 'sital':0.001,'sitam':0.001} 238 | 239 | ##read images from files(numpy) 240 | l1=imgread(l1file) 241 | m1=imgread(m1file) 242 | m2=imgread(m2file) 243 | l2_gt=imgread(l2file) 244 | 245 | ##numpy to tensor 246 | shape=l1.shape 247 | l1r=torch.tensor(l1.reshape(1,shape[0],shape[1],shape[2]) ,dtype=torch.float32) 248 | m1r=torch.tensor(m1.reshape(1,shape[0],shape[1],shape[2]) ,dtype=torch.float32) 249 | m2r=torch.tensor(m2.reshape(1,shape[0],shape[1],shape[2]) ,dtype=torch.float32) 250 | device= torch.device("cuda" if torch.cuda.is_available() else "cpu") 251 | l1r=l1r.to(device) 252 | m1r=m1r.to(device) 253 | m2r=m2r.to(device) 254 | 255 | ##predicte(tensor input —> tensor output) 256 | l2_fake=starfm_main(l1r,m1r,m2r,param) 257 | print(l2_fake.shape) 258 | 259 | ##tensor to numpy 260 | if device.type=='cuda': 261 | l2_fake=l2_fake[0].cpu().numpy() 262 | else: 263 | l2_fake=l2_fake[0].numpy() 264 | 265 | ##show results 266 | #transform:(chanel,H,W) to (H,W,chanel) 267 | l2_fake=l2_fake.transpose(1,2,0) 268 | l2_gt=l2_gt.transpose(1,2,0) 269 | l1=l1.transpose(1,2,0) 270 | m1=m1.transpose(1,2,0) 271 | m2=m2.transpose(1,2,0) 272 | #plot 273 | plt.figure('landsat:t1') 274 | plt.imshow(l1) 275 | plt.figure('landsat:t2_fake') 276 | plt.imshow(l2_fake) 277 | plt.figure('landsat:t2_groundtrue') 278 | plt.imshow(l2_gt) 279 | 280 | ##evaluation 281 | ssim1=sm.structural_similarity(l2_fake,l2_gt,data_range=1,multichannel=True) 282 | ssim2=sm.structural_similarity(l1,l2_gt,data_range=1,multichannel=True) 283 | ssim3=sm.structural_similarity(l1+m2-m1,l2_gt,data_range=1,multichannel=True) 284 | print('with-similarpixels ssim: {:.4f};landsat_t1 ssim: {:.4f};non-similarpixels ssim: {:.4f}'.format(ssim1,ssim2,ssim3)) 285 | 286 | return 287 | if __name__ == "__main__": 288 | test() 289 | 290 | 291 | 292 | --------------------------------------------------------------------------------