├── sqgturb.png ├── setup.py ├── compute_means.py ├── COPYING ├── sqgturb ├── __init__.py ├── sqg.py └── enkf_utils.py ├── plotobnetwork.py ├── README.md ├── upscale.py ├── sqg_run.py └── sqg_lgetkf_cv.py /sqgturb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jswhit/sqgturb/HEAD/sqgturb.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from distutils.core import setup 2 | short_desc = "A program for simulating surface quasi-geostropic turbulence" 3 | setup( 4 | name = 'sqgturb', 5 | version = '0.1', 6 | description = short_desc, 7 | author = 'Jeff Whitaker', 8 | author_email = 'jeffrey dot s dot whitaker at noaa dot gov', 9 | url = 'https://github.com/jswhit/sqgturb', 10 | packages = ['sqgturb'], 11 | requires = ['numpy'] 12 | ) 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /compute_means.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | # generate time mean statistics from standard output of sqg_enkf.py 4 | file = sys.argv[1] 5 | data = np.loadtxt(file) 6 | nskip = int(sys.argv[2]) 7 | if len(sys.argv) > 3: 8 | nend = int(sys.argv[3]) 9 | else: 10 | nend = -1 11 | if nend == -1: 12 | data2 = data[nskip:,:] 13 | else: 14 | data2 = data[nskip:nend,:] 15 | data_mean = data2.mean(axis=0) 16 | data_mean[0] = data2.shape[0] 17 | print_list = ''.join(['%g ' % x for x in data_mean]) 18 | print(print_list) 19 | 20 | -------------------------------------------------------------------------------- /COPYING: -------------------------------------------------------------------------------- 1 | copyright: 2016 by Jeffrey Whitaker. 2 | 3 | Permission to use, copy, modify, and distribute this software and 4 | its documentation for any purpose and without fee is hereby granted, 5 | provided that the above copyright notice appear in all copies and that 6 | both the copyright notice and this permission notice appear in 7 | supporting documentation. 8 | 9 | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, 10 | INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO 11 | EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR 12 | CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF 13 | USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR 14 | OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR 15 | PERFORMANCE OF THIS SOFTWARE. 16 | -------------------------------------------------------------------------------- /sqgturb/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | constant PV f-plane QG turbulence (a.k.a surface QG turbulence). 3 | Doubly periodic geometry with sin(2*pi/L) jet basic state. 4 | References: 5 | http://journals.ametsoc.org/doi/pdf/10.1175/2008JAS2921.1 (section 3) 6 | http://journals.ametsoc.org/doi/pdf/10.1175/1520-0469%281978%29035%3C0774%3AUPVFPI%3E2.0.CO%3B2 7 | 8 | includes Ekman damping, linear thermal relaxation back 9 | to equilibrium jet, and hyperdiffusion. 10 | 11 | pv has units of meters per second. 12 | scale by f*theta0/g to convert to temperature. 13 | 14 | FFT spectral collocation method with 4th order Runge Kutta 15 | time stepping (dealiasing with 2/3 rule, hyperdiffusion treated implicitly). 16 | 17 | Jeff Whitaker December, 2016 18 | """ 19 | from .enkf_utils import gaspcohn, cartdist, lgetkf, lgetkf_vloc 20 | from .sqg import SQG, rfft2, irfft2 21 | 22 | __all__=['SQG','rfft2','irfft2',gaspcohn,cartdist,lgetkf,lgetkf_vloc] 23 | -------------------------------------------------------------------------------- /plotobnetwork.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('QTAgg') 3 | import matplotlib.pyplot as plt 4 | from netCDF4 import Dataset 5 | import numpy as np 6 | 7 | nc = Dataset('sqgu20_N96_6hrly.nc') 8 | ntimes, nlevs, ny, nx = nc['pv'].shape 9 | nobs = 1024 # 512 obs on each boundary 10 | 11 | rsobs = np.random.RandomState(42) 12 | indxob = np.sort(rsobs.choice(2*nx*ny,nobs,replace=False)) 13 | x = nc.variables['x'][:] 14 | y = nc.variables['y'][:] 15 | x, y = np.meshgrid(x, y) 16 | xobs = nx*np.concatenate((x.ravel(),x.ravel()))[indxob]/nc.L 17 | yobs = ny*np.concatenate((y.ravel(),y.ravel()))[indxob]/nc.L 18 | 19 | # just plot obs on lower boundary 20 | pv = nc['pv'][-1,0,...] # last time, lower boundary 21 | plt.imshow(pv,cmap=plt.cm.jet,interpolation='nearest',origin='lower') 22 | plt.scatter(xobs[:nobs//2], yobs[:nobs//2], s=5, color='black') 23 | 24 | # just plot obs on upper boundary 25 | #pv = nc['pv'][-1,1,...] # last time, upper boundary 26 | #plt.imshow(pv,cmap=plt.cm.jet,interpolation='nearest',origin='lower') 27 | #plt.scatter(xobs[nobs//2:], yobs[nobs//2:], s=5, color='black') 28 | 29 | plt.axis('off') 30 | plt.tight_layout() 31 | plt.savefig('obnetwork.png') 32 | plt.show() 33 | 34 | 35 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | constant PV f-plane QG turbulence (a.k.a surface QG turbulence). 2 | 3 | ![SQG Turbulence](sqgturb.png?raw=true "SQG Turbulence") 4 | 5 | Doubly periodic geometry with sin(2*pi/L) jet basic state. 6 | 7 | References: 8 | 9 | * http://journals.ametsoc.org/doi/pdf/10.1175/2008JAS2921.1 (section 3) 10 | * http://journals.ametsoc.org/doi/pdf/10.1175/1520-0469%281978%29035%3C0774%3AUPVFPI%3E2.0.CO%3B2 11 | 12 | includes Ekman damping, linear thermal relaxation back 13 | to equilibrium jet, and hyperdiffusion. 14 | 15 | pv has units of meters per second. 16 | scale by f*theta0/g to convert to temperature. 17 | 18 | FFT spectral collocation method with 4th order Runge Kutta 19 | time stepping (dealiasing with 2/3 rule, hyperdiffusion treated implicitly). 20 | 21 | Requires numpy (pyfftw, netcdf4-python and matplotlib highly recommended). 22 | 23 | example code to run model and animate the solution in ``sqg_run.py`` 24 | 25 | To run EnKF data assimilation: 26 | * install with ``python setup.py install``. 27 | * first run ``sqg_run.py`` to generate nature run. 28 | * then run ``sqg_lgetkf_cv.py``. (uses [local gain-form ensemble transform KF](https://doi.org/10.1175/MWR-D-19-0402.1) with cross-validation) 29 | -------------------------------------------------------------------------------- /upscale.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | from netCDF4 import Dataset 4 | import numpy as np 5 | import sys 6 | from sqgturb import rfft2, irfft2 7 | from scipy import ndimage 8 | 9 | def block_mean(ar, fact): 10 | # downsample 2d array by averaging fact x fact blocks 11 | # requires scipy.ndimage 12 | #assert isinstance(fact, int), type(fact) 13 | sx, sy = ar.shape 14 | X, Y = np.ogrid[0:sx, 0:sy] 15 | regions = sy//fact * (X//fact) + Y//fact 16 | res = ndimage.mean(ar, labels=regions, 17 | index=np.arange(regions.max() + 1)) 18 | res.shape = (sx//fact, sy//fact) 19 | return res 20 | 21 | def spectrunc(specarr,N): 22 | fact = float(N)/float(specarr.shape[1]) 23 | specarr_trunc = np.zeros((2, N, N//2+1), specarr.dtype) 24 | specarr_trunc[:,0:N//2,0:N//2] = fact**2*specarr[:,0:N//2,0:N//2] 25 | specarr_trunc[:,-N//2:,0:N//2] = fact**2*specarr[:,-N//2:,0:N//2] 26 | return specarr_trunc 27 | 28 | # spectrally truncate or block average data in filenamein, write to filenameout on Nout x Nout 29 | # grid. 30 | filenamein = sys.argv[1] 31 | filenameout = sys.argv[2] 32 | Nout = int(sys.argv[3]) 33 | blockmean = bool(int(sys.argv[4])) 34 | print('Nout, blockmean = ',Nout,blockmean) 35 | 36 | ncin = Dataset(filenamein) 37 | nc = Dataset(filenameout, mode='w', format='NETCDF4_CLASSIC') 38 | nc.r = ncin.r 39 | nc.f = ncin.f 40 | nc.U = ncin.U 41 | nc.L = ncin.L 42 | nc.H = ncin.H 43 | nc.g = ncin.g; nc.theta0 = ncin.theta0 44 | nc.nsq = ncin.nsq 45 | nc.tdiab = ncin.tdiab 46 | nc.dt = ncin.dt 47 | nc.diff_efold = ncin.diff_efold 48 | nc.diff_order = ncin.diff_order 49 | nc.symmetric = ncin.symmetric 50 | nc.dealias = ncin.dealias 51 | x = nc.createDimension('x',Nout) 52 | y = nc.createDimension('y',Nout) 53 | z = nc.createDimension('z',2) 54 | t = nc.createDimension('t',None) 55 | pvvar =\ 56 | nc.createVariable('pv',np.float32,('t','z','y','x'),zlib=True) 57 | pvvar.units = 'K' 58 | # pv scaled by g/(f*theta0) so du/dz = d(pv)/dy 59 | xvar = nc.createVariable('x',np.float32,('x',)) 60 | xvar.units = 'meters' 61 | yvar = nc.createVariable('y',np.float32,('y',)) 62 | yvar.units = 'meters' 63 | zvar = nc.createVariable('z',np.float32,('z',)) 64 | zvar.units = 'meters' 65 | tvar = nc.createVariable('t',np.float32,('t',)) 66 | tvar.units = 'seconds' 67 | xvar[:] = np.arange(0,ncin.L,ncin.L/Nout) 68 | yvar[:] = np.arange(0,ncin.L,ncin.L/Nout) 69 | zvar[0] = 0; zvar[1] = ncin.H 70 | N = ncin['pv'].shape[-1] 71 | print(N,Nout) 72 | nskip = N//Nout 73 | print('nskip = ',nskip) 74 | ntimes = len(ncin.dimensions['t']) 75 | #ntimes = 10 76 | for n in range(ntimes): 77 | tvar[n] = ncin['t'][n] 78 | pvin = ncin['pv'][n] 79 | # downsample by averaging nskip x nskip blocks of pixels 80 | if blockmean: 81 | pvout = np.empty((2,Nout,Nout), dtype=pvin.dtype) 82 | for k in range(2): 83 | pvout[k,:,:] = block_mean(pvin[k],nskip) 84 | # spectrally truncate. 85 | else: 86 | pvout = irfft2(spectrunc(rfft2(pvin),Nout)) 87 | pvvar[n] = pvout 88 | print(n,tvar[n]/86400.,pvin.shape,pvout.shape,pvin.min(),pvin.max(),pvout.min(),pvout.max()) 89 | nc.close() 90 | scalefact = ncin.f*ncin.theta0/ncin.g 91 | ncin.close() 92 | 93 | # make plot 94 | import matplotlib.pyplot as plt 95 | fig = plt.figure(figsize=(16,8)) 96 | vmin = -25; vmax= 25 97 | ax = fig.add_subplot(1,2,1) 98 | ax.axis('off') 99 | im = plt.imshow(scalefact*pvin[1],cmap=plt.cm.jet,interpolation='nearest',origin='lower',vmin=vmin,vmax=vmax) 100 | plt.title('%s x %s solution' % (N,N) ,fontsize=18,fontweight='bold') 101 | ax = fig.add_subplot(1,2,2) 102 | ax.axis('off') 103 | im = plt.imshow(scalefact*pvout[1],cmap=plt.cm.jet,interpolation='nearest',origin='lower',vmin=vmin,vmax=vmax) 104 | if blockmean: 105 | plt.title('upscaled to %s x %s (block mean)' % (Nout,Nout),fontsize=18,fontweight='bold') 106 | else: 107 | plt.title('upscaled to %s x %s (spectrally truncated)' % (Nout, Nout),fontsize=18,fontweight='bold') 108 | #plt.tight_layout() 109 | plt.savefig('upscale.png') 110 | plt.show() 111 | -------------------------------------------------------------------------------- /sqg_run.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('qtagg') 3 | from sqgturb import SQG, rfft2, irfft2 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import matplotlib.animation as animation 7 | import os 8 | 9 | # run SQG turbulence simulation, optionally plotting results to screen and/or saving to 10 | # netcdf file. 11 | 12 | # model parameters. 13 | 14 | #N = 512 # number of grid points in each direction (waves=N/2) 15 | #dt = 90 # time step in seconds 16 | #diff_efold = 1800. # time scale for hyperdiffusion at smallest resolved scale 17 | 18 | #N = 192 19 | #dt = 300 20 | #diff_efold = 86400./8. 21 | # 22 | #N = 128 23 | #dt = 600 24 | #diff_efold = 86400./3. 25 | 26 | N = 96 27 | dt = 900 28 | diff_efold = 86400./2. 29 | 30 | #N = 64 31 | #dt = 900 32 | #diff_efold = 86400./2. 33 | 34 | norder = 8 # order of hyperdiffusion 35 | dealias = True # dealiased with 2/3 rule? 36 | 37 | # Ekman damping coefficient r=dek*N**2/f, dek = ekman depth = sqrt(2.*Av/f)) 38 | # Av (turb viscosity) = 2.5 gives dek = sqrt(5/f) = 223 39 | # for ocean Av is 1-5, land 5-50 (Lin and Pierrehumbert, 1988) 40 | # corresponding to ekman depth of 141-316 m over ocean. 41 | # spindown time of a barotropic vortex is tau = H/(f*dek), 10 days for 42 | # H=10km, f=0.0001, dek=100m. 43 | dek = 0 # applied only at surface if symmetric=False 44 | nsq = 1.e-4; f=1.e-4; g = 9.8; theta0 = 300 45 | H = 10.e3 # lid height 46 | r = dek*nsq/f 47 | U = 20 # jet speed 48 | Lr = np.sqrt(nsq)*H/f # Rossby radius 49 | L = 20.*Lr 50 | # thermal relaxation time scale 51 | tdiab = 10.*86400 # in seconds 52 | symmetric = True # (if False, asymmetric equilibrium jet with zero wind at sfc) 53 | # parameter used to scale PV to temperature units. 54 | scalefact = f*theta0/g 55 | 56 | # create random noise 57 | pv = np.random.normal(0,100.,size=(2,N,N)).astype(np.float32) 58 | # add isolated blob on lid 59 | nexp = 20 60 | x = np.arange(0,2.*np.pi,2.*np.pi/N); y = np.arange(0.,2.*np.pi,2.*np.pi/N) 61 | x,y = np.meshgrid(x,y) 62 | x = x.astype(np.float32); y = y.astype(np.float32) 63 | pv[1] = pv[1]+2000.*(np.sin(x/2)**(2*nexp)*np.sin(y)**nexp) 64 | # remove area mean from each level. 65 | for k in range(2): 66 | pv[k] = pv[k] - pv[k].mean() 67 | 68 | # get OMP_NUM_THREADS (threads to use) from environment. 69 | threads = int(os.getenv('OMP_NUM_THREADS','1')) 70 | 71 | # single or double precision 72 | precision='single' # pyfftw FFTs twice as fast as double 73 | 74 | # initialize qg model instance 75 | model = SQG(pv,nsq=nsq,f=f,U=U,H=H,r=r,tdiab=tdiab,dt=dt, 76 | diff_order=norder,diff_efold=diff_efold, 77 | dealias=dealias,symmetric=symmetric,threads=threads, 78 | precision=precision,tstart=0) 79 | 80 | # initialize figure. 81 | outputinterval = 6.*3600. # interval between frames in seconds 82 | tmin = 100.*86400. # time to start saving data (in days) 83 | tmax = 300.*86400. # time to stop (in days) 84 | nsteps = int(tmax/outputinterval) # number of time steps to animate 85 | # set number of timesteps to integrate for each call to model.advance 86 | model.timesteps = int(outputinterval/model.dt) 87 | savedata = 'sqgu20_N%s_6hrly.nc' % N # save data plotted in a netcdf file. 88 | #savedata = None # don't save data 89 | plot = True # animate data as model is running? 90 | 91 | if savedata is not None: 92 | from netCDF4 import Dataset 93 | nc = Dataset(savedata, mode='w', format='NETCDF4_CLASSIC') 94 | nc.r = model.r 95 | nc.f = model.f 96 | nc.U = model.U 97 | nc.L = model.L 98 | nc.H = model.H 99 | nc.g = g; nc.theta0 = theta0 100 | nc.nsq = model.nsq 101 | nc.tdiab = model.tdiab 102 | nc.dt = model.dt 103 | nc.diff_efold = model.diff_efold 104 | nc.diff_order = model.diff_order 105 | nc.symmetric = int(model.symmetric) 106 | nc.dealias = int(model.dealias) 107 | x = nc.createDimension('x',N) 108 | y = nc.createDimension('y',N) 109 | z = nc.createDimension('z',2) 110 | t = nc.createDimension('t',None) 111 | pvvar =\ 112 | nc.createVariable('pv',np.float32,('t','z','y','x'),zlib=True) 113 | pvvar.units = 'K' 114 | # pv scaled by g/(f*theta0) so du/dz = d(pv)/dy 115 | xvar = nc.createVariable('x',np.float32,('x',)) 116 | xvar.units = 'meters' 117 | yvar = nc.createVariable('y',np.float32,('y',)) 118 | yvar.units = 'meters' 119 | zvar = nc.createVariable('z',np.float32,('z',)) 120 | zvar.units = 'meters' 121 | tvar = nc.createVariable('t',np.float32,('t',)) 122 | tvar.units = 'seconds' 123 | xvar[:] = np.arange(0,model.L,model.L/N) 124 | yvar[:] = np.arange(0,model.L,model.L/N) 125 | zvar[0] = 0; zvar[1] = model.H 126 | 127 | levplot = 1; nout = 0 # levplot < 0 is vertical mean PV 128 | if plot: 129 | fig = plt.figure(figsize=(14,8)) 130 | fig.subplots_adjust(left=0.05, bottom=0.05, top=0.95, right=0.95) 131 | vmin = scalefact*model.pvbar[levplot].min() 132 | vmax = scalefact*model.pvbar[levplot].max() 133 | if levplot < 0: 134 | vmin=0.8*vmin; vmax=0.8*vmax 135 | def initfig(): 136 | global im1,im2 137 | ax1 = fig.add_subplot(121) 138 | ax1.axis('off') 139 | pv = irfft2(model.pvspec[levplot]) # spectral to grid 140 | im1 = ax1.imshow(scalefact*pv,cmap=plt.cm.jet,interpolation='nearest',origin='lower',vmin=vmin,vmax=vmax) 141 | ax2 = fig.add_subplot(122) 142 | ax2.axis('off') 143 | pvspec_mean = model.meantemp() 144 | pv = irfft2(pvspec_mean) # mean pv 145 | im2 = ax2.imshow(scalefact*pv,cmap=plt.cm.jet,interpolation='nearest',origin='lower',vmin=vmin,vmax=vmax) 146 | return im1,im2, 147 | def updatefig(*args): 148 | global nout 149 | model.advance() 150 | t = model.t 151 | pv = irfft2(model.pvspec[levplot]) 152 | hr = t/3600. 153 | spd = np.sqrt(model.u[levplot]**2+model.v[levplot]**2) 154 | print(hr,spd.max(),scalefact*pv.min(),scalefact*pv.max()) 155 | im1.set_data(scalefact*pv) 156 | pvspec_mean = model.meantemp() 157 | pv = irfft2(pvspec_mean) # mean pv 158 | im2.set_data(scalefact*pv) 159 | if savedata is not None and t >= tmin: 160 | print('saving data at t = t = %g hours' % hr) 161 | pvvar[nout,:,:,:] = irfft2(model.pvspec) 162 | tvar[nout] = t 163 | nc.sync() 164 | if t >= tmax: nc.close() 165 | nout = nout + 1 166 | return im1,im2, 167 | 168 | # interval=0 means draw as fast as possible 169 | ani = animation.FuncAnimation(fig, updatefig, frames=nsteps, repeat=False,\ 170 | init_func=initfig,interval=0,blit=True) 171 | plt.show() 172 | else: 173 | t = 0.0 174 | while t < tmax: 175 | model.advance() 176 | t = model.t 177 | pv = irfft2(model.pvspec) 178 | hr = t/3600. 179 | spd = np.sqrt(model.u[levplot]**2+model.v[levplot]**2) 180 | print(hr,spd.max(),scalefact*pv.min(),scalefact*pv.max()) 181 | if savedata is not None and t >= tmin: 182 | print('saving data at t = t = %g hours' % hr) 183 | pvvar[nout,:,:,:] = pv 184 | tvar[nout] = t 185 | nc.sync() 186 | if t >= tmax: nc.close() 187 | nout = nout + 1 188 | -------------------------------------------------------------------------------- /sqgturb/sqg.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | 4 | try: # pyfftw is *much* faster 5 | from pyfftw.interfaces import numpy_fft, cache 6 | 7 | # print('# using pyfftw...') 8 | cache.enable() 9 | rfft2 = numpy_fft.rfft2 10 | irfft2 = numpy_fft.irfft2 11 | except ImportError: # fall back on numpy fft. 12 | print("# WARNING: using numpy fft (install pyfftw for better performance)...") 13 | 14 | def rfft2(*args, **kwargs): 15 | kwargs.pop("threads", None) 16 | return np.fft.rfft2(*args, **kwargs) 17 | 18 | def irfft2(*args, **kwargs): 19 | kwargs.pop("threads", None) 20 | return np.fft.irfft2(*args, **kwargs) 21 | 22 | 23 | class SQG: 24 | def __init__( 25 | self, 26 | pv, 27 | f=1.0e-4, 28 | nsq=1.0e-4, 29 | L=20.0e6, 30 | H=10.0e3, 31 | U=30.0, 32 | r=0.0, 33 | tdiab=10.0 * 86400, 34 | diff_order=8, 35 | diff_efold=None, 36 | theta0=300, 37 | g=9.8, 38 | symmetric=True, 39 | dt=None, 40 | dealias=True, 41 | threads=1, 42 | precision="single", 43 | tstart=0, 44 | ): 45 | # initialize SQG model. 46 | if pv.shape[0] != 2: 47 | raise ValueError("1st dim of pv should be 2") 48 | N = pv.shape[1] # number of grid points in each direction 49 | # N should be even 50 | if N % 2: 51 | raise ValueError("N must be even (powers of 2 are fastest)") 52 | if dt is None: # time step must be specified 53 | raise ValueError("must specify time step") 54 | if diff_efold is None: # efolding time scale for diffusion must be specified 55 | raise ValueError("must specify efolding time scale for diffusion") 56 | # number of openmp threads to use for FFTs (only for pyfftw) 57 | self.threads = threads 58 | self.N = N 59 | if precision == "single": 60 | # ffts in single precision (faster) 61 | dtype = np.float32 62 | elif precision == "double": 63 | # ffts in double precision 64 | dtype = np.float64 65 | else: 66 | msg = "precision must be 'single' or 'double'" 67 | raise ValueError(msg) 68 | # force arrays to be float32 for precision='single' (ffts are twice as fast) 69 | self.nsq = np.array(nsq, dtype) # Brunt-Vaisalla (buoyancy) freq squared 70 | self.f = np.array(f, dtype) # coriolis 71 | self.H = np.array(H, dtype) # height of upper boundary 72 | self.U = np.array(U, dtype) # basic state velocity at z = H 73 | self.L = np.array(L, dtype) # size of square domain. 74 | # theta0,g only used to convert pv to temp units (K). 75 | self.theta0 = np.array(theta0, dtype) # mean temp 76 | self.g = np.array(g, dtype) # gravity 77 | self.dt = np.array(dt, dtype) # time step (seconds) 78 | self.dealias = dealias # if True, dealiasing applied using 2/3 rule. 79 | if r < 1.0e-10: 80 | self.ekman = False 81 | else: 82 | self.ekman = True 83 | self.r = np.array(r, dtype) # Ekman damping (at z=0) 84 | self.tdiab = np.array(tdiab, dtype) # thermal relaxation damping. 85 | self.t = tstart # initialize time counter 86 | # setup basic state pv (for thermal relaxation) 87 | self.symmetric = symmetric # symmetric jet, or jet with U=0 at sfc. 88 | y = np.arange(0, L, L / N, dtype=dtype) 89 | pvbar = np.zeros((2, N), dtype) 90 | pi = np.array(np.pi, dtype) 91 | l = 2.0 * pi / L 92 | mu = l * np.sqrt(nsq) * H / f 93 | if symmetric: 94 | # symmetric version, no difference between upper and lower 95 | # boundary. 96 | # l = 2.*pi/L and mu = l*N*H/f 97 | # u = -0.5*U*np.sin(l*y)*np.sinh(mu*(z-0.5*H)/H)*np.sin(l*y)/np.sinh(0.5*mu) 98 | # theta = (f*theta0/g)*(0.5*U*mu/(l*H))*np.cosh(mu*(z-0.5*H)/H)* 99 | # np.cos(l*y)/np.sinh(0.5*mu) 100 | # + theta0 + (theta0*nsq*z/g) 101 | pvbar[:] = ( 102 | -(mu * 0.5 * U / (l * H)) 103 | * np.cosh(0.5 * mu) 104 | * np.cos(l * y) 105 | / np.sinh(0.5 * mu) 106 | ) 107 | else: 108 | # asymmetric version, equilibrium state has no flow at surface and 109 | # temp gradient slightly weaker at sfc. 110 | # u = U*np.sin(l*y)*np.sinh(mu*z/H)*np.sin(l*y)/np.sinh(mu) 111 | # theta = (f*theta0/g)*(U*mu/(l*H))*np.cosh(mu*z/H)* 112 | # np.cos(l*y)/np.sinh(mu) 113 | # + theta0 + (theta0*nsq*z/g) 114 | pvbar[:] = -(mu * U / (l * H)) * np.cos(l * y) / np.sinh(mu) 115 | pvbar[1, :] = pvbar[0, :] * np.cosh(mu) 116 | pvbar.shape = (2, N, 1) 117 | pvbar = pvbar * np.ones((2, N, N), dtype) 118 | self.pvbar = pvbar 119 | self.pvspec_eq = rfft2(pvbar) # state to relax to with timescale tdiab 120 | self.pvspec = rfft2(pv) # initial pv field (spectral) 121 | # spectral stuff 122 | k = (N * np.fft.fftfreq(N))[0 : (N // 2) + 1] 123 | l = N * np.fft.fftfreq(N) 124 | kk, ll = np.meshgrid(k, l) 125 | k = kk.astype(dtype) 126 | l = ll.astype(dtype) 127 | # dimensionalize wavenumbers. 128 | k = 2.0 * pi * k / self.L 129 | l = 2.0 * pi * l / self.L 130 | ksqlsq = k ** 2 + l ** 2 131 | self.k = k 132 | self.l = l 133 | self.ksqlsq = ksqlsq 134 | self.ik = (1.0j * k).astype(np.complex64) 135 | self.il = (1.0j * l).astype(np.complex64) 136 | self.wavenums = np.sqrt(kk**2+ll**2) 137 | if dealias: # arrays needed for dealiasing nonlinear Jacobian 138 | k_pad = ((3 * N // 2) * np.fft.fftfreq(3 * N // 2))[0 : (3 * N // 4) + 1] 139 | l_pad = (3 * N // 2) * np.fft.fftfreq(3 * N // 2) 140 | k_pad, l_pad = np.meshgrid(k_pad, l_pad) 141 | k_pad = k_pad.astype(dtype) 142 | l_pad = l_pad.astype(dtype) 143 | k_pad = 2.0 * pi * k_pad / self.L 144 | l_pad = 2.0 * pi * l_pad / self.L 145 | self.ik_pad = (1.0j * k_pad).astype(np.complex64) 146 | self.il_pad = (1.0j * l_pad).astype(np.complex64) 147 | mu = np.sqrt(ksqlsq) * np.sqrt(self.nsq) * self.H / self.f 148 | mu = mu.clip(np.finfo(mu.dtype).eps) # clip to avoid NaN 149 | self.Hovermu = self.H / mu 150 | mu = mu.astype(np.float64) # cast to avoid overflow in sinh 151 | self.tanhmu = np.tanh(mu).astype(dtype) # cast back to original type 152 | self.sinhmu = np.sinh(mu).astype(dtype) 153 | self.diff_order = np.array(diff_order, dtype) # hyperdiffusion order 154 | self.diff_efold = np.array(diff_efold, dtype) # hyperdiff time scale 155 | ktot = np.sqrt(ksqlsq) 156 | ktotcutoff = np.array(pi * N / self.L, dtype) 157 | # integrating factor for hyperdiffusion 158 | # with efolding time scale for diffusion of shortest wave (N/2) 159 | self.hyperdiff = np.exp( 160 | (-self.dt / self.diff_efold) * (ktot / ktotcutoff) ** self.diff_order 161 | ) 162 | # number of timesteps to advance each call to 'advance' method. 163 | self.timesteps = 1 164 | 165 | def invert(self, pvspec=None): 166 | if pvspec is None: 167 | pvspec = self.pvspec 168 | # invert boundary pv to get streamfunction 169 | psispec = np.empty((2, self.N, self.N // 2 + 1), dtype=pvspec.dtype) 170 | psispec[0] = self.Hovermu * ( 171 | (pvspec[1] / self.sinhmu) - (pvspec[0] / self.tanhmu) 172 | ) 173 | psispec[1] = self.Hovermu * ( 174 | (pvspec[1] / self.tanhmu) - (pvspec[0] / self.sinhmu) 175 | ) 176 | return psispec 177 | 178 | def meantemp(self, pvspec=None): 179 | # vertical integral from 0 to H of d/dz(eqn 4) in 180 | # https://doi.org/10.1175/2008JAS2921.1, 181 | # divided by H. 182 | # Since the vertical integral of d/dz(psi eqn) is simply the 183 | # psi eqn (eqn 4) - this boils down to evaluating psi at top 184 | # minus psi at bottom and dividing by H. 185 | # This is not the same as average of temperature at top and bottom! 186 | if pvspec is None: 187 | pvspec = self.pvspec 188 | psispec = self.invert(pvspec=pvspec) 189 | pvavspec = (psispec[1]-psispec[0])/self.H 190 | return pvavspec 191 | 192 | def invert_inverse(self, psispec=None): 193 | if psispec is None: 194 | psispec = self.invert(self.pvspec) 195 | # given streamfunction, return PV 196 | pvspec = np.empty((2, self.N, self.N // 2 + 1), dtype=psispec.dtype) 197 | alpha = self.Hovermu 198 | th = self.tanhmu 199 | sh = self.sinhmu 200 | tmp1 = 1.0 / sh ** 2 - 1.0 / th ** 2 201 | tmp1[0, 0] = 1.0 202 | pvspec[0] = ((psispec[0] / th) - (psispec[1] / sh)) / (alpha * tmp1) 203 | pvspec[1] = ((psispec[0] / sh) - (psispec[1] / th)) / (alpha * tmp1) 204 | pvspec[:, 0, 0] = 0.0 # area mean PV not determined by streamfunction 205 | return pvspec 206 | 207 | def advance(self, pv=None): 208 | # given total pv on grid, advance forward 209 | # number of timesteps given by 'timesteps' instance var. 210 | # if pv not specified, use pvspec instance variable. 211 | if pv is not None: 212 | self.pvspec = rfft2(pv, threads=self.threads) 213 | for n in range(self.timesteps): 214 | self.timestep() 215 | return irfft2(self.pvspec, threads=self.threads) 216 | 217 | def specpad(self, specarr): 218 | # pad spectral arrays with zeros to get 219 | # interpolation to 3/2 larger grid using inverse fft. 220 | # take care of normalization factor for inverse transform. 221 | specarr_pad = np.zeros((2, 3 * self.N // 2, 3 * self.N // 4 + 1), specarr.dtype) 222 | specarr_pad[:, 0 : self.N // 2, 0 : self.N // 2] = ( 223 | 2.25 * specarr[:, 0 : self.N // 2, 0 : self.N // 2] 224 | ) 225 | specarr_pad[:, -self.N // 2 :, 0 : self.N // 2] = ( 226 | 2.25 * specarr[:, -self.N // 2 :, 0 : self.N // 2] 227 | ) 228 | # include negative Nyquist frequency. 229 | specarr_pad[:, 0 : self.N // 2, self.N // 2] = np.conjugate( 230 | 2.25 * specarr[:, 0 : self.N // 2, -1] 231 | ) 232 | specarr_pad[:, -self.N // 2 :, self.N // 2] = np.conjugate( 233 | 2.25 * specarr[:, -self.N // 2 :, -1] 234 | ) 235 | return specarr_pad 236 | 237 | def spectrunc(self, specarr): 238 | # truncate spectral array using 2/3 rule. 239 | specarr_trunc = np.zeros((2, self.N, self.N // 2 + 1), specarr.dtype) 240 | specarr_trunc[:, 0 : self.N // 2, 0 : self.N // 2] = specarr[ 241 | :, 0 : self.N // 2, 0 : self.N // 2 242 | ] 243 | specarr_trunc[:, -self.N // 2 :, 0 : self.N // 2] = specarr[ 244 | :, -self.N // 2 :, 0 : self.N // 2 245 | ] 246 | return specarr_trunc 247 | 248 | def xyderiv(self, specarr): 249 | if not self.dealias: 250 | xderiv = irfft2(self.ik * specarr, threads=self.threads) 251 | yderiv = irfft2(self.il * specarr, threads=self.threads) 252 | else: # pad spectral coeffs with zeros for dealiased jacobian 253 | specarr_pad = self.specpad(specarr) 254 | xderiv = irfft2(self.ik_pad * specarr_pad, threads=self.threads) 255 | yderiv = irfft2(self.il_pad * specarr_pad, threads=self.threads) 256 | return xderiv, yderiv 257 | 258 | def gettend(self, pvspec=None): 259 | # compute tendencies of pv on z=0,H 260 | # invert pv to get streamfunction 261 | if pvspec is None: 262 | pvspec = self.pvspec 263 | psispec = self.invert(pvspec) 264 | # nonlinear jacobian and thermal relaxation 265 | psix, psiy = self.xyderiv(psispec) 266 | pvx, pvy = self.xyderiv(pvspec) 267 | jacobian = psix * pvy - psiy * pvx 268 | jacobianspec = rfft2(jacobian, threads=self.threads) 269 | if self.dealias: # 2/3 rule: truncate spectral coefficients of jacobian 270 | jacobianspec = self.spectrunc(jacobianspec) 271 | dpvspecdt = (1.0 / self.tdiab) * (self.pvspec_eq - pvspec) - jacobianspec 272 | # Ekman damping at boundaries. 273 | if self.ekman: 274 | dpvspecdt[0] += self.r * self.ksqlsq * psispec[0] 275 | # for asymmetric jet (U=0 at sfc), no Ekman layer at lid 276 | if self.symmetric: 277 | dpvspecdt[1] -= self.r * self.ksqlsq * psispec[1] 278 | # save wind field 279 | self.u = -psiy 280 | self.v = psix 281 | return dpvspecdt 282 | 283 | def timestep(self): 284 | # update pv using 4th order runge-kutta time step with 285 | # implicit "integrating factor" treatment of hyperdiffusion. 286 | self.rkstep = 0 287 | k1 = self.dt * self.gettend(self.pvspec) 288 | self.rkstep = 1 289 | k2 = self.dt * self.gettend(self.pvspec + 0.5 * k1) 290 | self.rkstep = 2 291 | k3 = self.dt * self.gettend(self.pvspec + 0.5 * k2) 292 | self.rkstep = 3 293 | k4 = self.dt * self.gettend(self.pvspec + k3) 294 | pvspecnew = self.pvspec + (k1 + 2.0 * k2 + 2.0 * k3 + k4) / 6.0 295 | self.pvspec = self.hyperdiff * pvspecnew 296 | self.t += self.dt # increment time 297 | -------------------------------------------------------------------------------- /sqgturb/enkf_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.linalg import lapack, inv 3 | 4 | # function definitions. 5 | 6 | 7 | def cartdist(x1, y1, x2, y2, xmax, ymax): 8 | """cartesian distance on doubly periodic plane""" 9 | dx = np.abs(x1 - x2) 10 | dy = np.abs(y1 - y2) 11 | dx = np.where(dx > 0.5 * xmax, xmax - dx, dx) 12 | dy = np.where(dy > 0.5 * ymax, ymax - dy, dy) 13 | return np.sqrt(dx ** 2 + dy ** 2) 14 | 15 | 16 | def gaspcohn(r): 17 | """ 18 | Gaspari-Cohn taper function. 19 | very close to exp(-(r/c)**2), where c = sqrt(0.15) 20 | r should be >0 and normalized so taper = 0 at r = 1 21 | """ 22 | rr = 2.0 * r 23 | rr += 1.0e-13 # avoid divide by zero warnings from numpy 24 | taper = np.where( 25 | r <= 0.5, 26 | (((-0.25 * rr + 0.5) * rr + 0.625) * rr - 5.0 / 3.0) * rr ** 2 + 1.0, 27 | np.zeros(r.shape, r.dtype), 28 | ) 29 | taper = np.where( 30 | np.logical_and(r > 0.5, r < 1.0), 31 | ((((rr / 12.0 - 0.5) * rr + 0.625) * rr + 5.0 / 3.0) * rr - 5.0) * rr 32 | + 4.0 33 | - 2.0 / (3.0 * rr), 34 | taper, 35 | ) 36 | return taper 37 | 38 | def lgetkf(xens, hxens, obs, oberrs, covlocal, nerger=True, ngroups=None): 39 | 40 | """returns ensemble updated by LGETKF with cross-validation""" 41 | 42 | hxmean = hxens.mean(axis=0) 43 | hxprime = hxens - hxmean 44 | nanals = hxens.shape[0] 45 | ndim = covlocal.shape[-1] 46 | xmean = xens.mean(axis=0) 47 | xprime = xens - xmean 48 | xprime_b = xprime.copy() 49 | if ngroups is None: # default is "leave one out" (nanals must be multiple of ngroups) 50 | ngroups = nanals 51 | if nanals % ngroups: 52 | raise ValueError('nanals must be a multiple of ngroups') 53 | else: 54 | nanals_per_group = nanals//ngroups 55 | 56 | def getYbvecs(hx, Rlocal, oberrvar, nerger=True): 57 | normfact = np.array(np.sqrt(hx.shape[0]-1),dtype=np.float32) 58 | if nerger: 59 | # Nerger regularization 60 | hpbht = (hx**2).sum(axis=0)/normfact**2 61 | hpbhtplusR = hpbht+oberrvar 62 | Rlocalfact = (Rlocal*oberrvar/hpbhtplusR)/(1.-Rlocal*hpbht/hpbhtplusR) 63 | Rinvsqrt = np.sqrt(Rlocalfact/oberrvar) 64 | YbRinv = hx*Rinvsqrt**2/normfact 65 | YbsqrtRinv = hx*Rinvsqrt/normfact 66 | else: 67 | YbsqrtRinv = hx*np.sqrt(Rlocal/oberrvar)/normfact 68 | YbRinv = hx*(Rlocal/oberrvar)/normfact 69 | return YbsqrtRinv, YbRinv 70 | 71 | def calcwts_mean(nens, hx, Rlocal, oberrvar, ominusf, nerger=True): 72 | # nens is the original (unmodulated) ens size 73 | nobs = hx.shape[1] 74 | normfact = np.array(np.sqrt(nens-1),dtype=np.float32) 75 | # gain-form etkf solution 76 | # HZ^T = hxens * R**-1/2 77 | # compute eigenvectors/eigenvalues of A = HZ^T HZ (C=left SV) 78 | # (in Bishop paper HZ is nobs, nanals, here is it nanals, nobs) 79 | # normalize so dot product is covariance 80 | YbsqrtRinv, YbRinv = getYbvecs(hx,Rlocal,oberrvar,nerger=nerger) 81 | if nobs >= nens: 82 | a = np.dot(YbsqrtRinv,YbsqrtRinv.T) 83 | evals, evecs, info = lapack.dsyevd(a) 84 | evals = evals.clip(min=np.finfo(evals.dtype).eps) 85 | else: 86 | a = np.dot(YbsqrtRinv.T,YbsqrtRinv) 87 | evals, evecs, info = lapack.dsyevd(a) 88 | evals = evals.clip(min=np.finfo(evals.dtype).eps) 89 | evecs = np.dot(YbsqrtRinv,evecs/np.sqrt(evals)) 90 | # gammapI used in calculation of posterior cov in ensemble space 91 | gammapI = evals+1. 92 | # compute factor to multiply with model space ensemble perturbations 93 | # to compute analysis increment (for mean update). 94 | # This is the factor C (Gamma + I)**-1 C^T (HZ)^ T R**-1/2 (y - HXmean) 95 | # in Bishop paper (eqs 10-12). 96 | # pa = C (Gamma + I)**-1 C^T (analysis error cov in ensemble space) 97 | # wts_ensmean = C (Gamma + I)**-1 C^T (HZ)^ T R**-1/2 (y - HXmean) 98 | pa = np.dot(evecs/gammapI[np.newaxis,:],evecs.T) 99 | return np.dot(pa, np.dot(YbRinv,ominusf))/normfact 100 | 101 | def calcwts_perts(nens, hx_orig, hx, Rlocal, oberrvar,nerger=True): 102 | # hx_orig contains the ensemble for the witheld member 103 | # nens is the original (unmodulated) ens size 104 | nobs = hx.shape[1] 105 | normfact = np.array(np.sqrt(nens-1),dtype=np.float32) 106 | # gain-form etkf solution 107 | # HZ^T = hxens * R**-1/2 108 | # compute eigenvectors/eigenvalues of A = HZ^T HZ (C=left SV) 109 | # (in Bishop paper HZ is nobs, nanals, here is it nanals, nobs) 110 | # normalize so dot product is covariance 111 | YbsqrtRinv, YbRinv = getYbvecs(hx,Rlocal,oberrvar,nerger=nerger) 112 | if nobs >= nens: 113 | a = np.dot(YbsqrtRinv,YbsqrtRinv.T) 114 | evals, evecs, info = lapack.dsyevd(a) 115 | evals = evals.clip(min=np.finfo(evals.dtype).eps) 116 | else: 117 | a = np.dot(YbsqrtRinv.T,YbsqrtRinv) 118 | evals, evecs, info = lapack.dsyevd(a) 119 | evals = evals.clip(min=np.finfo(evals.dtype).eps) 120 | evecs = np.dot(YbsqrtRinv,evecs/np.sqrt(evals)) 121 | # gammapI used in calculation of posterior cov in ensemble space 122 | gamma_inv = 1./evals; gammapI = evals+1. 123 | # compute factor to multiply with model space ensemble perturbations 124 | # to compute analysis increment (for perturbation update), save in single precision. 125 | # This is -C [ (I - (Gamma+I)**-1/2)*Gamma**-1 ] C^T (HZ)^T R**-1/2 HXprime 126 | # in Bishop paper (eqn 29). 127 | # wts_ensperts = -C [ (I - (Gamma+I)**-1/2)*Gamma**-1 ] C^T (HZ)^T R**-1/2 HXprime 128 | pasqrt=np.dot(evecs*(1.-np.sqrt(1./gammapI[np.newaxis,:]))*gamma_inv[np.newaxis,:],evecs.T) 129 | return -np.dot(pasqrt, np.dot(YbRinv,hx_orig.T)).T/normfact # use witheld ens member here 130 | 131 | for n in range(ndim): 132 | mask = covlocal[:,n] > 1.0e-10 133 | nobs_local = mask.sum() 134 | if nobs_local > 0: 135 | Rlocal = covlocal[mask, n] 136 | oberrvar_local = oberrs[mask] 137 | ominusf_local = (obs-hxmean)[mask] 138 | hxprime_local = hxprime[:,mask] 139 | wts_ensmean = calcwts_mean(nanals, hxprime_local, Rlocal, oberrvar_local, ominusf_local, nerger=nerger) 140 | for k in range(2): 141 | xmean[k,n] += np.dot(wts_ensmean,xprime_b[:,k,n]) 142 | # update sub-ensemble groups, using cross validation. 143 | for ngrp in range(ngroups): 144 | nanal_cv = [na + ngrp*nanals_per_group for na in range(nanals_per_group)] 145 | hxprime_cv = np.delete(hxprime_local,nanal_cv,axis=0); xprime_cv = np.delete(xprime_b[:,:,n],nanal_cv,axis=0) 146 | wts_ensperts_cv = calcwts_perts(nanals-nanals//ngroups, hxprime_local[nanal_cv], hxprime_cv, Rlocal, oberrvar_local, nerger=nerger) 147 | for k in range(2): 148 | xprime[nanal_cv,k,n] += np.dot(wts_ensperts_cv,xprime_cv[:,k]) 149 | xprime_mean = xprime[:,:,n].mean(axis=0) 150 | xprime[:,:,n] -= xprime_mean # ensure zero mean 151 | xens[:,:,n] = xmean[:,n]+xprime[:,:,n] 152 | 153 | return xens 154 | 155 | def lgetkf_vloc(xens, xens2, hxens, hxens2, obs, oberrs, covlocal, nanal_index, nerger=True, ngroups=None): 156 | 157 | """returns ensemble updated by LGETKF with cross-validation""" 158 | 159 | hxmean = hxens.mean(axis=0) 160 | hxprime = hxens - hxmean 161 | hxprime2 = hxens2 - hxmean # modulated ens 162 | nanals = hxens.shape[0] 163 | nanals2 = hxens2.shape[0] # modulated ens size 164 | ndim = covlocal.shape[-1] 165 | xmean = xens.mean(axis=0) 166 | xprime = xens - xmean 167 | xprime2 = xens2 - xmean # modulated ens 168 | if ngroups is None: # default is "leave one out" (nanals must be multiple of ngroups) 169 | ngroups = nanals 170 | if nanals % ngroups: 171 | raise ValueError('nanals must be a multiple of ngroups') 172 | else: 173 | nanals_per_group = nanals//ngroups 174 | 175 | def getYbvecs(hx, Rlocal, oberrvar, nerger=True): 176 | normfact = np.array(np.sqrt(hx.shape[0]-1),dtype=np.float32) 177 | if nerger: 178 | # Nerger regularization 179 | hpbht = (hx**2).sum(axis=0)/normfact**2 180 | hpbhtplusR = hpbht+oberrvar 181 | Rlocalfact = (Rlocal*oberrvar/hpbhtplusR)/(1.-Rlocal*hpbht/hpbhtplusR) 182 | Rinvsqrt = np.sqrt(Rlocalfact/oberrvar) 183 | YbRinv = hx*Rinvsqrt**2/normfact 184 | YbsqrtRinv = hx*Rinvsqrt/normfact 185 | else: 186 | YbsqrtRinv = hx*np.sqrt(Rlocal/oberrvar)/normfact 187 | YbRinv = hx*(Rlocal/oberrvar)/normfact 188 | return YbsqrtRinv, YbRinv 189 | 190 | def calcwts_mean(nens, hx, Rlocal, oberrvar, ominusf,nerger=True): 191 | # nens is the original (unmodulated) ens size 192 | nobs = hx.shape[1] 193 | normfact = np.array(np.sqrt(nens-1),dtype=np.float32) 194 | # gain-form etkf solution 195 | # HZ^T = hxens * R**-1/2 196 | # compute eigenvectors/eigenvalues of A = HZ^T HZ (C=left SV) 197 | # (in Bishop paper HZ is nobs, nanals, here is it nanals, nobs) 198 | # normalize so dot product is covariance 199 | YbsqrtRinv, YbRinv = getYbvecs(hx,Rlocal,oberrvar,nerger=nerger) 200 | if nobs >= nens: 201 | a = np.dot(YbsqrtRinv,YbsqrtRinv.T) 202 | evals, evecs, info = lapack.dsyevd(a) 203 | evals = evals.clip(min=np.finfo(evals.dtype).eps) 204 | else: 205 | a = np.dot(YbsqrtRinv.T,YbsqrtRinv) 206 | evals, evecs, info = lapack.dsyevd(a) 207 | evals = evals.clip(min=np.finfo(evals.dtype).eps) 208 | evecs = np.dot(YbsqrtRinv,evecs/np.sqrt(evals)) 209 | # gammapI used in calculation of posterior cov in ensemble space 210 | gammapI = evals+1. 211 | # compute factor to multiply with model space ensemble perturbations 212 | # to compute analysis increment (for mean update). 213 | # This is the factor C (Gamma + I)**-1 C^T (HZ)^ T R**-1/2 (y - HXmean) 214 | # in Bishop paper (eqs 10-12). 215 | # pa = C (Gamma + I)**-1 C^T (analysis error cov in ensemble space) 216 | # wts_ensmean = C (Gamma + I)**-1 C^T (HZ)^ T R**-1/2 (y - HXmean) 217 | pa = np.dot(evecs/gammapI[np.newaxis,:],evecs.T) 218 | return np.dot(pa, np.dot(YbRinv,ominusf))/normfact 219 | 220 | def calcwts_perts(nens, hx_orig, hx, Rlocal, oberrvar,nerger=True): 221 | # hx_orig contains the ensemble for the witheld member 222 | # nens is the original (unmodulated) ens size 223 | nobs = hx.shape[1] 224 | normfact = np.array(np.sqrt(nens-1),dtype=np.float32) 225 | # gain-form etkf solution 226 | # HZ^T = hxens * R**-1/2 227 | # compute eigenvectors/eigenvalues of A = HZ^T HZ (C=left SV) 228 | # (in Bishop paper HZ is nobs, nanals, here is it nanals, nobs) 229 | # normalize so dot product is covariance 230 | YbsqrtRinv, YbRinv = getYbvecs(hx,Rlocal,oberrvar,nerger=nerger) 231 | if nobs >= nens: 232 | a = np.dot(YbsqrtRinv,YbsqrtRinv.T) 233 | evals, evecs, info = lapack.dsyevd(a) 234 | evals = evals.clip(min=np.finfo(evals.dtype).eps) 235 | else: 236 | a = np.dot(YbsqrtRinv.T,YbsqrtRinv) 237 | evals, evecs, info = lapack.dsyevd(a) 238 | evals = evals.clip(min=np.finfo(evals.dtype).eps) 239 | evecs = np.dot(YbsqrtRinv,evecs/np.sqrt(evals)) 240 | # gammapI used in calculation of posterior cov in ensemble space 241 | gamma_inv = 1./evals; gammapI = evals+1. 242 | # compute factor to multiply with model space ensemble perturbations 243 | # to compute analysis increment (for perturbation update), save in single precision. 244 | # This is -C [ (I - (Gamma+I)**-1/2)*Gamma**-1 ] C^T (HZ)^T R**-1/2 HXprime 245 | # in Bishop paper (eqn 29). 246 | # wts_ensperts = -C [ (I - (Gamma+I)**-1/2)*Gamma**-1 ] C^T (HZ)^T R**-1/2 HXprime 247 | pasqrt=np.dot(evecs*(1.-np.sqrt(1./gammapI[np.newaxis,:]))*gamma_inv[np.newaxis,:],evecs.T) 248 | return -np.dot(pasqrt, np.dot(YbRinv,hx_orig.T)).T/normfact # use witheld ens member here 249 | 250 | for n in range(ndim): 251 | mask = covlocal[:,n] > 1.0e-10 252 | nobs_local = mask.sum() 253 | if nobs_local > 0: 254 | Rlocal = covlocal[mask, n] 255 | oberrvar_local = oberrs[mask] 256 | ominusf_local = (obs-hxmean)[mask] 257 | hxprime2_local = hxprime2[:,mask] 258 | hxprime_local = hxprime[:,mask] 259 | wts_ensmean = calcwts_mean(nanals, hxprime2_local, Rlocal, oberrvar_local, ominusf_local, nerger=nerger) 260 | for k in range(2): 261 | xmean[k,n] += np.dot(wts_ensmean,xprime2[:,k,n]) 262 | # update sub-ensemble groups, using cross validation. 263 | for ngrp in range(ngroups): 264 | # nanal_index has original ens index for modulated member 265 | nanal_cv = [na + ngrp*nanals_per_group for na in range(nanals_per_group)] 266 | nanals_sub = np.nonzero(np.isin(nanal_index,nanal_cv)) 267 | hxprime_cv = np.delete(hxprime2_local,nanals_sub,axis=0); xprime_cv = np.delete(xprime2[:,:,n],nanals_sub,axis=0) 268 | wts_ensperts_cv = calcwts_perts(nanals-nanals//ngroups, hxprime_local[nanal_cv], hxprime_cv, Rlocal, oberrvar_local, nerger=nerger) 269 | for k in range(2): 270 | xprime[nanal_cv,k,n] += np.dot(wts_ensperts_cv,xprime_cv[:,k]) 271 | xprime_mean = xprime[:,:,n].mean(axis=0) 272 | xprime[:,:,n] -= xprime_mean # ensure zero mean 273 | xens[:,:,n] = xmean[:,n]+xprime[:,:,n] 274 | 275 | return xens 276 | -------------------------------------------------------------------------------- /sqg_lgetkf_cv.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | from netCDF4 import Dataset 6 | import sys, time, os 7 | from sqgturb import SQG, rfft2, irfft2, cartdist, lgetkf, gaspcohn 8 | 9 | # LGETKF cycling for SQG turbulence model with boundary temp obs, 10 | # ob space horizontal localization, no vertical localization. 11 | # cross-validation update (no inflation). 12 | # Random observing network. 13 | 14 | if len(sys.argv) == 1: 15 | msg=""" 16 | python sqg_lgetkf_cv.py hcovlocal_scale covinflate> 17 | hcovlocal_scale = horizontal localization scale in km 18 | """ 19 | raise SystemExit(msg) 20 | 21 | # horizontal covariance localization length scale in meters. 22 | hcovlocal_scale = float(sys.argv[1]) 23 | exptname = os.getenv('exptname','test') 24 | threads = int(os.getenv('OMP_NUM_THREADS','1')) 25 | 26 | diff_efold = None # use diffusion from climo file 27 | 28 | profile = False # turn on profiling? 29 | 30 | read_restart = False 31 | # if savedata not None, netcdf filename will be defined by env var 'exptname' 32 | # if savedata = 'restart', only last time is saved (so expt can be restarted) 33 | #savedata = True 34 | #savedata = 'restart' 35 | savedata = None 36 | #nassim = 101 37 | #nassim_spinup = 100 38 | nassim = 600 # assimilation times to run 39 | nassim_spinup = 100 40 | 41 | nanals = 20 # ensemble members 42 | nerger = True # use Nerger regularization for R localization 43 | ngroups = nanals # number of groups for cross-validation (ngroups=nanals is "leave one out") 44 | 45 | oberrstdev = 1. # ob error standard deviation in K 46 | 47 | # nature run created using sqg_run.py. 48 | filename_climo = 'sqgu20_N96_6hrly.nc' # file name for forecast model climo 49 | # perfect model 50 | filename_truth = 'sqgu20_N96_6hrly.nc' # file name for nature run to draw obs 51 | #filename_truth = 'sqg_N256_N96_12hrly.nc' # file name for nature run to draw obs 52 | 53 | print('# filename_modelclimo=%s' % filename_climo) 54 | print('# filename_truth=%s' % filename_truth) 55 | 56 | # fix random seed for reproducibility. 57 | rsobs = np.random.RandomState(42) # fixed seed for observations 58 | #rsics = np.random.RandomState() # varying seed for initial conditions 59 | rsics = np.random.RandomState(24) # fixed seed for initial conditions 60 | 61 | # get model info 62 | nc_climo = Dataset(filename_climo) 63 | # parameter used to scale PV to temperature units. 64 | scalefact = nc_climo.f*nc_climo.theta0/nc_climo.g 65 | # initialize qg model instances for each ensemble member. 66 | x = nc_climo.variables['x'][:] 67 | y = nc_climo.variables['y'][:] 68 | x, y = np.meshgrid(x, y) 69 | nx = len(x); ny = len(y) 70 | dt = nc_climo.dt 71 | if diff_efold == None: diff_efold=nc_climo.diff_efold 72 | pvens = np.empty((nanals,2,ny,nx),np.float32) 73 | if not read_restart: 74 | pv_climo = nc_climo.variables['pv'] 75 | indxran = rsics.choice(pv_climo.shape[0],size=nanals,replace=False) 76 | else: 77 | ncinit = Dataset('%s_restart.nc' % exptname, mode='r', format='NETCDF4_CLASSIC') 78 | ncinit.set_auto_mask(False) 79 | pvens[:] = ncinit.variables['pv_b'][-1,...]/scalefact 80 | tstart = ncinit.variables['t'][-1] 81 | #for nanal in range(nanals): 82 | # print(nanal, pvens[nanal].min(), pvens[nanal].max()) 83 | # get OMP_NUM_THREADS (threads to use) from environment. 84 | models = [] 85 | for nanal in range(nanals): 86 | if not read_restart: 87 | pvens[nanal] = pv_climo[indxran[nanal]] 88 | #print(nanal, pvens[nanal].min(), pvens[nanal].max()) 89 | models.append(\ 90 | SQG(pvens[nanal], 91 | nsq=nc_climo.nsq,f=nc_climo.f,dt=dt,U=nc_climo.U,H=nc_climo.H,\ 92 | r=nc_climo.r,tdiab=nc_climo.tdiab,symmetric=nc_climo.symmetric,\ 93 | diff_order=nc_climo.diff_order,diff_efold=diff_efold,threads=threads)) 94 | if read_restart: ncinit.close() 95 | 96 | hcovlocal_km = int(hcovlocal_scale/1000.) 97 | print("# hcovlocal=%g diff_efold=%s nanals=%s ngroups=%s" %\ 98 | (hcovlocal_km,diff_efold,nanals,ngroups)) 99 | 100 | # each ob time nobs ob locations are randomly sampled (without 101 | # replacement) from the model grid 102 | #nobs = nx*ny//6 # number of obs to assimilate (randomly distributed) 103 | nobs = 1024 104 | 105 | # nature run 106 | nc_truth = Dataset(filename_truth) 107 | pv_truth = nc_truth.variables['pv'] 108 | # set up arrays for obs and localization function 109 | print('# random network nobs = %s' % nobs) 110 | 111 | oberrvar = oberrstdev**2*np.ones(nobs,np.float32) 112 | covlocal = np.empty((ny,nx),np.float32) 113 | covlocal_tmp = np.empty((nobs,nx*ny),np.float32) 114 | 115 | obtimes = nc_truth.variables['t'][:] 116 | if read_restart: 117 | timeslist = obtimes.tolist() 118 | ntstart = timeslist.index(tstart) 119 | print('# restarting from %s.nc ntstart = %s' % (exptname,ntstart)) 120 | else: 121 | ntstart = 0 122 | assim_interval = obtimes[1]-obtimes[0] 123 | assim_timesteps = int(np.round(assim_interval/models[0].dt)) 124 | print('# assim interval = %s secs (%s time steps)' % (assim_interval,assim_timesteps)) 125 | print('# ntime,pverr_a,pvsprd_a,pverr_b,pvsprd_b,obfits_b,osprd_b+R,obbias_b,tr(P^a)/tr(P^b)') 126 | 127 | # initialize model clock 128 | for nanal in range(nanals): 129 | models[nanal].t = obtimes[ntstart] 130 | models[nanal].timesteps = assim_timesteps 131 | 132 | # initialize output file. 133 | if savedata is not None: 134 | nc = Dataset('%s.nc' % exptname, mode='w', format='NETCDF4_CLASSIC') 135 | nc.r = models[0].r 136 | nc.f = models[0].f 137 | nc.U = models[0].U 138 | nc.L = models[0].L 139 | nc.H = models[0].H 140 | nc.nanals = nanals 141 | nc.hcovlocal_scale = hcovlocal_scale 142 | nc.oberrstdev = oberrstdev 143 | nc.g = nc_climo.g; nc.theta0 = nc_climo.theta0 144 | nc.nsq = models[0].nsq 145 | nc.tdiab = models[0].tdiab 146 | nc.dt = models[0].dt 147 | nc.diff_efold = models[0].diff_efold 148 | nc.diff_order = models[0].diff_order 149 | nc.filename_climo = filename_climo 150 | nc.filename_truth = filename_truth 151 | nc.symmetric = models[0].symmetric 152 | xdim = nc.createDimension('x',models[0].N) 153 | ydim = nc.createDimension('y',models[0].N) 154 | z = nc.createDimension('z',2) 155 | t = nc.createDimension('t',None) 156 | obs = nc.createDimension('obs',nobs) 157 | ens = nc.createDimension('ens',nanals) 158 | pv_t =\ 159 | nc.createVariable('pv_t',np.float32,('t','z','y','x'),zlib=True) 160 | pv_b =\ 161 | nc.createVariable('pv_b',np.float32,('t','ens','z','y','x'),zlib=True) 162 | pv_a =\ 163 | nc.createVariable('pv_a',np.float32,('t','ens','z','y','x'),zlib=True) 164 | pv_a.units = 'K' 165 | pv_b.units = 'K' 166 | pv_obs = nc.createVariable('obs',np.float32,('t','z','obs')) 167 | x_obs = nc.createVariable('x_obs',np.float32,('t','obs')) 168 | y_obs = nc.createVariable('y_obs',np.float32,('t','obs')) 169 | # eady pv scaled by g/(f*theta0) so du/dz = d(pv)/dy 170 | xvar = nc.createVariable('x',np.float32,('x',)) 171 | xvar.units = 'meters' 172 | yvar = nc.createVariable('y',np.float32,('y',)) 173 | yvar.units = 'meters' 174 | zvar = nc.createVariable('z',np.float32,('z',)) 175 | zvar.units = 'meters' 176 | tvar = nc.createVariable('t',np.float32,('t',)) 177 | tvar.units = 'seconds' 178 | ensvar = nc.createVariable('ens',np.int32,('ens',)) 179 | ensvar.units = 'dimensionless' 180 | xvar[:] = np.arange(0,models[0].L,models[0].L/models[0].N) 181 | yvar[:] = np.arange(0,models[0].L,models[0].L/models[0].N) 182 | zvar[0] = 0; zvar[1] = models[0].H 183 | ensvar[:] = np.arange(1,nanals+1) 184 | 185 | # initialize kinetic energy error/spread spectra 186 | pvspec_errmean = None; pvspec_sprdmean = None 187 | 188 | ncount = 0 189 | 190 | N = models[0].N 191 | k = np.abs((N*np.fft.fftfreq(N))[0:(N//2)+1]) 192 | l = N*np.fft.fftfreq(N) 193 | imax = len(k); jmax = len(l) 194 | k,l = np.meshgrid(k,l) 195 | ktotsq = (k**2+l**2).astype(np.int32) 196 | jmax,imax = ktotsq.shape 197 | ktot = np.sqrt(ktotsq) 198 | ktotmax = (N//2)+1 199 | 200 | for ntime in range(nassim): 201 | 202 | # check model clock 203 | if models[0].t != obtimes[ntime+ntstart]: 204 | raise ValueError('model/ob time mismatch %s vs %s' %\ 205 | (models[0].t, obtimes[ntime+ntstart])) 206 | 207 | t1 = time.time() 208 | indxob = np.sort(rsobs.choice(2*nx*ny,nobs,replace=False)) 209 | pvob = scalefact*pv_truth[ntime+ntstart,...].reshape(2*nx*ny)[indxob] 210 | pvob += rsobs.normal(scale=oberrstdev,size=nobs) # add ob errors 211 | xob = np.concatenate((x.ravel(),x.ravel()))[indxob] 212 | yob = np.concatenate((y.ravel(),y.ravel()))[indxob] 213 | # compute covariance localization function for each ob 214 | for nob in range(nobs): 215 | dist = cartdist(xob[nob],yob[nob],x,y,nc_climo.L,nc_climo.L) 216 | covlocal = gaspcohn(dist/hcovlocal_scale) 217 | covlocal_tmp[nob] = covlocal.ravel() 218 | dist = cartdist(xob[nob],yob[nob],xob,yob,nc_climo.L,nc_climo.L) 219 | 220 | # first-guess spread 221 | pvensmean = pvens.mean(axis=0) 222 | pvprime = pvens - pvensmean 223 | 224 | fsprd = (pvprime**2).sum(axis=0)/(nanals-1) 225 | 226 | # compute forward operator on modulated ensemble. 227 | # hxens is ensemble in observation space. 228 | hxens = np.empty((nanals,nobs),np.float32) 229 | 230 | for nanal in range(nanals): 231 | hxens[nanal] = scalefact*pvens[nanal,...].reshape(2*nx*ny)[indxob] # surface pv obs 232 | hxensmean_b = hxens.mean(axis=0) 233 | obsprd = ((hxens-hxensmean_b)**2).sum(axis=0)/(nanals-1) 234 | # innov stats for background 235 | obfits = pvob - hxensmean_b 236 | obfits_b = (obfits**2).mean() 237 | obbias_b = obfits.mean() 238 | obsprd_b = obsprd.mean() 239 | pvensmean_b = pvens.mean(axis=0).copy() 240 | pverr_b = (scalefact*(pvensmean_b-pv_truth[ntime+ntstart]))**2 241 | pvsprd_b = ((scalefact*(pvensmean_b-pvens))**2).sum(axis=0)/(nanals-1) 242 | 243 | if savedata is not None: 244 | if savedata == 'restart' and ntime != nassim-1: 245 | pass 246 | else: 247 | pv_t[ntime] = pv_truth[ntime+ntstart] 248 | pv_b[ntime,:,:,:] = scalefact*pvens 249 | pv_obs[ntime] = pvob 250 | x_obs[ntime] = xob 251 | y_obs[ntime] = yob 252 | 253 | # EnKF update 254 | # create 1d state vector. 255 | xens = pvens.reshape(nanals,2,nx*ny) 256 | 257 | # update state vector. 258 | 259 | # hxens,pvob are in PV units, xens is not 260 | xens = lgetkf(xens,hxens,pvob,oberrvar,covlocal_tmp,nerger=nerger,ngroups=ngroups) 261 | 262 | # back to 3d state vector 263 | pvens = xens.reshape((nanals,2,ny,nx)) 264 | t2 = time.time() 265 | if profile: print('cpu time for EnKF update',t2-t1) 266 | 267 | pvensmean_a = pvens.mean(axis=0) 268 | pvprime = pvens-pvensmean_a 269 | asprd = (pvprime**2).sum(axis=0)/(nanals-1) 270 | asprd_over_fsprd = asprd.mean()/fsprd.mean() 271 | 272 | # print out analysis error, spread and innov stats for background 273 | pverr_a = (scalefact*(pvensmean_a-pv_truth[ntime+ntstart]))**2 274 | pvsprd_a = ((scalefact*(pvensmean_a-pvens))**2).sum(axis=0)/(nanals-1) 275 | print("%s %g %g %g %g %g %g %g %g" %\ 276 | (ntime+ntstart,np.sqrt(pverr_a.mean()),np.sqrt(pvsprd_a.mean()),\ 277 | np.sqrt(pverr_b.mean()),np.sqrt(pvsprd_b.mean()),\ 278 | np.sqrt(obfits_b),np.sqrt(obsprd_b+oberrstdev**2),obbias_b, 279 | asprd_over_fsprd)) 280 | 281 | # save data. 282 | if savedata is not None: 283 | if savedata == 'restart' and ntime != nassim-1: 284 | pass 285 | else: 286 | pv_a[ntime,:,:,:] = scalefact*pvens 287 | tvar[ntime] = obtimes[ntime+ntstart] 288 | nc.sync() 289 | 290 | # run forecast ensemble to next analysis time 291 | t1 = time.time() 292 | for nanal in range(nanals): 293 | pvens[nanal] = models[nanal].advance(pvens[nanal]) 294 | t2 = time.time() 295 | if profile: print('cpu time for ens forecast',t2-t1) 296 | if not np.all(np.isfinite(pvens)): 297 | raise SystemExit('non-finite values detected after forecast, stopping...') 298 | 299 | # compute spectra of error and spread 300 | if ntime >= nassim_spinup: 301 | pvfcstmean = pvens.mean(axis=0) 302 | pverrspec = scalefact*rfft2(pvfcstmean - pv_truth[ntime+ntstart+1]) 303 | pverrspec_mag = (pverrspec*np.conjugate(pverrspec)).real 304 | if pvspec_errmean is None: 305 | pvspec_errmean = pverrspec_mag 306 | else: 307 | pvspec_errmean = pvspec_errmean + pverrspec_mag 308 | for nanal in range(nanals): 309 | pvpertspec = scalefact*rfft2(pvens[nanal] - pvfcstmean) 310 | pvpertspec_mag = (pvpertspec*np.conjugate(pvpertspec)).real/(nanals-1) 311 | if pvspec_sprdmean is None: 312 | pvspec_sprdmean = pvpertspec_mag 313 | else: 314 | pvspec_sprdmean = pvspec_sprdmean+pvpertspec_mag 315 | ncount += 1 316 | 317 | if savedata: nc.close() 318 | 319 | if ncount: 320 | pvspec_sprdmean = pvspec_sprdmean/ncount 321 | pvspec_errmean = pvspec_errmean/ncount 322 | pvspec_err = np.zeros(ktotmax,np.float32) 323 | pvspec_sprd = np.zeros(ktotmax,np.float32) 324 | for i in range(pvspec_errmean.shape[2]): 325 | for j in range(pvspec_errmean.shape[1]): 326 | totwavenum = int(np.round(ktot[j,i])) 327 | if totwavenum < ktotmax: 328 | pvspec_err[totwavenum] = pvspec_err[totwavenum] +\ 329 | pvspec_errmean[:,j,i].mean(axis=0) # average of upper/lower boundary 330 | pvspec_sprd[totwavenum] = pvspec_sprd[totwavenum] +\ 331 | pvspec_sprdmean[:,j,i].mean(axis=0) 332 | 333 | print('# mean error/spread',pvspec_errmean.sum(), pvspec_sprdmean.sum()) 334 | plt.figure() 335 | wavenums = np.arange(ktotmax,dtype=np.float32) 336 | for n in range(1,ktotmax): 337 | print('# ',wavenums[n],pvspec_err[n],pvspec_sprd[n]) 338 | plt.loglog(wavenums[1:-1],pvspec_err[1:-1],color='r') 339 | plt.loglog(wavenums[1:-1],pvspec_sprd[1:-1],color='b') 340 | plt.title('error (red) and spread (blue) l=%s' % hcovlocal_km) 341 | plt.savefig('errorspread_spectra_cv_%s.png' % exptname) 342 | --------------------------------------------------------------------------------