├── .gitignore ├── ExportSimpleLineBlobForTraining.py ├── Inn2.py ├── LICENSE ├── MiniBalancedTraining.pickle ├── README.md ├── Radynversion.ipynb ├── __init__.py ├── loss.py ├── single_pixel_inversion_example.ipynb ├── utils.py └── z.h5 /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /ExportSimpleLineBlobForTraining.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['CDF_LIB'] = '/usr/local/cdf/lib' 3 | import radynpy.RadynCdfLoader as RadynCdfLoader 4 | import numpy as np 5 | from scipy.signal import savgol_filter 6 | from scipy.interpolate import interp1d 7 | import matplotlib.pyplot as plt 8 | import torch 9 | import pickle 10 | from tqdm import tqdm 11 | from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed 12 | 13 | # We should possibly do the extrapolation against cell centres rather than 14 | # interface locations, but what I can see of the default radyn analysis scripts 15 | # is that they use the interface data everywhere. Probably doesn't really 16 | # matter for the shape of curve. Also, there can't be k interaface heights and 17 | # cell centred heights... 18 | tags = ['alamb', 'q', 'nq', 'outint', 'cont', 'time', 'tg1', 'ne1', 'zmu', 'z1', 'vz1', 'n1', 'jrad', 'irad'] 19 | 20 | folder = '/local1/scratch/RadynGrid/' 21 | files = [folder + x for x in os.listdir(folder) if x.startswith('radyn_out.val3c') and not x.endswith('em')] 22 | outputFolder = '/local0/scratch/HAlphaGridExportStatic/' 23 | 24 | start = -6.5e6 25 | expTrans = 3.5e8 26 | stop = 1.05e9 27 | # linStep = 2e5 28 | # expNum = 500 29 | # linStep = 10e5 30 | totalNum = 88 31 | expNum = totalNum // 10 32 | linNum = totalNum - expNum 33 | expNum += 1 # We have to remove the first entry of the exponential grid, because it's the same as the last of the linear grid, so add an extra point to account for this 34 | # staticAltitudeGrid = np.concatenate((np.arange(start, expTrans, linStep), np.logspace(np.log10(expTrans), np.log10(stop), num=expNum))).astype(np.float32) 35 | staticAltitudeGrid = np.concatenate((np.linspace(start, expTrans, num=linNum), np.logspace(np.log10(expTrans), np.log10(stop), num=expNum)[1:])).astype(np.float32) 36 | # staticAltitudeGrid = savgol_filter(staticAltitudeGrid, 7, 3) 37 | print('Interpolating onto %d spatial points' % len(staticAltitudeGrid)) 38 | 39 | def line_intensity_with_cont(data, kr, muIdx): 40 | if not not data.cont[kr]: 41 | print('line_intensity cannot compute bf intensity') 42 | return 43 | 44 | # Made some changes to the q indexing here, because it didn't seem quite right to me 45 | # perhaps that's idl slicing 46 | # yup, it includes the last element, and python doesn't 47 | # 1e5 is a conversion between km/s (qnorm) and cm/s (cc) 48 | wl = data.alamb[kr] / (data.q[0:data.nq[kr], kr] * data.qnorm * 1e5 / data.cc + 1) 49 | # The 1e8 in here comes from the conversion from angstrom to cm, but the square 50 | # on wavelength comes from computing the derivative of nu*lambda = c to get the 51 | # Jacobian 52 | # The second part of the parens is the continuum here, normally negligible -- especially for Halpha 53 | intens = (data.outint[:, 1:data.nq[kr]+1, muIdx, kr] + data.outint[:, 0, muIdx, kr][:, np.newaxis]) * data.cc * 1e8 / (wl**2)[np.newaxis, :] 54 | # wl is retruned in angstrom, intens in erg/cm^2/sr/A/s 55 | return wl[::-1], intens[:,::-1] 56 | 57 | def line_dict(data, export, lineIdx): 58 | 59 | line = export['lineInfo'][lineIdx] 60 | kr = line['kr'] 61 | iel = line['iel'] 62 | halfWidth = line['halfWidth'] 63 | jIdx = data.jrad[kr] - 1 64 | iIdx = data.irad[kr] - 1 65 | 66 | lineCentre = data.alamb[kr] 67 | muIdx = 4 68 | 69 | # mu = 4 # for straight on 70 | # mu = 0 # for limb 71 | trimmedLines = [] 72 | for m in range(len(data.zmu)): 73 | wl, intens = line_intensity_with_cont(data, kr, m) 74 | lowIdx = np.searchsorted(wl, lineCentre - halfWidth, side='right')-1 75 | highIdx = np.searchsorted(wl, lineCentre + halfWidth, side='left')+1 76 | trimmedLines.append(intens[:,lowIdx:highIdx]) 77 | 78 | # These statements assume that we will go through the lineIdx in ascending 79 | # order, but, at the end of the day, we will. 80 | if len(export['wavelength']) <= lineIdx: 81 | export['wavelength'].append(torch.from_numpy(wl[lowIdx:highIdx].copy())) 82 | if len(export['mu']) <= lineIdx: 83 | export['mu'].append(data.zmu[muIdx]) 84 | 85 | 86 | staticTemp = np.zeros_like(staticAltitudeGrid).astype(np.float32) 87 | staticNe = np.zeros_like(staticAltitudeGrid).astype(np.float32) 88 | staticVel = np.zeros_like(staticAltitudeGrid).astype(np.float32) 89 | 90 | if lineIdx == 0: 91 | export['nTime'].append(data.tg1.shape[0]) 92 | export['beamSpectralIndex'].append(data.beamSpectralIndex) 93 | export['totalBeamEnergy'].append(data.totalBeamEnergy) 94 | export['beamPulseType'].append(data.beamPlulseType) 95 | export['cutoffEnergy'].append(data.cutoffEnergy) 96 | 97 | for t in range(export['nTime'][-1]): 98 | export['line'][lineIdx].append(torch.from_numpy(trimmedLines[muIdx][t].copy())) 99 | 100 | z = data.z1[t, ::-1] 101 | 102 | interp_static = lambda param: interp1d(z, param, assume_sorted=True)(staticAltitudeGrid) 103 | 104 | upperLevel = data.n1[t, ::-1, jIdx, iel] 105 | staticUpperLevel = interp_static(upperLevel) 106 | 107 | lowerLevel = data.n1[t, ::-1, iIdx, iel] 108 | staticLowerLevel = interp_static(lowerLevel) 109 | export['upperLevelPop'][lineIdx].append(torch.from_numpy(staticUpperLevel.copy()).float()) 110 | export['lowerLevelPop'][lineIdx].append(torch.from_numpy(staticLowerLevel.copy()).float()) 111 | 112 | if lineIdx == 0: 113 | temp = data.tg1[t, ::-1] 114 | staticTemp = interp_static(temp) 115 | 116 | ne = data.ne1[t, ::-1] 117 | staticNe = interp_static(ne) 118 | 119 | vel = data.vz1[t, ::-1] 120 | staticVel = interp_static(vel) 121 | 122 | nhG = data.n1[t, ::-1, 0, 0] 123 | staticNhG = interp_static(nhG) 124 | 125 | nhi = data.n1[t, ::-1, :5, 0].sum(axis=1) 126 | staticNhi = interp_static(nhi) 127 | 128 | nhii = data.n1[t, ::-1, 5, 0] 129 | staticNhii = interp_static(nhii) 130 | 131 | export['temperature'].append(torch.from_numpy(staticTemp.copy()).float()) 132 | export['ne'].append(torch.from_numpy(staticNe.copy()).float()) 133 | export['vel'].append(torch.from_numpy(staticVel.copy()).float()) 134 | export['nhGround'].append(torch.from_numpy(staticNhG.copy()).float()) 135 | export['nhi'].append(torch.from_numpy(staticNhi.copy()).float()) 136 | export['nhii'].append(torch.from_numpy(staticNhii.copy()).float()) 137 | return export 138 | 139 | export = {} 140 | export['lineInfo'] = [{'line': 'H_alpha', 'kr': 4, 'iel': 0, 'halfWidth': 1.4}, {'line': 'CaII k', 'kr': 20, 'iel': 1, 'halfWidth': 1.0}] 141 | export['line'] = [[] for l in range(len(export['lineInfo']))] 142 | export['temperature'] = [] 143 | export['ne'] = [] 144 | export['nhGround'] = [] 145 | export['nhi'] = [] 146 | export['nhii'] = [] 147 | export['upperLevelPop'] = [[] for l in range(len(export['lineInfo']))] 148 | export['lowerLevelPop'] = [[] for l in range(len(export['lineInfo']))] 149 | export['vel'] = [] 150 | export['nTime'] = [] 151 | export['beamSpectralIndex'] = [] 152 | export['totalBeamEnergy'] = [] 153 | export['beamPulseType'] = [] 154 | export['cutoffEnergy'] = [] 155 | export['mu'] = [] 156 | export['wavelength'] = [] 157 | export['z'] = torch.from_numpy(staticAltitudeGrid.copy()) 158 | 159 | # for i, f in enumerate(tqdm(files)): 160 | # data = RadynCdfLoader.load_vars(f, tags) 161 | 162 | # for lineIdx in range(len(export['lineInfo'])): 163 | # export = line_dict(data, export, lineIdx) 164 | 165 | def async_load_radyn(f, tags): 166 | return RadynCdfLoader.load_vars(f, tags) 167 | 168 | # SpacePy's CDF loader doesn't seem like multiple threads at the same time and our Radyn object doesn't pickle nicely, 169 | # but we can still load the next while we process the current one 170 | with ThreadPoolExecutor(max_workers=1) as executor: 171 | radynFiles = [executor.submit(async_load_radyn, f, tags) for f in files] 172 | 173 | for res in tqdm(as_completed(radynFiles)): 174 | data = res.result() 175 | for lineIdx in range(len(export['lineInfo'])): 176 | export = line_dict(data, export, lineIdx) 177 | 178 | with open(outputFolder + 'DoublePicoGigaPickle%d.pickle' % totalNum, 'wb') as p: 179 | pickle.dump(export, p) 180 | 181 | -------------------------------------------------------------------------------- /Inn2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | from FrEIA.framework import InputNode, OutputNode, Node, ReversibleGraphNet 6 | from FrEIA.modules import rev_multiplicative_layer, permute_layer 7 | 8 | from .loss import mse, mse_tv, mmd_multiscale_on 9 | 10 | from scipy.interpolate import interp1d 11 | 12 | from copy import deepcopy 13 | from itertools import accumulate 14 | import pickle 15 | 16 | PadOp = '!!PAD' 17 | ZeroPadOp = '!!ZeroPadding' 18 | 19 | def schema_min_len(schema, zeroPadding): 20 | length = sum(s[1] if s[0] != PadOp else 0 for s in schema) \ 21 | + zeroPadding * (len([s for s in schema if s[0] != PadOp]) - 1) 22 | return length 23 | 24 | class DataSchema1D: 25 | def __init__(self, inp, minLength, zeroPadding, zero_pad_fn=torch.zeros): 26 | self.zero_pad = zero_pad_fn 27 | # Check schema is valid 28 | padCount = sum(1 if i[0] == PadOp else 0 for i in inp) 29 | for i in range(len(inp)-1): 30 | if inp[i][0] == PadOp and inp[i+1][0] == PadOp: 31 | raise ValueError('Schema cannot contain two consecutive \'!!PAD\' instructions.') 32 | # if padCount > 1: 33 | # raise ValueError('Schema can only contain one \'!!PAD\' instruction.') 34 | if len([i for i in inp if i[0] != PadOp]) > len(set([i[0] for i in inp if i[0] != PadOp])): 35 | raise ValueError('Schema names must be unique within a schema.') 36 | 37 | # Find length without extra padding (beyond normal channel separation) 38 | length = schema_min_len(inp, zeroPadding) 39 | if (minLength - length) // padCount != (minLength - length) / padCount: 40 | raise ValueError('Schema padding isn\'t divisible by number of PadOps') 41 | 42 | # Build schema 43 | schema = [] 44 | padding = (ZeroPadOp, zeroPadding) 45 | for j, i in enumerate(inp): 46 | if i[0] == PadOp: 47 | if j == len(inp) - 1: 48 | # Count the edge case where '!!PAD' is the last op and a spurious 49 | # extra padding gets inserted before it 50 | if schema[-1] == padding: 51 | del schema[-1] 52 | 53 | if length < minLength: 54 | schema.append((ZeroPadOp, (minLength - length) // padCount)) 55 | continue 56 | 57 | schema.append(i) 58 | if j != len(inp) - 1: 59 | schema.append(padding) 60 | 61 | if padCount == 0 and length < minLength: 62 | schema.append((ZeroPadOp, minLength - length)) 63 | 64 | # Fuse adjacent zero padding -- no rational way to have more than two in a row 65 | fusedSchema = [] 66 | i = 0 67 | while True: 68 | if i >= len(schema): 69 | break 70 | 71 | if i < len(schema) - 1 and schema[i][0] == ZeroPadOp and schema[i+1][0] == ZeroPadOp: 72 | fusedSchema.append((ZeroPadOp, schema[i][1] + schema[i+1][1])) 73 | i += 1 74 | else: 75 | fusedSchema.append(schema[i]) 76 | i += 1 77 | # Also remove 0-width ZeroPadding 78 | fusedSchema = [s for s in fusedSchema if s != (ZeroPadOp, 0)] 79 | self.schema = fusedSchema 80 | schemaTags = [s[0] for s in self.schema if s[0] != ZeroPadOp] 81 | tagIndices = [0] + list(accumulate([s[1] for s in self.schema])) 82 | tagRange = [(s[0], range(tagIndices[i], tagIndices[i+1])) for i, s in enumerate(self.schema) if s[0] != ZeroPadOp] 83 | for name, r in tagRange: 84 | setattr(self, name, r) 85 | self.len = tagIndices[-1] 86 | 87 | def __len__(self): 88 | return self.len 89 | 90 | def fill(self, entries, zero_pad_fn=None, batchSize=None, checkBounds=False, dev='cpu'): 91 | # Try and infer batchSize 92 | if batchSize is None: 93 | for k, v in entries.items(): 94 | if not callable(v): 95 | batchSize = v.shape[0] 96 | break 97 | else: 98 | raise ValueError('Unable to infer batchSize from entries (all fns?). Set batchSize manually.') 99 | 100 | if checkBounds: 101 | try: 102 | for s in self.schema: 103 | if s[0] == ZeroPadOp: 104 | continue 105 | entry = entries[s[0]] 106 | if not callable(entry): 107 | if len(entry.shape) != 2: 108 | raise ValueError('Entry: %s must be a 2D array or fn.' % s[0]) 109 | if entry.shape[0] != batchSize: 110 | raise ValueError('Entry: %s does not match batchSize along dim=0.' % s[0]) 111 | if entry.shape[1] != s[1]: 112 | raise ValueError('Entry: %s does not match schema dimension.' % s[0]) 113 | except KeyError as e: 114 | raise ValueError('No key present in entries to schema: ' + repr(e)) 115 | 116 | # Use different zero_pad if specified 117 | if zero_pad_fn is None: 118 | zero_pad_fn = self.zero_pad 119 | 120 | # Fill in the schema, throw exception if entry is missing 121 | reifiedSchema = [] 122 | try: 123 | for s in self.schema: 124 | if s[0] == ZeroPadOp: 125 | reifiedSchema.append(zero_pad_fn(batchSize, s[1])) 126 | else: 127 | entry = entries[s[0]] 128 | if callable(entry): 129 | reifiedSchema.append(entry(batchSize, s[1])) 130 | else: 131 | reifiedSchema.append(entry) 132 | except KeyError as e: 133 | raise ValueError('No key present in entries to schema: ' + repr(e)) 134 | 135 | reifiedSchema = torch.cat(reifiedSchema, dim=1) 136 | return reifiedSchema 137 | 138 | def __repr__(self): 139 | return repr(self.schema) 140 | 141 | class F_fully_connected_leaky(nn.Module): 142 | '''Fully connected tranformation, not reversible, but used below.''' 143 | 144 | def __init__(self, size_in, size, internal_size=None, dropout=0.0, 145 | batch_norm=False, leaky_slope=0.01): 146 | super(F_fully_connected_leaky, self).__init__() 147 | if not internal_size: 148 | internal_size = 2*size 149 | 150 | self.d1 = nn.Dropout(p=dropout) 151 | self.d2 = nn.Dropout(p=dropout) 152 | self.d2b = nn.Dropout(p=dropout) 153 | 154 | self.fc1 = nn.Linear(size_in, internal_size) 155 | self.fc2 = nn.Linear(internal_size, internal_size) 156 | self.fc2b = nn.Linear(internal_size, internal_size) 157 | # self.fc2c = nn.Linear(internal_size, internal_size) 158 | self.fc2d = nn.Linear(internal_size, internal_size) 159 | self.fc3 = nn.Linear(internal_size, size) 160 | 161 | self.nl1 = nn.LeakyReLU(negative_slope=leaky_slope) 162 | self.nl2 = nn.LeakyReLU(negative_slope=leaky_slope) 163 | self.nl2b = nn.LeakyReLU(negative_slope=leaky_slope) 164 | # self.nl2c = nn.LeakyReLU(negative_slope=leaky_slope) 165 | self.nl2d = nn.ReLU() 166 | 167 | if batch_norm: 168 | self.bn1 = nn.BatchNorm1d(internal_size) 169 | self.bn1.weight.data.fill_(1) 170 | self.bn2 = nn.BatchNorm1d(internal_size) 171 | self.bn2.weight.data.fill_(1) 172 | self.bn2b = nn.BatchNorm1d(internal_size) 173 | self.bn2b.weight.data.fill_(1) 174 | self.batch_norm = batch_norm 175 | 176 | def forward(self, x): 177 | out = self.fc1(x) 178 | if self.batch_norm: 179 | out = self.bn1(out) 180 | out = self.nl1(self.d1(out)) 181 | 182 | out = self.fc2(out) 183 | if self.batch_norm: 184 | out = self.bn2(out) 185 | out = self.nl2(self.d2(out)) 186 | 187 | out = self.fc2b(out) 188 | if self.batch_norm: 189 | out = self.bn2b(out) 190 | out = self.nl2b(self.d2b(out)) 191 | 192 | # out = self.fc2c(out) 193 | # out = self.nl2c(out) 194 | 195 | out = self.fc2d(out) 196 | out = self.nl2d(out) 197 | 198 | out = self.fc3(out) 199 | return out 200 | 201 | class RadynversionNet(ReversibleGraphNet): 202 | def __init__(self, inputs, outputs, zeroPadding=0, numInvLayers=5, dropout=0.00, minSize=None): 203 | # Determine dimensions and construct DataSchema 204 | inMinLength = schema_min_len(inputs, zeroPadding) 205 | outMinLength = schema_min_len(outputs, zeroPadding) 206 | minLength = max(inMinLength, outMinLength) 207 | if minSize is not None: 208 | minLength = max(minLength, minSize) 209 | self.inSchema = DataSchema1D(inputs, minLength, zeroPadding) 210 | self.outSchema = DataSchema1D(outputs, minLength, zeroPadding) 211 | if len(self.inSchema) != len(self.outSchema): 212 | raise ValueError('Input and output schemas do not have the same dimension.') 213 | 214 | # Build net graph 215 | inp = InputNode(len(self.inSchema), name='Input (0-pad extra channels)') 216 | nodes = [inp] 217 | 218 | for i in range(numInvLayers): 219 | nodes.append(Node([nodes[-1].out0], rev_multiplicative_layer, 220 | {'F_class': F_fully_connected_leaky, 'clamp': 2.0, 221 | 'F_args': {'dropout': 0.0}}, name='Inv%d' % i)) 222 | if (i != numInvLayers - 1): 223 | nodes.append(Node([nodes[-1].out0], permute_layer, {'seed': i}, name='Permute%d' % i)) 224 | 225 | nodes.append(OutputNode([nodes[-1].out0], name='Output')) 226 | # Build net 227 | super().__init__(nodes) 228 | 229 | 230 | class RadynversionTrainer: 231 | def __init__(self, model, atmosData, dev): 232 | self.model = model 233 | self.atmosData = atmosData 234 | self.dev = dev 235 | self.mmFns = None 236 | 237 | for mod_list in model.children(): 238 | for block in mod_list.children(): 239 | for coeff in block.children(): 240 | coeff.fc3.weight.data = 1e-3*torch.randn(coeff.fc3.weight.shape) 241 | # coeff.fc3.weight.data = 1e-2*torch.randn(coeff.fc3.weight.shape) 242 | 243 | self.model.to(dev) 244 | 245 | def training_params(self, numEpochs, lr=2e-3, miniBatchesPerEpoch=20, metaEpoch=12, miniBatchSize=None, 246 | l2Reg=2e-5, wPred=1500, wLatent=300, wRev=500, zerosNoiseScale=5e-3, fadeIn=True, 247 | loss_fit=mse, loss_latent=None, loss_backward=None): 248 | if miniBatchSize is None: 249 | miniBatchSize = self.atmosData.batchSize 250 | 251 | if loss_latent is None: 252 | loss_latent = mmd_multiscale_on(self.dev) 253 | 254 | if loss_backward is None: 255 | loss_backward = mmd_multiscale_on(self.dev) 256 | 257 | decayEpochs = (numEpochs * miniBatchesPerEpoch) // metaEpoch 258 | gamma = 0.004**(1.0 / decayEpochs) 259 | 260 | # self.optim = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.8, 0.8), 261 | # eps=1e-06, weight_decay=l2Reg) 262 | self.optim = torch.optim.Adam(self.model.parameters(), lr=lr, betas=(0.8, 0.8), 263 | eps=1e-06, weight_decay=l2Reg) 264 | self.scheduler = torch.optim.lr_scheduler.StepLR(self.optim, 265 | step_size=metaEpoch, 266 | gamma=gamma) 267 | self.wPred = wPred 268 | self.fadeIn = fadeIn 269 | self.wLatent = wLatent 270 | self.wRev = wRev 271 | self.zerosNoiseScale = zerosNoiseScale 272 | self.miniBatchSize = miniBatchSize 273 | self.miniBatchesPerEpoch = miniBatchesPerEpoch 274 | self.numEpochs = numEpochs 275 | self.loss_fit = loss_fit 276 | self.loss_latent = loss_latent 277 | self.loss_backward = loss_backward 278 | 279 | def train(self, epoch): 280 | self.model.train() 281 | 282 | lTot = 0 283 | miniBatchIdx = 0 284 | if self.fadeIn: 285 | wRevScale = min(epoch / (0.4 * self.numEpochs), 1)**3 286 | else: 287 | wRevScale = 1.0 288 | noiseScale = (1.0 - wRevScale) * self.zerosNoiseScale 289 | # noiseScale = self.zerosNoiseScale 290 | 291 | pad_fn = lambda *x: noiseScale * torch.randn(*x, device=self.dev) #+ 10 * torch.ones(*x, device=self.dev) 292 | # zeros = lambda *x: torch.zeros(*x, device=self.dev) 293 | randn = lambda *x: torch.randn(*x, device=self.dev) 294 | losses = [0, 0, 0, 0] 295 | 296 | for x, y in self.atmosData.trainLoader: 297 | miniBatchIdx += 1 298 | 299 | if miniBatchIdx > self.miniBatchesPerEpoch: 300 | break 301 | 302 | x, y = x.to(self.dev), y.to(self.dev) 303 | yClean = y.clone() 304 | 305 | xp = self.model.inSchema.fill({'ne': x[:, 0], 306 | 'temperature': x[:, 1], 307 | 'vel': x[:, 2]}, 308 | zero_pad_fn=pad_fn) 309 | yzp = self.model.outSchema.fill({'Halpha': y[:, 0], 310 | 'Ca8542': y[:, 1], 311 | 'LatentSpace': randn}, 312 | zero_pad_fn=pad_fn) 313 | 314 | self.optim.zero_grad() 315 | 316 | out = self.model(xp) 317 | 318 | # lForward = self.wPred * (self.loss_fit(y[:, 0], out[:, self.model.outSchema.Halpha]) + 319 | # self.loss_fit(y[:, 1], out[:, self.model.outSchema.Ca8542])) 320 | # lForward = self.wPred * self.loss_fit(yzp[:, :self.model.outSchema.LatentSpace[0]], out[:, :self.model.outSchema.LatentSpace[0]]) 321 | 322 | lForward = self.wPred * self.loss_fit(yzp[:, self.model.outSchema.LatentSpace[-1]+1:], 323 | out[:, self.model.outSchema.LatentSpace[-1]+1:]) 324 | losses[0] += lForward.data.item() / self.wPred 325 | 326 | 327 | outLatentGradOnly = torch.cat((out[:, self.model.outSchema.Halpha].data, 328 | out[:, self.model.outSchema.Ca8542].data, 329 | out[:, self.model.outSchema.LatentSpace]), 330 | dim=1) 331 | unpaddedTarget = torch.cat((yzp[:, self.model.outSchema.Halpha], 332 | yzp[:, self.model.outSchema.Ca8542], 333 | yzp[:, self.model.outSchema.LatentSpace]), 334 | dim=1) 335 | 336 | lForward2 = self.wLatent * self.loss_latent(outLatentGradOnly, unpaddedTarget) 337 | losses[1] += lForward2.data.item() / self.wLatent 338 | lForward += lForward2 339 | 340 | lTot += lForward.data.item() 341 | 342 | lForward.backward() 343 | 344 | yzpRev = self.model.outSchema.fill({'Halpha': yClean[:, 0], 345 | 'Ca8542': yClean[:, 1], 346 | 'LatentSpace': out[:, self.model.outSchema.LatentSpace].data}, 347 | zero_pad_fn=pad_fn) 348 | yzpRevRand = self.model.outSchema.fill({'Halpha': yClean[:, 0], 349 | 'Ca8542': yClean[:, 1], 350 | 'LatentSpace': randn}, 351 | zero_pad_fn=pad_fn) 352 | 353 | outRev = self.model(yzpRev, rev=True) 354 | outRevRand = self.model(yzpRevRand, rev=True) 355 | 356 | # THis guy should have been OUTREVRAND!!! 357 | # xBack = torch.cat((outRevRand[:, self.model.inSchema.ne], 358 | # outRevRand[:, self.model.inSchema.temperature], 359 | # outRevRand[:, self.model.inSchema.vel]), 360 | # dim=1) 361 | # lBackward = self.wRev * wRevScale * self.loss_backward(xBack, x.reshape(self.miniBatchSize, -1)) 362 | lBackward = self.wRev * wRevScale * self.loss_backward(outRevRand[:, self.model.inSchema.ne[0]:self.model.inSchema.vel[-1]+1], 363 | xp[:, self.model.inSchema.ne[0]:self.model.inSchema.vel[-1]+1]) 364 | 365 | scale = wRevScale if wRevScale != 0 else 1.0 366 | losses[2] += lBackward.data.item() / (self.wRev * scale) 367 | lBackward2 = 0.5 * self.wPred * self.loss_fit(outRev, xp) 368 | # lBackward2 = 0.5 * self.wPred * self.loss_fit(outRev[:, self.model.inSchema.ne[0]:self.model.inSchema.vel[-1]+1], 369 | # xp[:, self.model.inSchema.ne[0]:self.model.inSchema.vel[-1]+1]) 370 | losses[3] += lBackward2.data.item() / self.wPred * 2 371 | lBackward += lBackward2 372 | 373 | lTot += lBackward.data.item() 374 | 375 | lBackward.backward() 376 | 377 | for p in self.model.parameters(): 378 | p.grad.data.clamp_(-15.0, 15.0) 379 | 380 | self.optim.step() 381 | 382 | losses = [l / miniBatchIdx for l in losses] 383 | return lTot / miniBatchIdx, losses 384 | 385 | def test(self, maxBatches=10): 386 | self.model.eval() 387 | 388 | forwardError = [] 389 | backwardError = [] 390 | 391 | batchIdx = 0 392 | 393 | if maxBatches == -1: 394 | maxBatches = len(self.atmosData.testLoader) 395 | 396 | pad_fn = lambda *x: torch.zeros(*x, device=self.dev) # 10 * torch.ones(*x, device=self.dev) 397 | randn = lambda *x: torch.randn(*x, device=self.dev) 398 | with torch.no_grad(): 399 | for x, y in self.atmosData.testLoader: 400 | batchIdx += 1 401 | if batchIdx > maxBatches: 402 | break 403 | 404 | x, y = x.to(self.dev), y.to(self.dev) 405 | 406 | inp = self.model.inSchema.fill({'ne': x[:, 0], 407 | 'temperature': x[:, 1], 408 | 'vel': x[:, 2]}, 409 | zero_pad_fn=pad_fn) 410 | inpBack = self.model.outSchema.fill({'Halpha': y[:, 0], 411 | 'Ca8542': y[:, 1], 412 | 'LatentSpace': randn}, 413 | zero_pad_fn=pad_fn) 414 | 415 | out = self.model(inp) 416 | f = self.loss_fit(out[:, self.model.outSchema.Halpha], y[:, 0]) + \ 417 | self.loss_fit(out[:, self.model.outSchema.Ca8542], y[:, 1]) 418 | forwardError.append(f) 419 | 420 | outBack = self.model(inpBack, rev=True) 421 | # b = self.loss_fit(out[:, self.model.inSchema.ne], x[:, 0]) + \ 422 | # self.loss_fit(out[:, self.model.inSchema.temperature], x[:, 1]) + \ 423 | # self.loss_fit(out[:, self.model.inSchema.vel], x[:, 2]) 424 | b = self.loss_backward(outBack, inp) 425 | backwardError.append(b) 426 | 427 | fE = torch.mean(torch.stack(forwardError)) 428 | bE = torch.mean(torch.stack(backwardError)) 429 | 430 | return fE, bE, out, outBack 431 | 432 | def review_mmd(self): 433 | with torch.no_grad(): 434 | # Latent MMD 435 | loadIter = iter(self.atmosData.testLoader) 436 | # This is fine and doesn't load the first batch in testLoader every time, as shuffle=True 437 | x1, y1 = next(loadIter) 438 | x1, y1 = x1.to(self.dev), y1.to(self.dev) 439 | pad_fn = lambda *x: torch.zeros(*x, device=self.dev) # 10 * torch.ones(*x, device=self.dev) 440 | randn = lambda *x: torch.randn(*x, device=self.dev) 441 | xp = self.model.inSchema.fill({'ne': x1[:, 0], 442 | 'temperature': x1[:, 1], 443 | 'vel': x1[:, 2]}, 444 | zero_pad_fn=pad_fn) 445 | yp = self.model.outSchema.fill({'Halpha': y1[:, 0], 446 | 'Ca8542': y1[:, 1], 447 | 'LatentSpace': randn}, 448 | zero_pad_fn=pad_fn) 449 | yFor = self.model(xp) 450 | yForNp = torch.cat((yFor[:, self.model.outSchema.Halpha], yFor[:, self.model.outSchema.Ca8542], yFor[:, self.model.outSchema.LatentSpace]), dim=1).to(self.dev) 451 | ynp = torch.cat((yp[:, self.model.outSchema.Halpha], yp[:, self.model.outSchema.Ca8542], yp[:, self.model.outSchema.LatentSpace]), dim=1).to(self.dev) 452 | 453 | # Backward MMD 454 | xBack = self.model(yp, rev=True) 455 | 456 | r = np.logspace(np.log10(0.5), np.log10(500), num=2000) 457 | mmdValsFor = [] 458 | mmdValsBack = [] 459 | if self.mmFns is None: 460 | self.mmFns = [] 461 | for a in r: 462 | mm = mmd_multiscale_on(self.dev, alphas=[float(a)]) 463 | self.mmFns.append(mm) 464 | 465 | for mm in self.mmFns: 466 | mmdValsFor.append(mm(yForNp, ynp).item()) 467 | mmdValsBack.append(mm(xp[:, self.model.inSchema.ne[0]:self.model.inSchema.vel[-1]+1], xBack[:, self.model.inSchema.ne[0]:self.model.inSchema.vel[-1]+1]).item()) 468 | 469 | 470 | def find_new_mmd_idx(a): 471 | aRev = a[::-1] 472 | for i, v in enumerate(a[-2::-1]): 473 | if v < aRev[i]: 474 | return min(len(a)-i, len(a)-1) 475 | mmdValsFor = np.array(mmdValsFor) 476 | mmdValsBack = np.array(mmdValsBack) 477 | idxFor = find_new_mmd_idx(mmdValsFor) 478 | idxBack = find_new_mmd_idx(mmdValsBack) 479 | # idxFor = np.searchsorted(r, 2.0) if idxFor is None else idxFor 480 | # idxBack = np.searchsorted(r, 2.0) if idxBack is None else idxBack 481 | idxFor = idxFor if not idxFor is None else np.searchsorted(r, 2.0) 482 | idxBack = idxBack if not idxBack is None else np.searchsorted(r, 2.0) 483 | 484 | self.loss_backward = mmd_multiscale_on(self.dev, alphas=[float(r[idxBack])]) 485 | self.loss_latent = mmd_multiscale_on(self.dev, alphas=[float(r[idxFor])]) 486 | 487 | return r, mmdValsFor, mmdValsBack, idxFor, idxBack 488 | 489 | 490 | class AtmosData: 491 | def __init__(self, dataLocations, resampleWl='ProfileLength'): 492 | if type(dataLocations) is str: 493 | dataLocations = [dataLocations] 494 | 495 | with open(dataLocations[0], 'rb') as p: 496 | data = pickle.load(p) 497 | 498 | if len(dataLocations) > 1: 499 | for dataLocation in dataLocations[1:]: 500 | with open(dataLocation, 'rb') as p: 501 | d = pickle.load(p) 502 | 503 | for k in data.keys(): 504 | if k == 'wavelength' or k == 'z' or k == 'lineInfo': 505 | continue 506 | if k == 'line': 507 | for i in range(len(data['line'])): 508 | data[k][i] += d[k][i] 509 | else: 510 | try: 511 | data[k] += d[k] 512 | except KeyError: 513 | pass 514 | 515 | self.temperature = torch.stack(data['temperature']).float().log10_() 516 | self.ne = torch.stack(data['ne']).float().log10_() 517 | vel = torch.stack(data['vel']).float() / 1e5 518 | velSign = vel / vel.abs() 519 | velSign[velSign != velSign] = 0 520 | self.vel = velSign * (vel.abs() + 1).log10() 521 | 522 | if resampleWl == 'ProfileLength': 523 | resampleWl = self.ne.shape[1] 524 | 525 | wls = [wl.float() for wl in data['wavelength']] 526 | 527 | if resampleWl is not None: 528 | wlResample = [torch.from_numpy(np.linspace(torch.min(wl), torch.max(wl), num=resampleWl, dtype=np.float32)) for wl in wls] 529 | lineResample = [] 530 | for lineIdx in range(len(data['lineInfo'])): 531 | lineProfile = [] 532 | for line in data['line'][lineIdx]: 533 | interp = interp1d(wls[lineIdx], line, assume_sorted=True, kind='cubic') 534 | lineProfile.append(torch.from_numpy(interp(wlResample[lineIdx])).float()) 535 | lineResample.append(lineProfile) 536 | 537 | lines = [torch.stack(l).float() for l in lineResample] 538 | else: 539 | wlResample = wls 540 | lines = [torch.stack(data['line'][idx]).float() for idx in range(len(wls))] 541 | 542 | self.wls = wlResample 543 | self.lines = lines 544 | 545 | # use the [0] the chuck the index vector away 546 | lineMaxs = [torch.max(l, 1, keepdim=True)[0] for l in self.lines] 547 | lineMaxs = torch.cat(lineMaxs, dim=1) 548 | lineMaxs = torch.max(lineMaxs, 1, keepdim=True)[0] 549 | 550 | self.lines = [l / lineMaxs for l in self.lines] 551 | # self.lines = [l / torch.max(l, 1, keepdim=True)[0] for l in self.lines] 552 | self.z = data['z'].float() 553 | 554 | def split_data_and_init_loaders(self, batchSize, splitSeed=41, padLines=False, linePadValue='Edge', zeroPadding=0, testingFraction=0.2): 555 | self.atmosIn = torch.stack([self.ne, self.temperature, self.vel]).permute(1, 0, 2) 556 | self.batchSize = batchSize 557 | 558 | if padLines and linePadValue == 'Edge': 559 | lPad0Size = (self.ne.shape[1] - self.lines[0].shape[1]) // 2 560 | rPad0Size = self.ne.shape[1] - self.lines[0].shape[1] - lPad0Size 561 | lPad1Size = (self.ne.shape[1] - self.lines[1].shape[1]) // 2 562 | rPad1Size = self.ne.shape[1] - self.lines[1].shape[1] - lPad1Size 563 | if any(np.array([lPad0Size, rPad0Size, lPad1Size, rPad1Size]) <= 0): 564 | raise ValueError('Cannot pad lines as they are already bigger than/same size as the profiles!') 565 | lPad0 = torch.ones(self.lines[0].shape[0], lPad0Size) * self.lines[0][:, 0].unsqueeze(1) 566 | rPad0 = torch.ones(self.lines[0].shape[0], rPad0Size) * self.lines[0][:, -1].unsqueeze(1) 567 | lPad1 = torch.ones(self.lines[1].shape[0], lPad1Size) * self.lines[1][:, 0].unsqueeze(1) 568 | rPad1 = torch.ones(self.lines[1].shape[0], rPad1Size) * self.lines[1][:, -1].unsqueeze(1) 569 | 570 | self.lineOut = torch.stack([torch.cat((lPad0, self.lines[0], rPad0), dim=1), torch.cat((lPad1, self.lines[1], rPad1), dim=1)]).permute(1, 0, 2) 571 | elif padLines: 572 | lPad0Size = (self.ne.shape[1] - self.lines[0].shape[1]) // 2 573 | rPad0Size = self.ne.shape[1] - self.lines[0].shape[1] - lPad0Size 574 | lPad1Size = (self.ne.shape[1] - self.lines[1].shape[1]) // 2 575 | rPad1Size = self.ne.shape[1] - self.lines[1].shape[1] - lPad1Size 576 | if any(np.array([lPad0Size, rPad0Size, lPad1Size, rPad1Size]) <= 0): 577 | raise ValueError('Cannot pad lines as they are already bigger than/same size as the profiles!') 578 | lPad0 = torch.ones(self.lines[0].shape[0], lPad0Size) * linePadValue 579 | rPad0 = torch.ones(self.lines[0].shape[0], rPad0Size) * linePadValue 580 | lPad1 = torch.ones(self.lines[1].shape[0], lPad1Size) * linePadValue 581 | rPad1 = torch.ones(self.lines[1].shape[0], rPad1Size) * linePadValue 582 | 583 | self.lineOut = torch.stack([torch.cat((lPad0, self.lines[0], rPad0), dim=1), torch.cat((lPad1, self.lines[1], rPad1), dim=1)]).permute(1, 0, 2) 584 | else: 585 | self.lineOut = torch.stack([self.lines[0], self.lines[1]]).permute(1, 0, 2) 586 | 587 | indices = np.arange(self.atmosIn.shape[0]) 588 | np.random.RandomState(seed=splitSeed).shuffle(indices) 589 | 590 | # split off 20% for testing 591 | maxIdx = int(self.atmosIn.shape[0] * (1.0 - testingFraction)) + 1 592 | if zeroPadding != 0: 593 | trainIn = torch.cat((self.atmosIn[indices][:maxIdx], torch.zeros(maxIdx, self.atmosIn.shape[1], zeroPadding)), dim=2) 594 | trainOut = torch.cat((self.lineOut[indices][:maxIdx], torch.zeros(maxIdx, self.lineOut.shape[1], zeroPadding)), dim=2) 595 | testIn = torch.cat((self.atmosIn[indices][maxIdx:], torch.zeros(self.atmosIn.shape[0] - maxIdx, self.atmosIn.shape[1], zeroPadding)), dim=2) 596 | testOut = torch.cat((self.lineOut[indices][maxIdx:], torch.zeros(self.atmosIn.shape[0] - maxIdx, self.lineOut.shape[1], zeroPadding)), dim=2) 597 | else: 598 | trainIn = self.atmosIn[indices][:maxIdx] 599 | trainOut = self.lineOut[indices][:maxIdx] 600 | testIn = self.atmosIn[indices][maxIdx:] 601 | testOut = self.lineOut[indices][maxIdx:] 602 | 603 | self.testLoader = torch.utils.data.DataLoader( 604 | torch.utils.data.TensorDataset(testIn, testOut), 605 | batch_size=batchSize, shuffle=True, drop_last=True) 606 | self.trainLoader = torch.utils.data.DataLoader( 607 | torch.utils.data.TensorDataset(trainIn, trainOut), 608 | batch_size=batchSize, shuffle=True, drop_last=True) 609 | 610 | 611 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 C. M. J. Osborne, J. A. Armstrong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MiniBalancedTraining.pickle: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Goobley/Radynversion/f44edc77b6eb7ef2bdbd8e8aabda3bf9822d3695/MiniBalancedTraining.pickle -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Radynversion: Learning to Invert a Solar Flare Atmosphere with Invertible Neural Networks 2 | 3 | This repository contains the code for Radynversion, an Invertible Neural Network (INN) based tool that infers solar atmospheric properties during a solar flare, based on a model trained on RADYN simulations (references not limited to Carlsson & Stein 1992, 1995, 1999, and many more..., Allred et al. 2005, 2015). RADYN is a 1D non-equilibrium radiation hydrodynamic model with good optically thick radiation treatment (under the assumption of complete redistribution). It does not consider magnetic effects. 4 | 5 | Our INN is trained to learn an approximate bijective mapping between the atmospheric properties of electron density, temperature, and bulk velocity (all as a function of altitude), and the observed Hα and Ca II λ8542 line profiles. As information is lost in the forward process of radiation transfer, this information is injected back into the model during the inverse process by means of a latent space. Thanks to the training, this latent space can now be filled using an n-dimensional unit Gaussian distribution, where n is the dimensionality of the latent space. 6 | 7 | The bijectivity of the model is assured by the building blocks of the INN, the affine coupling layers. By splitting the data into two streams, these blocks combine four arbitrarily complex non-invertible functions and apply these to the input data in a reversible manner. 8 | 9 | This processs, and its validation are described in the paper: [Osborne, Armstrong, and Fletcher (2019)](https://doi.org/10.3847/1538-4357/ab07b4). The code associated with this paper lives in this repository, but the Radynversion tool will in time be merged into the 10 | [RadynPy](https://github.com/Goobley/radynpy) python module. 11 | 12 | For an example of the model in action the `single_pixel_inversion_example.ipynb` notebook is the recommended place to start. It uses a library of functions defined in `utils.py` You will also need the model weights, which are available on the Github releases page for this project. 13 | To look at the training of the model, the reader is directed to `Radynversion.ipynb`, which calls functions from `Inn2.py` and `Loss.py`. To train your own variant of the model you can use our data extracted from the F-CHROMA RADYN simulations grid, also available on the releases page (the _ridiculously_ named `DoublePicoGigaPickle50.pickle`) or you can generate your own from Radyn simulations via `ExportSimpleLineBlobForTraining.py`. At the very least, the paths in this last script will need modifying for your system and simulation set. 14 | 15 | The two main notebooks specify their required packages. The combined requirements are: 16 | - `Python 3` 17 | - `numpy` 18 | - `scipy` 19 | - `matplotlib` 20 | - `pytorch` (currently `0.4.1`, but should also be compatible with `1.0`, though I have yet to check). 21 | - `astropy` 22 | - `scikit-image` 23 | - `palettable` (optional, only required for colourmaps, but no fail-safes in the code if not preset). 24 | - `crisPy` ([available here](https://github.com/rhero12/crisPy)) 25 | - `FrEIA` ([available here](https://github.com/VLL-HD/FrEIA)) 26 | - `RadynPy` ([available here](https://github.com/Goobley/radynpy), needed for loading RADYN outputs, so essential for making your own training set, not currently required otherwise, though Radynversion will eventually be accessible as a RadynPy module). 27 | 28 | Some of these packages will also have their own requirements, but your package manager should hopefully be able to figure most of that out! 29 | 30 | Developed by Chris Osborne & John Armstrong, University of Glasgow, Astronomy and Astrophysics (2018-2019). MIT License. 31 | Please drop us an email with comments, suggestions etc. Contact address `c.osborne.1 [at] research [dot] gla [dot] ac [dot] uk`. 32 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .Inn2 import * 2 | from .loss import * 3 | from .utils import * 4 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def mse(inp, target): 4 | return torch.mean((inp - target)**2) 5 | 6 | def tv_chan(inp, target): 7 | l = inp.shape[-1]-1 8 | tvIn = torch.sum(torch.abs(inp[:, :, 1:] - inp[:, :, :-1]), dim=2) 9 | tvTarget = torch.sum(torch.abs(target[:, :, 1:] - target[:, :, :-1]), dim=2) 10 | return torch.mean(torch.abs(tvTarget - tvIn)) / l 11 | 12 | def tv_no_chan(inp, target): 13 | l = inp.shape[-1]-1 14 | tvIn = torch.sum(torch.abs(inp[:, 1:] - inp[:, :-1]), dim=1) 15 | tvTarget = torch.sum(torch.abs(target[:, 1:] - target[:, :-1]), dim=1) 16 | return torch.mean(torch.abs(tvTarget - tvIn)) / l 17 | 18 | def mse_tv(inp, target): 19 | return 0.9 * mse(inp, target) + 0.1 * tv_no_chan(inp, target) 20 | 21 | 22 | # def mmd_multiscale_on(dev): 23 | # def mmd_multiscale(x, y): 24 | # xx, yy, zz = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x,y.t()) 25 | 26 | # rx = (xx.diag().unsqueeze(0).expand_as(xx)) 27 | # ry = (yy.diag().unsqueeze(0).expand_as(yy)) 28 | 29 | # dxx = rx.t() + rx - 2.*xx 30 | # dyy = ry.t() + ry - 2.*yy 31 | # dxy = rx.t() + ry - 2.*zz 32 | 33 | # XX, YY, XY = (torch.zeros(xx.shape).to(dev), 34 | 35 | # torch.zeros(xx.shape).to(dev), 36 | # torch.zeros(xx.shape).to(dev)) 37 | 38 | # for a in [0.2, 0.5, 0.9, 1.3, 2.4, 5.0, 10.0, 20.0, 40.0]: 39 | # # for a in [0.05, 0.125, 0.225, 0.325]: 40 | # XX += a**2 * (a**2 + dxx)**-1 41 | # YY += a**2 * (a**2 + dyy)**-1 42 | # XY += a**2 * (a**2 + dxy)**-1 43 | 44 | # return torch.mean(XX + YY - 2.*XY) 45 | # return mmd_multiscale 46 | 47 | def mmd_multiscale_on(dev, alphas=None): 48 | if alphas is None: 49 | alphas = [0.2, 0.5, 0.9, 1.3, 2.4, 5.0, 10.0, 20.0, 40.0] 50 | 51 | def mmd_multiscale(x, y): 52 | xx, yy, zz = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x,y.t()) 53 | 54 | rx = (xx.diag().unsqueeze(0).expand_as(xx)) 55 | ry = (yy.diag().unsqueeze(0).expand_as(yy)) 56 | 57 | dxx = rx.t() + rx - 2.*xx 58 | dyy = ry.t() + ry - 2.*yy 59 | dxy = rx.t() + ry - 2.*zz 60 | 61 | XX, YY, XY = (torch.zeros(xx.shape).to(dev), 62 | 63 | torch.zeros(xx.shape).to(dev), 64 | torch.zeros(xx.shape).to(dev)) 65 | 66 | # for a in [0.2, 0.5, 0.9, 1.3, 2.4, 5.0, 10.0, 20.0, 40.0]: 67 | # for a in [0.05, 0.125, 0.225, 0.325]: 68 | for a in alphas: 69 | XX += a**2 * (a**2 + dxx)**-1 70 | YY += a**2 * (a**2 + dyy)**-1 71 | XY += a**2 * (a**2 + dxy)**-1 72 | 73 | return torch.mean(XX + YY - 2.*XY) 74 | return mmd_multiscale -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from .Inn2 import RadynversionNet 2 | import os 3 | import numpy as np 4 | from scipy.interpolate import interp1d 5 | import torch 6 | import matplotlib.pyplot as plt 7 | import matplotlib as mpl 8 | from matplotlib.colors import PowerNorm,LinearSegmentedColormap 9 | import matplotlib.ticker 10 | 11 | __all__ = ["create_model","obs_files","interp_to_radyn_grid","normalise","inversion","inversion_plots","integrated_intensity","intensity_ratio","doppler_vel","lambda_0","variance","wing_idxs","oom_formatter","delta_lambda","lambda_0_wing","interp_fine"] 12 | 13 | def create_model(filename,dev): 14 | ''' 15 | A function to load the model to perform inversions on unseen data. This function also loads the height profile and wavelength grids from RADYN. 16 | 17 | Paramters 18 | --------- 19 | filename : str 20 | The path to the checkpoint file. 21 | dev : torch.device 22 | The hardware device to pass the model onto. 23 | 24 | Returns 25 | ------- 26 | model : RadynversionNet 27 | The model with the loaded trained weights ready to do testing. 28 | checkpoint["z"] : torch.Tensor 29 | The height profile from the RADYN grid. 30 | ''' 31 | 32 | if os.path.isfile(filename): 33 | print("=> loading checkpoint '%s'" % filename) 34 | checkpoint = torch.load(filename,map_location=dev) 35 | model = RadynversionNet(inputs=checkpoint["inRepr"],outputs=checkpoint["outRepr"],minSize=384).to(dev) 36 | model.load_state_dict(checkpoint["state_dict"]) 37 | print("=> loaded checkpoint '%s' (total number of epochs trained for %d)" % (filename,checkpoint["epoch"])) 38 | return model, checkpoint["z"] 39 | else: 40 | print("=> no checkpoint found at '%s'" % filename) 41 | 42 | def obs_files(path): 43 | ''' 44 | A function to return a list of the files of the observations. 45 | 46 | Parameters 47 | ---------- 48 | path : str 49 | The path to the observations. 50 | 51 | Returns 52 | ------- 53 | : list 54 | The list of the paths to all of the observation files. 55 | ''' 56 | 57 | return sorted([path + f for f in os.listdir(path) if f.endswith(".fits") and not f.startswith(".")]) 58 | 59 | def interp_to_radyn_grid(intensity_vector,centre_wvl,hw,wvl_range): 60 | ''' 61 | A function to linearly interpolate the observational line profiles to the number of wavelength points in the RADYN grid. 62 | 63 | Parameters 64 | ---------- 65 | intensity_vector : numpy.ndarray 66 | The intensity vector from a pixel in the CRISP image. 67 | centre_wvl : float 68 | The central measured wavelength obtained from the TWAVE1 keyword in the observartion's FITS header. 69 | hw : float 70 | The half-width of the line on the RADYN grid. 71 | wvl_range : numpy.ndarray 72 | The wavelength range from the observations. 73 | 74 | Returns 75 | ------- 76 | : list 77 | A list of the interpolated wavelengths and intensities. Each element of the list is a numpy.ndarray. 78 | ''' 79 | 80 | wvl_vector = np.linspace(centre_wvl-hw,centre_wvl+hw,num=30) 81 | interp = interp1d(wvl_range,intensity_vector,kind="linear") 82 | 83 | return [wvl_vector,interp(wvl_vector)] 84 | 85 | def normalise(new_ca,new_ha): 86 | ''' 87 | A function to normalise the spectral line profiles as the RADYN grid works on normalised profiles. 88 | 89 | Parameters 90 | ---------- 91 | new_ca : numpy.ndarray 92 | The new calcium line interpolated onto the RADYN grid. 93 | new_ha : numpy.ndarray 94 | The new hydrogen line interpolated onto the RADYN grid. 95 | 96 | Returns 97 | ------- 98 | new_ca : numpy.ndarray 99 | The interpolated calcium line normalised. 100 | new_ha : numpy.ndarray 101 | The interpolated hydrogen line normalised. 102 | ''' 103 | 104 | peak_emission = max(np.amax(new_ca[1]),np.amax(new_ha[1])) 105 | 106 | new_ca[1] /= peak_emission 107 | new_ha[1] /= peak_emission 108 | 109 | return new_ca, new_ha 110 | 111 | def inverse_velocity_conversion(out_velocities): 112 | ''' 113 | A function to convert the calculated inverse velocities from the smooth space to the actual space. 114 | 115 | Parameters 116 | ---------- 117 | out_velocities : torch.Tensor 118 | The velocity profiles obtained from the inversion. 119 | 120 | Returns 121 | ------- 122 | : torch.Tensor 123 | The velocity profiles converted back to the actual space. 124 | ''' 125 | 126 | v_sign = out_velocities / torch.abs(out_velocities) 127 | v_sign[torch.isnan(v_sign)] = 0 128 | 129 | return v_sign * (10**torch.abs(out_velocities) - 1.0) 130 | 131 | def inversion(model,dev,ca_data,ha_data,batch_size): 132 | ''' 133 | A function which performs the inversions on the spectral line profiles. 134 | 135 | Parameters 136 | ---------- 137 | model : RadynversionNet 138 | The trained inversion model. 139 | dev : torch.device 140 | The hardware device to pass the model onto. 141 | ca_data : list 142 | A concatenated list of the calcium wavelengths and intensities. 143 | ha_data : list 144 | A concatenated list of the hydrogen wavelengths and intensities. 145 | batch_size : int 146 | The number of samples to take from the latent space. 147 | 148 | Returns 149 | ------- 150 | results : dict 151 | The results of the inversions and the roundtrips on the line profiles. 152 | ''' 153 | 154 | model.eval() 155 | with torch.no_grad(): 156 | y = torch.ones((batch_size,2,ca_data[0].shape[0])) #sets up the input to the model by creating an array containing the line profiles a batch_size number of times such that the latent space can be sampled that many times for adequate confidence in the inversion 157 | y[:,0] *= torch.from_numpy(ha_data[1]).float() 158 | y[:,1] *= torch.from_numpy(ca_data[1]).float() #loads in the line profiles batch_size amount of times to be used with different samples drawn from the latent space 159 | yz = model.outSchema.fill({ 160 | "Halpha" : y[:,0], 161 | "Ca8542" : y[:,1], 162 | "LatentSpace" : torch.randn 163 | }) #constructs the [y,z] pairs for the network 164 | x_out = model(yz.to(dev),rev=True) 165 | y_round_trip = model(x_out) #uses the calculated atmospheric parameters to generate the line profiles to see if they're the same 166 | vel = inverse_velocity_conversion(x_out[:,model.inSchema.vel]) 167 | 168 | results = { 169 | "Halpha" : y_round_trip[:,model.outSchema.Halpha].cpu().numpy(), 170 | "Ca8542" : y_round_trip[:,model.outSchema.Ca8542].cpu().numpy(), 171 | "ne" : x_out[:,model.inSchema.ne].cpu().numpy(), 172 | "temperature" : x_out[:,model.inSchema.temperature].cpu().numpy(), 173 | "vel" : vel.cpu().numpy(), 174 | "Halpha_true" : yz[0,model.outSchema.Halpha].cpu().numpy(), 175 | "Ca8542_true" : yz[0,model.outSchema.Ca8542].cpu().numpy() 176 | } 177 | 178 | return results 179 | 180 | def inversion_plots(results,z,ca_data,ha_data): 181 | ''' 182 | A function to plot the results of the inversions. 183 | 184 | Parameters 185 | ---------- 186 | results : dict 187 | The results from the inversions.m the latent space. 188 | z : torch.Tensor 189 | The height profiles of the RADYN grid. 190 | ca_data : list 191 | A concatenated list of the calcium wavelengths and intensities. 192 | ha_data : list 193 | A concatenated list of the hydrogen wavelengths and intensities. 194 | ''' 195 | 196 | fig, ax = plt.subplots(nrows=2,ncols=2,figsize=(9,7),constrained_layout=True) 197 | ax2 = ax[0,0].twinx() 198 | ca_wvls = ca_data[0] 199 | ha_wvls = ha_data[0] 200 | z_local = z / 1e8 201 | 202 | z_edges = [z_local[0] - 0.5*(z_local[1]-z_local[0])] 203 | for i in range(z_local.shape[0]-1): 204 | z_edges.append(0.5*(z_local[i]+z_local[i+1])) 205 | z_edges.append(z_local[-1] + 0.5*(z_local[-1]-z_local[-2])) 206 | z_edges = [float(f) for f in z_edges] 207 | ca_edges = [ca_wvls[0] - 0.5*(ca_wvls[1]-ca_wvls[0])] 208 | for i in range(ca_wvls.shape[0]-1): 209 | ca_edges.append(0.5*(ca_wvls[i]+ca_wvls[i+1])) 210 | ca_edges.append(ca_wvls[-1] + 0.5*(ca_wvls[-1]-ca_wvls[-2])) 211 | ha_edges = [ha_wvls[0] - 0.5*(ha_wvls[1]-ha_wvls[0])] 212 | for i in range(ha_wvls.shape[0]-1): 213 | ha_edges.append(0.5*(ha_wvls[i]+ha_wvls[i+1])) 214 | ha_edges.append(ha_wvls[-1] + 0.5*(ha_wvls[-1]-ha_wvls[-2])) 215 | ne_edges = np.linspace(8,15,num=101) 216 | temp_edges = np.linspace(3,8,num=101) 217 | vel_max = 2*np.max(np.median(results["vel"],axis=0)) 218 | vel_min = np.min(np.median(results["vel"],axis=0)) 219 | vel_min = np.sign(vel_min)*np.abs(vel_min)*2 220 | vel_edges = np.linspace(vel_min,vel_max,num=101) 221 | ca_max = 1.1*np.max(np.max(results["Ca8542"],axis=0)) 222 | ca_min = 0.9*np.min(np.min(results["Ca8542"],axis=0)) 223 | ca_edges_int = np.linspace(ca_min,ca_max,num=101) 224 | ha_max = 1.1*np.max(np.max(results["Halpha"],axis=0)) 225 | ha_min = 0.9*np.min(np.min(results["Halpha"],axis=0)) 226 | ha_edges_int = np.linspace(ha_min,ha_max,num=201) 227 | 228 | 229 | cmap_ne = [(51/255,187/255,238/255,0.0), (51/255, 187/255, 238/255, 1.0)] 230 | colors_ne = LinearSegmentedColormap.from_list('ne', cmap_ne) 231 | cmap_temp = [(238/255,119/255,51/255,0.0),(238/255,119/255,51/255,1.0)] 232 | colors_temp = LinearSegmentedColormap.from_list("temp",cmap_temp) 233 | cmap_vel = [(238/255,51/255,119/255,0.0),(238/255,51/255,119/255,1.0)] 234 | cmap_vel = LinearSegmentedColormap.from_list("vel",cmap_vel) 235 | 236 | ax[0,0].hist2d(torch.cat([z_local]*results["ne"].shape[0]).cpu().numpy(),results["ne"].reshape((-1,)),bins=(z_edges,ne_edges),cmap=colors_ne,norm=PowerNorm(0.3)) 237 | ax[0,0].plot(z_local.cpu().numpy(),np.median(results["ne"],axis=0), "--",c="k") 238 | ax[0,0].set_ylabel(r"log $n_{e}$ [cm$^{-3}$]",color=(51/255,187/255,238/255)) 239 | ax[0,0].set_xlabel("z [Mm]") 240 | # ax[0,0].xaxis.set_major_formatter(oom_formatter(8)) 241 | ax2.hist2d(torch.cat([z_local]*results["temperature"].shape[0]).cpu().numpy(),results["temperature"].reshape((-1,)),bins=(z_edges,temp_edges),cmap=colors_temp,norm=PowerNorm(0.3)) 242 | ax2.plot(z_local.cpu().numpy(),np.median(results["temperature"],axis=0),"--",c="k") 243 | ax2.set_ylabel("log T [K]",color=(238/255,119/255,51/255)) 244 | ax[0,1].hist2d(torch.cat([z_local]*results["vel"].shape[0]).cpu().numpy(),results["vel"].reshape((-1,)),bins=(z_edges,vel_edges),cmap=cmap_vel,norm=PowerNorm(0.3)) 245 | ax[0,1].plot(z_local.cpu().numpy(),np.median(results["vel"],axis=0),"--",c="k") 246 | ax[0,1].set_ylabel(r"v [kms$^{-1}$]",color=(238/255,51/255,119/255)) 247 | ax[0,1].set_xlabel("z [Mm]") 248 | # ax[0,1].xaxis.set_major_formatter(oom_formatter(8)) 249 | ax[1,0].plot(ha_data[0],results["Halpha_true"],"--") 250 | ax[1,0].hist2d(np.concatenate([ha_wvls]*results["Halpha"].shape[0]),results["Halpha"].reshape((-1,)),bins=(ha_edges,ha_edges_int),cmap="gray_r",norm=PowerNorm(0.3)) 251 | ax[1,0].set_title(r"H$\alpha$") 252 | ax[1,0].set_ylabel("Normalised Intensity") 253 | ax[1,0].set_xlabel(r"Wavelength [$\AA{}$]") 254 | ax[1,0].xaxis.set_major_locator(plt.MaxNLocator(5)) 255 | ax[1,1].hist2d(np.concatenate([ca_wvls]*results["Ca8542"].shape[0]),results["Ca8542"].reshape((-1,)),bins=(ca_edges,ca_edges_int),cmap="gray_r",norm=PowerNorm(0.3)) 256 | ax[1,1].set_title(r"Ca II 8542$\AA{}$") 257 | ax[1,1].plot(ca_data[0],results["Ca8542_true"],"--") 258 | ax[1,1].set_xlabel(r"Wavelength [$\AA{}$]") 259 | ax[1,1].xaxis.set_major_locator(plt.MaxNLocator(5)) 260 | 261 | class oom_formatter(matplotlib.ticker.ScalarFormatter): 262 | ''' 263 | Matplotlib formatter for changing the number of orders of magnitude displayed on an axis as well as the number of decimal points. 264 | 265 | Adapted from: https://stackoverflow.com/questions/42656139/set-scientific-notation-with-fixed-exponent-and-significant-digits-for-multiple 266 | ''' 267 | 268 | def __init__(self,order=0,fformat="%1.1f",offset=True,math_text=True): 269 | self.oom = order 270 | self.fformat = fformat 271 | matplotlib.ticker.ScalarFormatter.__init__(self,useOffset=offset,useMathText=math_text) 272 | 273 | def _set_orderOfMagnitude(self,nothing): 274 | self.orderOfMagnitude = self.oom 275 | 276 | def _set_format(self, v_min, v_max): 277 | self.format = self.fformat 278 | if self._useMathText: 279 | self.format = "$%s$" % matplotlib.ticker._mathdefault(self.format) 280 | 281 | def integrated_intensity(idx_range,intensity_vector): 282 | ''' 283 | A function to find the integrated intensity over a wavelength range of a spectral line. 284 | 285 | Parameters 286 | ---------- 287 | idx_range : range 288 | The range of indices to integrate over. 289 | intensity_vector : numpy.ndarray 290 | The vector of spectral line intensities. 291 | ''' 292 | 293 | total = 0 294 | for idx in idx_range: 295 | total += intensity_vector[idx] 296 | 297 | return total / len(idx_range) 298 | 299 | def intensity_ratio(blue_intensity,red_intensity): 300 | ''' 301 | A function that calculates the intensity ratio of two previously integrated intensities. 302 | ''' 303 | 304 | return blue_intensity / red_intensity 305 | 306 | def doppler_vel(l,delta_l): 307 | return (delta_l / l) * 3e5 #calculates the doppler velocity in km/s 308 | 309 | def lambda_0(wvls,ints): 310 | ''' 311 | Calculates the intensity-averaged line core. 312 | ''' 313 | 314 | num = np.sum(np.multiply(ints,wvls)) 315 | den = np.sum(ints) 316 | 317 | return num / den 318 | 319 | def variance(wvls,ints,l_0): 320 | ''' 321 | Calculates the variance of the spectral line w.r.t. the intensity-averaged line core. 322 | ''' 323 | 324 | num = np.sum(np.multiply(ints,(wvls-l_0)**2)) 325 | den = np.sum(ints) 326 | 327 | return num / den 328 | 329 | def wing_idxs(wvls,ints,var,l_0): 330 | ''' 331 | A function to work out the index range for the wings of a spectral line. This is working on the definition of wings that says the wings are defined as being one standard deviation away from the intensity-averaged line core. 332 | ''' 333 | 334 | blue_wing_start = 0 #blue wing starts at the shortest wavelength 335 | red_wing_end = wvls.shape[0] - 1 #red wing ends at the longest wavelength 336 | 337 | blue_end_wvl = l_0 - np.sqrt(var) 338 | red_start_wvl = l_0 + np.sqrt(var) 339 | 340 | blue_wing_end = np.argmin(np.abs(wvls - blue_end_wvl)) 341 | red_wing_start = np.argmin(np.abs(wvls - red_start_wvl)) 342 | 343 | return range(blue_wing_start,blue_wing_end+1), range(red_wing_start,red_wing_end+1) 344 | 345 | def delta_lambda(wing_idxs,wvls): 346 | ''' 347 | Calculates the half-width wavelength of an intensity range. 348 | 349 | Parameters 350 | ---------- 351 | wing_idxs : range 352 | The range of the indices of the intensity region in question. 353 | wvls : numpy.ndarray 354 | The wavelengths corresponding to the intensity region in question. 355 | ''' 356 | 357 | return len(wing_idxs)*(wvls[1] - wvls[0])/2 358 | 359 | def lambda_0_wing(wing_idxs,wvls,delta_lambda): 360 | ''' 361 | Calculates the central wavelength of an intensity range. 362 | 363 | Parameters 364 | ---------- 365 | wing_idxs : range 366 | The range of the indices of the intensity region in question. 367 | wvls : numpy.ndarray 368 | The wavelengths corresponding to the intensity region in question. 369 | delta_lambda : float 370 | The half-width wavelength of an intensity range. 371 | ''' 372 | 373 | return wvls[list(wing_idxs)[-1]] - delta_lambda 374 | 375 | def interp_fine(spec_line): 376 | ''' 377 | Interpolates the spectral line onto a finer grid for more accurate calculations for the wing properties. 378 | ''' 379 | 380 | x, y = spec_line 381 | x_new = np.linspace(x[0],x[-1],num=1001) 382 | y_new = interp1d(x,y)(x_new) 383 | 384 | return np.array([x_new,y_new]) 385 | 386 | z = np.array([-0.065, 0.016, 0.097, 0.178, 0.259, 0.340, 0.421, 0.502, 0.583, 0.664, 0.745, 0.826, 0.907, 0.988, 1.069, 1.150, 1.231, 1.312, 1.393, 1.474, 1.555, 1.636, 1.718, 1.799, 1.880, 1.961, 2.042, 2.123, 2.204, 2.285, 2.366, 2.447, 2.528, 2.609, 2.690, 2.771, 2.852, 2.933, 3.014, 3.095, 3.176, 3.257, 3.338, 3.419, 3.500, 4.360, 5.431, 6.766, 8.429, 10.5], dtype=np.float32) 387 | -------------------------------------------------------------------------------- /z.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Goobley/Radynversion/f44edc77b6eb7ef2bdbd8e8aabda3bf9822d3695/z.h5 --------------------------------------------------------------------------------